all: fix staticcheck issues
Some checks are pending
Go / Lint (latest) (push) Waiting to run
Go / Build (old, libolm) (push) Waiting to run
Go / Build (latest, libolm) (push) Waiting to run
Go / Build (old, goolm) (push) Waiting to run
Go / Build (latest, goolm) (push) Waiting to run

This commit is contained in:
Tulir Asokan 2025-12-08 00:07:07 +02:00
commit 315d2ab17d
56 changed files with 358 additions and 338 deletions

View file

@ -24,6 +24,7 @@ jobs:
- name: Install goimports
run: |
go install golang.org/x/tools/cmd/goimports@latest
go install honnef.co/go/tools/cmd/staticcheck@latest
export PATH="$HOME/go/bin:$PATH"
- name: Run pre-commit

View file

@ -18,8 +18,7 @@ repos:
- "-w"
- id: go-vet-repo-mod
- id: go-mod-tidy
# TODO enable this
#- id: go-staticcheck-repo-mod
- id: go-staticcheck-repo-mod
- repo: https://github.com/beeper/pre-commit-go
rev: v0.4.2

View file

@ -56,7 +56,7 @@ func (wsc *WebsocketCommand) MakeResponse(ok bool, data any) *WebsocketRequest {
var prefixMessage string
for unwrappedErr != nil {
errorData, jsonErr = json.Marshal(unwrappedErr)
if errorData != nil && len(errorData) > 2 && jsonErr == nil {
if len(errorData) > 2 && jsonErr == nil {
prefixMessage = strings.Replace(err.Error(), unwrappedErr.Error(), "", 1)
prefixMessage = strings.TrimRight(prefixMessage, ": ")
break

View file

@ -41,10 +41,7 @@ func (pc PermissionConfig) IsConfigured() bool {
_, hasExampleDomain := pc["example.com"]
_, hasExampleUser := pc["@admin:example.com"]
exampleLen := boolToInt(hasWildcard) + boolToInt(hasExampleUser) + boolToInt(hasExampleDomain)
if len(pc) <= exampleLen {
return false
}
return true
return len(pc) > exampleLen
}
func (pc PermissionConfig) Get(userID id.UserID) Permissions {

View file

@ -102,9 +102,9 @@ func (bsq *BridgeStateQueue) loop() {
}
}
func (bsq *BridgeStateQueue) scheduleNotice(ctx context.Context, triggeredBy status.BridgeState) {
func (bsq *BridgeStateQueue) scheduleNotice(triggeredBy status.BridgeState) {
log := bsq.login.Log.With().Str("action", "transient disconnect notice").Logger()
ctx = log.WithContext(bsq.bridge.BackgroundCtx)
ctx := log.WithContext(bsq.bridge.BackgroundCtx)
if !bsq.waitForTransientDisconnectReconnect(ctx) {
return
}
@ -135,7 +135,7 @@ func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.Bridge
if bsq.firstTransientDisconnect.IsZero() {
bsq.firstTransientDisconnect = time.Now()
}
go bsq.scheduleNotice(ctx, state)
go bsq.scheduleNotice(state)
}
return
}

View file

@ -7,13 +7,7 @@
package database
import (
"encoding/json"
"reflect"
"strings"
"go.mau.fi/util/dbutil"
"golang.org/x/exp/constraints"
"golang.org/x/exp/maps"
"maunium.net/go/mautrix/bridgev2/networkid"
@ -158,55 +152,3 @@ func ensureBridgeIDMatches(ptr *networkid.BridgeID, expected networkid.BridgeID)
panic("bridge ID mismatch")
}
}
func GetNumberFromMap[T constraints.Integer | constraints.Float](m map[string]any, key string) (T, bool) {
if val, found := m[key]; found {
floatVal, ok := val.(float64)
if ok {
return T(floatVal), true
}
tVal, ok := val.(T)
if ok {
return tVal, true
}
}
return 0, false
}
func unmarshalMerge(input []byte, data any, extra *map[string]any) error {
err := json.Unmarshal(input, data)
if err != nil {
return err
}
err = json.Unmarshal(input, extra)
if err != nil {
return err
}
if *extra == nil {
*extra = make(map[string]any)
}
return nil
}
func marshalMerge(data any, extra map[string]any) ([]byte, error) {
if extra == nil {
return json.Marshal(data)
}
merged := make(map[string]any)
maps.Copy(merged, extra)
dataRef := reflect.ValueOf(data).Elem()
dataType := dataRef.Type()
for _, field := range reflect.VisibleFields(dataType) {
parts := strings.Split(field.Tag.Get("json"), ",")
if len(parts) == 0 || len(parts[0]) == 0 || parts[0] == "-" {
continue
}
fieldVal := dataRef.FieldByIndex(field.Index)
if fieldVal.IsZero() {
delete(merged, parts[0])
} else {
merged[parts[0]] = fieldVal.Interface()
}
}
return json.Marshal(merged)
}

View file

@ -38,9 +38,9 @@ func init() {
var _ crypto.StateStore = (*sqlstatestore.SQLStateStore)(nil)
var NoSessionFound = crypto.NoSessionFound
var DuplicateMessageIndex = crypto.DuplicateMessageIndex
var UnknownMessageIndex = olm.UnknownMessageIndex
var NoSessionFound = crypto.ErrNoSessionFound
var DuplicateMessageIndex = crypto.ErrDuplicateMessageIndex
var UnknownMessageIndex = olm.ErrUnknownMessageIndex
type CryptoHelper struct {
bridge *Connector
@ -439,7 +439,7 @@ func (helper *CryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtTy
var encrypted *event.EncryptedEventContent
encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content)
if err != nil {
if !errors.Is(err, crypto.SessionExpired) && !errors.Is(err, crypto.SessionNotShared) && !errors.Is(err, crypto.NoGroupSession) {
if !errors.Is(err, crypto.ErrSessionExpired) && !errors.Is(err, crypto.ErrSessionNotShared) && !errors.Is(err, crypto.ErrNoGroupSession) {
return
}
helper.log.Debug().Err(err).

View file

@ -127,6 +127,7 @@ func (br *Connector) waitLongerForSession(ctx context.Context, evt *event.Event,
Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).
Msg("Couldn't find session, requesting keys and waiting longer...")
//lint:ignore SA1019 RequestSession will gracefully request from all devices if DeviceID is blank
go br.Crypto.RequestSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID)
go br.sendCryptoStatusError(ctx, evt, fmt.Errorf("%w. The bridge will retry for %d seconds", errNoDecryptionKeys, int(extendedSessionWaitTimeout.Seconds())), errorEventID, 1, false)

View file

@ -135,7 +135,10 @@ func (br *BridgeMain) CheckLegacyDB(
}
var dbVersion int
err = br.DB.QueryRow(ctx, "SELECT version FROM version").Scan(&dbVersion)
if dbVersion < expectedVersion {
if err != nil {
log.Fatal().Err(err).Msg("Failed to get database version")
return
} else if dbVersion < expectedVersion {
log.Fatal().
Int("expected_version", expectedVersion).
Int("version", dbVersion).

View file

@ -85,10 +85,9 @@ const (
provisioningUserKey provisioningContextKey = iota
provisioningUserLoginKey
provisioningLoginProcessKey
ProvisioningKeyRequest
)
const ProvisioningKeyRequest = "fi.mau.provision.request"
func (prov *ProvisioningAPI) GetUser(r *http.Request) *bridgev2.User {
return r.Context().Value(provisioningUserKey).(*bridgev2.User)
}

View file

@ -410,6 +410,7 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin
if reaction.Timestamp.IsZero() {
reaction.Timestamp = msg.Timestamp.Add(10 * time.Millisecond)
}
//lint:ignore SA4006 it's a todo
targetPart, ok := partMap[*reaction.TargetPart]
if !ok {
// TODO warning log and/or skip reaction?

View file

@ -96,7 +96,7 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta
go func() {
_, err := br.Bot.SendState(ctx, sourcePortal.MXID, event.StateTombstone, "", &event.Content{
Parsed: &event.TombstoneEventContent{
Body: fmt.Sprintf("This room has been merged"),
Body: "This room has been merged",
ReplacementRoom: targetPortal.MXID,
},
}, time.Now())

View file

@ -742,7 +742,7 @@ func (cli *Client) executeCompiledRequest(
cli.RequestStart(req)
startTime := time.Now()
res, err := client.Do(req)
duration := time.Now().Sub(startTime)
duration := time.Since(startTime)
if res != nil && !dontReadResponse {
defer res.Body.Close()
}
@ -862,7 +862,7 @@ func (cli *Client) FullSyncRequest(ctx context.Context, req ReqSync) (resp *Resp
}
start := time.Now()
_, err = cli.MakeFullRequest(ctx, fullReq)
duration := time.Now().Sub(start)
duration := time.Since(start)
timeout := time.Duration(req.Timeout) * time.Millisecond
buffer := 10 * time.Second
if req.Since == "" {
@ -966,7 +966,7 @@ func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister) (*RespRe
// }
// token := res.AccessToken
func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister) (*RespRegister, error) {
res, uia, err := cli.Register(ctx, req)
_, uia, err := cli.Register(ctx, req)
if err != nil && uia == nil {
return nil, err
} else if uia == nil {
@ -975,7 +975,7 @@ func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister) (*RespRe
return nil, errors.New("server does not support m.login.dummy")
}
req.Auth = BaseAuthData{Type: AuthTypeDummy, Session: uia.Session}
res, _, err = cli.Register(ctx, req)
res, _, err := cli.Register(ctx, req)
if err != nil {
return nil, err
}
@ -1751,6 +1751,8 @@ func parseRoomStateArray(req *http.Request, res *http.Response, responseJSON any
return nil, nil
}
type RoomStateMap = map[event.Type]map[string]*event.Event
// State gets all state in a room.
// See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidstate
func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomStateMap, err error) {

View file

@ -21,13 +21,24 @@ import (
)
var (
HashMismatch = errors.New("mismatching SHA-256 digest")
UnsupportedVersion = errors.New("unsupported Matrix file encryption version")
UnsupportedAlgorithm = errors.New("unsupported JWK encryption algorithm")
InvalidKey = errors.New("failed to decode key")
InvalidInitVector = errors.New("failed to decode initialization vector")
InvalidHash = errors.New("failed to decode SHA-256 hash")
ReaderClosed = errors.New("encrypting reader was already closed")
ErrHashMismatch = errors.New("mismatching SHA-256 digest")
ErrUnsupportedVersion = errors.New("unsupported Matrix file encryption version")
ErrUnsupportedAlgorithm = errors.New("unsupported JWK encryption algorithm")
ErrInvalidKey = errors.New("failed to decode key")
ErrInvalidInitVector = errors.New("failed to decode initialization vector")
ErrInvalidHash = errors.New("failed to decode SHA-256 hash")
ErrReaderClosed = errors.New("encrypting reader was already closed")
)
// Deprecated: use variables prefixed with Err
var (
HashMismatch = ErrHashMismatch
UnsupportedVersion = ErrUnsupportedVersion
UnsupportedAlgorithm = ErrUnsupportedAlgorithm
InvalidKey = ErrInvalidKey
InvalidInitVector = ErrInvalidInitVector
InvalidHash = ErrInvalidHash
ReaderClosed = ErrReaderClosed
)
var (
@ -85,25 +96,25 @@ func (ef *EncryptedFile) decodeKeys(includeHash bool) error {
if ef.decoded != nil {
return nil
} else if len(ef.Key.Key) != keyBase64Length {
return InvalidKey
return ErrInvalidKey
} else if len(ef.InitVector) != ivBase64Length {
return InvalidInitVector
return ErrInvalidInitVector
} else if includeHash && len(ef.Hashes.SHA256) != hashBase64Length {
return InvalidHash
return ErrInvalidHash
}
ef.decoded = &decodedKeys{}
_, err := base64.RawURLEncoding.Decode(ef.decoded.key[:], []byte(ef.Key.Key))
if err != nil {
return InvalidKey
return ErrInvalidKey
}
_, err = base64.RawStdEncoding.Decode(ef.decoded.iv[:], []byte(ef.InitVector))
if err != nil {
return InvalidInitVector
return ErrInvalidInitVector
}
if includeHash {
_, err = base64.RawStdEncoding.Decode(ef.decoded.sha256[:], []byte(ef.Hashes.SHA256))
if err != nil {
return InvalidHash
return ErrInvalidHash
}
}
return nil
@ -179,7 +190,7 @@ var _ io.ReadSeekCloser = (*encryptingReader)(nil)
func (r *encryptingReader) Seek(offset int64, whence int) (int64, error) {
if r.closed {
return 0, ReaderClosed
return 0, ErrReaderClosed
}
if offset != 0 || whence != io.SeekStart {
return 0, fmt.Errorf("attachments.EncryptStream: only seeking to the beginning is supported")
@ -200,7 +211,7 @@ func (r *encryptingReader) Seek(offset int64, whence int) (int64, error) {
func (r *encryptingReader) Read(dst []byte) (n int, err error) {
if r.closed {
return 0, ReaderClosed
return 0, ErrReaderClosed
} else if r.isDecrypting && r.file.decoded == nil {
if err = r.file.PrepareForDecryption(); err != nil {
return
@ -224,7 +235,7 @@ func (r *encryptingReader) Close() (err error) {
}
if r.isDecrypting {
if !hmac.Equal(r.hash.Sum(nil), r.file.decoded.sha256[:]) {
return HashMismatch
return ErrHashMismatch
}
} else {
r.file.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString(r.hash.Sum(nil))
@ -265,9 +276,9 @@ func (ef *EncryptedFile) Decrypt(ciphertext []byte) ([]byte, error) {
// DecryptInPlace will always call this automatically, so calling this manually is not necessary when using that function.
func (ef *EncryptedFile) PrepareForDecryption() error {
if ef.Version != "v2" {
return UnsupportedVersion
return ErrUnsupportedVersion
} else if ef.Key.Algorithm != "A256CTR" {
return UnsupportedAlgorithm
return ErrUnsupportedAlgorithm
} else if err := ef.decodeKeys(true); err != nil {
return err
}
@ -281,7 +292,7 @@ func (ef *EncryptedFile) DecryptInPlace(data []byte) error {
}
dataHash := sha256.Sum256(data)
if !hmac.Equal(ef.decoded.sha256[:], dataHash[:]) {
return HashMismatch
return ErrHashMismatch
}
utils.XorA256CTR(data, ef.decoded.key, ef.decoded.iv)
return nil

View file

@ -53,33 +53,33 @@ func TestUnsupportedVersion(t *testing.T) {
file := parseHelloWorld()
file.Version = "foo"
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
assert.ErrorIs(t, err, UnsupportedVersion)
assert.ErrorIs(t, err, ErrUnsupportedVersion)
}
func TestUnsupportedAlgorithm(t *testing.T) {
file := parseHelloWorld()
file.Key.Algorithm = "bar"
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
assert.ErrorIs(t, err, UnsupportedAlgorithm)
assert.ErrorIs(t, err, ErrUnsupportedAlgorithm)
}
func TestHashMismatch(t *testing.T) {
file := parseHelloWorld()
file.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString([]byte(random32Bytes))
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
assert.ErrorIs(t, err, HashMismatch)
assert.ErrorIs(t, err, ErrHashMismatch)
}
func TestTooLongHash(t *testing.T) {
file := parseHelloWorld()
file.Hashes.SHA256 = "TG9yZW0gaXBzdW0gZG9sb3Igc2l0IGFtZXQsIGNvbnNlY3RldHVlciBhZGlwaXNjaW5nIGVsaXQuIFNlZCBwb3N1ZXJlIGludGVyZHVtIHNlbS4gUXVpc3F1ZSBsaWd1bGEgZXJvcyB1bGxhbWNvcnBlciBxdWlzLCBsYWNpbmlhIHF1aXMgZmFjaWxpc2lzIHNlZCBzYXBpZW4uCg"
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
assert.ErrorIs(t, err, InvalidHash)
assert.ErrorIs(t, err, ErrInvalidHash)
}
func TestTooShortHash(t *testing.T) {
file := parseHelloWorld()
file.Hashes.SHA256 = "5/Gy1JftyyQ"
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
assert.ErrorIs(t, err, InvalidHash)
assert.ErrorIs(t, err, ErrInvalidHash)
}

View file

@ -63,8 +63,8 @@ func (mach *OlmMachine) GetCrossSigningPublicKeys(ctx context.Context, userID id
if len(dbKeys) > 0 {
masterKey, ok := dbKeys[id.XSUsageMaster]
if ok {
selfSigning, _ := dbKeys[id.XSUsageSelfSigning]
userSigning, _ := dbKeys[id.XSUsageUserSigning]
selfSigning := dbKeys[id.XSUsageSelfSigning]
userSigning := dbKeys[id.XSUsageUserSigning]
return &CrossSigningPublicKeysCache{
MasterKey: masterKey.Key,
SelfSigningKey: selfSigning.Key,

View file

@ -26,24 +26,22 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK
log.Error().Err(err).
Msg("Error fetching current cross-signing keys of user")
}
if currentKeys != nil {
for curKeyUsage, curKey := range currentKeys {
log := log.With().Stringer("old_key", curKey.Key).Str("old_key_usage", string(curKeyUsage)).Logger()
// got a new key with the same usage as an existing key
for _, newKeyUsage := range userKeys.Usage {
if newKeyUsage == curKeyUsage {
if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok {
// old key is not in the new key map, so we drop signatures made by it
if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil {
log.Error().Err(err).Msg("Error deleting old signatures made by user")
} else {
log.Debug().
Int64("signature_count", count).
Msg("Dropped signatures made by old key as it has been replaced")
}
for curKeyUsage, curKey := range currentKeys {
log := log.With().Stringer("old_key", curKey.Key).Str("old_key_usage", string(curKeyUsage)).Logger()
// got a new key with the same usage as an existing key
for _, newKeyUsage := range userKeys.Usage {
if newKeyUsage == curKeyUsage {
if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok {
// old key is not in the new key map, so we drop signatures made by it
if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil {
log.Error().Err(err).Msg("Error deleting old signatures made by user")
} else {
log.Debug().
Int64("signature_count", count).
Msg("Dropped signatures made by old key as it has been replaced")
}
break
}
break
}
}
}

View file

@ -278,7 +278,7 @@ func (helper *CryptoHelper) verifyDeviceKeysOnServer(ctx context.Context) error
}
}
var NoSessionFound = crypto.NoSessionFound
var NoSessionFound = crypto.ErrNoSessionFound
const initialSessionWaitTimeout = 3 * time.Second
const extendedSessionWaitTimeout = 22 * time.Second
@ -371,6 +371,7 @@ func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, evt *event
content := evt.Content.AsEncrypted()
log.Debug().Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).Msg("Couldn't find session, requesting keys and waiting longer...")
//lint:ignore SA1019 RequestSession will gracefully request from all devices if DeviceID is blank
go helper.RequestSession(context.TODO(), evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID)
if !helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) {
@ -418,7 +419,7 @@ func (helper *CryptoHelper) EncryptWithStateKey(ctx context.Context, roomID id.R
defer helper.lock.RUnlock()
encrypted, err = helper.mach.EncryptMegolmEventWithStateKey(ctx, roomID, evtType, stateKey, content)
if err != nil {
if !errors.Is(err, crypto.SessionExpired) && err != crypto.NoGroupSession && !errors.Is(err, crypto.SessionNotShared) {
if !errors.Is(err, crypto.ErrSessionExpired) && err != crypto.ErrNoGroupSession && !errors.Is(err, crypto.ErrSessionNotShared) {
return
}
helper.log.Debug().

View file

@ -24,13 +24,24 @@ import (
)
var (
IncorrectEncryptedContentType = errors.New("event content is not instance of *event.EncryptedEventContent")
NoSessionFound = errors.New("failed to decrypt megolm event: no session with given ID found")
DuplicateMessageIndex = errors.New("duplicate megolm message index")
WrongRoom = errors.New("encrypted megolm event is not intended for this room")
DeviceKeyMismatch = errors.New("device keys in event and verified device info do not match")
SenderKeyMismatch = errors.New("sender keys in content and megolm session do not match")
RatchetError = errors.New("failed to ratchet session after use")
ErrIncorrectEncryptedContentType = errors.New("event content is not instance of *event.EncryptedEventContent")
ErrNoSessionFound = errors.New("failed to decrypt megolm event: no session with given ID found")
ErrDuplicateMessageIndex = errors.New("duplicate megolm message index")
ErrWrongRoom = errors.New("encrypted megolm event is not intended for this room")
ErrDeviceKeyMismatch = errors.New("device keys in event and verified device info do not match")
ErrSenderKeyMismatch = errors.New("sender keys in content and megolm session do not match")
ErrRatchetError = errors.New("failed to ratchet session after use")
)
// Deprecated: use variables prefixed with Err
var (
IncorrectEncryptedContentType = ErrIncorrectEncryptedContentType
NoSessionFound = ErrNoSessionFound
DuplicateMessageIndex = ErrDuplicateMessageIndex
WrongRoom = ErrWrongRoom
DeviceKeyMismatch = ErrDeviceKeyMismatch
SenderKeyMismatch = ErrSenderKeyMismatch
RatchetError = ErrRatchetError
)
type megolmEvent struct {
@ -49,9 +60,9 @@ var (
func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event) (*event.Event, error) {
content, ok := evt.Content.Parsed.(*event.EncryptedEventContent)
if !ok {
return nil, IncorrectEncryptedContentType
return nil, ErrIncorrectEncryptedContentType
} else if content.Algorithm != id.AlgorithmMegolmV1 {
return nil, UnsupportedAlgorithm
return nil, ErrUnsupportedAlgorithm
}
log := mach.machOrContextLog(ctx).With().
Str("action", "decrypt megolm event").
@ -97,7 +108,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event
Msg("Couldn't resolve trust level of session: sent by unknown device")
trustLevel = id.TrustStateUnknownDevice
} else if device.SigningKey != sess.SigningKey || device.IdentityKey != sess.SenderKey {
return nil, DeviceKeyMismatch
return nil, ErrDeviceKeyMismatch
} else {
trustLevel, err = mach.ResolveTrustContext(ctx, device)
if err != nil {
@ -147,7 +158,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event
if err != nil {
return nil, fmt.Errorf("failed to parse megolm payload: %w", err)
} else if megolmEvt.RoomID != encryptionRoomID {
return nil, WrongRoom
return nil, ErrWrongRoom
}
if evt.StateKey != nil && megolmEvt.StateKey != nil && mach.AllowEncryptedState {
megolmEvt.Type.Class = event.StateEventType
@ -201,19 +212,19 @@ func (mach *OlmMachine) checkUndecryptableMessageIndexDuplication(ctx context.Co
messageIndex, decodeErr := ParseMegolmMessageIndex(content.MegolmCiphertext)
if decodeErr != nil {
log.Warn().Err(decodeErr).Msg("Failed to parse message index to check if it's a duplicate for message that failed to decrypt")
return 0, fmt.Errorf("%w (also failed to parse message index)", olm.UnknownMessageIndex)
return 0, fmt.Errorf("%w (also failed to parse message index)", olm.ErrUnknownMessageIndex)
}
firstKnown := sess.Internal.FirstKnownIndex()
log = log.With().Uint("message_index", messageIndex).Uint32("first_known_index", firstKnown).Logger()
if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil {
log.Debug().Err(err).Msg("Failed to check if message index is duplicate")
return messageIndex, fmt.Errorf("%w (failed to check if index is duplicate; received: %d, earliest known: %d)", olm.UnknownMessageIndex, messageIndex, firstKnown)
return messageIndex, fmt.Errorf("%w (failed to check if index is duplicate; received: %d, earliest known: %d)", olm.ErrUnknownMessageIndex, messageIndex, firstKnown)
} else if !ok {
log.Debug().Msg("Failed to decrypt message due to unknown index and found duplicate")
return messageIndex, fmt.Errorf("%w %d (also failed to decrypt because earliest known index is %d)", DuplicateMessageIndex, messageIndex, firstKnown)
return messageIndex, fmt.Errorf("%w %d (also failed to decrypt because earliest known index is %d)", ErrDuplicateMessageIndex, messageIndex, firstKnown)
}
log.Debug().Msg("Failed to decrypt message due to unknown index, but index is not duplicate")
return messageIndex, fmt.Errorf("%w (not duplicate index; received: %d, earliest known: %d)", olm.UnknownMessageIndex, messageIndex, firstKnown)
return messageIndex, fmt.Errorf("%w (not duplicate index; received: %d, earliest known: %d)", olm.ErrUnknownMessageIndex, messageIndex, firstKnown)
}
func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *event.Event, encryptionRoomID id.RoomID, content *event.EncryptedEventContent) (*InboundGroupSession, []byte, uint, error) {
@ -224,13 +235,13 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
if err != nil {
return nil, nil, 0, fmt.Errorf("failed to get group session: %w", err)
} else if sess == nil {
return nil, nil, 0, fmt.Errorf("%w (ID %s)", NoSessionFound, content.SessionID)
return nil, nil, 0, fmt.Errorf("%w (ID %s)", ErrNoSessionFound, content.SessionID)
} else if content.SenderKey != "" && content.SenderKey != sess.SenderKey {
return sess, nil, 0, SenderKeyMismatch
return sess, nil, 0, ErrSenderKeyMismatch
}
plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext)
if err != nil {
if errors.Is(err, olm.UnknownMessageIndex) && mach.RatchetKeysOnDecrypt {
if errors.Is(err, olm.ErrUnknownMessageIndex) && mach.RatchetKeysOnDecrypt {
messageIndex, err = mach.checkUndecryptableMessageIndexDuplication(ctx, sess, evt, content)
return sess, nil, messageIndex, fmt.Errorf("failed to decrypt megolm event: %w", err)
}
@ -238,7 +249,7 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
} else if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil {
return sess, nil, messageIndex, fmt.Errorf("failed to check if message index is duplicate: %w", err)
} else if !ok {
return sess, nil, messageIndex, fmt.Errorf("%w %d", DuplicateMessageIndex, messageIndex)
return sess, nil, messageIndex, fmt.Errorf("%w %d", ErrDuplicateMessageIndex, messageIndex)
}
// Normal clients don't care about tracking the ratchet state, so let them bypass the rest of the function
@ -290,24 +301,24 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
err = mach.CryptoStore.RedactGroupSession(ctx, sess.RoomID, sess.ID(), "maximum messages reached")
if err != nil {
log.Err(err).Msg("Failed to delete fully used session")
return sess, plaintext, messageIndex, RatchetError
return sess, plaintext, messageIndex, ErrRatchetError
} else {
log.Info().Msg("Deleted fully used session")
}
} else if ratchetCurrentIndex < ratchetTargetIndex && mach.RatchetKeysOnDecrypt {
if err = sess.RatchetTo(ratchetTargetIndex); err != nil {
log.Err(err).Msg("Failed to ratchet session")
return sess, plaintext, messageIndex, RatchetError
return sess, plaintext, messageIndex, ErrRatchetError
} else if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil {
log.Err(err).Msg("Failed to store ratcheted session")
return sess, plaintext, messageIndex, RatchetError
return sess, plaintext, messageIndex, ErrRatchetError
} else {
log.Info().Msg("Ratcheted session forward")
}
} else if didModify {
if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil {
log.Err(err).Msg("Failed to store updated ratchet safety data")
return sess, plaintext, messageIndex, RatchetError
return sess, plaintext, messageIndex, ErrRatchetError
} else {
log.Debug().Msg("Ratchet safety data changed (ratchet state didn't change)")
}

View file

@ -26,15 +26,27 @@ import (
)
var (
UnsupportedAlgorithm = errors.New("unsupported event encryption algorithm")
NotEncryptedForMe = errors.New("olm event doesn't contain ciphertext for this device")
UnsupportedOlmMessageType = errors.New("unsupported olm message type")
DecryptionFailedWithMatchingSession = errors.New("decryption failed with matching session")
DecryptionFailedForNormalMessage = errors.New("decryption failed for normal message")
SenderMismatch = errors.New("mismatched sender in olm payload")
RecipientMismatch = errors.New("mismatched recipient in olm payload")
RecipientKeyMismatch = errors.New("mismatched recipient key in olm payload")
ErrDuplicateMessage = errors.New("duplicate olm message")
ErrUnsupportedAlgorithm = errors.New("unsupported event encryption algorithm")
ErrNotEncryptedForMe = errors.New("olm event doesn't contain ciphertext for this device")
ErrUnsupportedOlmMessageType = errors.New("unsupported olm message type")
ErrDecryptionFailedWithMatchingSession = errors.New("decryption failed with matching session")
ErrDecryptionFailedForNormalMessage = errors.New("decryption failed for normal message")
ErrSenderMismatch = errors.New("mismatched sender in olm payload")
ErrRecipientMismatch = errors.New("mismatched recipient in olm payload")
ErrRecipientKeyMismatch = errors.New("mismatched recipient key in olm payload")
ErrDuplicateMessage = errors.New("duplicate olm message")
)
// Deprecated: use variables prefixed with Err
var (
UnsupportedAlgorithm = ErrUnsupportedAlgorithm
NotEncryptedForMe = ErrNotEncryptedForMe
UnsupportedOlmMessageType = ErrUnsupportedOlmMessageType
DecryptionFailedWithMatchingSession = ErrDecryptionFailedWithMatchingSession
DecryptionFailedForNormalMessage = ErrDecryptionFailedForNormalMessage
SenderMismatch = ErrSenderMismatch
RecipientMismatch = ErrRecipientMismatch
RecipientKeyMismatch = ErrRecipientKeyMismatch
)
// DecryptedOlmEvent represents an event that was decrypted from an event encrypted with the m.olm.v1.curve25519-aes-sha2 algorithm.
@ -56,13 +68,13 @@ type DecryptedOlmEvent struct {
func (mach *OlmMachine) decryptOlmEvent(ctx context.Context, evt *event.Event) (*DecryptedOlmEvent, error) {
content, ok := evt.Content.Parsed.(*event.EncryptedEventContent)
if !ok {
return nil, IncorrectEncryptedContentType
return nil, ErrIncorrectEncryptedContentType
} else if content.Algorithm != id.AlgorithmOlmV1 {
return nil, UnsupportedAlgorithm
return nil, ErrUnsupportedAlgorithm
}
ownContent, ok := content.OlmCiphertext[mach.account.IdentityKey()]
if !ok {
return nil, NotEncryptedForMe
return nil, ErrNotEncryptedForMe
}
decrypted, err := mach.decryptAndParseOlmCiphertext(ctx, evt, content.SenderKey, ownContent.Type, ownContent.Body)
if err != nil {
@ -78,7 +90,7 @@ type OlmEventKeys struct {
func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *event.Event, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) (*DecryptedOlmEvent, error) {
if olmType != id.OlmMsgTypePreKey && olmType != id.OlmMsgTypeMsg {
return nil, UnsupportedOlmMessageType
return nil, ErrUnsupportedOlmMessageType
}
log := mach.machOrContextLog(ctx).With().
@ -102,11 +114,11 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *e
}
olmEvt.Type.Class = evt.Type.Class
if evt.Sender != olmEvt.Sender {
return nil, SenderMismatch
return nil, ErrSenderMismatch
} else if mach.Client.UserID != olmEvt.Recipient {
return nil, RecipientMismatch
return nil, ErrRecipientMismatch
} else if mach.account.SigningKey() != olmEvt.RecipientKeys.Ed25519 {
return nil, RecipientKeyMismatch
return nil, ErrRecipientKeyMismatch
}
if len(olmEvt.Content.VeryRaw) > 0 {
@ -151,7 +163,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
plaintext, err := mach.tryDecryptOlmCiphertextWithExistingSession(ctx, senderKey, olmType, ciphertext, ciphertextHash)
if err != nil {
if err == DecryptionFailedWithMatchingSession {
if err == ErrDecryptionFailedWithMatchingSession {
log.Warn().Msg("Found matching session, but decryption failed")
go mach.unwedgeDevice(log, sender, senderKey)
}
@ -169,10 +181,10 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
// if it isn't one at this point in time anymore, so return early.
if olmType != id.OlmMsgTypePreKey {
go mach.unwedgeDevice(log, sender, senderKey)
return nil, DecryptionFailedForNormalMessage
return nil, ErrDecryptionFailedForNormalMessage
}
accountBackup, err := mach.account.Internal.Pickle([]byte("tmp"))
accountBackup, _ := mach.account.Internal.Pickle([]byte("tmp"))
log.Trace().Msg("Trying to create inbound session")
endTimeTrace = mach.timeTrace(ctx, "creating inbound olm session", time.Second)
session, err := mach.createInboundSession(ctx, senderKey, ciphertext)
@ -302,7 +314,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(
Str("session_description", session.Describe()).
Msg("Failed to decrypt olm message")
if olmType == id.OlmMsgTypePreKey {
return nil, DecryptionFailedWithMatchingSession
return nil, ErrDecryptionFailedWithMatchingSession
}
} else {
endTimeTrace = mach.timeTrace(ctx, "updating session in database", time.Second)
@ -345,7 +357,7 @@ func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, send
ctx := log.WithContext(mach.backgroundCtx)
mach.recentlyUnwedgedLock.Lock()
prevUnwedge, ok := mach.recentlyUnwedged[senderKey]
delta := time.Now().Sub(prevUnwedge)
delta := time.Since(prevUnwedge)
if ok && delta < MinUnwedgeInterval {
log.Debug().
Str("previous_recreation", delta.String()).

View file

@ -22,14 +22,23 @@ import (
)
var (
MismatchingDeviceID = errors.New("mismatching device ID in parameter and keys object")
MismatchingUserID = errors.New("mismatching user ID in parameter and keys object")
MismatchingSigningKey = errors.New("received update for device with different signing key")
NoSigningKeyFound = errors.New("didn't find ed25519 signing key")
NoIdentityKeyFound = errors.New("didn't find curve25519 identity key")
InvalidKeySignature = errors.New("invalid signature on device keys")
ErrMismatchingDeviceID = errors.New("mismatching device ID in parameter and keys object")
ErrMismatchingUserID = errors.New("mismatching user ID in parameter and keys object")
ErrMismatchingSigningKey = errors.New("received update for device with different signing key")
ErrNoSigningKeyFound = errors.New("didn't find ed25519 signing key")
ErrNoIdentityKeyFound = errors.New("didn't find curve25519 identity key")
ErrInvalidKeySignature = errors.New("invalid signature on device keys")
ErrUserNotTracked = errors.New("user is not tracked")
)
ErrUserNotTracked = errors.New("user is not tracked")
// Deprecated: use variables prefixed with Err
var (
MismatchingDeviceID = ErrMismatchingDeviceID
MismatchingUserID = ErrMismatchingUserID
MismatchingSigningKey = ErrMismatchingSigningKey
NoSigningKeyFound = ErrNoSigningKeyFound
NoIdentityKeyFound = ErrNoIdentityKeyFound
InvalidKeySignature = ErrInvalidKeySignature
)
func (mach *OlmMachine) LoadDevices(ctx context.Context, user id.UserID) (keys map[id.DeviceID]*id.Device) {
@ -312,28 +321,28 @@ func (mach *OlmMachine) OnDevicesChanged(ctx context.Context, userID id.UserID)
func (mach *OlmMachine) validateDevice(userID id.UserID, deviceID id.DeviceID, deviceKeys mautrix.DeviceKeys, existing *id.Device) (*id.Device, error) {
if deviceID != deviceKeys.DeviceID {
return nil, fmt.Errorf("%w (expected %s, got %s)", MismatchingDeviceID, deviceID, deviceKeys.DeviceID)
return nil, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingDeviceID, deviceID, deviceKeys.DeviceID)
} else if userID != deviceKeys.UserID {
return nil, fmt.Errorf("%w (expected %s, got %s)", MismatchingUserID, userID, deviceKeys.UserID)
return nil, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingUserID, userID, deviceKeys.UserID)
}
signingKey := deviceKeys.Keys.GetEd25519(deviceID)
identityKey := deviceKeys.Keys.GetCurve25519(deviceID)
if signingKey == "" {
return nil, NoSigningKeyFound
return nil, ErrNoSigningKeyFound
} else if identityKey == "" {
return nil, NoIdentityKeyFound
return nil, ErrNoIdentityKeyFound
}
if existing != nil && existing.SigningKey != signingKey {
return existing, fmt.Errorf("%w (expected %s, got %s)", MismatchingSigningKey, existing.SigningKey, signingKey)
return existing, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingSigningKey, existing.SigningKey, signingKey)
}
ok, err := signatures.VerifySignatureJSON(deviceKeys, userID, deviceID.String(), signingKey)
if err != nil {
return existing, fmt.Errorf("failed to verify signature: %w", err)
} else if !ok {
return existing, InvalidKeySignature
return existing, ErrInvalidKeySignature
}
name, ok := deviceKeys.Unsigned["device_display_name"].(string)

View file

@ -25,7 +25,12 @@ import (
)
var (
NoGroupSession = errors.New("no group session created")
ErrNoGroupSession = errors.New("no group session created")
)
// Deprecated: use variables prefixed with Err
var (
NoGroupSession = ErrNoGroupSession
)
func getRawJSON[T any](content json.RawMessage, path ...string) *T {
@ -82,7 +87,7 @@ type rawMegolmEvent struct {
// IsShareError returns true if the error is caused by the lack of an outgoing megolm session and can be solved with OlmMachine.ShareGroupSession
func IsShareError(err error) bool {
return err == SessionExpired || err == SessionNotShared || err == NoGroupSession
return err == ErrSessionExpired || err == ErrSessionNotShared || err == ErrNoGroupSession
}
func ParseMegolmMessageIndex(ciphertext []byte) (uint, error) {
@ -120,7 +125,7 @@ func (mach *OlmMachine) EncryptMegolmEventWithStateKey(ctx context.Context, room
if err != nil {
return nil, fmt.Errorf("failed to get outbound group session: %w", err)
} else if session == nil {
return nil, NoGroupSession
return nil, ErrNoGroupSession
}
plaintext, err := json.Marshal(&rawMegolmEvent{
RoomID: roomID,

View file

@ -334,7 +334,7 @@ func (a *Account) UnpickleLibOlm(buf []byte) error {
if err != nil {
return err
} else if pickledVersion != accountPickleVersionLibOLM && pickledVersion != 3 && pickledVersion != 2 {
return fmt.Errorf("unpickle account: %w (found version %d)", olm.ErrBadVersion, pickledVersion)
return fmt.Errorf("unpickle account: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion)
} else if err = a.IdKeys.Ed25519.UnpickleLibOlm(decoder); err != nil { // read the ed25519 key pair
return err
} else if err = a.IdKeys.Curve25519.UnpickleLibOlm(decoder); err != nil { // read curve25519 key pair

View file

@ -124,7 +124,7 @@ func TestOldAccountPickle(t *testing.T) {
account, err := account.NewAccount()
assert.NoError(t, err)
err = account.Unpickle(pickled, pickleKey)
assert.ErrorIs(t, err, olm.ErrBadVersion)
assert.ErrorIs(t, err, olm.ErrUnknownOlmPickleVersion)
}
func TestLoopback(t *testing.T) {

View file

@ -4,7 +4,8 @@ import (
"encoding/base64"
)
// Deprecated: base64.RawStdEncoding should be used directly
// These methods should only be used for raw byte operations, never with string conversion
func Decode(input []byte) ([]byte, error) {
decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(input)))
writtenBytes, err := base64.RawStdEncoding.Decode(decoded, input)
@ -14,7 +15,6 @@ func Decode(input []byte) ([]byte, error) {
return decoded[:writtenBytes], nil
}
// Deprecated: base64.RawStdEncoding should be used directly
func Encode(input []byte) []byte {
encoded := make([]byte, base64.RawStdEncoding.EncodedLen(len(input)))
base64.RawStdEncoding.Encode(encoded, input)

View file

@ -50,7 +50,7 @@ func UnpickleAsJSON(object any, pickled, key []byte, pickleVersion byte) error {
}
}
if decrypted[0] != pickleVersion {
return fmt.Errorf("unpickle: %w", olm.ErrWrongPickleVersion)
return fmt.Errorf("unpickle: %w", olm.ErrUnknownJSONPickleVersion)
}
err = json.Unmarshal(decrypted[1:], object)
if err != nil {

View file

@ -35,7 +35,7 @@ func (s *MegolmSessionExport) Decode(input []byte) error {
return fmt.Errorf("decrypt: %w", olm.ErrBadInput)
}
if input[0] != sessionExportVersion {
return fmt.Errorf("decrypt: %w", olm.ErrBadVersion)
return fmt.Errorf("decrypt: %w", olm.ErrUnknownOlmPickleVersion)
}
s.Counter = binary.BigEndian.Uint32(input[1:5])
copy(s.RatchetData[:], input[5:133])

View file

@ -42,7 +42,7 @@ func (s *MegolmSessionSharing) VerifyAndDecode(input []byte) error {
}
s.PublicKey = publicKey
if input[0] != sessionSharingVersion {
return fmt.Errorf("verify: %w", olm.ErrBadVersion)
return fmt.Errorf("verify: %w", olm.ErrUnknownOlmPickleVersion)
}
s.Counter = binary.BigEndian.Uint32(input[1:5])
copy(s.RatchetData[:], input[5:133])

View file

@ -103,7 +103,7 @@ func (a *Decryption) UnpickleLibOlm(unpickled []byte) error {
if pickledVersion == decryptionPickleVersionLibOlm {
return a.KeyPair.UnpickleLibOlm(decoder)
} else {
return fmt.Errorf("unpickle olmSession: %w (found %d, expected %d)", olm.ErrBadVersion, pickledVersion, decryptionPickleVersionLibOlm)
return fmt.Errorf("unpickle olmSession: %w (found %d, expected %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion, decryptionPickleVersionLibOlm)
}
}

View file

@ -37,6 +37,9 @@ func (e Encryption) Encrypt(plaintext []byte, privateKey crypto.Curve25519Privat
return nil, nil, err
}
cipher, err := aessha2.NewAESSHA2(sharedSecret, nil)
if err != nil {
return nil, nil, err
}
ciphertext, err = cipher.Encrypt(plaintext)
if err != nil {
return nil, nil, err

View file

@ -99,7 +99,7 @@ func (o *MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet,
}
if (messageIndex - o.InitialRatchet.Counter) >= uint32(1<<31) {
// the counter is before our initial ratchet - we can't decode this
return nil, fmt.Errorf("decrypt: %w", olm.ErrRatchetNotAvailable)
return nil, fmt.Errorf("decrypt: %w", olm.ErrUnknownMessageIndex)
}
// otherwise, start from the initial ratchet. Take a copy so that we don't overwrite the initial ratchet
copiedRatchet := o.InitialRatchet
@ -206,7 +206,7 @@ func (o *MegolmInboundSession) UnpickleLibOlm(value []byte) error {
return err
}
if pickledVersion != megolmInboundSessionPickleVersionLibOlm && pickledVersion != 1 {
return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion)
return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion)
}
if err = o.InitialRatchet.UnpickleLibOlm(decoder); err != nil {

View file

@ -101,8 +101,10 @@ func (o *MegolmOutboundSession) Unpickle(pickled, key []byte) error {
func (o *MegolmOutboundSession) UnpickleLibOlm(buf []byte) error {
decoder := libolmpickle.NewDecoder(buf)
pickledVersion, err := decoder.ReadUInt32()
if pickledVersion != megolmOutboundSessionPickleVersionLibOlm {
return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion)
if err != nil {
return fmt.Errorf("unpickle MegolmOutboundSession: failed to read version: %w", err)
} else if pickledVersion != megolmOutboundSessionPickleVersionLibOlm {
return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion)
}
if err = o.Ratchet.UnpickleLibOlm(decoder); err != nil {
return err

View file

@ -168,11 +168,11 @@ func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, received
msg := message.Message{}
err = msg.Decode(oneTimeMsg.Message)
if err != nil {
return nil, fmt.Errorf("Message decode: %w", err)
return nil, fmt.Errorf("message decode: %w", err)
}
if len(msg.RatchetKey) == 0 {
return nil, fmt.Errorf("Message missing ratchet key: %w", olm.ErrBadMessageFormat)
return nil, fmt.Errorf("message missing ratchet key: %w", olm.ErrBadMessageFormat)
}
//Init Ratchet
s.Ratchet.InitializeAsBob(secret, msg.RatchetKey)
@ -203,7 +203,7 @@ func (s *OlmSession) ID() id.SessionID {
copy(message[crypto.Curve25519PrivateKeyLength:], s.AliceBaseKey)
copy(message[2*crypto.Curve25519PrivateKeyLength:], s.BobOneTimeKey)
hash := sha256.Sum256(message)
res := id.SessionID(goolmbase64.Encode(hash[:]))
res := id.SessionID(base64.RawStdEncoding.EncodeToString(hash[:]))
return res
}
@ -325,7 +325,7 @@ func (s *OlmSession) Decrypt(crypttext string, msgType id.OlmMsgType) ([]byte, e
if len(crypttext) == 0 {
return nil, fmt.Errorf("decrypt: %w", olm.ErrEmptyInput)
}
decodedCrypttext, err := goolmbase64.Decode([]byte(crypttext))
decodedCrypttext, err := base64.RawStdEncoding.DecodeString(crypttext)
if err != nil {
return nil, err
}
@ -365,6 +365,9 @@ func (o *OlmSession) Unpickle(pickled, key []byte) error {
func (o *OlmSession) UnpickleLibOlm(buf []byte) error {
decoder := libolmpickle.NewDecoder(buf)
pickledVersion, err := decoder.ReadUInt32()
if err != nil {
return fmt.Errorf("unpickle olmSession: failed to read version: %w", err)
}
var includesChainIndex bool
switch pickledVersion {
@ -373,7 +376,7 @@ func (o *OlmSession) UnpickleLibOlm(buf []byte) error {
case uint32(0x80000001):
includesChainIndex = true
default:
return fmt.Errorf("unpickle olmSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion)
return fmt.Errorf("unpickle olmSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion)
}
if o.ReceivedMessage, err = decoder.ReadBool(); err != nil {

View file

@ -14,7 +14,7 @@ func Register() {
// Inbound Session
olm.InitInboundGroupSessionFromPickled = func(pickled, key []byte) (olm.InboundGroupSession, error) {
if len(pickled) == 0 {
return nil, olm.EmptyInput
return nil, olm.ErrEmptyInput
}
if len(key) == 0 {
key = []byte(" ")
@ -23,13 +23,13 @@ func Register() {
}
olm.InitNewInboundGroupSession = func(sessionKey []byte) (olm.InboundGroupSession, error) {
if len(sessionKey) == 0 {
return nil, olm.EmptyInput
return nil, olm.ErrEmptyInput
}
return NewMegolmInboundSession(sessionKey)
}
olm.InitInboundGroupSessionImport = func(sessionKey []byte) (olm.InboundGroupSession, error) {
if len(sessionKey) == 0 {
return nil, olm.EmptyInput
return nil, olm.ErrEmptyInput
}
return NewMegolmInboundSessionFromExport(sessionKey)
}
@ -40,7 +40,7 @@ func Register() {
// Outbound Session
olm.InitNewOutboundGroupSessionFromPickled = func(pickled, key []byte) (olm.OutboundGroupSession, error) {
if len(pickled) == 0 {
return nil, olm.EmptyInput
return nil, olm.ErrEmptyInput
}
lenKey := len(key)
if lenKey == 0 {

View file

@ -56,11 +56,12 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context,
// ...by deriving the public key from a private key that it obtained from a trusted source. Trusted sources for the private
// key include the user entering the key, retrieving the key stored in secret storage, or obtaining the key via secret sharing
// from a verified device belonging to the same user."
megolmBackupDerivedPublicKey := id.Ed25519(base64.RawStdEncoding.EncodeToString(megolmBackupKey.PublicKey().Bytes()))
if megolmBackupKey != nil && versionInfo.AuthData.PublicKey == megolmBackupDerivedPublicKey {
log.Debug().Msg("key backup is trusted based on derived public key")
return versionInfo, nil
} else {
if megolmBackupKey != nil {
megolmBackupDerivedPublicKey := id.Ed25519(base64.RawStdEncoding.EncodeToString(megolmBackupKey.PublicKey().Bytes()))
if versionInfo.AuthData.PublicKey == megolmBackupDerivedPublicKey {
log.Debug().Msg("Key backup is trusted based on derived public key")
return versionInfo, nil
}
log.Debug().
Stringer("expected_key", megolmBackupDerivedPublicKey).
Stringer("actual_key", versionInfo.AuthData.PublicKey).

View file

@ -33,7 +33,7 @@ var _ olm.Account = (*Account)(nil)
// "INVALID_BASE64".
func AccountFromPickled(pickled, key []byte) (*Account, error) {
if len(pickled) == 0 {
return nil, olm.EmptyInput
return nil, olm.ErrEmptyInput
}
a := NewBlankAccount()
return a, a.Unpickle(pickled, key)
@ -53,7 +53,7 @@ func NewAccount() (*Account, error) {
random := make([]byte, a.createRandomLen()+1)
_, err := rand.Read(random)
if err != nil {
panic(olm.NotEnoughGoRandom)
panic(olm.ErrNotEnoughGoRandom)
}
ret := C.olm_create_account(
(*C.OlmAccount)(a.int),
@ -128,7 +128,7 @@ func (a *Account) genOneTimeKeysRandomLen(num uint) uint {
// supplied key.
func (a *Account) Pickle(key []byte) ([]byte, error) {
if len(key) == 0 {
return nil, olm.NoKeyProvided
return nil, olm.ErrNoKeyProvided
}
pickled := make([]byte, a.pickleLen())
r := C.olm_pickle_account(
@ -145,7 +145,7 @@ func (a *Account) Pickle(key []byte) ([]byte, error) {
func (a *Account) Unpickle(pickled, key []byte) error {
if len(key) == 0 {
return olm.NoKeyProvided
return olm.ErrNoKeyProvided
}
r := C.olm_unpickle_account(
(*C.OlmAccount)(a.int),
@ -198,7 +198,7 @@ func (a *Account) MarshalJSON() ([]byte, error) {
// Deprecated
func (a *Account) UnmarshalJSON(data []byte) error {
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
return olm.InputNotJSONString
return olm.ErrInputNotJSONString
}
if a.int == nil {
*a = *NewBlankAccount()
@ -235,7 +235,7 @@ func (a *Account) IdentityKeys() (id.Ed25519, id.Curve25519, error) {
// Account.
func (a *Account) Sign(message []byte) ([]byte, error) {
if len(message) == 0 {
panic(olm.EmptyInput)
panic(olm.ErrEmptyInput)
}
signature := make([]byte, a.signatureLen())
r := C.olm_account_sign(
@ -299,7 +299,7 @@ func (a *Account) GenOneTimeKeys(num uint) error {
random := make([]byte, a.genOneTimeKeysRandomLen(num)+1)
_, err := rand.Read(random)
if err != nil {
return olm.NotEnoughGoRandom
return olm.ErrNotEnoughGoRandom
}
r := C.olm_account_generate_one_time_keys(
(*C.OlmAccount)(a.int),
@ -319,13 +319,13 @@ func (a *Account) GenOneTimeKeys(num uint) error {
// keys couldn't be decoded as base64 then the error will be "INVALID_BASE64"
func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (olm.Session, error) {
if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 {
return nil, olm.EmptyInput
return nil, olm.ErrEmptyInput
}
s := NewBlankSession()
random := make([]byte, s.createOutboundRandomLen()+1)
_, err := rand.Read(random)
if err != nil {
panic(olm.NotEnoughGoRandom)
panic(olm.ErrNotEnoughGoRandom)
}
theirIdentityKeyCopy := []byte(theirIdentityKey)
theirOneTimeKeyCopy := []byte(theirOneTimeKey)
@ -357,7 +357,7 @@ func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve2
// time key then the error will be "BAD_MESSAGE_KEY_ID".
func (a *Account) NewInboundSession(oneTimeKeyMsg string) (olm.Session, error) {
if len(oneTimeKeyMsg) == 0 {
return nil, olm.EmptyInput
return nil, olm.ErrEmptyInput
}
s := NewBlankSession()
oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg)
@ -383,7 +383,7 @@ func (a *Account) NewInboundSession(oneTimeKeyMsg string) (olm.Session, error) {
// time key then the error will be "BAD_MESSAGE_KEY_ID".
func (a *Account) NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTimeKeyMsg string) (olm.Session, error) {
if theirIdentityKey == nil || len(oneTimeKeyMsg) == 0 {
return nil, olm.EmptyInput
return nil, olm.ErrEmptyInput
}
theirIdentityKeyCopy := []byte(*theirIdentityKey)
oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg)

View file

@ -11,21 +11,21 @@ import (
)
var errorMap = map[string]error{
"NOT_ENOUGH_RANDOM": olm.NotEnoughRandom,
"OUTPUT_BUFFER_TOO_SMALL": olm.OutputBufferTooSmall,
"BAD_MESSAGE_VERSION": olm.BadMessageVersion,
"BAD_MESSAGE_FORMAT": olm.BadMessageFormat,
"BAD_MESSAGE_MAC": olm.BadMessageMAC,
"BAD_MESSAGE_KEY_ID": olm.BadMessageKeyID,
"INVALID_BASE64": olm.InvalidBase64,
"BAD_ACCOUNT_KEY": olm.BadAccountKey,
"UNKNOWN_PICKLE_VERSION": olm.UnknownPickleVersion,
"CORRUPTED_PICKLE": olm.CorruptedPickle,
"BAD_SESSION_KEY": olm.BadSessionKey,
"UNKNOWN_MESSAGE_INDEX": olm.UnknownMessageIndex,
"BAD_LEGACY_ACCOUNT_PICKLE": olm.BadLegacyAccountPickle,
"BAD_SIGNATURE": olm.BadSignature,
"INPUT_BUFFER_TOO_SMALL": olm.InputBufferTooSmall,
"NOT_ENOUGH_RANDOM": olm.ErrLibolmNotEnoughRandom,
"OUTPUT_BUFFER_TOO_SMALL": olm.ErrLibolmOutputBufferTooSmall,
"BAD_MESSAGE_VERSION": olm.ErrWrongProtocolVersion,
"BAD_MESSAGE_FORMAT": olm.ErrBadMessageFormat,
"BAD_MESSAGE_MAC": olm.ErrBadMAC,
"BAD_MESSAGE_KEY_ID": olm.ErrBadMessageKeyID,
"INVALID_BASE64": olm.ErrLibolmInvalidBase64,
"BAD_ACCOUNT_KEY": olm.ErrLibolmBadAccountKey,
"UNKNOWN_PICKLE_VERSION": olm.ErrUnknownOlmPickleVersion,
"CORRUPTED_PICKLE": olm.ErrLibolmCorruptedPickle,
"BAD_SESSION_KEY": olm.ErrLibolmBadSessionKey,
"UNKNOWN_MESSAGE_INDEX": olm.ErrUnknownMessageIndex,
"BAD_LEGACY_ACCOUNT_PICKLE": olm.ErrLibolmBadLegacyAccountPickle,
"BAD_SIGNATURE": olm.ErrBadSignature,
"INPUT_BUFFER_TOO_SMALL": olm.ErrInputToSmall,
}
func convertError(errCode string) error {

View file

@ -31,7 +31,7 @@ var _ olm.InboundGroupSession = (*InboundGroupSession)(nil)
// base64 couldn't be decoded then the error will be "INVALID_BASE64".
func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, error) {
if len(pickled) == 0 {
return nil, olm.EmptyInput
return nil, olm.ErrEmptyInput
}
lenKey := len(key)
if lenKey == 0 {
@ -48,7 +48,7 @@ func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession,
// "OLM_BAD_SESSION_KEY".
func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) {
if len(sessionKey) == 0 {
return nil, olm.EmptyInput
return nil, olm.ErrEmptyInput
}
s := NewBlankInboundGroupSession()
r := C.olm_init_inbound_group_session(
@ -69,7 +69,7 @@ func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) {
// error will be "OLM_BAD_SESSION_KEY".
func InboundGroupSessionImport(sessionKey []byte) (*InboundGroupSession, error) {
if len(sessionKey) == 0 {
return nil, olm.EmptyInput
return nil, olm.ErrEmptyInput
}
s := NewBlankInboundGroupSession()
r := C.olm_import_inbound_group_session(
@ -124,7 +124,7 @@ func (s *InboundGroupSession) pickleLen() uint {
// InboundGroupSession using the supplied key.
func (s *InboundGroupSession) Pickle(key []byte) ([]byte, error) {
if len(key) == 0 {
return nil, olm.NoKeyProvided
return nil, olm.ErrNoKeyProvided
}
pickled := make([]byte, s.pickleLen())
r := C.olm_pickle_inbound_group_session(
@ -143,9 +143,9 @@ func (s *InboundGroupSession) Pickle(key []byte) ([]byte, error) {
func (s *InboundGroupSession) Unpickle(pickled, key []byte) error {
if len(key) == 0 {
return olm.NoKeyProvided
return olm.ErrNoKeyProvided
} else if len(pickled) == 0 {
return olm.EmptyInput
return olm.ErrEmptyInput
}
r := C.olm_unpickle_inbound_group_session(
(*C.OlmInboundGroupSession)(s.int),
@ -200,7 +200,7 @@ func (s *InboundGroupSession) MarshalJSON() ([]byte, error) {
// Deprecated
func (s *InboundGroupSession) UnmarshalJSON(data []byte) error {
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
return olm.InputNotJSONString
return olm.ErrInputNotJSONString
}
if s == nil || s.int == nil {
*s = *NewBlankInboundGroupSession()
@ -217,7 +217,7 @@ func (s *InboundGroupSession) UnmarshalJSON(data []byte) error {
// will be "BAD_MESSAGE_FORMAT".
func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, error) {
if len(message) == 0 {
return 0, olm.EmptyInput
return 0, olm.ErrEmptyInput
}
// olm_group_decrypt_max_plaintext_length destroys the input, so we have to clone it
messageCopy := bytes.Clone(message)
@ -244,7 +244,7 @@ func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, erro
// was shared with us) the error will be "OLM_UNKNOWN_MESSAGE_INDEX".
func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) {
if len(message) == 0 {
return nil, 0, olm.EmptyInput
return nil, 0, olm.ErrEmptyInput
}
decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message)
if err != nil {

View file

@ -84,7 +84,7 @@ func (s *OutboundGroupSession) pickleLen() uint {
// OutboundGroupSession using the supplied key.
func (s *OutboundGroupSession) Pickle(key []byte) ([]byte, error) {
if len(key) == 0 {
return nil, olm.NoKeyProvided
return nil, olm.ErrNoKeyProvided
}
pickled := make([]byte, s.pickleLen())
r := C.olm_pickle_outbound_group_session(
@ -103,7 +103,7 @@ func (s *OutboundGroupSession) Pickle(key []byte) ([]byte, error) {
func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error {
if len(key) == 0 {
return olm.NoKeyProvided
return olm.ErrNoKeyProvided
}
r := C.olm_unpickle_outbound_group_session(
(*C.OlmOutboundGroupSession)(s.int),
@ -159,7 +159,7 @@ func (s *OutboundGroupSession) MarshalJSON() ([]byte, error) {
// Deprecated
func (s *OutboundGroupSession) UnmarshalJSON(data []byte) error {
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
return olm.InputNotJSONString
return olm.ErrInputNotJSONString
}
if s == nil || s.int == nil {
*s = *NewBlankOutboundGroupSession()
@ -183,7 +183,7 @@ func (s *OutboundGroupSession) encryptMsgLen(plainTextLen int) uint {
// as base64.
func (s *OutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) {
if len(plaintext) == 0 {
return nil, olm.EmptyInput
return nil, olm.ErrEmptyInput
}
message := make([]byte, s.encryptMsgLen(len(plaintext)))
r := C.olm_group_encrypt(

View file

@ -86,7 +86,7 @@ func NewPKSigning() (*PKSigning, error) {
seed := make([]byte, pkSigningSeedLength())
_, err := rand.Read(seed)
if err != nil {
panic(olm.NotEnoughGoRandom)
panic(olm.ErrNotEnoughGoRandom)
}
pk, err := NewPKSigningFromSeed(seed)
return pk, err

View file

@ -65,7 +65,7 @@ func Register() {
olm.InitNewOutboundGroupSessionFromPickled = func(pickled, key []byte) (olm.OutboundGroupSession, error) {
if len(pickled) == 0 {
return nil, olm.EmptyInput
return nil, olm.ErrEmptyInput
}
s := NewBlankOutboundGroupSession()
return s, s.Unpickle(pickled, key)

View file

@ -51,7 +51,7 @@ func sessionSize() uint {
// "INVALID_BASE64".
func SessionFromPickled(pickled, key []byte) (*Session, error) {
if len(pickled) == 0 {
return nil, olm.EmptyInput
return nil, olm.ErrEmptyInput
}
s := NewBlankSession()
return s, s.Unpickle(pickled, key)
@ -118,7 +118,7 @@ func (s *Session) encryptMsgLen(plainTextLen int) uint {
// will be "BAD_MESSAGE_FORMAT".
func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType) (uint, error) {
if len(message) == 0 {
return 0, olm.EmptyInput
return 0, olm.ErrEmptyInput
}
messageCopy := []byte(message)
r := C.olm_decrypt_max_plaintext_length(
@ -138,7 +138,7 @@ func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType)
// supplied key.
func (s *Session) Pickle(key []byte) ([]byte, error) {
if len(key) == 0 {
return nil, olm.NoKeyProvided
return nil, olm.ErrNoKeyProvided
}
pickled := make([]byte, s.pickleLen())
r := C.olm_pickle_session(
@ -158,7 +158,7 @@ func (s *Session) Pickle(key []byte) ([]byte, error) {
// provided key. This function mutates the input pickled data slice.
func (s *Session) Unpickle(pickled, key []byte) error {
if len(key) == 0 {
return olm.NoKeyProvided
return olm.ErrNoKeyProvided
}
r := C.olm_unpickle_session(
(*C.OlmSession)(s.int),
@ -213,7 +213,7 @@ func (s *Session) MarshalJSON() ([]byte, error) {
// Deprecated
func (s *Session) UnmarshalJSON(data []byte) error {
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
return olm.InputNotJSONString
return olm.ErrInputNotJSONString
}
if s == nil || s.int == nil {
*s = *NewBlankSession()
@ -256,7 +256,7 @@ func (s *Session) HasReceivedMessage() bool {
// decoded then then the error will be "BAD_MESSAGE_FORMAT".
func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) {
if len(oneTimeKeyMsg) == 0 {
return false, olm.EmptyInput
return false, olm.ErrEmptyInput
}
oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg)
r := C.olm_matches_inbound_session(
@ -284,7 +284,7 @@ func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) {
// decoded then then the error will be "BAD_MESSAGE_FORMAT".
func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) {
if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 {
return false, olm.EmptyInput
return false, olm.ErrEmptyInput
}
theirIdentityKeyCopy := []byte(theirIdentityKey)
oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg)
@ -325,14 +325,14 @@ func (s *Session) EncryptMsgType() id.OlmMsgType {
// as base64.
func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) {
if len(plaintext) == 0 {
return 0, nil, olm.EmptyInput
return 0, nil, olm.ErrEmptyInput
}
// Make the slice be at least length 1
random := make([]byte, s.encryptRandomLen()+1)
_, err := rand.Read(random)
if err != nil {
// TODO can we just return err here?
return 0, nil, olm.NotEnoughGoRandom
return 0, nil, olm.ErrNotEnoughGoRandom
}
messageType := s.EncryptMsgType()
message := make([]byte, s.encryptMsgLen(len(plaintext)))
@ -362,7 +362,7 @@ func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) {
// "BAD_MESSAGE_MAC".
func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) {
if len(message) == 0 {
return nil, olm.EmptyInput
return nil, olm.ErrEmptyInput
}
decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message, msgType)
if err != nil {

View file

@ -205,7 +205,7 @@ func (mach *OlmMachine) FlushStore(ctx context.Context) error {
func (mach *OlmMachine) timeTrace(ctx context.Context, thing string, expectedDuration time.Duration) func() {
start := time.Now()
return func() {
duration := time.Now().Sub(start)
duration := time.Since(start)
if duration > expectedDuration {
zerolog.Ctx(ctx).Warn().
Str("action", thing).

View file

@ -10,50 +10,67 @@ import "errors"
// Those are the most common used errors
var (
ErrBadSignature = errors.New("bad signature")
ErrBadMAC = errors.New("bad mac")
ErrBadMessageFormat = errors.New("bad message format")
ErrBadVerification = errors.New("bad verification")
ErrWrongProtocolVersion = errors.New("wrong protocol version")
ErrEmptyInput = errors.New("empty input")
ErrNoKeyProvided = errors.New("no key")
ErrBadMessageKeyID = errors.New("bad message key id")
ErrRatchetNotAvailable = errors.New("ratchet not available: attempt to decode a message whose index is earlier than our earliest known session key")
ErrMsgIndexTooHigh = errors.New("message index too high")
ErrProtocolViolation = errors.New("not protocol message order")
ErrMessageKeyNotFound = errors.New("message key not found")
ErrChainTooHigh = errors.New("chain index too high")
ErrBadInput = errors.New("bad input")
ErrBadVersion = errors.New("wrong version")
ErrWrongPickleVersion = errors.New("wrong pickle version")
ErrInputToSmall = errors.New("input too small (truncated?)")
ErrOverflow = errors.New("overflow")
ErrBadSignature = errors.New("bad signature")
ErrBadMAC = errors.New("the message couldn't be decrypted (bad mac)")
ErrBadMessageFormat = errors.New("the message couldn't be decoded")
ErrBadVerification = errors.New("bad verification")
ErrWrongProtocolVersion = errors.New("wrong protocol version")
ErrEmptyInput = errors.New("empty input")
ErrNoKeyProvided = errors.New("no key provided")
ErrBadMessageKeyID = errors.New("the message references an unknown key ID")
ErrUnknownMessageIndex = errors.New("attempt to decode a message whose index is earlier than our earliest known session key")
ErrMsgIndexTooHigh = errors.New("message index too high")
ErrProtocolViolation = errors.New("not protocol message order")
ErrMessageKeyNotFound = errors.New("message key not found")
ErrChainTooHigh = errors.New("chain index too high")
ErrBadInput = errors.New("bad input")
ErrUnknownOlmPickleVersion = errors.New("unknown olm pickle version")
ErrUnknownJSONPickleVersion = errors.New("unknown JSON pickle version")
ErrInputToSmall = errors.New("input too small (truncated?)")
)
// Error codes from go-olm
var (
EmptyInput = errors.New("empty input")
NoKeyProvided = errors.New("no pickle key provided")
NotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand")
SignatureNotFound = errors.New("input JSON doesn't contain signature from specified device")
InputNotJSONString = errors.New("input doesn't look like a JSON string")
ErrNotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand")
ErrInputNotJSONString = errors.New("input doesn't look like a JSON string")
)
// Error codes from olm code
var (
NotEnoughRandom = errors.New("not enough entropy was supplied")
OutputBufferTooSmall = errors.New("supplied output buffer is too small")
BadMessageVersion = errors.New("the message version is unsupported")
BadMessageFormat = errors.New("the message couldn't be decoded")
BadMessageMAC = errors.New("the message couldn't be decrypted")
BadMessageKeyID = errors.New("the message references an unknown key ID")
InvalidBase64 = errors.New("the input base64 was invalid")
BadAccountKey = errors.New("the supplied account key is invalid")
UnknownPickleVersion = errors.New("the pickled object is too new")
CorruptedPickle = errors.New("the pickled object couldn't be decoded")
BadSessionKey = errors.New("attempt to initialise an inbound group session from an invalid session key")
UnknownMessageIndex = errors.New("attempt to decode a message whose index is earlier than our earliest known session key")
BadLegacyAccountPickle = errors.New("attempt to unpickle an account which uses pickle version 1")
BadSignature = errors.New("received message had a bad signature")
InputBufferTooSmall = errors.New("the input data was too small to be valid")
ErrLibolmInvalidBase64 = errors.New("the input base64 was invalid")
ErrLibolmNotEnoughRandom = errors.New("not enough entropy was supplied")
ErrLibolmOutputBufferTooSmall = errors.New("supplied output buffer is too small")
ErrLibolmBadAccountKey = errors.New("the supplied account key is invalid")
ErrLibolmCorruptedPickle = errors.New("the pickled object couldn't be decoded")
ErrLibolmBadSessionKey = errors.New("attempt to initialise an inbound group session from an invalid session key")
ErrLibolmBadLegacyAccountPickle = errors.New("attempt to unpickle an account which uses pickle version 1")
)
// Deprecated: use variables prefixed with Err
var (
EmptyInput = ErrEmptyInput
BadSignature = ErrBadSignature
InvalidBase64 = ErrLibolmInvalidBase64
BadMessageKeyID = ErrBadMessageKeyID
BadMessageFormat = ErrBadMessageFormat
BadMessageVersion = ErrWrongProtocolVersion
BadMessageMAC = ErrBadMAC
UnknownPickleVersion = ErrUnknownOlmPickleVersion
NotEnoughRandom = ErrLibolmNotEnoughRandom
OutputBufferTooSmall = ErrLibolmOutputBufferTooSmall
BadAccountKey = ErrLibolmBadAccountKey
CorruptedPickle = ErrLibolmCorruptedPickle
BadSessionKey = ErrLibolmBadSessionKey
UnknownMessageIndex = ErrUnknownMessageIndex
BadLegacyAccountPickle = ErrLibolmBadLegacyAccountPickle
InputBufferTooSmall = ErrInputToSmall
NoKeyProvided = ErrNoKeyProvided
NotEnoughGoRandom = ErrNotEnoughGoRandom
InputNotJSONString = ErrInputNotJSONString
ErrBadVersion = ErrUnknownJSONPickleVersion
ErrWrongPickleVersion = ErrUnknownJSONPickleVersion
ErrRatchetNotAvailable = ErrUnknownMessageIndex
)

View file

@ -18,8 +18,14 @@ import (
)
var (
SessionNotShared = errors.New("session has not been shared")
SessionExpired = errors.New("session has expired")
ErrSessionNotShared = errors.New("session has not been shared")
ErrSessionExpired = errors.New("session has expired")
)
// Deprecated: use variables prefixed with Err
var (
SessionNotShared = ErrSessionNotShared
SessionExpired = ErrSessionExpired
)
// OlmSessionList is a list of OlmSessions.
@ -255,9 +261,9 @@ func (ogs *OutboundGroupSession) Expired() bool {
func (ogs *OutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) {
if !ogs.Shared {
return nil, SessionNotShared
return nil, ErrSessionNotShared
} else if ogs.Expired() {
return nil, SessionExpired
return nil, ErrSessionExpired
}
ogs.MessageCount++
ogs.LastEncryptedTime = time.Now()

View file

@ -63,7 +63,7 @@ func (content *EncryptedEventContent) UnmarshalJSON(data []byte) error {
return json.Unmarshal(content.Ciphertext, &content.OlmCiphertext)
case id.AlgorithmMegolmV1:
if len(content.Ciphertext) == 0 || content.Ciphertext[0] != '"' || content.Ciphertext[len(content.Ciphertext)-1] != '"' {
return id.InputNotJSONString
return id.ErrInputNotJSONString
}
content.MegolmCiphertext = content.Ciphertext[1 : len(content.Ciphertext)-1]
}

View file

@ -80,7 +80,10 @@ func ResolveServerName(ctx context.Context, serverName string, opts ...*ResolveS
} else if wellKnown != nil {
output.Expires = expiry
output.HostHeader = wellKnown.Server
hostname, port, ok = ParseServerName(wellKnown.Server)
wkHost, wkPort, ok := ParseServerName(wellKnown.Server)
if ok {
hostname, port = wkHost, wkPort
}
// Step 3.1 and 3.2: IP literals and hostnames with port inside .well-known
if net.ParseIP(hostname) != nil || port != 0 {
if port == 0 {

View file

@ -57,7 +57,7 @@ type FilterPart struct {
// Validate checks if the filter contains valid property values
func (filter *Filter) Validate() error {
if filter.EventFormat != EventFormatClient && filter.EventFormat != EventFormatFederation {
return errors.New("Bad event_format value. Must be one of [\"client\", \"federation\"]")
return errors.New("bad event_format value")
}
return nil
}

View file

@ -17,8 +17,14 @@ import (
)
var (
InvalidContentURI = errors.New("invalid Matrix content URI")
InputNotJSONString = errors.New("input doesn't look like a JSON string")
ErrInvalidContentURI = errors.New("invalid Matrix content URI")
ErrInputNotJSONString = errors.New("input doesn't look like a JSON string")
)
// Deprecated: use variables prefixed with Err
var (
InvalidContentURI = ErrInvalidContentURI
InputNotJSONString = ErrInputNotJSONString
)
// ContentURIString is a string that's expected to be a Matrix content URI.
@ -55,9 +61,9 @@ func ParseContentURI(uri string) (parsed ContentURI, err error) {
if len(uri) == 0 {
return
} else if !strings.HasPrefix(uri, "mxc://") {
err = InvalidContentURI
err = ErrInvalidContentURI
} else if index := strings.IndexRune(uri[6:], '/'); index == -1 || index == len(uri)-7 {
err = InvalidContentURI
err = ErrInvalidContentURI
} else {
parsed.Homeserver = uri[6 : 6+index]
parsed.FileID = uri[6+index+1:]
@ -71,9 +77,9 @@ func ParseContentURIBytes(uri []byte) (parsed ContentURI, err error) {
if len(uri) == 0 {
return
} else if !bytes.HasPrefix(uri, mxcBytes) {
err = InvalidContentURI
err = ErrInvalidContentURI
} else if index := bytes.IndexRune(uri[6:], '/'); index == -1 || index == len(uri)-7 {
err = InvalidContentURI
err = ErrInvalidContentURI
} else {
parsed.Homeserver = string(uri[6 : 6+index])
parsed.FileID = string(uri[6+index+1:])
@ -86,7 +92,7 @@ func (uri *ContentURI) UnmarshalJSON(raw []byte) (err error) {
*uri = ContentURI{}
return nil
} else if len(raw) < 2 || raw[0] != '"' || raw[len(raw)-1] != '"' {
return InputNotJSONString
return ErrInputNotJSONString
}
parsed, err := ParseContentURIBytes(raw[1 : len(raw)-1])
if err != nil {

View file

@ -54,7 +54,7 @@ var SigilToPathSegment = map[rune]string{
func (uri *MatrixURI) getQuery() url.Values {
q := make(url.Values)
if uri.Via != nil && len(uri.Via) > 0 {
if len(uri.Via) > 0 {
q["via"] = uri.Via
}
if len(uri.Action) > 0 {

View file

@ -219,15 +219,15 @@ func DecodeUserLocalpart(str string) (string, error) {
for i := 0; i < len(strBytes); i++ {
b := strBytes[i]
if !isValidByte(b) {
return "", fmt.Errorf("Byte pos %d: Invalid byte", i)
return "", fmt.Errorf("invalid encoded byte at position %d: %c", i, b)
}
if b == '_' { // next byte is a-z and should be upper-case or is another _ and should be a literal _
if i+1 >= len(strBytes) {
return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding but ran out of string", i)
return "", fmt.Errorf("unexpected end of string after underscore at %d", i)
}
if !isValidEscapedChar(strBytes[i+1]) { // invalid escaping
return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding", i)
return "", fmt.Errorf("unexpected byte %c after underscore at %d", strBytes[i+1], i)
}
if strBytes[i+1] == '_' {
outputBuffer.WriteByte('_')
@ -237,7 +237,7 @@ func DecodeUserLocalpart(str string) (string, error) {
i++ // skip next byte since we just handled it
} else if b == '=' { // next 2 bytes are hex and should be buffered ready to be read as utf8
if i+2 >= len(strBytes) {
return "", fmt.Errorf("Byte pos: %d: expected quote-printable encoding but ran out of string", i)
return "", fmt.Errorf("unexpected end of string after equals sign at %d", i)
}
dst := make([]byte, 1)
_, err := hex.Decode(dst, strBytes[i+1:i+3])

View file

@ -105,7 +105,7 @@ func (action *PushAction) UnmarshalJSON(raw []byte) error {
if ok {
action.Action = ActionSetTweak
action.Tweak = PushActionTweak(tweak)
action.Value, _ = val["value"]
action.Value = val["value"]
}
}
return nil

View file

@ -102,14 +102,6 @@ func newEventPropertyIsPushCondition(key string, value any) *pushrules.PushCondi
}
}
func newEventPropertyContainsPushCondition(key string, value any) *pushrules.PushCondition {
return &pushrules.PushCondition{
Kind: pushrules.KindEventPropertyContains,
Key: key,
Value: value,
}
}
func TestPushCondition_Match_InvalidKind(t *testing.T) {
condition := &pushrules.PushCondition{
Kind: pushrules.PushCondKind("invalid"),

View file

@ -5,8 +5,6 @@ import (
"maunium.net/go/mautrix/id"
)
type RoomStateMap = map[event.Type]map[string]*event.Event
// Room represents a single Matrix room.
type Room struct {
ID id.RoomID
@ -25,8 +23,8 @@ func (room Room) UpdateState(evt *event.Event) {
// GetStateEvent returns the state event for the given type/state_key combo, or nil.
func (room Room) GetStateEvent(eventType event.Type, stateKey string) *event.Event {
stateEventMap, _ := room.State[eventType]
evt, _ := stateEventMap[stateKey]
stateEventMap := room.State[eventType]
evt := stateEventMap[stateKey]
return evt
}

View file

@ -75,8 +75,7 @@ type RespListRooms struct {
// https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#list-room-api
func (cli *Client) ListRooms(ctx context.Context, req ReqListRoom) (RespListRooms, error) {
var resp RespListRooms
var reqURL string
reqURL = cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery())
reqURL := cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery())
_, err := cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp)
return resp, err
}

6
url.go
View file

@ -98,10 +98,8 @@ func (saup SynapseAdminURLPath) FullPath() []any {
// and appservice user ID set already.
func (cli *Client) BuildURLWithQuery(urlPath PrefixableURLPath, urlQuery map[string]string) string {
return cli.BuildURLWithFullQuery(urlPath, func(q url.Values) {
if urlQuery != nil {
for k, v := range urlQuery {
q.Set(k, v)
}
for k, v := range urlQuery {
q.Set(k, v)
}
})
}