goolm: simplify tests using testify

Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
This commit is contained in:
Sumner Evans 2024-09-04 00:31:14 -06:00
commit eb632a9994
No known key found for this signature in database
17 changed files with 488 additions and 1271 deletions

View file

@ -1,13 +1,10 @@
package account_test
import (
"bytes"
"encoding/base64"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"maunium.net/go/mautrix/id"
@ -18,75 +15,42 @@ import (
func TestAccount(t *testing.T) {
firstAccount, err := account.NewAccount()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
err = firstAccount.GenFallbackKey()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
err = firstAccount.GenOneTimeKeys(2)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
encryptionKey := []byte("testkey")
//now pickle account in JSON format
pickled, err := firstAccount.PickleAsJSON(encryptionKey)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
//now unpickle into new Account
unpickledAccount, err := account.AccountFromJSONPickled(pickled, encryptionKey)
if err != nil {
t.Fatal(err)
}
//check if accounts are the same
if firstAccount.NextOneTimeKeyID != unpickledAccount.NextOneTimeKeyID {
t.Fatal("NextOneTimeKeyID unequal")
}
if !firstAccount.CurrentFallbackKey.Equal(unpickledAccount.CurrentFallbackKey) {
t.Fatal("CurrentFallbackKey unequal")
}
if !firstAccount.PrevFallbackKey.Equal(unpickledAccount.PrevFallbackKey) {
t.Fatal("PrevFallbackKey unequal")
}
if len(firstAccount.OTKeys) != len(unpickledAccount.OTKeys) {
t.Fatal("OneTimeKeysunequal")
}
for i := range firstAccount.OTKeys {
if !firstAccount.OTKeys[i].Equal(unpickledAccount.OTKeys[i]) {
t.Fatalf("OneTimeKeys %d unequal", i)
}
}
if !firstAccount.IdKeys.Curve25519.PrivateKey.Equal(unpickledAccount.IdKeys.Curve25519.PrivateKey) {
t.Fatal("IdentityKeys Curve25519 private unequal")
}
if !firstAccount.IdKeys.Curve25519.PublicKey.Equal(unpickledAccount.IdKeys.Curve25519.PublicKey) {
t.Fatal("IdentityKeys Curve25519 public unequal")
}
if !firstAccount.IdKeys.Ed25519.PrivateKey.Equal(unpickledAccount.IdKeys.Ed25519.PrivateKey) {
t.Fatal("IdentityKeys Ed25519 private unequal")
}
if !firstAccount.IdKeys.Ed25519.PublicKey.Equal(unpickledAccount.IdKeys.Ed25519.PublicKey) {
t.Fatal("IdentityKeys Ed25519 public unequal")
}
assert.NoError(t, err)
if otks, err := firstAccount.OneTimeKeys(); err != nil || len(otks) != 2 {
t.Fatal("should get 2 unpublished oneTimeKeys")
}
if len(firstAccount.FallbackKeyUnpublished()) == 0 {
t.Fatal("should get fallbackKey")
}
//check if accounts are the same
assert.Equal(t, firstAccount.NextOneTimeKeyID, unpickledAccount.NextOneTimeKeyID)
assert.Equal(t, firstAccount.CurrentFallbackKey, unpickledAccount.CurrentFallbackKey)
assert.Equal(t, firstAccount.PrevFallbackKey, unpickledAccount.PrevFallbackKey)
assert.Equal(t, firstAccount.OTKeys, unpickledAccount.OTKeys)
assert.Equal(t, firstAccount.IdKeys, unpickledAccount.IdKeys)
// Ensure that all of the keys are unpublished right now
otks, err := firstAccount.OneTimeKeys()
assert.NoError(t, err)
assert.Len(t, otks, 2)
assert.Len(t, firstAccount.FallbackKeyUnpublished(), 1)
// Now, publish the key and make sure that they are published
firstAccount.MarkKeysAsPublished()
if len(firstAccount.FallbackKey()) == 0 {
t.Fatal("should get fallbackKey")
}
if len(firstAccount.FallbackKeyUnpublished()) != 0 {
t.Fatal("should get no fallbackKey")
}
if otks, err := firstAccount.OneTimeKeys(); err != nil || len(otks) != 0 {
t.Fatal("should get no oneTimeKeys")
}
assert.Len(t, firstAccount.FallbackKeyUnpublished(), 0)
assert.Len(t, firstAccount.FallbackKey(), 1)
otks, err = firstAccount.OneTimeKeys()
assert.NoError(t, err)
assert.Len(t, otks, 0)
}
func TestAccountPickleJSON(t *testing.T) {
@ -104,109 +68,49 @@ func TestAccountPickleJSON(t *testing.T) {
pickledData := []byte("6POkBWwbNl20fwvZWsOu0jgbHy4jkA5h0Ji+XCag59+ifWIRPDrqtgQi9HmkLiSF6wUhhYaV4S73WM+Hh+dlCuZRuXhTQr8yGPTifjcjq8birdAhObbEqHrYEdqaQkrgBLr/rlS5sibXeDqbkhVu4LslvootU9DkcCbd4b/0Flh7iugxqkcCs5GDndTEx9IzTVJzmK82Y0Q1Z1Z9Vuc2Iw746PtBJLtZjite6fSMp2NigPX/ZWWJ3OnwcJo0Vvjy8hgptZEWkamOHdWbUtelbHyjDIZlvxOC25D3rFif0zzPkF9qdpBPqVCWPPzGFmgnqKau6CHrnPfq7GLsM3BrprD7sHN1Js28ex14gXQPjBT7KTUo6H0e4gQMTMRp4qb8btNXDeId8xIFIElTh2SXZBTDmSq/ziVNJinEvYV8mGPvJZjDQQU+SyoS/HZ8uMc41tH0BOGDbFMHbfLMiz61E429gOrx2klu5lqyoyet7//HKi0ed5w2dQ")
account, err := account.AccountFromJSONPickled(pickledData, key)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
expectedJSON := `{"ed25519":"qWvNB6Ztov5/AOsP073op0O32KJ8/tgSNarT7MaYgQE","curve25519":"TFUB6M6zwgyWhBEp2m1aUodl2AsnsrIuBr8l9AvwGS8"}`
jsonData, err := account.IdentityKeysJSON()
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(jsonData, []byte(expectedJSON)) {
t.Fatalf("Expected '%s' but got '%s'", expectedJSON, jsonData)
}
assert.NoError(t, err)
assert.Equal(t, expectedJSON, string(jsonData))
}
func TestSessions(t *testing.T) {
aliceAccount, err := account.NewAccount()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
err = aliceAccount.GenOneTimeKeys(5)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
bobAccount, err := account.NewAccount()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
err = bobAccount.GenOneTimeKeys(5)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
aliceSession, err := aliceAccount.NewOutboundSession(bobAccount.IdKeys.Curve25519.B64Encoded(), bobAccount.OTKeys[2].Key.B64Encoded())
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
plaintext := []byte("test message")
msgType, crypttext, err := aliceSession.Encrypt(plaintext)
if err != nil {
t.Fatal(err)
}
if msgType != id.OlmMsgTypePreKey {
t.Fatal("wrong message type")
}
assert.NoError(t, err)
assert.Equal(t, id.OlmMsgTypePreKey, msgType)
bobSession, err := bobAccount.NewInboundSession(string(crypttext))
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
decodedText, err := bobSession.Decrypt(string(crypttext), msgType)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(plaintext, decodedText) {
t.Fatalf("expected '%s' but got '%s'", string(plaintext), string(decodedText))
}
assert.NoError(t, err)
assert.Equal(t, plaintext, decodedText)
}
func TestAccountPickle(t *testing.T) {
pickleKey := []byte("secret_key")
account, err := account.AccountFromPickled(pickledDataFromLibOlm, pickleKey)
if err != nil {
t.Fatal(err)
}
if !expectedEd25519KeyPairPickleLibOLM.PrivateKey.Equal(account.IdKeys.Ed25519.PrivateKey) {
t.Fatal("keys not equal")
}
if !expectedEd25519KeyPairPickleLibOLM.PublicKey.Equal(account.IdKeys.Ed25519.PublicKey) {
t.Fatal("keys not equal")
}
if !expectedCurve25519KeyPairPickleLibOLM.PrivateKey.Equal(account.IdKeys.Curve25519.PrivateKey) {
t.Fatal("keys not equal")
}
if !expectedCurve25519KeyPairPickleLibOLM.PublicKey.Equal(account.IdKeys.Curve25519.PublicKey) {
t.Fatal("keys not equal")
}
if account.NextOneTimeKeyID != 42 {
t.Fatal("wrong next otKey id")
}
if len(account.OTKeys) != len(expectedOTKeysPickleLibOLM) {
t.Fatal("wrong number of otKeys")
}
if account.NumFallbackKeys != 0 {
t.Fatal("fallback keys set but not in pickle")
}
for curIndex, curValue := range account.OTKeys {
curExpected := expectedOTKeysPickleLibOLM[curIndex]
if curExpected.ID != curValue.ID {
t.Fatal("OTKey id not correct")
}
if !curExpected.Key.PublicKey.Equal(curValue.Key.PublicKey) {
t.Fatal("OTKey public key not correct")
}
if !curExpected.Key.PrivateKey.Equal(curValue.Key.PrivateKey) {
t.Fatal("OTKey private key not correct")
}
}
assert.NoError(t, err)
assert.Equal(t, expectedEd25519KeyPairPickleLibOLM, account.IdKeys.Ed25519)
assert.Equal(t, expectedCurve25519KeyPairPickleLibOLM, account.IdKeys.Curve25519)
assert.EqualValues(t, 42, account.NextOneTimeKeyID)
assert.Equal(t, account.OTKeys, expectedOTKeysPickleLibOLM)
assert.EqualValues(t, 0, account.NumFallbackKeys)
targetPickled, err := account.Pickle(pickleKey)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(targetPickled, pickledDataFromLibOlm) {
t.Fatal("repickled value does not equal given value")
}
assert.NoError(t, err)
assert.Equal(t, pickledDataFromLibOlm, targetPickled)
}
func TestOldAccountPickle(t *testing.T) {
@ -218,355 +122,212 @@ func TestOldAccountPickle(t *testing.T) {
"O5TmXua1FcU")
pickleKey := []byte("")
account, err := account.NewAccount()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
err = account.Unpickle(pickled, pickleKey)
if err == nil {
t.Fatal("expected error")
} else {
if !errors.Is(err, olm.ErrBadVersion) {
t.Fatal(err)
}
}
assert.ErrorIs(t, err, olm.ErrBadVersion)
}
func TestLoopback(t *testing.T) {
accountA, err := account.NewAccount()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
accountB, err := account.NewAccount()
if err != nil {
t.Fatal(err)
}
err = accountB.GenOneTimeKeys(42)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
err = accountB.GenOneTimeKeys( 42)
assert.NoError(t, err)
aliceSession, err := accountA.NewOutboundSession(accountB.IdKeys.Curve25519.B64Encoded(), accountB.OTKeys[0].Key.B64Encoded())
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
plainText := []byte("Hello, World")
msgType, message1, err := aliceSession.Encrypt(plainText)
if err != nil {
t.Fatal(err)
}
if msgType != id.OlmMsgTypePreKey {
t.Fatal("wrong message type")
}
assert.NoError(t, err)
assert.Equal(t, id.OlmMsgTypePreKey, msgType)
bobSession, err := accountB.NewInboundSession(string(message1))
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
// Check that the inbound session matches the message it was created from.
sessionIsOK, err := bobSession.MatchesInboundSessionFrom("", string(message1))
if err != nil {
t.Fatal(err)
}
if !sessionIsOK {
t.Fatal("session was not detected to be valid")
}
assert.NoError(t, err)
assert.True(t, sessionIsOK, "session was not detected to be valid")
// Check that the inbound session matches the key this message is supposed to be from.
aIDKey := accountA.IdKeys.Curve25519.PublicKey.B64Encoded()
sessionIsOK, err = bobSession.MatchesInboundSessionFrom(string(aIDKey), string(message1))
if err != nil {
t.Fatal(err)
}
if !sessionIsOK {
t.Fatal("session is sad to be not from a but it should")
}
assert.NoError(t, err)
assert.True(t, sessionIsOK, "session is sad to be not from a but it should")
// Check that the inbound session isn't from a different user.
bIDKey := accountB.IdKeys.Curve25519.PublicKey.B64Encoded()
sessionIsOK, err = bobSession.MatchesInboundSessionFrom(string(bIDKey), string(message1))
if err != nil {
t.Fatal(err)
}
if sessionIsOK {
t.Fatal("session is sad to be from b but is from a")
}
assert.NoError(t, err)
assert.False(t, sessionIsOK, "session is sad to be from b but is from a")
// Check that we can decrypt the message.
decryptedMessage, err := bobSession.Decrypt(string(message1), msgType)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(decryptedMessage, plainText) {
t.Fatal("messages are not the same")
}
assert.NoError(t, err)
assert.Equal(t, plainText, decryptedMessage)
msgTyp2, message2, err := bobSession.Encrypt(plainText)
if err != nil {
t.Fatal(err)
}
if msgTyp2 == id.OlmMsgTypePreKey {
t.Fatal("wrong message type")
}
assert.NoError(t, err)
assert.Equal(t, id.OlmMsgTypeMsg, msgTyp2)
decryptedMessage2, err := aliceSession.Decrypt(string(message2), msgTyp2)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(decryptedMessage2, plainText) {
t.Fatal("messages are not the same")
}
assert.NoError(t, err)
assert.Equal(t, plainText, decryptedMessage2)
//decrypting again should fail, as the chain moved on
_, err = aliceSession.Decrypt(string(message2), msgTyp2)
if err == nil {
t.Fatal("expected error")
}
assert.Error(t, err)
assert.ErrorIs(t, err, olm.ErrMessageKeyNotFound)
//compare sessionIDs
if aliceSession.ID() != bobSession.ID() {
t.Fatal("sessionIDs are not equal")
}
assert.Equal(t, aliceSession.ID(), bobSession.ID())
}
func TestMoreMessages(t *testing.T) {
accountA, err := account.NewAccount()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
accountB, err := account.NewAccount()
if err != nil {
t.Fatal(err)
}
err = accountB.GenOneTimeKeys(42)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
err = accountB.GenOneTimeKeys( 42)
assert.NoError(t, err)
aliceSession, err := accountA.NewOutboundSession(accountB.IdKeys.Curve25519.B64Encoded(), accountB.OTKeys[0].Key.B64Encoded())
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
plainText := []byte("Hello, World")
msgType, message1, err := aliceSession.Encrypt(plainText)
if err != nil {
t.Fatal(err)
}
if msgType != id.OlmMsgTypePreKey {
t.Fatal("wrong message type")
}
assert.NoError(t, err)
assert.Equal(t, id.OlmMsgTypePreKey, msgType)
bobSession, err := accountB.NewInboundSession(string(message1))
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
decryptedMessage, err := bobSession.Decrypt(string(message1), msgType)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(decryptedMessage, plainText) {
t.Fatal("messages are not the same")
}
assert.NoError(t, err)
assert.Equal(t, plainText, decryptedMessage)
for i := 0; i < 8; i++ {
//alice sends, bob reveices
msgType, message, err := aliceSession.Encrypt(plainText)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
if i == 0 {
//The first time should still be a preKeyMessage as bob has not yet send a message to alice
if msgType != id.OlmMsgTypePreKey {
t.Fatal("wrong message type")
}
assert.Equal(t, id.OlmMsgTypePreKey, msgType)
} else {
if msgType == id.OlmMsgTypePreKey {
t.Fatal("wrong message type")
}
assert.Equal(t, id.OlmMsgTypeMsg, msgType)
}
decryptedMessage, err := bobSession.Decrypt(string(message), msgType)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(decryptedMessage, plainText) {
t.Fatal("messages are not the same")
}
assert.NoError(t, err)
assert.Equal(t, plainText, decryptedMessage)
//now bob sends, alice receives
msgType, message, err = bobSession.Encrypt(plainText)
if err != nil {
t.Fatal(err)
}
if msgType == id.OlmMsgTypePreKey {
t.Fatal("wrong message type")
}
assert.NoError(t, err)
assert.Equal(t, id.OlmMsgTypeMsg, msgType)
decryptedMessage, err = aliceSession.Decrypt(string(message), msgType)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(decryptedMessage, plainText) {
t.Fatal("messages are not the same")
}
assert.NoError(t, err)
assert.Equal(t, plainText, decryptedMessage)
}
}
func TestFallbackKey(t *testing.T) {
accountA, err := account.NewAccount()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
accountB, err := account.NewAccount()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
err = accountB.GenFallbackKey()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
fallBackKeys := accountB.FallbackKeyUnpublished()
var fallbackKey id.Curve25519
for _, fbKey := range fallBackKeys {
fallbackKey = fbKey
}
aliceSession, err := accountA.NewOutboundSession(accountB.IdKeys.Curve25519.B64Encoded(), fallbackKey)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
plainText := []byte("Hello, World")
msgType, message1, err := aliceSession.Encrypt(plainText)
if err != nil {
t.Fatal(err)
}
if msgType != id.OlmMsgTypePreKey {
t.Fatal("wrong message type")
}
assert.NoError(t, err)
assert.Equal(t, id.OlmMsgTypePreKey, msgType)
bobSession, err := accountB.NewInboundSession(string(message1))
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
// Check that the inbound session matches the message it was created from.
sessionIsOK, err := bobSession.MatchesInboundSessionFrom("", string(message1))
if err != nil {
t.Fatal(err)
}
if !sessionIsOK {
t.Fatal("session was not detected to be valid")
}
assert.NoError(t, err)
assert.True(t, sessionIsOK, "session was not detected to be valid")
// Check that the inbound session matches the key this message is supposed to be from.
aIDKey := accountA.IdKeys.Curve25519.PublicKey.B64Encoded()
sessionIsOK, err = bobSession.MatchesInboundSessionFrom(string(aIDKey), string(message1))
if err != nil {
t.Fatal(err)
}
if !sessionIsOK {
t.Fatal("session is sad to be not from a but it should")
}
assert.NoError(t, err)
assert.True(t, sessionIsOK, "session is sad to be not from a but it should")
// Check that the inbound session isn't from a different user.
bIDKey := accountB.IdKeys.Curve25519.PublicKey.B64Encoded()
sessionIsOK, err = bobSession.MatchesInboundSessionFrom(string(bIDKey), string(message1))
if err != nil {
t.Fatal(err)
}
if sessionIsOK {
t.Fatal("session is sad to be from b but is from a")
}
assert.NoError(t, err)
assert.False(t, sessionIsOK, "session is sad to be from b but is from a")
// Check that we can decrypt the message.
decryptedMessage, err := bobSession.Decrypt(string(message1), msgType)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(decryptedMessage, plainText) {
t.Fatal("messages are not the same")
}
assert.NoError(t, err)
assert.Equal(t, plainText, decryptedMessage)
// create a new fallback key for B (the old fallback should still be usable)
err = accountB.GenFallbackKey()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
// start another session and encrypt a message
aliceSession2, err := accountA.NewOutboundSession(accountB.IdKeys.Curve25519.B64Encoded(), fallbackKey)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
msgType2, message2, err := aliceSession2.Encrypt(plainText)
if err != nil {
t.Fatal(err)
}
if msgType2 != id.OlmMsgTypePreKey {
t.Fatal("wrong message type")
}
assert.NoError(t, err)
assert.Equal(t, id.OlmMsgTypePreKey, msgType2)
// bobSession should not be valid for the message2
// Check that the inbound session matches the message it was created from.
sessionIsOK, err = bobSession.MatchesInboundSessionFrom("", string(message2))
if err != nil {
t.Fatal(err)
}
if sessionIsOK {
t.Fatal("session was detected to be valid but should not")
}
assert.NoError(t, err)
assert.False(t, sessionIsOK, "session was detected to be valid but should not")
bobSession2, err := accountB.NewInboundSession(string(message2))
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
// Check that the inbound session matches the message it was created from.
sessionIsOK, err = bobSession2.MatchesInboundSessionFrom("", string(message2))
if err != nil {
t.Fatal(err)
}
if !sessionIsOK {
t.Fatal("session was not detected to be valid")
}
assert.NoError(t, err)
assert.True(t, sessionIsOK, "session was not detected to be valid")
// Check that the inbound session matches the key this message is supposed to be from.
sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(string(aIDKey), string(message2))
if err != nil {
t.Fatal(err)
}
if !sessionIsOK {
t.Fatal("session is sad to be not from a but it should")
}
assert.NoError(t, err)
assert.True(t, sessionIsOK, "session is sad to be not from a but it should")
// Check that the inbound session isn't from a different user.
sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(string(bIDKey), string(message2))
if err != nil {
t.Fatal(err)
}
if sessionIsOK {
t.Fatal("session is sad to be from b but is from a")
}
assert.NoError(t, err)
assert.False(t, sessionIsOK, "session is sad to be from b but is from a")
// Check that we can decrypt the message.
decryptedMessage2, err := bobSession2.Decrypt(string(message2), msgType2)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(decryptedMessage2, plainText) {
t.Fatal("messages are not the same")
}
assert.NoError(t, err)
assert.Equal(t, plainText, decryptedMessage2)
//Forget the old fallback key -- creating a new session should fail now
accountB.ForgetOldFallbackKey()
// start another session and encrypt a message
aliceSession3, err := accountA.NewOutboundSession(accountB.IdKeys.Curve25519.B64Encoded(), fallbackKey)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
msgType3, message3, err := aliceSession3.Encrypt(plainText)
if err != nil {
t.Fatal(err)
}
if msgType3 != id.OlmMsgTypePreKey {
t.Fatal("wrong message type")
}
assert.NoError(t, err)
assert.Equal(t, id.OlmMsgTypePreKey, msgType3)
_, err = accountB.NewInboundSession(string(message3))
if err == nil {
t.Fatal("expected error")
}
if !errors.Is(err, olm.ErrBadMessageKeyID) {
t.Fatal(err)
}
assert.ErrorIs(t, err, olm.ErrBadMessageKeyID)
}
func TestOldV3AccountPickle(t *testing.T) {
@ -582,33 +343,23 @@ func TestOldV3AccountPickle(t *testing.T) {
expectedUnpublishedFallbackJSON := []byte("{\"curve25519\":{}}")
account, err := account.AccountFromPickled(pickledData, pickleKey)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
fallbackJSON, err := account.FallbackKeyJSON()
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(fallbackJSON, expectedFallbackJSON) {
t.Fatalf("expected not as result:\n%s\n%s\n", expectedFallbackJSON, fallbackJSON)
}
assert.NoError(t, err)
assert.Equal(t, expectedFallbackJSON, fallbackJSON)
fallbackJSONUnpublished, err := account.FallbackKeyUnpublishedJSON()
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(fallbackJSONUnpublished, expectedUnpublishedFallbackJSON) {
t.Fatalf("expected not as result:\n%s\n%s\n", expectedUnpublishedFallbackJSON, fallbackJSONUnpublished)
}
assert.NoError(t, err)
assert.Equal(t, expectedUnpublishedFallbackJSON, fallbackJSONUnpublished)
}
func TestAccountSign(t *testing.T) {
accountA, err := account.NewAccount()
require.NoError(t, err)
assert.NoError(t, err)
plainText := []byte("Hello, World")
signatureB64, err := accountA.Sign(plainText)
require.NoError(t, err)
assert.NoError(t, err)
signature, err := base64.RawStdEncoding.DecodeString(string(signatureB64))
require.NoError(t, err)
assert.NoError(t, err)
verified, err := signatures.VerifySignature(plainText, accountA.IdKeys.Ed25519.B64Encoded(), signature)
assert.NoError(t, err)

View file

@ -1,52 +1,44 @@
package cipher
import (
"bytes"
"crypto/aes"
"testing"
"github.com/stretchr/testify/assert"
)
func TestDeriveAESKeys(t *testing.T) {
kdfInfo := []byte("test")
key := []byte("test key")
derivedKeys, err := deriveAESKeys(kdfInfo, key)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
derivedKeys2, err := deriveAESKeys(kdfInfo, key)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
//derivedKeys and derivedKeys2 should be identical
if !bytes.Equal(derivedKeys.key, derivedKeys2.key) ||
!bytes.Equal(derivedKeys.iv, derivedKeys2.iv) ||
!bytes.Equal(derivedKeys.hmacKey, derivedKeys2.hmacKey) {
t.Fail()
}
assert.Equal(t, derivedKeys.key, derivedKeys2.key)
assert.Equal(t, derivedKeys.iv, derivedKeys2.iv)
assert.Equal(t, derivedKeys.hmacKey, derivedKeys2.hmacKey)
//changing kdfInfo
kdfInfo = []byte("other kdf")
derivedKeys2, err = deriveAESKeys(kdfInfo, key)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
//derivedKeys and derivedKeys2 should now be different
if bytes.Equal(derivedKeys.key, derivedKeys2.key) ||
bytes.Equal(derivedKeys.iv, derivedKeys2.iv) ||
bytes.Equal(derivedKeys.hmacKey, derivedKeys2.hmacKey) {
t.Fail()
}
assert.NotEqual(t, derivedKeys.key, derivedKeys2.key)
assert.NotEqual(t, derivedKeys.iv, derivedKeys2.iv)
assert.NotEqual(t, derivedKeys.hmacKey, derivedKeys2.hmacKey)
//changing key
key = []byte("other test key")
derivedKeys, err = deriveAESKeys(kdfInfo, key)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
//derivedKeys and derivedKeys2 should now be different
if bytes.Equal(derivedKeys.key, derivedKeys2.key) ||
bytes.Equal(derivedKeys.iv, derivedKeys2.iv) ||
bytes.Equal(derivedKeys.hmacKey, derivedKeys2.hmacKey) {
t.Fail()
}
assert.NotEqual(t, derivedKeys.key, derivedKeys2.key)
assert.NotEqual(t, derivedKeys.iv, derivedKeys2.iv)
assert.NotEqual(t, derivedKeys.hmacKey, derivedKeys2.hmacKey)
}
func TestCipherAESSha256(t *testing.T) {
@ -58,26 +50,15 @@ func TestCipherAESSha256(t *testing.T) {
message = append(message, []byte("-")...)
}
encrypted, err := cipher.Encrypt(key, []byte(message))
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
mac, err := cipher.MAC(key, encrypted)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
verified, err := cipher.Verify(key, encrypted, mac[:8])
if err != nil {
t.Fatal(err)
}
if !verified {
t.Fatal("signature verification failed")
}
assert.NoError(t, err)
assert.True(t, verified, "signature verification failed")
resultPlainText, err := cipher.Decrypt(key, encrypted)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(message, resultPlainText) {
t.Fail()
}
assert.NoError(t, err)
assert.Equal(t, message, resultPlainText)
}

View file

@ -1,10 +1,11 @@
package cipher_test
import (
"bytes"
"crypto/aes"
"testing"
"github.com/stretchr/testify/assert"
"maunium.net/go/mautrix/crypto/goolm/cipher"
)
@ -19,15 +20,9 @@ func TestEncoding(t *testing.T) {
copy(toEncrypt, input)
}
encoded, err := cipher.Pickle(key, toEncrypt)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
decoded, err := cipher.Unpickle(key, encoded)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(decoded, toEncrypt) {
t.Fatalf("Expected '%s' but got '%s'", toEncrypt, decoded)
}
assert.NoError(t, err)
assert.Equal(t, toEncrypt, decoded)
}

View file

@ -1,39 +1,26 @@
package crypto_test
import (
"bytes"
"testing"
"github.com/stretchr/testify/assert"
"maunium.net/go/mautrix/crypto/goolm/crypto"
)
func TestCurve25519(t *testing.T) {
firstKeypair, err := crypto.Curve25519GenerateKey()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
secondKeypair, err := crypto.Curve25519GenerateKey()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
sharedSecretFromFirst, err := firstKeypair.SharedSecret(secondKeypair.PublicKey)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
sharedSecretFromSecond, err := secondKeypair.SharedSecret(firstKeypair.PublicKey)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(sharedSecretFromFirst, sharedSecretFromSecond) {
t.Fatal("shared secret not equal")
}
assert.NoError(t, err)
assert.Equal(t, sharedSecretFromFirst, sharedSecretFromSecond, "shared secret not equal")
fromPrivate, err := crypto.Curve25519GenerateFromPrivate(firstKeypair.PrivateKey)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(fromPrivate.PublicKey, firstKeypair.PublicKey) {
t.Fatal("public keys not equal")
}
assert.NoError(t, err)
assert.Equal(t, fromPrivate, firstKeypair)
}
func TestCurve25519Case1(t *testing.T) {
@ -76,112 +63,59 @@ func TestCurve25519Case1(t *testing.T) {
PublicKey: bobPublic,
}
agreementFromAlice, err := aliceKeyPair.SharedSecret(bobKeyPair.PublicKey)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(agreementFromAlice, expectedAgreement) {
t.Fatal("expected agreement does not match agreement from Alice's view")
}
assert.NoError(t, err)
assert.Equal(t, expectedAgreement, agreementFromAlice, "expected agreement does not match agreement from Alice's view")
agreementFromBob, err := bobKeyPair.SharedSecret(aliceKeyPair.PublicKey)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(agreementFromBob, expectedAgreement) {
t.Fatal("expected agreement does not match agreement from Bob's view")
}
assert.NoError(t, err)
assert.Equal(t, expectedAgreement, agreementFromBob, "expected agreement does not match agreement from Bob's view")
}
func TestCurve25519Pickle(t *testing.T) {
//create keypair
keyPair, err := crypto.Curve25519GenerateKey()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
target := make([]byte, keyPair.PickleLen())
writtenBytes, err := keyPair.PickleLibOlm(target)
if err != nil {
t.Fatal(err)
}
if writtenBytes != len(target) {
t.Fatal("written bytes not correct")
}
assert.NoError(t, err)
assert.Len(t, target, writtenBytes)
unpickledKeyPair := crypto.Curve25519KeyPair{}
readBytes, err := unpickledKeyPair.UnpickleLibOlm(target)
if err != nil {
t.Fatal(err)
}
if readBytes != len(target) {
t.Fatal("read bytes not correct")
}
if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) {
t.Fatal("private keys not correct")
}
if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) {
t.Fatal("public keys not correct")
}
assert.NoError(t, err)
assert.Len(t, target, readBytes)
assert.Equal(t, keyPair, unpickledKeyPair)
}
func TestCurve25519PicklePubKeyOnly(t *testing.T) {
//create keypair
keyPair, err := crypto.Curve25519GenerateKey()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
//Remove privateKey
keyPair.PrivateKey = nil
target := make([]byte, keyPair.PickleLen())
writtenBytes, err := keyPair.PickleLibOlm(target)
if err != nil {
t.Fatal(err)
}
if writtenBytes != len(target) {
t.Fatal("written bytes not correct")
}
assert.NoError(t, err)
assert.Len(t, target, writtenBytes)
unpickledKeyPair := crypto.Curve25519KeyPair{}
readBytes, err := unpickledKeyPair.UnpickleLibOlm(target)
if err != nil {
t.Fatal(err)
}
if readBytes != len(target) {
t.Fatal("read bytes not correct")
}
if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) {
t.Fatal("private keys not correct")
}
if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) {
t.Fatal("public keys not correct")
}
assert.NoError(t, err)
assert.Len(t, target, readBytes)
assert.Equal(t, keyPair, unpickledKeyPair)
}
func TestCurve25519PicklePrivKeyOnly(t *testing.T) {
//create keypair
keyPair, err := crypto.Curve25519GenerateKey()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
//Remove public
keyPair.PublicKey = nil
target := make([]byte, keyPair.PickleLen())
writtenBytes, err := keyPair.PickleLibOlm(target)
if err != nil {
t.Fatal(err)
}
if writtenBytes != len(target) {
t.Fatal("written bytes not correct")
}
assert.NoError(t, err)
assert.Len(t, target, writtenBytes)
unpickledKeyPair := crypto.Curve25519KeyPair{}
readBytes, err := unpickledKeyPair.UnpickleLibOlm(target)
if err != nil {
t.Fatal(err)
}
if readBytes != len(target) {
t.Fatal("read bytes not correct")
}
if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) {
t.Fatal("private keys not correct")
}
if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) {
t.Fatal("public keys not correct")
}
assert.NoError(t, err)
assert.Len(t, target, readBytes)
assert.Equal(t, keyPair, unpickledKeyPair)
}

