mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
crypto/ssss: remove id from key metadata
Instead, we will pass it into the key constructor functions directly. This avoids the footgun where you don't set the key ID on the metadata and then the ID is not properly propagated to the Key that is returned. Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
This commit is contained in:
parent
7402f5a705
commit
0a17ac1cbe
4 changed files with 17 additions and 19 deletions
|
|
@ -53,7 +53,7 @@ func (mach *Machine) SetDefaultKeyID(ctx context.Context, keyID string) error {
|
|||
|
||||
// GetKeyData gets the details about the given key ID.
|
||||
func (mach *Machine) GetKeyData(ctx context.Context, keyID string) (keyData *KeyMetadata, err error) {
|
||||
keyData = &KeyMetadata{id: keyID}
|
||||
keyData = &KeyMetadata{}
|
||||
err = mach.Client.GetAccountData(ctx, fmt.Sprintf("%s.%s", event.AccountDataSecretStorageKey.Type, keyID), keyData)
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,8 +17,6 @@ import (
|
|||
// KeyMetadata represents server-side metadata about a SSSS key. The metadata can be used to get
|
||||
// the actual SSSS key from a passphrase or recovery key.
|
||||
type KeyMetadata struct {
|
||||
id string
|
||||
|
||||
Name string `json:"name"`
|
||||
Algorithm Algorithm `json:"algorithm"`
|
||||
|
||||
|
|
@ -31,7 +29,7 @@ type KeyMetadata struct {
|
|||
}
|
||||
|
||||
// VerifyRecoveryKey verifies that the given passphrase is valid and returns the computed SSSS key.
|
||||
func (kd *KeyMetadata) VerifyPassphrase(passphrase string) (*Key, error) {
|
||||
func (kd *KeyMetadata) VerifyPassphrase(keyID, passphrase string) (*Key, error) {
|
||||
ssssKey, err := kd.Passphrase.GetKey(passphrase)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -40,15 +38,15 @@ func (kd *KeyMetadata) VerifyPassphrase(passphrase string) (*Key, error) {
|
|||
}
|
||||
|
||||
return &Key{
|
||||
ID: kd.id,
|
||||
ID: keyID,
|
||||
Key: ssssKey,
|
||||
Metadata: kd,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// VerifyRecoveryKey verifies that the given recovery key is valid and returns the decoded SSSS key.
|
||||
func (kd *KeyMetadata) VerifyRecoveryKey(recoverKey string) (*Key, error) {
|
||||
ssssKey := utils.DecodeBase58RecoveryKey(recoverKey)
|
||||
func (kd *KeyMetadata) VerifyRecoveryKey(keyID, recoveryKey string) (*Key, error) {
|
||||
ssssKey := utils.DecodeBase58RecoveryKey(recoveryKey)
|
||||
if ssssKey == nil {
|
||||
return nil, ErrInvalidRecoveryKey
|
||||
} else if !kd.VerifyKey(ssssKey) {
|
||||
|
|
@ -56,7 +54,7 @@ func (kd *KeyMetadata) VerifyRecoveryKey(recoverKey string) (*Key, error) {
|
|||
}
|
||||
|
||||
return &Key{
|
||||
ID: kd.id,
|
||||
ID: keyID,
|
||||
Key: ssssKey,
|
||||
Metadata: kd,
|
||||
}, nil
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ func getKey1Meta() *ssss.KeyMetadata {
|
|||
|
||||
func getKey1() *ssss.Key {
|
||||
km := getKey1Meta()
|
||||
key, err := km.VerifyRecoveryKey(key1RecoveryKey)
|
||||
key, err := km.VerifyRecoveryKey(key1ID, key1RecoveryKey)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
|
@ -74,7 +74,7 @@ func getKey2Meta() *ssss.KeyMetadata {
|
|||
|
||||
func getKey2() *ssss.Key {
|
||||
km := getKey2Meta()
|
||||
key, err := km.VerifyRecoveryKey(key2RecoveryKey)
|
||||
key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
|
@ -84,7 +84,7 @@ func getKey2() *ssss.Key {
|
|||
|
||||
func TestKeyMetadata_VerifyRecoveryKey_Correct(t *testing.T) {
|
||||
km := getKey1Meta()
|
||||
key, err := km.VerifyRecoveryKey(key1RecoveryKey)
|
||||
key, err := km.VerifyRecoveryKey(key1ID, key1RecoveryKey)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, key)
|
||||
assert.Equal(t, key1RecoveryKey, key.RecoveryKey())
|
||||
|
|
@ -92,7 +92,7 @@ func TestKeyMetadata_VerifyRecoveryKey_Correct(t *testing.T) {
|
|||
|
||||
func TestKeyMetadata_VerifyRecoveryKey_Correct2(t *testing.T) {
|
||||
km := getKey2Meta()
|
||||
key, err := km.VerifyRecoveryKey(key2RecoveryKey)
|
||||
key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, key)
|
||||
assert.Equal(t, key2RecoveryKey, key.RecoveryKey())
|
||||
|
|
@ -100,21 +100,21 @@ func TestKeyMetadata_VerifyRecoveryKey_Correct2(t *testing.T) {
|
|||
|
||||
func TestKeyMetadata_VerifyRecoveryKey_Invalid(t *testing.T) {
|
||||
km := getKey1Meta()
|
||||
key, err := km.VerifyRecoveryKey("foo")
|
||||
key, err := km.VerifyRecoveryKey(key1ID, "foo")
|
||||
assert.True(t, errors.Is(err, ssss.ErrInvalidRecoveryKey), "unexpected error: %v", err)
|
||||
assert.Nil(t, key)
|
||||
}
|
||||
|
||||
func TestKeyMetadata_VerifyRecoveryKey_Incorrect(t *testing.T) {
|
||||
km := getKey1Meta()
|
||||
key, err := km.VerifyRecoveryKey(key2RecoveryKey)
|
||||
key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey)
|
||||
assert.True(t, errors.Is(err, ssss.ErrIncorrectSSSSKey), "unexpected error: %v", err)
|
||||
assert.Nil(t, key)
|
||||
}
|
||||
|
||||
func TestKeyMetadata_VerifyPassphrase_Correct(t *testing.T) {
|
||||
km := getKey1Meta()
|
||||
key, err := km.VerifyPassphrase(key1Passphrase)
|
||||
key, err := km.VerifyPassphrase(key1ID, key1Passphrase)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, key)
|
||||
assert.Equal(t, key1RecoveryKey, key.RecoveryKey())
|
||||
|
|
@ -122,14 +122,14 @@ func TestKeyMetadata_VerifyPassphrase_Correct(t *testing.T) {
|
|||
|
||||
func TestKeyMetadata_VerifyPassphrase_Incorrect(t *testing.T) {
|
||||
km := getKey1Meta()
|
||||
key, err := km.VerifyPassphrase("incorrect horse battery staple")
|
||||
key, err := km.VerifyPassphrase(key1ID, "incorrect horse battery staple")
|
||||
assert.True(t, errors.Is(err, ssss.ErrIncorrectSSSSKey), "unexpected error %v", err)
|
||||
assert.Nil(t, key)
|
||||
}
|
||||
|
||||
func TestKeyMetadata_VerifyPassphrase_NotSet(t *testing.T) {
|
||||
km := getKey2Meta()
|
||||
key, err := km.VerifyPassphrase("hmm")
|
||||
key, err := km.VerifyPassphrase(key2ID, "hmm")
|
||||
assert.True(t, errors.Is(err, ssss.ErrNoPassphrase), "unexpected error %v", err)
|
||||
assert.Nil(t, key)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -125,11 +125,11 @@ func (h *HiClient) storeCrossSigningPrivateKeys(ctx context.Context) error {
|
|||
}
|
||||
|
||||
func (h *HiClient) VerifyWithRecoveryCode(ctx context.Context, code string) error {
|
||||
_, keyData, err := h.Crypto.SSSS.GetDefaultKeyData(ctx)
|
||||
keyID, keyData, err := h.Crypto.SSSS.GetDefaultKeyData(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get default SSSS key data: %w", err)
|
||||
}
|
||||
key, err := keyData.VerifyRecoveryKey(code)
|
||||
key, err := keyData.VerifyRecoveryKey(keyID, code)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue