crypto: propagate more errors
Some checks are pending
Go / Lint (latest) (push) Waiting to run
Go / Build (old, libolm) (push) Waiting to run
Go / Build (latest, libolm) (push) Waiting to run
Go / Build (old, goolm) (push) Waiting to run
Go / Build (latest, goolm) (push) Waiting to run

Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
This commit is contained in:
Sumner Evans 2024-10-25 09:39:19 -06:00
commit 4a2557ed15
No known key found for this signature in database
17 changed files with 86 additions and 97 deletions

View file

@ -179,7 +179,10 @@ func (mach *OlmMachine) newOutboundGroupSession(ctx context.Context, roomID id.R
Msg("Failed to get encryption event in room")
return nil, fmt.Errorf("failed to get encryption event in room %s: %w", roomID, err)
}
session := NewOutboundGroupSession(roomID, encryptionEvent)
session, err := NewOutboundGroupSession(roomID, encryptionEvent)
if err != nil {
return nil, err
}
if !mach.DontStoreOutboundKeys {
signingKey, idKey := mach.account.Keys()
err := mach.createGroupSession(ctx, idKey, signingKey, roomID, session.ID(), session.Internal.Key(), session.MaxAge, session.MaxMessages, false)

View file

@ -111,8 +111,11 @@ func (a *Account) IdentityKeys() (id.Ed25519, id.Curve25519, error) {
func (a *Account) Sign(message []byte) ([]byte, error) {
if len(message) == 0 {
return nil, fmt.Errorf("sign: %w", olm.ErrEmptyInput)
} else if signature, err := a.IdKeys.Ed25519.Sign(message); err != nil {
return nil, err
} else {
return []byte(base64.RawStdEncoding.EncodeToString(signature)), nil
}
return []byte(base64.RawStdEncoding.EncodeToString(a.IdKeys.Ed25519.Sign(message))), nil
}
// OneTimeKeys returns the public parts of the unpublished one time keys of the Account.
@ -122,7 +125,7 @@ func (a *Account) OneTimeKeys() (map[string]id.Curve25519, error) {
oneTimeKeys := make(map[string]id.Curve25519)
for _, curKey := range a.OTKeys {
if !curKey.Published {
oneTimeKeys[curKey.KeyIDEncoded()] = id.Curve25519(curKey.PublicKeyEncoded())
oneTimeKeys[curKey.KeyIDEncoded()] = curKey.Key.PublicKey.B64Encoded()
}
}
return oneTimeKeys, nil
@ -259,7 +262,7 @@ func (a *Account) GenFallbackKey() error {
func (a *Account) FallbackKey() map[string]id.Curve25519 {
keys := make(map[string]id.Curve25519)
if a.NumFallbackKeys >= 1 {
keys[a.CurrentFallbackKey.KeyIDEncoded()] = id.Curve25519(a.CurrentFallbackKey.PublicKeyEncoded())
keys[a.CurrentFallbackKey.KeyIDEncoded()] = a.CurrentFallbackKey.Key.PublicKey.B64Encoded()
}
return keys
}
@ -286,7 +289,7 @@ func (a *Account) FallbackKeyJSON() ([]byte, error) {
func (a *Account) FallbackKeyUnpublished() map[string]id.Curve25519 {
keys := make(map[string]id.Curve25519)
if a.NumFallbackKeys >= 1 && !a.CurrentFallbackKey.Published {
keys[a.CurrentFallbackKey.KeyIDEncoded()] = id.Curve25519(a.CurrentFallbackKey.PublicKeyEncoded())
keys[a.CurrentFallbackKey.KeyIDEncoded()] = a.CurrentFallbackKey.Key.PublicKey.B64Encoded()
}
return keys
}

View file

@ -1,8 +1,8 @@
package crypto
import (
"bytes"
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"golang.org/x/crypto/curve25519"
@ -16,6 +16,12 @@ const (
Curve25519PublicKeyLength = 32
)
// Curve25519KeyPair stores both parts of a curve25519 key.
type Curve25519KeyPair struct {
PrivateKey Curve25519PrivateKey `json:"private,omitempty"`
PublicKey Curve25519PublicKey `json:"public,omitempty"`
}
// Curve25519GenerateKey creates a new curve25519 key pair.
func Curve25519GenerateKey() (Curve25519KeyPair, error) {
privateKeyByte := make([]byte, Curve25519PrivateKeyLength)
@ -34,19 +40,10 @@ func Curve25519GenerateKey() (Curve25519KeyPair, error) {
// Curve25519GenerateFromPrivate creates a new curve25519 key pair with the private key given.
func Curve25519GenerateFromPrivate(private Curve25519PrivateKey) (Curve25519KeyPair, error) {
publicKey, err := private.PubKey()
if err != nil {
return Curve25519KeyPair{}, err
}
return Curve25519KeyPair{
PrivateKey: private,
PublicKey: Curve25519PublicKey(publicKey),
}, nil
}
// Curve25519KeyPair stores both parts of a curve25519 key.
type Curve25519KeyPair struct {
PrivateKey Curve25519PrivateKey `json:"private,omitempty"`
PublicKey Curve25519PublicKey `json:"public,omitempty"`
}, err
}
// B64Encoded returns a base64 encoded string of the public key.
@ -86,7 +83,7 @@ type Curve25519PrivateKey []byte
// Equal compares the private key to the given private key.
func (c Curve25519PrivateKey) Equal(x Curve25519PrivateKey) bool {
return bytes.Equal(c, x)
return subtle.ConstantTimeCompare(c, x) == 1
}
// PubKey returns the public key derived from the private key.
@ -104,7 +101,7 @@ type Curve25519PublicKey []byte
// Equal compares the public key to the given public key.
func (c Curve25519PublicKey) Equal(x Curve25519PublicKey) bool {
return bytes.Equal(c, x)
return subtle.ConstantTimeCompare(c, x) == 1
}
// B64Encoded returns a base64 encoded string of the public key.

View file

@ -9,7 +9,7 @@ import (
)
const (
ED25519SignatureSize = ed25519.SignatureSize //The length of a signature
Ed25519SignatureSize = ed25519.SignatureSize //The length of a signature
)
// Ed25519GenerateKey creates a new ed25519 key pair.
@ -50,7 +50,7 @@ func (c Ed25519KeyPair) B64Encoded() id.Ed25519 {
}
// Sign returns the signature for the message.
func (c Ed25519KeyPair) Sign(message []byte) []byte {
func (c Ed25519KeyPair) Sign(message []byte) ([]byte, error) {
return c.PrivateKey.Sign(message)
}
@ -96,12 +96,8 @@ func (c Ed25519PrivateKey) PubKey() Ed25519PublicKey {
}
// Sign returns the signature for the message.
func (c Ed25519PrivateKey) Sign(message []byte) []byte {
signature, err := ed25519.PrivateKey(c).Sign(nil, message, &ed25519.Options{})
if err != nil {
panic(err)
}
return signature
func (c Ed25519PrivateKey) Sign(message []byte) ([]byte, error) {
return ed25519.PrivateKey(c).Sign(nil, message, &ed25519.Options{})
}
// Ed25519PublicKey represents the public key for ed25519 usage. This is just a wrapper.

View file

@ -4,6 +4,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"maunium.net/go/mautrix/crypto/ed25519"
"maunium.net/go/mautrix/crypto/goolm/crypto"
@ -17,7 +18,8 @@ func TestEd25519(t *testing.T) {
keypair, err := crypto.Ed25519GenerateKey()
assert.NoError(t, err)
message := []byte("test message")
signature := keypair.Sign(message)
signature, err := keypair.Sign(message)
require.NoError(t, err)
assert.True(t, keypair.Verify(message, signature))
}
@ -29,7 +31,8 @@ func TestEd25519Case1(t *testing.T) {
keyPair2 := crypto.Ed25519GenerateFromPrivate(keyPair.PrivateKey)
assert.Equal(t, keyPair, keyPair2, "not equal key pairs")
signature := keyPair.Sign(message)
signature, err := keyPair.Sign(message)
require.NoError(t, err)
verified := keyPair.Verify(message, signature)
assert.True(t, verified, "message did not verify although it should")

View file

@ -5,7 +5,6 @@ import (
"encoding/binary"
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
"maunium.net/go/mautrix/id"
)
// OneTimeKey stores the information about a one time key.
@ -16,20 +15,11 @@ type OneTimeKey struct {
}
// Equal compares the one time key to the given one.
func (otk OneTimeKey) Equal(s OneTimeKey) bool {
if otk.ID != s.ID {
return false
}
if otk.Published != s.Published {
return false
}
if !otk.Key.PrivateKey.Equal(s.Key.PrivateKey) {
return false
}
if !otk.Key.PublicKey.Equal(s.Key.PublicKey) {
return false
}
return true
func (otk OneTimeKey) Equal(other OneTimeKey) bool {
return otk.ID == other.ID &&
otk.Published == other.Published &&
otk.Key.PrivateKey.Equal(other.Key.PrivateKey) &&
otk.Key.PublicKey.Equal(other.Key.PublicKey)
}
// PickleLibOlm pickles the key pair into the encoder.
@ -50,14 +40,7 @@ func (c *OneTimeKey) UnpickleLibOlm(decoder *libolmpickle.Decoder) (err error) {
return c.Key.UnpickleLibOlm(decoder)
}
// KeyIDEncoded returns the base64 encoded id.
// KeyIDEncoded returns the base64 encoded key ID.
func (c OneTimeKey) KeyIDEncoded() string {
resSlice := make([]byte, 4)
binary.BigEndian.PutUint32(resSlice, c.ID)
return base64.RawStdEncoding.EncodeToString(resSlice)
}
// PublicKeyEncoded returns the base64 encoded public key
func (c OneTimeKey) PublicKeyEncoded() id.Curve25519 {
return c.Key.PublicKey.B64Encoded()
return base64.RawStdEncoding.EncodeToString(binary.BigEndian.AppendUint32(nil, c.ID))
}

View file

@ -161,8 +161,8 @@ func (r Ratchet) SessionSharingMessage(key crypto.Ed25519KeyPair) ([]byte, error
m := message.MegolmSessionSharing{}
m.Counter = r.Counter
m.RatchetData = r.Data
encoded := m.EncodeAndSign(key)
return goolmbase64.Encode(encoded), nil
encoded, err := m.EncodeAndSign(key)
return goolmbase64.Encode(encoded), err
}
// SessionExportMessage creates a message in the session export format.

View file

@ -32,7 +32,7 @@ func (r *GroupMessage) Decode(input []byte) error {
//first Byte is always version
r.Version = input[0]
curPos := 1
for curPos < len(input)-countMACBytesGroupMessage-crypto.ED25519SignatureSize {
for curPos < len(input)-countMACBytesGroupMessage-crypto.Ed25519SignatureSize {
//Read Key
curKey, readBytes := decodeVarInt(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {
@ -98,7 +98,10 @@ func (r *GroupMessage) EncodeAndMacAndSign(macKey []byte, cipher cipher.Cipher,
out = append(out, mac[:countMACBytesGroupMessage]...)
}
if signKey != nil {
signature := signKey.Sign(out)
signature, err := signKey.Sign(out)
if err != nil {
return nil, err
}
out = append(out, signature...)
}
return out, nil
@ -120,8 +123,8 @@ func (r *GroupMessage) VerifySignature(key crypto.Ed25519PublicKey, message, giv
// VerifySignature verifies the signature taken from the message to the calculated signature of the message.
func (r *GroupMessage) VerifySignatureInline(key crypto.Ed25519PublicKey, message []byte) bool {
signature := message[len(message)-crypto.ED25519SignatureSize:]
message = message[:len(message)-crypto.ED25519SignatureSize]
signature := message[len(message)-crypto.Ed25519SignatureSize:]
message = message[:len(message)-crypto.Ed25519SignatureSize]
return key.Verify(message, signature)
}
@ -136,7 +139,7 @@ func (r *GroupMessage) VerifyMAC(key []byte, cipher cipher.Cipher, message, give
// VerifyMACInline verifies the MAC taken from the message to the calculated MAC of the message.
func (r *GroupMessage) VerifyMACInline(key []byte, cipher cipher.Cipher, message []byte) (bool, error) {
startMAC := len(message) - countMACBytesGroupMessage - crypto.ED25519SignatureSize
startMAC := len(message) - countMACBytesGroupMessage - crypto.Ed25519SignatureSize
endMAC := startMAC + countMACBytesGroupMessage
suplMac := message[startMAC:endMAC]
message = message[:startMAC]

View file

@ -20,15 +20,15 @@ type MegolmSessionSharing struct {
}
// Encode returns the encoded message in the correct format with the signature by key appended.
func (s MegolmSessionSharing) EncodeAndSign(key crypto.Ed25519KeyPair) []byte {
func (s MegolmSessionSharing) EncodeAndSign(key crypto.Ed25519KeyPair) ([]byte, error) {
output := make([]byte, 229)
output[0] = sessionSharingVersion
binary.BigEndian.PutUint32(output[1:], s.Counter)
copy(output[5:], s.RatchetData[:])
copy(output[133:], key.PublicKey)
signature := key.Sign(output[:165])
signature, err := key.Sign(output[:165])
copy(output[165:], signature)
return output
return output, err
}
// VerifyAndDecode verifies the input and populates the struct with the data encoded in input.

View file

@ -48,8 +48,8 @@ func (s Signing) PublicKey() id.Ed25519 {
// Sign returns the signature of the message base64 encoded.
func (s Signing) Sign(message []byte) ([]byte, error) {
signature := s.keyPair.Sign(message)
return goolmbase64.Encode(signature), nil
signature, err := s.keyPair.Sign(message)
return goolmbase64.Encode(signature), err
}
// SignJSON creates a signature for the given object after encoding it to

View file

@ -48,16 +48,8 @@ func init() {
}
return MegolmOutboundSessionFromPickled(pickled, key)
}
olm.InitNewOutboundGroupSession = func() olm.OutboundGroupSession {
session, err := NewMegolmOutboundSession()
if err != nil {
panic(err)
}
return session
}
olm.InitNewBlankOutboundGroupSession = func() olm.OutboundGroupSession {
return &MegolmOutboundSession{}
}
olm.InitNewOutboundGroupSession = func() (olm.OutboundGroupSession, error) { return NewMegolmOutboundSession() }
olm.InitNewBlankOutboundGroupSession = func() olm.OutboundGroupSession { return &MegolmOutboundSession{} }
// Olm Session
olm.InitSessionFromPickled = func(pickled, key []byte) (olm.Session, error) {

View file

@ -28,32 +28,28 @@ func init() {
s := NewBlankOutboundGroupSession()
return s, s.Unpickle(pickled, key)
}
olm.InitNewOutboundGroupSession = func() olm.OutboundGroupSession {
return NewOutboundGroupSession()
}
olm.InitNewBlankOutboundGroupSession = func() olm.OutboundGroupSession {
return NewBlankOutboundGroupSession()
}
olm.InitNewOutboundGroupSession = func() (olm.OutboundGroupSession, error) { return NewOutboundGroupSession() }
olm.InitNewBlankOutboundGroupSession = func() olm.OutboundGroupSession { return NewBlankOutboundGroupSession() }
}
// Ensure that [OutboundGroupSession] implements [olm.OutboundGroupSession].
var _ olm.OutboundGroupSession = (*OutboundGroupSession)(nil)
func NewOutboundGroupSession() *OutboundGroupSession {
func NewOutboundGroupSession() (*OutboundGroupSession, error) {
s := NewBlankOutboundGroupSession()
random := make([]byte, s.createRandomLen()+1)
_, err := rand.Read(random)
if err != nil {
panic(olm.NotEnoughGoRandom)
return nil, err
}
r := C.olm_init_outbound_group_session(
(*C.OlmOutboundGroupSession)(s.int),
(*C.uint8_t)(&random[0]),
C.size_t(len(random)))
if r == errorVal() {
panic(s.lastError())
return nil, s.lastError()
}
return s
return s, nil
}
// outboundGroupSessionSize is the size of an outbound group session object in

View file

@ -31,7 +31,8 @@ func TestEncryptDecrypt_GoolmToLibolm(t *testing.T) {
}
func TestEncryptDecrypt_LibolmToGoolm(t *testing.T) {
libolmOutbound := libolm.NewOutboundGroupSession()
libolmOutbound, err := libolm.NewOutboundGroupSession()
require.NoError(t, err)
goolmInbound, err := session.NewMegolmInboundSession([]byte(libolmOutbound.Key()))
require.NoError(t, err)

View file

@ -34,7 +34,7 @@ type OutboundGroupSession interface {
}
var InitNewOutboundGroupSessionFromPickled func(pickled, key []byte) (OutboundGroupSession, error)
var InitNewOutboundGroupSession func() OutboundGroupSession
var InitNewOutboundGroupSession func() (OutboundGroupSession, error)
var InitNewBlankOutboundGroupSession func() OutboundGroupSession
// OutboundGroupSessionFromPickled loads an OutboundGroupSession from a pickled
@ -47,7 +47,7 @@ func OutboundGroupSessionFromPickled(pickled, key []byte) (OutboundGroupSession,
}
// NewOutboundGroupSession creates a new outbound group session.
func NewOutboundGroupSession() OutboundGroupSession {
func NewOutboundGroupSession() (OutboundGroupSession, error) {
return InitNewOutboundGroupSession()
}

View file

@ -12,7 +12,8 @@ import (
)
func TestMegolmOutboundSessionPickle_RoundtripThroughGoolm(t *testing.T) {
libolmSession := libolm.NewOutboundGroupSession()
libolmSession, err := libolm.NewOutboundGroupSession()
require.NoError(t, err)
libolmPickled, err := libolmSession.Pickle([]byte("test"))
require.NoError(t, err)
@ -24,7 +25,8 @@ func TestMegolmOutboundSessionPickle_RoundtripThroughGoolm(t *testing.T) {
assert.Equal(t, libolmPickled, goolmPickled, "pickled versions are not the same")
libolmSession2 := libolm.NewOutboundGroupSession()
libolmSession2, err := libolm.NewOutboundGroupSession()
require.NoError(t, err)
err = libolmSession2.Unpickle(bytes.Clone(goolmPickled), []byte("test"))
require.NoError(t, err)
@ -38,7 +40,8 @@ func TestMegolmOutboundSessionPickle_RoundtripThroughLibolm(t *testing.T) {
goolmPickled, err := goolmSession.Pickle([]byte("test"))
require.NoError(t, err)
libolmSession := libolm.NewOutboundGroupSession()
libolmSession, err := libolm.NewOutboundGroupSession()
require.NoError(t, err)
err = libolmSession.Unpickle(bytes.Clone(goolmPickled), []byte("test"))
require.NoError(t, err)
@ -55,7 +58,8 @@ func TestMegolmOutboundSessionPickle_RoundtripThroughLibolm(t *testing.T) {
}
func TestMegolmOutboundSessionPickleLibolm(t *testing.T) {
libolmSession := libolm.NewOutboundGroupSession()
libolmSession, err := libolm.NewOutboundGroupSession()
require.NoError(t, err)
libolmPickled, err := libolmSession.Pickle([]byte("test"))
require.NoError(t, err)
@ -77,7 +81,8 @@ func TestMegolmOutboundSessionPickleGoolm(t *testing.T) {
goolmPickled, err := goolmSession.Pickle([]byte("test"))
require.NoError(t, err)
libolmSession := libolm.NewOutboundGroupSession()
libolmSession, err := libolm.NewOutboundGroupSession()
require.NoError(t, err)
err = libolmSession.Unpickle(bytes.Clone(goolmPickled), []byte("test"))
require.NoError(t, err)
libolmPickled, err := libolmSession.Pickle([]byte("test"))
@ -98,7 +103,8 @@ func FuzzMegolmOutboundSession_Encrypt(f *testing.F) {
t.Skip("empty plaintext is not supported")
}
libolmSession := libolm.NewOutboundGroupSession()
libolmSession, err := libolm.NewOutboundGroupSession()
require.NoError(t, err)
libolmPickled, err := libolmSession.Pickle([]byte("test"))
require.NoError(t, err)

View file

@ -180,9 +180,13 @@ type OutboundGroupSession struct {
content *event.RoomKeyEventContent
}
func NewOutboundGroupSession(roomID id.RoomID, encryptionContent *event.EncryptionEventContent) *OutboundGroupSession {
func NewOutboundGroupSession(roomID id.RoomID, encryptionContent *event.EncryptionEventContent) (*OutboundGroupSession, error) {
internal, err := olm.NewOutboundGroupSession()
if err != nil {
return nil, err
}
ogs := &OutboundGroupSession{
Internal: olm.NewOutboundGroupSession(),
Internal: internal,
ExpirationMixin: ExpirationMixin{
TimeMixin: TimeMixin{
CreationTime: time.Now(),
@ -206,7 +210,7 @@ func NewOutboundGroupSession(roomID id.RoomID, encryptionContent *event.Encrypti
ogs.MaxMessages = min(max(encryptionContent.RotationPeriodMessages, 1), 10000)
}
}
return ogs
return ogs, nil
}
func (ogs *OutboundGroupSession) ShareContent() event.Content {

View file

@ -13,6 +13,7 @@ import (
"testing"
_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/require"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/crypto/olm"
@ -195,7 +196,8 @@ func TestStoreOutboundMegolmSession(t *testing.T) {
t.Errorf("Error retrieving outbound session: %v", err)
}
outbound := NewOutboundGroupSession("room1", nil)
outbound, err := NewOutboundGroupSession("room1", nil)
require.NoError(t, err)
err = store.AddOutboundGroupSession(context.TODO(), outbound)
if err != nil {
t.Errorf("Error inserting outbound session: %v", err)