View file

@ -1,140 +1,87 @@
package crypto_test
import (
"bytes"
"testing"
"github.com/stretchr/testify/assert"
"maunium.net/go/mautrix/crypto/goolm/crypto"
)
func TestEd25519(t *testing.T) {
keypair, err := crypto.Ed25519GenerateKey()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
message := []byte("test message")
signature := keypair.Sign(message)
if !keypair.Verify(message, signature) {
t.Fail()
}
assert.True(t, keypair.Verify(message, signature))
}
func TestEd25519Case1(t *testing.T) {
//64 bytes for ed25519 package
keyPair, err := crypto.Ed25519GenerateKey()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
message := []byte("Hello, World")
keyPair2 := crypto.Ed25519GenerateFromPrivate(keyPair.PrivateKey)
if !bytes.Equal(keyPair.PublicKey, keyPair2.PublicKey) {
t.Fatal("not equal key pairs")
}
assert.Equal(t, keyPair, keyPair2, "not equal key pairs")
signature := keyPair.Sign(message)
verified := keyPair.Verify(message, signature)
if !verified {
t.Fatal("message did not verify although it should")
}
assert.True(t, verified, "message did not verify although it should")
//Now change the message and verify again
message = append(message, []byte("a")...)
verified = keyPair.Verify(message, signature)
if verified {
t.Fatal("message did verify although it should not")
}
assert.False(t, verified, "message did verify although it should not")
}
func TestEd25519Pickle(t *testing.T) {
//create keypair
keyPair, err := crypto.Ed25519GenerateKey()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
target := make([]byte, keyPair.PickleLen())
writtenBytes, err := keyPair.PickleLibOlm(target)
if err != nil {
t.Fatal(err)
}
if writtenBytes != len(target) {
t.Fatal("written bytes not correct")
}
assert.NoError(t, err)
assert.Len(t, target, writtenBytes)
unpickledKeyPair := crypto.Ed25519KeyPair{}
readBytes, err := unpickledKeyPair.UnpickleLibOlm(target)
if err != nil {
t.Fatal(err)
}
if readBytes != len(target) {
t.Fatal("read bytes not correct")
}
if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) {
t.Fatal("private keys not correct")
}
if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) {
t.Fatal("public keys not correct")
}
assert.NoError(t, err)
assert.Len(t, target, readBytes, "read bytes not correct")
assert.Equal(t, keyPair, unpickledKeyPair)
}
func TestEd25519PicklePubKeyOnly(t *testing.T) {
//create keypair
keyPair, err := crypto.Ed25519GenerateKey()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
//Remove privateKey
keyPair.PrivateKey = nil
target := make([]byte, keyPair.PickleLen())
writtenBytes, err := keyPair.PickleLibOlm(target)
if err != nil {
t.Fatal(err)
}
if writtenBytes != len(target) {
t.Fatal("written bytes not correct")
}
assert.NoError(t, err)
assert.Len(t, target, writtenBytes)
unpickledKeyPair := crypto.Ed25519KeyPair{}
readBytes, err := unpickledKeyPair.UnpickleLibOlm(target)
if err != nil {
t.Fatal(err)
}
if readBytes != len(target) {
t.Fatal("read bytes not correct")
}
if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) {
t.Fatal("private keys not correct")
}
if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) {
t.Fatal("public keys not correct")
}
assert.NoError(t, err)
assert.Len(t, target, readBytes, "read bytes not correct")
assert.Equal(t, keyPair, unpickledKeyPair)
}
func TestEd25519PicklePrivKeyOnly(t *testing.T) {
//create keypair
keyPair, err := crypto.Ed25519GenerateKey()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
//Remove public
keyPair.PublicKey = nil
target := make([]byte, keyPair.PickleLen())
writtenBytes, err := keyPair.PickleLibOlm(target)
if err != nil {
t.Fatal(err)
}
if writtenBytes != len(target) {
t.Fatal("written bytes not correct")
}
assert.NoError(t, err)
assert.Len(t, target, writtenBytes)
unpickledKeyPair := crypto.Ed25519KeyPair{}
readBytes, err := unpickledKeyPair.UnpickleLibOlm(target)
if err != nil {
t.Fatal(err)
}
if readBytes != len(target) {
t.Fatal("read bytes not correct")
}
if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) {
t.Fatal("private keys not correct")
}
if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) {
t.Fatal("public keys not correct")
}
assert.NoError(t, err)
assert.Len(t, target, readBytes, "read bytes not correct")
assert.Equal(t, keyPair, unpickledKeyPair)
}

View file

@ -1,49 +1,44 @@
package crypto_test
import (
"bytes"
"encoding/base64"
"io"
"testing"
"github.com/stretchr/testify/assert"
"maunium.net/go/mautrix/crypto/goolm/crypto"
)
func TestHMACSha256(t *testing.T) {
func TestHMACSHA256(t *testing.T) {
key := []byte("test key")
message := []byte("test message")
hash := crypto.HMACSHA256(key, message)
if !bytes.Equal(hash, crypto.HMACSHA256(key, message)) {
t.Fail()
}
assert.Equal(t, hash, crypto.HMACSHA256(key, message))
str := "A4M0ovdiWHaZ5msdDFbrvtChFwZIoIaRSVGmv8bmPtc"
result, err := base64.RawStdEncoding.DecodeString(str)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(result, hash) {
t.Fail()
}
assert.NoError(t, err)
assert.Equal(t, result, hash)
}
func TestHKDFSha256(t *testing.T) {
func TestHKDFSHA256(t *testing.T) {
message := []byte("test content")
hkdf := crypto.HKDFSHA256(message, nil, nil)
hkdf2 := crypto.HKDFSHA256(message, nil, nil)
result := make([]byte, 32)
if _, err := io.ReadFull(hkdf, result); err != nil {
t.Fatal(err)
}
_, err := io.ReadFull(hkdf, result)
assert.NoError(t, err)
hkdf2 := crypto.HKDFSHA256(message, nil, nil)
result2 := make([]byte, 32)
if _, err := io.ReadFull(hkdf2, result2); err != nil {
t.Fatal(err)
}
if !bytes.Equal(result, result2) {
t.Fail()
}
_, err = io.ReadFull(hkdf2, result2)
assert.NoError(t, err)
assert.Equal(t, result, result2)
}
func TestSha256Case1(t *testing.T) {
func TestSHA256Case1(t *testing.T) {
input := make([]byte, 0)
expected := []byte{
0xE3, 0xB0, 0xC4, 0x42, 0x98, 0xFC, 0x1C, 0x14,
@ -52,9 +47,7 @@ func TestSha256Case1(t *testing.T) {
0xA4, 0x95, 0x99, 0x1B, 0x78, 0x52, 0xB8, 0x55,
}
result := crypto.SHA256(input)
if !bytes.Equal(expected, result) {
t.Fatalf("result not as expected:\n%v\n%v\n", result, expected)
}
assert.Equal(t, expected, result)
}
func TestHMACCase1(t *testing.T) {
@ -66,9 +59,7 @@ func TestHMACCase1(t *testing.T) {
0xc6, 0xc7, 0x12, 0x14, 0x42, 0x92, 0xc5, 0xad,
}
result := crypto.HMACSHA256(input, input)
if !bytes.Equal(expected, result) {
t.Fatalf("result not as expected:\n%v\n%v\n", result, expected)
}
assert.Equal(t, expected, result)
}
func TestHDKFCase1(t *testing.T) {
@ -92,9 +83,8 @@ func TestHDKFCase1(t *testing.T) {
0x22, 0xec, 0x84, 0x4a, 0xd7, 0xc2, 0xb3, 0xe5,
}
result := crypto.HMACSHA256(salt, input)
if !bytes.Equal(expectedHMAC, result) {
t.Fatalf("result not as expected:\n%v\n%v\n", result, expectedHMAC)
}
assert.Equal(t, expectedHMAC, result)
expectedHDKF := []byte{
0x3c, 0xb2, 0x5f, 0x25, 0xfa, 0xac, 0xd5, 0x7a,
0x90, 0x43, 0x4f, 0x64, 0xd0, 0x36, 0x2f, 0x2a,
@ -105,10 +95,7 @@ func TestHDKFCase1(t *testing.T) {
}
resultReader := crypto.HKDFSHA256(input, salt, info)
result = make([]byte, len(expectedHDKF))
if _, err := io.ReadFull(resultReader, result); err != nil {
t.Fatal(err)
}
if !bytes.Equal(expectedHDKF, result) {
t.Fatalf("result not as expected:\n%v\n%v\n", result, expectedHDKF)
}
_, err := io.ReadFull(resultReader, result)
assert.NoError(t, err)
assert.Equal(t, expectedHDKF, result)
}

View file

@ -1,9 +1,10 @@
package libolmpickle_test
import (
"bytes"
"testing"
"github.com/stretchr/testify/assert"
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
)
@ -23,12 +24,8 @@ func TestPickleUInt32(t *testing.T) {
for curIndex := range values {
response := make([]byte, 4)
resPLen := libolmpickle.PickleUInt32(values[curIndex], response)
if resPLen != libolmpickle.PickleUInt32Len(values[curIndex]) {
t.Fatal("written bytes not correct")
}
if !bytes.Equal(response, expected[curIndex]) {
t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex])
}
assert.Equal(t, libolmpickle.PickleUInt32Len(values[curIndex]), resPLen)
assert.Equal(t, expected[curIndex], response)
}
}
@ -44,12 +41,8 @@ func TestPickleBool(t *testing.T) {
for curIndex := range values {
response := make([]byte, 1)
resPLen := libolmpickle.PickleBool(values[curIndex], response)
if resPLen != libolmpickle.PickleBoolLen(values[curIndex]) {
t.Fatal("written bytes not correct")
}
if !bytes.Equal(response, expected[curIndex]) {
t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex])
}
assert.Equal(t, libolmpickle.PickleBoolLen(values[curIndex]), resPLen)
assert.Equal(t, expected[curIndex], response)
}
}
@ -65,12 +58,8 @@ func TestPickleUInt8(t *testing.T) {
for curIndex := range values {
response := make([]byte, 1)
resPLen := libolmpickle.PickleUInt8(values[curIndex], response)
if resPLen != libolmpickle.PickleUInt8Len(values[curIndex]) {
t.Fatal("written bytes not correct")
}
if !bytes.Equal(response, expected[curIndex]) {
t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex])
}
assert.Equal(t, libolmpickle.PickleUInt8Len(values[curIndex]), resPLen)
assert.Equal(t, expected[curIndex], response)
}
}
@ -88,11 +77,7 @@ func TestPickleBytes(t *testing.T) {
for curIndex := range values {
response := make([]byte, len(values[curIndex]))
resPLen := libolmpickle.PickleBytes(values[curIndex], response)
if resPLen != libolmpickle.PickleBytesLen(values[curIndex]) {
t.Fatal("written bytes not correct")
}
if !bytes.Equal(response, expected[curIndex]) {
t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex])
}
assert.Equal(t, libolmpickle.PickleBytesLen(values[curIndex]), resPLen)
assert.Equal(t, expected[curIndex], response)
}
}

View file

@ -1,9 +1,10 @@
package libolmpickle_test
import (
"bytes"
"testing"
"github.com/stretchr/testify/assert"
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
)
@ -20,15 +21,9 @@ func TestUnpickleUInt32(t *testing.T) {
}
for curIndex := range values {
response, readLength, err := libolmpickle.UnpickleUInt32(values[curIndex])
if err != nil {
t.Fatal(err)
}
if readLength != 4 {
t.Fatal("read bytes not correct")
}
if response != expected[curIndex] {
t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex])
}
assert.NoError(t, err)
assert.Equal(t, 4, readLength)
assert.Equal(t, expected[curIndex], response)
}
}
@ -45,15 +40,9 @@ func TestUnpickleBool(t *testing.T) {
}
for curIndex := range values {
response, readLength, err := libolmpickle.UnpickleBool(values[curIndex])
if err != nil {
t.Fatal(err)
}
if readLength != 1 {
t.Fatal("read bytes not correct")
}
if response != expected[curIndex] {
t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex])
}
assert.NoError(t, err)
assert.Equal(t, 1, readLength)
assert.Equal(t, expected[curIndex], response)
}
}
@ -68,15 +57,9 @@ func TestUnpickleUInt8(t *testing.T) {
}
for curIndex := range values {
response, readLength, err := libolmpickle.UnpickleUInt8(values[curIndex])
if err != nil {
t.Fatal(err)
}
if readLength != 1 {
t.Fatal("read bytes not correct")
}
if response != expected[curIndex] {
t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex])
}
assert.NoError(t, err)
assert.Equal(t, 1, readLength)
assert.Equal(t, expected[curIndex], response)
}
}
@ -93,14 +76,8 @@ func TestUnpickleBytes(t *testing.T) {
}
for curIndex := range values {
response, readLength, err := libolmpickle.UnpickleBytes(values[curIndex], 4)
if err != nil {
t.Fatal(err)
}
if readLength != 4 {
t.Fatal("read bytes not correct")
}
if !bytes.Equal(response, expected[curIndex]) {
t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex])
}
assert.NoError(t, err)
assert.Equal(t, 4, readLength)
assert.Equal(t, expected[curIndex], response)
}
}

