mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
all: fix staticcheck issues
This commit is contained in:
parent
6017612c55
commit
315d2ab17d
56 changed files with 358 additions and 338 deletions
1
.github/workflows/go.yml
vendored
1
.github/workflows/go.yml
vendored
|
|
@ -24,6 +24,7 @@ jobs:
|
|||
- name: Install goimports
|
||||
run: |
|
||||
go install golang.org/x/tools/cmd/goimports@latest
|
||||
go install honnef.co/go/tools/cmd/staticcheck@latest
|
||||
export PATH="$HOME/go/bin:$PATH"
|
||||
|
||||
- name: Run pre-commit
|
||||
|
|
|
|||
|
|
@ -18,8 +18,7 @@ repos:
|
|||
- "-w"
|
||||
- id: go-vet-repo-mod
|
||||
- id: go-mod-tidy
|
||||
# TODO enable this
|
||||
#- id: go-staticcheck-repo-mod
|
||||
- id: go-staticcheck-repo-mod
|
||||
|
||||
- repo: https://github.com/beeper/pre-commit-go
|
||||
rev: v0.4.2
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ func (wsc *WebsocketCommand) MakeResponse(ok bool, data any) *WebsocketRequest {
|
|||
var prefixMessage string
|
||||
for unwrappedErr != nil {
|
||||
errorData, jsonErr = json.Marshal(unwrappedErr)
|
||||
if errorData != nil && len(errorData) > 2 && jsonErr == nil {
|
||||
if len(errorData) > 2 && jsonErr == nil {
|
||||
prefixMessage = strings.Replace(err.Error(), unwrappedErr.Error(), "", 1)
|
||||
prefixMessage = strings.TrimRight(prefixMessage, ": ")
|
||||
break
|
||||
|
|
|
|||
|
|
@ -41,10 +41,7 @@ func (pc PermissionConfig) IsConfigured() bool {
|
|||
_, hasExampleDomain := pc["example.com"]
|
||||
_, hasExampleUser := pc["@admin:example.com"]
|
||||
exampleLen := boolToInt(hasWildcard) + boolToInt(hasExampleUser) + boolToInt(hasExampleDomain)
|
||||
if len(pc) <= exampleLen {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
return len(pc) > exampleLen
|
||||
}
|
||||
|
||||
func (pc PermissionConfig) Get(userID id.UserID) Permissions {
|
||||
|
|
|
|||
|
|
@ -102,9 +102,9 @@ func (bsq *BridgeStateQueue) loop() {
|
|||
}
|
||||
}
|
||||
|
||||
func (bsq *BridgeStateQueue) scheduleNotice(ctx context.Context, triggeredBy status.BridgeState) {
|
||||
func (bsq *BridgeStateQueue) scheduleNotice(triggeredBy status.BridgeState) {
|
||||
log := bsq.login.Log.With().Str("action", "transient disconnect notice").Logger()
|
||||
ctx = log.WithContext(bsq.bridge.BackgroundCtx)
|
||||
ctx := log.WithContext(bsq.bridge.BackgroundCtx)
|
||||
if !bsq.waitForTransientDisconnectReconnect(ctx) {
|
||||
return
|
||||
}
|
||||
|
|
@ -135,7 +135,7 @@ func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.Bridge
|
|||
if bsq.firstTransientDisconnect.IsZero() {
|
||||
bsq.firstTransientDisconnect = time.Now()
|
||||
}
|
||||
go bsq.scheduleNotice(ctx, state)
|
||||
go bsq.scheduleNotice(state)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,13 +7,7 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
"golang.org/x/exp/constraints"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"maunium.net/go/mautrix/bridgev2/networkid"
|
||||
|
||||
|
|
@ -158,55 +152,3 @@ func ensureBridgeIDMatches(ptr *networkid.BridgeID, expected networkid.BridgeID)
|
|||
panic("bridge ID mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func GetNumberFromMap[T constraints.Integer | constraints.Float](m map[string]any, key string) (T, bool) {
|
||||
if val, found := m[key]; found {
|
||||
floatVal, ok := val.(float64)
|
||||
if ok {
|
||||
return T(floatVal), true
|
||||
}
|
||||
tVal, ok := val.(T)
|
||||
if ok {
|
||||
return tVal, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func unmarshalMerge(input []byte, data any, extra *map[string]any) error {
|
||||
err := json.Unmarshal(input, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = json.Unmarshal(input, extra)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if *extra == nil {
|
||||
*extra = make(map[string]any)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func marshalMerge(data any, extra map[string]any) ([]byte, error) {
|
||||
if extra == nil {
|
||||
return json.Marshal(data)
|
||||
}
|
||||
merged := make(map[string]any)
|
||||
maps.Copy(merged, extra)
|
||||
dataRef := reflect.ValueOf(data).Elem()
|
||||
dataType := dataRef.Type()
|
||||
for _, field := range reflect.VisibleFields(dataType) {
|
||||
parts := strings.Split(field.Tag.Get("json"), ",")
|
||||
if len(parts) == 0 || len(parts[0]) == 0 || parts[0] == "-" {
|
||||
continue
|
||||
}
|
||||
fieldVal := dataRef.FieldByIndex(field.Index)
|
||||
if fieldVal.IsZero() {
|
||||
delete(merged, parts[0])
|
||||
} else {
|
||||
merged[parts[0]] = fieldVal.Interface()
|
||||
}
|
||||
}
|
||||
return json.Marshal(merged)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -38,9 +38,9 @@ func init() {
|
|||
|
||||
var _ crypto.StateStore = (*sqlstatestore.SQLStateStore)(nil)
|
||||
|
||||
var NoSessionFound = crypto.NoSessionFound
|
||||
var DuplicateMessageIndex = crypto.DuplicateMessageIndex
|
||||
var UnknownMessageIndex = olm.UnknownMessageIndex
|
||||
var NoSessionFound = crypto.ErrNoSessionFound
|
||||
var DuplicateMessageIndex = crypto.ErrDuplicateMessageIndex
|
||||
var UnknownMessageIndex = olm.ErrUnknownMessageIndex
|
||||
|
||||
type CryptoHelper struct {
|
||||
bridge *Connector
|
||||
|
|
@ -439,7 +439,7 @@ func (helper *CryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtTy
|
|||
var encrypted *event.EncryptedEventContent
|
||||
encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content)
|
||||
if err != nil {
|
||||
if !errors.Is(err, crypto.SessionExpired) && !errors.Is(err, crypto.SessionNotShared) && !errors.Is(err, crypto.NoGroupSession) {
|
||||
if !errors.Is(err, crypto.ErrSessionExpired) && !errors.Is(err, crypto.ErrSessionNotShared) && !errors.Is(err, crypto.ErrNoGroupSession) {
|
||||
return
|
||||
}
|
||||
helper.log.Debug().Err(err).
|
||||
|
|
|
|||
|
|
@ -127,6 +127,7 @@ func (br *Connector) waitLongerForSession(ctx context.Context, evt *event.Event,
|
|||
Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).
|
||||
Msg("Couldn't find session, requesting keys and waiting longer...")
|
||||
|
||||
//lint:ignore SA1019 RequestSession will gracefully request from all devices if DeviceID is blank
|
||||
go br.Crypto.RequestSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID)
|
||||
go br.sendCryptoStatusError(ctx, evt, fmt.Errorf("%w. The bridge will retry for %d seconds", errNoDecryptionKeys, int(extendedSessionWaitTimeout.Seconds())), errorEventID, 1, false)
|
||||
|
||||
|
|
|
|||
|
|
@ -135,7 +135,10 @@ func (br *BridgeMain) CheckLegacyDB(
|
|||
}
|
||||
var dbVersion int
|
||||
err = br.DB.QueryRow(ctx, "SELECT version FROM version").Scan(&dbVersion)
|
||||
if dbVersion < expectedVersion {
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to get database version")
|
||||
return
|
||||
} else if dbVersion < expectedVersion {
|
||||
log.Fatal().
|
||||
Int("expected_version", expectedVersion).
|
||||
Int("version", dbVersion).
|
||||
|
|
|
|||
|
|
@ -85,10 +85,9 @@ const (
|
|||
provisioningUserKey provisioningContextKey = iota
|
||||
provisioningUserLoginKey
|
||||
provisioningLoginProcessKey
|
||||
ProvisioningKeyRequest
|
||||
)
|
||||
|
||||
const ProvisioningKeyRequest = "fi.mau.provision.request"
|
||||
|
||||
func (prov *ProvisioningAPI) GetUser(r *http.Request) *bridgev2.User {
|
||||
return r.Context().Value(provisioningUserKey).(*bridgev2.User)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -410,6 +410,7 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin
|
|||
if reaction.Timestamp.IsZero() {
|
||||
reaction.Timestamp = msg.Timestamp.Add(10 * time.Millisecond)
|
||||
}
|
||||
//lint:ignore SA4006 it's a todo
|
||||
targetPart, ok := partMap[*reaction.TargetPart]
|
||||
if !ok {
|
||||
// TODO warning log and/or skip reaction?
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta
|
|||
go func() {
|
||||
_, err := br.Bot.SendState(ctx, sourcePortal.MXID, event.StateTombstone, "", &event.Content{
|
||||
Parsed: &event.TombstoneEventContent{
|
||||
Body: fmt.Sprintf("This room has been merged"),
|
||||
Body: "This room has been merged",
|
||||
ReplacementRoom: targetPortal.MXID,
|
||||
},
|
||||
}, time.Now())
|
||||
|
|
|
|||
10
client.go
10
client.go
|
|
@ -742,7 +742,7 @@ func (cli *Client) executeCompiledRequest(
|
|||
cli.RequestStart(req)
|
||||
startTime := time.Now()
|
||||
res, err := client.Do(req)
|
||||
duration := time.Now().Sub(startTime)
|
||||
duration := time.Since(startTime)
|
||||
if res != nil && !dontReadResponse {
|
||||
defer res.Body.Close()
|
||||
}
|
||||
|
|
@ -862,7 +862,7 @@ func (cli *Client) FullSyncRequest(ctx context.Context, req ReqSync) (resp *Resp
|
|||
}
|
||||
start := time.Now()
|
||||
_, err = cli.MakeFullRequest(ctx, fullReq)
|
||||
duration := time.Now().Sub(start)
|
||||
duration := time.Since(start)
|
||||
timeout := time.Duration(req.Timeout) * time.Millisecond
|
||||
buffer := 10 * time.Second
|
||||
if req.Since == "" {
|
||||
|
|
@ -966,7 +966,7 @@ func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister) (*RespRe
|
|||
// }
|
||||
// token := res.AccessToken
|
||||
func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister) (*RespRegister, error) {
|
||||
res, uia, err := cli.Register(ctx, req)
|
||||
_, uia, err := cli.Register(ctx, req)
|
||||
if err != nil && uia == nil {
|
||||
return nil, err
|
||||
} else if uia == nil {
|
||||
|
|
@ -975,7 +975,7 @@ func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister) (*RespRe
|
|||
return nil, errors.New("server does not support m.login.dummy")
|
||||
}
|
||||
req.Auth = BaseAuthData{Type: AuthTypeDummy, Session: uia.Session}
|
||||
res, _, err = cli.Register(ctx, req)
|
||||
res, _, err := cli.Register(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -1751,6 +1751,8 @@ func parseRoomStateArray(req *http.Request, res *http.Response, responseJSON any
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
type RoomStateMap = map[event.Type]map[string]*event.Event
|
||||
|
||||
// State gets all state in a room.
|
||||
// See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidstate
|
||||
func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomStateMap, err error) {
|
||||
|
|
|
|||
|
|
@ -21,13 +21,24 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
HashMismatch = errors.New("mismatching SHA-256 digest")
|
||||
UnsupportedVersion = errors.New("unsupported Matrix file encryption version")
|
||||
UnsupportedAlgorithm = errors.New("unsupported JWK encryption algorithm")
|
||||
InvalidKey = errors.New("failed to decode key")
|
||||
InvalidInitVector = errors.New("failed to decode initialization vector")
|
||||
InvalidHash = errors.New("failed to decode SHA-256 hash")
|
||||
ReaderClosed = errors.New("encrypting reader was already closed")
|
||||
ErrHashMismatch = errors.New("mismatching SHA-256 digest")
|
||||
ErrUnsupportedVersion = errors.New("unsupported Matrix file encryption version")
|
||||
ErrUnsupportedAlgorithm = errors.New("unsupported JWK encryption algorithm")
|
||||
ErrInvalidKey = errors.New("failed to decode key")
|
||||
ErrInvalidInitVector = errors.New("failed to decode initialization vector")
|
||||
ErrInvalidHash = errors.New("failed to decode SHA-256 hash")
|
||||
ErrReaderClosed = errors.New("encrypting reader was already closed")
|
||||
)
|
||||
|
||||
// Deprecated: use variables prefixed with Err
|
||||
var (
|
||||
HashMismatch = ErrHashMismatch
|
||||
UnsupportedVersion = ErrUnsupportedVersion
|
||||
UnsupportedAlgorithm = ErrUnsupportedAlgorithm
|
||||
InvalidKey = ErrInvalidKey
|
||||
InvalidInitVector = ErrInvalidInitVector
|
||||
InvalidHash = ErrInvalidHash
|
||||
ReaderClosed = ErrReaderClosed
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
@ -85,25 +96,25 @@ func (ef *EncryptedFile) decodeKeys(includeHash bool) error {
|
|||
if ef.decoded != nil {
|
||||
return nil
|
||||
} else if len(ef.Key.Key) != keyBase64Length {
|
||||
return InvalidKey
|
||||
return ErrInvalidKey
|
||||
} else if len(ef.InitVector) != ivBase64Length {
|
||||
return InvalidInitVector
|
||||
return ErrInvalidInitVector
|
||||
} else if includeHash && len(ef.Hashes.SHA256) != hashBase64Length {
|
||||
return InvalidHash
|
||||
return ErrInvalidHash
|
||||
}
|
||||
ef.decoded = &decodedKeys{}
|
||||
_, err := base64.RawURLEncoding.Decode(ef.decoded.key[:], []byte(ef.Key.Key))
|
||||
if err != nil {
|
||||
return InvalidKey
|
||||
return ErrInvalidKey
|
||||
}
|
||||
_, err = base64.RawStdEncoding.Decode(ef.decoded.iv[:], []byte(ef.InitVector))
|
||||
if err != nil {
|
||||
return InvalidInitVector
|
||||
return ErrInvalidInitVector
|
||||
}
|
||||
if includeHash {
|
||||
_, err = base64.RawStdEncoding.Decode(ef.decoded.sha256[:], []byte(ef.Hashes.SHA256))
|
||||
if err != nil {
|
||||
return InvalidHash
|
||||
return ErrInvalidHash
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
|
@ -179,7 +190,7 @@ var _ io.ReadSeekCloser = (*encryptingReader)(nil)
|
|||
|
||||
func (r *encryptingReader) Seek(offset int64, whence int) (int64, error) {
|
||||
if r.closed {
|
||||
return 0, ReaderClosed
|
||||
return 0, ErrReaderClosed
|
||||
}
|
||||
if offset != 0 || whence != io.SeekStart {
|
||||
return 0, fmt.Errorf("attachments.EncryptStream: only seeking to the beginning is supported")
|
||||
|
|
@ -200,7 +211,7 @@ func (r *encryptingReader) Seek(offset int64, whence int) (int64, error) {
|
|||
|
||||
func (r *encryptingReader) Read(dst []byte) (n int, err error) {
|
||||
if r.closed {
|
||||
return 0, ReaderClosed
|
||||
return 0, ErrReaderClosed
|
||||
} else if r.isDecrypting && r.file.decoded == nil {
|
||||
if err = r.file.PrepareForDecryption(); err != nil {
|
||||
return
|
||||
|
|
@ -224,7 +235,7 @@ func (r *encryptingReader) Close() (err error) {
|
|||
}
|
||||
if r.isDecrypting {
|
||||
if !hmac.Equal(r.hash.Sum(nil), r.file.decoded.sha256[:]) {
|
||||
return HashMismatch
|
||||
return ErrHashMismatch
|
||||
}
|
||||
} else {
|
||||
r.file.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString(r.hash.Sum(nil))
|
||||
|
|
@ -265,9 +276,9 @@ func (ef *EncryptedFile) Decrypt(ciphertext []byte) ([]byte, error) {
|
|||
// DecryptInPlace will always call this automatically, so calling this manually is not necessary when using that function.
|
||||
func (ef *EncryptedFile) PrepareForDecryption() error {
|
||||
if ef.Version != "v2" {
|
||||
return UnsupportedVersion
|
||||
return ErrUnsupportedVersion
|
||||
} else if ef.Key.Algorithm != "A256CTR" {
|
||||
return UnsupportedAlgorithm
|
||||
return ErrUnsupportedAlgorithm
|
||||
} else if err := ef.decodeKeys(true); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -281,7 +292,7 @@ func (ef *EncryptedFile) DecryptInPlace(data []byte) error {
|
|||
}
|
||||
dataHash := sha256.Sum256(data)
|
||||
if !hmac.Equal(ef.decoded.sha256[:], dataHash[:]) {
|
||||
return HashMismatch
|
||||
return ErrHashMismatch
|
||||
}
|
||||
utils.XorA256CTR(data, ef.decoded.key, ef.decoded.iv)
|
||||
return nil
|
||||
|
|
|
|||
|
|
@ -53,33 +53,33 @@ func TestUnsupportedVersion(t *testing.T) {
|
|||
file := parseHelloWorld()
|
||||
file.Version = "foo"
|
||||
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
|
||||
assert.ErrorIs(t, err, UnsupportedVersion)
|
||||
assert.ErrorIs(t, err, ErrUnsupportedVersion)
|
||||
}
|
||||
|
||||
func TestUnsupportedAlgorithm(t *testing.T) {
|
||||
file := parseHelloWorld()
|
||||
file.Key.Algorithm = "bar"
|
||||
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
|
||||
assert.ErrorIs(t, err, UnsupportedAlgorithm)
|
||||
assert.ErrorIs(t, err, ErrUnsupportedAlgorithm)
|
||||
}
|
||||
|
||||
func TestHashMismatch(t *testing.T) {
|
||||
file := parseHelloWorld()
|
||||
file.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString([]byte(random32Bytes))
|
||||
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
|
||||
assert.ErrorIs(t, err, HashMismatch)
|
||||
assert.ErrorIs(t, err, ErrHashMismatch)
|
||||
}
|
||||
|
||||
func TestTooLongHash(t *testing.T) {
|
||||
file := parseHelloWorld()
|
||||
file.Hashes.SHA256 = "TG9yZW0gaXBzdW0gZG9sb3Igc2l0IGFtZXQsIGNvbnNlY3RldHVlciBhZGlwaXNjaW5nIGVsaXQuIFNlZCBwb3N1ZXJlIGludGVyZHVtIHNlbS4gUXVpc3F1ZSBsaWd1bGEgZXJvcyB1bGxhbWNvcnBlciBxdWlzLCBsYWNpbmlhIHF1aXMgZmFjaWxpc2lzIHNlZCBzYXBpZW4uCg"
|
||||
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
|
||||
assert.ErrorIs(t, err, InvalidHash)
|
||||
assert.ErrorIs(t, err, ErrInvalidHash)
|
||||
}
|
||||
|
||||
func TestTooShortHash(t *testing.T) {
|
||||
file := parseHelloWorld()
|
||||
file.Hashes.SHA256 = "5/Gy1JftyyQ"
|
||||
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
|
||||
assert.ErrorIs(t, err, InvalidHash)
|
||||
assert.ErrorIs(t, err, ErrInvalidHash)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -63,8 +63,8 @@ func (mach *OlmMachine) GetCrossSigningPublicKeys(ctx context.Context, userID id
|
|||
if len(dbKeys) > 0 {
|
||||
masterKey, ok := dbKeys[id.XSUsageMaster]
|
||||
if ok {
|
||||
selfSigning, _ := dbKeys[id.XSUsageSelfSigning]
|
||||
userSigning, _ := dbKeys[id.XSUsageUserSigning]
|
||||
selfSigning := dbKeys[id.XSUsageSelfSigning]
|
||||
userSigning := dbKeys[id.XSUsageUserSigning]
|
||||
return &CrossSigningPublicKeysCache{
|
||||
MasterKey: masterKey.Key,
|
||||
SelfSigningKey: selfSigning.Key,
|
||||
|
|
|
|||
|
|
@ -26,24 +26,22 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK
|
|||
log.Error().Err(err).
|
||||
Msg("Error fetching current cross-signing keys of user")
|
||||
}
|
||||
if currentKeys != nil {
|
||||
for curKeyUsage, curKey := range currentKeys {
|
||||
log := log.With().Stringer("old_key", curKey.Key).Str("old_key_usage", string(curKeyUsage)).Logger()
|
||||
// got a new key with the same usage as an existing key
|
||||
for _, newKeyUsage := range userKeys.Usage {
|
||||
if newKeyUsage == curKeyUsage {
|
||||
if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok {
|
||||
// old key is not in the new key map, so we drop signatures made by it
|
||||
if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil {
|
||||
log.Error().Err(err).Msg("Error deleting old signatures made by user")
|
||||
} else {
|
||||
log.Debug().
|
||||
Int64("signature_count", count).
|
||||
Msg("Dropped signatures made by old key as it has been replaced")
|
||||
}
|
||||
for curKeyUsage, curKey := range currentKeys {
|
||||
log := log.With().Stringer("old_key", curKey.Key).Str("old_key_usage", string(curKeyUsage)).Logger()
|
||||
// got a new key with the same usage as an existing key
|
||||
for _, newKeyUsage := range userKeys.Usage {
|
||||
if newKeyUsage == curKeyUsage {
|
||||
if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok {
|
||||
// old key is not in the new key map, so we drop signatures made by it
|
||||
if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil {
|
||||
log.Error().Err(err).Msg("Error deleting old signatures made by user")
|
||||
} else {
|
||||
log.Debug().
|
||||
Int64("signature_count", count).
|
||||
Msg("Dropped signatures made by old key as it has been replaced")
|
||||
}
|
||||
break
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -278,7 +278,7 @@ func (helper *CryptoHelper) verifyDeviceKeysOnServer(ctx context.Context) error
|
|||
}
|
||||
}
|
||||
|
||||
var NoSessionFound = crypto.NoSessionFound
|
||||
var NoSessionFound = crypto.ErrNoSessionFound
|
||||
|
||||
const initialSessionWaitTimeout = 3 * time.Second
|
||||
const extendedSessionWaitTimeout = 22 * time.Second
|
||||
|
|
@ -371,6 +371,7 @@ func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, evt *event
|
|||
content := evt.Content.AsEncrypted()
|
||||
log.Debug().Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).Msg("Couldn't find session, requesting keys and waiting longer...")
|
||||
|
||||
//lint:ignore SA1019 RequestSession will gracefully request from all devices if DeviceID is blank
|
||||
go helper.RequestSession(context.TODO(), evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID)
|
||||
|
||||
if !helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) {
|
||||
|
|
@ -418,7 +419,7 @@ func (helper *CryptoHelper) EncryptWithStateKey(ctx context.Context, roomID id.R
|
|||
defer helper.lock.RUnlock()
|
||||
encrypted, err = helper.mach.EncryptMegolmEventWithStateKey(ctx, roomID, evtType, stateKey, content)
|
||||
if err != nil {
|
||||
if !errors.Is(err, crypto.SessionExpired) && err != crypto.NoGroupSession && !errors.Is(err, crypto.SessionNotShared) {
|
||||
if !errors.Is(err, crypto.ErrSessionExpired) && err != crypto.ErrNoGroupSession && !errors.Is(err, crypto.ErrSessionNotShared) {
|
||||
return
|
||||
}
|
||||
helper.log.Debug().
|
||||
|
|
|
|||
|
|
@ -24,13 +24,24 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
IncorrectEncryptedContentType = errors.New("event content is not instance of *event.EncryptedEventContent")
|
||||
NoSessionFound = errors.New("failed to decrypt megolm event: no session with given ID found")
|
||||
DuplicateMessageIndex = errors.New("duplicate megolm message index")
|
||||
WrongRoom = errors.New("encrypted megolm event is not intended for this room")
|
||||
DeviceKeyMismatch = errors.New("device keys in event and verified device info do not match")
|
||||
SenderKeyMismatch = errors.New("sender keys in content and megolm session do not match")
|
||||
RatchetError = errors.New("failed to ratchet session after use")
|
||||
ErrIncorrectEncryptedContentType = errors.New("event content is not instance of *event.EncryptedEventContent")
|
||||
ErrNoSessionFound = errors.New("failed to decrypt megolm event: no session with given ID found")
|
||||
ErrDuplicateMessageIndex = errors.New("duplicate megolm message index")
|
||||
ErrWrongRoom = errors.New("encrypted megolm event is not intended for this room")
|
||||
ErrDeviceKeyMismatch = errors.New("device keys in event and verified device info do not match")
|
||||
ErrSenderKeyMismatch = errors.New("sender keys in content and megolm session do not match")
|
||||
ErrRatchetError = errors.New("failed to ratchet session after use")
|
||||
)
|
||||
|
||||
// Deprecated: use variables prefixed with Err
|
||||
var (
|
||||
IncorrectEncryptedContentType = ErrIncorrectEncryptedContentType
|
||||
NoSessionFound = ErrNoSessionFound
|
||||
DuplicateMessageIndex = ErrDuplicateMessageIndex
|
||||
WrongRoom = ErrWrongRoom
|
||||
DeviceKeyMismatch = ErrDeviceKeyMismatch
|
||||
SenderKeyMismatch = ErrSenderKeyMismatch
|
||||
RatchetError = ErrRatchetError
|
||||
)
|
||||
|
||||
type megolmEvent struct {
|
||||
|
|
@ -49,9 +60,9 @@ var (
|
|||
func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event) (*event.Event, error) {
|
||||
content, ok := evt.Content.Parsed.(*event.EncryptedEventContent)
|
||||
if !ok {
|
||||
return nil, IncorrectEncryptedContentType
|
||||
return nil, ErrIncorrectEncryptedContentType
|
||||
} else if content.Algorithm != id.AlgorithmMegolmV1 {
|
||||
return nil, UnsupportedAlgorithm
|
||||
return nil, ErrUnsupportedAlgorithm
|
||||
}
|
||||
log := mach.machOrContextLog(ctx).With().
|
||||
Str("action", "decrypt megolm event").
|
||||
|
|
@ -97,7 +108,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event
|
|||
Msg("Couldn't resolve trust level of session: sent by unknown device")
|
||||
trustLevel = id.TrustStateUnknownDevice
|
||||
} else if device.SigningKey != sess.SigningKey || device.IdentityKey != sess.SenderKey {
|
||||
return nil, DeviceKeyMismatch
|
||||
return nil, ErrDeviceKeyMismatch
|
||||
} else {
|
||||
trustLevel, err = mach.ResolveTrustContext(ctx, device)
|
||||
if err != nil {
|
||||
|
|
@ -147,7 +158,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse megolm payload: %w", err)
|
||||
} else if megolmEvt.RoomID != encryptionRoomID {
|
||||
return nil, WrongRoom
|
||||
return nil, ErrWrongRoom
|
||||
}
|
||||
if evt.StateKey != nil && megolmEvt.StateKey != nil && mach.AllowEncryptedState {
|
||||
megolmEvt.Type.Class = event.StateEventType
|
||||
|
|
@ -201,19 +212,19 @@ func (mach *OlmMachine) checkUndecryptableMessageIndexDuplication(ctx context.Co
|
|||
messageIndex, decodeErr := ParseMegolmMessageIndex(content.MegolmCiphertext)
|
||||
if decodeErr != nil {
|
||||
log.Warn().Err(decodeErr).Msg("Failed to parse message index to check if it's a duplicate for message that failed to decrypt")
|
||||
return 0, fmt.Errorf("%w (also failed to parse message index)", olm.UnknownMessageIndex)
|
||||
return 0, fmt.Errorf("%w (also failed to parse message index)", olm.ErrUnknownMessageIndex)
|
||||
}
|
||||
firstKnown := sess.Internal.FirstKnownIndex()
|
||||
log = log.With().Uint("message_index", messageIndex).Uint32("first_known_index", firstKnown).Logger()
|
||||
if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil {
|
||||
log.Debug().Err(err).Msg("Failed to check if message index is duplicate")
|
||||
return messageIndex, fmt.Errorf("%w (failed to check if index is duplicate; received: %d, earliest known: %d)", olm.UnknownMessageIndex, messageIndex, firstKnown)
|
||||
return messageIndex, fmt.Errorf("%w (failed to check if index is duplicate; received: %d, earliest known: %d)", olm.ErrUnknownMessageIndex, messageIndex, firstKnown)
|
||||
} else if !ok {
|
||||
log.Debug().Msg("Failed to decrypt message due to unknown index and found duplicate")
|
||||
return messageIndex, fmt.Errorf("%w %d (also failed to decrypt because earliest known index is %d)", DuplicateMessageIndex, messageIndex, firstKnown)
|
||||
return messageIndex, fmt.Errorf("%w %d (also failed to decrypt because earliest known index is %d)", ErrDuplicateMessageIndex, messageIndex, firstKnown)
|
||||
}
|
||||
log.Debug().Msg("Failed to decrypt message due to unknown index, but index is not duplicate")
|
||||
return messageIndex, fmt.Errorf("%w (not duplicate index; received: %d, earliest known: %d)", olm.UnknownMessageIndex, messageIndex, firstKnown)
|
||||
return messageIndex, fmt.Errorf("%w (not duplicate index; received: %d, earliest known: %d)", olm.ErrUnknownMessageIndex, messageIndex, firstKnown)
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *event.Event, encryptionRoomID id.RoomID, content *event.EncryptedEventContent) (*InboundGroupSession, []byte, uint, error) {
|
||||
|
|
@ -224,13 +235,13 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
|
|||
if err != nil {
|
||||
return nil, nil, 0, fmt.Errorf("failed to get group session: %w", err)
|
||||
} else if sess == nil {
|
||||
return nil, nil, 0, fmt.Errorf("%w (ID %s)", NoSessionFound, content.SessionID)
|
||||
return nil, nil, 0, fmt.Errorf("%w (ID %s)", ErrNoSessionFound, content.SessionID)
|
||||
} else if content.SenderKey != "" && content.SenderKey != sess.SenderKey {
|
||||
return sess, nil, 0, SenderKeyMismatch
|
||||
return sess, nil, 0, ErrSenderKeyMismatch
|
||||
}
|
||||
plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext)
|
||||
if err != nil {
|
||||
if errors.Is(err, olm.UnknownMessageIndex) && mach.RatchetKeysOnDecrypt {
|
||||
if errors.Is(err, olm.ErrUnknownMessageIndex) && mach.RatchetKeysOnDecrypt {
|
||||
messageIndex, err = mach.checkUndecryptableMessageIndexDuplication(ctx, sess, evt, content)
|
||||
return sess, nil, messageIndex, fmt.Errorf("failed to decrypt megolm event: %w", err)
|
||||
}
|
||||
|
|
@ -238,7 +249,7 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
|
|||
} else if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil {
|
||||
return sess, nil, messageIndex, fmt.Errorf("failed to check if message index is duplicate: %w", err)
|
||||
} else if !ok {
|
||||
return sess, nil, messageIndex, fmt.Errorf("%w %d", DuplicateMessageIndex, messageIndex)
|
||||
return sess, nil, messageIndex, fmt.Errorf("%w %d", ErrDuplicateMessageIndex, messageIndex)
|
||||
}
|
||||
|
||||
// Normal clients don't care about tracking the ratchet state, so let them bypass the rest of the function
|
||||
|
|
@ -290,24 +301,24 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
|
|||
err = mach.CryptoStore.RedactGroupSession(ctx, sess.RoomID, sess.ID(), "maximum messages reached")
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to delete fully used session")
|
||||
return sess, plaintext, messageIndex, RatchetError
|
||||
return sess, plaintext, messageIndex, ErrRatchetError
|
||||
} else {
|
||||
log.Info().Msg("Deleted fully used session")
|
||||
}
|
||||
} else if ratchetCurrentIndex < ratchetTargetIndex && mach.RatchetKeysOnDecrypt {
|
||||
if err = sess.RatchetTo(ratchetTargetIndex); err != nil {
|
||||
log.Err(err).Msg("Failed to ratchet session")
|
||||
return sess, plaintext, messageIndex, RatchetError
|
||||
return sess, plaintext, messageIndex, ErrRatchetError
|
||||
} else if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil {
|
||||
log.Err(err).Msg("Failed to store ratcheted session")
|
||||
return sess, plaintext, messageIndex, RatchetError
|
||||
return sess, plaintext, messageIndex, ErrRatchetError
|
||||
} else {
|
||||
log.Info().Msg("Ratcheted session forward")
|
||||
}
|
||||
} else if didModify {
|
||||
if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil {
|
||||
log.Err(err).Msg("Failed to store updated ratchet safety data")
|
||||
return sess, plaintext, messageIndex, RatchetError
|
||||
return sess, plaintext, messageIndex, ErrRatchetError
|
||||
} else {
|
||||
log.Debug().Msg("Ratchet safety data changed (ratchet state didn't change)")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,15 +26,27 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
UnsupportedAlgorithm = errors.New("unsupported event encryption algorithm")
|
||||
NotEncryptedForMe = errors.New("olm event doesn't contain ciphertext for this device")
|
||||
UnsupportedOlmMessageType = errors.New("unsupported olm message type")
|
||||
DecryptionFailedWithMatchingSession = errors.New("decryption failed with matching session")
|
||||
DecryptionFailedForNormalMessage = errors.New("decryption failed for normal message")
|
||||
SenderMismatch = errors.New("mismatched sender in olm payload")
|
||||
RecipientMismatch = errors.New("mismatched recipient in olm payload")
|
||||
RecipientKeyMismatch = errors.New("mismatched recipient key in olm payload")
|
||||
ErrDuplicateMessage = errors.New("duplicate olm message")
|
||||
ErrUnsupportedAlgorithm = errors.New("unsupported event encryption algorithm")
|
||||
ErrNotEncryptedForMe = errors.New("olm event doesn't contain ciphertext for this device")
|
||||
ErrUnsupportedOlmMessageType = errors.New("unsupported olm message type")
|
||||
ErrDecryptionFailedWithMatchingSession = errors.New("decryption failed with matching session")
|
||||
ErrDecryptionFailedForNormalMessage = errors.New("decryption failed for normal message")
|
||||
ErrSenderMismatch = errors.New("mismatched sender in olm payload")
|
||||
ErrRecipientMismatch = errors.New("mismatched recipient in olm payload")
|
||||
ErrRecipientKeyMismatch = errors.New("mismatched recipient key in olm payload")
|
||||
ErrDuplicateMessage = errors.New("duplicate olm message")
|
||||
)
|
||||
|
||||
// Deprecated: use variables prefixed with Err
|
||||
var (
|
||||
UnsupportedAlgorithm = ErrUnsupportedAlgorithm
|
||||
NotEncryptedForMe = ErrNotEncryptedForMe
|
||||
UnsupportedOlmMessageType = ErrUnsupportedOlmMessageType
|
||||
DecryptionFailedWithMatchingSession = ErrDecryptionFailedWithMatchingSession
|
||||
DecryptionFailedForNormalMessage = ErrDecryptionFailedForNormalMessage
|
||||
SenderMismatch = ErrSenderMismatch
|
||||
RecipientMismatch = ErrRecipientMismatch
|
||||
RecipientKeyMismatch = ErrRecipientKeyMismatch
|
||||
)
|
||||
|
||||
// DecryptedOlmEvent represents an event that was decrypted from an event encrypted with the m.olm.v1.curve25519-aes-sha2 algorithm.
|
||||
|
|
@ -56,13 +68,13 @@ type DecryptedOlmEvent struct {
|
|||
func (mach *OlmMachine) decryptOlmEvent(ctx context.Context, evt *event.Event) (*DecryptedOlmEvent, error) {
|
||||
content, ok := evt.Content.Parsed.(*event.EncryptedEventContent)
|
||||
if !ok {
|
||||
return nil, IncorrectEncryptedContentType
|
||||
return nil, ErrIncorrectEncryptedContentType
|
||||
} else if content.Algorithm != id.AlgorithmOlmV1 {
|
||||
return nil, UnsupportedAlgorithm
|
||||
return nil, ErrUnsupportedAlgorithm
|
||||
}
|
||||
ownContent, ok := content.OlmCiphertext[mach.account.IdentityKey()]
|
||||
if !ok {
|
||||
return nil, NotEncryptedForMe
|
||||
return nil, ErrNotEncryptedForMe
|
||||
}
|
||||
decrypted, err := mach.decryptAndParseOlmCiphertext(ctx, evt, content.SenderKey, ownContent.Type, ownContent.Body)
|
||||
if err != nil {
|
||||
|
|
@ -78,7 +90,7 @@ type OlmEventKeys struct {
|
|||
|
||||
func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *event.Event, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) (*DecryptedOlmEvent, error) {
|
||||
if olmType != id.OlmMsgTypePreKey && olmType != id.OlmMsgTypeMsg {
|
||||
return nil, UnsupportedOlmMessageType
|
||||
return nil, ErrUnsupportedOlmMessageType
|
||||
}
|
||||
|
||||
log := mach.machOrContextLog(ctx).With().
|
||||
|
|
@ -102,11 +114,11 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *e
|
|||
}
|
||||
olmEvt.Type.Class = evt.Type.Class
|
||||
if evt.Sender != olmEvt.Sender {
|
||||
return nil, SenderMismatch
|
||||
return nil, ErrSenderMismatch
|
||||
} else if mach.Client.UserID != olmEvt.Recipient {
|
||||
return nil, RecipientMismatch
|
||||
return nil, ErrRecipientMismatch
|
||||
} else if mach.account.SigningKey() != olmEvt.RecipientKeys.Ed25519 {
|
||||
return nil, RecipientKeyMismatch
|
||||
return nil, ErrRecipientKeyMismatch
|
||||
}
|
||||
|
||||
if len(olmEvt.Content.VeryRaw) > 0 {
|
||||
|
|
@ -151,7 +163,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
|
|||
|
||||
plaintext, err := mach.tryDecryptOlmCiphertextWithExistingSession(ctx, senderKey, olmType, ciphertext, ciphertextHash)
|
||||
if err != nil {
|
||||
if err == DecryptionFailedWithMatchingSession {
|
||||
if err == ErrDecryptionFailedWithMatchingSession {
|
||||
log.Warn().Msg("Found matching session, but decryption failed")
|
||||
go mach.unwedgeDevice(log, sender, senderKey)
|
||||
}
|
||||
|
|
@ -169,10 +181,10 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
|
|||
// if it isn't one at this point in time anymore, so return early.
|
||||
if olmType != id.OlmMsgTypePreKey {
|
||||
go mach.unwedgeDevice(log, sender, senderKey)
|
||||
return nil, DecryptionFailedForNormalMessage
|
||||
return nil, ErrDecryptionFailedForNormalMessage
|
||||
}
|
||||
|
||||
accountBackup, err := mach.account.Internal.Pickle([]byte("tmp"))
|
||||
accountBackup, _ := mach.account.Internal.Pickle([]byte("tmp"))
|
||||
log.Trace().Msg("Trying to create inbound session")
|
||||
endTimeTrace = mach.timeTrace(ctx, "creating inbound olm session", time.Second)
|
||||
session, err := mach.createInboundSession(ctx, senderKey, ciphertext)
|
||||
|
|
@ -302,7 +314,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(
|
|||
Str("session_description", session.Describe()).
|
||||
Msg("Failed to decrypt olm message")
|
||||
if olmType == id.OlmMsgTypePreKey {
|
||||
return nil, DecryptionFailedWithMatchingSession
|
||||
return nil, ErrDecryptionFailedWithMatchingSession
|
||||
}
|
||||
} else {
|
||||
endTimeTrace = mach.timeTrace(ctx, "updating session in database", time.Second)
|
||||
|
|
@ -345,7 +357,7 @@ func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, send
|
|||
ctx := log.WithContext(mach.backgroundCtx)
|
||||
mach.recentlyUnwedgedLock.Lock()
|
||||
prevUnwedge, ok := mach.recentlyUnwedged[senderKey]
|
||||
delta := time.Now().Sub(prevUnwedge)
|
||||
delta := time.Since(prevUnwedge)
|
||||
if ok && delta < MinUnwedgeInterval {
|
||||
log.Debug().
|
||||
Str("previous_recreation", delta.String()).
|
||||
|
|
|
|||
|
|
@ -22,14 +22,23 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
MismatchingDeviceID = errors.New("mismatching device ID in parameter and keys object")
|
||||
MismatchingUserID = errors.New("mismatching user ID in parameter and keys object")
|
||||
MismatchingSigningKey = errors.New("received update for device with different signing key")
|
||||
NoSigningKeyFound = errors.New("didn't find ed25519 signing key")
|
||||
NoIdentityKeyFound = errors.New("didn't find curve25519 identity key")
|
||||
InvalidKeySignature = errors.New("invalid signature on device keys")
|
||||
ErrMismatchingDeviceID = errors.New("mismatching device ID in parameter and keys object")
|
||||
ErrMismatchingUserID = errors.New("mismatching user ID in parameter and keys object")
|
||||
ErrMismatchingSigningKey = errors.New("received update for device with different signing key")
|
||||
ErrNoSigningKeyFound = errors.New("didn't find ed25519 signing key")
|
||||
ErrNoIdentityKeyFound = errors.New("didn't find curve25519 identity key")
|
||||
ErrInvalidKeySignature = errors.New("invalid signature on device keys")
|
||||
ErrUserNotTracked = errors.New("user is not tracked")
|
||||
)
|
||||
|
||||
ErrUserNotTracked = errors.New("user is not tracked")
|
||||
// Deprecated: use variables prefixed with Err
|
||||
var (
|
||||
MismatchingDeviceID = ErrMismatchingDeviceID
|
||||
MismatchingUserID = ErrMismatchingUserID
|
||||
MismatchingSigningKey = ErrMismatchingSigningKey
|
||||
NoSigningKeyFound = ErrNoSigningKeyFound
|
||||
NoIdentityKeyFound = ErrNoIdentityKeyFound
|
||||
InvalidKeySignature = ErrInvalidKeySignature
|
||||
)
|
||||
|
||||
func (mach *OlmMachine) LoadDevices(ctx context.Context, user id.UserID) (keys map[id.DeviceID]*id.Device) {
|
||||
|
|
@ -312,28 +321,28 @@ func (mach *OlmMachine) OnDevicesChanged(ctx context.Context, userID id.UserID)
|
|||
|
||||
func (mach *OlmMachine) validateDevice(userID id.UserID, deviceID id.DeviceID, deviceKeys mautrix.DeviceKeys, existing *id.Device) (*id.Device, error) {
|
||||
if deviceID != deviceKeys.DeviceID {
|
||||
return nil, fmt.Errorf("%w (expected %s, got %s)", MismatchingDeviceID, deviceID, deviceKeys.DeviceID)
|
||||
return nil, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingDeviceID, deviceID, deviceKeys.DeviceID)
|
||||
} else if userID != deviceKeys.UserID {
|
||||
return nil, fmt.Errorf("%w (expected %s, got %s)", MismatchingUserID, userID, deviceKeys.UserID)
|
||||
return nil, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingUserID, userID, deviceKeys.UserID)
|
||||
}
|
||||
|
||||
signingKey := deviceKeys.Keys.GetEd25519(deviceID)
|
||||
identityKey := deviceKeys.Keys.GetCurve25519(deviceID)
|
||||
if signingKey == "" {
|
||||
return nil, NoSigningKeyFound
|
||||
return nil, ErrNoSigningKeyFound
|
||||
} else if identityKey == "" {
|
||||
return nil, NoIdentityKeyFound
|
||||
return nil, ErrNoIdentityKeyFound
|
||||
}
|
||||
|
||||
if existing != nil && existing.SigningKey != signingKey {
|
||||
return existing, fmt.Errorf("%w (expected %s, got %s)", MismatchingSigningKey, existing.SigningKey, signingKey)
|
||||
return existing, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingSigningKey, existing.SigningKey, signingKey)
|
||||
}
|
||||
|
||||
ok, err := signatures.VerifySignatureJSON(deviceKeys, userID, deviceID.String(), signingKey)
|
||||
if err != nil {
|
||||
return existing, fmt.Errorf("failed to verify signature: %w", err)
|
||||
} else if !ok {
|
||||
return existing, InvalidKeySignature
|
||||
return existing, ErrInvalidKeySignature
|
||||
}
|
||||
|
||||
name, ok := deviceKeys.Unsigned["device_display_name"].(string)
|
||||
|
|
|
|||
|
|
@ -25,7 +25,12 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
NoGroupSession = errors.New("no group session created")
|
||||
ErrNoGroupSession = errors.New("no group session created")
|
||||
)
|
||||
|
||||
// Deprecated: use variables prefixed with Err
|
||||
var (
|
||||
NoGroupSession = ErrNoGroupSession
|
||||
)
|
||||
|
||||
func getRawJSON[T any](content json.RawMessage, path ...string) *T {
|
||||
|
|
@ -82,7 +87,7 @@ type rawMegolmEvent struct {
|
|||
|
||||
// IsShareError returns true if the error is caused by the lack of an outgoing megolm session and can be solved with OlmMachine.ShareGroupSession
|
||||
func IsShareError(err error) bool {
|
||||
return err == SessionExpired || err == SessionNotShared || err == NoGroupSession
|
||||
return err == ErrSessionExpired || err == ErrSessionNotShared || err == ErrNoGroupSession
|
||||
}
|
||||
|
||||
func ParseMegolmMessageIndex(ciphertext []byte) (uint, error) {
|
||||
|
|
@ -120,7 +125,7 @@ func (mach *OlmMachine) EncryptMegolmEventWithStateKey(ctx context.Context, room
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get outbound group session: %w", err)
|
||||
} else if session == nil {
|
||||
return nil, NoGroupSession
|
||||
return nil, ErrNoGroupSession
|
||||
}
|
||||
plaintext, err := json.Marshal(&rawMegolmEvent{
|
||||
RoomID: roomID,
|
||||
|
|
|
|||
|
|
@ -334,7 +334,7 @@ func (a *Account) UnpickleLibOlm(buf []byte) error {
|
|||
if err != nil {
|
||||
return err
|
||||
} else if pickledVersion != accountPickleVersionLibOLM && pickledVersion != 3 && pickledVersion != 2 {
|
||||
return fmt.Errorf("unpickle account: %w (found version %d)", olm.ErrBadVersion, pickledVersion)
|
||||
return fmt.Errorf("unpickle account: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion)
|
||||
} else if err = a.IdKeys.Ed25519.UnpickleLibOlm(decoder); err != nil { // read the ed25519 key pair
|
||||
return err
|
||||
} else if err = a.IdKeys.Curve25519.UnpickleLibOlm(decoder); err != nil { // read curve25519 key pair
|
||||
|
|
|
|||
|
|
@ -124,7 +124,7 @@ func TestOldAccountPickle(t *testing.T) {
|
|||
account, err := account.NewAccount()
|
||||
assert.NoError(t, err)
|
||||
err = account.Unpickle(pickled, pickleKey)
|
||||
assert.ErrorIs(t, err, olm.ErrBadVersion)
|
||||
assert.ErrorIs(t, err, olm.ErrUnknownOlmPickleVersion)
|
||||
}
|
||||
|
||||
func TestLoopback(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -4,7 +4,8 @@ import (
|
|||
"encoding/base64"
|
||||
)
|
||||
|
||||
// Deprecated: base64.RawStdEncoding should be used directly
|
||||
// These methods should only be used for raw byte operations, never with string conversion
|
||||
|
||||
func Decode(input []byte) ([]byte, error) {
|
||||
decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(input)))
|
||||
writtenBytes, err := base64.RawStdEncoding.Decode(decoded, input)
|
||||
|
|
@ -14,7 +15,6 @@ func Decode(input []byte) ([]byte, error) {
|
|||
return decoded[:writtenBytes], nil
|
||||
}
|
||||
|
||||
// Deprecated: base64.RawStdEncoding should be used directly
|
||||
func Encode(input []byte) []byte {
|
||||
encoded := make([]byte, base64.RawStdEncoding.EncodedLen(len(input)))
|
||||
base64.RawStdEncoding.Encode(encoded, input)
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ func UnpickleAsJSON(object any, pickled, key []byte, pickleVersion byte) error {
|
|||
}
|
||||
}
|
||||
if decrypted[0] != pickleVersion {
|
||||
return fmt.Errorf("unpickle: %w", olm.ErrWrongPickleVersion)
|
||||
return fmt.Errorf("unpickle: %w", olm.ErrUnknownJSONPickleVersion)
|
||||
}
|
||||
err = json.Unmarshal(decrypted[1:], object)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ func (s *MegolmSessionExport) Decode(input []byte) error {
|
|||
return fmt.Errorf("decrypt: %w", olm.ErrBadInput)
|
||||
}
|
||||
if input[0] != sessionExportVersion {
|
||||
return fmt.Errorf("decrypt: %w", olm.ErrBadVersion)
|
||||
return fmt.Errorf("decrypt: %w", olm.ErrUnknownOlmPickleVersion)
|
||||
}
|
||||
s.Counter = binary.BigEndian.Uint32(input[1:5])
|
||||
copy(s.RatchetData[:], input[5:133])
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ func (s *MegolmSessionSharing) VerifyAndDecode(input []byte) error {
|
|||
}
|
||||
s.PublicKey = publicKey
|
||||
if input[0] != sessionSharingVersion {
|
||||
return fmt.Errorf("verify: %w", olm.ErrBadVersion)
|
||||
return fmt.Errorf("verify: %w", olm.ErrUnknownOlmPickleVersion)
|
||||
}
|
||||
s.Counter = binary.BigEndian.Uint32(input[1:5])
|
||||
copy(s.RatchetData[:], input[5:133])
|
||||
|
|
|
|||
|
|
@ -103,7 +103,7 @@ func (a *Decryption) UnpickleLibOlm(unpickled []byte) error {
|
|||
if pickledVersion == decryptionPickleVersionLibOlm {
|
||||
return a.KeyPair.UnpickleLibOlm(decoder)
|
||||
} else {
|
||||
return fmt.Errorf("unpickle olmSession: %w (found %d, expected %d)", olm.ErrBadVersion, pickledVersion, decryptionPickleVersionLibOlm)
|
||||
return fmt.Errorf("unpickle olmSession: %w (found %d, expected %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion, decryptionPickleVersionLibOlm)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -37,6 +37,9 @@ func (e Encryption) Encrypt(plaintext []byte, privateKey crypto.Curve25519Privat
|
|||
return nil, nil, err
|
||||
}
|
||||
cipher, err := aessha2.NewAESSHA2(sharedSecret, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
ciphertext, err = cipher.Encrypt(plaintext)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ func (o *MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet,
|
|||
}
|
||||
if (messageIndex - o.InitialRatchet.Counter) >= uint32(1<<31) {
|
||||
// the counter is before our initial ratchet - we can't decode this
|
||||
return nil, fmt.Errorf("decrypt: %w", olm.ErrRatchetNotAvailable)
|
||||
return nil, fmt.Errorf("decrypt: %w", olm.ErrUnknownMessageIndex)
|
||||
}
|
||||
// otherwise, start from the initial ratchet. Take a copy so that we don't overwrite the initial ratchet
|
||||
copiedRatchet := o.InitialRatchet
|
||||
|
|
@ -206,7 +206,7 @@ func (o *MegolmInboundSession) UnpickleLibOlm(value []byte) error {
|
|||
return err
|
||||
}
|
||||
if pickledVersion != megolmInboundSessionPickleVersionLibOlm && pickledVersion != 1 {
|
||||
return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion)
|
||||
return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion)
|
||||
}
|
||||
|
||||
if err = o.InitialRatchet.UnpickleLibOlm(decoder); err != nil {
|
||||
|
|
|
|||
|
|
@ -101,8 +101,10 @@ func (o *MegolmOutboundSession) Unpickle(pickled, key []byte) error {
|
|||
func (o *MegolmOutboundSession) UnpickleLibOlm(buf []byte) error {
|
||||
decoder := libolmpickle.NewDecoder(buf)
|
||||
pickledVersion, err := decoder.ReadUInt32()
|
||||
if pickledVersion != megolmOutboundSessionPickleVersionLibOlm {
|
||||
return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unpickle MegolmOutboundSession: failed to read version: %w", err)
|
||||
} else if pickledVersion != megolmOutboundSessionPickleVersionLibOlm {
|
||||
return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion)
|
||||
}
|
||||
if err = o.Ratchet.UnpickleLibOlm(decoder); err != nil {
|
||||
return err
|
||||
|
|
|
|||
|
|
@ -168,11 +168,11 @@ func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, received
|
|||
msg := message.Message{}
|
||||
err = msg.Decode(oneTimeMsg.Message)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Message decode: %w", err)
|
||||
return nil, fmt.Errorf("message decode: %w", err)
|
||||
}
|
||||
|
||||
if len(msg.RatchetKey) == 0 {
|
||||
return nil, fmt.Errorf("Message missing ratchet key: %w", olm.ErrBadMessageFormat)
|
||||
return nil, fmt.Errorf("message missing ratchet key: %w", olm.ErrBadMessageFormat)
|
||||
}
|
||||
//Init Ratchet
|
||||
s.Ratchet.InitializeAsBob(secret, msg.RatchetKey)
|
||||
|
|
@ -203,7 +203,7 @@ func (s *OlmSession) ID() id.SessionID {
|
|||
copy(message[crypto.Curve25519PrivateKeyLength:], s.AliceBaseKey)
|
||||
copy(message[2*crypto.Curve25519PrivateKeyLength:], s.BobOneTimeKey)
|
||||
hash := sha256.Sum256(message)
|
||||
res := id.SessionID(goolmbase64.Encode(hash[:]))
|
||||
res := id.SessionID(base64.RawStdEncoding.EncodeToString(hash[:]))
|
||||
return res
|
||||
}
|
||||
|
||||
|
|
@ -325,7 +325,7 @@ func (s *OlmSession) Decrypt(crypttext string, msgType id.OlmMsgType) ([]byte, e
|
|||
if len(crypttext) == 0 {
|
||||
return nil, fmt.Errorf("decrypt: %w", olm.ErrEmptyInput)
|
||||
}
|
||||
decodedCrypttext, err := goolmbase64.Decode([]byte(crypttext))
|
||||
decodedCrypttext, err := base64.RawStdEncoding.DecodeString(crypttext)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -365,6 +365,9 @@ func (o *OlmSession) Unpickle(pickled, key []byte) error {
|
|||
func (o *OlmSession) UnpickleLibOlm(buf []byte) error {
|
||||
decoder := libolmpickle.NewDecoder(buf)
|
||||
pickledVersion, err := decoder.ReadUInt32()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unpickle olmSession: failed to read version: %w", err)
|
||||
}
|
||||
|
||||
var includesChainIndex bool
|
||||
switch pickledVersion {
|
||||
|
|
@ -373,7 +376,7 @@ func (o *OlmSession) UnpickleLibOlm(buf []byte) error {
|
|||
case uint32(0x80000001):
|
||||
includesChainIndex = true
|
||||
default:
|
||||
return fmt.Errorf("unpickle olmSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion)
|
||||
return fmt.Errorf("unpickle olmSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion)
|
||||
}
|
||||
|
||||
if o.ReceivedMessage, err = decoder.ReadBool(); err != nil {
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ func Register() {
|
|||
// Inbound Session
|
||||
olm.InitInboundGroupSessionFromPickled = func(pickled, key []byte) (olm.InboundGroupSession, error) {
|
||||
if len(pickled) == 0 {
|
||||
return nil, olm.EmptyInput
|
||||
return nil, olm.ErrEmptyInput
|
||||
}
|
||||
if len(key) == 0 {
|
||||
key = []byte(" ")
|
||||
|
|
@ -23,13 +23,13 @@ func Register() {
|
|||
}
|
||||
olm.InitNewInboundGroupSession = func(sessionKey []byte) (olm.InboundGroupSession, error) {
|
||||
if len(sessionKey) == 0 {
|
||||
return nil, olm.EmptyInput
|
||||
return nil, olm.ErrEmptyInput
|
||||
}
|
||||
return NewMegolmInboundSession(sessionKey)
|
||||
}
|
||||
olm.InitInboundGroupSessionImport = func(sessionKey []byte) (olm.InboundGroupSession, error) {
|
||||
if len(sessionKey) == 0 {
|
||||
return nil, olm.EmptyInput
|
||||
return nil, olm.ErrEmptyInput
|
||||
}
|
||||
return NewMegolmInboundSessionFromExport(sessionKey)
|
||||
}
|
||||
|
|
@ -40,7 +40,7 @@ func Register() {
|
|||
// Outbound Session
|
||||
olm.InitNewOutboundGroupSessionFromPickled = func(pickled, key []byte) (olm.OutboundGroupSession, error) {
|
||||
if len(pickled) == 0 {
|
||||
return nil, olm.EmptyInput
|
||||
return nil, olm.ErrEmptyInput
|
||||
}
|
||||
lenKey := len(key)
|
||||
if lenKey == 0 {
|
||||
|
|
|
|||
|
|
@ -56,11 +56,12 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context,
|
|||
// ...by deriving the public key from a private key that it obtained from a trusted source. Trusted sources for the private
|
||||
// key include the user entering the key, retrieving the key stored in secret storage, or obtaining the key via secret sharing
|
||||
// from a verified device belonging to the same user."
|
||||
megolmBackupDerivedPublicKey := id.Ed25519(base64.RawStdEncoding.EncodeToString(megolmBackupKey.PublicKey().Bytes()))
|
||||
if megolmBackupKey != nil && versionInfo.AuthData.PublicKey == megolmBackupDerivedPublicKey {
|
||||
log.Debug().Msg("key backup is trusted based on derived public key")
|
||||
return versionInfo, nil
|
||||
} else {
|
||||
if megolmBackupKey != nil {
|
||||
megolmBackupDerivedPublicKey := id.Ed25519(base64.RawStdEncoding.EncodeToString(megolmBackupKey.PublicKey().Bytes()))
|
||||
if versionInfo.AuthData.PublicKey == megolmBackupDerivedPublicKey {
|
||||
log.Debug().Msg("Key backup is trusted based on derived public key")
|
||||
return versionInfo, nil
|
||||
}
|
||||
log.Debug().
|
||||
Stringer("expected_key", megolmBackupDerivedPublicKey).
|
||||
Stringer("actual_key", versionInfo.AuthData.PublicKey).
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ var _ olm.Account = (*Account)(nil)
|
|||
// "INVALID_BASE64".
|
||||
func AccountFromPickled(pickled, key []byte) (*Account, error) {
|
||||
if len(pickled) == 0 {
|
||||
return nil, olm.EmptyInput
|
||||
return nil, olm.ErrEmptyInput
|
||||
}
|
||||
a := NewBlankAccount()
|
||||
return a, a.Unpickle(pickled, key)
|
||||
|
|
@ -53,7 +53,7 @@ func NewAccount() (*Account, error) {
|
|||
random := make([]byte, a.createRandomLen()+1)
|
||||
_, err := rand.Read(random)
|
||||
if err != nil {
|
||||
panic(olm.NotEnoughGoRandom)
|
||||
panic(olm.ErrNotEnoughGoRandom)
|
||||
}
|
||||
ret := C.olm_create_account(
|
||||
(*C.OlmAccount)(a.int),
|
||||
|
|
@ -128,7 +128,7 @@ func (a *Account) genOneTimeKeysRandomLen(num uint) uint {
|
|||
// supplied key.
|
||||
func (a *Account) Pickle(key []byte) ([]byte, error) {
|
||||
if len(key) == 0 {
|
||||
return nil, olm.NoKeyProvided
|
||||
return nil, olm.ErrNoKeyProvided
|
||||
}
|
||||
pickled := make([]byte, a.pickleLen())
|
||||
r := C.olm_pickle_account(
|
||||
|
|
@ -145,7 +145,7 @@ func (a *Account) Pickle(key []byte) ([]byte, error) {
|
|||
|
||||
func (a *Account) Unpickle(pickled, key []byte) error {
|
||||
if len(key) == 0 {
|
||||
return olm.NoKeyProvided
|
||||
return olm.ErrNoKeyProvided
|
||||
}
|
||||
r := C.olm_unpickle_account(
|
||||
(*C.OlmAccount)(a.int),
|
||||
|
|
@ -198,7 +198,7 @@ func (a *Account) MarshalJSON() ([]byte, error) {
|
|||
// Deprecated
|
||||
func (a *Account) UnmarshalJSON(data []byte) error {
|
||||
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
|
||||
return olm.InputNotJSONString
|
||||
return olm.ErrInputNotJSONString
|
||||
}
|
||||
if a.int == nil {
|
||||
*a = *NewBlankAccount()
|
||||
|
|
@ -235,7 +235,7 @@ func (a *Account) IdentityKeys() (id.Ed25519, id.Curve25519, error) {
|
|||
// Account.
|
||||
func (a *Account) Sign(message []byte) ([]byte, error) {
|
||||
if len(message) == 0 {
|
||||
panic(olm.EmptyInput)
|
||||
panic(olm.ErrEmptyInput)
|
||||
}
|
||||
signature := make([]byte, a.signatureLen())
|
||||
r := C.olm_account_sign(
|
||||
|
|
@ -299,7 +299,7 @@ func (a *Account) GenOneTimeKeys(num uint) error {
|
|||
random := make([]byte, a.genOneTimeKeysRandomLen(num)+1)
|
||||
_, err := rand.Read(random)
|
||||
if err != nil {
|
||||
return olm.NotEnoughGoRandom
|
||||
return olm.ErrNotEnoughGoRandom
|
||||
}
|
||||
r := C.olm_account_generate_one_time_keys(
|
||||
(*C.OlmAccount)(a.int),
|
||||
|
|
@ -319,13 +319,13 @@ func (a *Account) GenOneTimeKeys(num uint) error {
|
|||
// keys couldn't be decoded as base64 then the error will be "INVALID_BASE64"
|
||||
func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (olm.Session, error) {
|
||||
if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 {
|
||||
return nil, olm.EmptyInput
|
||||
return nil, olm.ErrEmptyInput
|
||||
}
|
||||
s := NewBlankSession()
|
||||
random := make([]byte, s.createOutboundRandomLen()+1)
|
||||
_, err := rand.Read(random)
|
||||
if err != nil {
|
||||
panic(olm.NotEnoughGoRandom)
|
||||
panic(olm.ErrNotEnoughGoRandom)
|
||||
}
|
||||
theirIdentityKeyCopy := []byte(theirIdentityKey)
|
||||
theirOneTimeKeyCopy := []byte(theirOneTimeKey)
|
||||
|
|
@ -357,7 +357,7 @@ func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve2
|
|||
// time key then the error will be "BAD_MESSAGE_KEY_ID".
|
||||
func (a *Account) NewInboundSession(oneTimeKeyMsg string) (olm.Session, error) {
|
||||
if len(oneTimeKeyMsg) == 0 {
|
||||
return nil, olm.EmptyInput
|
||||
return nil, olm.ErrEmptyInput
|
||||
}
|
||||
s := NewBlankSession()
|
||||
oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg)
|
||||
|
|
@ -383,7 +383,7 @@ func (a *Account) NewInboundSession(oneTimeKeyMsg string) (olm.Session, error) {
|
|||
// time key then the error will be "BAD_MESSAGE_KEY_ID".
|
||||
func (a *Account) NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTimeKeyMsg string) (olm.Session, error) {
|
||||
if theirIdentityKey == nil || len(oneTimeKeyMsg) == 0 {
|
||||
return nil, olm.EmptyInput
|
||||
return nil, olm.ErrEmptyInput
|
||||
}
|
||||
theirIdentityKeyCopy := []byte(*theirIdentityKey)
|
||||
oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg)
|
||||
|
|
|
|||
|
|
@ -11,21 +11,21 @@ import (
|
|||
)
|
||||
|
||||
var errorMap = map[string]error{
|
||||
"NOT_ENOUGH_RANDOM": olm.NotEnoughRandom,
|
||||
"OUTPUT_BUFFER_TOO_SMALL": olm.OutputBufferTooSmall,
|
||||
"BAD_MESSAGE_VERSION": olm.BadMessageVersion,
|
||||
"BAD_MESSAGE_FORMAT": olm.BadMessageFormat,
|
||||
"BAD_MESSAGE_MAC": olm.BadMessageMAC,
|
||||
"BAD_MESSAGE_KEY_ID": olm.BadMessageKeyID,
|
||||
"INVALID_BASE64": olm.InvalidBase64,
|
||||
"BAD_ACCOUNT_KEY": olm.BadAccountKey,
|
||||
"UNKNOWN_PICKLE_VERSION": olm.UnknownPickleVersion,
|
||||
"CORRUPTED_PICKLE": olm.CorruptedPickle,
|
||||
"BAD_SESSION_KEY": olm.BadSessionKey,
|
||||
"UNKNOWN_MESSAGE_INDEX": olm.UnknownMessageIndex,
|
||||
"BAD_LEGACY_ACCOUNT_PICKLE": olm.BadLegacyAccountPickle,
|
||||
"BAD_SIGNATURE": olm.BadSignature,
|
||||
"INPUT_BUFFER_TOO_SMALL": olm.InputBufferTooSmall,
|
||||
"NOT_ENOUGH_RANDOM": olm.ErrLibolmNotEnoughRandom,
|
||||
"OUTPUT_BUFFER_TOO_SMALL": olm.ErrLibolmOutputBufferTooSmall,
|
||||
"BAD_MESSAGE_VERSION": olm.ErrWrongProtocolVersion,
|
||||
"BAD_MESSAGE_FORMAT": olm.ErrBadMessageFormat,
|
||||
"BAD_MESSAGE_MAC": olm.ErrBadMAC,
|
||||
"BAD_MESSAGE_KEY_ID": olm.ErrBadMessageKeyID,
|
||||
"INVALID_BASE64": olm.ErrLibolmInvalidBase64,
|
||||
"BAD_ACCOUNT_KEY": olm.ErrLibolmBadAccountKey,
|
||||
"UNKNOWN_PICKLE_VERSION": olm.ErrUnknownOlmPickleVersion,
|
||||
"CORRUPTED_PICKLE": olm.ErrLibolmCorruptedPickle,
|
||||
"BAD_SESSION_KEY": olm.ErrLibolmBadSessionKey,
|
||||
"UNKNOWN_MESSAGE_INDEX": olm.ErrUnknownMessageIndex,
|
||||
"BAD_LEGACY_ACCOUNT_PICKLE": olm.ErrLibolmBadLegacyAccountPickle,
|
||||
"BAD_SIGNATURE": olm.ErrBadSignature,
|
||||
"INPUT_BUFFER_TOO_SMALL": olm.ErrInputToSmall,
|
||||
}
|
||||
|
||||
func convertError(errCode string) error {
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ var _ olm.InboundGroupSession = (*InboundGroupSession)(nil)
|
|||
// base64 couldn't be decoded then the error will be "INVALID_BASE64".
|
||||
func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, error) {
|
||||
if len(pickled) == 0 {
|
||||
return nil, olm.EmptyInput
|
||||
return nil, olm.ErrEmptyInput
|
||||
}
|
||||
lenKey := len(key)
|
||||
if lenKey == 0 {
|
||||
|
|
@ -48,7 +48,7 @@ func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession,
|
|||
// "OLM_BAD_SESSION_KEY".
|
||||
func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) {
|
||||
if len(sessionKey) == 0 {
|
||||
return nil, olm.EmptyInput
|
||||
return nil, olm.ErrEmptyInput
|
||||
}
|
||||
s := NewBlankInboundGroupSession()
|
||||
r := C.olm_init_inbound_group_session(
|
||||
|
|
@ -69,7 +69,7 @@ func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) {
|
|||
// error will be "OLM_BAD_SESSION_KEY".
|
||||
func InboundGroupSessionImport(sessionKey []byte) (*InboundGroupSession, error) {
|
||||
if len(sessionKey) == 0 {
|
||||
return nil, olm.EmptyInput
|
||||
return nil, olm.ErrEmptyInput
|
||||
}
|
||||
s := NewBlankInboundGroupSession()
|
||||
r := C.olm_import_inbound_group_session(
|
||||
|
|
@ -124,7 +124,7 @@ func (s *InboundGroupSession) pickleLen() uint {
|
|||
// InboundGroupSession using the supplied key.
|
||||
func (s *InboundGroupSession) Pickle(key []byte) ([]byte, error) {
|
||||
if len(key) == 0 {
|
||||
return nil, olm.NoKeyProvided
|
||||
return nil, olm.ErrNoKeyProvided
|
||||
}
|
||||
pickled := make([]byte, s.pickleLen())
|
||||
r := C.olm_pickle_inbound_group_session(
|
||||
|
|
@ -143,9 +143,9 @@ func (s *InboundGroupSession) Pickle(key []byte) ([]byte, error) {
|
|||
|
||||
func (s *InboundGroupSession) Unpickle(pickled, key []byte) error {
|
||||
if len(key) == 0 {
|
||||
return olm.NoKeyProvided
|
||||
return olm.ErrNoKeyProvided
|
||||
} else if len(pickled) == 0 {
|
||||
return olm.EmptyInput
|
||||
return olm.ErrEmptyInput
|
||||
}
|
||||
r := C.olm_unpickle_inbound_group_session(
|
||||
(*C.OlmInboundGroupSession)(s.int),
|
||||
|
|
@ -200,7 +200,7 @@ func (s *InboundGroupSession) MarshalJSON() ([]byte, error) {
|
|||
// Deprecated
|
||||
func (s *InboundGroupSession) UnmarshalJSON(data []byte) error {
|
||||
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
|
||||
return olm.InputNotJSONString
|
||||
return olm.ErrInputNotJSONString
|
||||
}
|
||||
if s == nil || s.int == nil {
|
||||
*s = *NewBlankInboundGroupSession()
|
||||
|
|
@ -217,7 +217,7 @@ func (s *InboundGroupSession) UnmarshalJSON(data []byte) error {
|
|||
// will be "BAD_MESSAGE_FORMAT".
|
||||
func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, error) {
|
||||
if len(message) == 0 {
|
||||
return 0, olm.EmptyInput
|
||||
return 0, olm.ErrEmptyInput
|
||||
}
|
||||
// olm_group_decrypt_max_plaintext_length destroys the input, so we have to clone it
|
||||
messageCopy := bytes.Clone(message)
|
||||
|
|
@ -244,7 +244,7 @@ func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, erro
|
|||
// was shared with us) the error will be "OLM_UNKNOWN_MESSAGE_INDEX".
|
||||
func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) {
|
||||
if len(message) == 0 {
|
||||
return nil, 0, olm.EmptyInput
|
||||
return nil, 0, olm.ErrEmptyInput
|
||||
}
|
||||
decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ func (s *OutboundGroupSession) pickleLen() uint {
|
|||
// OutboundGroupSession using the supplied key.
|
||||
func (s *OutboundGroupSession) Pickle(key []byte) ([]byte, error) {
|
||||
if len(key) == 0 {
|
||||
return nil, olm.NoKeyProvided
|
||||
return nil, olm.ErrNoKeyProvided
|
||||
}
|
||||
pickled := make([]byte, s.pickleLen())
|
||||
r := C.olm_pickle_outbound_group_session(
|
||||
|
|
@ -103,7 +103,7 @@ func (s *OutboundGroupSession) Pickle(key []byte) ([]byte, error) {
|
|||
|
||||
func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error {
|
||||
if len(key) == 0 {
|
||||
return olm.NoKeyProvided
|
||||
return olm.ErrNoKeyProvided
|
||||
}
|
||||
r := C.olm_unpickle_outbound_group_session(
|
||||
(*C.OlmOutboundGroupSession)(s.int),
|
||||
|
|
@ -159,7 +159,7 @@ func (s *OutboundGroupSession) MarshalJSON() ([]byte, error) {
|
|||
// Deprecated
|
||||
func (s *OutboundGroupSession) UnmarshalJSON(data []byte) error {
|
||||
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
|
||||
return olm.InputNotJSONString
|
||||
return olm.ErrInputNotJSONString
|
||||
}
|
||||
if s == nil || s.int == nil {
|
||||
*s = *NewBlankOutboundGroupSession()
|
||||
|
|
@ -183,7 +183,7 @@ func (s *OutboundGroupSession) encryptMsgLen(plainTextLen int) uint {
|
|||
// as base64.
|
||||
func (s *OutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) {
|
||||
if len(plaintext) == 0 {
|
||||
return nil, olm.EmptyInput
|
||||
return nil, olm.ErrEmptyInput
|
||||
}
|
||||
message := make([]byte, s.encryptMsgLen(len(plaintext)))
|
||||
r := C.olm_group_encrypt(
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ func NewPKSigning() (*PKSigning, error) {
|
|||
seed := make([]byte, pkSigningSeedLength())
|
||||
_, err := rand.Read(seed)
|
||||
if err != nil {
|
||||
panic(olm.NotEnoughGoRandom)
|
||||
panic(olm.ErrNotEnoughGoRandom)
|
||||
}
|
||||
pk, err := NewPKSigningFromSeed(seed)
|
||||
return pk, err
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ func Register() {
|
|||
|
||||
olm.InitNewOutboundGroupSessionFromPickled = func(pickled, key []byte) (olm.OutboundGroupSession, error) {
|
||||
if len(pickled) == 0 {
|
||||
return nil, olm.EmptyInput
|
||||
return nil, olm.ErrEmptyInput
|
||||
}
|
||||
s := NewBlankOutboundGroupSession()
|
||||
return s, s.Unpickle(pickled, key)
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ func sessionSize() uint {
|
|||
// "INVALID_BASE64".
|
||||
func SessionFromPickled(pickled, key []byte) (*Session, error) {
|
||||
if len(pickled) == 0 {
|
||||
return nil, olm.EmptyInput
|
||||
return nil, olm.ErrEmptyInput
|
||||
}
|
||||
s := NewBlankSession()
|
||||
return s, s.Unpickle(pickled, key)
|
||||
|
|
@ -118,7 +118,7 @@ func (s *Session) encryptMsgLen(plainTextLen int) uint {
|
|||
// will be "BAD_MESSAGE_FORMAT".
|
||||
func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType) (uint, error) {
|
||||
if len(message) == 0 {
|
||||
return 0, olm.EmptyInput
|
||||
return 0, olm.ErrEmptyInput
|
||||
}
|
||||
messageCopy := []byte(message)
|
||||
r := C.olm_decrypt_max_plaintext_length(
|
||||
|
|
@ -138,7 +138,7 @@ func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType)
|
|||
// supplied key.
|
||||
func (s *Session) Pickle(key []byte) ([]byte, error) {
|
||||
if len(key) == 0 {
|
||||
return nil, olm.NoKeyProvided
|
||||
return nil, olm.ErrNoKeyProvided
|
||||
}
|
||||
pickled := make([]byte, s.pickleLen())
|
||||
r := C.olm_pickle_session(
|
||||
|
|
@ -158,7 +158,7 @@ func (s *Session) Pickle(key []byte) ([]byte, error) {
|
|||
// provided key. This function mutates the input pickled data slice.
|
||||
func (s *Session) Unpickle(pickled, key []byte) error {
|
||||
if len(key) == 0 {
|
||||
return olm.NoKeyProvided
|
||||
return olm.ErrNoKeyProvided
|
||||
}
|
||||
r := C.olm_unpickle_session(
|
||||
(*C.OlmSession)(s.int),
|
||||
|
|
@ -213,7 +213,7 @@ func (s *Session) MarshalJSON() ([]byte, error) {
|
|||
// Deprecated
|
||||
func (s *Session) UnmarshalJSON(data []byte) error {
|
||||
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
|
||||
return olm.InputNotJSONString
|
||||
return olm.ErrInputNotJSONString
|
||||
}
|
||||
if s == nil || s.int == nil {
|
||||
*s = *NewBlankSession()
|
||||
|
|
@ -256,7 +256,7 @@ func (s *Session) HasReceivedMessage() bool {
|
|||
// decoded then then the error will be "BAD_MESSAGE_FORMAT".
|
||||
func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) {
|
||||
if len(oneTimeKeyMsg) == 0 {
|
||||
return false, olm.EmptyInput
|
||||
return false, olm.ErrEmptyInput
|
||||
}
|
||||
oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg)
|
||||
r := C.olm_matches_inbound_session(
|
||||
|
|
@ -284,7 +284,7 @@ func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) {
|
|||
// decoded then then the error will be "BAD_MESSAGE_FORMAT".
|
||||
func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) {
|
||||
if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 {
|
||||
return false, olm.EmptyInput
|
||||
return false, olm.ErrEmptyInput
|
||||
}
|
||||
theirIdentityKeyCopy := []byte(theirIdentityKey)
|
||||
oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg)
|
||||
|
|
@ -325,14 +325,14 @@ func (s *Session) EncryptMsgType() id.OlmMsgType {
|
|||
// as base64.
|
||||
func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) {
|
||||
if len(plaintext) == 0 {
|
||||
return 0, nil, olm.EmptyInput
|
||||
return 0, nil, olm.ErrEmptyInput
|
||||
}
|
||||
// Make the slice be at least length 1
|
||||
random := make([]byte, s.encryptRandomLen()+1)
|
||||
_, err := rand.Read(random)
|
||||
if err != nil {
|
||||
// TODO can we just return err here?
|
||||
return 0, nil, olm.NotEnoughGoRandom
|
||||
return 0, nil, olm.ErrNotEnoughGoRandom
|
||||
}
|
||||
messageType := s.EncryptMsgType()
|
||||
message := make([]byte, s.encryptMsgLen(len(plaintext)))
|
||||
|
|
@ -362,7 +362,7 @@ func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) {
|
|||
// "BAD_MESSAGE_MAC".
|
||||
func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) {
|
||||
if len(message) == 0 {
|
||||
return nil, olm.EmptyInput
|
||||
return nil, olm.ErrEmptyInput
|
||||
}
|
||||
decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message, msgType)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -205,7 +205,7 @@ func (mach *OlmMachine) FlushStore(ctx context.Context) error {
|
|||
func (mach *OlmMachine) timeTrace(ctx context.Context, thing string, expectedDuration time.Duration) func() {
|
||||
start := time.Now()
|
||||
return func() {
|
||||
duration := time.Now().Sub(start)
|
||||
duration := time.Since(start)
|
||||
if duration > expectedDuration {
|
||||
zerolog.Ctx(ctx).Warn().
|
||||
Str("action", thing).
|
||||
|
|
|
|||
|
|
@ -10,50 +10,67 @@ import "errors"
|
|||
|
||||
// Those are the most common used errors
|
||||
var (
|
||||
ErrBadSignature = errors.New("bad signature")
|
||||
ErrBadMAC = errors.New("bad mac")
|
||||
ErrBadMessageFormat = errors.New("bad message format")
|
||||
ErrBadVerification = errors.New("bad verification")
|
||||
ErrWrongProtocolVersion = errors.New("wrong protocol version")
|
||||
ErrEmptyInput = errors.New("empty input")
|
||||
ErrNoKeyProvided = errors.New("no key")
|
||||
ErrBadMessageKeyID = errors.New("bad message key id")
|
||||
ErrRatchetNotAvailable = errors.New("ratchet not available: attempt to decode a message whose index is earlier than our earliest known session key")
|
||||
ErrMsgIndexTooHigh = errors.New("message index too high")
|
||||
ErrProtocolViolation = errors.New("not protocol message order")
|
||||
ErrMessageKeyNotFound = errors.New("message key not found")
|
||||
ErrChainTooHigh = errors.New("chain index too high")
|
||||
ErrBadInput = errors.New("bad input")
|
||||
ErrBadVersion = errors.New("wrong version")
|
||||
ErrWrongPickleVersion = errors.New("wrong pickle version")
|
||||
ErrInputToSmall = errors.New("input too small (truncated?)")
|
||||
ErrOverflow = errors.New("overflow")
|
||||
ErrBadSignature = errors.New("bad signature")
|
||||
ErrBadMAC = errors.New("the message couldn't be decrypted (bad mac)")
|
||||
ErrBadMessageFormat = errors.New("the message couldn't be decoded")
|
||||
ErrBadVerification = errors.New("bad verification")
|
||||
ErrWrongProtocolVersion = errors.New("wrong protocol version")
|
||||
ErrEmptyInput = errors.New("empty input")
|
||||
ErrNoKeyProvided = errors.New("no key provided")
|
||||
ErrBadMessageKeyID = errors.New("the message references an unknown key ID")
|
||||
ErrUnknownMessageIndex = errors.New("attempt to decode a message whose index is earlier than our earliest known session key")
|
||||
ErrMsgIndexTooHigh = errors.New("message index too high")
|
||||
ErrProtocolViolation = errors.New("not protocol message order")
|
||||
ErrMessageKeyNotFound = errors.New("message key not found")
|
||||
ErrChainTooHigh = errors.New("chain index too high")
|
||||
ErrBadInput = errors.New("bad input")
|
||||
ErrUnknownOlmPickleVersion = errors.New("unknown olm pickle version")
|
||||
ErrUnknownJSONPickleVersion = errors.New("unknown JSON pickle version")
|
||||
ErrInputToSmall = errors.New("input too small (truncated?)")
|
||||
)
|
||||
|
||||
// Error codes from go-olm
|
||||
var (
|
||||
EmptyInput = errors.New("empty input")
|
||||
NoKeyProvided = errors.New("no pickle key provided")
|
||||
NotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand")
|
||||
SignatureNotFound = errors.New("input JSON doesn't contain signature from specified device")
|
||||
InputNotJSONString = errors.New("input doesn't look like a JSON string")
|
||||
ErrNotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand")
|
||||
ErrInputNotJSONString = errors.New("input doesn't look like a JSON string")
|
||||
)
|
||||
|
||||
// Error codes from olm code
|
||||
var (
|
||||
NotEnoughRandom = errors.New("not enough entropy was supplied")
|
||||
OutputBufferTooSmall = errors.New("supplied output buffer is too small")
|
||||
BadMessageVersion = errors.New("the message version is unsupported")
|
||||
BadMessageFormat = errors.New("the message couldn't be decoded")
|
||||
BadMessageMAC = errors.New("the message couldn't be decrypted")
|
||||
BadMessageKeyID = errors.New("the message references an unknown key ID")
|
||||
InvalidBase64 = errors.New("the input base64 was invalid")
|
||||
BadAccountKey = errors.New("the supplied account key is invalid")
|
||||
UnknownPickleVersion = errors.New("the pickled object is too new")
|
||||
CorruptedPickle = errors.New("the pickled object couldn't be decoded")
|
||||
BadSessionKey = errors.New("attempt to initialise an inbound group session from an invalid session key")
|
||||
UnknownMessageIndex = errors.New("attempt to decode a message whose index is earlier than our earliest known session key")
|
||||
BadLegacyAccountPickle = errors.New("attempt to unpickle an account which uses pickle version 1")
|
||||
BadSignature = errors.New("received message had a bad signature")
|
||||
InputBufferTooSmall = errors.New("the input data was too small to be valid")
|
||||
ErrLibolmInvalidBase64 = errors.New("the input base64 was invalid")
|
||||
|
||||
ErrLibolmNotEnoughRandom = errors.New("not enough entropy was supplied")
|
||||
ErrLibolmOutputBufferTooSmall = errors.New("supplied output buffer is too small")
|
||||
ErrLibolmBadAccountKey = errors.New("the supplied account key is invalid")
|
||||
ErrLibolmCorruptedPickle = errors.New("the pickled object couldn't be decoded")
|
||||
ErrLibolmBadSessionKey = errors.New("attempt to initialise an inbound group session from an invalid session key")
|
||||
ErrLibolmBadLegacyAccountPickle = errors.New("attempt to unpickle an account which uses pickle version 1")
|
||||
)
|
||||
|
||||
// Deprecated: use variables prefixed with Err
|
||||
var (
|
||||
EmptyInput = ErrEmptyInput
|
||||
BadSignature = ErrBadSignature
|
||||
InvalidBase64 = ErrLibolmInvalidBase64
|
||||
BadMessageKeyID = ErrBadMessageKeyID
|
||||
BadMessageFormat = ErrBadMessageFormat
|
||||
BadMessageVersion = ErrWrongProtocolVersion
|
||||
BadMessageMAC = ErrBadMAC
|
||||
UnknownPickleVersion = ErrUnknownOlmPickleVersion
|
||||
NotEnoughRandom = ErrLibolmNotEnoughRandom
|
||||
OutputBufferTooSmall = ErrLibolmOutputBufferTooSmall
|
||||
BadAccountKey = ErrLibolmBadAccountKey
|
||||
CorruptedPickle = ErrLibolmCorruptedPickle
|
||||
BadSessionKey = ErrLibolmBadSessionKey
|
||||
UnknownMessageIndex = ErrUnknownMessageIndex
|
||||
BadLegacyAccountPickle = ErrLibolmBadLegacyAccountPickle
|
||||
InputBufferTooSmall = ErrInputToSmall
|
||||
NoKeyProvided = ErrNoKeyProvided
|
||||
|
||||
NotEnoughGoRandom = ErrNotEnoughGoRandom
|
||||
InputNotJSONString = ErrInputNotJSONString
|
||||
|
||||
ErrBadVersion = ErrUnknownJSONPickleVersion
|
||||
ErrWrongPickleVersion = ErrUnknownJSONPickleVersion
|
||||
ErrRatchetNotAvailable = ErrUnknownMessageIndex
|
||||
)
|
||||
|
|
|
|||
|
|
@ -18,8 +18,14 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
SessionNotShared = errors.New("session has not been shared")
|
||||
SessionExpired = errors.New("session has expired")
|
||||
ErrSessionNotShared = errors.New("session has not been shared")
|
||||
ErrSessionExpired = errors.New("session has expired")
|
||||
)
|
||||
|
||||
// Deprecated: use variables prefixed with Err
|
||||
var (
|
||||
SessionNotShared = ErrSessionNotShared
|
||||
SessionExpired = ErrSessionExpired
|
||||
)
|
||||
|
||||
// OlmSessionList is a list of OlmSessions.
|
||||
|
|
@ -255,9 +261,9 @@ func (ogs *OutboundGroupSession) Expired() bool {
|
|||
|
||||
func (ogs *OutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) {
|
||||
if !ogs.Shared {
|
||||
return nil, SessionNotShared
|
||||
return nil, ErrSessionNotShared
|
||||
} else if ogs.Expired() {
|
||||
return nil, SessionExpired
|
||||
return nil, ErrSessionExpired
|
||||
}
|
||||
ogs.MessageCount++
|
||||
ogs.LastEncryptedTime = time.Now()
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ func (content *EncryptedEventContent) UnmarshalJSON(data []byte) error {
|
|||
return json.Unmarshal(content.Ciphertext, &content.OlmCiphertext)
|
||||
case id.AlgorithmMegolmV1:
|
||||
if len(content.Ciphertext) == 0 || content.Ciphertext[0] != '"' || content.Ciphertext[len(content.Ciphertext)-1] != '"' {
|
||||
return id.InputNotJSONString
|
||||
return id.ErrInputNotJSONString
|
||||
}
|
||||
content.MegolmCiphertext = content.Ciphertext[1 : len(content.Ciphertext)-1]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -80,7 +80,10 @@ func ResolveServerName(ctx context.Context, serverName string, opts ...*ResolveS
|
|||
} else if wellKnown != nil {
|
||||
output.Expires = expiry
|
||||
output.HostHeader = wellKnown.Server
|
||||
hostname, port, ok = ParseServerName(wellKnown.Server)
|
||||
wkHost, wkPort, ok := ParseServerName(wellKnown.Server)
|
||||
if ok {
|
||||
hostname, port = wkHost, wkPort
|
||||
}
|
||||
// Step 3.1 and 3.2: IP literals and hostnames with port inside .well-known
|
||||
if net.ParseIP(hostname) != nil || port != 0 {
|
||||
if port == 0 {
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ type FilterPart struct {
|
|||
// Validate checks if the filter contains valid property values
|
||||
func (filter *Filter) Validate() error {
|
||||
if filter.EventFormat != EventFormatClient && filter.EventFormat != EventFormatFederation {
|
||||
return errors.New("Bad event_format value. Must be one of [\"client\", \"federation\"]")
|
||||
return errors.New("bad event_format value")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,8 +17,14 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
InvalidContentURI = errors.New("invalid Matrix content URI")
|
||||
InputNotJSONString = errors.New("input doesn't look like a JSON string")
|
||||
ErrInvalidContentURI = errors.New("invalid Matrix content URI")
|
||||
ErrInputNotJSONString = errors.New("input doesn't look like a JSON string")
|
||||
)
|
||||
|
||||
// Deprecated: use variables prefixed with Err
|
||||
var (
|
||||
InvalidContentURI = ErrInvalidContentURI
|
||||
InputNotJSONString = ErrInputNotJSONString
|
||||
)
|
||||
|
||||
// ContentURIString is a string that's expected to be a Matrix content URI.
|
||||
|
|
@ -55,9 +61,9 @@ func ParseContentURI(uri string) (parsed ContentURI, err error) {
|
|||
if len(uri) == 0 {
|
||||
return
|
||||
} else if !strings.HasPrefix(uri, "mxc://") {
|
||||
err = InvalidContentURI
|
||||
err = ErrInvalidContentURI
|
||||
} else if index := strings.IndexRune(uri[6:], '/'); index == -1 || index == len(uri)-7 {
|
||||
err = InvalidContentURI
|
||||
err = ErrInvalidContentURI
|
||||
} else {
|
||||
parsed.Homeserver = uri[6 : 6+index]
|
||||
parsed.FileID = uri[6+index+1:]
|
||||
|
|
@ -71,9 +77,9 @@ func ParseContentURIBytes(uri []byte) (parsed ContentURI, err error) {
|
|||
if len(uri) == 0 {
|
||||
return
|
||||
} else if !bytes.HasPrefix(uri, mxcBytes) {
|
||||
err = InvalidContentURI
|
||||
err = ErrInvalidContentURI
|
||||
} else if index := bytes.IndexRune(uri[6:], '/'); index == -1 || index == len(uri)-7 {
|
||||
err = InvalidContentURI
|
||||
err = ErrInvalidContentURI
|
||||
} else {
|
||||
parsed.Homeserver = string(uri[6 : 6+index])
|
||||
parsed.FileID = string(uri[6+index+1:])
|
||||
|
|
@ -86,7 +92,7 @@ func (uri *ContentURI) UnmarshalJSON(raw []byte) (err error) {
|
|||
*uri = ContentURI{}
|
||||
return nil
|
||||
} else if len(raw) < 2 || raw[0] != '"' || raw[len(raw)-1] != '"' {
|
||||
return InputNotJSONString
|
||||
return ErrInputNotJSONString
|
||||
}
|
||||
parsed, err := ParseContentURIBytes(raw[1 : len(raw)-1])
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ var SigilToPathSegment = map[rune]string{
|
|||
|
||||
func (uri *MatrixURI) getQuery() url.Values {
|
||||
q := make(url.Values)
|
||||
if uri.Via != nil && len(uri.Via) > 0 {
|
||||
if len(uri.Via) > 0 {
|
||||
q["via"] = uri.Via
|
||||
}
|
||||
if len(uri.Action) > 0 {
|
||||
|
|
|
|||
|
|
@ -219,15 +219,15 @@ func DecodeUserLocalpart(str string) (string, error) {
|
|||
for i := 0; i < len(strBytes); i++ {
|
||||
b := strBytes[i]
|
||||
if !isValidByte(b) {
|
||||
return "", fmt.Errorf("Byte pos %d: Invalid byte", i)
|
||||
return "", fmt.Errorf("invalid encoded byte at position %d: %c", i, b)
|
||||
}
|
||||
|
||||
if b == '_' { // next byte is a-z and should be upper-case or is another _ and should be a literal _
|
||||
if i+1 >= len(strBytes) {
|
||||
return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding but ran out of string", i)
|
||||
return "", fmt.Errorf("unexpected end of string after underscore at %d", i)
|
||||
}
|
||||
if !isValidEscapedChar(strBytes[i+1]) { // invalid escaping
|
||||
return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding", i)
|
||||
return "", fmt.Errorf("unexpected byte %c after underscore at %d", strBytes[i+1], i)
|
||||
}
|
||||
if strBytes[i+1] == '_' {
|
||||
outputBuffer.WriteByte('_')
|
||||
|
|
@ -237,7 +237,7 @@ func DecodeUserLocalpart(str string) (string, error) {
|
|||
i++ // skip next byte since we just handled it
|
||||
} else if b == '=' { // next 2 bytes are hex and should be buffered ready to be read as utf8
|
||||
if i+2 >= len(strBytes) {
|
||||
return "", fmt.Errorf("Byte pos: %d: expected quote-printable encoding but ran out of string", i)
|
||||
return "", fmt.Errorf("unexpected end of string after equals sign at %d", i)
|
||||
}
|
||||
dst := make([]byte, 1)
|
||||
_, err := hex.Decode(dst, strBytes[i+1:i+3])
|
||||
|
|
|
|||
|
|
@ -105,7 +105,7 @@ func (action *PushAction) UnmarshalJSON(raw []byte) error {
|
|||
if ok {
|
||||
action.Action = ActionSetTweak
|
||||
action.Tweak = PushActionTweak(tweak)
|
||||
action.Value, _ = val["value"]
|
||||
action.Value = val["value"]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
|
|
|||
|
|
@ -102,14 +102,6 @@ func newEventPropertyIsPushCondition(key string, value any) *pushrules.PushCondi
|
|||
}
|
||||
}
|
||||
|
||||
func newEventPropertyContainsPushCondition(key string, value any) *pushrules.PushCondition {
|
||||
return &pushrules.PushCondition{
|
||||
Kind: pushrules.KindEventPropertyContains,
|
||||
Key: key,
|
||||
Value: value,
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushCondition_Match_InvalidKind(t *testing.T) {
|
||||
condition := &pushrules.PushCondition{
|
||||
Kind: pushrules.PushCondKind("invalid"),
|
||||
|
|
|
|||
6
room.go
6
room.go
|
|
@ -5,8 +5,6 @@ import (
|
|||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
type RoomStateMap = map[event.Type]map[string]*event.Event
|
||||
|
||||
// Room represents a single Matrix room.
|
||||
type Room struct {
|
||||
ID id.RoomID
|
||||
|
|
@ -25,8 +23,8 @@ func (room Room) UpdateState(evt *event.Event) {
|
|||
|
||||
// GetStateEvent returns the state event for the given type/state_key combo, or nil.
|
||||
func (room Room) GetStateEvent(eventType event.Type, stateKey string) *event.Event {
|
||||
stateEventMap, _ := room.State[eventType]
|
||||
evt, _ := stateEventMap[stateKey]
|
||||
stateEventMap := room.State[eventType]
|
||||
evt := stateEventMap[stateKey]
|
||||
return evt
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -75,8 +75,7 @@ type RespListRooms struct {
|
|||
// https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#list-room-api
|
||||
func (cli *Client) ListRooms(ctx context.Context, req ReqListRoom) (RespListRooms, error) {
|
||||
var resp RespListRooms
|
||||
var reqURL string
|
||||
reqURL = cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery())
|
||||
reqURL := cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery())
|
||||
_, err := cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp)
|
||||
return resp, err
|
||||
}
|
||||
|
|
|
|||
6
url.go
6
url.go
|
|
@ -98,10 +98,8 @@ func (saup SynapseAdminURLPath) FullPath() []any {
|
|||
// and appservice user ID set already.
|
||||
func (cli *Client) BuildURLWithQuery(urlPath PrefixableURLPath, urlQuery map[string]string) string {
|
||||
return cli.BuildURLWithFullQuery(urlPath, func(q url.Values) {
|
||||
if urlQuery != nil {
|
||||
for k, v := range urlQuery {
|
||||
q.Set(k, v)
|
||||
}
|
||||
for k, v := range urlQuery {
|
||||
q.Set(k, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue