diff --git a/crypto/aescbc/aes_cbc_test.go b/crypto/aescbc/aes_cbc_test.go index bb03f706..d6611dc9 100644 --- a/crypto/aescbc/aes_cbc_test.go +++ b/crypto/aescbc/aes_cbc_test.go @@ -7,11 +7,13 @@ package aescbc_test import ( - "bytes" "crypto/aes" "crypto/rand" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "maunium.net/go/mautrix/crypto/aescbc" ) @@ -22,32 +24,23 @@ func TestAESCBC(t *testing.T) { // The key length can be 32, 24, 16 bytes (OR in bits: 128, 192 or 256) key := make([]byte, 32) _, err = rand.Read(key) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) iv := make([]byte, aes.BlockSize) _, err = rand.Read(iv) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) plaintext = []byte("secret message for testing") //increase to next block size for len(plaintext)%8 != 0 { plaintext = append(plaintext, []byte("-")...) } - if ciphertext, err = aescbc.Encrypt(key, iv, plaintext); err != nil { - t.Fatal(err) - } + ciphertext, err = aescbc.Encrypt(key, iv, plaintext) + require.NoError(t, err) resultPlainText, err := aescbc.Decrypt(key, iv, ciphertext) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) - if string(resultPlainText) != string(plaintext) { - t.Fatalf("message '%s' (length %d) != '%s'", resultPlainText, len(resultPlainText), plaintext) - } + assert.Equal(t, string(resultPlainText), string(plaintext)) } func TestAESCBCCase1(t *testing.T) { @@ -61,18 +54,10 @@ func TestAESCBCCase1(t *testing.T) { key := make([]byte, 32) iv := make([]byte, aes.BlockSize) encrypted, err := aescbc.Encrypt(key, iv, input) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(expected, encrypted) { - t.Fatalf("encrypted did not match expected:\n%v\n%v\n", encrypted, expected) - } + require.NoError(t, err) + assert.Equal(t, expected, encrypted, "encrypted output does not match expected") decrypted, err := aescbc.Decrypt(key, iv, encrypted) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(input, decrypted) { - t.Fatalf("decrypted did not match expected:\n%v\n%v\n", decrypted, input) - } + require.NoError(t, err) + assert.Equal(t, input, decrypted, "decrypted output does not match input") } diff --git a/crypto/canonicaljson/json_test.go b/crypto/canonicaljson/json_test.go index d1a7f0a5..36476aa4 100644 --- a/crypto/canonicaljson/json_test.go +++ b/crypto/canonicaljson/json_test.go @@ -17,31 +17,43 @@ package canonicaljson import ( "testing" + + "github.com/stretchr/testify/assert" ) -func testSortJSON(t *testing.T, input, want string) { - got := SortJSON([]byte(input), nil) - - // Squash out the whitespace before comparing the JSON in case SortJSON had inserted whitespace. - if string(CompactJSON(got, nil)) != want { - t.Errorf("SortJSON(%q): want %q got %q", input, want, got) - } -} - func TestSortJSON(t *testing.T) { - testSortJSON(t, `[{"b":"two","a":1}]`, `[{"a":1,"b":"two"}]`) - testSortJSON(t, `{"B":{"4":4,"3":3},"A":{"1":1,"2":2}}`, - `{"A":{"1":1,"2":2},"B":{"3":3,"4":4}}`) - testSortJSON(t, `[true,false,null]`, `[true,false,null]`) - testSortJSON(t, `[9007199254740991]`, `[9007199254740991]`) - testSortJSON(t, "\t\n[9007199254740991]", `[9007199254740991]`) + var tests = []struct { + input string + want string + }{ + {"{}", "{}"}, + {`[{"b":"two","a":1}]`, `[{"a":1,"b":"two"}]`}, + {`{"B":{"4":4,"3":3},"A":{"1":1,"2":2}}`, `{"A":{"1":1,"2":2},"B":{"3":3,"4":4}}`}, + {`[true,false,null]`, `[true,false,null]`}, + {`[9007199254740991]`, `[9007199254740991]`}, + {"\t\n[9007199254740991]", `[9007199254740991]`}, + {`[true,false,null]`, `[true,false,null]`}, + {`[{"b":"two","a":1}]`, `[{"a":1,"b":"two"}]`}, + {`{"B":{"4":4,"3":3},"A":{"1":1,"2":2}}`, `{"A":{"1":1,"2":2},"B":{"3":3,"4":4}}`}, + {`[true,false,null]`, `[true,false,null]`}, + {`[9007199254740991]`, `[9007199254740991]`}, + {"\t\n[9007199254740991]", `[9007199254740991]`}, + {`[true,false,null]`, `[true,false,null]`}, + } + for _, test := range tests { + t.Run(test.input, func(t *testing.T) { + got := SortJSON([]byte(test.input), nil) + + // Squash out the whitespace before comparing the JSON in case SortJSON had inserted whitespace. + assert.EqualValues(t, test.want, string(CompactJSON(got, nil))) + }) + } } func testCompactJSON(t *testing.T, input, want string) { + t.Helper() got := string(CompactJSON([]byte(input), nil)) - if got != want { - t.Errorf("CompactJSON(%q): want %q got %q", input, want, got) - } + assert.EqualValues(t, want, got) } func TestCompactJSON(t *testing.T) { @@ -74,18 +86,23 @@ func TestCompactJSON(t *testing.T) { testCompactJSON(t, `["\"\\\/"]`, `["\"\\/"]`) } -func testReadHex(t *testing.T, input string, want uint32) { - got := readHexDigits([]byte(input)) - if want != got { - t.Errorf("readHexDigits(%q): want 0x%x got 0x%x", input, want, got) +func TestReadHex(t *testing.T) { + tests := []struct { + input string + want uint32 + }{ + + {"0123", 0x0123}, + {"4567", 0x4567}, + {"89AB", 0x89AB}, + {"CDEF", 0xCDEF}, + {"89ab", 0x89AB}, + {"cdef", 0xCDEF}, + } + for _, test := range tests { + t.Run(test.input, func(t *testing.T) { + got := readHexDigits([]byte(test.input)) + assert.Equal(t, test.want, got) + }) } } - -func TestReadHex(t *testing.T) { - testReadHex(t, "0123", 0x0123) - testReadHex(t, "4567", 0x4567) - testReadHex(t, "89AB", 0x89AB) - testReadHex(t, "CDEF", 0xCDEF) - testReadHex(t, "89ab", 0x89AB) - testReadHex(t, "cdef", 0xCDEF) -} diff --git a/crypto/cross_sign_test.go b/crypto/cross_sign_test.go index 5e1ffd50..b70370a2 100644 --- a/crypto/cross_sign_test.go +++ b/crypto/cross_sign_test.go @@ -13,6 +13,8 @@ import ( "testing" "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.mau.fi/util/dbutil" "maunium.net/go/mautrix" @@ -24,17 +26,12 @@ var noopLogger = zerolog.Nop() func getOlmMachine(t *testing.T) *OlmMachine { rawDB, err := sql.Open("sqlite3", ":memory:?_busy_timeout=5000") - if err != nil { - t.Fatalf("Error opening db: %v", err) - } + require.NoError(t, err, "Error opening raw database") db, err := dbutil.NewWithDB(rawDB, "sqlite3") - if err != nil { - t.Fatalf("Error opening db: %v", err) - } + require.NoError(t, err, "Error creating database wrapper") sqlStore := NewSQLCryptoStore(db, nil, "accid", id.DeviceID("dev"), []byte("test")) - if err = sqlStore.DB.Upgrade(context.TODO()); err != nil { - t.Fatalf("Error creating tables: %v", err) - } + err = sqlStore.DB.Upgrade(context.TODO()) + require.NoError(t, err, "Error upgrading database") userID := id.UserID("@mautrix") mk, _ := olm.NewPKSigning() @@ -66,29 +63,25 @@ func TestTrustOwnDevice(t *testing.T) { DeviceID: "device", SigningKey: id.Ed25519("deviceKey"), } - if m.IsDeviceTrusted(context.TODO(), ownDevice) { - t.Error("Own device trusted while it shouldn't be") - } + assert.False(t, m.IsDeviceTrusted(context.TODO(), ownDevice), "Own device trusted while it shouldn't be") m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(), ownDevice.UserID, m.CrossSigningKeys.MasterKey.PublicKey(), "sig1") m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, ownDevice.SigningKey, ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(), "sig2") - if trusted, _ := m.IsUserTrusted(context.TODO(), ownDevice.UserID); !trusted { - t.Error("Own user not trusted while they should be") - } - if !m.IsDeviceTrusted(context.TODO(), ownDevice) { - t.Error("Own device not trusted while it should be") - } + trusted, err := m.IsUserTrusted(context.TODO(), ownDevice.UserID) + require.NoError(t, err, "Error checking if own user is trusted") + assert.True(t, trusted, "Own user not trusted while they should be") + assert.True(t, m.IsDeviceTrusted(context.TODO(), ownDevice), "Own device not trusted while it should be") } func TestTrustOtherUser(t *testing.T) { m := getOlmMachine(t) otherUser := id.UserID("@user") - if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); trusted { - t.Error("Other user trusted while they shouldn't be") - } + trusted, err := m.IsUserTrusted(context.TODO(), otherUser) + require.NoError(t, err, "Error checking if other user is trusted") + assert.False(t, trusted, "Other user trusted while they shouldn't be") theirMasterKey, _ := olm.NewPKSigning() m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey()) @@ -100,16 +93,16 @@ func TestTrustOtherUser(t *testing.T) { m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(), m.Client.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(), "invalid_sig") - if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); trusted { - t.Error("Other user trusted before their master key has been signed with our user-signing key") - } + trusted, err = m.IsUserTrusted(context.TODO(), otherUser) + require.NoError(t, err, "Error checking if other user is trusted") + assert.False(t, trusted, "Other user trusted before their master key has been signed with our user-signing key") m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(), m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey(), "sig2") - if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); !trusted { - t.Error("Other user not trusted while they should be") - } + trusted, err = m.IsUserTrusted(context.TODO(), otherUser) + require.NoError(t, err, "Error checking if other user is trusted") + assert.True(t, trusted, "Other user not trusted while they should be") } func TestTrustOtherDevice(t *testing.T) { @@ -120,12 +113,11 @@ func TestTrustOtherDevice(t *testing.T) { DeviceID: "theirDevice", SigningKey: id.Ed25519("theirDeviceKey"), } - if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); trusted { - t.Error("Other user trusted while they shouldn't be") - } - if m.IsDeviceTrusted(context.TODO(), theirDevice) { - t.Error("Other device trusted while it shouldn't be") - } + + trusted, err := m.IsUserTrusted(context.TODO(), otherUser) + require.NoError(t, err, "Error checking if other user is trusted") + assert.False(t, trusted, "Other user trusted while they shouldn't be") + assert.False(t, m.IsDeviceTrusted(context.TODO(), theirDevice), "Other device trusted while it shouldn't be") theirMasterKey, _ := olm.NewPKSigning() m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey()) @@ -137,21 +129,17 @@ func TestTrustOtherDevice(t *testing.T) { m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(), m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey(), "sig2") - if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); !trusted { - t.Error("Other user not trusted while they should be") - } + trusted, err = m.IsUserTrusted(context.TODO(), otherUser) + require.NoError(t, err, "Error checking if other user is trusted") + assert.True(t, trusted, "Other user not trusted while they should be") m.CryptoStore.PutSignature(context.TODO(), otherUser, theirSSK.PublicKey(), otherUser, theirMasterKey.PublicKey(), "sig3") - if m.IsDeviceTrusted(context.TODO(), theirDevice) { - t.Error("Other device trusted before it has been signed with user's SSK") - } + assert.False(t, m.IsDeviceTrusted(context.TODO(), theirDevice), "Other device trusted before it has been signed with user's SSK") m.CryptoStore.PutSignature(context.TODO(), otherUser, theirDevice.SigningKey, otherUser, theirSSK.PublicKey(), "sig4") - if !m.IsDeviceTrusted(context.TODO(), theirDevice) { - t.Error("Other device not trusted while it should be") - } + assert.True(t, m.IsDeviceTrusted(context.TODO(), theirDevice), "Other device not trusted after it has been signed with user's SSK") } diff --git a/crypto/machine_test.go b/crypto/machine_test.go index 59c86236..872c3ac4 100644 --- a/crypto/machine_test.go +++ b/crypto/machine_test.go @@ -36,20 +36,15 @@ func (mockStateStore) FindSharedRooms(context.Context, id.UserID) ([]id.RoomID, func newMachine(t *testing.T, userID id.UserID) *OlmMachine { client, err := mautrix.NewClient("http://localhost", userID, "token") - if err != nil { - t.Fatalf("Error creating client: %v", err) - } + require.NoError(t, err, "Error creating client") client.DeviceID = "device1" gobStore := NewMemoryStore(nil) - if err != nil { - t.Fatalf("Error creating Gob store: %v", err) - } + require.NoError(t, err, "Error creating Gob store") machine := NewOlmMachine(client, nil, gobStore, mockStateStore{}) - if err := machine.Load(context.TODO()); err != nil { - t.Fatalf("Error creating account: %v", err) - } + err = machine.Load(context.TODO()) + require.NoError(t, err, "Error creating account") return machine } @@ -82,9 +77,7 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) { // create outbound olm session for sending machine using OTK olmSession, err := machineOut.account.Internal.NewOutboundSession(machineIn.account.IdentityKey(), otk.Key) - if err != nil { - t.Errorf("Failed to create outbound olm session: %v", err) - } + require.NoError(t, err, "Error creating outbound olm session") // store sender device identity in receiving machine store machineIn.CryptoStore.PutDevices(context.TODO(), "user1", map[id.DeviceID]*id.Device{ @@ -121,29 +114,21 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) { Type: event.ToDeviceEncrypted, Sender: "user1", }, senderKey, content.Type, content.Body) - if err != nil { - t.Errorf("Error decrypting olm content: %v", err) - } + require.NoError(t, err, "Error decrypting olm ciphertext") + // store room key in new inbound group session roomKeyEvt := decrypted.Content.AsRoomKey() igs, err := NewInboundGroupSession(senderKey, signingKey, "room1", roomKeyEvt.SessionKey, 0, 0, false) - if err != nil { - t.Errorf("Error creating inbound megolm session: %v", err) - } - if err = machineIn.CryptoStore.PutGroupSession(context.TODO(), igs); err != nil { - t.Errorf("Error storing inbound megolm session: %v", err) - } + require.NoError(t, err, "Error creating inbound group session") + err = machineIn.CryptoStore.PutGroupSession(context.TODO(), igs) + require.NoError(t, err, "Error storing inbound group session") } // encrypt event with megolm session in sending machine eventContent := map[string]string{"hello": "world"} encryptedEvtContent, err := machineOut.EncryptMegolmEvent(context.TODO(), "room1", event.EventMessage, eventContent) - if err != nil { - t.Errorf("Error encrypting megolm event: %v", err) - } - if megolmOutSession.MessageCount != 1 { - t.Errorf("Megolm outbound session message count is not 1 but %d", megolmOutSession.MessageCount) - } + require.NoError(t, err, "Error encrypting megolm event") + assert.Equal(t, 1, megolmOutSession.MessageCount) encryptedEvt := &event.Event{ Content: event.Content{Parsed: encryptedEvtContent}, @@ -155,22 +140,12 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) { // decrypt event on receiving machine and confirm decryptedEvt, err := machineIn.DecryptMegolmEvent(context.TODO(), encryptedEvt) - if err != nil { - t.Errorf("Error decrypting megolm event: %v", err) - } - if decryptedEvt.Type != event.EventMessage { - t.Errorf("Expected event type %v, got %v", event.EventMessage, decryptedEvt.Type) - } - if decryptedEvt.Content.Raw["hello"] != "world" { - t.Errorf("Expected event content %v, got %v", eventContent, decryptedEvt.Content.Raw) - } + require.NoError(t, err, "Error decrypting megolm event") + assert.Equal(t, event.EventMessage, decryptedEvt.Type) + assert.Equal(t, "world", decryptedEvt.Content.Raw["hello"]) machineOut.EncryptMegolmEvent(context.TODO(), "room1", event.EventMessage, eventContent) - if megolmOutSession.Expired() { - t.Error("Megolm outbound session expired before 3rd message") - } + assert.False(t, megolmOutSession.Expired(), "Megolm outbound session expired before 3rd message") machineOut.EncryptMegolmEvent(context.TODO(), "room1", event.EventMessage, eventContent) - if !megolmOutSession.Expired() { - t.Error("Megolm outbound session not expired after 3rd message") - } + assert.True(t, megolmOutSession.Expired(), "Megolm outbound session not expired after 3rd message") } diff --git a/crypto/store_test.go b/crypto/store_test.go index a7c4d75a..8aeae7af 100644 --- a/crypto/store_test.go +++ b/crypto/store_test.go @@ -13,6 +13,7 @@ import ( "testing" _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.mau.fi/util/dbutil" @@ -29,22 +30,14 @@ const groupSession = "9ZbsRqJuETbjnxPpKv29n3dubP/m5PSLbr9I9CIWS2O86F/Og1JZXhqT+4 func getCryptoStores(t *testing.T) map[string]Store { rawDB, err := sql.Open("sqlite3", ":memory:?_busy_timeout=5000") - if err != nil { - t.Fatalf("Error opening db: %v", err) - } + require.NoError(t, err, "Error opening raw database") db, err := dbutil.NewWithDB(rawDB, "sqlite3") - if err != nil { - t.Fatalf("Error opening db: %v", err) - } + require.NoError(t, err, "Error creating database wrapper") sqlStore := NewSQLCryptoStore(db, nil, "accid", id.DeviceID("dev"), []byte("test")) - if err = sqlStore.DB.Upgrade(context.TODO()); err != nil { - t.Fatalf("Error creating tables: %v", err) - } + err = sqlStore.DB.Upgrade(context.TODO()) + require.NoError(t, err, "Error upgrading database") gobStore := NewMemoryStore(nil) - if err != nil { - t.Fatalf("Error creating Gob store: %v", err) - } return map[string]Store{ "sql": sqlStore, @@ -56,9 +49,10 @@ func TestPutNextBatch(t *testing.T) { stores := getCryptoStores(t) store := stores["sql"].(*SQLCryptoStore) store.PutNextBatch(context.Background(), "batch1") - if batch, _ := store.GetNextBatch(context.Background()); batch != "batch1" { - t.Errorf("Expected batch1, got %v", batch) - } + + batch, err := store.GetNextBatch(context.Background()) + require.NoError(t, err, "Error retrieving next batch") + assert.Equal(t, "batch1", batch) } func TestPutAccount(t *testing.T) { @@ -68,15 +62,9 @@ func TestPutAccount(t *testing.T) { acc := NewOlmAccount() store.PutAccount(context.TODO(), acc) retrieved, err := store.GetAccount(context.TODO()) - if err != nil { - t.Fatalf("Error retrieving account: %v", err) - } - if acc.IdentityKey() != retrieved.IdentityKey() { - t.Errorf("Stored identity key %v, got %v", acc.IdentityKey(), retrieved.IdentityKey()) - } - if acc.SigningKey() != retrieved.SigningKey() { - t.Errorf("Stored signing key %v, got %v", acc.SigningKey(), retrieved.SigningKey()) - } + require.NoError(t, err, "Error retrieving account") + assert.Equal(t, acc.IdentityKey(), retrieved.IdentityKey(), "Identity key does not match") + assert.Equal(t, acc.SigningKey(), retrieved.SigningKey(), "Signing key does not match") }) } } @@ -86,18 +74,26 @@ func TestValidateMessageIndex(t *testing.T) { for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { acc := NewOlmAccount() - if ok, _ := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000); !ok { - t.Error("First message not validated successfully") - } - if ok, _ := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1001); ok { - t.Error("First message validated successfully after changing timestamp") - } - if ok, _ := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event2", 0, 1000); ok { - t.Error("First message validated successfully after changing event ID") - } - if ok, _ := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000); !ok { - t.Error("First message not validated successfully for a second time") - } + + // First message should validate successfully + ok, err := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000) + require.NoError(t, err, "Error validating message index") + assert.True(t, ok, "First message validation should be valid") + + // Edit the timestamp and ensure validate fails + ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1001) + require.NoError(t, err, "Error validating message index after timestamp change") + assert.False(t, ok, "First message validation should fail after timestamp change") + + // Edit the event ID and ensure validate fails + ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event2", 0, 1000) + require.NoError(t, err, "Error validating message index after event ID change") + assert.False(t, ok, "First message validation should fail after event ID change") + + // Validate again with the original parameters and ensure that it still passes + ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000) + require.NoError(t, err, "Error validating message index") + assert.True(t, ok, "First message validation should be valid") }) } } @@ -106,43 +102,26 @@ func TestStoreOlmSession(t *testing.T) { stores := getCryptoStores(t) for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { - if store.HasSession(context.TODO(), olmSessID) { - t.Error("Found Olm session before inserting it") - } + require.False(t, store.HasSession(context.TODO(), olmSessID), "Found Olm session before inserting it") + olmInternal, err := olm.SessionFromPickled([]byte(olmPickled), []byte("test")) - if err != nil { - t.Fatalf("Error creating internal Olm session: %v", err) - } + require.NoError(t, err, "Error creating internal Olm session") olmSess := OlmSession{ id: olmSessID, Internal: olmInternal, } err = store.AddSession(context.TODO(), olmSessID, &olmSess) - if err != nil { - t.Errorf("Error storing Olm session: %v", err) - } - if !store.HasSession(context.TODO(), olmSessID) { - t.Error("Not found Olm session after inserting it") - } + require.NoError(t, err, "Error storing Olm session") + assert.True(t, store.HasSession(context.TODO(), olmSessID), "Olm session not found after inserting it") retrieved, err := store.GetLatestSession(context.TODO(), olmSessID) - if err != nil { - t.Errorf("Failed retrieving Olm session: %v", err) - } - - if retrieved.ID() != olmSessID { - t.Errorf("Expected session ID to be %v, got %v", olmSessID, retrieved.ID()) - } + require.NoError(t, err, "Error retrieving Olm session") + assert.EqualValues(t, olmSessID, retrieved.ID()) pickled, err := retrieved.Internal.Pickle([]byte("test")) - if err != nil { - t.Fatalf("Error pickling Olm session: %v", err) - } - - if string(pickled) != olmPickled { - t.Error("Pickled Olm session does not match original") - } + require.NoError(t, err, "Error pickling Olm session") + assert.EqualValues(t, pickled, olmPickled, "Pickled Olm session does not match original") }) } } @@ -154,9 +133,7 @@ func TestStoreMegolmSession(t *testing.T) { acc := NewOlmAccount() internal, err := olm.InboundGroupSessionFromPickled([]byte(groupSession), []byte("test")) - if err != nil { - t.Fatalf("Error creating internal inbound group session: %v", err) - } + require.NoError(t, err, "Error creating internal inbound group session") igs := &InboundGroupSession{ Internal: internal, @@ -166,20 +143,14 @@ func TestStoreMegolmSession(t *testing.T) { } err = store.PutGroupSession(context.TODO(), igs) - if err != nil { - t.Errorf("Error storing inbound group session: %v", err) - } + require.NoError(t, err, "Error storing inbound group session") retrieved, err := store.GetGroupSession(context.TODO(), "room1", igs.ID()) - if err != nil { - t.Errorf("Error retrieving inbound group session: %v", err) - } + require.NoError(t, err, "Error retrieving inbound group session") - if pickled, err := retrieved.Internal.Pickle([]byte("test")); err != nil { - t.Fatalf("Error pickling inbound group session: %v", err) - } else if string(pickled) != groupSession { - t.Error("Pickled inbound group session does not match original") - } + pickled, err := retrieved.Internal.Pickle([]byte("test")) + require.NoError(t, err, "Error pickling inbound group session") + assert.EqualValues(t, pickled, groupSession, "Pickled inbound group session does not match original") }) } } @@ -189,40 +160,24 @@ func TestStoreOutboundMegolmSession(t *testing.T) { for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { sess, err := store.GetOutboundGroupSession(context.TODO(), "room1") - if sess != nil { - t.Error("Got outbound session before inserting") - } - if err != nil { - t.Errorf("Error retrieving outbound session: %v", err) - } + require.NoError(t, err, "Error retrieving outbound session") + require.Nil(t, sess, "Got outbound session before inserting") outbound, err := NewOutboundGroupSession("room1", nil) require.NoError(t, err) err = store.AddOutboundGroupSession(context.TODO(), outbound) - if err != nil { - t.Errorf("Error inserting outbound session: %v", err) - } + require.NoError(t, err, "Error inserting outbound session") sess, err = store.GetOutboundGroupSession(context.TODO(), "room1") - if sess == nil { - t.Error("Did not get outbound session after inserting") - } - if err != nil { - t.Errorf("Error retrieving outbound session: %v", err) - } + require.NoError(t, err, "Error retrieving outbound session") + assert.NotNil(t, sess, "Did not get outbound session after inserting") err = store.RemoveOutboundGroupSession(context.TODO(), "room1") - if err != nil { - t.Errorf("Error deleting outbound session: %v", err) - } + require.NoError(t, err, "Error deleting outbound session") sess, err = store.GetOutboundGroupSession(context.TODO(), "room1") - if sess != nil { - t.Error("Got outbound session after deleting") - } - if err != nil { - t.Errorf("Error retrieving outbound session: %v", err) - } + require.NoError(t, err, "Error retrieving outbound session after deletion") + assert.Nil(t, sess, "Got outbound session after deleting") }) } } @@ -244,58 +199,41 @@ func TestStoreOutboundMegolmSessionSharing(t *testing.T) { t.Run(storeName, func(t *testing.T) { device := resetDevice() err := store.PutDevice(context.TODO(), "user1", device) - if err != nil { - t.Errorf("Error storing devices: %v", err) - } + require.NoError(t, err, "Error storing device") shared, err := store.IsOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1") - if err != nil { - t.Errorf("Error checking if outbound group session is shared: %v", err) - } else if shared { - t.Errorf("Outbound group session shared when it shouldn't") - } + require.NoError(t, err, "Error checking if outbound group session is shared") + assert.False(t, shared, "Outbound group session should not be shared initially") err = store.MarkOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1") - if err != nil { - t.Errorf("Error marking outbound group session as shared: %v", err) - } + require.NoError(t, err, "Error marking outbound group session as shared") shared, err = store.IsOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1") - if err != nil { - t.Errorf("Error checking if outbound group session is shared: %v", err) - } else if !shared { - t.Errorf("Outbound group session not shared when it should") - } + require.NoError(t, err, "Error checking if outbound group session is shared") + assert.True(t, shared, "Outbound group session should be shared after marking it as such") device = resetDevice() err = store.PutDevice(context.TODO(), "user1", device) - if err != nil { - t.Errorf("Error storing devices: %v", err) - } + require.NoError(t, err, "Error storing device after resetting") shared, err = store.IsOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1") - if err != nil { - t.Errorf("Error checking if outbound group session is shared: %v", err) - } else if shared { - t.Errorf("Outbound group session shared when it shouldn't") - } + require.NoError(t, err, "Error checking if outbound group session is shared") + assert.False(t, shared, "Outbound group session should not be shared after resetting device") }) } } func TestStoreDevices(t *testing.T) { + devicesToCreate := 17 stores := getCryptoStores(t) for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { outdated, err := store.GetOutdatedTrackedUsers(context.TODO()) - if err != nil { - t.Errorf("Error filtering tracked users: %v", err) - } - if len(outdated) > 0 { - t.Errorf("Got %d outdated tracked users when expected none", len(outdated)) - } + require.NoError(t, err, "Error filtering tracked users") + assert.Empty(t, outdated, "Expected no outdated tracked users initially") + deviceMap := make(map[id.DeviceID]*id.Device) - for i := 0; i < 17; i++ { + for i := 0; i < devicesToCreate; i++ { iStr := strconv.Itoa(i) acc := NewOlmAccount() deviceMap[id.DeviceID("dev"+iStr)] = &id.Device{ @@ -306,59 +244,33 @@ func TestStoreDevices(t *testing.T) { } } err = store.PutDevices(context.TODO(), "user1", deviceMap) - if err != nil { - t.Errorf("Error storing devices: %v", err) - } + require.NoError(t, err, "Error storing devices") devs, err := store.GetDevices(context.TODO(), "user1") - if err != nil { - t.Errorf("Error getting devices: %v", err) - } - if len(devs) != 17 { - t.Errorf("Stored 17 devices, got back %v", len(devs)) - } - if devs["dev0"].IdentityKey != deviceMap["dev0"].IdentityKey { - t.Errorf("First device identity key does not match") - } - if devs["dev16"].IdentityKey != deviceMap["dev16"].IdentityKey { - t.Errorf("Last device identity key does not match") - } + require.NoError(t, err, "Error getting devices") + assert.Len(t, devs, devicesToCreate, "Expected to get %d devices back", devicesToCreate) + assert.Equal(t, deviceMap, devs, "Stored devices do not match retrieved devices") filtered, err := store.FilterTrackedUsers(context.TODO(), []id.UserID{"user0", "user1", "user2"}) - if err != nil { - t.Errorf("Error filtering tracked users: %v", err) - } else if len(filtered) != 1 || filtered[0] != "user1" { - t.Errorf("Expected to get 'user1' from filter, got %v", filtered) - } + require.NoError(t, err, "Error filtering tracked users") + assert.Equal(t, []id.UserID{"user1"}, filtered, "Expected to get 'user1' from filter") outdated, err = store.GetOutdatedTrackedUsers(context.TODO()) - if err != nil { - t.Errorf("Error filtering tracked users: %v", err) - } - if len(outdated) > 0 { - t.Errorf("Got %d outdated tracked users when expected none", len(outdated)) - } + require.NoError(t, err, "Error filtering tracked users") + assert.Empty(t, outdated, "Expected no outdated tracked users after initial storage") + err = store.MarkTrackedUsersOutdated(context.TODO(), []id.UserID{"user0", "user1"}) - if err != nil { - t.Errorf("Error marking tracked users outdated: %v", err) - } + require.NoError(t, err, "Error marking tracked users outdated") + outdated, err = store.GetOutdatedTrackedUsers(context.TODO()) - if err != nil { - t.Errorf("Error filtering tracked users: %v", err) - } - if len(outdated) != 1 || outdated[0] != id.UserID("user1") { - t.Errorf("Got outdated tracked users %v when expected 'user1'", outdated) - } + require.NoError(t, err, "Error filtering tracked users") + assert.Equal(t, []id.UserID{"user1"}, outdated, "Expected 'user1' to be marked as outdated") + err = store.PutDevices(context.TODO(), "user1", deviceMap) - if err != nil { - t.Errorf("Error storing devices: %v", err) - } + require.NoError(t, err, "Error storing devices again") + outdated, err = store.GetOutdatedTrackedUsers(context.TODO()) - if err != nil { - t.Errorf("Error filtering tracked users: %v", err) - } - if len(outdated) > 0 { - t.Errorf("Got outdated tracked users %v when expected none", outdated) - } + require.NoError(t, err, "Error filtering tracked users") + assert.Empty(t, outdated, "Expected no outdated tracked users after re-storing devices") }) } } @@ -369,16 +281,11 @@ func TestStoreSecrets(t *testing.T) { t.Run(storeName, func(t *testing.T) { storedSecret := "trustno1" err := store.PutSecret(context.TODO(), id.SecretMegolmBackupV1, storedSecret) - if err != nil { - t.Errorf("Error storing secret: %v", err) - } + require.NoError(t, err, "Error storing secret") secret, err := store.GetSecret(context.TODO(), id.SecretMegolmBackupV1) - if err != nil { - t.Errorf("Error storing secret: %v", err) - } else if secret != storedSecret { - t.Errorf("Stored secret did not match: '%s' != '%s'", secret, storedSecret) - } + require.NoError(t, err, "Error retrieving secret") + assert.Equal(t, storedSecret, secret, "Retrieved secret does not match stored secret") }) } } diff --git a/crypto/utils/utils_test.go b/crypto/utils/utils_test.go index c4f01a68..b12fd9e2 100644 --- a/crypto/utils/utils_test.go +++ b/crypto/utils/utils_test.go @@ -9,6 +9,9 @@ package utils import ( "encoding/base64" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestAES256Ctr(t *testing.T) { @@ -16,9 +19,7 @@ func TestAES256Ctr(t *testing.T) { key, iv := GenAttachmentA256CTR() enc := XorA256CTR([]byte(expected), key, iv) dec := XorA256CTR(enc, key, iv) - if string(dec) != expected { - t.Errorf("Expected decrypted using generated key/iv to be `%v`, got %v", expected, string(dec)) - } + assert.EqualValues(t, expected, dec, "Decrypted text should match original") var key2 [AESCTRKeyLength]byte var iv2 [AESCTRIVLength]byte @@ -29,9 +30,7 @@ func TestAES256Ctr(t *testing.T) { iv2[i] = byte(i) + 32 } dec2 := XorA256CTR([]byte{0x29, 0xc3, 0xff, 0x02, 0x21, 0xaf, 0x67, 0x73, 0x6e, 0xad, 0x9d}, key2, iv2) - if string(dec2) != expected { - t.Errorf("Expected decrypted using constant key/iv to be `%v`, got %v", expected, string(dec2)) - } + assert.EqualValues(t, expected, dec2, "Decrypted text with constant key/iv should match original") } func TestPBKDF(t *testing.T) { @@ -42,9 +41,7 @@ func TestPBKDF(t *testing.T) { key := PBKDF2SHA512([]byte("Hello world"), salt, 1000, 256) expected := "ffk9YdbVE1cgqOWgDaec0lH+rJzO+MuCcxpIn3Z6D0E=" keyB64 := base64.StdEncoding.EncodeToString([]byte(key)) - if keyB64 != expected { - t.Errorf("Expected base64 of generated key to be `%v`, got `%v`", expected, keyB64) - } + assert.Equal(t, expected, keyB64) } func TestDecodeSSSSKey(t *testing.T) { @@ -53,13 +50,10 @@ func TestDecodeSSSSKey(t *testing.T) { expected := "QCFDrXZYLEFnwf4NikVm62rYGJS2mNBEmAWLC3CgNPw=" decodedB64 := base64.StdEncoding.EncodeToString(decoded[:]) - if expected != decodedB64 { - t.Errorf("Expected decoded recovery key b64 to be `%v`, got `%v`", expected, decodedB64) - } + assert.Equal(t, expected, decodedB64) - if encoded := EncodeBase58RecoveryKey(decoded); encoded != recoveryKey { - t.Errorf("Expected recovery key to be `%v`, got `%v`", recoveryKey, encoded) - } + encoded := EncodeBase58RecoveryKey(decoded) + assert.Equal(t, recoveryKey, encoded) } func TestKeyDerivationAndHMAC(t *testing.T) { @@ -69,15 +63,11 @@ func TestKeyDerivationAndHMAC(t *testing.T) { aesKey, hmacKey := DeriveKeysSHA256(decoded[:], "m.cross_signing.master") ciphertextBytes, err := base64.StdEncoding.DecodeString("Fx16KlJ9vkd3Dd6CafIq5spaH5QmK5BALMzbtFbQznG2j1VARKK+klc4/Qo=") - if err != nil { - t.Error(err) - } + require.NoError(t, err) calcMac := HMACSHA256B64(ciphertextBytes, hmacKey) expectedMac := "0DABPNIZsP9iTOh1o6EM0s7BfHHXb96dN7Eca88jq2E" - if calcMac != expectedMac { - t.Errorf("Expected MAC `%v`, got `%v`", expectedMac, calcMac) - } + assert.Equal(t, expectedMac, calcMac) var ivBytes [AESCTRIVLength]byte decodedIV, _ := base64.StdEncoding.DecodeString("zxT/W5LpZ0Q819pfju6hZw==") @@ -85,7 +75,5 @@ func TestKeyDerivationAndHMAC(t *testing.T) { decrypted := string(XorA256CTR(ciphertextBytes, aesKey, ivBytes)) expectedDec := "Ec8eZDyvVkO3EDsEG6ej5c0cCHnX7PINqFXZjnaTV2s=" - if expectedDec != decrypted { - t.Errorf("Expected decrypted text to be `%v`, got `%v`", expectedDec, decrypted) - } + assert.Equal(t, expectedDec, decrypted) }