mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
goolm: simplify tests using testify
Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
This commit is contained in:
parent
7cc46f1ff3
commit
eb632a9994
17 changed files with 488 additions and 1271 deletions
|
|
@ -1,13 +1,10 @@
|
|||
package account_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
|
|
@ -18,75 +15,42 @@ import (
|
|||
|
||||
func TestAccount(t *testing.T) {
|
||||
firstAccount, err := account.NewAccount()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
err = firstAccount.GenFallbackKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
err = firstAccount.GenOneTimeKeys(2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
encryptionKey := []byte("testkey")
|
||||
|
||||
//now pickle account in JSON format
|
||||
pickled, err := firstAccount.PickleAsJSON(encryptionKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
//now unpickle into new Account
|
||||
unpickledAccount, err := account.AccountFromJSONPickled(pickled, encryptionKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
//check if accounts are the same
|
||||
if firstAccount.NextOneTimeKeyID != unpickledAccount.NextOneTimeKeyID {
|
||||
t.Fatal("NextOneTimeKeyID unequal")
|
||||
}
|
||||
if !firstAccount.CurrentFallbackKey.Equal(unpickledAccount.CurrentFallbackKey) {
|
||||
t.Fatal("CurrentFallbackKey unequal")
|
||||
}
|
||||
if !firstAccount.PrevFallbackKey.Equal(unpickledAccount.PrevFallbackKey) {
|
||||
t.Fatal("PrevFallbackKey unequal")
|
||||
}
|
||||
if len(firstAccount.OTKeys) != len(unpickledAccount.OTKeys) {
|
||||
t.Fatal("OneTimeKeysunequal")
|
||||
}
|
||||
for i := range firstAccount.OTKeys {
|
||||
if !firstAccount.OTKeys[i].Equal(unpickledAccount.OTKeys[i]) {
|
||||
t.Fatalf("OneTimeKeys %d unequal", i)
|
||||
}
|
||||
}
|
||||
if !firstAccount.IdKeys.Curve25519.PrivateKey.Equal(unpickledAccount.IdKeys.Curve25519.PrivateKey) {
|
||||
t.Fatal("IdentityKeys Curve25519 private unequal")
|
||||
}
|
||||
if !firstAccount.IdKeys.Curve25519.PublicKey.Equal(unpickledAccount.IdKeys.Curve25519.PublicKey) {
|
||||
t.Fatal("IdentityKeys Curve25519 public unequal")
|
||||
}
|
||||
if !firstAccount.IdKeys.Ed25519.PrivateKey.Equal(unpickledAccount.IdKeys.Ed25519.PrivateKey) {
|
||||
t.Fatal("IdentityKeys Ed25519 private unequal")
|
||||
}
|
||||
if !firstAccount.IdKeys.Ed25519.PublicKey.Equal(unpickledAccount.IdKeys.Ed25519.PublicKey) {
|
||||
t.Fatal("IdentityKeys Ed25519 public unequal")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
if otks, err := firstAccount.OneTimeKeys(); err != nil || len(otks) != 2 {
|
||||
t.Fatal("should get 2 unpublished oneTimeKeys")
|
||||
}
|
||||
if len(firstAccount.FallbackKeyUnpublished()) == 0 {
|
||||
t.Fatal("should get fallbackKey")
|
||||
}
|
||||
//check if accounts are the same
|
||||
assert.Equal(t, firstAccount.NextOneTimeKeyID, unpickledAccount.NextOneTimeKeyID)
|
||||
assert.Equal(t, firstAccount.CurrentFallbackKey, unpickledAccount.CurrentFallbackKey)
|
||||
assert.Equal(t, firstAccount.PrevFallbackKey, unpickledAccount.PrevFallbackKey)
|
||||
assert.Equal(t, firstAccount.OTKeys, unpickledAccount.OTKeys)
|
||||
assert.Equal(t, firstAccount.IdKeys, unpickledAccount.IdKeys)
|
||||
|
||||
// Ensure that all of the keys are unpublished right now
|
||||
otks, err := firstAccount.OneTimeKeys()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, otks, 2)
|
||||
assert.Len(t, firstAccount.FallbackKeyUnpublished(), 1)
|
||||
|
||||
// Now, publish the key and make sure that they are published
|
||||
firstAccount.MarkKeysAsPublished()
|
||||
if len(firstAccount.FallbackKey()) == 0 {
|
||||
t.Fatal("should get fallbackKey")
|
||||
}
|
||||
if len(firstAccount.FallbackKeyUnpublished()) != 0 {
|
||||
t.Fatal("should get no fallbackKey")
|
||||
}
|
||||
if otks, err := firstAccount.OneTimeKeys(); err != nil || len(otks) != 0 {
|
||||
t.Fatal("should get no oneTimeKeys")
|
||||
}
|
||||
|
||||
assert.Len(t, firstAccount.FallbackKeyUnpublished(), 0)
|
||||
assert.Len(t, firstAccount.FallbackKey(), 1)
|
||||
otks, err = firstAccount.OneTimeKeys()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, otks, 0)
|
||||
}
|
||||
|
||||
func TestAccountPickleJSON(t *testing.T) {
|
||||
|
|
@ -104,109 +68,49 @@ func TestAccountPickleJSON(t *testing.T) {
|
|||
|
||||
pickledData := []byte("6POkBWwbNl20fwvZWsOu0jgbHy4jkA5h0Ji+XCag59+ifWIRPDrqtgQi9HmkLiSF6wUhhYaV4S73WM+Hh+dlCuZRuXhTQr8yGPTifjcjq8birdAhObbEqHrYEdqaQkrgBLr/rlS5sibXeDqbkhVu4LslvootU9DkcCbd4b/0Flh7iugxqkcCs5GDndTEx9IzTVJzmK82Y0Q1Z1Z9Vuc2Iw746PtBJLtZjite6fSMp2NigPX/ZWWJ3OnwcJo0Vvjy8hgptZEWkamOHdWbUtelbHyjDIZlvxOC25D3rFif0zzPkF9qdpBPqVCWPPzGFmgnqKau6CHrnPfq7GLsM3BrprD7sHN1Js28ex14gXQPjBT7KTUo6H0e4gQMTMRp4qb8btNXDeId8xIFIElTh2SXZBTDmSq/ziVNJinEvYV8mGPvJZjDQQU+SyoS/HZ8uMc41tH0BOGDbFMHbfLMiz61E429gOrx2klu5lqyoyet7//HKi0ed5w2dQ")
|
||||
account, err := account.AccountFromJSONPickled(pickledData, key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
expectedJSON := `{"ed25519":"qWvNB6Ztov5/AOsP073op0O32KJ8/tgSNarT7MaYgQE","curve25519":"TFUB6M6zwgyWhBEp2m1aUodl2AsnsrIuBr8l9AvwGS8"}`
|
||||
jsonData, err := account.IdentityKeysJSON()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(jsonData, []byte(expectedJSON)) {
|
||||
t.Fatalf("Expected '%s' but got '%s'", expectedJSON, jsonData)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedJSON, string(jsonData))
|
||||
}
|
||||
|
||||
func TestSessions(t *testing.T) {
|
||||
aliceAccount, err := account.NewAccount()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
err = aliceAccount.GenOneTimeKeys(5)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
bobAccount, err := account.NewAccount()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
err = bobAccount.GenOneTimeKeys(5)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
aliceSession, err := aliceAccount.NewOutboundSession(bobAccount.IdKeys.Curve25519.B64Encoded(), bobAccount.OTKeys[2].Key.B64Encoded())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
plaintext := []byte("test message")
|
||||
msgType, crypttext, err := aliceSession.Encrypt(plaintext)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if msgType != id.OlmMsgTypePreKey {
|
||||
t.Fatal("wrong message type")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, id.OlmMsgTypePreKey, msgType)
|
||||
|
||||
bobSession, err := bobAccount.NewInboundSession(string(crypttext))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
decodedText, err := bobSession.Decrypt(string(crypttext), msgType)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(plaintext, decodedText) {
|
||||
t.Fatalf("expected '%s' but got '%s'", string(plaintext), string(decodedText))
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plaintext, decodedText)
|
||||
}
|
||||
|
||||
func TestAccountPickle(t *testing.T) {
|
||||
pickleKey := []byte("secret_key")
|
||||
account, err := account.AccountFromPickled(pickledDataFromLibOlm, pickleKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !expectedEd25519KeyPairPickleLibOLM.PrivateKey.Equal(account.IdKeys.Ed25519.PrivateKey) {
|
||||
t.Fatal("keys not equal")
|
||||
}
|
||||
if !expectedEd25519KeyPairPickleLibOLM.PublicKey.Equal(account.IdKeys.Ed25519.PublicKey) {
|
||||
t.Fatal("keys not equal")
|
||||
}
|
||||
if !expectedCurve25519KeyPairPickleLibOLM.PrivateKey.Equal(account.IdKeys.Curve25519.PrivateKey) {
|
||||
t.Fatal("keys not equal")
|
||||
}
|
||||
if !expectedCurve25519KeyPairPickleLibOLM.PublicKey.Equal(account.IdKeys.Curve25519.PublicKey) {
|
||||
t.Fatal("keys not equal")
|
||||
}
|
||||
if account.NextOneTimeKeyID != 42 {
|
||||
t.Fatal("wrong next otKey id")
|
||||
}
|
||||
if len(account.OTKeys) != len(expectedOTKeysPickleLibOLM) {
|
||||
t.Fatal("wrong number of otKeys")
|
||||
}
|
||||
if account.NumFallbackKeys != 0 {
|
||||
t.Fatal("fallback keys set but not in pickle")
|
||||
}
|
||||
for curIndex, curValue := range account.OTKeys {
|
||||
curExpected := expectedOTKeysPickleLibOLM[curIndex]
|
||||
if curExpected.ID != curValue.ID {
|
||||
t.Fatal("OTKey id not correct")
|
||||
}
|
||||
if !curExpected.Key.PublicKey.Equal(curValue.Key.PublicKey) {
|
||||
t.Fatal("OTKey public key not correct")
|
||||
}
|
||||
if !curExpected.Key.PrivateKey.Equal(curValue.Key.PrivateKey) {
|
||||
t.Fatal("OTKey private key not correct")
|
||||
}
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedEd25519KeyPairPickleLibOLM, account.IdKeys.Ed25519)
|
||||
assert.Equal(t, expectedCurve25519KeyPairPickleLibOLM, account.IdKeys.Curve25519)
|
||||
assert.EqualValues(t, 42, account.NextOneTimeKeyID)
|
||||
assert.Equal(t, account.OTKeys, expectedOTKeysPickleLibOLM)
|
||||
assert.EqualValues(t, 0, account.NumFallbackKeys)
|
||||
|
||||
targetPickled, err := account.Pickle(pickleKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(targetPickled, pickledDataFromLibOlm) {
|
||||
t.Fatal("repickled value does not equal given value")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, pickledDataFromLibOlm, targetPickled)
|
||||
}
|
||||
|
||||
func TestOldAccountPickle(t *testing.T) {
|
||||
|
|
@ -218,355 +122,212 @@ func TestOldAccountPickle(t *testing.T) {
|
|||
"O5TmXua1FcU")
|
||||
pickleKey := []byte("")
|
||||
account, err := account.NewAccount()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
err = account.Unpickle(pickled, pickleKey)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
} else {
|
||||
if !errors.Is(err, olm.ErrBadVersion) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
assert.ErrorIs(t, err, olm.ErrBadVersion)
|
||||
}
|
||||
|
||||
func TestLoopback(t *testing.T) {
|
||||
accountA, err := account.NewAccount()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
accountB, err := account.NewAccount()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = accountB.GenOneTimeKeys(42)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
err = accountB.GenOneTimeKeys( 42)
|
||||
assert.NoError(t, err)
|
||||
|
||||
aliceSession, err := accountA.NewOutboundSession(accountB.IdKeys.Curve25519.B64Encoded(), accountB.OTKeys[0].Key.B64Encoded())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
plainText := []byte("Hello, World")
|
||||
msgType, message1, err := aliceSession.Encrypt(plainText)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if msgType != id.OlmMsgTypePreKey {
|
||||
t.Fatal("wrong message type")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, id.OlmMsgTypePreKey, msgType)
|
||||
|
||||
bobSession, err := accountB.NewInboundSession(string(message1))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
// Check that the inbound session matches the message it was created from.
|
||||
sessionIsOK, err := bobSession.MatchesInboundSessionFrom("", string(message1))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !sessionIsOK {
|
||||
t.Fatal("session was not detected to be valid")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, sessionIsOK, "session was not detected to be valid")
|
||||
|
||||
// Check that the inbound session matches the key this message is supposed to be from.
|
||||
aIDKey := accountA.IdKeys.Curve25519.PublicKey.B64Encoded()
|
||||
sessionIsOK, err = bobSession.MatchesInboundSessionFrom(string(aIDKey), string(message1))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !sessionIsOK {
|
||||
t.Fatal("session is sad to be not from a but it should")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, sessionIsOK, "session is sad to be not from a but it should")
|
||||
|
||||
// Check that the inbound session isn't from a different user.
|
||||
bIDKey := accountB.IdKeys.Curve25519.PublicKey.B64Encoded()
|
||||
sessionIsOK, err = bobSession.MatchesInboundSessionFrom(string(bIDKey), string(message1))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if sessionIsOK {
|
||||
t.Fatal("session is sad to be from b but is from a")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, sessionIsOK, "session is sad to be from b but is from a")
|
||||
|
||||
// Check that we can decrypt the message.
|
||||
decryptedMessage, err := bobSession.Decrypt(string(message1), msgType)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(decryptedMessage, plainText) {
|
||||
t.Fatal("messages are not the same")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plainText, decryptedMessage)
|
||||
|
||||
msgTyp2, message2, err := bobSession.Encrypt(plainText)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if msgTyp2 == id.OlmMsgTypePreKey {
|
||||
t.Fatal("wrong message type")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, id.OlmMsgTypeMsg, msgTyp2)
|
||||
|
||||
decryptedMessage2, err := aliceSession.Decrypt(string(message2), msgTyp2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(decryptedMessage2, plainText) {
|
||||
t.Fatal("messages are not the same")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plainText, decryptedMessage2)
|
||||
|
||||
//decrypting again should fail, as the chain moved on
|
||||
_, err = aliceSession.Decrypt(string(message2), msgTyp2)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, olm.ErrMessageKeyNotFound)
|
||||
|
||||
//compare sessionIDs
|
||||
if aliceSession.ID() != bobSession.ID() {
|
||||
t.Fatal("sessionIDs are not equal")
|
||||
}
|
||||
assert.Equal(t, aliceSession.ID(), bobSession.ID())
|
||||
}
|
||||
|
||||
func TestMoreMessages(t *testing.T) {
|
||||
accountA, err := account.NewAccount()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
accountB, err := account.NewAccount()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = accountB.GenOneTimeKeys(42)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
err = accountB.GenOneTimeKeys( 42)
|
||||
assert.NoError(t, err)
|
||||
|
||||
aliceSession, err := accountA.NewOutboundSession(accountB.IdKeys.Curve25519.B64Encoded(), accountB.OTKeys[0].Key.B64Encoded())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
plainText := []byte("Hello, World")
|
||||
msgType, message1, err := aliceSession.Encrypt(plainText)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if msgType != id.OlmMsgTypePreKey {
|
||||
t.Fatal("wrong message type")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, id.OlmMsgTypePreKey, msgType)
|
||||
|
||||
bobSession, err := accountB.NewInboundSession(string(message1))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
decryptedMessage, err := bobSession.Decrypt(string(message1), msgType)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(decryptedMessage, plainText) {
|
||||
t.Fatal("messages are not the same")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plainText, decryptedMessage)
|
||||
|
||||
for i := 0; i < 8; i++ {
|
||||
//alice sends, bob reveices
|
||||
msgType, message, err := aliceSession.Encrypt(plainText)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
if i == 0 {
|
||||
//The first time should still be a preKeyMessage as bob has not yet send a message to alice
|
||||
if msgType != id.OlmMsgTypePreKey {
|
||||
t.Fatal("wrong message type")
|
||||
}
|
||||
assert.Equal(t, id.OlmMsgTypePreKey, msgType)
|
||||
} else {
|
||||
if msgType == id.OlmMsgTypePreKey {
|
||||
t.Fatal("wrong message type")
|
||||
}
|
||||
assert.Equal(t, id.OlmMsgTypeMsg, msgType)
|
||||
}
|
||||
|
||||
decryptedMessage, err := bobSession.Decrypt(string(message), msgType)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(decryptedMessage, plainText) {
|
||||
t.Fatal("messages are not the same")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plainText, decryptedMessage)
|
||||
|
||||
//now bob sends, alice receives
|
||||
msgType, message, err = bobSession.Encrypt(plainText)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if msgType == id.OlmMsgTypePreKey {
|
||||
t.Fatal("wrong message type")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, id.OlmMsgTypeMsg, msgType)
|
||||
|
||||
decryptedMessage, err = aliceSession.Decrypt(string(message), msgType)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(decryptedMessage, plainText) {
|
||||
t.Fatal("messages are not the same")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plainText, decryptedMessage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallbackKey(t *testing.T) {
|
||||
accountA, err := account.NewAccount()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
accountB, err := account.NewAccount()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
err = accountB.GenFallbackKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
fallBackKeys := accountB.FallbackKeyUnpublished()
|
||||
var fallbackKey id.Curve25519
|
||||
for _, fbKey := range fallBackKeys {
|
||||
fallbackKey = fbKey
|
||||
}
|
||||
aliceSession, err := accountA.NewOutboundSession(accountB.IdKeys.Curve25519.B64Encoded(), fallbackKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
plainText := []byte("Hello, World")
|
||||
msgType, message1, err := aliceSession.Encrypt(plainText)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if msgType != id.OlmMsgTypePreKey {
|
||||
t.Fatal("wrong message type")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, id.OlmMsgTypePreKey, msgType)
|
||||
|
||||
bobSession, err := accountB.NewInboundSession(string(message1))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
// Check that the inbound session matches the message it was created from.
|
||||
sessionIsOK, err := bobSession.MatchesInboundSessionFrom("", string(message1))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !sessionIsOK {
|
||||
t.Fatal("session was not detected to be valid")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, sessionIsOK, "session was not detected to be valid")
|
||||
|
||||
// Check that the inbound session matches the key this message is supposed to be from.
|
||||
aIDKey := accountA.IdKeys.Curve25519.PublicKey.B64Encoded()
|
||||
sessionIsOK, err = bobSession.MatchesInboundSessionFrom(string(aIDKey), string(message1))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !sessionIsOK {
|
||||
t.Fatal("session is sad to be not from a but it should")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, sessionIsOK, "session is sad to be not from a but it should")
|
||||
|
||||
// Check that the inbound session isn't from a different user.
|
||||
bIDKey := accountB.IdKeys.Curve25519.PublicKey.B64Encoded()
|
||||
sessionIsOK, err = bobSession.MatchesInboundSessionFrom(string(bIDKey), string(message1))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if sessionIsOK {
|
||||
t.Fatal("session is sad to be from b but is from a")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, sessionIsOK, "session is sad to be from b but is from a")
|
||||
|
||||
// Check that we can decrypt the message.
|
||||
decryptedMessage, err := bobSession.Decrypt(string(message1), msgType)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(decryptedMessage, plainText) {
|
||||
t.Fatal("messages are not the same")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plainText, decryptedMessage)
|
||||
|
||||
// create a new fallback key for B (the old fallback should still be usable)
|
||||
err = accountB.GenFallbackKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
// start another session and encrypt a message
|
||||
aliceSession2, err := accountA.NewOutboundSession(accountB.IdKeys.Curve25519.B64Encoded(), fallbackKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
msgType2, message2, err := aliceSession2.Encrypt(plainText)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if msgType2 != id.OlmMsgTypePreKey {
|
||||
t.Fatal("wrong message type")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, id.OlmMsgTypePreKey, msgType2)
|
||||
|
||||
// bobSession should not be valid for the message2
|
||||
// Check that the inbound session matches the message it was created from.
|
||||
sessionIsOK, err = bobSession.MatchesInboundSessionFrom("", string(message2))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if sessionIsOK {
|
||||
t.Fatal("session was detected to be valid but should not")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, sessionIsOK, "session was detected to be valid but should not")
|
||||
|
||||
bobSession2, err := accountB.NewInboundSession(string(message2))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
// Check that the inbound session matches the message it was created from.
|
||||
sessionIsOK, err = bobSession2.MatchesInboundSessionFrom("", string(message2))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !sessionIsOK {
|
||||
t.Fatal("session was not detected to be valid")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, sessionIsOK, "session was not detected to be valid")
|
||||
|
||||
// Check that the inbound session matches the key this message is supposed to be from.
|
||||
sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(string(aIDKey), string(message2))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !sessionIsOK {
|
||||
t.Fatal("session is sad to be not from a but it should")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, sessionIsOK, "session is sad to be not from a but it should")
|
||||
|
||||
// Check that the inbound session isn't from a different user.
|
||||
sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(string(bIDKey), string(message2))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if sessionIsOK {
|
||||
t.Fatal("session is sad to be from b but is from a")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, sessionIsOK, "session is sad to be from b but is from a")
|
||||
|
||||
// Check that we can decrypt the message.
|
||||
decryptedMessage2, err := bobSession2.Decrypt(string(message2), msgType2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(decryptedMessage2, plainText) {
|
||||
t.Fatal("messages are not the same")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plainText, decryptedMessage2)
|
||||
|
||||
//Forget the old fallback key -- creating a new session should fail now
|
||||
accountB.ForgetOldFallbackKey()
|
||||
// start another session and encrypt a message
|
||||
aliceSession3, err := accountA.NewOutboundSession(accountB.IdKeys.Curve25519.B64Encoded(), fallbackKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
msgType3, message3, err := aliceSession3.Encrypt(plainText)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if msgType3 != id.OlmMsgTypePreKey {
|
||||
t.Fatal("wrong message type")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, id.OlmMsgTypePreKey, msgType3)
|
||||
_, err = accountB.NewInboundSession(string(message3))
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !errors.Is(err, olm.ErrBadMessageKeyID) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.ErrorIs(t, err, olm.ErrBadMessageKeyID)
|
||||
}
|
||||
|
||||
func TestOldV3AccountPickle(t *testing.T) {
|
||||
|
|
@ -582,33 +343,23 @@ func TestOldV3AccountPickle(t *testing.T) {
|
|||
expectedUnpublishedFallbackJSON := []byte("{\"curve25519\":{}}")
|
||||
|
||||
account, err := account.AccountFromPickled(pickledData, pickleKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
fallbackJSON, err := account.FallbackKeyJSON()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(fallbackJSON, expectedFallbackJSON) {
|
||||
t.Fatalf("expected not as result:\n%s\n%s\n", expectedFallbackJSON, fallbackJSON)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedFallbackJSON, fallbackJSON)
|
||||
fallbackJSONUnpublished, err := account.FallbackKeyUnpublishedJSON()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(fallbackJSONUnpublished, expectedUnpublishedFallbackJSON) {
|
||||
t.Fatalf("expected not as result:\n%s\n%s\n", expectedUnpublishedFallbackJSON, fallbackJSONUnpublished)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedUnpublishedFallbackJSON, fallbackJSONUnpublished)
|
||||
}
|
||||
|
||||
func TestAccountSign(t *testing.T) {
|
||||
accountA, err := account.NewAccount()
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
plainText := []byte("Hello, World")
|
||||
signatureB64, err := accountA.Sign(plainText)
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
signature, err := base64.RawStdEncoding.DecodeString(string(signatureB64))
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
verified, err := signatures.VerifySignature(plainText, accountA.IdKeys.Ed25519.B64Encoded(), signature)
|
||||
assert.NoError(t, err)
|
||||
|
|
|
|||
|
|
@ -1,52 +1,44 @@
|
|||
package cipher
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDeriveAESKeys(t *testing.T) {
|
||||
kdfInfo := []byte("test")
|
||||
key := []byte("test key")
|
||||
derivedKeys, err := deriveAESKeys(kdfInfo, key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
derivedKeys2, err := deriveAESKeys(kdfInfo, key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
//derivedKeys and derivedKeys2 should be identical
|
||||
if !bytes.Equal(derivedKeys.key, derivedKeys2.key) ||
|
||||
!bytes.Equal(derivedKeys.iv, derivedKeys2.iv) ||
|
||||
!bytes.Equal(derivedKeys.hmacKey, derivedKeys2.hmacKey) {
|
||||
t.Fail()
|
||||
}
|
||||
assert.Equal(t, derivedKeys.key, derivedKeys2.key)
|
||||
assert.Equal(t, derivedKeys.iv, derivedKeys2.iv)
|
||||
assert.Equal(t, derivedKeys.hmacKey, derivedKeys2.hmacKey)
|
||||
|
||||
//changing kdfInfo
|
||||
kdfInfo = []byte("other kdf")
|
||||
derivedKeys2, err = deriveAESKeys(kdfInfo, key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
//derivedKeys and derivedKeys2 should now be different
|
||||
if bytes.Equal(derivedKeys.key, derivedKeys2.key) ||
|
||||
bytes.Equal(derivedKeys.iv, derivedKeys2.iv) ||
|
||||
bytes.Equal(derivedKeys.hmacKey, derivedKeys2.hmacKey) {
|
||||
t.Fail()
|
||||
}
|
||||
assert.NotEqual(t, derivedKeys.key, derivedKeys2.key)
|
||||
assert.NotEqual(t, derivedKeys.iv, derivedKeys2.iv)
|
||||
assert.NotEqual(t, derivedKeys.hmacKey, derivedKeys2.hmacKey)
|
||||
|
||||
//changing key
|
||||
key = []byte("other test key")
|
||||
derivedKeys, err = deriveAESKeys(kdfInfo, key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
//derivedKeys and derivedKeys2 should now be different
|
||||
if bytes.Equal(derivedKeys.key, derivedKeys2.key) ||
|
||||
bytes.Equal(derivedKeys.iv, derivedKeys2.iv) ||
|
||||
bytes.Equal(derivedKeys.hmacKey, derivedKeys2.hmacKey) {
|
||||
t.Fail()
|
||||
}
|
||||
assert.NotEqual(t, derivedKeys.key, derivedKeys2.key)
|
||||
assert.NotEqual(t, derivedKeys.iv, derivedKeys2.iv)
|
||||
assert.NotEqual(t, derivedKeys.hmacKey, derivedKeys2.hmacKey)
|
||||
}
|
||||
|
||||
func TestCipherAESSha256(t *testing.T) {
|
||||
|
|
@ -58,26 +50,15 @@ func TestCipherAESSha256(t *testing.T) {
|
|||
message = append(message, []byte("-")...)
|
||||
}
|
||||
encrypted, err := cipher.Encrypt(key, []byte(message))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
mac, err := cipher.MAC(key, encrypted)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
verified, err := cipher.Verify(key, encrypted, mac[:8])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !verified {
|
||||
t.Fatal("signature verification failed")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, verified, "signature verification failed")
|
||||
|
||||
resultPlainText, err := cipher.Decrypt(key, encrypted)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(message, resultPlainText) {
|
||||
t.Fail()
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, message, resultPlainText)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
package cipher_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm/cipher"
|
||||
)
|
||||
|
||||
|
|
@ -19,15 +20,9 @@ func TestEncoding(t *testing.T) {
|
|||
copy(toEncrypt, input)
|
||||
}
|
||||
encoded, err := cipher.Pickle(key, toEncrypt)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
decoded, err := cipher.Unpickle(key, encoded)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(decoded, toEncrypt) {
|
||||
t.Fatalf("Expected '%s' but got '%s'", toEncrypt, decoded)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, toEncrypt, decoded)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,39 +1,26 @@
|
|||
package crypto_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
||||
)
|
||||
|
||||
func TestCurve25519(t *testing.T) {
|
||||
firstKeypair, err := crypto.Curve25519GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
secondKeypair, err := crypto.Curve25519GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
sharedSecretFromFirst, err := firstKeypair.SharedSecret(secondKeypair.PublicKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
sharedSecretFromSecond, err := secondKeypair.SharedSecret(firstKeypair.PublicKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(sharedSecretFromFirst, sharedSecretFromSecond) {
|
||||
t.Fatal("shared secret not equal")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, sharedSecretFromFirst, sharedSecretFromSecond, "shared secret not equal")
|
||||
fromPrivate, err := crypto.Curve25519GenerateFromPrivate(firstKeypair.PrivateKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(fromPrivate.PublicKey, firstKeypair.PublicKey) {
|
||||
t.Fatal("public keys not equal")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, fromPrivate, firstKeypair)
|
||||
}
|
||||
|
||||
func TestCurve25519Case1(t *testing.T) {
|
||||
|
|
@ -76,112 +63,59 @@ func TestCurve25519Case1(t *testing.T) {
|
|||
PublicKey: bobPublic,
|
||||
}
|
||||
agreementFromAlice, err := aliceKeyPair.SharedSecret(bobKeyPair.PublicKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(agreementFromAlice, expectedAgreement) {
|
||||
t.Fatal("expected agreement does not match agreement from Alice's view")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedAgreement, agreementFromAlice, "expected agreement does not match agreement from Alice's view")
|
||||
agreementFromBob, err := bobKeyPair.SharedSecret(aliceKeyPair.PublicKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(agreementFromBob, expectedAgreement) {
|
||||
t.Fatal("expected agreement does not match agreement from Bob's view")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedAgreement, agreementFromBob, "expected agreement does not match agreement from Bob's view")
|
||||
}
|
||||
|
||||
func TestCurve25519Pickle(t *testing.T) {
|
||||
//create keypair
|
||||
keyPair, err := crypto.Curve25519GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
target := make([]byte, keyPair.PickleLen())
|
||||
writtenBytes, err := keyPair.PickleLibOlm(target)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if writtenBytes != len(target) {
|
||||
t.Fatal("written bytes not correct")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, target, writtenBytes)
|
||||
|
||||
unpickledKeyPair := crypto.Curve25519KeyPair{}
|
||||
readBytes, err := unpickledKeyPair.UnpickleLibOlm(target)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if readBytes != len(target) {
|
||||
t.Fatal("read bytes not correct")
|
||||
}
|
||||
if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) {
|
||||
t.Fatal("private keys not correct")
|
||||
}
|
||||
if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) {
|
||||
t.Fatal("public keys not correct")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, target, readBytes)
|
||||
assert.Equal(t, keyPair, unpickledKeyPair)
|
||||
}
|
||||
|
||||
func TestCurve25519PicklePubKeyOnly(t *testing.T) {
|
||||
//create keypair
|
||||
keyPair, err := crypto.Curve25519GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
//Remove privateKey
|
||||
keyPair.PrivateKey = nil
|
||||
target := make([]byte, keyPair.PickleLen())
|
||||
writtenBytes, err := keyPair.PickleLibOlm(target)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if writtenBytes != len(target) {
|
||||
t.Fatal("written bytes not correct")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, target, writtenBytes)
|
||||
unpickledKeyPair := crypto.Curve25519KeyPair{}
|
||||
readBytes, err := unpickledKeyPair.UnpickleLibOlm(target)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if readBytes != len(target) {
|
||||
t.Fatal("read bytes not correct")
|
||||
}
|
||||
if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) {
|
||||
t.Fatal("private keys not correct")
|
||||
}
|
||||
if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) {
|
||||
t.Fatal("public keys not correct")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, target, readBytes)
|
||||
assert.Equal(t, keyPair, unpickledKeyPair)
|
||||
}
|
||||
|
||||
func TestCurve25519PicklePrivKeyOnly(t *testing.T) {
|
||||
//create keypair
|
||||
keyPair, err := crypto.Curve25519GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
//Remove public
|
||||
keyPair.PublicKey = nil
|
||||
target := make([]byte, keyPair.PickleLen())
|
||||
writtenBytes, err := keyPair.PickleLibOlm(target)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if writtenBytes != len(target) {
|
||||
t.Fatal("written bytes not correct")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, target, writtenBytes)
|
||||
unpickledKeyPair := crypto.Curve25519KeyPair{}
|
||||
readBytes, err := unpickledKeyPair.UnpickleLibOlm(target)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if readBytes != len(target) {
|
||||
t.Fatal("read bytes not correct")
|
||||
}
|
||||
if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) {
|
||||
t.Fatal("private keys not correct")
|
||||
}
|
||||
if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) {
|
||||
t.Fatal("public keys not correct")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, target, readBytes)
|
||||
assert.Equal(t, keyPair, unpickledKeyPair)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,140 +1,87 @@
|
|||
package crypto_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
||||
)
|
||||
|
||||
func TestEd25519(t *testing.T) {
|
||||
keypair, err := crypto.Ed25519GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
message := []byte("test message")
|
||||
signature := keypair.Sign(message)
|
||||
if !keypair.Verify(message, signature) {
|
||||
t.Fail()
|
||||
}
|
||||
assert.True(t, keypair.Verify(message, signature))
|
||||
}
|
||||
|
||||
func TestEd25519Case1(t *testing.T) {
|
||||
//64 bytes for ed25519 package
|
||||
keyPair, err := crypto.Ed25519GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
message := []byte("Hello, World")
|
||||
|
||||
keyPair2 := crypto.Ed25519GenerateFromPrivate(keyPair.PrivateKey)
|
||||
if !bytes.Equal(keyPair.PublicKey, keyPair2.PublicKey) {
|
||||
t.Fatal("not equal key pairs")
|
||||
}
|
||||
assert.Equal(t, keyPair, keyPair2, "not equal key pairs")
|
||||
signature := keyPair.Sign(message)
|
||||
verified := keyPair.Verify(message, signature)
|
||||
if !verified {
|
||||
t.Fatal("message did not verify although it should")
|
||||
}
|
||||
assert.True(t, verified, "message did not verify although it should")
|
||||
|
||||
//Now change the message and verify again
|
||||
message = append(message, []byte("a")...)
|
||||
verified = keyPair.Verify(message, signature)
|
||||
if verified {
|
||||
t.Fatal("message did verify although it should not")
|
||||
}
|
||||
assert.False(t, verified, "message did verify although it should not")
|
||||
}
|
||||
|
||||
func TestEd25519Pickle(t *testing.T) {
|
||||
//create keypair
|
||||
keyPair, err := crypto.Ed25519GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
target := make([]byte, keyPair.PickleLen())
|
||||
writtenBytes, err := keyPair.PickleLibOlm(target)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if writtenBytes != len(target) {
|
||||
t.Fatal("written bytes not correct")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, target, writtenBytes)
|
||||
|
||||
unpickledKeyPair := crypto.Ed25519KeyPair{}
|
||||
readBytes, err := unpickledKeyPair.UnpickleLibOlm(target)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if readBytes != len(target) {
|
||||
t.Fatal("read bytes not correct")
|
||||
}
|
||||
if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) {
|
||||
t.Fatal("private keys not correct")
|
||||
}
|
||||
if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) {
|
||||
t.Fatal("public keys not correct")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, target, readBytes, "read bytes not correct")
|
||||
assert.Equal(t, keyPair, unpickledKeyPair)
|
||||
}
|
||||
|
||||
func TestEd25519PicklePubKeyOnly(t *testing.T) {
|
||||
//create keypair
|
||||
keyPair, err := crypto.Ed25519GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
//Remove privateKey
|
||||
keyPair.PrivateKey = nil
|
||||
target := make([]byte, keyPair.PickleLen())
|
||||
writtenBytes, err := keyPair.PickleLibOlm(target)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if writtenBytes != len(target) {
|
||||
t.Fatal("written bytes not correct")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, target, writtenBytes)
|
||||
|
||||
unpickledKeyPair := crypto.Ed25519KeyPair{}
|
||||
readBytes, err := unpickledKeyPair.UnpickleLibOlm(target)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if readBytes != len(target) {
|
||||
t.Fatal("read bytes not correct")
|
||||
}
|
||||
if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) {
|
||||
t.Fatal("private keys not correct")
|
||||
}
|
||||
if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) {
|
||||
t.Fatal("public keys not correct")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, target, readBytes, "read bytes not correct")
|
||||
assert.Equal(t, keyPair, unpickledKeyPair)
|
||||
}
|
||||
|
||||
func TestEd25519PicklePrivKeyOnly(t *testing.T) {
|
||||
//create keypair
|
||||
keyPair, err := crypto.Ed25519GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
//Remove public
|
||||
keyPair.PublicKey = nil
|
||||
target := make([]byte, keyPair.PickleLen())
|
||||
writtenBytes, err := keyPair.PickleLibOlm(target)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if writtenBytes != len(target) {
|
||||
t.Fatal("written bytes not correct")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, target, writtenBytes)
|
||||
|
||||
unpickledKeyPair := crypto.Ed25519KeyPair{}
|
||||
readBytes, err := unpickledKeyPair.UnpickleLibOlm(target)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if readBytes != len(target) {
|
||||
t.Fatal("read bytes not correct")
|
||||
}
|
||||
if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) {
|
||||
t.Fatal("private keys not correct")
|
||||
}
|
||||
if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) {
|
||||
t.Fatal("public keys not correct")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, target, readBytes, "read bytes not correct")
|
||||
assert.Equal(t, keyPair, unpickledKeyPair)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,49 +1,44 @@
|
|||
package crypto_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
||||
)
|
||||
|
||||
func TestHMACSha256(t *testing.T) {
|
||||
func TestHMACSHA256(t *testing.T) {
|
||||
key := []byte("test key")
|
||||
message := []byte("test message")
|
||||
hash := crypto.HMACSHA256(key, message)
|
||||
if !bytes.Equal(hash, crypto.HMACSHA256(key, message)) {
|
||||
t.Fail()
|
||||
}
|
||||
assert.Equal(t, hash, crypto.HMACSHA256(key, message))
|
||||
|
||||
str := "A4M0ovdiWHaZ5msdDFbrvtChFwZIoIaRSVGmv8bmPtc"
|
||||
result, err := base64.RawStdEncoding.DecodeString(str)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(result, hash) {
|
||||
t.Fail()
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, result, hash)
|
||||
}
|
||||
|
||||
func TestHKDFSha256(t *testing.T) {
|
||||
func TestHKDFSHA256(t *testing.T) {
|
||||
message := []byte("test content")
|
||||
|
||||
hkdf := crypto.HKDFSHA256(message, nil, nil)
|
||||
hkdf2 := crypto.HKDFSHA256(message, nil, nil)
|
||||
result := make([]byte, 32)
|
||||
if _, err := io.ReadFull(hkdf, result); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err := io.ReadFull(hkdf, result)
|
||||
assert.NoError(t, err)
|
||||
|
||||
hkdf2 := crypto.HKDFSHA256(message, nil, nil)
|
||||
result2 := make([]byte, 32)
|
||||
if _, err := io.ReadFull(hkdf2, result2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(result, result2) {
|
||||
t.Fail()
|
||||
}
|
||||
_, err = io.ReadFull(hkdf2, result2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, result, result2)
|
||||
}
|
||||
|
||||
func TestSha256Case1(t *testing.T) {
|
||||
func TestSHA256Case1(t *testing.T) {
|
||||
input := make([]byte, 0)
|
||||
expected := []byte{
|
||||
0xE3, 0xB0, 0xC4, 0x42, 0x98, 0xFC, 0x1C, 0x14,
|
||||
|
|
@ -52,9 +47,7 @@ func TestSha256Case1(t *testing.T) {
|
|||
0xA4, 0x95, 0x99, 0x1B, 0x78, 0x52, 0xB8, 0x55,
|
||||
}
|
||||
result := crypto.SHA256(input)
|
||||
if !bytes.Equal(expected, result) {
|
||||
t.Fatalf("result not as expected:\n%v\n%v\n", result, expected)
|
||||
}
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
|
||||
func TestHMACCase1(t *testing.T) {
|
||||
|
|
@ -66,9 +59,7 @@ func TestHMACCase1(t *testing.T) {
|
|||
0xc6, 0xc7, 0x12, 0x14, 0x42, 0x92, 0xc5, 0xad,
|
||||
}
|
||||
result := crypto.HMACSHA256(input, input)
|
||||
if !bytes.Equal(expected, result) {
|
||||
t.Fatalf("result not as expected:\n%v\n%v\n", result, expected)
|
||||
}
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
|
||||
func TestHDKFCase1(t *testing.T) {
|
||||
|
|
@ -92,9 +83,8 @@ func TestHDKFCase1(t *testing.T) {
|
|||
0x22, 0xec, 0x84, 0x4a, 0xd7, 0xc2, 0xb3, 0xe5,
|
||||
}
|
||||
result := crypto.HMACSHA256(salt, input)
|
||||
if !bytes.Equal(expectedHMAC, result) {
|
||||
t.Fatalf("result not as expected:\n%v\n%v\n", result, expectedHMAC)
|
||||
}
|
||||
assert.Equal(t, expectedHMAC, result)
|
||||
|
||||
expectedHDKF := []byte{
|
||||
0x3c, 0xb2, 0x5f, 0x25, 0xfa, 0xac, 0xd5, 0x7a,
|
||||
0x90, 0x43, 0x4f, 0x64, 0xd0, 0x36, 0x2f, 0x2a,
|
||||
|
|
@ -105,10 +95,7 @@ func TestHDKFCase1(t *testing.T) {
|
|||
}
|
||||
resultReader := crypto.HKDFSHA256(input, salt, info)
|
||||
result = make([]byte, len(expectedHDKF))
|
||||
if _, err := io.ReadFull(resultReader, result); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(expectedHDKF, result) {
|
||||
t.Fatalf("result not as expected:\n%v\n%v\n", result, expectedHDKF)
|
||||
}
|
||||
_, err := io.ReadFull(resultReader, result)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedHDKF, result)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
package libolmpickle_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
|
||||
)
|
||||
|
||||
|
|
@ -23,12 +24,8 @@ func TestPickleUInt32(t *testing.T) {
|
|||
for curIndex := range values {
|
||||
response := make([]byte, 4)
|
||||
resPLen := libolmpickle.PickleUInt32(values[curIndex], response)
|
||||
if resPLen != libolmpickle.PickleUInt32Len(values[curIndex]) {
|
||||
t.Fatal("written bytes not correct")
|
||||
}
|
||||
if !bytes.Equal(response, expected[curIndex]) {
|
||||
t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex])
|
||||
}
|
||||
assert.Equal(t, libolmpickle.PickleUInt32Len(values[curIndex]), resPLen)
|
||||
assert.Equal(t, expected[curIndex], response)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -44,12 +41,8 @@ func TestPickleBool(t *testing.T) {
|
|||
for curIndex := range values {
|
||||
response := make([]byte, 1)
|
||||
resPLen := libolmpickle.PickleBool(values[curIndex], response)
|
||||
if resPLen != libolmpickle.PickleBoolLen(values[curIndex]) {
|
||||
t.Fatal("written bytes not correct")
|
||||
}
|
||||
if !bytes.Equal(response, expected[curIndex]) {
|
||||
t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex])
|
||||
}
|
||||
assert.Equal(t, libolmpickle.PickleBoolLen(values[curIndex]), resPLen)
|
||||
assert.Equal(t, expected[curIndex], response)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -65,12 +58,8 @@ func TestPickleUInt8(t *testing.T) {
|
|||
for curIndex := range values {
|
||||
response := make([]byte, 1)
|
||||
resPLen := libolmpickle.PickleUInt8(values[curIndex], response)
|
||||
if resPLen != libolmpickle.PickleUInt8Len(values[curIndex]) {
|
||||
t.Fatal("written bytes not correct")
|
||||
}
|
||||
if !bytes.Equal(response, expected[curIndex]) {
|
||||
t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex])
|
||||
}
|
||||
assert.Equal(t, libolmpickle.PickleUInt8Len(values[curIndex]), resPLen)
|
||||
assert.Equal(t, expected[curIndex], response)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -88,11 +77,7 @@ func TestPickleBytes(t *testing.T) {
|
|||
for curIndex := range values {
|
||||
response := make([]byte, len(values[curIndex]))
|
||||
resPLen := libolmpickle.PickleBytes(values[curIndex], response)
|
||||
if resPLen != libolmpickle.PickleBytesLen(values[curIndex]) {
|
||||
t.Fatal("written bytes not correct")
|
||||
}
|
||||
if !bytes.Equal(response, expected[curIndex]) {
|
||||
t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex])
|
||||
}
|
||||
assert.Equal(t, libolmpickle.PickleBytesLen(values[curIndex]), resPLen)
|
||||
assert.Equal(t, expected[curIndex], response)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
package libolmpickle_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
|
||||
)
|
||||
|
||||
|
|
@ -20,15 +21,9 @@ func TestUnpickleUInt32(t *testing.T) {
|
|||
}
|
||||
for curIndex := range values {
|
||||
response, readLength, err := libolmpickle.UnpickleUInt32(values[curIndex])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if readLength != 4 {
|
||||
t.Fatal("read bytes not correct")
|
||||
}
|
||||
if response != expected[curIndex] {
|
||||
t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex])
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 4, readLength)
|
||||
assert.Equal(t, expected[curIndex], response)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -45,15 +40,9 @@ func TestUnpickleBool(t *testing.T) {
|
|||
}
|
||||
for curIndex := range values {
|
||||
response, readLength, err := libolmpickle.UnpickleBool(values[curIndex])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if readLength != 1 {
|
||||
t.Fatal("read bytes not correct")
|
||||
}
|
||||
if response != expected[curIndex] {
|
||||
t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex])
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, readLength)
|
||||
assert.Equal(t, expected[curIndex], response)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -68,15 +57,9 @@ func TestUnpickleUInt8(t *testing.T) {
|
|||
}
|
||||
for curIndex := range values {
|
||||
response, readLength, err := libolmpickle.UnpickleUInt8(values[curIndex])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if readLength != 1 {
|
||||
t.Fatal("read bytes not correct")
|
||||
}
|
||||
if response != expected[curIndex] {
|
||||
t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex])
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, readLength)
|
||||
assert.Equal(t, expected[curIndex], response)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -93,14 +76,8 @@ func TestUnpickleBytes(t *testing.T) {
|
|||
}
|
||||
for curIndex := range values {
|
||||
response, readLength, err := libolmpickle.UnpickleBytes(values[curIndex], 4)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if readLength != 4 {
|
||||
t.Fatal("read bytes not correct")
|
||||
}
|
||||
if !bytes.Equal(response, expected[curIndex]) {
|
||||
t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex])
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 4, readLength)
|
||||
assert.Equal(t, expected[curIndex], response)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
package megolm_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm/megolm"
|
||||
)
|
||||
|
||||
|
|
@ -19,9 +20,7 @@ func init() {
|
|||
|
||||
func TestAdvance(t *testing.T) {
|
||||
m, err := megolm.New(0, startData)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
expectedData := [megolm.RatchetParts * megolm.RatchetPartLength]byte{
|
||||
0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46,
|
||||
|
|
@ -34,9 +33,7 @@ func TestAdvance(t *testing.T) {
|
|||
0x89, 0xbb, 0xb4, 0x23, 0xa1, 0x8f, 0x23, 0x82, 0x8f, 0xb2, 0x09, 0x0d, 0x6e, 0x2a, 0xf8, 0x6a,
|
||||
}
|
||||
m.Advance()
|
||||
if !bytes.Equal(m.Data[:], expectedData[:]) {
|
||||
t.Fatal("result after advancing the ratchet is not as expected")
|
||||
}
|
||||
assert.Equal(t, m.Data[:], expectedData[:], "result after advancing the ratchet is not as expected")
|
||||
|
||||
//repeat with complex advance
|
||||
m.Data = startData
|
||||
|
|
@ -51,9 +48,8 @@ func TestAdvance(t *testing.T) {
|
|||
0x89, 0xbb, 0xb4, 0x23, 0xa1, 0x8f, 0x23, 0x82, 0x8f, 0xb2, 0x09, 0x0d, 0x6e, 0x2a, 0xf8, 0x6a,
|
||||
}
|
||||
m.AdvanceTo(0x1000000)
|
||||
if !bytes.Equal(m.Data[:], expectedData[:]) {
|
||||
t.Fatal("result after advancing the ratchet is not as expected")
|
||||
}
|
||||
assert.Equal(t, m.Data[:], expectedData[:], "result after advancing the ratchet is not as expected")
|
||||
|
||||
expectedData = [megolm.RatchetParts * megolm.RatchetPartLength]byte{
|
||||
0x54, 0x02, 0x2d, 0x7d, 0xc0, 0x29, 0x8e, 0x16, 0x37, 0xe2, 0x1c, 0x97, 0x15, 0x30, 0x92, 0xf9,
|
||||
0x33, 0xc0, 0x56, 0xff, 0x74, 0xfe, 0x1b, 0x92, 0x2d, 0x97, 0x1f, 0x24, 0x82, 0xc2, 0x85, 0x9c,
|
||||
|
|
@ -65,77 +61,45 @@ func TestAdvance(t *testing.T) {
|
|||
0xd5, 0x6f, 0x03, 0xe2, 0x44, 0x16, 0xb9, 0x8e, 0x1c, 0xfd, 0x97, 0xc2, 0x06, 0xaa, 0x90, 0x7a,
|
||||
}
|
||||
m.AdvanceTo(0x1041506)
|
||||
if !bytes.Equal(m.Data[:], expectedData[:]) {
|
||||
t.Fatal("result after advancing the ratchet is not as expected")
|
||||
}
|
||||
assert.Equal(t, m.Data[:], expectedData[:], "result after advancing the ratchet is not as expected")
|
||||
}
|
||||
|
||||
func TestAdvanceWraparound(t *testing.T) {
|
||||
m, err := megolm.New(0xffffffff, startData)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
m.AdvanceTo(0x1000000)
|
||||
if m.Counter != 0x1000000 {
|
||||
t.Fatal("counter not correct")
|
||||
}
|
||||
assert.EqualValues(t, 0x1000000, m.Counter, "counter not correct")
|
||||
|
||||
m2, err := megolm.New(0, startData)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
m2.AdvanceTo(0x2000000)
|
||||
if m2.Counter != 0x2000000 {
|
||||
t.Fatal("counter not correct")
|
||||
}
|
||||
if !bytes.Equal(m.Data[:], m2.Data[:]) {
|
||||
t.Fatal("result after wrapping the ratchet is not as expected")
|
||||
}
|
||||
assert.EqualValues(t, 0x2000000, m2.Counter, "counter not correct")
|
||||
assert.Equal(t, m.Data, m2.Data, "result after wrapping the ratchet is not as expected")
|
||||
}
|
||||
|
||||
func TestAdvanceOverflowByOne(t *testing.T) {
|
||||
m, err := megolm.New(0xffffffff, startData)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
m.AdvanceTo(0x0)
|
||||
if m.Counter != 0x0 {
|
||||
t.Fatal("counter not correct")
|
||||
}
|
||||
assert.EqualValues(t, 0x0, m.Counter, "counter not correct")
|
||||
|
||||
m2, err := megolm.New(0xffffffff, startData)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
m2.Advance()
|
||||
if m2.Counter != 0x0 {
|
||||
t.Fatal("counter not correct")
|
||||
}
|
||||
if !bytes.Equal(m.Data[:], m2.Data[:]) {
|
||||
t.Fatal("result after wrapping the ratchet is not as expected")
|
||||
}
|
||||
assert.EqualValues(t, 0x0, m2.Counter, "counter not correct")
|
||||
assert.Equal(t, m.Data, m2.Data, "result after wrapping the ratchet is not as expected")
|
||||
}
|
||||
|
||||
func TestAdvanceOverflow(t *testing.T) {
|
||||
m, err := megolm.New(0x1, startData)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
m.AdvanceTo(0x80000000)
|
||||
m.AdvanceTo(0x0)
|
||||
if m.Counter != 0x0 {
|
||||
t.Fatal("counter not correct")
|
||||
}
|
||||
assert.EqualValues(t, 0x0, m.Counter, "counter not correct")
|
||||
|
||||
m2, err := megolm.New(0x1, startData)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
m2.AdvanceTo(0x0)
|
||||
if m2.Counter != 0x0 {
|
||||
t.Fatal("counter not correct")
|
||||
}
|
||||
if !bytes.Equal(m.Data[:], m2.Data[:]) {
|
||||
t.Fatal("result after wrapping the ratchet is not as expected")
|
||||
}
|
||||
assert.EqualValues(t, 0x0, m2.Counter, "counter not correct")
|
||||
assert.Equal(t, m.Data, m2.Data, "result after wrapping the ratchet is not as expected")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,17 +1,16 @@
|
|||
package message
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestEncodeLengthInt(t *testing.T) {
|
||||
numbers := []uint32{127, 128, 16383, 16384, 32767}
|
||||
expected := []int{1, 2, 2, 3, 3}
|
||||
for curIndex := range numbers {
|
||||
if result := encodeVarIntByteLength(numbers[curIndex]); result != expected[curIndex] {
|
||||
t.Fatalf("expected byte length of %d but got %d", expected[curIndex], result)
|
||||
}
|
||||
assert.Equal(t, expected[curIndex], encodeVarIntByteLength(numbers[curIndex]))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -25,9 +24,7 @@ func TestEncodeLengthString(t *testing.T) {
|
|||
strings = append(strings, []byte("this is an even longer message with a length between 128 and 16383 so that the varint of the length needs two byte. just needs some padding again ---------"))
|
||||
expected = append(expected, 2+155)
|
||||
for curIndex := range strings {
|
||||
if result := encodeVarStringByteLength(strings[curIndex]); result != expected[curIndex] {
|
||||
t.Fatalf("expected byte length of %d but got %d", expected[curIndex], result)
|
||||
}
|
||||
assert.Equal(t, expected[curIndex], encodeVarStringByteLength(strings[curIndex]))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -43,9 +40,7 @@ func TestEncodeInt(t *testing.T) {
|
|||
ints = append(ints, 16383)
|
||||
expected = append(expected, []byte{0b11111111, 0b01111111})
|
||||
for curIndex := range ints {
|
||||
if result := encodeVarInt(ints[curIndex]); !bytes.Equal(result, expected[curIndex]) {
|
||||
t.Fatalf("expected byte of %b but got %b", expected[curIndex], result)
|
||||
}
|
||||
assert.Equal(t, expected[curIndex], encodeVarInt(ints[curIndex]))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -75,8 +70,6 @@ func TestEncodeString(t *testing.T) {
|
|||
res = append(res, curTest...) //Add string itself
|
||||
expected = append(expected, res)
|
||||
for curIndex := range strings {
|
||||
if result := encodeVarString(strings[curIndex]); !bytes.Equal(result, expected[curIndex]) {
|
||||
t.Fatalf("expected byte of %b but got %b", expected[curIndex], result)
|
||||
}
|
||||
assert.Equal(t, expected[curIndex], encodeVarString(strings[curIndex]))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
package message_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm/message"
|
||||
)
|
||||
|
||||
|
|
@ -16,18 +17,10 @@ func TestGroupMessageDecode(t *testing.T) {
|
|||
|
||||
msg := message.GroupMessage{}
|
||||
err := msg.Decode(messageRaw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if msg.Version != 3 {
|
||||
t.Fatalf("Expected Version to be 3 but go %d", msg.Version)
|
||||
}
|
||||
if msg.MessageIndex != expectedMessageIndex {
|
||||
t.Fatalf("Expected message index to be %d but got %d", expectedMessageIndex, msg.MessageIndex)
|
||||
}
|
||||
if !bytes.Equal(msg.Ciphertext, expectedCipherText) {
|
||||
t.Fatalf("expected '%s' but got '%s'", expectedCipherText, msg.Ciphertext)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 3, msg.Version)
|
||||
assert.Equal(t, expectedMessageIndex, msg.MessageIndex)
|
||||
assert.Equal(t, expectedCipherText, msg.Ciphertext)
|
||||
}
|
||||
|
||||
func TestGroupMessageEncode(t *testing.T) {
|
||||
|
|
@ -40,12 +33,8 @@ func TestGroupMessageEncode(t *testing.T) {
|
|||
Ciphertext: []byte("ciphertext"),
|
||||
}
|
||||
encoded, err := msg.EncodeAndMacAndSign(nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
encoded = append(encoded, hmacsha256...)
|
||||
encoded = append(encoded, sign...)
|
||||
if !bytes.Equal(encoded, expectedRaw) {
|
||||
t.Fatalf("expected '%s' but got '%s'", expectedRaw, encoded)
|
||||
}
|
||||
assert.Equal(t, expectedRaw, encoded)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
package message_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm/message"
|
||||
)
|
||||
|
||||
|
|
@ -14,24 +15,12 @@ func TestMessageDecode(t *testing.T) {
|
|||
|
||||
msg := message.Message{}
|
||||
err := msg.Decode(messageRaw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if msg.Version != 3 {
|
||||
t.Fatalf("Expected Version to be 3 but go %d", msg.Version)
|
||||
}
|
||||
if !msg.HasCounter {
|
||||
t.Fatal("Expected to have counter")
|
||||
}
|
||||
if msg.Counter != 1 {
|
||||
t.Fatalf("Expected counter to be 1 but got %d", msg.Counter)
|
||||
}
|
||||
if !bytes.Equal(msg.Ciphertext, expectedCipherText) {
|
||||
t.Fatalf("expected '%s' but got '%s'", expectedCipherText, msg.Ciphertext)
|
||||
}
|
||||
if !bytes.Equal(msg.RatchetKey, expectedRatchetKey) {
|
||||
t.Fatalf("expected '%s' but got '%s'", expectedRatchetKey, msg.RatchetKey)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 3, msg.Version)
|
||||
assert.True(t, msg.HasCounter)
|
||||
assert.EqualValues(t, 1, msg.Counter)
|
||||
assert.Equal(t, expectedCipherText, msg.Ciphertext)
|
||||
assert.EqualValues(t, expectedRatchetKey, msg.RatchetKey)
|
||||
}
|
||||
|
||||
func TestMessageEncode(t *testing.T) {
|
||||
|
|
@ -44,11 +33,7 @@ func TestMessageEncode(t *testing.T) {
|
|||
Ciphertext: []byte("ciphertext"),
|
||||
}
|
||||
encoded, err := msg.EncodeAndMAC(nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
encoded = append(encoded, hmacsha256...)
|
||||
if !bytes.Equal(encoded, expectedRaw) {
|
||||
t.Fatalf("expected '%s' but got '%s'", expectedRaw, encoded)
|
||||
}
|
||||
assert.Equal(t, expectedRaw, encoded)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
package message_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
||||
"maunium.net/go/mautrix/crypto/goolm/message"
|
||||
)
|
||||
|
|
@ -19,29 +20,14 @@ func TestPreKeyMessageDecode(t *testing.T) {
|
|||
|
||||
msg := message.PreKeyMessage{}
|
||||
err := msg.Decode(messageRaw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if msg.Version != 3 {
|
||||
t.Fatalf("Expected Version to be 3 but go %d", msg.Version)
|
||||
}
|
||||
if !bytes.Equal(msg.OneTimeKey, expectedOneTimeKey) {
|
||||
t.Fatalf("expected '%s' but got '%s'", expectedOneTimeKey, msg.OneTimeKey)
|
||||
}
|
||||
if !bytes.Equal(msg.IdentityKey, expectedIdKey) {
|
||||
t.Fatalf("expected '%s' but got '%s'", expectedIdKey, msg.IdentityKey)
|
||||
}
|
||||
if !bytes.Equal(msg.BaseKey, expectedbaseKey) {
|
||||
t.Fatalf("expected '%s' but got '%s'", expectedbaseKey, msg.BaseKey)
|
||||
}
|
||||
if !bytes.Equal(msg.Message, expectedmessage) {
|
||||
t.Fatalf("expected '%s' but got '%s'", expectedmessage, msg.Message)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 3, msg.Version)
|
||||
assert.EqualValues(t, expectedOneTimeKey, msg.OneTimeKey)
|
||||
assert.EqualValues(t, expectedIdKey, msg.IdentityKey)
|
||||
assert.EqualValues(t, expectedbaseKey, msg.BaseKey)
|
||||
assert.Equal(t, expectedmessage, msg.Message)
|
||||
theirIDKey := crypto.Curve25519PublicKey(expectedIdKey)
|
||||
checked := msg.CheckFields(&theirIDKey)
|
||||
if !checked {
|
||||
t.Fatal("field check failed")
|
||||
}
|
||||
assert.True(t, msg.CheckFields(&theirIDKey), "field check failed")
|
||||
}
|
||||
|
||||
func TestPreKeyMessageEncode(t *testing.T) {
|
||||
|
|
@ -54,10 +40,6 @@ func TestPreKeyMessageEncode(t *testing.T) {
|
|||
Message: []byte("message"),
|
||||
}
|
||||
encoded, err := msg.Encode()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(encoded, expectedRaw) {
|
||||
t.Fatalf("got other than expected:\nExpected:\n%v\nGot:\n%v", expectedRaw, encoded)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedRaw, encoded)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
package pk_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
||||
"maunium.net/go/mautrix/crypto/goolm/goolmbase64"
|
||||
"maunium.net/go/mautrix/crypto/goolm/pk"
|
||||
)
|
||||
|
||||
|
|
@ -26,34 +26,20 @@ func TestEncryptionDecryption(t *testing.T) {
|
|||
}
|
||||
bobPublic := []byte("3p7bfXt9wbTTW2HC7OQ1Nz+DQ8hbeGdNrfx+FG+IK08")
|
||||
decryption, err := pk.NewDecryptionFromPrivate(alicePrivate)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal([]byte(decryption.PublicKey()), alicePublic) {
|
||||
t.Fatal("public key not correct")
|
||||
}
|
||||
if !bytes.Equal(decryption.PrivateKey(), alicePrivate) {
|
||||
t.Fatal("private key not correct")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, alicePublic, decryption.PublicKey(), "public key not correct")
|
||||
assert.EqualValues(t, alicePrivate, decryption.PrivateKey(), "private key not correct")
|
||||
|
||||
encryption, err := pk.NewEncryption(decryption.PublicKey())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
plaintext := []byte("This is a test")
|
||||
|
||||
ciphertext, mac, err := encryption.Encrypt(plaintext, bobPrivate)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
decrypted, err := decryption.Decrypt(bobPublic, mac, ciphertext)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(decrypted, plaintext) {
|
||||
t.Fatal("message not equal")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, plaintext, decrypted, "message not equal")
|
||||
}
|
||||
|
||||
func TestSigning(t *testing.T) {
|
||||
|
|
@ -66,29 +52,20 @@ func TestSigning(t *testing.T) {
|
|||
message := []byte("We hold these truths to be self-evident, that all men are created equal, that they are endowed by their Creator with certain unalienable Rights, that among these are Life, Liberty and the pursuit of Happiness.")
|
||||
signing, _ := pk.NewSigningFromSeed(seed)
|
||||
signature, err := signing.Sign(message)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
signatureDecoded, err := goolmbase64.Decode(signature)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
signatureDecoded, err := base64.RawStdEncoding.DecodeString(string(signature))
|
||||
assert.NoError(t, err)
|
||||
pubKeyEncoded := signing.PublicKey()
|
||||
pubKeyDecoded, err := base64.RawStdEncoding.DecodeString(string(pubKeyEncoded))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
pubKey := crypto.Ed25519PublicKey(pubKeyDecoded)
|
||||
|
||||
verified := pubKey.Verify(message, signatureDecoded)
|
||||
if !verified {
|
||||
t.Fatal("signature did not verify")
|
||||
}
|
||||
assert.True(t, verified, "signature did not verify")
|
||||
|
||||
copy(signatureDecoded[0:], []byte("m"))
|
||||
verified = pubKey.Verify(message, signatureDecoded)
|
||||
if verified {
|
||||
t.Fatal("signature did verify")
|
||||
}
|
||||
assert.False(t, verified, "signature verified with wrong message")
|
||||
}
|
||||
|
||||
func TestDecryptionPickling(t *testing.T) {
|
||||
|
|
@ -100,37 +77,19 @@ func TestDecryptionPickling(t *testing.T) {
|
|||
}
|
||||
alicePublic := []byte("hSDwCYkwp1R0i33ctD73Wg2/Og0mOBr066SpjqqbTmo")
|
||||
decryption, err := pk.NewDecryptionFromPrivate(alicePrivate)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal([]byte(decryption.PublicKey()), alicePublic) {
|
||||
t.Fatal("public key not correct")
|
||||
}
|
||||
if !bytes.Equal(decryption.PrivateKey(), alicePrivate) {
|
||||
t.Fatal("private key not correct")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, alicePublic, decryption.PublicKey(), "public key not correct")
|
||||
assert.EqualValues(t, alicePrivate, decryption.PrivateKey(), "private key not correct")
|
||||
pickleKey := []byte("secret_key")
|
||||
expectedPickle := []byte("qx37WTQrjZLz5tId/uBX9B3/okqAbV1ofl9UnHKno1eipByCpXleAAlAZoJgYnCDOQZDQWzo3luTSfkF9pU1mOILCbbouubs6TVeDyPfgGD9i86J8irHjA")
|
||||
pickled, err := decryption.Pickle(pickleKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(expectedPickle, pickled) {
|
||||
t.Fatalf("pickle not as expected:\n%v\n%v\n", pickled, expectedPickle)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, expectedPickle, pickled, "pickle not as expected")
|
||||
|
||||
newDecription, err := pk.NewDecryption()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
err = newDecription.Unpickle(pickled, pickleKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal([]byte(newDecription.PublicKey()), alicePublic) {
|
||||
t.Fatal("public key not correct")
|
||||
}
|
||||
if !bytes.Equal(newDecription.PrivateKey(), alicePrivate) {
|
||||
t.Fatal("private key not correct")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, alicePublic, newDecription.PublicKey(), "public key not correct")
|
||||
assert.EqualValues(t, alicePrivate, newDecription.PrivateKey(), "private key not correct")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
package ratchet_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm/cipher"
|
||||
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
||||
"maunium.net/go/mautrix/crypto/goolm/ratchet"
|
||||
|
|
@ -38,149 +39,90 @@ func initializeRatchets() (*ratchet.Ratchet, *ratchet.Ratchet, error) {
|
|||
|
||||
func TestSendReceive(t *testing.T) {
|
||||
aliceRatchet, bobRatchet, err := initializeRatchets()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
plainText := []byte("Hello Bob")
|
||||
|
||||
//Alice sends Bob a message
|
||||
encryptedMessage, err := aliceRatchet.Encrypt(plainText)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
decrypted, err := bobRatchet.Decrypt(encryptedMessage)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(plainText, decrypted) {
|
||||
t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plainText, decrypted)
|
||||
|
||||
//Bob sends Alice a message
|
||||
plainText = []byte("Hello Alice")
|
||||
encryptedMessage, err = bobRatchet.Encrypt(plainText)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
decrypted, err = aliceRatchet.Decrypt(encryptedMessage)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(plainText, decrypted) {
|
||||
t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plainText, decrypted)
|
||||
}
|
||||
|
||||
func TestOutOfOrder(t *testing.T) {
|
||||
aliceRatchet, bobRatchet, err := initializeRatchets()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
plainText1 := []byte("First Message")
|
||||
plainText2 := []byte("Second Messsage. A bit longer than the first.")
|
||||
|
||||
/* Alice sends Bob two messages and they arrive out of order */
|
||||
message1Encrypted, err := aliceRatchet.Encrypt(plainText1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
message2Encrypted, err := aliceRatchet.Encrypt(plainText2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
decrypted2, err := bobRatchet.Decrypt(message2Encrypted)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
decrypted1, err := bobRatchet.Decrypt(message1Encrypted)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(plainText1, decrypted1) {
|
||||
t.Fatalf("expected '%v' from decryption but got '%v'", plainText1, decrypted1)
|
||||
}
|
||||
if !bytes.Equal(plainText2, decrypted2) {
|
||||
t.Fatalf("expected '%v' from decryption but got '%v'", plainText2, decrypted2)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plainText1, decrypted1)
|
||||
assert.Equal(t, plainText2, decrypted2)
|
||||
}
|
||||
|
||||
func TestMoreMessages(t *testing.T) {
|
||||
aliceRatchet, bobRatchet, err := initializeRatchets()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
plainText := []byte("These 15 bytes")
|
||||
for i := 0; i < 8; i++ {
|
||||
messageEncrypted, err := aliceRatchet.Encrypt(plainText)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
decrypted, err := bobRatchet.Decrypt(messageEncrypted)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(plainText, decrypted) {
|
||||
t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plainText, decrypted)
|
||||
}
|
||||
for i := 0; i < 8; i++ {
|
||||
messageEncrypted, err := bobRatchet.Encrypt(plainText)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
decrypted, err := aliceRatchet.Decrypt(messageEncrypted)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(plainText, decrypted) {
|
||||
t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plainText, decrypted)
|
||||
}
|
||||
messageEncrypted, err := aliceRatchet.Encrypt(plainText)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
decrypted, err := bobRatchet.Decrypt(messageEncrypted)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(plainText, decrypted) {
|
||||
t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plainText, decrypted)
|
||||
}
|
||||
|
||||
func TestJSONEncoding(t *testing.T) {
|
||||
aliceRatchet, bobRatchet, err := initializeRatchets()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
marshaled, err := json.Marshal(aliceRatchet)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
newRatcher := ratchet.Ratchet{}
|
||||
err = json.Unmarshal(marshaled, &newRatcher)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
plainText := []byte("These 15 bytes")
|
||||
|
||||
messageEncrypted, err := newRatcher.Encrypt(plainText)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
decrypted, err := bobRatchet.Decrypt(messageEncrypted)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(plainText, decrypted) {
|
||||
t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted)
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plainText, decrypted)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
package session_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
||||
"maunium.net/go/mautrix/crypto/goolm/megolm"
|
||||
"maunium.net/go/mautrix/crypto/goolm/session"
|
||||
|
|
@ -15,78 +16,42 @@ import (
|
|||
func TestOutboundPickleJSON(t *testing.T) {
|
||||
pickleKey := []byte("secretKey")
|
||||
sess, err := session.NewMegolmOutboundSession()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
kp, err := crypto.Ed25519GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
sess.SigningKey = kp
|
||||
pickled, err := sess.PickleAsJSON(pickleKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
newSession := session.MegolmOutboundSession{}
|
||||
err = newSession.UnpickleAsJSON(pickled, pickleKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if sess.ID() != newSession.ID() {
|
||||
t.Fatal("session ids not equal")
|
||||
}
|
||||
if !bytes.Equal(sess.SigningKey.PrivateKey, newSession.SigningKey.PrivateKey) {
|
||||
t.Fatal("private keys not equal")
|
||||
}
|
||||
if !bytes.Equal(sess.Ratchet.Data[:], newSession.Ratchet.Data[:]) {
|
||||
t.Fatal("ratchet data not equal")
|
||||
}
|
||||
if sess.Ratchet.Counter != newSession.Ratchet.Counter {
|
||||
t.Fatal("ratchet counter not equal")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, sess.ID(), newSession.ID())
|
||||
assert.Equal(t, sess.SigningKey, newSession.SigningKey)
|
||||
assert.Equal(t, sess.Ratchet, newSession.Ratchet)
|
||||
}
|
||||
|
||||
func TestInboundPickleJSON(t *testing.T) {
|
||||
pickleKey := []byte("secretKey")
|
||||
sess := session.MegolmInboundSession{}
|
||||
kp, err := crypto.Ed25519GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
sess.SigningKey = kp.PublicKey
|
||||
var randomData [megolm.RatchetParts * megolm.RatchetPartLength]byte
|
||||
_, err = rand.Read(randomData[:])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
ratchet, err := megolm.New(0, randomData)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
sess.Ratchet = *ratchet
|
||||
pickled, err := sess.PickleAsJSON(pickleKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
newSession := session.MegolmInboundSession{}
|
||||
err = newSession.UnpickleAsJSON(pickled, pickleKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if sess.ID() != newSession.ID() {
|
||||
t.Fatal("sess ids not equal")
|
||||
}
|
||||
if !bytes.Equal(sess.SigningKey, newSession.SigningKey) {
|
||||
t.Fatal("private keys not equal")
|
||||
}
|
||||
if !bytes.Equal(sess.Ratchet.Data[:], newSession.Ratchet.Data[:]) {
|
||||
t.Fatal("ratchet data not equal")
|
||||
}
|
||||
if sess.Ratchet.Counter != newSession.Ratchet.Counter {
|
||||
t.Fatal("ratchet counter not equal")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, sess.ID(), newSession.ID())
|
||||
assert.Equal(t, sess.SigningKey, newSession.SigningKey)
|
||||
assert.Equal(t, sess.Ratchet, newSession.Ratchet)
|
||||
}
|
||||
|
||||
func TestGroupSendReceive(t *testing.T) {
|
||||
|
|
@ -100,46 +65,27 @@ func TestGroupSendReceive(t *testing.T) {
|
|||
)
|
||||
|
||||
outboundSession, err := session.NewMegolmOutboundSession()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
copy(outboundSession.Ratchet.Data[:], randomData)
|
||||
if outboundSession.Ratchet.Counter != 0 {
|
||||
t.Fatal("ratchet counter is not correkt")
|
||||
}
|
||||
assert.EqualValues(t, 0, outboundSession.Ratchet.Counter)
|
||||
|
||||
sessionSharing, err := outboundSession.SessionSharingMessage()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
plainText := []byte("Message")
|
||||
ciphertext, err := outboundSession.Encrypt(plainText)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if outboundSession.Ratchet.Counter != 1 {
|
||||
t.Fatal("ratchet counter is not correkt")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 1, outboundSession.Ratchet.Counter)
|
||||
|
||||
//build inbound session
|
||||
inboundSession, err := session.NewMegolmInboundSession(sessionSharing)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !inboundSession.SigningKeyVerified {
|
||||
t.Fatal("key not verified")
|
||||
}
|
||||
if inboundSession.ID() != outboundSession.ID() {
|
||||
t.Fatal("session ids not equal")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, inboundSession.SigningKeyVerified)
|
||||
assert.Equal(t, outboundSession.ID(), inboundSession.ID())
|
||||
|
||||
//decode message
|
||||
decoded, _, err := inboundSession.Decrypt(ciphertext)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(plainText, decoded) {
|
||||
t.Fatal("messages not equal")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plainText, decoded)
|
||||
}
|
||||
|
||||
func TestGroupSessionExportImport(t *testing.T) {
|
||||
|
|
@ -158,45 +104,26 @@ func TestGroupSessionExportImport(t *testing.T) {
|
|||
|
||||
//init inbound
|
||||
inboundSession, err := session.NewMegolmInboundSession(sessionKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !inboundSession.SigningKeyVerified {
|
||||
t.Fatal("signing key not verified")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, inboundSession.SigningKeyVerified)
|
||||
|
||||
decrypted, _, err := inboundSession.Decrypt(message)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(plaintext, decrypted) {
|
||||
t.Fatal("message is not correct")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plaintext, decrypted)
|
||||
|
||||
//Export the keys
|
||||
exported, err := inboundSession.Export(0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
secondInboundSession, err := session.NewMegolmInboundSessionFromExport(exported)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if secondInboundSession.SigningKeyVerified {
|
||||
t.Fatal("signing key is verified")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, secondInboundSession.SigningKeyVerified)
|
||||
|
||||
//decrypt with new session
|
||||
decrypted, _, err = secondInboundSession.Decrypt(message)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(plaintext, decrypted) {
|
||||
t.Fatal("message is not correct")
|
||||
}
|
||||
if !secondInboundSession.SigningKeyVerified {
|
||||
t.Fatal("signing key not verified")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plaintext, decrypted)
|
||||
assert.True(t, secondInboundSession.SigningKeyVerified)
|
||||
}
|
||||
|
||||
func TestBadSignatureGroupMessage(t *testing.T) {
|
||||
|
|
@ -215,70 +142,43 @@ func TestBadSignatureGroupMessage(t *testing.T) {
|
|||
|
||||
//init inbound
|
||||
inboundSession, err := session.NewMegolmInboundSession(sessionKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !inboundSession.SigningKeyVerified {
|
||||
t.Fatal("signing key not verified")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, inboundSession.SigningKeyVerified)
|
||||
|
||||
decrypted, _, err := inboundSession.Decrypt(message)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(plaintext, decrypted) {
|
||||
t.Fatal("message is not correct")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plaintext, decrypted)
|
||||
|
||||
//Now twiddle the signature
|
||||
copy(message[len(message)-1:], []byte("E"))
|
||||
_, _, err = inboundSession.Decrypt(message)
|
||||
if err == nil {
|
||||
t.Fatal("Signature was changed but did not cause an error")
|
||||
}
|
||||
if !errors.Is(err, olm.ErrBadSignature) {
|
||||
t.Fatalf("wrong error %s", err.Error())
|
||||
}
|
||||
assert.ErrorIs(t, err, olm.ErrBadSignature)
|
||||
}
|
||||
|
||||
func TestOutbountPickle(t *testing.T) {
|
||||
pickledDataFromLibOlm := []byte("icDKYm0b4aO23WgUuOxdpPoxC0UlEOYPVeuduNH3IkpFsmnWx5KuEOpxGiZw5IuB/sSn2RZUCTiJ90IvgC7AClkYGHep9O8lpiqQX73XVKD9okZDCAkBc83eEq0DKYC7HBkGRAU/4T6QPIBBY3UK4QZwULLE/fLsi3j4YZBehMtnlsqgHK0q1bvX4cRznZItUO3TiOp5I+6PnQka6n8eHTyIEh3tCetilD+BKnHvtakE0eHHvG6pjEsMNN/vs7lkB5rV6XkoUKHLTE1dAfFunYEeHEZuKQpbG385dBwaMJXt4JrC0hU5jnv6jWNqAA0Ud9GxRDvkp04")
|
||||
pickleKey := []byte("secret_key")
|
||||
sess, err := session.MegolmOutboundSessionFromPickled(pickledDataFromLibOlm, pickleKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
newPickled, err := sess.Pickle(pickleKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(pickledDataFromLibOlm, newPickled) {
|
||||
t.Fatal("pickled version does not equal libolm version")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, pickledDataFromLibOlm, newPickled)
|
||||
|
||||
pickledDataFromLibOlm = append(pickledDataFromLibOlm, []byte("a")...)
|
||||
_, err = session.MegolmOutboundSessionFromPickled(pickledDataFromLibOlm, pickleKey)
|
||||
if err == nil {
|
||||
t.Fatal("should have gotten an error")
|
||||
}
|
||||
assert.ErrorIs(t, err, olm.ErrBadMAC)
|
||||
}
|
||||
|
||||
func TestInbountPickle(t *testing.T) {
|
||||
pickledDataFromLibOlm := []byte("1/IPCdtUoQxMba5XT7sjjUW0Hrs7no9duGFnhsEmxzFX2H3qtRc4eaFBRZYXxOBRTGZ6eMgy3IiSrgAQ1gUlSZf5Q4AVKeBkhvN4LZ6hdhQFv91mM+C2C55/4B9/gDjJEbDGiRgLoMqbWPDV+y0F4h0KaR1V1PiTCC7zCi4WdxJQ098nJLgDL4VSsDbnaLcSMO60FOYgRN4KsLaKUGkXiiUBWp4boFMCiuTTOiyH8XlH0e9uWc0vMLyGNUcO8kCbpAnx3v1JTIVan3WGsnGv4K8Qu4M8GAkZewpexrsb2BSNNeLclOV9/cR203Y5KlzXcpiWNXSs8XoB3TLEtHYMnjuakMQfyrcXKIQntg4xPD/+wvfqkcMg9i7pcplQh7X2OK5ylrMZQrZkJ1fAYBGbBz1tykWOjfrZ")
|
||||
pickleKey := []byte("secret_key")
|
||||
sess, err := session.MegolmInboundSessionFromPickled(pickledDataFromLibOlm, pickleKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
newPickled, err := sess.Pickle(pickleKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(pickledDataFromLibOlm, newPickled) {
|
||||
t.Fatal("pickled version does not equal libolm version")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, pickledDataFromLibOlm, newPickled)
|
||||
|
||||
pickledDataFromLibOlm = append(pickledDataFromLibOlm, []byte("a")...)
|
||||
_, err = session.MegolmInboundSessionFromPickled(pickledDataFromLibOlm, pickleKey)
|
||||
if err == nil {
|
||||
t.Fatal("should have gotten an error")
|
||||
}
|
||||
assert.ErrorIs(t, err, base64.CorruptInputError(416))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
package session_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
||||
"maunium.net/go/mautrix/crypto/goolm/session"
|
||||
"maunium.net/go/mautrix/crypto/olm"
|
||||
|
|
@ -15,30 +15,18 @@ import (
|
|||
func TestOlmSession(t *testing.T) {
|
||||
pickleKey := []byte("secretKey")
|
||||
aliceKeyPair, err := crypto.Curve25519GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
bobKeyPair, err := crypto.Curve25519GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
bobOneTimeKey, err := crypto.Curve25519GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
aliceSession, err := session.NewOutboundOlmSession(aliceKeyPair, bobKeyPair.PublicKey, bobOneTimeKey.PublicKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
//create a message so that there are more keys to marshal
|
||||
plaintext := []byte("Test message from Alice to Bob")
|
||||
msgType, message, err := aliceSession.Encrypt(plaintext)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if msgType != id.OlmMsgTypePreKey {
|
||||
t.Fatal("Wrong message type")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, id.OlmMsgTypePreKey, msgType)
|
||||
|
||||
searchFunc := func(target crypto.Curve25519PublicKey) *crypto.OneTimeKey {
|
||||
if target.Equal(bobOneTimeKey.PublicKey) {
|
||||
|
|
@ -52,92 +40,58 @@ func TestOlmSession(t *testing.T) {
|
|||
}
|
||||
//bob receives message
|
||||
bobSession, err := session.NewInboundOlmSession(nil, message, searchFunc, bobKeyPair)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
decryptedMsg, err := bobSession.Decrypt(string(message), msgType)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(plaintext, decryptedMsg) {
|
||||
t.Fatalf("messages are not equal:\n%v\n%v\n", plaintext, decryptedMsg)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plaintext, decryptedMsg)
|
||||
|
||||
// Alice pickles session
|
||||
pickled, err := aliceSession.PickleAsJSON(pickleKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
//bob sends a message
|
||||
plaintext = []byte("A message from Bob to Alice")
|
||||
msgType, message, err = bobSession.Encrypt(plaintext)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if msgType != id.OlmMsgTypeMsg {
|
||||
t.Fatal("Wrong message type")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, id.OlmMsgTypeMsg, msgType)
|
||||
|
||||
//Alice unpickles session
|
||||
newAliceSession, err := session.OlmSessionFromJSONPickled(pickled, pickleKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
//Alice receives message
|
||||
decryptedMsg, err = newAliceSession.Decrypt(string(message), msgType)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(plaintext, decryptedMsg) {
|
||||
t.Fatalf("messages are not equal:\n%v\n%v\n", plaintext, decryptedMsg)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plaintext, decryptedMsg)
|
||||
|
||||
//Alice receives message again
|
||||
_, err = newAliceSession.Decrypt(string(message), msgType)
|
||||
if err == nil {
|
||||
t.Fatal("should have gotten an error")
|
||||
}
|
||||
assert.ErrorIs(t, err, olm.ErrMessageKeyNotFound)
|
||||
|
||||
//Alice sends another message
|
||||
plaintext = []byte("A second message to Bob")
|
||||
msgType, message, err = newAliceSession.Encrypt(plaintext)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if msgType != id.OlmMsgTypeMsg {
|
||||
t.Fatal("Wrong message type")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, id.OlmMsgTypeMsg, msgType)
|
||||
|
||||
//bob receives message
|
||||
decryptedMsg, err = bobSession.Decrypt(string(message), msgType)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(plaintext, decryptedMsg) {
|
||||
t.Fatalf("messages are not equal:\n%v\n%v\n", plaintext, decryptedMsg)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plaintext, decryptedMsg)
|
||||
}
|
||||
|
||||
func TestSessionPickle(t *testing.T) {
|
||||
pickledDataFromLibOlm := []byte("icDKYm0b4aO23WgUuOxdpPoxC0UlEOYPVeuduNH3IkpFsmnWx5KuEOpxGiZw5IuB/sSn2RZUCTiJ90IvgC7AClkYGHep9O8lpiqQX73XVKD9okZDCAkBc83eEq0DKYC7HBkGRAU/4T6QPIBBY3UK4QZwULLE/fLsi3j4YZBehMtnlsqgHK0q1bvX4cRznZItVKR4ro0O9EAk6LLxJtSnRu5elSUk7YXT")
|
||||
pickleKey := []byte("secret_key")
|
||||
sess, err := session.OlmSessionFromPickled(pickledDataFromLibOlm, pickleKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
newPickled, err := sess.Pickle(pickleKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(pickledDataFromLibOlm, newPickled) {
|
||||
t.Fatal("pickled version does not equal libolm version")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, pickledDataFromLibOlm, newPickled)
|
||||
|
||||
pickledDataFromLibOlm = append(pickledDataFromLibOlm, []byte("a")...)
|
||||
_, err = session.OlmSessionFromPickled(pickledDataFromLibOlm, pickleKey)
|
||||
if err == nil {
|
||||
t.Fatal("should have gotten an error")
|
||||
}
|
||||
assert.ErrorIs(t, err, base64.CorruptInputError(224))
|
||||
}
|
||||
|
||||
func TestDecrypts(t *testing.T) {
|
||||
|
|
@ -161,17 +115,9 @@ func TestDecrypts(t *testing.T) {
|
|||
"dGvPXeH8qLeNZA")
|
||||
pickleKey := []byte("")
|
||||
sess, err := session.OlmSessionFromPickled(sessionPickled, pickleKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
for curIndex, curMessage := range messages {
|
||||
_, err := sess.Decrypt(string(curMessage), id.OlmMsgTypePreKey)
|
||||
if err != nil {
|
||||
if !errors.Is(err, expectedErr[curIndex]) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
t.Fatal("error expected")
|
||||
}
|
||||
assert.ErrorIs(t, err, expectedErr[curIndex])
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue