mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
crypto: replace t.Fatal and t.Error with require and assert
Signed-off-by: Sumner Evans <me@sumnerevans.com>
This commit is contained in:
parent
09e4706fdb
commit
654b6b1d45
6 changed files with 209 additions and 349 deletions
|
|
@ -7,11 +7,13 @@
|
|||
package aescbc_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/aescbc"
|
||||
)
|
||||
|
||||
|
|
@ -22,32 +24,23 @@ func TestAESCBC(t *testing.T) {
|
|||
// The key length can be 32, 24, 16 bytes (OR in bits: 128, 192 or 256)
|
||||
key := make([]byte, 32)
|
||||
_, err = rand.Read(key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
iv := make([]byte, aes.BlockSize)
|
||||
_, err = rand.Read(iv)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
plaintext = []byte("secret message for testing")
|
||||
//increase to next block size
|
||||
for len(plaintext)%8 != 0 {
|
||||
plaintext = append(plaintext, []byte("-")...)
|
||||
}
|
||||
|
||||
if ciphertext, err = aescbc.Encrypt(key, iv, plaintext); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ciphertext, err = aescbc.Encrypt(key, iv, plaintext)
|
||||
require.NoError(t, err)
|
||||
|
||||
resultPlainText, err := aescbc.Decrypt(key, iv, ciphertext)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
if string(resultPlainText) != string(plaintext) {
|
||||
t.Fatalf("message '%s' (length %d) != '%s'", resultPlainText, len(resultPlainText), plaintext)
|
||||
}
|
||||
assert.Equal(t, string(resultPlainText), string(plaintext))
|
||||
}
|
||||
|
||||
func TestAESCBCCase1(t *testing.T) {
|
||||
|
|
@ -61,18 +54,10 @@ func TestAESCBCCase1(t *testing.T) {
|
|||
key := make([]byte, 32)
|
||||
iv := make([]byte, aes.BlockSize)
|
||||
encrypted, err := aescbc.Encrypt(key, iv, input)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(expected, encrypted) {
|
||||
t.Fatalf("encrypted did not match expected:\n%v\n%v\n", encrypted, expected)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expected, encrypted, "encrypted output does not match expected")
|
||||
|
||||
decrypted, err := aescbc.Decrypt(key, iv, encrypted)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(input, decrypted) {
|
||||
t.Fatalf("decrypted did not match expected:\n%v\n%v\n", decrypted, input)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, input, decrypted, "decrypted output does not match input")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,31 +17,43 @@ package canonicaljson
|
|||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func testSortJSON(t *testing.T, input, want string) {
|
||||
got := SortJSON([]byte(input), nil)
|
||||
|
||||
// Squash out the whitespace before comparing the JSON in case SortJSON had inserted whitespace.
|
||||
if string(CompactJSON(got, nil)) != want {
|
||||
t.Errorf("SortJSON(%q): want %q got %q", input, want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSortJSON(t *testing.T) {
|
||||
testSortJSON(t, `[{"b":"two","a":1}]`, `[{"a":1,"b":"two"}]`)
|
||||
testSortJSON(t, `{"B":{"4":4,"3":3},"A":{"1":1,"2":2}}`,
|
||||
`{"A":{"1":1,"2":2},"B":{"3":3,"4":4}}`)
|
||||
testSortJSON(t, `[true,false,null]`, `[true,false,null]`)
|
||||
testSortJSON(t, `[9007199254740991]`, `[9007199254740991]`)
|
||||
testSortJSON(t, "\t\n[9007199254740991]", `[9007199254740991]`)
|
||||
var tests = []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"{}", "{}"},
|
||||
{`[{"b":"two","a":1}]`, `[{"a":1,"b":"two"}]`},
|
||||
{`{"B":{"4":4,"3":3},"A":{"1":1,"2":2}}`, `{"A":{"1":1,"2":2},"B":{"3":3,"4":4}}`},
|
||||
{`[true,false,null]`, `[true,false,null]`},
|
||||
{`[9007199254740991]`, `[9007199254740991]`},
|
||||
{"\t\n[9007199254740991]", `[9007199254740991]`},
|
||||
{`[true,false,null]`, `[true,false,null]`},
|
||||
{`[{"b":"two","a":1}]`, `[{"a":1,"b":"two"}]`},
|
||||
{`{"B":{"4":4,"3":3},"A":{"1":1,"2":2}}`, `{"A":{"1":1,"2":2},"B":{"3":3,"4":4}}`},
|
||||
{`[true,false,null]`, `[true,false,null]`},
|
||||
{`[9007199254740991]`, `[9007199254740991]`},
|
||||
{"\t\n[9007199254740991]", `[9007199254740991]`},
|
||||
{`[true,false,null]`, `[true,false,null]`},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.input, func(t *testing.T) {
|
||||
got := SortJSON([]byte(test.input), nil)
|
||||
|
||||
// Squash out the whitespace before comparing the JSON in case SortJSON had inserted whitespace.
|
||||
assert.EqualValues(t, test.want, string(CompactJSON(got, nil)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testCompactJSON(t *testing.T, input, want string) {
|
||||
t.Helper()
|
||||
got := string(CompactJSON([]byte(input), nil))
|
||||
if got != want {
|
||||
t.Errorf("CompactJSON(%q): want %q got %q", input, want, got)
|
||||
}
|
||||
assert.EqualValues(t, want, got)
|
||||
}
|
||||
|
||||
func TestCompactJSON(t *testing.T) {
|
||||
|
|
@ -74,18 +86,23 @@ func TestCompactJSON(t *testing.T) {
|
|||
testCompactJSON(t, `["\"\\\/"]`, `["\"\\/"]`)
|
||||
}
|
||||
|
||||
func testReadHex(t *testing.T, input string, want uint32) {
|
||||
got := readHexDigits([]byte(input))
|
||||
if want != got {
|
||||
t.Errorf("readHexDigits(%q): want 0x%x got 0x%x", input, want, got)
|
||||
func TestReadHex(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want uint32
|
||||
}{
|
||||
|
||||
{"0123", 0x0123},
|
||||
{"4567", 0x4567},
|
||||
{"89AB", 0x89AB},
|
||||
{"CDEF", 0xCDEF},
|
||||
{"89ab", 0x89AB},
|
||||
{"cdef", 0xCDEF},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.input, func(t *testing.T) {
|
||||
got := readHexDigits([]byte(test.input))
|
||||
assert.Equal(t, test.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadHex(t *testing.T) {
|
||||
testReadHex(t, "0123", 0x0123)
|
||||
testReadHex(t, "4567", 0x4567)
|
||||
testReadHex(t, "89AB", 0x89AB)
|
||||
testReadHex(t, "CDEF", 0xCDEF)
|
||||
testReadHex(t, "89ab", 0x89AB)
|
||||
testReadHex(t, "cdef", 0xCDEF)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.mau.fi/util/dbutil"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
|
|
@ -24,17 +26,12 @@ var noopLogger = zerolog.Nop()
|
|||
|
||||
func getOlmMachine(t *testing.T) *OlmMachine {
|
||||
rawDB, err := sql.Open("sqlite3", ":memory:?_busy_timeout=5000")
|
||||
if err != nil {
|
||||
t.Fatalf("Error opening db: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "Error opening raw database")
|
||||
db, err := dbutil.NewWithDB(rawDB, "sqlite3")
|
||||
if err != nil {
|
||||
t.Fatalf("Error opening db: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "Error creating database wrapper")
|
||||
sqlStore := NewSQLCryptoStore(db, nil, "accid", id.DeviceID("dev"), []byte("test"))
|
||||
if err = sqlStore.DB.Upgrade(context.TODO()); err != nil {
|
||||
t.Fatalf("Error creating tables: %v", err)
|
||||
}
|
||||
err = sqlStore.DB.Upgrade(context.TODO())
|
||||
require.NoError(t, err, "Error upgrading database")
|
||||
|
||||
userID := id.UserID("@mautrix")
|
||||
mk, _ := olm.NewPKSigning()
|
||||
|
|
@ -66,29 +63,25 @@ func TestTrustOwnDevice(t *testing.T) {
|
|||
DeviceID: "device",
|
||||
SigningKey: id.Ed25519("deviceKey"),
|
||||
}
|
||||
if m.IsDeviceTrusted(context.TODO(), ownDevice) {
|
||||
t.Error("Own device trusted while it shouldn't be")
|
||||
}
|
||||
assert.False(t, m.IsDeviceTrusted(context.TODO(), ownDevice), "Own device trusted while it shouldn't be")
|
||||
|
||||
m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(),
|
||||
ownDevice.UserID, m.CrossSigningKeys.MasterKey.PublicKey(), "sig1")
|
||||
m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, ownDevice.SigningKey,
|
||||
ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(), "sig2")
|
||||
|
||||
if trusted, _ := m.IsUserTrusted(context.TODO(), ownDevice.UserID); !trusted {
|
||||
t.Error("Own user not trusted while they should be")
|
||||
}
|
||||
if !m.IsDeviceTrusted(context.TODO(), ownDevice) {
|
||||
t.Error("Own device not trusted while it should be")
|
||||
}
|
||||
trusted, err := m.IsUserTrusted(context.TODO(), ownDevice.UserID)
|
||||
require.NoError(t, err, "Error checking if own user is trusted")
|
||||
assert.True(t, trusted, "Own user not trusted while they should be")
|
||||
assert.True(t, m.IsDeviceTrusted(context.TODO(), ownDevice), "Own device not trusted while it should be")
|
||||
}
|
||||
|
||||
func TestTrustOtherUser(t *testing.T) {
|
||||
m := getOlmMachine(t)
|
||||
otherUser := id.UserID("@user")
|
||||
if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); trusted {
|
||||
t.Error("Other user trusted while they shouldn't be")
|
||||
}
|
||||
trusted, err := m.IsUserTrusted(context.TODO(), otherUser)
|
||||
require.NoError(t, err, "Error checking if other user is trusted")
|
||||
assert.False(t, trusted, "Other user trusted while they shouldn't be")
|
||||
|
||||
theirMasterKey, _ := olm.NewPKSigning()
|
||||
m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey())
|
||||
|
|
@ -100,16 +93,16 @@ func TestTrustOtherUser(t *testing.T) {
|
|||
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(),
|
||||
m.Client.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(), "invalid_sig")
|
||||
|
||||
if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); trusted {
|
||||
t.Error("Other user trusted before their master key has been signed with our user-signing key")
|
||||
}
|
||||
trusted, err = m.IsUserTrusted(context.TODO(), otherUser)
|
||||
require.NoError(t, err, "Error checking if other user is trusted")
|
||||
assert.False(t, trusted, "Other user trusted before their master key has been signed with our user-signing key")
|
||||
|
||||
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(),
|
||||
m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey(), "sig2")
|
||||
|
||||
if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); !trusted {
|
||||
t.Error("Other user not trusted while they should be")
|
||||
}
|
||||
trusted, err = m.IsUserTrusted(context.TODO(), otherUser)
|
||||
require.NoError(t, err, "Error checking if other user is trusted")
|
||||
assert.True(t, trusted, "Other user not trusted while they should be")
|
||||
}
|
||||
|
||||
func TestTrustOtherDevice(t *testing.T) {
|
||||
|
|
@ -120,12 +113,11 @@ func TestTrustOtherDevice(t *testing.T) {
|
|||
DeviceID: "theirDevice",
|
||||
SigningKey: id.Ed25519("theirDeviceKey"),
|
||||
}
|
||||
if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); trusted {
|
||||
t.Error("Other user trusted while they shouldn't be")
|
||||
}
|
||||
if m.IsDeviceTrusted(context.TODO(), theirDevice) {
|
||||
t.Error("Other device trusted while it shouldn't be")
|
||||
}
|
||||
|
||||
trusted, err := m.IsUserTrusted(context.TODO(), otherUser)
|
||||
require.NoError(t, err, "Error checking if other user is trusted")
|
||||
assert.False(t, trusted, "Other user trusted while they shouldn't be")
|
||||
assert.False(t, m.IsDeviceTrusted(context.TODO(), theirDevice), "Other device trusted while it shouldn't be")
|
||||
|
||||
theirMasterKey, _ := olm.NewPKSigning()
|
||||
m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey())
|
||||
|
|
@ -137,21 +129,17 @@ func TestTrustOtherDevice(t *testing.T) {
|
|||
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(),
|
||||
m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey(), "sig2")
|
||||
|
||||
if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); !trusted {
|
||||
t.Error("Other user not trusted while they should be")
|
||||
}
|
||||
trusted, err = m.IsUserTrusted(context.TODO(), otherUser)
|
||||
require.NoError(t, err, "Error checking if other user is trusted")
|
||||
assert.True(t, trusted, "Other user not trusted while they should be")
|
||||
|
||||
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirSSK.PublicKey(),
|
||||
otherUser, theirMasterKey.PublicKey(), "sig3")
|
||||
|
||||
if m.IsDeviceTrusted(context.TODO(), theirDevice) {
|
||||
t.Error("Other device trusted before it has been signed with user's SSK")
|
||||
}
|
||||
assert.False(t, m.IsDeviceTrusted(context.TODO(), theirDevice), "Other device trusted before it has been signed with user's SSK")
|
||||
|
||||
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirDevice.SigningKey,
|
||||
otherUser, theirSSK.PublicKey(), "sig4")
|
||||
|
||||
if !m.IsDeviceTrusted(context.TODO(), theirDevice) {
|
||||
t.Error("Other device not trusted while it should be")
|
||||
}
|
||||
assert.True(t, m.IsDeviceTrusted(context.TODO(), theirDevice), "Other device not trusted after it has been signed with user's SSK")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -36,20 +36,15 @@ func (mockStateStore) FindSharedRooms(context.Context, id.UserID) ([]id.RoomID,
|
|||
|
||||
func newMachine(t *testing.T, userID id.UserID) *OlmMachine {
|
||||
client, err := mautrix.NewClient("http://localhost", userID, "token")
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating client: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "Error creating client")
|
||||
client.DeviceID = "device1"
|
||||
|
||||
gobStore := NewMemoryStore(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating Gob store: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "Error creating Gob store")
|
||||
|
||||
machine := NewOlmMachine(client, nil, gobStore, mockStateStore{})
|
||||
if err := machine.Load(context.TODO()); err != nil {
|
||||
t.Fatalf("Error creating account: %v", err)
|
||||
}
|
||||
err = machine.Load(context.TODO())
|
||||
require.NoError(t, err, "Error creating account")
|
||||
|
||||
return machine
|
||||
}
|
||||
|
|
@ -82,9 +77,7 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) {
|
|||
|
||||
// create outbound olm session for sending machine using OTK
|
||||
olmSession, err := machineOut.account.Internal.NewOutboundSession(machineIn.account.IdentityKey(), otk.Key)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to create outbound olm session: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "Error creating outbound olm session")
|
||||
|
||||
// store sender device identity in receiving machine store
|
||||
machineIn.CryptoStore.PutDevices(context.TODO(), "user1", map[id.DeviceID]*id.Device{
|
||||
|
|
@ -121,29 +114,21 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) {
|
|||
Type: event.ToDeviceEncrypted,
|
||||
Sender: "user1",
|
||||
}, senderKey, content.Type, content.Body)
|
||||
if err != nil {
|
||||
t.Errorf("Error decrypting olm content: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "Error decrypting olm ciphertext")
|
||||
|
||||
// store room key in new inbound group session
|
||||
roomKeyEvt := decrypted.Content.AsRoomKey()
|
||||
igs, err := NewInboundGroupSession(senderKey, signingKey, "room1", roomKeyEvt.SessionKey, 0, 0, false)
|
||||
if err != nil {
|
||||
t.Errorf("Error creating inbound megolm session: %v", err)
|
||||
}
|
||||
if err = machineIn.CryptoStore.PutGroupSession(context.TODO(), igs); err != nil {
|
||||
t.Errorf("Error storing inbound megolm session: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "Error creating inbound group session")
|
||||
err = machineIn.CryptoStore.PutGroupSession(context.TODO(), igs)
|
||||
require.NoError(t, err, "Error storing inbound group session")
|
||||
}
|
||||
|
||||
// encrypt event with megolm session in sending machine
|
||||
eventContent := map[string]string{"hello": "world"}
|
||||
encryptedEvtContent, err := machineOut.EncryptMegolmEvent(context.TODO(), "room1", event.EventMessage, eventContent)
|
||||
if err != nil {
|
||||
t.Errorf("Error encrypting megolm event: %v", err)
|
||||
}
|
||||
if megolmOutSession.MessageCount != 1 {
|
||||
t.Errorf("Megolm outbound session message count is not 1 but %d", megolmOutSession.MessageCount)
|
||||
}
|
||||
require.NoError(t, err, "Error encrypting megolm event")
|
||||
assert.Equal(t, 1, megolmOutSession.MessageCount)
|
||||
|
||||
encryptedEvt := &event.Event{
|
||||
Content: event.Content{Parsed: encryptedEvtContent},
|
||||
|
|
@ -155,22 +140,12 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) {
|
|||
|
||||
// decrypt event on receiving machine and confirm
|
||||
decryptedEvt, err := machineIn.DecryptMegolmEvent(context.TODO(), encryptedEvt)
|
||||
if err != nil {
|
||||
t.Errorf("Error decrypting megolm event: %v", err)
|
||||
}
|
||||
if decryptedEvt.Type != event.EventMessage {
|
||||
t.Errorf("Expected event type %v, got %v", event.EventMessage, decryptedEvt.Type)
|
||||
}
|
||||
if decryptedEvt.Content.Raw["hello"] != "world" {
|
||||
t.Errorf("Expected event content %v, got %v", eventContent, decryptedEvt.Content.Raw)
|
||||
}
|
||||
require.NoError(t, err, "Error decrypting megolm event")
|
||||
assert.Equal(t, event.EventMessage, decryptedEvt.Type)
|
||||
assert.Equal(t, "world", decryptedEvt.Content.Raw["hello"])
|
||||
|
||||
machineOut.EncryptMegolmEvent(context.TODO(), "room1", event.EventMessage, eventContent)
|
||||
if megolmOutSession.Expired() {
|
||||
t.Error("Megolm outbound session expired before 3rd message")
|
||||
}
|
||||
assert.False(t, megolmOutSession.Expired(), "Megolm outbound session expired before 3rd message")
|
||||
machineOut.EncryptMegolmEvent(context.TODO(), "room1", event.EventMessage, eventContent)
|
||||
if !megolmOutSession.Expired() {
|
||||
t.Error("Megolm outbound session not expired after 3rd message")
|
||||
}
|
||||
assert.True(t, megolmOutSession.Expired(), "Megolm outbound session not expired after 3rd message")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import (
|
|||
"testing"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.mau.fi/util/dbutil"
|
||||
|
||||
|
|
@ -29,22 +30,14 @@ const groupSession = "9ZbsRqJuETbjnxPpKv29n3dubP/m5PSLbr9I9CIWS2O86F/Og1JZXhqT+4
|
|||
|
||||
func getCryptoStores(t *testing.T) map[string]Store {
|
||||
rawDB, err := sql.Open("sqlite3", ":memory:?_busy_timeout=5000")
|
||||
if err != nil {
|
||||
t.Fatalf("Error opening db: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "Error opening raw database")
|
||||
db, err := dbutil.NewWithDB(rawDB, "sqlite3")
|
||||
if err != nil {
|
||||
t.Fatalf("Error opening db: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "Error creating database wrapper")
|
||||
sqlStore := NewSQLCryptoStore(db, nil, "accid", id.DeviceID("dev"), []byte("test"))
|
||||
if err = sqlStore.DB.Upgrade(context.TODO()); err != nil {
|
||||
t.Fatalf("Error creating tables: %v", err)
|
||||
}
|
||||
err = sqlStore.DB.Upgrade(context.TODO())
|
||||
require.NoError(t, err, "Error upgrading database")
|
||||
|
||||
gobStore := NewMemoryStore(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating Gob store: %v", err)
|
||||
}
|
||||
|
||||
return map[string]Store{
|
||||
"sql": sqlStore,
|
||||
|
|
@ -56,9 +49,10 @@ func TestPutNextBatch(t *testing.T) {
|
|||
stores := getCryptoStores(t)
|
||||
store := stores["sql"].(*SQLCryptoStore)
|
||||
store.PutNextBatch(context.Background(), "batch1")
|
||||
if batch, _ := store.GetNextBatch(context.Background()); batch != "batch1" {
|
||||
t.Errorf("Expected batch1, got %v", batch)
|
||||
}
|
||||
|
||||
batch, err := store.GetNextBatch(context.Background())
|
||||
require.NoError(t, err, "Error retrieving next batch")
|
||||
assert.Equal(t, "batch1", batch)
|
||||
}
|
||||
|
||||
func TestPutAccount(t *testing.T) {
|
||||
|
|
@ -68,15 +62,9 @@ func TestPutAccount(t *testing.T) {
|
|||
acc := NewOlmAccount()
|
||||
store.PutAccount(context.TODO(), acc)
|
||||
retrieved, err := store.GetAccount(context.TODO())
|
||||
if err != nil {
|
||||
t.Fatalf("Error retrieving account: %v", err)
|
||||
}
|
||||
if acc.IdentityKey() != retrieved.IdentityKey() {
|
||||
t.Errorf("Stored identity key %v, got %v", acc.IdentityKey(), retrieved.IdentityKey())
|
||||
}
|
||||
if acc.SigningKey() != retrieved.SigningKey() {
|
||||
t.Errorf("Stored signing key %v, got %v", acc.SigningKey(), retrieved.SigningKey())
|
||||
}
|
||||
require.NoError(t, err, "Error retrieving account")
|
||||
assert.Equal(t, acc.IdentityKey(), retrieved.IdentityKey(), "Identity key does not match")
|
||||
assert.Equal(t, acc.SigningKey(), retrieved.SigningKey(), "Signing key does not match")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -86,18 +74,26 @@ func TestValidateMessageIndex(t *testing.T) {
|
|||
for storeName, store := range stores {
|
||||
t.Run(storeName, func(t *testing.T) {
|
||||
acc := NewOlmAccount()
|
||||
if ok, _ := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000); !ok {
|
||||
t.Error("First message not validated successfully")
|
||||
}
|
||||
if ok, _ := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1001); ok {
|
||||
t.Error("First message validated successfully after changing timestamp")
|
||||
}
|
||||
if ok, _ := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event2", 0, 1000); ok {
|
||||
t.Error("First message validated successfully after changing event ID")
|
||||
}
|
||||
if ok, _ := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000); !ok {
|
||||
t.Error("First message not validated successfully for a second time")
|
||||
}
|
||||
|
||||
// First message should validate successfully
|
||||
ok, err := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000)
|
||||
require.NoError(t, err, "Error validating message index")
|
||||
assert.True(t, ok, "First message validation should be valid")
|
||||
|
||||
// Edit the timestamp and ensure validate fails
|
||||
ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1001)
|
||||
require.NoError(t, err, "Error validating message index after timestamp change")
|
||||
assert.False(t, ok, "First message validation should fail after timestamp change")
|
||||
|
||||
// Edit the event ID and ensure validate fails
|
||||
ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event2", 0, 1000)
|
||||
require.NoError(t, err, "Error validating message index after event ID change")
|
||||
assert.False(t, ok, "First message validation should fail after event ID change")
|
||||
|
||||
// Validate again with the original parameters and ensure that it still passes
|
||||
ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000)
|
||||
require.NoError(t, err, "Error validating message index")
|
||||
assert.True(t, ok, "First message validation should be valid")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -106,43 +102,26 @@ func TestStoreOlmSession(t *testing.T) {
|
|||
stores := getCryptoStores(t)
|
||||
for storeName, store := range stores {
|
||||
t.Run(storeName, func(t *testing.T) {
|
||||
if store.HasSession(context.TODO(), olmSessID) {
|
||||
t.Error("Found Olm session before inserting it")
|
||||
}
|
||||
require.False(t, store.HasSession(context.TODO(), olmSessID), "Found Olm session before inserting it")
|
||||
|
||||
olmInternal, err := olm.SessionFromPickled([]byte(olmPickled), []byte("test"))
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating internal Olm session: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "Error creating internal Olm session")
|
||||
|
||||
olmSess := OlmSession{
|
||||
id: olmSessID,
|
||||
Internal: olmInternal,
|
||||
}
|
||||
err = store.AddSession(context.TODO(), olmSessID, &olmSess)
|
||||
if err != nil {
|
||||
t.Errorf("Error storing Olm session: %v", err)
|
||||
}
|
||||
if !store.HasSession(context.TODO(), olmSessID) {
|
||||
t.Error("Not found Olm session after inserting it")
|
||||
}
|
||||
require.NoError(t, err, "Error storing Olm session")
|
||||
assert.True(t, store.HasSession(context.TODO(), olmSessID), "Olm session not found after inserting it")
|
||||
|
||||
retrieved, err := store.GetLatestSession(context.TODO(), olmSessID)
|
||||
if err != nil {
|
||||
t.Errorf("Failed retrieving Olm session: %v", err)
|
||||
}
|
||||
|
||||
if retrieved.ID() != olmSessID {
|
||||
t.Errorf("Expected session ID to be %v, got %v", olmSessID, retrieved.ID())
|
||||
}
|
||||
require.NoError(t, err, "Error retrieving Olm session")
|
||||
assert.EqualValues(t, olmSessID, retrieved.ID())
|
||||
|
||||
pickled, err := retrieved.Internal.Pickle([]byte("test"))
|
||||
if err != nil {
|
||||
t.Fatalf("Error pickling Olm session: %v", err)
|
||||
}
|
||||
|
||||
if string(pickled) != olmPickled {
|
||||
t.Error("Pickled Olm session does not match original")
|
||||
}
|
||||
require.NoError(t, err, "Error pickling Olm session")
|
||||
assert.EqualValues(t, pickled, olmPickled, "Pickled Olm session does not match original")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -154,9 +133,7 @@ func TestStoreMegolmSession(t *testing.T) {
|
|||
acc := NewOlmAccount()
|
||||
|
||||
internal, err := olm.InboundGroupSessionFromPickled([]byte(groupSession), []byte("test"))
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating internal inbound group session: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "Error creating internal inbound group session")
|
||||
|
||||
igs := &InboundGroupSession{
|
||||
Internal: internal,
|
||||
|
|
@ -166,20 +143,14 @@ func TestStoreMegolmSession(t *testing.T) {
|
|||
}
|
||||
|
||||
err = store.PutGroupSession(context.TODO(), igs)
|
||||
if err != nil {
|
||||
t.Errorf("Error storing inbound group session: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "Error storing inbound group session")
|
||||
|
||||
retrieved, err := store.GetGroupSession(context.TODO(), "room1", igs.ID())
|
||||
if err != nil {
|
||||
t.Errorf("Error retrieving inbound group session: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "Error retrieving inbound group session")
|
||||
|
||||
if pickled, err := retrieved.Internal.Pickle([]byte("test")); err != nil {
|
||||
t.Fatalf("Error pickling inbound group session: %v", err)
|
||||
} else if string(pickled) != groupSession {
|
||||
t.Error("Pickled inbound group session does not match original")
|
||||
}
|
||||
pickled, err := retrieved.Internal.Pickle([]byte("test"))
|
||||
require.NoError(t, err, "Error pickling inbound group session")
|
||||
assert.EqualValues(t, pickled, groupSession, "Pickled inbound group session does not match original")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -189,40 +160,24 @@ func TestStoreOutboundMegolmSession(t *testing.T) {
|
|||
for storeName, store := range stores {
|
||||
t.Run(storeName, func(t *testing.T) {
|
||||
sess, err := store.GetOutboundGroupSession(context.TODO(), "room1")
|
||||
if sess != nil {
|
||||
t.Error("Got outbound session before inserting")
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("Error retrieving outbound session: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "Error retrieving outbound session")
|
||||
require.Nil(t, sess, "Got outbound session before inserting")
|
||||
|
||||
outbound, err := NewOutboundGroupSession("room1", nil)
|
||||
require.NoError(t, err)
|
||||
err = store.AddOutboundGroupSession(context.TODO(), outbound)
|
||||
if err != nil {
|
||||
t.Errorf("Error inserting outbound session: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "Error inserting outbound session")
|
||||
|
||||
sess, err = store.GetOutboundGroupSession(context.TODO(), "room1")
|
||||
if sess == nil {
|
||||
t.Error("Did not get outbound session after inserting")
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("Error retrieving outbound session: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "Error retrieving outbound session")
|
||||
assert.NotNil(t, sess, "Did not get outbound session after inserting")
|
||||
|
||||
err = store.RemoveOutboundGroupSession(context.TODO(), "room1")
|
||||
if err != nil {
|
||||
t.Errorf("Error deleting outbound session: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "Error deleting outbound session")
|
||||
|
||||
sess, err = store.GetOutboundGroupSession(context.TODO(), "room1")
|
||||
if sess != nil {
|
||||
t.Error("Got outbound session after deleting")
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("Error retrieving outbound session: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "Error retrieving outbound session after deletion")
|
||||
assert.Nil(t, sess, "Got outbound session after deleting")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -244,58 +199,41 @@ func TestStoreOutboundMegolmSessionSharing(t *testing.T) {
|
|||
t.Run(storeName, func(t *testing.T) {
|
||||
device := resetDevice()
|
||||
err := store.PutDevice(context.TODO(), "user1", device)
|
||||
if err != nil {
|
||||
t.Errorf("Error storing devices: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "Error storing device")
|
||||
|
||||
shared, err := store.IsOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1")
|
||||
if err != nil {
|
||||
t.Errorf("Error checking if outbound group session is shared: %v", err)
|
||||
} else if shared {
|
||||
t.Errorf("Outbound group session shared when it shouldn't")
|
||||
}
|
||||
require.NoError(t, err, "Error checking if outbound group session is shared")
|
||||
assert.False(t, shared, "Outbound group session should not be shared initially")
|
||||
|
||||
err = store.MarkOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1")
|
||||
if err != nil {
|
||||
t.Errorf("Error marking outbound group session as shared: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "Error marking outbound group session as shared")
|
||||
|
||||
shared, err = store.IsOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1")
|
||||
if err != nil {
|
||||
t.Errorf("Error checking if outbound group session is shared: %v", err)
|
||||
} else if !shared {
|
||||
t.Errorf("Outbound group session not shared when it should")
|
||||
}
|
||||
require.NoError(t, err, "Error checking if outbound group session is shared")
|
||||
assert.True(t, shared, "Outbound group session should be shared after marking it as such")
|
||||
|
||||
device = resetDevice()
|
||||
err = store.PutDevice(context.TODO(), "user1", device)
|
||||
if err != nil {
|
||||
t.Errorf("Error storing devices: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "Error storing device after resetting")
|
||||
|
||||
shared, err = store.IsOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1")
|
||||
if err != nil {
|
||||
t.Errorf("Error checking if outbound group session is shared: %v", err)
|
||||
} else if shared {
|
||||
t.Errorf("Outbound group session shared when it shouldn't")
|
||||
}
|
||||
require.NoError(t, err, "Error checking if outbound group session is shared")
|
||||
assert.False(t, shared, "Outbound group session should not be shared after resetting device")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreDevices(t *testing.T) {
|
||||
devicesToCreate := 17
|
||||
stores := getCryptoStores(t)
|
||||
for storeName, store := range stores {
|
||||
t.Run(storeName, func(t *testing.T) {
|
||||
outdated, err := store.GetOutdatedTrackedUsers(context.TODO())
|
||||
if err != nil {
|
||||
t.Errorf("Error filtering tracked users: %v", err)
|
||||
}
|
||||
if len(outdated) > 0 {
|
||||
t.Errorf("Got %d outdated tracked users when expected none", len(outdated))
|
||||
}
|
||||
require.NoError(t, err, "Error filtering tracked users")
|
||||
assert.Empty(t, outdated, "Expected no outdated tracked users initially")
|
||||
|
||||
deviceMap := make(map[id.DeviceID]*id.Device)
|
||||
for i := 0; i < 17; i++ {
|
||||
for i := 0; i < devicesToCreate; i++ {
|
||||
iStr := strconv.Itoa(i)
|
||||
acc := NewOlmAccount()
|
||||
deviceMap[id.DeviceID("dev"+iStr)] = &id.Device{
|
||||
|
|
@ -306,59 +244,33 @@ func TestStoreDevices(t *testing.T) {
|
|||
}
|
||||
}
|
||||
err = store.PutDevices(context.TODO(), "user1", deviceMap)
|
||||
if err != nil {
|
||||
t.Errorf("Error storing devices: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "Error storing devices")
|
||||
devs, err := store.GetDevices(context.TODO(), "user1")
|
||||
if err != nil {
|
||||
t.Errorf("Error getting devices: %v", err)
|
||||
}
|
||||
if len(devs) != 17 {
|
||||
t.Errorf("Stored 17 devices, got back %v", len(devs))
|
||||
}
|
||||
if devs["dev0"].IdentityKey != deviceMap["dev0"].IdentityKey {
|
||||
t.Errorf("First device identity key does not match")
|
||||
}
|
||||
if devs["dev16"].IdentityKey != deviceMap["dev16"].IdentityKey {
|
||||
t.Errorf("Last device identity key does not match")
|
||||
}
|
||||
require.NoError(t, err, "Error getting devices")
|
||||
assert.Len(t, devs, devicesToCreate, "Expected to get %d devices back", devicesToCreate)
|
||||
assert.Equal(t, deviceMap, devs, "Stored devices do not match retrieved devices")
|
||||
|
||||
filtered, err := store.FilterTrackedUsers(context.TODO(), []id.UserID{"user0", "user1", "user2"})
|
||||
if err != nil {
|
||||
t.Errorf("Error filtering tracked users: %v", err)
|
||||
} else if len(filtered) != 1 || filtered[0] != "user1" {
|
||||
t.Errorf("Expected to get 'user1' from filter, got %v", filtered)
|
||||
}
|
||||
require.NoError(t, err, "Error filtering tracked users")
|
||||
assert.Equal(t, []id.UserID{"user1"}, filtered, "Expected to get 'user1' from filter")
|
||||
|
||||
outdated, err = store.GetOutdatedTrackedUsers(context.TODO())
|
||||
if err != nil {
|
||||
t.Errorf("Error filtering tracked users: %v", err)
|
||||
}
|
||||
if len(outdated) > 0 {
|
||||
t.Errorf("Got %d outdated tracked users when expected none", len(outdated))
|
||||
}
|
||||
require.NoError(t, err, "Error filtering tracked users")
|
||||
assert.Empty(t, outdated, "Expected no outdated tracked users after initial storage")
|
||||
|
||||
err = store.MarkTrackedUsersOutdated(context.TODO(), []id.UserID{"user0", "user1"})
|
||||
if err != nil {
|
||||
t.Errorf("Error marking tracked users outdated: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "Error marking tracked users outdated")
|
||||
|
||||
outdated, err = store.GetOutdatedTrackedUsers(context.TODO())
|
||||
if err != nil {
|
||||
t.Errorf("Error filtering tracked users: %v", err)
|
||||
}
|
||||
if len(outdated) != 1 || outdated[0] != id.UserID("user1") {
|
||||
t.Errorf("Got outdated tracked users %v when expected 'user1'", outdated)
|
||||
}
|
||||
require.NoError(t, err, "Error filtering tracked users")
|
||||
assert.Equal(t, []id.UserID{"user1"}, outdated, "Expected 'user1' to be marked as outdated")
|
||||
|
||||
err = store.PutDevices(context.TODO(), "user1", deviceMap)
|
||||
if err != nil {
|
||||
t.Errorf("Error storing devices: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "Error storing devices again")
|
||||
|
||||
outdated, err = store.GetOutdatedTrackedUsers(context.TODO())
|
||||
if err != nil {
|
||||
t.Errorf("Error filtering tracked users: %v", err)
|
||||
}
|
||||
if len(outdated) > 0 {
|
||||
t.Errorf("Got outdated tracked users %v when expected none", outdated)
|
||||
}
|
||||
require.NoError(t, err, "Error filtering tracked users")
|
||||
assert.Empty(t, outdated, "Expected no outdated tracked users after re-storing devices")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -369,16 +281,11 @@ func TestStoreSecrets(t *testing.T) {
|
|||
t.Run(storeName, func(t *testing.T) {
|
||||
storedSecret := "trustno1"
|
||||
err := store.PutSecret(context.TODO(), id.SecretMegolmBackupV1, storedSecret)
|
||||
if err != nil {
|
||||
t.Errorf("Error storing secret: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "Error storing secret")
|
||||
|
||||
secret, err := store.GetSecret(context.TODO(), id.SecretMegolmBackupV1)
|
||||
if err != nil {
|
||||
t.Errorf("Error storing secret: %v", err)
|
||||
} else if secret != storedSecret {
|
||||
t.Errorf("Stored secret did not match: '%s' != '%s'", secret, storedSecret)
|
||||
}
|
||||
require.NoError(t, err, "Error retrieving secret")
|
||||
assert.Equal(t, storedSecret, secret, "Retrieved secret does not match stored secret")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,6 +9,9 @@ package utils
|
|||
import (
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAES256Ctr(t *testing.T) {
|
||||
|
|
@ -16,9 +19,7 @@ func TestAES256Ctr(t *testing.T) {
|
|||
key, iv := GenAttachmentA256CTR()
|
||||
enc := XorA256CTR([]byte(expected), key, iv)
|
||||
dec := XorA256CTR(enc, key, iv)
|
||||
if string(dec) != expected {
|
||||
t.Errorf("Expected decrypted using generated key/iv to be `%v`, got %v", expected, string(dec))
|
||||
}
|
||||
assert.EqualValues(t, expected, dec, "Decrypted text should match original")
|
||||
|
||||
var key2 [AESCTRKeyLength]byte
|
||||
var iv2 [AESCTRIVLength]byte
|
||||
|
|
@ -29,9 +30,7 @@ func TestAES256Ctr(t *testing.T) {
|
|||
iv2[i] = byte(i) + 32
|
||||
}
|
||||
dec2 := XorA256CTR([]byte{0x29, 0xc3, 0xff, 0x02, 0x21, 0xaf, 0x67, 0x73, 0x6e, 0xad, 0x9d}, key2, iv2)
|
||||
if string(dec2) != expected {
|
||||
t.Errorf("Expected decrypted using constant key/iv to be `%v`, got %v", expected, string(dec2))
|
||||
}
|
||||
assert.EqualValues(t, expected, dec2, "Decrypted text with constant key/iv should match original")
|
||||
}
|
||||
|
||||
func TestPBKDF(t *testing.T) {
|
||||
|
|
@ -42,9 +41,7 @@ func TestPBKDF(t *testing.T) {
|
|||
key := PBKDF2SHA512([]byte("Hello world"), salt, 1000, 256)
|
||||
expected := "ffk9YdbVE1cgqOWgDaec0lH+rJzO+MuCcxpIn3Z6D0E="
|
||||
keyB64 := base64.StdEncoding.EncodeToString([]byte(key))
|
||||
if keyB64 != expected {
|
||||
t.Errorf("Expected base64 of generated key to be `%v`, got `%v`", expected, keyB64)
|
||||
}
|
||||
assert.Equal(t, expected, keyB64)
|
||||
}
|
||||
|
||||
func TestDecodeSSSSKey(t *testing.T) {
|
||||
|
|
@ -53,13 +50,10 @@ func TestDecodeSSSSKey(t *testing.T) {
|
|||
|
||||
expected := "QCFDrXZYLEFnwf4NikVm62rYGJS2mNBEmAWLC3CgNPw="
|
||||
decodedB64 := base64.StdEncoding.EncodeToString(decoded[:])
|
||||
if expected != decodedB64 {
|
||||
t.Errorf("Expected decoded recovery key b64 to be `%v`, got `%v`", expected, decodedB64)
|
||||
}
|
||||
assert.Equal(t, expected, decodedB64)
|
||||
|
||||
if encoded := EncodeBase58RecoveryKey(decoded); encoded != recoveryKey {
|
||||
t.Errorf("Expected recovery key to be `%v`, got `%v`", recoveryKey, encoded)
|
||||
}
|
||||
encoded := EncodeBase58RecoveryKey(decoded)
|
||||
assert.Equal(t, recoveryKey, encoded)
|
||||
}
|
||||
|
||||
func TestKeyDerivationAndHMAC(t *testing.T) {
|
||||
|
|
@ -69,15 +63,11 @@ func TestKeyDerivationAndHMAC(t *testing.T) {
|
|||
aesKey, hmacKey := DeriveKeysSHA256(decoded[:], "m.cross_signing.master")
|
||||
|
||||
ciphertextBytes, err := base64.StdEncoding.DecodeString("Fx16KlJ9vkd3Dd6CafIq5spaH5QmK5BALMzbtFbQznG2j1VARKK+klc4/Qo=")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
calcMac := HMACSHA256B64(ciphertextBytes, hmacKey)
|
||||
expectedMac := "0DABPNIZsP9iTOh1o6EM0s7BfHHXb96dN7Eca88jq2E"
|
||||
if calcMac != expectedMac {
|
||||
t.Errorf("Expected MAC `%v`, got `%v`", expectedMac, calcMac)
|
||||
}
|
||||
assert.Equal(t, expectedMac, calcMac)
|
||||
|
||||
var ivBytes [AESCTRIVLength]byte
|
||||
decodedIV, _ := base64.StdEncoding.DecodeString("zxT/W5LpZ0Q819pfju6hZw==")
|
||||
|
|
@ -85,7 +75,5 @@ func TestKeyDerivationAndHMAC(t *testing.T) {
|
|||
decrypted := string(XorA256CTR(ciphertextBytes, aesKey, ivBytes))
|
||||
|
||||
expectedDec := "Ec8eZDyvVkO3EDsEG6ej5c0cCHnX7PINqFXZjnaTV2s="
|
||||
if expectedDec != decrypted {
|
||||
t.Errorf("Expected decrypted text to be `%v`, got `%v`", expectedDec, decrypted)
|
||||
}
|
||||
assert.Equal(t, expectedDec, decrypted)
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue