From 315d2ab17d338f6aef6026929e8671726cd76ba7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 8 Dec 2025 00:07:07 +0200 Subject: [PATCH] all: fix staticcheck issues --- .github/workflows/go.yml | 1 + .pre-commit-config.yaml | 3 +- appservice/websocket.go | 2 +- bridgev2/bridgeconfig/permissions.go | 5 +- bridgev2/bridgestate.go | 6 +- bridgev2/database/database.go | 58 ------------ bridgev2/matrix/crypto.go | 8 +- bridgev2/matrix/matrix.go | 1 + bridgev2/matrix/mxmain/legacymigrate.go | 5 +- bridgev2/matrix/provisioning.go | 3 +- bridgev2/portalbackfill.go | 1 + bridgev2/portalreid.go | 2 +- client.go | 10 +- crypto/attachment/attachments.go | 49 ++++++---- crypto/attachment/attachments_test.go | 10 +- crypto/cross_sign_pubkey.go | 4 +- crypto/cross_sign_store.go | 30 +++--- crypto/cryptohelper/cryptohelper.go | 5 +- crypto/decryptmegolm.go | 57 +++++++----- crypto/decryptolm.go | 54 ++++++----- crypto/devicelist.go | 35 ++++--- crypto/encryptmegolm.go | 11 ++- crypto/goolm/account/account.go | 2 +- crypto/goolm/account/account_test.go | 2 +- crypto/goolm/goolmbase64/base64.go | 4 +- crypto/goolm/libolmpickle/picklejson.go | 2 +- crypto/goolm/message/session_export.go | 2 +- crypto/goolm/message/session_sharing.go | 2 +- crypto/goolm/pk/decryption.go | 2 +- crypto/goolm/pk/encryption.go | 3 + .../goolm/session/megolm_inbound_session.go | 4 +- .../goolm/session/megolm_outbound_session.go | 6 +- crypto/goolm/session/olm_session.go | 13 ++- crypto/goolm/session/register.go | 8 +- crypto/keybackup.go | 11 ++- crypto/libolm/account.go | 22 ++--- crypto/libolm/error.go | 30 +++--- crypto/libolm/inboundgroupsession.go | 18 ++-- crypto/libolm/outboundgroupsession.go | 8 +- crypto/libolm/pk.go | 2 +- crypto/libolm/register.go | 2 +- crypto/libolm/session.go | 20 ++-- crypto/machine.go | 2 +- crypto/olm/errors.go | 93 +++++++++++-------- crypto/sessions.go | 14 ++- event/encryption.go | 2 +- federation/resolution.go | 5 +- filter.go | 2 +- id/contenturi.go | 20 ++-- id/matrixuri.go | 2 +- id/userid.go | 8 +- pushrules/action.go | 2 +- pushrules/condition_test.go | 8 -- room.go | 6 +- synapseadmin/roomapi.go | 3 +- url.go | 6 +- 56 files changed, 358 insertions(+), 338 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index dc4f17e2..8bce4484 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -24,6 +24,7 @@ jobs: - name: Install goimports run: | go install golang.org/x/tools/cmd/goimports@latest + go install honnef.co/go/tools/cmd/staticcheck@latest export PATH="$HOME/go/bin:$PATH" - name: Run pre-commit diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0b9785ae..4f769e56 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,8 +18,7 @@ repos: - "-w" - id: go-vet-repo-mod - id: go-mod-tidy - # TODO enable this - #- id: go-staticcheck-repo-mod + - id: go-staticcheck-repo-mod - repo: https://github.com/beeper/pre-commit-go rev: v0.4.2 diff --git a/appservice/websocket.go b/appservice/websocket.go index 1e401c53..4f2538bf 100644 --- a/appservice/websocket.go +++ b/appservice/websocket.go @@ -56,7 +56,7 @@ func (wsc *WebsocketCommand) MakeResponse(ok bool, data any) *WebsocketRequest { var prefixMessage string for unwrappedErr != nil { errorData, jsonErr = json.Marshal(unwrappedErr) - if errorData != nil && len(errorData) > 2 && jsonErr == nil { + if len(errorData) > 2 && jsonErr == nil { prefixMessage = strings.Replace(err.Error(), unwrappedErr.Error(), "", 1) prefixMessage = strings.TrimRight(prefixMessage, ": ") break diff --git a/bridgev2/bridgeconfig/permissions.go b/bridgev2/bridgeconfig/permissions.go index 898bf58a..9efe068e 100644 --- a/bridgev2/bridgeconfig/permissions.go +++ b/bridgev2/bridgeconfig/permissions.go @@ -41,10 +41,7 @@ func (pc PermissionConfig) IsConfigured() bool { _, hasExampleDomain := pc["example.com"] _, hasExampleUser := pc["@admin:example.com"] exampleLen := boolToInt(hasWildcard) + boolToInt(hasExampleUser) + boolToInt(hasExampleDomain) - if len(pc) <= exampleLen { - return false - } - return true + return len(pc) > exampleLen } func (pc PermissionConfig) Get(userID id.UserID) Permissions { diff --git a/bridgev2/bridgestate.go b/bridgev2/bridgestate.go index a1d3e70b..babbccab 100644 --- a/bridgev2/bridgestate.go +++ b/bridgev2/bridgestate.go @@ -102,9 +102,9 @@ func (bsq *BridgeStateQueue) loop() { } } -func (bsq *BridgeStateQueue) scheduleNotice(ctx context.Context, triggeredBy status.BridgeState) { +func (bsq *BridgeStateQueue) scheduleNotice(triggeredBy status.BridgeState) { log := bsq.login.Log.With().Str("action", "transient disconnect notice").Logger() - ctx = log.WithContext(bsq.bridge.BackgroundCtx) + ctx := log.WithContext(bsq.bridge.BackgroundCtx) if !bsq.waitForTransientDisconnectReconnect(ctx) { return } @@ -135,7 +135,7 @@ func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.Bridge if bsq.firstTransientDisconnect.IsZero() { bsq.firstTransientDisconnect = time.Now() } - go bsq.scheduleNotice(ctx, state) + go bsq.scheduleNotice(state) } return } diff --git a/bridgev2/database/database.go b/bridgev2/database/database.go index 0729cb83..05abddf0 100644 --- a/bridgev2/database/database.go +++ b/bridgev2/database/database.go @@ -7,13 +7,7 @@ package database import ( - "encoding/json" - "reflect" - "strings" - "go.mau.fi/util/dbutil" - "golang.org/x/exp/constraints" - "golang.org/x/exp/maps" "maunium.net/go/mautrix/bridgev2/networkid" @@ -158,55 +152,3 @@ func ensureBridgeIDMatches(ptr *networkid.BridgeID, expected networkid.BridgeID) panic("bridge ID mismatch") } } - -func GetNumberFromMap[T constraints.Integer | constraints.Float](m map[string]any, key string) (T, bool) { - if val, found := m[key]; found { - floatVal, ok := val.(float64) - if ok { - return T(floatVal), true - } - tVal, ok := val.(T) - if ok { - return tVal, true - } - } - return 0, false -} - -func unmarshalMerge(input []byte, data any, extra *map[string]any) error { - err := json.Unmarshal(input, data) - if err != nil { - return err - } - err = json.Unmarshal(input, extra) - if err != nil { - return err - } - if *extra == nil { - *extra = make(map[string]any) - } - return nil -} - -func marshalMerge(data any, extra map[string]any) ([]byte, error) { - if extra == nil { - return json.Marshal(data) - } - merged := make(map[string]any) - maps.Copy(merged, extra) - dataRef := reflect.ValueOf(data).Elem() - dataType := dataRef.Type() - for _, field := range reflect.VisibleFields(dataType) { - parts := strings.Split(field.Tag.Get("json"), ",") - if len(parts) == 0 || len(parts[0]) == 0 || parts[0] == "-" { - continue - } - fieldVal := dataRef.FieldByIndex(field.Index) - if fieldVal.IsZero() { - delete(merged, parts[0]) - } else { - merged[parts[0]] = fieldVal.Interface() - } - } - return json.Marshal(merged) -} diff --git a/bridgev2/matrix/crypto.go b/bridgev2/matrix/crypto.go index f4a2e9a0..7f18f1f5 100644 --- a/bridgev2/matrix/crypto.go +++ b/bridgev2/matrix/crypto.go @@ -38,9 +38,9 @@ func init() { var _ crypto.StateStore = (*sqlstatestore.SQLStateStore)(nil) -var NoSessionFound = crypto.NoSessionFound -var DuplicateMessageIndex = crypto.DuplicateMessageIndex -var UnknownMessageIndex = olm.UnknownMessageIndex +var NoSessionFound = crypto.ErrNoSessionFound +var DuplicateMessageIndex = crypto.ErrDuplicateMessageIndex +var UnknownMessageIndex = olm.ErrUnknownMessageIndex type CryptoHelper struct { bridge *Connector @@ -439,7 +439,7 @@ func (helper *CryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtTy var encrypted *event.EncryptedEventContent encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content) if err != nil { - if !errors.Is(err, crypto.SessionExpired) && !errors.Is(err, crypto.SessionNotShared) && !errors.Is(err, crypto.NoGroupSession) { + if !errors.Is(err, crypto.ErrSessionExpired) && !errors.Is(err, crypto.ErrSessionNotShared) && !errors.Is(err, crypto.ErrNoGroupSession) { return } helper.log.Debug().Err(err). diff --git a/bridgev2/matrix/matrix.go b/bridgev2/matrix/matrix.go index 6c94bccc..570ae5f1 100644 --- a/bridgev2/matrix/matrix.go +++ b/bridgev2/matrix/matrix.go @@ -127,6 +127,7 @@ func (br *Connector) waitLongerForSession(ctx context.Context, evt *event.Event, Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())). Msg("Couldn't find session, requesting keys and waiting longer...") + //lint:ignore SA1019 RequestSession will gracefully request from all devices if DeviceID is blank go br.Crypto.RequestSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) go br.sendCryptoStatusError(ctx, evt, fmt.Errorf("%w. The bridge will retry for %d seconds", errNoDecryptionKeys, int(extendedSessionWaitTimeout.Seconds())), errorEventID, 1, false) diff --git a/bridgev2/matrix/mxmain/legacymigrate.go b/bridgev2/matrix/mxmain/legacymigrate.go index c8eb820b..97cdeddf 100644 --- a/bridgev2/matrix/mxmain/legacymigrate.go +++ b/bridgev2/matrix/mxmain/legacymigrate.go @@ -135,7 +135,10 @@ func (br *BridgeMain) CheckLegacyDB( } var dbVersion int err = br.DB.QueryRow(ctx, "SELECT version FROM version").Scan(&dbVersion) - if dbVersion < expectedVersion { + if err != nil { + log.Fatal().Err(err).Msg("Failed to get database version") + return + } else if dbVersion < expectedVersion { log.Fatal(). Int("expected_version", expectedVersion). Int("version", dbVersion). diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 43d19380..44e00e64 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -85,10 +85,9 @@ const ( provisioningUserKey provisioningContextKey = iota provisioningUserLoginKey provisioningLoginProcessKey + ProvisioningKeyRequest ) -const ProvisioningKeyRequest = "fi.mau.provision.request" - func (prov *ProvisioningAPI) GetUser(r *http.Request) *bridgev2.User { return r.Context().Value(provisioningUserKey).(*bridgev2.User) } diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index e8292388..879f07ae 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -410,6 +410,7 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin if reaction.Timestamp.IsZero() { reaction.Timestamp = msg.Timestamp.Add(10 * time.Millisecond) } + //lint:ignore SA4006 it's a todo targetPart, ok := partMap[*reaction.TargetPart] if !ok { // TODO warning log and/or skip reaction? diff --git a/bridgev2/portalreid.go b/bridgev2/portalreid.go index a25fe820..d1a9d5a6 100644 --- a/bridgev2/portalreid.go +++ b/bridgev2/portalreid.go @@ -96,7 +96,7 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta go func() { _, err := br.Bot.SendState(ctx, sourcePortal.MXID, event.StateTombstone, "", &event.Content{ Parsed: &event.TombstoneEventContent{ - Body: fmt.Sprintf("This room has been merged"), + Body: "This room has been merged", ReplacementRoom: targetPortal.MXID, }, }, time.Now()) diff --git a/client.go b/client.go index 9961e717..b740cba6 100644 --- a/client.go +++ b/client.go @@ -742,7 +742,7 @@ func (cli *Client) executeCompiledRequest( cli.RequestStart(req) startTime := time.Now() res, err := client.Do(req) - duration := time.Now().Sub(startTime) + duration := time.Since(startTime) if res != nil && !dontReadResponse { defer res.Body.Close() } @@ -862,7 +862,7 @@ func (cli *Client) FullSyncRequest(ctx context.Context, req ReqSync) (resp *Resp } start := time.Now() _, err = cli.MakeFullRequest(ctx, fullReq) - duration := time.Now().Sub(start) + duration := time.Since(start) timeout := time.Duration(req.Timeout) * time.Millisecond buffer := 10 * time.Second if req.Since == "" { @@ -966,7 +966,7 @@ func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister) (*RespRe // } // token := res.AccessToken func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister) (*RespRegister, error) { - res, uia, err := cli.Register(ctx, req) + _, uia, err := cli.Register(ctx, req) if err != nil && uia == nil { return nil, err } else if uia == nil { @@ -975,7 +975,7 @@ func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister) (*RespRe return nil, errors.New("server does not support m.login.dummy") } req.Auth = BaseAuthData{Type: AuthTypeDummy, Session: uia.Session} - res, _, err = cli.Register(ctx, req) + res, _, err := cli.Register(ctx, req) if err != nil { return nil, err } @@ -1751,6 +1751,8 @@ func parseRoomStateArray(req *http.Request, res *http.Response, responseJSON any return nil, nil } +type RoomStateMap = map[event.Type]map[string]*event.Event + // State gets all state in a room. // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidstate func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomStateMap, err error) { diff --git a/crypto/attachment/attachments.go b/crypto/attachment/attachments.go index 155cca5c..727aacbf 100644 --- a/crypto/attachment/attachments.go +++ b/crypto/attachment/attachments.go @@ -21,13 +21,24 @@ import ( ) var ( - HashMismatch = errors.New("mismatching SHA-256 digest") - UnsupportedVersion = errors.New("unsupported Matrix file encryption version") - UnsupportedAlgorithm = errors.New("unsupported JWK encryption algorithm") - InvalidKey = errors.New("failed to decode key") - InvalidInitVector = errors.New("failed to decode initialization vector") - InvalidHash = errors.New("failed to decode SHA-256 hash") - ReaderClosed = errors.New("encrypting reader was already closed") + ErrHashMismatch = errors.New("mismatching SHA-256 digest") + ErrUnsupportedVersion = errors.New("unsupported Matrix file encryption version") + ErrUnsupportedAlgorithm = errors.New("unsupported JWK encryption algorithm") + ErrInvalidKey = errors.New("failed to decode key") + ErrInvalidInitVector = errors.New("failed to decode initialization vector") + ErrInvalidHash = errors.New("failed to decode SHA-256 hash") + ErrReaderClosed = errors.New("encrypting reader was already closed") +) + +// Deprecated: use variables prefixed with Err +var ( + HashMismatch = ErrHashMismatch + UnsupportedVersion = ErrUnsupportedVersion + UnsupportedAlgorithm = ErrUnsupportedAlgorithm + InvalidKey = ErrInvalidKey + InvalidInitVector = ErrInvalidInitVector + InvalidHash = ErrInvalidHash + ReaderClosed = ErrReaderClosed ) var ( @@ -85,25 +96,25 @@ func (ef *EncryptedFile) decodeKeys(includeHash bool) error { if ef.decoded != nil { return nil } else if len(ef.Key.Key) != keyBase64Length { - return InvalidKey + return ErrInvalidKey } else if len(ef.InitVector) != ivBase64Length { - return InvalidInitVector + return ErrInvalidInitVector } else if includeHash && len(ef.Hashes.SHA256) != hashBase64Length { - return InvalidHash + return ErrInvalidHash } ef.decoded = &decodedKeys{} _, err := base64.RawURLEncoding.Decode(ef.decoded.key[:], []byte(ef.Key.Key)) if err != nil { - return InvalidKey + return ErrInvalidKey } _, err = base64.RawStdEncoding.Decode(ef.decoded.iv[:], []byte(ef.InitVector)) if err != nil { - return InvalidInitVector + return ErrInvalidInitVector } if includeHash { _, err = base64.RawStdEncoding.Decode(ef.decoded.sha256[:], []byte(ef.Hashes.SHA256)) if err != nil { - return InvalidHash + return ErrInvalidHash } } return nil @@ -179,7 +190,7 @@ var _ io.ReadSeekCloser = (*encryptingReader)(nil) func (r *encryptingReader) Seek(offset int64, whence int) (int64, error) { if r.closed { - return 0, ReaderClosed + return 0, ErrReaderClosed } if offset != 0 || whence != io.SeekStart { return 0, fmt.Errorf("attachments.EncryptStream: only seeking to the beginning is supported") @@ -200,7 +211,7 @@ func (r *encryptingReader) Seek(offset int64, whence int) (int64, error) { func (r *encryptingReader) Read(dst []byte) (n int, err error) { if r.closed { - return 0, ReaderClosed + return 0, ErrReaderClosed } else if r.isDecrypting && r.file.decoded == nil { if err = r.file.PrepareForDecryption(); err != nil { return @@ -224,7 +235,7 @@ func (r *encryptingReader) Close() (err error) { } if r.isDecrypting { if !hmac.Equal(r.hash.Sum(nil), r.file.decoded.sha256[:]) { - return HashMismatch + return ErrHashMismatch } } else { r.file.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString(r.hash.Sum(nil)) @@ -265,9 +276,9 @@ func (ef *EncryptedFile) Decrypt(ciphertext []byte) ([]byte, error) { // DecryptInPlace will always call this automatically, so calling this manually is not necessary when using that function. func (ef *EncryptedFile) PrepareForDecryption() error { if ef.Version != "v2" { - return UnsupportedVersion + return ErrUnsupportedVersion } else if ef.Key.Algorithm != "A256CTR" { - return UnsupportedAlgorithm + return ErrUnsupportedAlgorithm } else if err := ef.decodeKeys(true); err != nil { return err } @@ -281,7 +292,7 @@ func (ef *EncryptedFile) DecryptInPlace(data []byte) error { } dataHash := sha256.Sum256(data) if !hmac.Equal(ef.decoded.sha256[:], dataHash[:]) { - return HashMismatch + return ErrHashMismatch } utils.XorA256CTR(data, ef.decoded.key, ef.decoded.iv) return nil diff --git a/crypto/attachment/attachments_test.go b/crypto/attachment/attachments_test.go index d7f1394a..9fe929ab 100644 --- a/crypto/attachment/attachments_test.go +++ b/crypto/attachment/attachments_test.go @@ -53,33 +53,33 @@ func TestUnsupportedVersion(t *testing.T) { file := parseHelloWorld() file.Version = "foo" err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, UnsupportedVersion) + assert.ErrorIs(t, err, ErrUnsupportedVersion) } func TestUnsupportedAlgorithm(t *testing.T) { file := parseHelloWorld() file.Key.Algorithm = "bar" err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, UnsupportedAlgorithm) + assert.ErrorIs(t, err, ErrUnsupportedAlgorithm) } func TestHashMismatch(t *testing.T) { file := parseHelloWorld() file.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString([]byte(random32Bytes)) err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, HashMismatch) + assert.ErrorIs(t, err, ErrHashMismatch) } func TestTooLongHash(t *testing.T) { file := parseHelloWorld() file.Hashes.SHA256 = "TG9yZW0gaXBzdW0gZG9sb3Igc2l0IGFtZXQsIGNvbnNlY3RldHVlciBhZGlwaXNjaW5nIGVsaXQuIFNlZCBwb3N1ZXJlIGludGVyZHVtIHNlbS4gUXVpc3F1ZSBsaWd1bGEgZXJvcyB1bGxhbWNvcnBlciBxdWlzLCBsYWNpbmlhIHF1aXMgZmFjaWxpc2lzIHNlZCBzYXBpZW4uCg" err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, InvalidHash) + assert.ErrorIs(t, err, ErrInvalidHash) } func TestTooShortHash(t *testing.T) { file := parseHelloWorld() file.Hashes.SHA256 = "5/Gy1JftyyQ" err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, InvalidHash) + assert.ErrorIs(t, err, ErrInvalidHash) } diff --git a/crypto/cross_sign_pubkey.go b/crypto/cross_sign_pubkey.go index f85d1ea3..223fc7b5 100644 --- a/crypto/cross_sign_pubkey.go +++ b/crypto/cross_sign_pubkey.go @@ -63,8 +63,8 @@ func (mach *OlmMachine) GetCrossSigningPublicKeys(ctx context.Context, userID id if len(dbKeys) > 0 { masterKey, ok := dbKeys[id.XSUsageMaster] if ok { - selfSigning, _ := dbKeys[id.XSUsageSelfSigning] - userSigning, _ := dbKeys[id.XSUsageUserSigning] + selfSigning := dbKeys[id.XSUsageSelfSigning] + userSigning := dbKeys[id.XSUsageUserSigning] return &CrossSigningPublicKeysCache{ MasterKey: masterKey.Key, SelfSigningKey: selfSigning.Key, diff --git a/crypto/cross_sign_store.go b/crypto/cross_sign_store.go index d30b7e32..57406b11 100644 --- a/crypto/cross_sign_store.go +++ b/crypto/cross_sign_store.go @@ -26,24 +26,22 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK log.Error().Err(err). Msg("Error fetching current cross-signing keys of user") } - if currentKeys != nil { - for curKeyUsage, curKey := range currentKeys { - log := log.With().Stringer("old_key", curKey.Key).Str("old_key_usage", string(curKeyUsage)).Logger() - // got a new key with the same usage as an existing key - for _, newKeyUsage := range userKeys.Usage { - if newKeyUsage == curKeyUsage { - if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok { - // old key is not in the new key map, so we drop signatures made by it - if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil { - log.Error().Err(err).Msg("Error deleting old signatures made by user") - } else { - log.Debug(). - Int64("signature_count", count). - Msg("Dropped signatures made by old key as it has been replaced") - } + for curKeyUsage, curKey := range currentKeys { + log := log.With().Stringer("old_key", curKey.Key).Str("old_key_usage", string(curKeyUsage)).Logger() + // got a new key with the same usage as an existing key + for _, newKeyUsage := range userKeys.Usage { + if newKeyUsage == curKeyUsage { + if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok { + // old key is not in the new key map, so we drop signatures made by it + if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil { + log.Error().Err(err).Msg("Error deleting old signatures made by user") + } else { + log.Debug(). + Int64("signature_count", count). + Msg("Dropped signatures made by old key as it has been replaced") } - break } + break } } } diff --git a/crypto/cryptohelper/cryptohelper.go b/crypto/cryptohelper/cryptohelper.go index 1939ea79..b62dc128 100644 --- a/crypto/cryptohelper/cryptohelper.go +++ b/crypto/cryptohelper/cryptohelper.go @@ -278,7 +278,7 @@ func (helper *CryptoHelper) verifyDeviceKeysOnServer(ctx context.Context) error } } -var NoSessionFound = crypto.NoSessionFound +var NoSessionFound = crypto.ErrNoSessionFound const initialSessionWaitTimeout = 3 * time.Second const extendedSessionWaitTimeout = 22 * time.Second @@ -371,6 +371,7 @@ func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, evt *event content := evt.Content.AsEncrypted() log.Debug().Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).Msg("Couldn't find session, requesting keys and waiting longer...") + //lint:ignore SA1019 RequestSession will gracefully request from all devices if DeviceID is blank go helper.RequestSession(context.TODO(), evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) if !helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) { @@ -418,7 +419,7 @@ func (helper *CryptoHelper) EncryptWithStateKey(ctx context.Context, roomID id.R defer helper.lock.RUnlock() encrypted, err = helper.mach.EncryptMegolmEventWithStateKey(ctx, roomID, evtType, stateKey, content) if err != nil { - if !errors.Is(err, crypto.SessionExpired) && err != crypto.NoGroupSession && !errors.Is(err, crypto.SessionNotShared) { + if !errors.Is(err, crypto.ErrSessionExpired) && err != crypto.ErrNoGroupSession && !errors.Is(err, crypto.ErrSessionNotShared) { return } helper.log.Debug(). diff --git a/crypto/decryptmegolm.go b/crypto/decryptmegolm.go index 47279474..d8b419ab 100644 --- a/crypto/decryptmegolm.go +++ b/crypto/decryptmegolm.go @@ -24,13 +24,24 @@ import ( ) var ( - IncorrectEncryptedContentType = errors.New("event content is not instance of *event.EncryptedEventContent") - NoSessionFound = errors.New("failed to decrypt megolm event: no session with given ID found") - DuplicateMessageIndex = errors.New("duplicate megolm message index") - WrongRoom = errors.New("encrypted megolm event is not intended for this room") - DeviceKeyMismatch = errors.New("device keys in event and verified device info do not match") - SenderKeyMismatch = errors.New("sender keys in content and megolm session do not match") - RatchetError = errors.New("failed to ratchet session after use") + ErrIncorrectEncryptedContentType = errors.New("event content is not instance of *event.EncryptedEventContent") + ErrNoSessionFound = errors.New("failed to decrypt megolm event: no session with given ID found") + ErrDuplicateMessageIndex = errors.New("duplicate megolm message index") + ErrWrongRoom = errors.New("encrypted megolm event is not intended for this room") + ErrDeviceKeyMismatch = errors.New("device keys in event and verified device info do not match") + ErrSenderKeyMismatch = errors.New("sender keys in content and megolm session do not match") + ErrRatchetError = errors.New("failed to ratchet session after use") +) + +// Deprecated: use variables prefixed with Err +var ( + IncorrectEncryptedContentType = ErrIncorrectEncryptedContentType + NoSessionFound = ErrNoSessionFound + DuplicateMessageIndex = ErrDuplicateMessageIndex + WrongRoom = ErrWrongRoom + DeviceKeyMismatch = ErrDeviceKeyMismatch + SenderKeyMismatch = ErrSenderKeyMismatch + RatchetError = ErrRatchetError ) type megolmEvent struct { @@ -49,9 +60,9 @@ var ( func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event) (*event.Event, error) { content, ok := evt.Content.Parsed.(*event.EncryptedEventContent) if !ok { - return nil, IncorrectEncryptedContentType + return nil, ErrIncorrectEncryptedContentType } else if content.Algorithm != id.AlgorithmMegolmV1 { - return nil, UnsupportedAlgorithm + return nil, ErrUnsupportedAlgorithm } log := mach.machOrContextLog(ctx).With(). Str("action", "decrypt megolm event"). @@ -97,7 +108,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event Msg("Couldn't resolve trust level of session: sent by unknown device") trustLevel = id.TrustStateUnknownDevice } else if device.SigningKey != sess.SigningKey || device.IdentityKey != sess.SenderKey { - return nil, DeviceKeyMismatch + return nil, ErrDeviceKeyMismatch } else { trustLevel, err = mach.ResolveTrustContext(ctx, device) if err != nil { @@ -147,7 +158,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event if err != nil { return nil, fmt.Errorf("failed to parse megolm payload: %w", err) } else if megolmEvt.RoomID != encryptionRoomID { - return nil, WrongRoom + return nil, ErrWrongRoom } if evt.StateKey != nil && megolmEvt.StateKey != nil && mach.AllowEncryptedState { megolmEvt.Type.Class = event.StateEventType @@ -201,19 +212,19 @@ func (mach *OlmMachine) checkUndecryptableMessageIndexDuplication(ctx context.Co messageIndex, decodeErr := ParseMegolmMessageIndex(content.MegolmCiphertext) if decodeErr != nil { log.Warn().Err(decodeErr).Msg("Failed to parse message index to check if it's a duplicate for message that failed to decrypt") - return 0, fmt.Errorf("%w (also failed to parse message index)", olm.UnknownMessageIndex) + return 0, fmt.Errorf("%w (also failed to parse message index)", olm.ErrUnknownMessageIndex) } firstKnown := sess.Internal.FirstKnownIndex() log = log.With().Uint("message_index", messageIndex).Uint32("first_known_index", firstKnown).Logger() if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil { log.Debug().Err(err).Msg("Failed to check if message index is duplicate") - return messageIndex, fmt.Errorf("%w (failed to check if index is duplicate; received: %d, earliest known: %d)", olm.UnknownMessageIndex, messageIndex, firstKnown) + return messageIndex, fmt.Errorf("%w (failed to check if index is duplicate; received: %d, earliest known: %d)", olm.ErrUnknownMessageIndex, messageIndex, firstKnown) } else if !ok { log.Debug().Msg("Failed to decrypt message due to unknown index and found duplicate") - return messageIndex, fmt.Errorf("%w %d (also failed to decrypt because earliest known index is %d)", DuplicateMessageIndex, messageIndex, firstKnown) + return messageIndex, fmt.Errorf("%w %d (also failed to decrypt because earliest known index is %d)", ErrDuplicateMessageIndex, messageIndex, firstKnown) } log.Debug().Msg("Failed to decrypt message due to unknown index, but index is not duplicate") - return messageIndex, fmt.Errorf("%w (not duplicate index; received: %d, earliest known: %d)", olm.UnknownMessageIndex, messageIndex, firstKnown) + return messageIndex, fmt.Errorf("%w (not duplicate index; received: %d, earliest known: %d)", olm.ErrUnknownMessageIndex, messageIndex, firstKnown) } func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *event.Event, encryptionRoomID id.RoomID, content *event.EncryptedEventContent) (*InboundGroupSession, []byte, uint, error) { @@ -224,13 +235,13 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve if err != nil { return nil, nil, 0, fmt.Errorf("failed to get group session: %w", err) } else if sess == nil { - return nil, nil, 0, fmt.Errorf("%w (ID %s)", NoSessionFound, content.SessionID) + return nil, nil, 0, fmt.Errorf("%w (ID %s)", ErrNoSessionFound, content.SessionID) } else if content.SenderKey != "" && content.SenderKey != sess.SenderKey { - return sess, nil, 0, SenderKeyMismatch + return sess, nil, 0, ErrSenderKeyMismatch } plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext) if err != nil { - if errors.Is(err, olm.UnknownMessageIndex) && mach.RatchetKeysOnDecrypt { + if errors.Is(err, olm.ErrUnknownMessageIndex) && mach.RatchetKeysOnDecrypt { messageIndex, err = mach.checkUndecryptableMessageIndexDuplication(ctx, sess, evt, content) return sess, nil, messageIndex, fmt.Errorf("failed to decrypt megolm event: %w", err) } @@ -238,7 +249,7 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve } else if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil { return sess, nil, messageIndex, fmt.Errorf("failed to check if message index is duplicate: %w", err) } else if !ok { - return sess, nil, messageIndex, fmt.Errorf("%w %d", DuplicateMessageIndex, messageIndex) + return sess, nil, messageIndex, fmt.Errorf("%w %d", ErrDuplicateMessageIndex, messageIndex) } // Normal clients don't care about tracking the ratchet state, so let them bypass the rest of the function @@ -290,24 +301,24 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve err = mach.CryptoStore.RedactGroupSession(ctx, sess.RoomID, sess.ID(), "maximum messages reached") if err != nil { log.Err(err).Msg("Failed to delete fully used session") - return sess, plaintext, messageIndex, RatchetError + return sess, plaintext, messageIndex, ErrRatchetError } else { log.Info().Msg("Deleted fully used session") } } else if ratchetCurrentIndex < ratchetTargetIndex && mach.RatchetKeysOnDecrypt { if err = sess.RatchetTo(ratchetTargetIndex); err != nil { log.Err(err).Msg("Failed to ratchet session") - return sess, plaintext, messageIndex, RatchetError + return sess, plaintext, messageIndex, ErrRatchetError } else if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil { log.Err(err).Msg("Failed to store ratcheted session") - return sess, plaintext, messageIndex, RatchetError + return sess, plaintext, messageIndex, ErrRatchetError } else { log.Info().Msg("Ratcheted session forward") } } else if didModify { if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil { log.Err(err).Msg("Failed to store updated ratchet safety data") - return sess, plaintext, messageIndex, RatchetError + return sess, plaintext, messageIndex, ErrRatchetError } else { log.Debug().Msg("Ratchet safety data changed (ratchet state didn't change)") } diff --git a/crypto/decryptolm.go b/crypto/decryptolm.go index 30cc4cfe..cd02726d 100644 --- a/crypto/decryptolm.go +++ b/crypto/decryptolm.go @@ -26,15 +26,27 @@ import ( ) var ( - UnsupportedAlgorithm = errors.New("unsupported event encryption algorithm") - NotEncryptedForMe = errors.New("olm event doesn't contain ciphertext for this device") - UnsupportedOlmMessageType = errors.New("unsupported olm message type") - DecryptionFailedWithMatchingSession = errors.New("decryption failed with matching session") - DecryptionFailedForNormalMessage = errors.New("decryption failed for normal message") - SenderMismatch = errors.New("mismatched sender in olm payload") - RecipientMismatch = errors.New("mismatched recipient in olm payload") - RecipientKeyMismatch = errors.New("mismatched recipient key in olm payload") - ErrDuplicateMessage = errors.New("duplicate olm message") + ErrUnsupportedAlgorithm = errors.New("unsupported event encryption algorithm") + ErrNotEncryptedForMe = errors.New("olm event doesn't contain ciphertext for this device") + ErrUnsupportedOlmMessageType = errors.New("unsupported olm message type") + ErrDecryptionFailedWithMatchingSession = errors.New("decryption failed with matching session") + ErrDecryptionFailedForNormalMessage = errors.New("decryption failed for normal message") + ErrSenderMismatch = errors.New("mismatched sender in olm payload") + ErrRecipientMismatch = errors.New("mismatched recipient in olm payload") + ErrRecipientKeyMismatch = errors.New("mismatched recipient key in olm payload") + ErrDuplicateMessage = errors.New("duplicate olm message") +) + +// Deprecated: use variables prefixed with Err +var ( + UnsupportedAlgorithm = ErrUnsupportedAlgorithm + NotEncryptedForMe = ErrNotEncryptedForMe + UnsupportedOlmMessageType = ErrUnsupportedOlmMessageType + DecryptionFailedWithMatchingSession = ErrDecryptionFailedWithMatchingSession + DecryptionFailedForNormalMessage = ErrDecryptionFailedForNormalMessage + SenderMismatch = ErrSenderMismatch + RecipientMismatch = ErrRecipientMismatch + RecipientKeyMismatch = ErrRecipientKeyMismatch ) // DecryptedOlmEvent represents an event that was decrypted from an event encrypted with the m.olm.v1.curve25519-aes-sha2 algorithm. @@ -56,13 +68,13 @@ type DecryptedOlmEvent struct { func (mach *OlmMachine) decryptOlmEvent(ctx context.Context, evt *event.Event) (*DecryptedOlmEvent, error) { content, ok := evt.Content.Parsed.(*event.EncryptedEventContent) if !ok { - return nil, IncorrectEncryptedContentType + return nil, ErrIncorrectEncryptedContentType } else if content.Algorithm != id.AlgorithmOlmV1 { - return nil, UnsupportedAlgorithm + return nil, ErrUnsupportedAlgorithm } ownContent, ok := content.OlmCiphertext[mach.account.IdentityKey()] if !ok { - return nil, NotEncryptedForMe + return nil, ErrNotEncryptedForMe } decrypted, err := mach.decryptAndParseOlmCiphertext(ctx, evt, content.SenderKey, ownContent.Type, ownContent.Body) if err != nil { @@ -78,7 +90,7 @@ type OlmEventKeys struct { func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *event.Event, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) (*DecryptedOlmEvent, error) { if olmType != id.OlmMsgTypePreKey && olmType != id.OlmMsgTypeMsg { - return nil, UnsupportedOlmMessageType + return nil, ErrUnsupportedOlmMessageType } log := mach.machOrContextLog(ctx).With(). @@ -102,11 +114,11 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *e } olmEvt.Type.Class = evt.Type.Class if evt.Sender != olmEvt.Sender { - return nil, SenderMismatch + return nil, ErrSenderMismatch } else if mach.Client.UserID != olmEvt.Recipient { - return nil, RecipientMismatch + return nil, ErrRecipientMismatch } else if mach.account.SigningKey() != olmEvt.RecipientKeys.Ed25519 { - return nil, RecipientKeyMismatch + return nil, ErrRecipientKeyMismatch } if len(olmEvt.Content.VeryRaw) > 0 { @@ -151,7 +163,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U plaintext, err := mach.tryDecryptOlmCiphertextWithExistingSession(ctx, senderKey, olmType, ciphertext, ciphertextHash) if err != nil { - if err == DecryptionFailedWithMatchingSession { + if err == ErrDecryptionFailedWithMatchingSession { log.Warn().Msg("Found matching session, but decryption failed") go mach.unwedgeDevice(log, sender, senderKey) } @@ -169,10 +181,10 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U // if it isn't one at this point in time anymore, so return early. if olmType != id.OlmMsgTypePreKey { go mach.unwedgeDevice(log, sender, senderKey) - return nil, DecryptionFailedForNormalMessage + return nil, ErrDecryptionFailedForNormalMessage } - accountBackup, err := mach.account.Internal.Pickle([]byte("tmp")) + accountBackup, _ := mach.account.Internal.Pickle([]byte("tmp")) log.Trace().Msg("Trying to create inbound session") endTimeTrace = mach.timeTrace(ctx, "creating inbound olm session", time.Second) session, err := mach.createInboundSession(ctx, senderKey, ciphertext) @@ -302,7 +314,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession( Str("session_description", session.Describe()). Msg("Failed to decrypt olm message") if olmType == id.OlmMsgTypePreKey { - return nil, DecryptionFailedWithMatchingSession + return nil, ErrDecryptionFailedWithMatchingSession } } else { endTimeTrace = mach.timeTrace(ctx, "updating session in database", time.Second) @@ -345,7 +357,7 @@ func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, send ctx := log.WithContext(mach.backgroundCtx) mach.recentlyUnwedgedLock.Lock() prevUnwedge, ok := mach.recentlyUnwedged[senderKey] - delta := time.Now().Sub(prevUnwedge) + delta := time.Since(prevUnwedge) if ok && delta < MinUnwedgeInterval { log.Debug(). Str("previous_recreation", delta.String()). diff --git a/crypto/devicelist.go b/crypto/devicelist.go index 61a22522..f0d2b129 100644 --- a/crypto/devicelist.go +++ b/crypto/devicelist.go @@ -22,14 +22,23 @@ import ( ) var ( - MismatchingDeviceID = errors.New("mismatching device ID in parameter and keys object") - MismatchingUserID = errors.New("mismatching user ID in parameter and keys object") - MismatchingSigningKey = errors.New("received update for device with different signing key") - NoSigningKeyFound = errors.New("didn't find ed25519 signing key") - NoIdentityKeyFound = errors.New("didn't find curve25519 identity key") - InvalidKeySignature = errors.New("invalid signature on device keys") + ErrMismatchingDeviceID = errors.New("mismatching device ID in parameter and keys object") + ErrMismatchingUserID = errors.New("mismatching user ID in parameter and keys object") + ErrMismatchingSigningKey = errors.New("received update for device with different signing key") + ErrNoSigningKeyFound = errors.New("didn't find ed25519 signing key") + ErrNoIdentityKeyFound = errors.New("didn't find curve25519 identity key") + ErrInvalidKeySignature = errors.New("invalid signature on device keys") + ErrUserNotTracked = errors.New("user is not tracked") +) - ErrUserNotTracked = errors.New("user is not tracked") +// Deprecated: use variables prefixed with Err +var ( + MismatchingDeviceID = ErrMismatchingDeviceID + MismatchingUserID = ErrMismatchingUserID + MismatchingSigningKey = ErrMismatchingSigningKey + NoSigningKeyFound = ErrNoSigningKeyFound + NoIdentityKeyFound = ErrNoIdentityKeyFound + InvalidKeySignature = ErrInvalidKeySignature ) func (mach *OlmMachine) LoadDevices(ctx context.Context, user id.UserID) (keys map[id.DeviceID]*id.Device) { @@ -312,28 +321,28 @@ func (mach *OlmMachine) OnDevicesChanged(ctx context.Context, userID id.UserID) func (mach *OlmMachine) validateDevice(userID id.UserID, deviceID id.DeviceID, deviceKeys mautrix.DeviceKeys, existing *id.Device) (*id.Device, error) { if deviceID != deviceKeys.DeviceID { - return nil, fmt.Errorf("%w (expected %s, got %s)", MismatchingDeviceID, deviceID, deviceKeys.DeviceID) + return nil, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingDeviceID, deviceID, deviceKeys.DeviceID) } else if userID != deviceKeys.UserID { - return nil, fmt.Errorf("%w (expected %s, got %s)", MismatchingUserID, userID, deviceKeys.UserID) + return nil, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingUserID, userID, deviceKeys.UserID) } signingKey := deviceKeys.Keys.GetEd25519(deviceID) identityKey := deviceKeys.Keys.GetCurve25519(deviceID) if signingKey == "" { - return nil, NoSigningKeyFound + return nil, ErrNoSigningKeyFound } else if identityKey == "" { - return nil, NoIdentityKeyFound + return nil, ErrNoIdentityKeyFound } if existing != nil && existing.SigningKey != signingKey { - return existing, fmt.Errorf("%w (expected %s, got %s)", MismatchingSigningKey, existing.SigningKey, signingKey) + return existing, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingSigningKey, existing.SigningKey, signingKey) } ok, err := signatures.VerifySignatureJSON(deviceKeys, userID, deviceID.String(), signingKey) if err != nil { return existing, fmt.Errorf("failed to verify signature: %w", err) } else if !ok { - return existing, InvalidKeySignature + return existing, ErrInvalidKeySignature } name, ok := deviceKeys.Unsigned["device_display_name"].(string) diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index b3d19618..ea97f767 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -25,7 +25,12 @@ import ( ) var ( - NoGroupSession = errors.New("no group session created") + ErrNoGroupSession = errors.New("no group session created") +) + +// Deprecated: use variables prefixed with Err +var ( + NoGroupSession = ErrNoGroupSession ) func getRawJSON[T any](content json.RawMessage, path ...string) *T { @@ -82,7 +87,7 @@ type rawMegolmEvent struct { // IsShareError returns true if the error is caused by the lack of an outgoing megolm session and can be solved with OlmMachine.ShareGroupSession func IsShareError(err error) bool { - return err == SessionExpired || err == SessionNotShared || err == NoGroupSession + return err == ErrSessionExpired || err == ErrSessionNotShared || err == ErrNoGroupSession } func ParseMegolmMessageIndex(ciphertext []byte) (uint, error) { @@ -120,7 +125,7 @@ func (mach *OlmMachine) EncryptMegolmEventWithStateKey(ctx context.Context, room if err != nil { return nil, fmt.Errorf("failed to get outbound group session: %w", err) } else if session == nil { - return nil, NoGroupSession + return nil, ErrNoGroupSession } plaintext, err := json.Marshal(&rawMegolmEvent{ RoomID: roomID, diff --git a/crypto/goolm/account/account.go b/crypto/goolm/account/account.go index 4da08a73..b48843a4 100644 --- a/crypto/goolm/account/account.go +++ b/crypto/goolm/account/account.go @@ -334,7 +334,7 @@ func (a *Account) UnpickleLibOlm(buf []byte) error { if err != nil { return err } else if pickledVersion != accountPickleVersionLibOLM && pickledVersion != 3 && pickledVersion != 2 { - return fmt.Errorf("unpickle account: %w (found version %d)", olm.ErrBadVersion, pickledVersion) + return fmt.Errorf("unpickle account: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion) } else if err = a.IdKeys.Ed25519.UnpickleLibOlm(decoder); err != nil { // read the ed25519 key pair return err } else if err = a.IdKeys.Curve25519.UnpickleLibOlm(decoder); err != nil { // read curve25519 key pair diff --git a/crypto/goolm/account/account_test.go b/crypto/goolm/account/account_test.go index e1c9b452..d0dec5f0 100644 --- a/crypto/goolm/account/account_test.go +++ b/crypto/goolm/account/account_test.go @@ -124,7 +124,7 @@ func TestOldAccountPickle(t *testing.T) { account, err := account.NewAccount() assert.NoError(t, err) err = account.Unpickle(pickled, pickleKey) - assert.ErrorIs(t, err, olm.ErrBadVersion) + assert.ErrorIs(t, err, olm.ErrUnknownOlmPickleVersion) } func TestLoopback(t *testing.T) { diff --git a/crypto/goolm/goolmbase64/base64.go b/crypto/goolm/goolmbase64/base64.go index 061a052a..58ee26f7 100644 --- a/crypto/goolm/goolmbase64/base64.go +++ b/crypto/goolm/goolmbase64/base64.go @@ -4,7 +4,8 @@ import ( "encoding/base64" ) -// Deprecated: base64.RawStdEncoding should be used directly +// These methods should only be used for raw byte operations, never with string conversion + func Decode(input []byte) ([]byte, error) { decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(input))) writtenBytes, err := base64.RawStdEncoding.Decode(decoded, input) @@ -14,7 +15,6 @@ func Decode(input []byte) ([]byte, error) { return decoded[:writtenBytes], nil } -// Deprecated: base64.RawStdEncoding should be used directly func Encode(input []byte) []byte { encoded := make([]byte, base64.RawStdEncoding.EncodedLen(len(input))) base64.RawStdEncoding.Encode(encoded, input) diff --git a/crypto/goolm/libolmpickle/picklejson.go b/crypto/goolm/libolmpickle/picklejson.go index 308e472c..f765391f 100644 --- a/crypto/goolm/libolmpickle/picklejson.go +++ b/crypto/goolm/libolmpickle/picklejson.go @@ -50,7 +50,7 @@ func UnpickleAsJSON(object any, pickled, key []byte, pickleVersion byte) error { } } if decrypted[0] != pickleVersion { - return fmt.Errorf("unpickle: %w", olm.ErrWrongPickleVersion) + return fmt.Errorf("unpickle: %w", olm.ErrUnknownJSONPickleVersion) } err = json.Unmarshal(decrypted[1:], object) if err != nil { diff --git a/crypto/goolm/message/session_export.go b/crypto/goolm/message/session_export.go index 956868b2..d58dbb21 100644 --- a/crypto/goolm/message/session_export.go +++ b/crypto/goolm/message/session_export.go @@ -35,7 +35,7 @@ func (s *MegolmSessionExport) Decode(input []byte) error { return fmt.Errorf("decrypt: %w", olm.ErrBadInput) } if input[0] != sessionExportVersion { - return fmt.Errorf("decrypt: %w", olm.ErrBadVersion) + return fmt.Errorf("decrypt: %w", olm.ErrUnknownOlmPickleVersion) } s.Counter = binary.BigEndian.Uint32(input[1:5]) copy(s.RatchetData[:], input[5:133]) diff --git a/crypto/goolm/message/session_sharing.go b/crypto/goolm/message/session_sharing.go index 16240945..d04ef15a 100644 --- a/crypto/goolm/message/session_sharing.go +++ b/crypto/goolm/message/session_sharing.go @@ -42,7 +42,7 @@ func (s *MegolmSessionSharing) VerifyAndDecode(input []byte) error { } s.PublicKey = publicKey if input[0] != sessionSharingVersion { - return fmt.Errorf("verify: %w", olm.ErrBadVersion) + return fmt.Errorf("verify: %w", olm.ErrUnknownOlmPickleVersion) } s.Counter = binary.BigEndian.Uint32(input[1:5]) copy(s.RatchetData[:], input[5:133]) diff --git a/crypto/goolm/pk/decryption.go b/crypto/goolm/pk/decryption.go index afb01f74..cdb20eb1 100644 --- a/crypto/goolm/pk/decryption.go +++ b/crypto/goolm/pk/decryption.go @@ -103,7 +103,7 @@ func (a *Decryption) UnpickleLibOlm(unpickled []byte) error { if pickledVersion == decryptionPickleVersionLibOlm { return a.KeyPair.UnpickleLibOlm(decoder) } else { - return fmt.Errorf("unpickle olmSession: %w (found %d, expected %d)", olm.ErrBadVersion, pickledVersion, decryptionPickleVersionLibOlm) + return fmt.Errorf("unpickle olmSession: %w (found %d, expected %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion, decryptionPickleVersionLibOlm) } } diff --git a/crypto/goolm/pk/encryption.go b/crypto/goolm/pk/encryption.go index 23f67ddf..2897d9b0 100644 --- a/crypto/goolm/pk/encryption.go +++ b/crypto/goolm/pk/encryption.go @@ -37,6 +37,9 @@ func (e Encryption) Encrypt(plaintext []byte, privateKey crypto.Curve25519Privat return nil, nil, err } cipher, err := aessha2.NewAESSHA2(sharedSecret, nil) + if err != nil { + return nil, nil, err + } ciphertext, err = cipher.Encrypt(plaintext) if err != nil { return nil, nil, err diff --git a/crypto/goolm/session/megolm_inbound_session.go b/crypto/goolm/session/megolm_inbound_session.go index 80dd71cc..fb88b73c 100644 --- a/crypto/goolm/session/megolm_inbound_session.go +++ b/crypto/goolm/session/megolm_inbound_session.go @@ -99,7 +99,7 @@ func (o *MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet, } if (messageIndex - o.InitialRatchet.Counter) >= uint32(1<<31) { // the counter is before our initial ratchet - we can't decode this - return nil, fmt.Errorf("decrypt: %w", olm.ErrRatchetNotAvailable) + return nil, fmt.Errorf("decrypt: %w", olm.ErrUnknownMessageIndex) } // otherwise, start from the initial ratchet. Take a copy so that we don't overwrite the initial ratchet copiedRatchet := o.InitialRatchet @@ -206,7 +206,7 @@ func (o *MegolmInboundSession) UnpickleLibOlm(value []byte) error { return err } if pickledVersion != megolmInboundSessionPickleVersionLibOlm && pickledVersion != 1 { - return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion) + return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion) } if err = o.InitialRatchet.UnpickleLibOlm(decoder); err != nil { diff --git a/crypto/goolm/session/megolm_outbound_session.go b/crypto/goolm/session/megolm_outbound_session.go index 2b8e1c84..7f923534 100644 --- a/crypto/goolm/session/megolm_outbound_session.go +++ b/crypto/goolm/session/megolm_outbound_session.go @@ -101,8 +101,10 @@ func (o *MegolmOutboundSession) Unpickle(pickled, key []byte) error { func (o *MegolmOutboundSession) UnpickleLibOlm(buf []byte) error { decoder := libolmpickle.NewDecoder(buf) pickledVersion, err := decoder.ReadUInt32() - if pickledVersion != megolmOutboundSessionPickleVersionLibOlm { - return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion) + if err != nil { + return fmt.Errorf("unpickle MegolmOutboundSession: failed to read version: %w", err) + } else if pickledVersion != megolmOutboundSessionPickleVersionLibOlm { + return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion) } if err = o.Ratchet.UnpickleLibOlm(decoder); err != nil { return err diff --git a/crypto/goolm/session/olm_session.go b/crypto/goolm/session/olm_session.go index b99ab630..a1cb8d66 100644 --- a/crypto/goolm/session/olm_session.go +++ b/crypto/goolm/session/olm_session.go @@ -168,11 +168,11 @@ func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, received msg := message.Message{} err = msg.Decode(oneTimeMsg.Message) if err != nil { - return nil, fmt.Errorf("Message decode: %w", err) + return nil, fmt.Errorf("message decode: %w", err) } if len(msg.RatchetKey) == 0 { - return nil, fmt.Errorf("Message missing ratchet key: %w", olm.ErrBadMessageFormat) + return nil, fmt.Errorf("message missing ratchet key: %w", olm.ErrBadMessageFormat) } //Init Ratchet s.Ratchet.InitializeAsBob(secret, msg.RatchetKey) @@ -203,7 +203,7 @@ func (s *OlmSession) ID() id.SessionID { copy(message[crypto.Curve25519PrivateKeyLength:], s.AliceBaseKey) copy(message[2*crypto.Curve25519PrivateKeyLength:], s.BobOneTimeKey) hash := sha256.Sum256(message) - res := id.SessionID(goolmbase64.Encode(hash[:])) + res := id.SessionID(base64.RawStdEncoding.EncodeToString(hash[:])) return res } @@ -325,7 +325,7 @@ func (s *OlmSession) Decrypt(crypttext string, msgType id.OlmMsgType) ([]byte, e if len(crypttext) == 0 { return nil, fmt.Errorf("decrypt: %w", olm.ErrEmptyInput) } - decodedCrypttext, err := goolmbase64.Decode([]byte(crypttext)) + decodedCrypttext, err := base64.RawStdEncoding.DecodeString(crypttext) if err != nil { return nil, err } @@ -365,6 +365,9 @@ func (o *OlmSession) Unpickle(pickled, key []byte) error { func (o *OlmSession) UnpickleLibOlm(buf []byte) error { decoder := libolmpickle.NewDecoder(buf) pickledVersion, err := decoder.ReadUInt32() + if err != nil { + return fmt.Errorf("unpickle olmSession: failed to read version: %w", err) + } var includesChainIndex bool switch pickledVersion { @@ -373,7 +376,7 @@ func (o *OlmSession) UnpickleLibOlm(buf []byte) error { case uint32(0x80000001): includesChainIndex = true default: - return fmt.Errorf("unpickle olmSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion) + return fmt.Errorf("unpickle olmSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion) } if o.ReceivedMessage, err = decoder.ReadBool(); err != nil { diff --git a/crypto/goolm/session/register.go b/crypto/goolm/session/register.go index a88d12f6..b95a44ac 100644 --- a/crypto/goolm/session/register.go +++ b/crypto/goolm/session/register.go @@ -14,7 +14,7 @@ func Register() { // Inbound Session olm.InitInboundGroupSessionFromPickled = func(pickled, key []byte) (olm.InboundGroupSession, error) { if len(pickled) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } if len(key) == 0 { key = []byte(" ") @@ -23,13 +23,13 @@ func Register() { } olm.InitNewInboundGroupSession = func(sessionKey []byte) (olm.InboundGroupSession, error) { if len(sessionKey) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } return NewMegolmInboundSession(sessionKey) } olm.InitInboundGroupSessionImport = func(sessionKey []byte) (olm.InboundGroupSession, error) { if len(sessionKey) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } return NewMegolmInboundSessionFromExport(sessionKey) } @@ -40,7 +40,7 @@ func Register() { // Outbound Session olm.InitNewOutboundGroupSessionFromPickled = func(pickled, key []byte) (olm.OutboundGroupSession, error) { if len(pickled) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } lenKey := len(key) if lenKey == 0 { diff --git a/crypto/keybackup.go b/crypto/keybackup.go index d8b3d715..ceec1d58 100644 --- a/crypto/keybackup.go +++ b/crypto/keybackup.go @@ -56,11 +56,12 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context, // ...by deriving the public key from a private key that it obtained from a trusted source. Trusted sources for the private // key include the user entering the key, retrieving the key stored in secret storage, or obtaining the key via secret sharing // from a verified device belonging to the same user." - megolmBackupDerivedPublicKey := id.Ed25519(base64.RawStdEncoding.EncodeToString(megolmBackupKey.PublicKey().Bytes())) - if megolmBackupKey != nil && versionInfo.AuthData.PublicKey == megolmBackupDerivedPublicKey { - log.Debug().Msg("key backup is trusted based on derived public key") - return versionInfo, nil - } else { + if megolmBackupKey != nil { + megolmBackupDerivedPublicKey := id.Ed25519(base64.RawStdEncoding.EncodeToString(megolmBackupKey.PublicKey().Bytes())) + if versionInfo.AuthData.PublicKey == megolmBackupDerivedPublicKey { + log.Debug().Msg("Key backup is trusted based on derived public key") + return versionInfo, nil + } log.Debug(). Stringer("expected_key", megolmBackupDerivedPublicKey). Stringer("actual_key", versionInfo.AuthData.PublicKey). diff --git a/crypto/libolm/account.go b/crypto/libolm/account.go index f6f916e7..0350f083 100644 --- a/crypto/libolm/account.go +++ b/crypto/libolm/account.go @@ -33,7 +33,7 @@ var _ olm.Account = (*Account)(nil) // "INVALID_BASE64". func AccountFromPickled(pickled, key []byte) (*Account, error) { if len(pickled) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } a := NewBlankAccount() return a, a.Unpickle(pickled, key) @@ -53,7 +53,7 @@ func NewAccount() (*Account, error) { random := make([]byte, a.createRandomLen()+1) _, err := rand.Read(random) if err != nil { - panic(olm.NotEnoughGoRandom) + panic(olm.ErrNotEnoughGoRandom) } ret := C.olm_create_account( (*C.OlmAccount)(a.int), @@ -128,7 +128,7 @@ func (a *Account) genOneTimeKeysRandomLen(num uint) uint { // supplied key. func (a *Account) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { - return nil, olm.NoKeyProvided + return nil, olm.ErrNoKeyProvided } pickled := make([]byte, a.pickleLen()) r := C.olm_pickle_account( @@ -145,7 +145,7 @@ func (a *Account) Pickle(key []byte) ([]byte, error) { func (a *Account) Unpickle(pickled, key []byte) error { if len(key) == 0 { - return olm.NoKeyProvided + return olm.ErrNoKeyProvided } r := C.olm_unpickle_account( (*C.OlmAccount)(a.int), @@ -198,7 +198,7 @@ func (a *Account) MarshalJSON() ([]byte, error) { // Deprecated func (a *Account) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return olm.InputNotJSONString + return olm.ErrInputNotJSONString } if a.int == nil { *a = *NewBlankAccount() @@ -235,7 +235,7 @@ func (a *Account) IdentityKeys() (id.Ed25519, id.Curve25519, error) { // Account. func (a *Account) Sign(message []byte) ([]byte, error) { if len(message) == 0 { - panic(olm.EmptyInput) + panic(olm.ErrEmptyInput) } signature := make([]byte, a.signatureLen()) r := C.olm_account_sign( @@ -299,7 +299,7 @@ func (a *Account) GenOneTimeKeys(num uint) error { random := make([]byte, a.genOneTimeKeysRandomLen(num)+1) _, err := rand.Read(random) if err != nil { - return olm.NotEnoughGoRandom + return olm.ErrNotEnoughGoRandom } r := C.olm_account_generate_one_time_keys( (*C.OlmAccount)(a.int), @@ -319,13 +319,13 @@ func (a *Account) GenOneTimeKeys(num uint) error { // keys couldn't be decoded as base64 then the error will be "INVALID_BASE64" func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (olm.Session, error) { if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } s := NewBlankSession() random := make([]byte, s.createOutboundRandomLen()+1) _, err := rand.Read(random) if err != nil { - panic(olm.NotEnoughGoRandom) + panic(olm.ErrNotEnoughGoRandom) } theirIdentityKeyCopy := []byte(theirIdentityKey) theirOneTimeKeyCopy := []byte(theirOneTimeKey) @@ -357,7 +357,7 @@ func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve2 // time key then the error will be "BAD_MESSAGE_KEY_ID". func (a *Account) NewInboundSession(oneTimeKeyMsg string) (olm.Session, error) { if len(oneTimeKeyMsg) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } s := NewBlankSession() oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg) @@ -383,7 +383,7 @@ func (a *Account) NewInboundSession(oneTimeKeyMsg string) (olm.Session, error) { // time key then the error will be "BAD_MESSAGE_KEY_ID". func (a *Account) NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTimeKeyMsg string) (olm.Session, error) { if theirIdentityKey == nil || len(oneTimeKeyMsg) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } theirIdentityKeyCopy := []byte(*theirIdentityKey) oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg) diff --git a/crypto/libolm/error.go b/crypto/libolm/error.go index 9ca415ee..6fb5512b 100644 --- a/crypto/libolm/error.go +++ b/crypto/libolm/error.go @@ -11,21 +11,21 @@ import ( ) var errorMap = map[string]error{ - "NOT_ENOUGH_RANDOM": olm.NotEnoughRandom, - "OUTPUT_BUFFER_TOO_SMALL": olm.OutputBufferTooSmall, - "BAD_MESSAGE_VERSION": olm.BadMessageVersion, - "BAD_MESSAGE_FORMAT": olm.BadMessageFormat, - "BAD_MESSAGE_MAC": olm.BadMessageMAC, - "BAD_MESSAGE_KEY_ID": olm.BadMessageKeyID, - "INVALID_BASE64": olm.InvalidBase64, - "BAD_ACCOUNT_KEY": olm.BadAccountKey, - "UNKNOWN_PICKLE_VERSION": olm.UnknownPickleVersion, - "CORRUPTED_PICKLE": olm.CorruptedPickle, - "BAD_SESSION_KEY": olm.BadSessionKey, - "UNKNOWN_MESSAGE_INDEX": olm.UnknownMessageIndex, - "BAD_LEGACY_ACCOUNT_PICKLE": olm.BadLegacyAccountPickle, - "BAD_SIGNATURE": olm.BadSignature, - "INPUT_BUFFER_TOO_SMALL": olm.InputBufferTooSmall, + "NOT_ENOUGH_RANDOM": olm.ErrLibolmNotEnoughRandom, + "OUTPUT_BUFFER_TOO_SMALL": olm.ErrLibolmOutputBufferTooSmall, + "BAD_MESSAGE_VERSION": olm.ErrWrongProtocolVersion, + "BAD_MESSAGE_FORMAT": olm.ErrBadMessageFormat, + "BAD_MESSAGE_MAC": olm.ErrBadMAC, + "BAD_MESSAGE_KEY_ID": olm.ErrBadMessageKeyID, + "INVALID_BASE64": olm.ErrLibolmInvalidBase64, + "BAD_ACCOUNT_KEY": olm.ErrLibolmBadAccountKey, + "UNKNOWN_PICKLE_VERSION": olm.ErrUnknownOlmPickleVersion, + "CORRUPTED_PICKLE": olm.ErrLibolmCorruptedPickle, + "BAD_SESSION_KEY": olm.ErrLibolmBadSessionKey, + "UNKNOWN_MESSAGE_INDEX": olm.ErrUnknownMessageIndex, + "BAD_LEGACY_ACCOUNT_PICKLE": olm.ErrLibolmBadLegacyAccountPickle, + "BAD_SIGNATURE": olm.ErrBadSignature, + "INPUT_BUFFER_TOO_SMALL": olm.ErrInputToSmall, } func convertError(errCode string) error { diff --git a/crypto/libolm/inboundgroupsession.go b/crypto/libolm/inboundgroupsession.go index 5606475d..8815ac32 100644 --- a/crypto/libolm/inboundgroupsession.go +++ b/crypto/libolm/inboundgroupsession.go @@ -31,7 +31,7 @@ var _ olm.InboundGroupSession = (*InboundGroupSession)(nil) // base64 couldn't be decoded then the error will be "INVALID_BASE64". func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, error) { if len(pickled) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } lenKey := len(key) if lenKey == 0 { @@ -48,7 +48,7 @@ func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, // "OLM_BAD_SESSION_KEY". func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) { if len(sessionKey) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } s := NewBlankInboundGroupSession() r := C.olm_init_inbound_group_session( @@ -69,7 +69,7 @@ func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) { // error will be "OLM_BAD_SESSION_KEY". func InboundGroupSessionImport(sessionKey []byte) (*InboundGroupSession, error) { if len(sessionKey) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } s := NewBlankInboundGroupSession() r := C.olm_import_inbound_group_session( @@ -124,7 +124,7 @@ func (s *InboundGroupSession) pickleLen() uint { // InboundGroupSession using the supplied key. func (s *InboundGroupSession) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { - return nil, olm.NoKeyProvided + return nil, olm.ErrNoKeyProvided } pickled := make([]byte, s.pickleLen()) r := C.olm_pickle_inbound_group_session( @@ -143,9 +143,9 @@ func (s *InboundGroupSession) Pickle(key []byte) ([]byte, error) { func (s *InboundGroupSession) Unpickle(pickled, key []byte) error { if len(key) == 0 { - return olm.NoKeyProvided + return olm.ErrNoKeyProvided } else if len(pickled) == 0 { - return olm.EmptyInput + return olm.ErrEmptyInput } r := C.olm_unpickle_inbound_group_session( (*C.OlmInboundGroupSession)(s.int), @@ -200,7 +200,7 @@ func (s *InboundGroupSession) MarshalJSON() ([]byte, error) { // Deprecated func (s *InboundGroupSession) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return olm.InputNotJSONString + return olm.ErrInputNotJSONString } if s == nil || s.int == nil { *s = *NewBlankInboundGroupSession() @@ -217,7 +217,7 @@ func (s *InboundGroupSession) UnmarshalJSON(data []byte) error { // will be "BAD_MESSAGE_FORMAT". func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, error) { if len(message) == 0 { - return 0, olm.EmptyInput + return 0, olm.ErrEmptyInput } // olm_group_decrypt_max_plaintext_length destroys the input, so we have to clone it messageCopy := bytes.Clone(message) @@ -244,7 +244,7 @@ func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, erro // was shared with us) the error will be "OLM_UNKNOWN_MESSAGE_INDEX". func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) { if len(message) == 0 { - return nil, 0, olm.EmptyInput + return nil, 0, olm.ErrEmptyInput } decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message) if err != nil { diff --git a/crypto/libolm/outboundgroupsession.go b/crypto/libolm/outboundgroupsession.go index 646929eb..ca5b68f7 100644 --- a/crypto/libolm/outboundgroupsession.go +++ b/crypto/libolm/outboundgroupsession.go @@ -84,7 +84,7 @@ func (s *OutboundGroupSession) pickleLen() uint { // OutboundGroupSession using the supplied key. func (s *OutboundGroupSession) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { - return nil, olm.NoKeyProvided + return nil, olm.ErrNoKeyProvided } pickled := make([]byte, s.pickleLen()) r := C.olm_pickle_outbound_group_session( @@ -103,7 +103,7 @@ func (s *OutboundGroupSession) Pickle(key []byte) ([]byte, error) { func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error { if len(key) == 0 { - return olm.NoKeyProvided + return olm.ErrNoKeyProvided } r := C.olm_unpickle_outbound_group_session( (*C.OlmOutboundGroupSession)(s.int), @@ -159,7 +159,7 @@ func (s *OutboundGroupSession) MarshalJSON() ([]byte, error) { // Deprecated func (s *OutboundGroupSession) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return olm.InputNotJSONString + return olm.ErrInputNotJSONString } if s == nil || s.int == nil { *s = *NewBlankOutboundGroupSession() @@ -183,7 +183,7 @@ func (s *OutboundGroupSession) encryptMsgLen(plainTextLen int) uint { // as base64. func (s *OutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) { if len(plaintext) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } message := make([]byte, s.encryptMsgLen(len(plaintext))) r := C.olm_group_encrypt( diff --git a/crypto/libolm/pk.go b/crypto/libolm/pk.go index 35532140..2683cf15 100644 --- a/crypto/libolm/pk.go +++ b/crypto/libolm/pk.go @@ -86,7 +86,7 @@ func NewPKSigning() (*PKSigning, error) { seed := make([]byte, pkSigningSeedLength()) _, err := rand.Read(seed) if err != nil { - panic(olm.NotEnoughGoRandom) + panic(olm.ErrNotEnoughGoRandom) } pk, err := NewPKSigningFromSeed(seed) return pk, err diff --git a/crypto/libolm/register.go b/crypto/libolm/register.go index f091d822..ddf84613 100644 --- a/crypto/libolm/register.go +++ b/crypto/libolm/register.go @@ -65,7 +65,7 @@ func Register() { olm.InitNewOutboundGroupSessionFromPickled = func(pickled, key []byte) (olm.OutboundGroupSession, error) { if len(pickled) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } s := NewBlankOutboundGroupSession() return s, s.Unpickle(pickled, key) diff --git a/crypto/libolm/session.go b/crypto/libolm/session.go index 57e631c3..1441df26 100644 --- a/crypto/libolm/session.go +++ b/crypto/libolm/session.go @@ -51,7 +51,7 @@ func sessionSize() uint { // "INVALID_BASE64". func SessionFromPickled(pickled, key []byte) (*Session, error) { if len(pickled) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } s := NewBlankSession() return s, s.Unpickle(pickled, key) @@ -118,7 +118,7 @@ func (s *Session) encryptMsgLen(plainTextLen int) uint { // will be "BAD_MESSAGE_FORMAT". func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType) (uint, error) { if len(message) == 0 { - return 0, olm.EmptyInput + return 0, olm.ErrEmptyInput } messageCopy := []byte(message) r := C.olm_decrypt_max_plaintext_length( @@ -138,7 +138,7 @@ func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType) // supplied key. func (s *Session) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { - return nil, olm.NoKeyProvided + return nil, olm.ErrNoKeyProvided } pickled := make([]byte, s.pickleLen()) r := C.olm_pickle_session( @@ -158,7 +158,7 @@ func (s *Session) Pickle(key []byte) ([]byte, error) { // provided key. This function mutates the input pickled data slice. func (s *Session) Unpickle(pickled, key []byte) error { if len(key) == 0 { - return olm.NoKeyProvided + return olm.ErrNoKeyProvided } r := C.olm_unpickle_session( (*C.OlmSession)(s.int), @@ -213,7 +213,7 @@ func (s *Session) MarshalJSON() ([]byte, error) { // Deprecated func (s *Session) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return olm.InputNotJSONString + return olm.ErrInputNotJSONString } if s == nil || s.int == nil { *s = *NewBlankSession() @@ -256,7 +256,7 @@ func (s *Session) HasReceivedMessage() bool { // decoded then then the error will be "BAD_MESSAGE_FORMAT". func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) { if len(oneTimeKeyMsg) == 0 { - return false, olm.EmptyInput + return false, olm.ErrEmptyInput } oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg) r := C.olm_matches_inbound_session( @@ -284,7 +284,7 @@ func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) { // decoded then then the error will be "BAD_MESSAGE_FORMAT". func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) { if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 { - return false, olm.EmptyInput + return false, olm.ErrEmptyInput } theirIdentityKeyCopy := []byte(theirIdentityKey) oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg) @@ -325,14 +325,14 @@ func (s *Session) EncryptMsgType() id.OlmMsgType { // as base64. func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) { if len(plaintext) == 0 { - return 0, nil, olm.EmptyInput + return 0, nil, olm.ErrEmptyInput } // Make the slice be at least length 1 random := make([]byte, s.encryptRandomLen()+1) _, err := rand.Read(random) if err != nil { // TODO can we just return err here? - return 0, nil, olm.NotEnoughGoRandom + return 0, nil, olm.ErrNotEnoughGoRandom } messageType := s.EncryptMsgType() message := make([]byte, s.encryptMsgLen(len(plaintext))) @@ -362,7 +362,7 @@ func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) { // "BAD_MESSAGE_MAC". func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) { if len(message) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message, msgType) if err != nil { diff --git a/crypto/machine.go b/crypto/machine.go index 4d2e3880..f8ebe909 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -205,7 +205,7 @@ func (mach *OlmMachine) FlushStore(ctx context.Context) error { func (mach *OlmMachine) timeTrace(ctx context.Context, thing string, expectedDuration time.Duration) func() { start := time.Now() return func() { - duration := time.Now().Sub(start) + duration := time.Since(start) if duration > expectedDuration { zerolog.Ctx(ctx).Warn(). Str("action", thing). diff --git a/crypto/olm/errors.go b/crypto/olm/errors.go index 957d7928..9e522b2a 100644 --- a/crypto/olm/errors.go +++ b/crypto/olm/errors.go @@ -10,50 +10,67 @@ import "errors" // Those are the most common used errors var ( - ErrBadSignature = errors.New("bad signature") - ErrBadMAC = errors.New("bad mac") - ErrBadMessageFormat = errors.New("bad message format") - ErrBadVerification = errors.New("bad verification") - ErrWrongProtocolVersion = errors.New("wrong protocol version") - ErrEmptyInput = errors.New("empty input") - ErrNoKeyProvided = errors.New("no key") - ErrBadMessageKeyID = errors.New("bad message key id") - ErrRatchetNotAvailable = errors.New("ratchet not available: attempt to decode a message whose index is earlier than our earliest known session key") - ErrMsgIndexTooHigh = errors.New("message index too high") - ErrProtocolViolation = errors.New("not protocol message order") - ErrMessageKeyNotFound = errors.New("message key not found") - ErrChainTooHigh = errors.New("chain index too high") - ErrBadInput = errors.New("bad input") - ErrBadVersion = errors.New("wrong version") - ErrWrongPickleVersion = errors.New("wrong pickle version") - ErrInputToSmall = errors.New("input too small (truncated?)") - ErrOverflow = errors.New("overflow") + ErrBadSignature = errors.New("bad signature") + ErrBadMAC = errors.New("the message couldn't be decrypted (bad mac)") + ErrBadMessageFormat = errors.New("the message couldn't be decoded") + ErrBadVerification = errors.New("bad verification") + ErrWrongProtocolVersion = errors.New("wrong protocol version") + ErrEmptyInput = errors.New("empty input") + ErrNoKeyProvided = errors.New("no key provided") + ErrBadMessageKeyID = errors.New("the message references an unknown key ID") + ErrUnknownMessageIndex = errors.New("attempt to decode a message whose index is earlier than our earliest known session key") + ErrMsgIndexTooHigh = errors.New("message index too high") + ErrProtocolViolation = errors.New("not protocol message order") + ErrMessageKeyNotFound = errors.New("message key not found") + ErrChainTooHigh = errors.New("chain index too high") + ErrBadInput = errors.New("bad input") + ErrUnknownOlmPickleVersion = errors.New("unknown olm pickle version") + ErrUnknownJSONPickleVersion = errors.New("unknown JSON pickle version") + ErrInputToSmall = errors.New("input too small (truncated?)") ) // Error codes from go-olm var ( - EmptyInput = errors.New("empty input") - NoKeyProvided = errors.New("no pickle key provided") - NotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand") - SignatureNotFound = errors.New("input JSON doesn't contain signature from specified device") - InputNotJSONString = errors.New("input doesn't look like a JSON string") + ErrNotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand") + ErrInputNotJSONString = errors.New("input doesn't look like a JSON string") ) // Error codes from olm code var ( - NotEnoughRandom = errors.New("not enough entropy was supplied") - OutputBufferTooSmall = errors.New("supplied output buffer is too small") - BadMessageVersion = errors.New("the message version is unsupported") - BadMessageFormat = errors.New("the message couldn't be decoded") - BadMessageMAC = errors.New("the message couldn't be decrypted") - BadMessageKeyID = errors.New("the message references an unknown key ID") - InvalidBase64 = errors.New("the input base64 was invalid") - BadAccountKey = errors.New("the supplied account key is invalid") - UnknownPickleVersion = errors.New("the pickled object is too new") - CorruptedPickle = errors.New("the pickled object couldn't be decoded") - BadSessionKey = errors.New("attempt to initialise an inbound group session from an invalid session key") - UnknownMessageIndex = errors.New("attempt to decode a message whose index is earlier than our earliest known session key") - BadLegacyAccountPickle = errors.New("attempt to unpickle an account which uses pickle version 1") - BadSignature = errors.New("received message had a bad signature") - InputBufferTooSmall = errors.New("the input data was too small to be valid") + ErrLibolmInvalidBase64 = errors.New("the input base64 was invalid") + + ErrLibolmNotEnoughRandom = errors.New("not enough entropy was supplied") + ErrLibolmOutputBufferTooSmall = errors.New("supplied output buffer is too small") + ErrLibolmBadAccountKey = errors.New("the supplied account key is invalid") + ErrLibolmCorruptedPickle = errors.New("the pickled object couldn't be decoded") + ErrLibolmBadSessionKey = errors.New("attempt to initialise an inbound group session from an invalid session key") + ErrLibolmBadLegacyAccountPickle = errors.New("attempt to unpickle an account which uses pickle version 1") +) + +// Deprecated: use variables prefixed with Err +var ( + EmptyInput = ErrEmptyInput + BadSignature = ErrBadSignature + InvalidBase64 = ErrLibolmInvalidBase64 + BadMessageKeyID = ErrBadMessageKeyID + BadMessageFormat = ErrBadMessageFormat + BadMessageVersion = ErrWrongProtocolVersion + BadMessageMAC = ErrBadMAC + UnknownPickleVersion = ErrUnknownOlmPickleVersion + NotEnoughRandom = ErrLibolmNotEnoughRandom + OutputBufferTooSmall = ErrLibolmOutputBufferTooSmall + BadAccountKey = ErrLibolmBadAccountKey + CorruptedPickle = ErrLibolmCorruptedPickle + BadSessionKey = ErrLibolmBadSessionKey + UnknownMessageIndex = ErrUnknownMessageIndex + BadLegacyAccountPickle = ErrLibolmBadLegacyAccountPickle + InputBufferTooSmall = ErrInputToSmall + NoKeyProvided = ErrNoKeyProvided + + NotEnoughGoRandom = ErrNotEnoughGoRandom + InputNotJSONString = ErrInputNotJSONString + + ErrBadVersion = ErrUnknownJSONPickleVersion + ErrWrongPickleVersion = ErrUnknownJSONPickleVersion + ErrRatchetNotAvailable = ErrUnknownMessageIndex ) diff --git a/crypto/sessions.go b/crypto/sessions.go index aecb0416..6b90c998 100644 --- a/crypto/sessions.go +++ b/crypto/sessions.go @@ -18,8 +18,14 @@ import ( ) var ( - SessionNotShared = errors.New("session has not been shared") - SessionExpired = errors.New("session has expired") + ErrSessionNotShared = errors.New("session has not been shared") + ErrSessionExpired = errors.New("session has expired") +) + +// Deprecated: use variables prefixed with Err +var ( + SessionNotShared = ErrSessionNotShared + SessionExpired = ErrSessionExpired ) // OlmSessionList is a list of OlmSessions. @@ -255,9 +261,9 @@ func (ogs *OutboundGroupSession) Expired() bool { func (ogs *OutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) { if !ogs.Shared { - return nil, SessionNotShared + return nil, ErrSessionNotShared } else if ogs.Expired() { - return nil, SessionExpired + return nil, ErrSessionExpired } ogs.MessageCount++ ogs.LastEncryptedTime = time.Now() diff --git a/event/encryption.go b/event/encryption.go index cf9c2814..e07944af 100644 --- a/event/encryption.go +++ b/event/encryption.go @@ -63,7 +63,7 @@ func (content *EncryptedEventContent) UnmarshalJSON(data []byte) error { return json.Unmarshal(content.Ciphertext, &content.OlmCiphertext) case id.AlgorithmMegolmV1: if len(content.Ciphertext) == 0 || content.Ciphertext[0] != '"' || content.Ciphertext[len(content.Ciphertext)-1] != '"' { - return id.InputNotJSONString + return id.ErrInputNotJSONString } content.MegolmCiphertext = content.Ciphertext[1 : len(content.Ciphertext)-1] } diff --git a/federation/resolution.go b/federation/resolution.go index 81e19cfb..a3188266 100644 --- a/federation/resolution.go +++ b/federation/resolution.go @@ -80,7 +80,10 @@ func ResolveServerName(ctx context.Context, serverName string, opts ...*ResolveS } else if wellKnown != nil { output.Expires = expiry output.HostHeader = wellKnown.Server - hostname, port, ok = ParseServerName(wellKnown.Server) + wkHost, wkPort, ok := ParseServerName(wellKnown.Server) + if ok { + hostname, port = wkHost, wkPort + } // Step 3.1 and 3.2: IP literals and hostnames with port inside .well-known if net.ParseIP(hostname) != nil || port != 0 { if port == 0 { diff --git a/filter.go b/filter.go index c6c8211b..54973dab 100644 --- a/filter.go +++ b/filter.go @@ -57,7 +57,7 @@ type FilterPart struct { // Validate checks if the filter contains valid property values func (filter *Filter) Validate() error { if filter.EventFormat != EventFormatClient && filter.EventFormat != EventFormatFederation { - return errors.New("Bad event_format value. Must be one of [\"client\", \"federation\"]") + return errors.New("bad event_format value") } return nil } diff --git a/id/contenturi.go b/id/contenturi.go index e6a313f5..be45eb2b 100644 --- a/id/contenturi.go +++ b/id/contenturi.go @@ -17,8 +17,14 @@ import ( ) var ( - InvalidContentURI = errors.New("invalid Matrix content URI") - InputNotJSONString = errors.New("input doesn't look like a JSON string") + ErrInvalidContentURI = errors.New("invalid Matrix content URI") + ErrInputNotJSONString = errors.New("input doesn't look like a JSON string") +) + +// Deprecated: use variables prefixed with Err +var ( + InvalidContentURI = ErrInvalidContentURI + InputNotJSONString = ErrInputNotJSONString ) // ContentURIString is a string that's expected to be a Matrix content URI. @@ -55,9 +61,9 @@ func ParseContentURI(uri string) (parsed ContentURI, err error) { if len(uri) == 0 { return } else if !strings.HasPrefix(uri, "mxc://") { - err = InvalidContentURI + err = ErrInvalidContentURI } else if index := strings.IndexRune(uri[6:], '/'); index == -1 || index == len(uri)-7 { - err = InvalidContentURI + err = ErrInvalidContentURI } else { parsed.Homeserver = uri[6 : 6+index] parsed.FileID = uri[6+index+1:] @@ -71,9 +77,9 @@ func ParseContentURIBytes(uri []byte) (parsed ContentURI, err error) { if len(uri) == 0 { return } else if !bytes.HasPrefix(uri, mxcBytes) { - err = InvalidContentURI + err = ErrInvalidContentURI } else if index := bytes.IndexRune(uri[6:], '/'); index == -1 || index == len(uri)-7 { - err = InvalidContentURI + err = ErrInvalidContentURI } else { parsed.Homeserver = string(uri[6 : 6+index]) parsed.FileID = string(uri[6+index+1:]) @@ -86,7 +92,7 @@ func (uri *ContentURI) UnmarshalJSON(raw []byte) (err error) { *uri = ContentURI{} return nil } else if len(raw) < 2 || raw[0] != '"' || raw[len(raw)-1] != '"' { - return InputNotJSONString + return ErrInputNotJSONString } parsed, err := ParseContentURIBytes(raw[1 : len(raw)-1]) if err != nil { diff --git a/id/matrixuri.go b/id/matrixuri.go index 8f5ec849..d5c78bc7 100644 --- a/id/matrixuri.go +++ b/id/matrixuri.go @@ -54,7 +54,7 @@ var SigilToPathSegment = map[rune]string{ func (uri *MatrixURI) getQuery() url.Values { q := make(url.Values) - if uri.Via != nil && len(uri.Via) > 0 { + if len(uri.Via) > 0 { q["via"] = uri.Via } if len(uri.Action) > 0 { diff --git a/id/userid.go b/id/userid.go index 859d2358..726a0d58 100644 --- a/id/userid.go +++ b/id/userid.go @@ -219,15 +219,15 @@ func DecodeUserLocalpart(str string) (string, error) { for i := 0; i < len(strBytes); i++ { b := strBytes[i] if !isValidByte(b) { - return "", fmt.Errorf("Byte pos %d: Invalid byte", i) + return "", fmt.Errorf("invalid encoded byte at position %d: %c", i, b) } if b == '_' { // next byte is a-z and should be upper-case or is another _ and should be a literal _ if i+1 >= len(strBytes) { - return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding but ran out of string", i) + return "", fmt.Errorf("unexpected end of string after underscore at %d", i) } if !isValidEscapedChar(strBytes[i+1]) { // invalid escaping - return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding", i) + return "", fmt.Errorf("unexpected byte %c after underscore at %d", strBytes[i+1], i) } if strBytes[i+1] == '_' { outputBuffer.WriteByte('_') @@ -237,7 +237,7 @@ func DecodeUserLocalpart(str string) (string, error) { i++ // skip next byte since we just handled it } else if b == '=' { // next 2 bytes are hex and should be buffered ready to be read as utf8 if i+2 >= len(strBytes) { - return "", fmt.Errorf("Byte pos: %d: expected quote-printable encoding but ran out of string", i) + return "", fmt.Errorf("unexpected end of string after equals sign at %d", i) } dst := make([]byte, 1) _, err := hex.Decode(dst, strBytes[i+1:i+3]) diff --git a/pushrules/action.go b/pushrules/action.go index 9838e88b..b5a884b2 100644 --- a/pushrules/action.go +++ b/pushrules/action.go @@ -105,7 +105,7 @@ func (action *PushAction) UnmarshalJSON(raw []byte) error { if ok { action.Action = ActionSetTweak action.Tweak = PushActionTweak(tweak) - action.Value, _ = val["value"] + action.Value = val["value"] } } return nil diff --git a/pushrules/condition_test.go b/pushrules/condition_test.go index 0d3eaf7a..37af3e34 100644 --- a/pushrules/condition_test.go +++ b/pushrules/condition_test.go @@ -102,14 +102,6 @@ func newEventPropertyIsPushCondition(key string, value any) *pushrules.PushCondi } } -func newEventPropertyContainsPushCondition(key string, value any) *pushrules.PushCondition { - return &pushrules.PushCondition{ - Kind: pushrules.KindEventPropertyContains, - Key: key, - Value: value, - } -} - func TestPushCondition_Match_InvalidKind(t *testing.T) { condition := &pushrules.PushCondition{ Kind: pushrules.PushCondKind("invalid"), diff --git a/room.go b/room.go index c3ddb7e6..4292bff5 100644 --- a/room.go +++ b/room.go @@ -5,8 +5,6 @@ import ( "maunium.net/go/mautrix/id" ) -type RoomStateMap = map[event.Type]map[string]*event.Event - // Room represents a single Matrix room. type Room struct { ID id.RoomID @@ -25,8 +23,8 @@ func (room Room) UpdateState(evt *event.Event) { // GetStateEvent returns the state event for the given type/state_key combo, or nil. func (room Room) GetStateEvent(eventType event.Type, stateKey string) *event.Event { - stateEventMap, _ := room.State[eventType] - evt, _ := stateEventMap[stateKey] + stateEventMap := room.State[eventType] + evt := stateEventMap[stateKey] return evt } diff --git a/synapseadmin/roomapi.go b/synapseadmin/roomapi.go index c360acab..0925b748 100644 --- a/synapseadmin/roomapi.go +++ b/synapseadmin/roomapi.go @@ -75,8 +75,7 @@ type RespListRooms struct { // https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#list-room-api func (cli *Client) ListRooms(ctx context.Context, req ReqListRoom) (RespListRooms, error) { var resp RespListRooms - var reqURL string - reqURL = cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery()) + reqURL := cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery()) _, err := cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) return resp, err } diff --git a/url.go b/url.go index d888956a..91b3d49d 100644 --- a/url.go +++ b/url.go @@ -98,10 +98,8 @@ func (saup SynapseAdminURLPath) FullPath() []any { // and appservice user ID set already. func (cli *Client) BuildURLWithQuery(urlPath PrefixableURLPath, urlQuery map[string]string) string { return cli.BuildURLWithFullQuery(urlPath, func(q url.Values) { - if urlQuery != nil { - for k, v := range urlQuery { - q.Set(k, v) - } + for k, v := range urlQuery { + q.Set(k, v) } }) }