mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
crypto: propagate more errors
Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
This commit is contained in:
parent
9f74b58d84
commit
4a2557ed15
17 changed files with 86 additions and 97 deletions
|
|
@ -179,7 +179,10 @@ func (mach *OlmMachine) newOutboundGroupSession(ctx context.Context, roomID id.R
|
|||
Msg("Failed to get encryption event in room")
|
||||
return nil, fmt.Errorf("failed to get encryption event in room %s: %w", roomID, err)
|
||||
}
|
||||
session := NewOutboundGroupSession(roomID, encryptionEvent)
|
||||
session, err := NewOutboundGroupSession(roomID, encryptionEvent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !mach.DontStoreOutboundKeys {
|
||||
signingKey, idKey := mach.account.Keys()
|
||||
err := mach.createGroupSession(ctx, idKey, signingKey, roomID, session.ID(), session.Internal.Key(), session.MaxAge, session.MaxMessages, false)
|
||||
|
|
|
|||
|
|
@ -111,8 +111,11 @@ func (a *Account) IdentityKeys() (id.Ed25519, id.Curve25519, error) {
|
|||
func (a *Account) Sign(message []byte) ([]byte, error) {
|
||||
if len(message) == 0 {
|
||||
return nil, fmt.Errorf("sign: %w", olm.ErrEmptyInput)
|
||||
} else if signature, err := a.IdKeys.Ed25519.Sign(message); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return []byte(base64.RawStdEncoding.EncodeToString(signature)), nil
|
||||
}
|
||||
return []byte(base64.RawStdEncoding.EncodeToString(a.IdKeys.Ed25519.Sign(message))), nil
|
||||
}
|
||||
|
||||
// OneTimeKeys returns the public parts of the unpublished one time keys of the Account.
|
||||
|
|
@ -122,7 +125,7 @@ func (a *Account) OneTimeKeys() (map[string]id.Curve25519, error) {
|
|||
oneTimeKeys := make(map[string]id.Curve25519)
|
||||
for _, curKey := range a.OTKeys {
|
||||
if !curKey.Published {
|
||||
oneTimeKeys[curKey.KeyIDEncoded()] = id.Curve25519(curKey.PublicKeyEncoded())
|
||||
oneTimeKeys[curKey.KeyIDEncoded()] = curKey.Key.PublicKey.B64Encoded()
|
||||
}
|
||||
}
|
||||
return oneTimeKeys, nil
|
||||
|
|
@ -259,7 +262,7 @@ func (a *Account) GenFallbackKey() error {
|
|||
func (a *Account) FallbackKey() map[string]id.Curve25519 {
|
||||
keys := make(map[string]id.Curve25519)
|
||||
if a.NumFallbackKeys >= 1 {
|
||||
keys[a.CurrentFallbackKey.KeyIDEncoded()] = id.Curve25519(a.CurrentFallbackKey.PublicKeyEncoded())
|
||||
keys[a.CurrentFallbackKey.KeyIDEncoded()] = a.CurrentFallbackKey.Key.PublicKey.B64Encoded()
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
|
@ -286,7 +289,7 @@ func (a *Account) FallbackKeyJSON() ([]byte, error) {
|
|||
func (a *Account) FallbackKeyUnpublished() map[string]id.Curve25519 {
|
||||
keys := make(map[string]id.Curve25519)
|
||||
if a.NumFallbackKeys >= 1 && !a.CurrentFallbackKey.Published {
|
||||
keys[a.CurrentFallbackKey.KeyIDEncoded()] = id.Curve25519(a.CurrentFallbackKey.PublicKeyEncoded())
|
||||
keys[a.CurrentFallbackKey.KeyIDEncoded()] = a.CurrentFallbackKey.Key.PublicKey.B64Encoded()
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
|
||||
"golang.org/x/crypto/curve25519"
|
||||
|
|
@ -16,6 +16,12 @@ const (
|
|||
Curve25519PublicKeyLength = 32
|
||||
)
|
||||
|
||||
// Curve25519KeyPair stores both parts of a curve25519 key.
|
||||
type Curve25519KeyPair struct {
|
||||
PrivateKey Curve25519PrivateKey `json:"private,omitempty"`
|
||||
PublicKey Curve25519PublicKey `json:"public,omitempty"`
|
||||
}
|
||||
|
||||
// Curve25519GenerateKey creates a new curve25519 key pair.
|
||||
func Curve25519GenerateKey() (Curve25519KeyPair, error) {
|
||||
privateKeyByte := make([]byte, Curve25519PrivateKeyLength)
|
||||
|
|
@ -34,19 +40,10 @@ func Curve25519GenerateKey() (Curve25519KeyPair, error) {
|
|||
// Curve25519GenerateFromPrivate creates a new curve25519 key pair with the private key given.
|
||||
func Curve25519GenerateFromPrivate(private Curve25519PrivateKey) (Curve25519KeyPair, error) {
|
||||
publicKey, err := private.PubKey()
|
||||
if err != nil {
|
||||
return Curve25519KeyPair{}, err
|
||||
}
|
||||
return Curve25519KeyPair{
|
||||
PrivateKey: private,
|
||||
PublicKey: Curve25519PublicKey(publicKey),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Curve25519KeyPair stores both parts of a curve25519 key.
|
||||
type Curve25519KeyPair struct {
|
||||
PrivateKey Curve25519PrivateKey `json:"private,omitempty"`
|
||||
PublicKey Curve25519PublicKey `json:"public,omitempty"`
|
||||
}, err
|
||||
}
|
||||
|
||||
// B64Encoded returns a base64 encoded string of the public key.
|
||||
|
|
@ -86,7 +83,7 @@ type Curve25519PrivateKey []byte
|
|||
|
||||
// Equal compares the private key to the given private key.
|
||||
func (c Curve25519PrivateKey) Equal(x Curve25519PrivateKey) bool {
|
||||
return bytes.Equal(c, x)
|
||||
return subtle.ConstantTimeCompare(c, x) == 1
|
||||
}
|
||||
|
||||
// PubKey returns the public key derived from the private key.
|
||||
|
|
@ -104,7 +101,7 @@ type Curve25519PublicKey []byte
|
|||
|
||||
// Equal compares the public key to the given public key.
|
||||
func (c Curve25519PublicKey) Equal(x Curve25519PublicKey) bool {
|
||||
return bytes.Equal(c, x)
|
||||
return subtle.ConstantTimeCompare(c, x) == 1
|
||||
}
|
||||
|
||||
// B64Encoded returns a base64 encoded string of the public key.
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
ED25519SignatureSize = ed25519.SignatureSize //The length of a signature
|
||||
Ed25519SignatureSize = ed25519.SignatureSize //The length of a signature
|
||||
)
|
||||
|
||||
// Ed25519GenerateKey creates a new ed25519 key pair.
|
||||
|
|
@ -50,7 +50,7 @@ func (c Ed25519KeyPair) B64Encoded() id.Ed25519 {
|
|||
}
|
||||
|
||||
// Sign returns the signature for the message.
|
||||
func (c Ed25519KeyPair) Sign(message []byte) []byte {
|
||||
func (c Ed25519KeyPair) Sign(message []byte) ([]byte, error) {
|
||||
return c.PrivateKey.Sign(message)
|
||||
}
|
||||
|
||||
|
|
@ -96,12 +96,8 @@ func (c Ed25519PrivateKey) PubKey() Ed25519PublicKey {
|
|||
}
|
||||
|
||||
// Sign returns the signature for the message.
|
||||
func (c Ed25519PrivateKey) Sign(message []byte) []byte {
|
||||
signature, err := ed25519.PrivateKey(c).Sign(nil, message, &ed25519.Options{})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return signature
|
||||
func (c Ed25519PrivateKey) Sign(message []byte) ([]byte, error) {
|
||||
return ed25519.PrivateKey(c).Sign(nil, message, &ed25519.Options{})
|
||||
}
|
||||
|
||||
// Ed25519PublicKey represents the public key for ed25519 usage. This is just a wrapper.
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/ed25519"
|
||||
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
||||
|
|
@ -17,7 +18,8 @@ func TestEd25519(t *testing.T) {
|
|||
keypair, err := crypto.Ed25519GenerateKey()
|
||||
assert.NoError(t, err)
|
||||
message := []byte("test message")
|
||||
signature := keypair.Sign(message)
|
||||
signature, err := keypair.Sign(message)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, keypair.Verify(message, signature))
|
||||
}
|
||||
|
||||
|
|
@ -29,7 +31,8 @@ func TestEd25519Case1(t *testing.T) {
|
|||
|
||||
keyPair2 := crypto.Ed25519GenerateFromPrivate(keyPair.PrivateKey)
|
||||
assert.Equal(t, keyPair, keyPair2, "not equal key pairs")
|
||||
signature := keyPair.Sign(message)
|
||||
signature, err := keyPair.Sign(message)
|
||||
require.NoError(t, err)
|
||||
verified := keyPair.Verify(message, signature)
|
||||
assert.True(t, verified, "message did not verify although it should")
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import (
|
|||
"encoding/binary"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// OneTimeKey stores the information about a one time key.
|
||||
|
|
@ -16,20 +15,11 @@ type OneTimeKey struct {
|
|||
}
|
||||
|
||||
// Equal compares the one time key to the given one.
|
||||
func (otk OneTimeKey) Equal(s OneTimeKey) bool {
|
||||
if otk.ID != s.ID {
|
||||
return false
|
||||
}
|
||||
if otk.Published != s.Published {
|
||||
return false
|
||||
}
|
||||
if !otk.Key.PrivateKey.Equal(s.Key.PrivateKey) {
|
||||
return false
|
||||
}
|
||||
if !otk.Key.PublicKey.Equal(s.Key.PublicKey) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
func (otk OneTimeKey) Equal(other OneTimeKey) bool {
|
||||
return otk.ID == other.ID &&
|
||||
otk.Published == other.Published &&
|
||||
otk.Key.PrivateKey.Equal(other.Key.PrivateKey) &&
|
||||
otk.Key.PublicKey.Equal(other.Key.PublicKey)
|
||||
}
|
||||
|
||||
// PickleLibOlm pickles the key pair into the encoder.
|
||||
|
|
@ -50,14 +40,7 @@ func (c *OneTimeKey) UnpickleLibOlm(decoder *libolmpickle.Decoder) (err error) {
|
|||
return c.Key.UnpickleLibOlm(decoder)
|
||||
}
|
||||
|
||||
// KeyIDEncoded returns the base64 encoded id.
|
||||
// KeyIDEncoded returns the base64 encoded key ID.
|
||||
func (c OneTimeKey) KeyIDEncoded() string {
|
||||
resSlice := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(resSlice, c.ID)
|
||||
return base64.RawStdEncoding.EncodeToString(resSlice)
|
||||
}
|
||||
|
||||
// PublicKeyEncoded returns the base64 encoded public key
|
||||
func (c OneTimeKey) PublicKeyEncoded() id.Curve25519 {
|
||||
return c.Key.PublicKey.B64Encoded()
|
||||
return base64.RawStdEncoding.EncodeToString(binary.BigEndian.AppendUint32(nil, c.ID))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -161,8 +161,8 @@ func (r Ratchet) SessionSharingMessage(key crypto.Ed25519KeyPair) ([]byte, error
|
|||
m := message.MegolmSessionSharing{}
|
||||
m.Counter = r.Counter
|
||||
m.RatchetData = r.Data
|
||||
encoded := m.EncodeAndSign(key)
|
||||
return goolmbase64.Encode(encoded), nil
|
||||
encoded, err := m.EncodeAndSign(key)
|
||||
return goolmbase64.Encode(encoded), err
|
||||
}
|
||||
|
||||
// SessionExportMessage creates a message in the session export format.
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ func (r *GroupMessage) Decode(input []byte) error {
|
|||
//first Byte is always version
|
||||
r.Version = input[0]
|
||||
curPos := 1
|
||||
for curPos < len(input)-countMACBytesGroupMessage-crypto.ED25519SignatureSize {
|
||||
for curPos < len(input)-countMACBytesGroupMessage-crypto.Ed25519SignatureSize {
|
||||
//Read Key
|
||||
curKey, readBytes := decodeVarInt(input[curPos:])
|
||||
if err := checkDecodeErr(readBytes); err != nil {
|
||||
|
|
@ -98,7 +98,10 @@ func (r *GroupMessage) EncodeAndMacAndSign(macKey []byte, cipher cipher.Cipher,
|
|||
out = append(out, mac[:countMACBytesGroupMessage]...)
|
||||
}
|
||||
if signKey != nil {
|
||||
signature := signKey.Sign(out)
|
||||
signature, err := signKey.Sign(out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, signature...)
|
||||
}
|
||||
return out, nil
|
||||
|
|
@ -120,8 +123,8 @@ func (r *GroupMessage) VerifySignature(key crypto.Ed25519PublicKey, message, giv
|
|||
|
||||
// VerifySignature verifies the signature taken from the message to the calculated signature of the message.
|
||||
func (r *GroupMessage) VerifySignatureInline(key crypto.Ed25519PublicKey, message []byte) bool {
|
||||
signature := message[len(message)-crypto.ED25519SignatureSize:]
|
||||
message = message[:len(message)-crypto.ED25519SignatureSize]
|
||||
signature := message[len(message)-crypto.Ed25519SignatureSize:]
|
||||
message = message[:len(message)-crypto.Ed25519SignatureSize]
|
||||
return key.Verify(message, signature)
|
||||
}
|
||||
|
||||
|
|
@ -136,7 +139,7 @@ func (r *GroupMessage) VerifyMAC(key []byte, cipher cipher.Cipher, message, give
|
|||
|
||||
// VerifyMACInline verifies the MAC taken from the message to the calculated MAC of the message.
|
||||
func (r *GroupMessage) VerifyMACInline(key []byte, cipher cipher.Cipher, message []byte) (bool, error) {
|
||||
startMAC := len(message) - countMACBytesGroupMessage - crypto.ED25519SignatureSize
|
||||
startMAC := len(message) - countMACBytesGroupMessage - crypto.Ed25519SignatureSize
|
||||
endMAC := startMAC + countMACBytesGroupMessage
|
||||
suplMac := message[startMAC:endMAC]
|
||||
message = message[:startMAC]
|
||||
|
|
|
|||
|
|
@ -20,15 +20,15 @@ type MegolmSessionSharing struct {
|
|||
}
|
||||
|
||||
// Encode returns the encoded message in the correct format with the signature by key appended.
|
||||
func (s MegolmSessionSharing) EncodeAndSign(key crypto.Ed25519KeyPair) []byte {
|
||||
func (s MegolmSessionSharing) EncodeAndSign(key crypto.Ed25519KeyPair) ([]byte, error) {
|
||||
output := make([]byte, 229)
|
||||
output[0] = sessionSharingVersion
|
||||
binary.BigEndian.PutUint32(output[1:], s.Counter)
|
||||
copy(output[5:], s.RatchetData[:])
|
||||
copy(output[133:], key.PublicKey)
|
||||
signature := key.Sign(output[:165])
|
||||
signature, err := key.Sign(output[:165])
|
||||
copy(output[165:], signature)
|
||||
return output
|
||||
return output, err
|
||||
}
|
||||
|
||||
// VerifyAndDecode verifies the input and populates the struct with the data encoded in input.
|
||||
|
|
|
|||
|
|
@ -48,8 +48,8 @@ func (s Signing) PublicKey() id.Ed25519 {
|
|||
|
||||
// Sign returns the signature of the message base64 encoded.
|
||||
func (s Signing) Sign(message []byte) ([]byte, error) {
|
||||
signature := s.keyPair.Sign(message)
|
||||
return goolmbase64.Encode(signature), nil
|
||||
signature, err := s.keyPair.Sign(message)
|
||||
return goolmbase64.Encode(signature), err
|
||||
}
|
||||
|
||||
// SignJSON creates a signature for the given object after encoding it to
|
||||
|
|
|
|||
|
|
@ -48,16 +48,8 @@ func init() {
|
|||
}
|
||||
return MegolmOutboundSessionFromPickled(pickled, key)
|
||||
}
|
||||
olm.InitNewOutboundGroupSession = func() olm.OutboundGroupSession {
|
||||
session, err := NewMegolmOutboundSession()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return session
|
||||
}
|
||||
olm.InitNewBlankOutboundGroupSession = func() olm.OutboundGroupSession {
|
||||
return &MegolmOutboundSession{}
|
||||
}
|
||||
olm.InitNewOutboundGroupSession = func() (olm.OutboundGroupSession, error) { return NewMegolmOutboundSession() }
|
||||
olm.InitNewBlankOutboundGroupSession = func() olm.OutboundGroupSession { return &MegolmOutboundSession{} }
|
||||
|
||||
// Olm Session
|
||||
olm.InitSessionFromPickled = func(pickled, key []byte) (olm.Session, error) {
|
||||
|
|
|
|||
|
|
@ -28,32 +28,28 @@ func init() {
|
|||
s := NewBlankOutboundGroupSession()
|
||||
return s, s.Unpickle(pickled, key)
|
||||
}
|
||||
olm.InitNewOutboundGroupSession = func() olm.OutboundGroupSession {
|
||||
return NewOutboundGroupSession()
|
||||
}
|
||||
olm.InitNewBlankOutboundGroupSession = func() olm.OutboundGroupSession {
|
||||
return NewBlankOutboundGroupSession()
|
||||
}
|
||||
olm.InitNewOutboundGroupSession = func() (olm.OutboundGroupSession, error) { return NewOutboundGroupSession() }
|
||||
olm.InitNewBlankOutboundGroupSession = func() olm.OutboundGroupSession { return NewBlankOutboundGroupSession() }
|
||||
}
|
||||
|
||||
// Ensure that [OutboundGroupSession] implements [olm.OutboundGroupSession].
|
||||
var _ olm.OutboundGroupSession = (*OutboundGroupSession)(nil)
|
||||
|
||||
func NewOutboundGroupSession() *OutboundGroupSession {
|
||||
func NewOutboundGroupSession() (*OutboundGroupSession, error) {
|
||||
s := NewBlankOutboundGroupSession()
|
||||
random := make([]byte, s.createRandomLen()+1)
|
||||
_, err := rand.Read(random)
|
||||
if err != nil {
|
||||
panic(olm.NotEnoughGoRandom)
|
||||
return nil, err
|
||||
}
|
||||
r := C.olm_init_outbound_group_session(
|
||||
(*C.OlmOutboundGroupSession)(s.int),
|
||||
(*C.uint8_t)(&random[0]),
|
||||
C.size_t(len(random)))
|
||||
if r == errorVal() {
|
||||
panic(s.lastError())
|
||||
return nil, s.lastError()
|
||||
}
|
||||
return s
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// outboundGroupSessionSize is the size of an outbound group session object in
|
||||
|
|
|
|||
|
|
@ -31,7 +31,8 @@ func TestEncryptDecrypt_GoolmToLibolm(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestEncryptDecrypt_LibolmToGoolm(t *testing.T) {
|
||||
libolmOutbound := libolm.NewOutboundGroupSession()
|
||||
libolmOutbound, err := libolm.NewOutboundGroupSession()
|
||||
require.NoError(t, err)
|
||||
goolmInbound, err := session.NewMegolmInboundSession([]byte(libolmOutbound.Key()))
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ type OutboundGroupSession interface {
|
|||
}
|
||||
|
||||
var InitNewOutboundGroupSessionFromPickled func(pickled, key []byte) (OutboundGroupSession, error)
|
||||
var InitNewOutboundGroupSession func() OutboundGroupSession
|
||||
var InitNewOutboundGroupSession func() (OutboundGroupSession, error)
|
||||
var InitNewBlankOutboundGroupSession func() OutboundGroupSession
|
||||
|
||||
// OutboundGroupSessionFromPickled loads an OutboundGroupSession from a pickled
|
||||
|
|
@ -47,7 +47,7 @@ func OutboundGroupSessionFromPickled(pickled, key []byte) (OutboundGroupSession,
|
|||
}
|
||||
|
||||
// NewOutboundGroupSession creates a new outbound group session.
|
||||
func NewOutboundGroupSession() OutboundGroupSession {
|
||||
func NewOutboundGroupSession() (OutboundGroupSession, error) {
|
||||
return InitNewOutboundGroupSession()
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,8 @@ import (
|
|||
)
|
||||
|
||||
func TestMegolmOutboundSessionPickle_RoundtripThroughGoolm(t *testing.T) {
|
||||
libolmSession := libolm.NewOutboundGroupSession()
|
||||
libolmSession, err := libolm.NewOutboundGroupSession()
|
||||
require.NoError(t, err)
|
||||
libolmPickled, err := libolmSession.Pickle([]byte("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -24,7 +25,8 @@ func TestMegolmOutboundSessionPickle_RoundtripThroughGoolm(t *testing.T) {
|
|||
|
||||
assert.Equal(t, libolmPickled, goolmPickled, "pickled versions are not the same")
|
||||
|
||||
libolmSession2 := libolm.NewOutboundGroupSession()
|
||||
libolmSession2, err := libolm.NewOutboundGroupSession()
|
||||
require.NoError(t, err)
|
||||
err = libolmSession2.Unpickle(bytes.Clone(goolmPickled), []byte("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -38,7 +40,8 @@ func TestMegolmOutboundSessionPickle_RoundtripThroughLibolm(t *testing.T) {
|
|||
goolmPickled, err := goolmSession.Pickle([]byte("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
libolmSession := libolm.NewOutboundGroupSession()
|
||||
libolmSession, err := libolm.NewOutboundGroupSession()
|
||||
require.NoError(t, err)
|
||||
err = libolmSession.Unpickle(bytes.Clone(goolmPickled), []byte("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -55,7 +58,8 @@ func TestMegolmOutboundSessionPickle_RoundtripThroughLibolm(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMegolmOutboundSessionPickleLibolm(t *testing.T) {
|
||||
libolmSession := libolm.NewOutboundGroupSession()
|
||||
libolmSession, err := libolm.NewOutboundGroupSession()
|
||||
require.NoError(t, err)
|
||||
libolmPickled, err := libolmSession.Pickle([]byte("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -77,7 +81,8 @@ func TestMegolmOutboundSessionPickleGoolm(t *testing.T) {
|
|||
goolmPickled, err := goolmSession.Pickle([]byte("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
libolmSession := libolm.NewOutboundGroupSession()
|
||||
libolmSession, err := libolm.NewOutboundGroupSession()
|
||||
require.NoError(t, err)
|
||||
err = libolmSession.Unpickle(bytes.Clone(goolmPickled), []byte("test"))
|
||||
require.NoError(t, err)
|
||||
libolmPickled, err := libolmSession.Pickle([]byte("test"))
|
||||
|
|
@ -98,7 +103,8 @@ func FuzzMegolmOutboundSession_Encrypt(f *testing.F) {
|
|||
t.Skip("empty plaintext is not supported")
|
||||
}
|
||||
|
||||
libolmSession := libolm.NewOutboundGroupSession()
|
||||
libolmSession, err := libolm.NewOutboundGroupSession()
|
||||
require.NoError(t, err)
|
||||
libolmPickled, err := libolmSession.Pickle([]byte("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
|
|||
|
|
@ -180,9 +180,13 @@ type OutboundGroupSession struct {
|
|||
content *event.RoomKeyEventContent
|
||||
}
|
||||
|
||||
func NewOutboundGroupSession(roomID id.RoomID, encryptionContent *event.EncryptionEventContent) *OutboundGroupSession {
|
||||
func NewOutboundGroupSession(roomID id.RoomID, encryptionContent *event.EncryptionEventContent) (*OutboundGroupSession, error) {
|
||||
internal, err := olm.NewOutboundGroupSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ogs := &OutboundGroupSession{
|
||||
Internal: olm.NewOutboundGroupSession(),
|
||||
Internal: internal,
|
||||
ExpirationMixin: ExpirationMixin{
|
||||
TimeMixin: TimeMixin{
|
||||
CreationTime: time.Now(),
|
||||
|
|
@ -206,7 +210,7 @@ func NewOutboundGroupSession(roomID id.RoomID, encryptionContent *event.Encrypti
|
|||
ogs.MaxMessages = min(max(encryptionContent.RotationPeriodMessages, 1), 10000)
|
||||
}
|
||||
}
|
||||
return ogs
|
||||
return ogs, nil
|
||||
}
|
||||
|
||||
func (ogs *OutboundGroupSession) ShareContent() event.Content {
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import (
|
|||
"testing"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.mau.fi/util/dbutil"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/olm"
|
||||
|
|
@ -195,7 +196,8 @@ func TestStoreOutboundMegolmSession(t *testing.T) {
|
|||
t.Errorf("Error retrieving outbound session: %v", err)
|
||||
}
|
||||
|
||||
outbound := NewOutboundGroupSession("room1", nil)
|
||||
outbound, err := NewOutboundGroupSession("room1", nil)
|
||||
require.NoError(t, err)
|
||||
err = store.AddOutboundGroupSession(context.TODO(), outbound)
|
||||
if err != nil {
|
||||
t.Errorf("Error inserting outbound session: %v", err)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue