diff --git a/crypto/ssss/key.go b/crypto/ssss/key.go index c973c1fe..aa22360a 100644 --- a/crypto/ssss/key.go +++ b/crypto/ssss/key.go @@ -57,7 +57,12 @@ func NewKey(passphrase string) (*Key, error) { // We store a certain hash in the key metadata so that clients can check if the user entered the correct key. ivBytes := random.Bytes(utils.AESCTRIVLength) keyData.IV = base64.RawStdEncoding.EncodeToString(ivBytes) - keyData.MAC = keyData.calculateHash(ssssKey) + var err error + keyData.MAC, err = keyData.calculateHash(ssssKey) + if err != nil { + // This should never happen because we just generated the IV and key. + return nil, fmt.Errorf("failed to calculate hash: %w", err) + } return &Key{ Key: ssssKey, diff --git a/crypto/ssss/meta.go b/crypto/ssss/meta.go index 210bcdcf..474c85d8 100644 --- a/crypto/ssss/meta.go +++ b/crypto/ssss/meta.go @@ -33,8 +33,8 @@ func (kd *KeyMetadata) VerifyPassphrase(keyID, passphrase string) (*Key, error) ssssKey, err := kd.Passphrase.GetKey(passphrase) if err != nil { return nil, err - } else if !kd.VerifyKey(ssssKey) { - return nil, ErrIncorrectSSSSKey + } else if err = kd.verifyKey(ssssKey); err != nil { + return nil, err } return &Key{ @@ -49,8 +49,8 @@ func (kd *KeyMetadata) VerifyRecoveryKey(keyID, recoveryKey string) (*Key, error ssssKey := utils.DecodeBase58RecoveryKey(recoveryKey) if ssssKey == nil { return nil, ErrInvalidRecoveryKey - } else if !kd.VerifyKey(ssssKey) { - return nil, ErrIncorrectSSSSKey + } else if err := kd.verifyKey(ssssKey); err != nil { + return nil, err } return &Key{ @@ -60,22 +60,46 @@ func (kd *KeyMetadata) VerifyRecoveryKey(keyID, recoveryKey string) (*Key, error }, nil } +func (kd *KeyMetadata) verifyKey(key []byte) error { + unpaddedMAC := strings.TrimRight(kd.MAC, "=") + expectedMACLength := base64.RawStdEncoding.EncodedLen(utils.SHAHashLength) + if len(unpaddedMAC) != expectedMACLength { + return fmt.Errorf("%w: invalid mac length %d (expected %d)", ErrCorruptedKeyMetadata, len(unpaddedMAC), expectedMACLength) + } + hash, err := kd.calculateHash(key) + if err != nil { + return err + } + if unpaddedMAC != hash { + return ErrIncorrectSSSSKey + } + return nil +} + // VerifyKey verifies the SSSS key is valid by calculating and comparing its MAC. func (kd *KeyMetadata) VerifyKey(key []byte) bool { - return strings.TrimRight(kd.MAC, "=") == kd.calculateHash(key) + return kd.verifyKey(key) == nil } // calculateHash calculates the hash used for checking if the key is entered correctly as described // in the spec: https://matrix.org/docs/spec/client_server/unstable#m-secret-storage-v1-aes-hmac-sha2 -func (kd *KeyMetadata) calculateHash(key []byte) string { +func (kd *KeyMetadata) calculateHash(key []byte) (string, error) { aesKey, hmacKey := utils.DeriveKeysSHA256(key, "") + unpaddedIV := strings.TrimRight(kd.IV, "=") + expectedIVLength := base64.RawStdEncoding.EncodedLen(utils.AESCTRIVLength) + if len(unpaddedIV) != expectedIVLength { + return "", fmt.Errorf("%w: invalid iv length %d (expected %d)", ErrCorruptedKeyMetadata, len(unpaddedIV), expectedIVLength) + } var ivBytes [utils.AESCTRIVLength]byte - _, _ = base64.RawStdEncoding.Decode(ivBytes[:], []byte(strings.TrimRight(kd.IV, "="))) + _, err := base64.RawStdEncoding.Decode(ivBytes[:], []byte(unpaddedIV)) + if err != nil { + return "", fmt.Errorf("%w: failed to decode iv: %w", ErrCorruptedKeyMetadata, err) + } cipher := utils.XorA256CTR(make([]byte, utils.AESCTRKeyLength), aesKey, ivBytes) - return utils.HMACSHA256B64(cipher, hmacKey) + return utils.HMACSHA256B64(cipher, hmacKey), nil } // PassphraseMetadata represents server-side metadata about a SSSS key passphrase. diff --git a/crypto/ssss/meta_test.go b/crypto/ssss/meta_test.go index 96c97282..4f2ff378 100644 --- a/crypto/ssss/meta_test.go +++ b/crypto/ssss/meta_test.go @@ -41,12 +41,28 @@ const key2Meta = ` } ` +const key2MetaBrokenIV = ` +{ + "algorithm": "m.secret_storage.v1.aes-hmac-sha2", + "iv": "O0BOvTqiIAYjC+RMcyHfWwMeowMeowMeow", + "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtI=" +} +` + +const key2MetaBrokenMAC = ` +{ + "algorithm": "m.secret_storage.v1.aes-hmac-sha2", + "iv": "O0BOvTqiIAYjC+RMcyHfWw==", + "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtIMeowMeowMeow" +} +` + const key2ID = "NVe5vK6lZS9gEMQLJw0yqkzmE5Mr7dLv" const key2RecoveryKey = "EsUC xSxt XJgQ dz19 8WBZ rHdE GZo7 ybsn EFmG Y5HY MDAG GNWe" -func getKey1Meta() *ssss.KeyMetadata { +func getKeyMeta(meta string) *ssss.KeyMetadata { var km ssss.KeyMetadata - err := json.Unmarshal([]byte(key1Meta), &km) + err := json.Unmarshal([]byte(meta), &km) if err != nil { panic(err) } @@ -54,7 +70,7 @@ func getKey1Meta() *ssss.KeyMetadata { } func getKey1() *ssss.Key { - km := getKey1Meta() + km := getKeyMeta(key1Meta) key, err := km.VerifyRecoveryKey(key1ID, key1RecoveryKey) if err != nil { panic(err) @@ -63,17 +79,8 @@ func getKey1() *ssss.Key { return key } -func getKey2Meta() *ssss.KeyMetadata { - var km ssss.KeyMetadata - err := json.Unmarshal([]byte(key2Meta), &km) - if err != nil { - panic(err) - } - return &km -} - func getKey2() *ssss.Key { - km := getKey2Meta() + km := getKeyMeta(key2Meta) key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) if err != nil { panic(err) @@ -83,7 +90,7 @@ func getKey2() *ssss.Key { } func TestKeyMetadata_VerifyRecoveryKey_Correct(t *testing.T) { - km := getKey1Meta() + km := getKeyMeta(key1Meta) key, err := km.VerifyRecoveryKey(key1ID, key1RecoveryKey) assert.NoError(t, err) assert.NotNil(t, key) @@ -91,7 +98,7 @@ func TestKeyMetadata_VerifyRecoveryKey_Correct(t *testing.T) { } func TestKeyMetadata_VerifyRecoveryKey_Correct2(t *testing.T) { - km := getKey2Meta() + km := getKeyMeta(key2Meta) key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) assert.NoError(t, err) assert.NotNil(t, key) @@ -99,21 +106,21 @@ func TestKeyMetadata_VerifyRecoveryKey_Correct2(t *testing.T) { } func TestKeyMetadata_VerifyRecoveryKey_Invalid(t *testing.T) { - km := getKey1Meta() + km := getKeyMeta(key1Meta) key, err := km.VerifyRecoveryKey(key1ID, "foo") assert.True(t, errors.Is(err, ssss.ErrInvalidRecoveryKey), "unexpected error: %v", err) assert.Nil(t, key) } func TestKeyMetadata_VerifyRecoveryKey_Incorrect(t *testing.T) { - km := getKey1Meta() + km := getKeyMeta(key1Meta) key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) assert.True(t, errors.Is(err, ssss.ErrIncorrectSSSSKey), "unexpected error: %v", err) assert.Nil(t, key) } func TestKeyMetadata_VerifyPassphrase_Correct(t *testing.T) { - km := getKey1Meta() + km := getKeyMeta(key1Meta) key, err := km.VerifyPassphrase(key1ID, key1Passphrase) assert.NoError(t, err) assert.NotNil(t, key) @@ -121,15 +128,29 @@ func TestKeyMetadata_VerifyPassphrase_Correct(t *testing.T) { } func TestKeyMetadata_VerifyPassphrase_Incorrect(t *testing.T) { - km := getKey1Meta() + km := getKeyMeta(key1Meta) key, err := km.VerifyPassphrase(key1ID, "incorrect horse battery staple") assert.True(t, errors.Is(err, ssss.ErrIncorrectSSSSKey), "unexpected error %v", err) assert.Nil(t, key) } func TestKeyMetadata_VerifyPassphrase_NotSet(t *testing.T) { - km := getKey2Meta() + km := getKeyMeta(key2Meta) key, err := km.VerifyPassphrase(key2ID, "hmm") assert.True(t, errors.Is(err, ssss.ErrNoPassphrase), "unexpected error %v", err) assert.Nil(t, key) } + +func TestKeyMetadata_VerifyRecoveryKey_CorruptedIV(t *testing.T) { + km := getKeyMeta(key2MetaBrokenIV) + key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) + assert.True(t, errors.Is(err, ssss.ErrCorruptedKeyMetadata), "unexpected error %v", err) + assert.Nil(t, key) +} + +func TestKeyMetadata_VerifyRecoveryKey_CorruptedMAC(t *testing.T) { + km := getKeyMeta(key2MetaBrokenMAC) + key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) + assert.True(t, errors.Is(err, ssss.ErrCorruptedKeyMetadata), "unexpected error %v", err) + assert.Nil(t, key) +} diff --git a/crypto/ssss/types.go b/crypto/ssss/types.go index 60852c55..345393b0 100644 --- a/crypto/ssss/types.go +++ b/crypto/ssss/types.go @@ -26,6 +26,7 @@ var ( ErrUnsupportedPassphraseAlgorithm = errors.New("unsupported passphrase KDF algorithm") ErrIncorrectSSSSKey = errors.New("incorrect SSSS key") ErrInvalidRecoveryKey = errors.New("invalid recovery key") + ErrCorruptedKeyMetadata = errors.New("corrupted key metadata") ) // Algorithm is the identifier for an SSSS encryption algorithm.