View file

@ -1,9 +1,10 @@
package megolm_test
import (
"bytes"
"testing"
"github.com/stretchr/testify/assert"
"maunium.net/go/mautrix/crypto/goolm/megolm"
)
@ -19,9 +20,7 @@ func init() {
func TestAdvance(t *testing.T) {
m, err := megolm.New(0, startData)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
expectedData := [megolm.RatchetParts * megolm.RatchetPartLength]byte{
0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46,
@ -34,9 +33,7 @@ func TestAdvance(t *testing.T) {
0x89, 0xbb, 0xb4, 0x23, 0xa1, 0x8f, 0x23, 0x82, 0x8f, 0xb2, 0x09, 0x0d, 0x6e, 0x2a, 0xf8, 0x6a,
}
m.Advance()
if !bytes.Equal(m.Data[:], expectedData[:]) {
t.Fatal("result after advancing the ratchet is not as expected")
}
assert.Equal(t, m.Data[:], expectedData[:], "result after advancing the ratchet is not as expected")
//repeat with complex advance
m.Data = startData
@ -51,9 +48,8 @@ func TestAdvance(t *testing.T) {
0x89, 0xbb, 0xb4, 0x23, 0xa1, 0x8f, 0x23, 0x82, 0x8f, 0xb2, 0x09, 0x0d, 0x6e, 0x2a, 0xf8, 0x6a,
}
m.AdvanceTo(0x1000000)
if !bytes.Equal(m.Data[:], expectedData[:]) {
t.Fatal("result after advancing the ratchet is not as expected")
}
assert.Equal(t, m.Data[:], expectedData[:], "result after advancing the ratchet is not as expected")
expectedData = [megolm.RatchetParts * megolm.RatchetPartLength]byte{
0x54, 0x02, 0x2d, 0x7d, 0xc0, 0x29, 0x8e, 0x16, 0x37, 0xe2, 0x1c, 0x97, 0x15, 0x30, 0x92, 0xf9,
0x33, 0xc0, 0x56, 0xff, 0x74, 0xfe, 0x1b, 0x92, 0x2d, 0x97, 0x1f, 0x24, 0x82, 0xc2, 0x85, 0x9c,
@ -65,77 +61,45 @@ func TestAdvance(t *testing.T) {
0xd5, 0x6f, 0x03, 0xe2, 0x44, 0x16, 0xb9, 0x8e, 0x1c, 0xfd, 0x97, 0xc2, 0x06, 0xaa, 0x90, 0x7a,
}
m.AdvanceTo(0x1041506)
if !bytes.Equal(m.Data[:], expectedData[:]) {
t.Fatal("result after advancing the ratchet is not as expected")
}
assert.Equal(t, m.Data[:], expectedData[:], "result after advancing the ratchet is not as expected")
}
func TestAdvanceWraparound(t *testing.T) {
m, err := megolm.New(0xffffffff, startData)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
m.AdvanceTo(0x1000000)
if m.Counter != 0x1000000 {
t.Fatal("counter not correct")
}
assert.EqualValues(t, 0x1000000, m.Counter, "counter not correct")
m2, err := megolm.New(0, startData)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
m2.AdvanceTo(0x2000000)
if m2.Counter != 0x2000000 {
t.Fatal("counter not correct")
}
if !bytes.Equal(m.Data[:], m2.Data[:]) {
t.Fatal("result after wrapping the ratchet is not as expected")
}
assert.EqualValues(t, 0x2000000, m2.Counter, "counter not correct")
assert.Equal(t, m.Data, m2.Data, "result after wrapping the ratchet is not as expected")
}
func TestAdvanceOverflowByOne(t *testing.T) {
m, err := megolm.New(0xffffffff, startData)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
m.AdvanceTo(0x0)
if m.Counter != 0x0 {
t.Fatal("counter not correct")
}
assert.EqualValues(t, 0x0, m.Counter, "counter not correct")
m2, err := megolm.New(0xffffffff, startData)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
m2.Advance()
if m2.Counter != 0x0 {
t.Fatal("counter not correct")
}
if !bytes.Equal(m.Data[:], m2.Data[:]) {
t.Fatal("result after wrapping the ratchet is not as expected")
}
assert.EqualValues(t, 0x0, m2.Counter, "counter not correct")
assert.Equal(t, m.Data, m2.Data, "result after wrapping the ratchet is not as expected")
}
func TestAdvanceOverflow(t *testing.T) {
m, err := megolm.New(0x1, startData)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
m.AdvanceTo(0x80000000)
m.AdvanceTo(0x0)
if m.Counter != 0x0 {
t.Fatal("counter not correct")
}
assert.EqualValues(t, 0x0, m.Counter, "counter not correct")
m2, err := megolm.New(0x1, startData)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
m2.AdvanceTo(0x0)
if m2.Counter != 0x0 {
t.Fatal("counter not correct")
}
if !bytes.Equal(m.Data[:], m2.Data[:]) {
t.Fatal("result after wrapping the ratchet is not as expected")
}
assert.EqualValues(t, 0x0, m2.Counter, "counter not correct")
assert.Equal(t, m.Data, m2.Data, "result after wrapping the ratchet is not as expected")
}

View file

@ -1,17 +1,16 @@
package message
import (
"bytes"
"testing"
"github.com/stretchr/testify/assert"
)
func TestEncodeLengthInt(t *testing.T) {
numbers := []uint32{127, 128, 16383, 16384, 32767}
expected := []int{1, 2, 2, 3, 3}
for curIndex := range numbers {
if result := encodeVarIntByteLength(numbers[curIndex]); result != expected[curIndex] {
t.Fatalf("expected byte length of %d but got %d", expected[curIndex], result)
}
assert.Equal(t, expected[curIndex], encodeVarIntByteLength(numbers[curIndex]))
}
}
@ -25,9 +24,7 @@ func TestEncodeLengthString(t *testing.T) {
strings = append(strings, []byte("this is an even longer message with a length between 128 and 16383 so that the varint of the length needs two byte. just needs some padding again ---------"))
expected = append(expected, 2+155)
for curIndex := range strings {
if result := encodeVarStringByteLength(strings[curIndex]); result != expected[curIndex] {
t.Fatalf("expected byte length of %d but got %d", expected[curIndex], result)
}
assert.Equal(t, expected[curIndex], encodeVarStringByteLength(strings[curIndex]))
}
}
@ -43,9 +40,7 @@ func TestEncodeInt(t *testing.T) {
ints = append(ints, 16383)
expected = append(expected, []byte{0b11111111, 0b01111111})
for curIndex := range ints {
if result := encodeVarInt(ints[curIndex]); !bytes.Equal(result, expected[curIndex]) {
t.Fatalf("expected byte of %b but got %b", expected[curIndex], result)
}
assert.Equal(t, expected[curIndex], encodeVarInt(ints[curIndex]))
}
}
@ -75,8 +70,6 @@ func TestEncodeString(t *testing.T) {
res = append(res, curTest...) //Add string itself
expected = append(expected, res)
for curIndex := range strings {
if result := encodeVarString(strings[curIndex]); !bytes.Equal(result, expected[curIndex]) {
t.Fatalf("expected byte of %b but got %b", expected[curIndex], result)
}
assert.Equal(t, expected[curIndex], encodeVarString(strings[curIndex]))
}
}

View file

@ -1,9 +1,10 @@
package message_test
import (
"bytes"
"testing"
"github.com/stretchr/testify/assert"
"maunium.net/go/mautrix/crypto/goolm/message"
)
@ -16,18 +17,10 @@ func TestGroupMessageDecode(t *testing.T) {
msg := message.GroupMessage{}
err := msg.Decode(messageRaw)
if err != nil {
t.Fatal(err)
}
if msg.Version != 3 {
t.Fatalf("Expected Version to be 3 but go %d", msg.Version)
}
if msg.MessageIndex != expectedMessageIndex {
t.Fatalf("Expected message index to be %d but got %d", expectedMessageIndex, msg.MessageIndex)
}
if !bytes.Equal(msg.Ciphertext, expectedCipherText) {
t.Fatalf("expected '%s' but got '%s'", expectedCipherText, msg.Ciphertext)
}
assert.NoError(t, err)
assert.EqualValues(t, 3, msg.Version)
assert.Equal(t, expectedMessageIndex, msg.MessageIndex)
assert.Equal(t, expectedCipherText, msg.Ciphertext)
}
func TestGroupMessageEncode(t *testing.T) {
@ -40,12 +33,8 @@ func TestGroupMessageEncode(t *testing.T) {
Ciphertext: []byte("ciphertext"),
}
encoded, err := msg.EncodeAndMacAndSign(nil, nil, nil)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
encoded = append(encoded, hmacsha256...)
encoded = append(encoded, sign...)
if !bytes.Equal(encoded, expectedRaw) {
t.Fatalf("expected '%s' but got '%s'", expectedRaw, encoded)
}
assert.Equal(t, expectedRaw, encoded)
}

View file

@ -1,9 +1,10 @@
package message_test
import (
"bytes"
"testing"
"github.com/stretchr/testify/assert"
"maunium.net/go/mautrix/crypto/goolm/message"
)
@ -14,24 +15,12 @@ func TestMessageDecode(t *testing.T) {
msg := message.Message{}
err := msg.Decode(messageRaw)
if err != nil {
t.Fatal(err)
}
if msg.Version != 3 {
t.Fatalf("Expected Version to be 3 but go %d", msg.Version)
}
if !msg.HasCounter {
t.Fatal("Expected to have counter")
}
if msg.Counter != 1 {
t.Fatalf("Expected counter to be 1 but got %d", msg.Counter)
}
if !bytes.Equal(msg.Ciphertext, expectedCipherText) {
t.Fatalf("expected '%s' but got '%s'", expectedCipherText, msg.Ciphertext)
}
if !bytes.Equal(msg.RatchetKey, expectedRatchetKey) {
t.Fatalf("expected '%s' but got '%s'", expectedRatchetKey, msg.RatchetKey)
}
assert.NoError(t, err)
assert.EqualValues(t, 3, msg.Version)
assert.True(t, msg.HasCounter)
assert.EqualValues(t, 1, msg.Counter)
assert.Equal(t, expectedCipherText, msg.Ciphertext)
assert.EqualValues(t, expectedRatchetKey, msg.RatchetKey)
}
func TestMessageEncode(t *testing.T) {
@ -44,11 +33,7 @@ func TestMessageEncode(t *testing.T) {
Ciphertext: []byte("ciphertext"),
}
encoded, err := msg.EncodeAndMAC(nil, nil)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
encoded = append(encoded, hmacsha256...)
if !bytes.Equal(encoded, expectedRaw) {
t.Fatalf("expected '%s' but got '%s'", expectedRaw, encoded)
}
assert.Equal(t, expectedRaw, encoded)
}

View file

@ -1,9 +1,10 @@
package message_test
import (
"bytes"
"testing"
"github.com/stretchr/testify/assert"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/crypto/goolm/message"
)
@ -19,29 +20,14 @@ func TestPreKeyMessageDecode(t *testing.T) {
msg := message.PreKeyMessage{}
err := msg.Decode(messageRaw)
if err != nil {
t.Fatal(err)
}
if msg.Version != 3 {
t.Fatalf("Expected Version to be 3 but go %d", msg.Version)
}
if !bytes.Equal(msg.OneTimeKey, expectedOneTimeKey) {
t.Fatalf("expected '%s' but got '%s'", expectedOneTimeKey, msg.OneTimeKey)
}
if !bytes.Equal(msg.IdentityKey, expectedIdKey) {
t.Fatalf("expected '%s' but got '%s'", expectedIdKey, msg.IdentityKey)
}
if !bytes.Equal(msg.BaseKey, expectedbaseKey) {
t.Fatalf("expected '%s' but got '%s'", expectedbaseKey, msg.BaseKey)
}
if !bytes.Equal(msg.Message, expectedmessage) {
t.Fatalf("expected '%s' but got '%s'", expectedmessage, msg.Message)
}
assert.NoError(t, err)
assert.EqualValues(t, 3, msg.Version)
assert.EqualValues(t, expectedOneTimeKey, msg.OneTimeKey)
assert.EqualValues(t, expectedIdKey, msg.IdentityKey)
assert.EqualValues(t, expectedbaseKey, msg.BaseKey)
assert.Equal(t, expectedmessage, msg.Message)
theirIDKey := crypto.Curve25519PublicKey(expectedIdKey)
checked := msg.CheckFields(&theirIDKey)
if !checked {
t.Fatal("field check failed")
}
assert.True(t, msg.CheckFields(&theirIDKey), "field check failed")
}
func TestPreKeyMessageEncode(t *testing.T) {
@ -54,10 +40,6 @@ func TestPreKeyMessageEncode(t *testing.T) {
Message: []byte("message"),
}
encoded, err := msg.Encode()
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(encoded, expectedRaw) {
t.Fatalf("got other than expected:\nExpected:\n%v\nGot:\n%v", expectedRaw, encoded)
}
assert.NoError(t, err)
assert.Equal(t, expectedRaw, encoded)
}

View file

@ -1,12 +1,12 @@
package pk_test
import (
"bytes"
"encoding/base64"
"testing"
"github.com/stretchr/testify/assert"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/crypto/goolm/goolmbase64"
"maunium.net/go/mautrix/crypto/goolm/pk"
)
@ -26,34 +26,20 @@ func TestEncryptionDecryption(t *testing.T) {
}
bobPublic := []byte("3p7bfXt9wbTTW2HC7OQ1Nz+DQ8hbeGdNrfx+FG+IK08")
decryption, err := pk.NewDecryptionFromPrivate(alicePrivate)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal([]byte(decryption.PublicKey()), alicePublic) {
t.Fatal("public key not correct")
}
if !bytes.Equal(decryption.PrivateKey(), alicePrivate) {
t.Fatal("private key not correct")
}
assert.NoError(t, err)
assert.EqualValues(t, alicePublic, decryption.PublicKey(), "public key not correct")
assert.EqualValues(t, alicePrivate, decryption.PrivateKey(), "private key not correct")
encryption, err := pk.NewEncryption(decryption.PublicKey())
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
plaintext := []byte("This is a test")
ciphertext, mac, err := encryption.Encrypt(plaintext, bobPrivate)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
decrypted, err := decryption.Decrypt(bobPublic, mac, ciphertext)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(decrypted, plaintext) {
t.Fatal("message not equal")
}
assert.NoError(t, err)
assert.EqualValues(t, plaintext, decrypted, "message not equal")
}
func TestSigning(t *testing.T) {
@ -66,29 +52,20 @@ func TestSigning(t *testing.T) {
message := []byte("We hold these truths to be self-evident, that all men are created equal, that they are endowed by their Creator with certain unalienable Rights, that among these are Life, Liberty and the pursuit of Happiness.")
signing, _ := pk.NewSigningFromSeed(seed)
signature, err := signing.Sign(message)
if err != nil {
t.Fatal(err)
}
signatureDecoded, err := goolmbase64.Decode(signature)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
signatureDecoded, err := base64.RawStdEncoding.DecodeString(string(signature))
assert.NoError(t, err)
pubKeyEncoded := signing.PublicKey()
pubKeyDecoded, err := base64.RawStdEncoding.DecodeString(string(pubKeyEncoded))
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
pubKey := crypto.Ed25519PublicKey(pubKeyDecoded)
verified := pubKey.Verify(message, signatureDecoded)
if !verified {
t.Fatal("signature did not verify")
}
assert.True(t, verified, "signature did not verify")
copy(signatureDecoded[0:], []byte("m"))
verified = pubKey.Verify(message, signatureDecoded)
if verified {
t.Fatal("signature did verify")
}
assert.False(t, verified, "signature verified with wrong message")
}
func TestDecryptionPickling(t *testing.T) {
@ -100,37 +77,19 @@ func TestDecryptionPickling(t *testing.T) {
}
alicePublic := []byte("hSDwCYkwp1R0i33ctD73Wg2/Og0mOBr066SpjqqbTmo")
decryption, err := pk.NewDecryptionFromPrivate(alicePrivate)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal([]byte(decryption.PublicKey()), alicePublic) {
t.Fatal("public key not correct")
}
if !bytes.Equal(decryption.PrivateKey(), alicePrivate) {
t.Fatal("private key not correct")
}
assert.NoError(t, err)
assert.EqualValues(t, alicePublic, decryption.PublicKey(), "public key not correct")
assert.EqualValues(t, alicePrivate, decryption.PrivateKey(), "private key not correct")
pickleKey := []byte("secret_key")
expectedPickle := []byte("qx37WTQrjZLz5tId/uBX9B3/okqAbV1ofl9UnHKno1eipByCpXleAAlAZoJgYnCDOQZDQWzo3luTSfkF9pU1mOILCbbouubs6TVeDyPfgGD9i86J8irHjA")
pickled, err := decryption.Pickle(pickleKey)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(expectedPickle, pickled) {
t.Fatalf("pickle not as expected:\n%v\n%v\n", pickled, expectedPickle)
}
assert.NoError(t, err)
assert.EqualValues(t, expectedPickle, pickled, "pickle not as expected")
newDecription, err := pk.NewDecryption()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
err = newDecription.Unpickle(pickled, pickleKey)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal([]byte(newDecription.PublicKey()), alicePublic) {
t.Fatal("public key not correct")
}
if !bytes.Equal(newDecription.PrivateKey(), alicePrivate) {
t.Fatal("private key not correct")
}
assert.NoError(t, err)
assert.EqualValues(t, alicePublic, newDecription.PublicKey(), "public key not correct")
assert.EqualValues(t, alicePrivate, newDecription.PrivateKey(), "private key not correct")
}

View file

@ -1,10 +1,11 @@
package ratchet_test
import (
"bytes"
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"maunium.net/go/mautrix/crypto/goolm/cipher"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/crypto/goolm/ratchet"
@ -38,149 +39,90 @@ func initializeRatchets() (*ratchet.Ratchet, *ratchet.Ratchet, error) {
func TestSendReceive(t *testing.T) {
aliceRatchet, bobRatchet, err := initializeRatchets()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
plainText := []byte("Hello Bob")
//Alice sends Bob a message
encryptedMessage, err := aliceRatchet.Encrypt(plainText)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
decrypted, err := bobRatchet.Decrypt(encryptedMessage)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(plainText, decrypted) {
t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted)
}
assert.NoError(t, err)
assert.Equal(t, plainText, decrypted)
//Bob sends Alice a message
plainText = []byte("Hello Alice")
encryptedMessage, err = bobRatchet.Encrypt(plainText)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
decrypted, err = aliceRatchet.Decrypt(encryptedMessage)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(plainText, decrypted) {
t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted)
}
assert.NoError(t, err)
assert.Equal(t, plainText, decrypted)
}
func TestOutOfOrder(t *testing.T) {
aliceRatchet, bobRatchet, err := initializeRatchets()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
plainText1 := []byte("First Message")
plainText2 := []byte("Second Messsage. A bit longer than the first.")
/* Alice sends Bob two messages and they arrive out of order */
message1Encrypted, err := aliceRatchet.Encrypt(plainText1)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
message2Encrypted, err := aliceRatchet.Encrypt(plainText2)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
decrypted2, err := bobRatchet.Decrypt(message2Encrypted)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
decrypted1, err := bobRatchet.Decrypt(message1Encrypted)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(plainText1, decrypted1) {
t.Fatalf("expected '%v' from decryption but got '%v'", plainText1, decrypted1)
}
if !bytes.Equal(plainText2, decrypted2) {
t.Fatalf("expected '%v' from decryption but got '%v'", plainText2, decrypted2)
}
assert.NoError(t, err)
assert.Equal(t, plainText1, decrypted1)
assert.Equal(t, plainText2, decrypted2)
}
func TestMoreMessages(t *testing.T) {
aliceRatchet, bobRatchet, err := initializeRatchets()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
plainText := []byte("These 15 bytes")
for i := 0; i < 8; i++ {
messageEncrypted, err := aliceRatchet.Encrypt(plainText)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
decrypted, err := bobRatchet.Decrypt(messageEncrypted)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(plainText, decrypted) {
t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted)
}
assert.NoError(t, err)
assert.Equal(t, plainText, decrypted)
}
for i := 0; i < 8; i++ {
messageEncrypted, err := bobRatchet.Encrypt(plainText)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
decrypted, err := aliceRatchet.Decrypt(messageEncrypted)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(plainText, decrypted) {
t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted)
}
assert.NoError(t, err)
assert.Equal(t, plainText, decrypted)
}
messageEncrypted, err := aliceRatchet.Encrypt(plainText)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
decrypted, err := bobRatchet.Decrypt(messageEncrypted)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(plainText, decrypted) {
t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted)
}
assert.NoError(t, err)
assert.Equal(t, plainText, decrypted)
}
func TestJSONEncoding(t *testing.T) {
aliceRatchet, bobRatchet, err := initializeRatchets()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
marshaled, err := json.Marshal(aliceRatchet)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
newRatcher := ratchet.Ratchet{}
err = json.Unmarshal(marshaled, &newRatcher)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
plainText := []byte("These 15 bytes")
messageEncrypted, err := newRatcher.Encrypt(plainText)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
decrypted, err := bobRatchet.Decrypt(messageEncrypted)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(plainText, decrypted) {
t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted)
}
assert.NoError(t, err)
assert.Equal(t, plainText, decrypted)
}

View file

@ -1,11 +1,12 @@
package session_test
import (
"bytes"
"crypto/rand"
"errors"
"encoding/base64"
"testing"
"github.com/stretchr/testify/assert"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/crypto/goolm/megolm"
"maunium.net/go/mautrix/crypto/goolm/session"
@ -15,78 +16,42 @@ import (
func TestOutboundPickleJSON(t *testing.T) {
pickleKey := []byte("secretKey")
sess, err := session.NewMegolmOutboundSession()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
kp, err := crypto.Ed25519GenerateKey()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
sess.SigningKey = kp
pickled, err := sess.PickleAsJSON(pickleKey)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
newSession := session.MegolmOutboundSession{}
err = newSession.UnpickleAsJSON(pickled, pickleKey)
if err != nil {
t.Fatal(err)
}
if sess.ID() != newSession.ID() {
t.Fatal("session ids not equal")
}
if !bytes.Equal(sess.SigningKey.PrivateKey, newSession.SigningKey.PrivateKey) {
t.Fatal("private keys not equal")
}
if !bytes.Equal(sess.Ratchet.Data[:], newSession.Ratchet.Data[:]) {
t.Fatal("ratchet data not equal")
}
if sess.Ratchet.Counter != newSession.Ratchet.Counter {
t.Fatal("ratchet counter not equal")
}
assert.NoError(t, err)
assert.Equal(t, sess.ID(), newSession.ID())
assert.Equal(t, sess.SigningKey, newSession.SigningKey)
assert.Equal(t, sess.Ratchet, newSession.Ratchet)
}
func TestInboundPickleJSON(t *testing.T) {
pickleKey := []byte("secretKey")
sess := session.MegolmInboundSession{}
kp, err := crypto.Ed25519GenerateKey()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
sess.SigningKey = kp.PublicKey
var randomData [megolm.RatchetParts * megolm.RatchetPartLength]byte
_, err = rand.Read(randomData[:])
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
ratchet, err := megolm.New(0, randomData)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
sess.Ratchet = *ratchet
pickled, err := sess.PickleAsJSON(pickleKey)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
newSession := session.MegolmInboundSession{}
err = newSession.UnpickleAsJSON(pickled, pickleKey)
if err != nil {
t.Fatal(err)
}
if sess.ID() != newSession.ID() {
t.Fatal("sess ids not equal")
}
if !bytes.Equal(sess.SigningKey, newSession.SigningKey) {
t.Fatal("private keys not equal")
}
if !bytes.Equal(sess.Ratchet.Data[:], newSession.Ratchet.Data[:]) {
t.Fatal("ratchet data not equal")
}
if sess.Ratchet.Counter != newSession.Ratchet.Counter {
t.Fatal("ratchet counter not equal")
}
assert.NoError(t, err)
assert.Equal(t, sess.ID(), newSession.ID())
assert.Equal(t, sess.SigningKey, newSession.SigningKey)
assert.Equal(t, sess.Ratchet, newSession.Ratchet)
}
func TestGroupSendReceive(t *testing.T) {
@ -100,46 +65,27 @@ func TestGroupSendReceive(t *testing.T) {
)
outboundSession, err := session.NewMegolmOutboundSession()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
copy(outboundSession.Ratchet.Data[:], randomData)
if outboundSession.Ratchet.Counter != 0 {
t.Fatal("ratchet counter is not correkt")
}
assert.EqualValues(t, 0, outboundSession.Ratchet.Counter)
sessionSharing, err := outboundSession.SessionSharingMessage()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
plainText := []byte("Message")
ciphertext, err := outboundSession.Encrypt(plainText)
if err != nil {
t.Fatal(err)
}
if outboundSession.Ratchet.Counter != 1 {
t.Fatal("ratchet counter is not correkt")
}
assert.NoError(t, err)
assert.EqualValues(t, 1, outboundSession.Ratchet.Counter)
//build inbound session
inboundSession, err := session.NewMegolmInboundSession(sessionSharing)
if err != nil {
t.Fatal(err)
}
if !inboundSession.SigningKeyVerified {
t.Fatal("key not verified")
}
if inboundSession.ID() != outboundSession.ID() {
t.Fatal("session ids not equal")
}
assert.NoError(t, err)
assert.True(t, inboundSession.SigningKeyVerified)
assert.Equal(t, outboundSession.ID(), inboundSession.ID())
//decode message
decoded, _, err := inboundSession.Decrypt(ciphertext)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(plainText, decoded) {
t.Fatal("messages not equal")
}
assert.NoError(t, err)
assert.Equal(t, plainText, decoded)
}
func TestGroupSessionExportImport(t *testing.T) {
@ -158,45 +104,26 @@ func TestGroupSessionExportImport(t *testing.T) {
//init inbound
inboundSession, err := session.NewMegolmInboundSession(sessionKey)
if err != nil {
t.Fatal(err)
}
if !inboundSession.SigningKeyVerified {
t.Fatal("signing key not verified")
}
assert.NoError(t, err)
assert.True(t, inboundSession.SigningKeyVerified)
decrypted, _, err := inboundSession.Decrypt(message)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(plaintext, decrypted) {
t.Fatal("message is not correct")
}
assert.NoError(t, err)
assert.Equal(t, plaintext, decrypted)
//Export the keys
exported, err := inboundSession.Export(0)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
secondInboundSession, err := session.NewMegolmInboundSessionFromExport(exported)
if err != nil {
t.Fatal(err)
}
if secondInboundSession.SigningKeyVerified {
t.Fatal("signing key is verified")
}
assert.NoError(t, err)
assert.False(t, secondInboundSession.SigningKeyVerified)
//decrypt with new session
decrypted, _, err = secondInboundSession.Decrypt(message)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(plaintext, decrypted) {
t.Fatal("message is not correct")
}
if !secondInboundSession.SigningKeyVerified {
t.Fatal("signing key not verified")
}
assert.NoError(t, err)
assert.Equal(t, plaintext, decrypted)
assert.True(t, secondInboundSession.SigningKeyVerified)
}
func TestBadSignatureGroupMessage(t *testing.T) {
@ -215,70 +142,43 @@ func TestBadSignatureGroupMessage(t *testing.T) {
//init inbound
inboundSession, err := session.NewMegolmInboundSession(sessionKey)
if err != nil {
t.Fatal(err)
}
if !inboundSession.SigningKeyVerified {
t.Fatal("signing key not verified")
}
assert.NoError(t, err)
assert.True(t, inboundSession.SigningKeyVerified)
decrypted, _, err := inboundSession.Decrypt(message)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(plaintext, decrypted) {
t.Fatal("message is not correct")
}
assert.NoError(t, err)
assert.Equal(t, plaintext, decrypted)
//Now twiddle the signature
copy(message[len(message)-1:], []byte("E"))
_, _, err = inboundSession.Decrypt(message)
if err == nil {
t.Fatal("Signature was changed but did not cause an error")
}
if !errors.Is(err, olm.ErrBadSignature) {
t.Fatalf("wrong error %s", err.Error())
}
assert.ErrorIs(t, err, olm.ErrBadSignature)
}
func TestOutbountPickle(t *testing.T) {
pickledDataFromLibOlm := []byte("icDKYm0b4aO23WgUuOxdpPoxC0UlEOYPVeuduNH3IkpFsmnWx5KuEOpxGiZw5IuB/sSn2RZUCTiJ90IvgC7AClkYGHep9O8lpiqQX73XVKD9okZDCAkBc83eEq0DKYC7HBkGRAU/4T6QPIBBY3UK4QZwULLE/fLsi3j4YZBehMtnlsqgHK0q1bvX4cRznZItUO3TiOp5I+6PnQka6n8eHTyIEh3tCetilD+BKnHvtakE0eHHvG6pjEsMNN/vs7lkB5rV6XkoUKHLTE1dAfFunYEeHEZuKQpbG385dBwaMJXt4JrC0hU5jnv6jWNqAA0Ud9GxRDvkp04")
pickleKey := []byte("secret_key")
sess, err := session.MegolmOutboundSessionFromPickled(pickledDataFromLibOlm, pickleKey)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
newPickled, err := sess.Pickle(pickleKey)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(pickledDataFromLibOlm, newPickled) {
t.Fatal("pickled version does not equal libolm version")
}
assert.NoError(t, err)
assert.Equal(t, pickledDataFromLibOlm, newPickled)
pickledDataFromLibOlm = append(pickledDataFromLibOlm, []byte("a")...)
_, err = session.MegolmOutboundSessionFromPickled(pickledDataFromLibOlm, pickleKey)
if err == nil {
t.Fatal("should have gotten an error")
}
assert.ErrorIs(t, err, olm.ErrBadMAC)
}
func TestInbountPickle(t *testing.T) {
pickledDataFromLibOlm := []byte("1/IPCdtUoQxMba5XT7sjjUW0Hrs7no9duGFnhsEmxzFX2H3qtRc4eaFBRZYXxOBRTGZ6eMgy3IiSrgAQ1gUlSZf5Q4AVKeBkhvN4LZ6hdhQFv91mM+C2C55/4B9/gDjJEbDGiRgLoMqbWPDV+y0F4h0KaR1V1PiTCC7zCi4WdxJQ098nJLgDL4VSsDbnaLcSMO60FOYgRN4KsLaKUGkXiiUBWp4boFMCiuTTOiyH8XlH0e9uWc0vMLyGNUcO8kCbpAnx3v1JTIVan3WGsnGv4K8Qu4M8GAkZewpexrsb2BSNNeLclOV9/cR203Y5KlzXcpiWNXSs8XoB3TLEtHYMnjuakMQfyrcXKIQntg4xPD/+wvfqkcMg9i7pcplQh7X2OK5ylrMZQrZkJ1fAYBGbBz1tykWOjfrZ")
pickleKey := []byte("secret_key")
sess, err := session.MegolmInboundSessionFromPickled(pickledDataFromLibOlm, pickleKey)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
newPickled, err := sess.Pickle(pickleKey)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(pickledDataFromLibOlm, newPickled) {
t.Fatal("pickled version does not equal libolm version")
}
assert.NoError(t, err)
assert.Equal(t, pickledDataFromLibOlm, newPickled)
pickledDataFromLibOlm = append(pickledDataFromLibOlm, []byte("a")...)
_, err = session.MegolmInboundSessionFromPickled(pickledDataFromLibOlm, pickleKey)
if err == nil {
t.Fatal("should have gotten an error")
}
assert.ErrorIs(t, err, base64.CorruptInputError(416))
}

View file

@ -1,11 +1,11 @@
package session_test
import (
"bytes"
"encoding/base64"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/crypto/goolm/session"
"maunium.net/go/mautrix/crypto/olm"
@ -15,30 +15,18 @@ import (
func TestOlmSession(t *testing.T) {
pickleKey := []byte("secretKey")
aliceKeyPair, err := crypto.Curve25519GenerateKey()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
bobKeyPair, err := crypto.Curve25519GenerateKey()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
bobOneTimeKey, err := crypto.Curve25519GenerateKey()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
aliceSession, err := session.NewOutboundOlmSession(aliceKeyPair, bobKeyPair.PublicKey, bobOneTimeKey.PublicKey)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
//create a message so that there are more keys to marshal
plaintext := []byte("Test message from Alice to Bob")
msgType, message, err := aliceSession.Encrypt(plaintext)
if err != nil {
t.Fatal(err)
}
if msgType != id.OlmMsgTypePreKey {
t.Fatal("Wrong message type")
}
assert.NoError(t, err)
assert.Equal(t, id.OlmMsgTypePreKey, msgType)
searchFunc := func(target crypto.Curve25519PublicKey) *crypto.OneTimeKey {
if target.Equal(bobOneTimeKey.PublicKey) {
@ -52,92 +40,58 @@ func TestOlmSession(t *testing.T) {
}
//bob receives message
bobSession, err := session.NewInboundOlmSession(nil, message, searchFunc, bobKeyPair)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
decryptedMsg, err := bobSession.Decrypt(string(message), msgType)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(plaintext, decryptedMsg) {
t.Fatalf("messages are not equal:\n%v\n%v\n", plaintext, decryptedMsg)
}
assert.NoError(t, err)
assert.Equal(t, plaintext, decryptedMsg)
// Alice pickles session
pickled, err := aliceSession.PickleAsJSON(pickleKey)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
//bob sends a message
plaintext = []byte("A message from Bob to Alice")
msgType, message, err = bobSession.Encrypt(plaintext)
if err != nil {
t.Fatal(err)
}
if msgType != id.OlmMsgTypeMsg {
t.Fatal("Wrong message type")
}
assert.NoError(t, err)
assert.Equal(t, id.OlmMsgTypeMsg, msgType)
//Alice unpickles session
newAliceSession, err := session.OlmSessionFromJSONPickled(pickled, pickleKey)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
//Alice receives message
decryptedMsg, err = newAliceSession.Decrypt(string(message), msgType)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(plaintext, decryptedMsg) {
t.Fatalf("messages are not equal:\n%v\n%v\n", plaintext, decryptedMsg)
}
assert.NoError(t, err)
assert.Equal(t, plaintext, decryptedMsg)
//Alice receives message again
_, err = newAliceSession.Decrypt(string(message), msgType)
if err == nil {
t.Fatal("should have gotten an error")
}
assert.ErrorIs(t, err, olm.ErrMessageKeyNotFound)
//Alice sends another message
plaintext = []byte("A second message to Bob")
msgType, message, err = newAliceSession.Encrypt(plaintext)
if err != nil {
t.Fatal(err)
}
if msgType != id.OlmMsgTypeMsg {
t.Fatal("Wrong message type")
}
assert.NoError(t, err)
assert.Equal(t, id.OlmMsgTypeMsg, msgType)
//bob receives message
decryptedMsg, err = bobSession.Decrypt(string(message), msgType)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(plaintext, decryptedMsg) {
t.Fatalf("messages are not equal:\n%v\n%v\n", plaintext, decryptedMsg)
}
assert.NoError(t, err)
assert.Equal(t, plaintext, decryptedMsg)
}
func TestSessionPickle(t *testing.T) {
pickledDataFromLibOlm := []byte("icDKYm0b4aO23WgUuOxdpPoxC0UlEOYPVeuduNH3IkpFsmnWx5KuEOpxGiZw5IuB/sSn2RZUCTiJ90IvgC7AClkYGHep9O8lpiqQX73XVKD9okZDCAkBc83eEq0DKYC7HBkGRAU/4T6QPIBBY3UK4QZwULLE/fLsi3j4YZBehMtnlsqgHK0q1bvX4cRznZItVKR4ro0O9EAk6LLxJtSnRu5elSUk7YXT")
pickleKey := []byte("secret_key")
sess, err := session.OlmSessionFromPickled(pickledDataFromLibOlm, pickleKey)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
newPickled, err := sess.Pickle(pickleKey)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(pickledDataFromLibOlm, newPickled) {
t.Fatal("pickled version does not equal libolm version")
}
assert.NoError(t, err)
assert.Equal(t, pickledDataFromLibOlm, newPickled)
pickledDataFromLibOlm = append(pickledDataFromLibOlm, []byte("a")...)
_, err = session.OlmSessionFromPickled(pickledDataFromLibOlm, pickleKey)
if err == nil {
t.Fatal("should have gotten an error")
}
assert.ErrorIs(t, err, base64.CorruptInputError(224))
}
func TestDecrypts(t *testing.T) {
@ -161,17 +115,9 @@ func TestDecrypts(t *testing.T) {
"dGvPXeH8qLeNZA")
pickleKey := []byte("")
sess, err := session.OlmSessionFromPickled(sessionPickled, pickleKey)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
for curIndex, curMessage := range messages {
_, err := sess.Decrypt(string(curMessage), id.OlmMsgTypePreKey)
if err != nil {
if !errors.Is(err, expectedErr[curIndex]) {
t.Fatal(err)
}
} else {
t.Fatal("error expected")
}
assert.ErrorIs(t, err, expectedErr[curIndex])
}
}