mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
Add crypto store tests
* Add tests for the functionality of the Gob and SQL crypto stores * Change `RemoveOutboundGroupSession` in the Gob store to not check if a session is shared, to be consistent with the SQL store Signed-off-by: Nikos Filippakis <me@nfil.dev>
This commit is contained in:
parent
e9aa55cc66
commit
47ab76e49d
4 changed files with 276 additions and 1 deletions
|
|
@ -285,7 +285,7 @@ func (gs *GobStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroupSes
|
|||
func (gs *GobStore) RemoveOutboundGroupSession(roomID id.RoomID) error {
|
||||
gs.lock.Lock()
|
||||
session, ok := gs.OutGroupSessions[roomID]
|
||||
if !ok || session == nil || !session.Shared {
|
||||
if !ok || session == nil {
|
||||
gs.lock.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
267
crypto/store_test.go
Normal file
267
crypto/store_test.go
Normal file
|
|
@ -0,0 +1,267 @@
|
|||
// Copyright (c) 2020 Nikos Filippakis
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"maunium.net/go/mautrix/crypto/olm"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
type emptyLogger struct{}
|
||||
|
||||
func (emptyLogger) Error(message string, args ...interface{}) {}
|
||||
func (emptyLogger) Warn(message string, args ...interface{}) {}
|
||||
func (emptyLogger) Debug(message string, args ...interface{}) {}
|
||||
func (emptyLogger) Trace(message string, args ...interface{}) {}
|
||||
|
||||
const olmSessID = "sJlikQQKXp7UQjmS9/lyZCNUVJ2AmKyHbufPBaC7tpk"
|
||||
const olmPickled = "L6cdv3JYO9OzhXbcjNSwl7ldN5bDvwmGyin+hISePETE6bO71DIlhqTC9YIhg21RDqRPH2HNl1MCyCw0hEXICWQyeJ9S7JLie" +
|
||||
"5PYxhqSSaTYaybvlvw34jvuSgEx0iotM6WNuWu5ocrsOo5Ye/3Nz7lBvxaw2rpS0jZnn7eV1n9GbINZk4YEVWrHOn7OxYfaGECJHDeAk/ameStiy" +
|
||||
"o1Gru0a/cmR0O3oKMyYnlXir0jS7oETMCsWk59GeVlz++j4aK0FK4g8/3fCMmLDXSatFjE9hoWDmeRwal58Y+XwX76Te/PiWtrFrinvCDEQJcZTa" +
|
||||
"qcCwp6sZrgLbmfBUBb0zJCogCmYw8m2"
|
||||
const groupSession = "9ZbsRqJuETbjnxPpKv29n3dubP/m5PSLbr9I9CIWS2O86F/Og1JZXhqT+4fA5tovoPfdpk5QLh7PfDyjmgOcO9sSA37maJyzCy6Ap+uBZLAXp6VLJ0mjSvxi+PAbzGKDMqpn+pa+oeEIH6SFPG/2GGDSRoXVi5fttAClCIoav5RflWiMypKqnQRfkZR2Gx8glOaBiTzAd7m0X6XGfYIPol41JUIHfBLuJBfXQ0Uu5GScV4eKUWdJP2J6zzC2Hx8cZAhiBBzAza0CbGcnUK+YJXMYaJg92HiIo++l317LlsYUJ/P+gKOLafYR9/l8bAzxH7j5s31PnRs7mD1Bl6G1LFM+dPsGXUOLx6PlvlTlYYM/opai0uKKzT0Wk6zPoq9fN/smlXEPBtKlw2fqcytL4gOF0MrBPEca"
|
||||
|
||||
func getCryptoStores(t *testing.T) (map[string]Store, func()) {
|
||||
db, err := sql.Open("sqlite3", ":memory:?_busy_timeout=5000")
|
||||
if err != nil {
|
||||
t.Fatalf("Error opening db: %v", err)
|
||||
}
|
||||
sqlStore := NewSQLCryptoStore(db, "sqlite3", id.DeviceID("dev"), []byte("test"), emptyLogger{})
|
||||
if err = sqlStore.CreateTables(); err != nil {
|
||||
t.Fatalf("Error creating tables: %v", err)
|
||||
}
|
||||
|
||||
os.Remove("gob_store_test.gob")
|
||||
gobStore, err := NewGobStore("gob_store_test.gob")
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating Gob store: %v", err)
|
||||
}
|
||||
|
||||
return map[string]Store{
|
||||
"sql": sqlStore,
|
||||
"gob": gobStore,
|
||||
}, func() {
|
||||
os.Remove("gob_store_test.gob")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutNextBatch(t *testing.T) {
|
||||
stores, cleanup := getCryptoStores(t)
|
||||
defer cleanup()
|
||||
store := stores["sql"].(*SQLCryptoStore)
|
||||
store.PutNextBatch("batch1")
|
||||
if batch := store.GetNextBatch(); batch != "batch1" {
|
||||
t.Errorf("Expected batch1, got %v", batch)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutAccount(t *testing.T) {
|
||||
stores, cleanup := getCryptoStores(t)
|
||||
defer cleanup()
|
||||
for storeName, store := range stores {
|
||||
t.Run(storeName, func(t *testing.T) {
|
||||
acc := NewOlmAccount()
|
||||
store.PutAccount(acc)
|
||||
retrieved, err := store.GetAccount()
|
||||
if err != nil {
|
||||
t.Fatalf("Error retrieving account: %v", err)
|
||||
}
|
||||
if acc.IdentityKey() != retrieved.IdentityKey() {
|
||||
t.Errorf("Stored identity key %v, got %v", acc.IdentityKey(), retrieved.IdentityKey())
|
||||
}
|
||||
if acc.SigningKey() != retrieved.SigningKey() {
|
||||
t.Errorf("Stored signing key %v, got %v", acc.SigningKey(), retrieved.SigningKey())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateMessageIndex(t *testing.T) {
|
||||
stores, cleanup := getCryptoStores(t)
|
||||
defer cleanup()
|
||||
for storeName, store := range stores {
|
||||
t.Run(storeName, func(t *testing.T) {
|
||||
acc := NewOlmAccount()
|
||||
if !store.ValidateMessageIndex(acc.IdentityKey(), "sess1", "event1", 0, 1000) {
|
||||
t.Error("First message not validated successfully")
|
||||
}
|
||||
if store.ValidateMessageIndex(acc.IdentityKey(), "sess1", "event1", 0, 1001) {
|
||||
t.Error("First message validated successfully after changing timestamp")
|
||||
}
|
||||
if store.ValidateMessageIndex(acc.IdentityKey(), "sess1", "event2", 0, 1000) {
|
||||
t.Error("First message validated successfully after changing event ID")
|
||||
}
|
||||
if !store.ValidateMessageIndex(acc.IdentityKey(), "sess1", "event1", 0, 1000) {
|
||||
t.Error("First message not validated successfully for a second time")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreOlmSession(t *testing.T) {
|
||||
stores, cleanup := getCryptoStores(t)
|
||||
defer cleanup()
|
||||
for storeName, store := range stores {
|
||||
t.Run(storeName, func(t *testing.T) {
|
||||
if store.HasSession(olmSessID) {
|
||||
t.Error("Found Olm session before inserting it")
|
||||
}
|
||||
olmInternal, err := olm.SessionFromPickled([]byte(olmPickled), []byte("test"))
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating internal Olm session: %v", err)
|
||||
}
|
||||
|
||||
olmSess := OlmSession{
|
||||
id: olmSessID,
|
||||
Internal: *olmInternal,
|
||||
}
|
||||
err = store.AddSession(olmSessID, &olmSess)
|
||||
if err != nil {
|
||||
t.Errorf("Error storing Olm session: %v", err)
|
||||
}
|
||||
if !store.HasSession(olmSessID) {
|
||||
t.Error("Not found Olm session after inserting it")
|
||||
}
|
||||
|
||||
retrieved, err := store.GetLatestSession(olmSessID)
|
||||
if err != nil {
|
||||
t.Errorf("Failed retrieving Olm session: %v", err)
|
||||
}
|
||||
|
||||
if retrieved.ID() != olmSessID {
|
||||
t.Errorf("Expected session ID to be %v, got %v", olmSessID, retrieved.ID())
|
||||
}
|
||||
if pickled := string(retrieved.Internal.Pickle([]byte("test"))); pickled != olmPickled {
|
||||
t.Error("Pickled Olm session does not match original")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreMegolmSession(t *testing.T) {
|
||||
stores, cleanup := getCryptoStores(t)
|
||||
defer cleanup()
|
||||
for storeName, store := range stores {
|
||||
t.Run(storeName, func(t *testing.T) {
|
||||
acc := NewOlmAccount()
|
||||
|
||||
internal, err := olm.InboundGroupSessionFromPickled([]byte(groupSession), []byte("test"))
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating internal inbound group session: %v", err)
|
||||
}
|
||||
|
||||
igs := &InboundGroupSession{
|
||||
Internal: *internal,
|
||||
SigningKey: acc.SigningKey(),
|
||||
SenderKey: acc.IdentityKey(),
|
||||
RoomID: "room1",
|
||||
}
|
||||
|
||||
err = store.PutGroupSession("room1", acc.IdentityKey(), igs.ID(), igs)
|
||||
if err != nil {
|
||||
t.Errorf("Error storing inbound group session: %v", err)
|
||||
}
|
||||
|
||||
retrieved, err := store.GetGroupSession("room1", acc.IdentityKey(), igs.ID())
|
||||
if err != nil {
|
||||
t.Errorf("Error retrieving inbound group session: %v", err)
|
||||
}
|
||||
|
||||
if pickled := string(retrieved.Internal.Pickle([]byte("test"))); pickled != groupSession {
|
||||
t.Error("Pickled inbound group session does not match original")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreOutboundMegolmSession(t *testing.T) {
|
||||
stores, cleanup := getCryptoStores(t)
|
||||
defer cleanup()
|
||||
for storeName, store := range stores {
|
||||
t.Run(storeName, func(t *testing.T) {
|
||||
sess, err := store.GetOutboundGroupSession("room1")
|
||||
if sess != nil {
|
||||
t.Error("Got outbound session before inserting")
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("Error retrieving outbound session: %v", err)
|
||||
}
|
||||
|
||||
outbound := NewOutboundGroupSession("room1")
|
||||
err = store.AddOutboundGroupSession(outbound)
|
||||
if err != nil {
|
||||
t.Errorf("Error inserting outbound session: %v", err)
|
||||
}
|
||||
|
||||
sess, err = store.GetOutboundGroupSession("room1")
|
||||
if sess == nil {
|
||||
t.Error("Did not get outbound session after inserting")
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("Error retrieving outbound session: %v", err)
|
||||
}
|
||||
|
||||
err = store.RemoveOutboundGroupSession("room1")
|
||||
if err != nil {
|
||||
t.Errorf("Error deleting outbound session: %v", err)
|
||||
}
|
||||
|
||||
sess, err = store.GetOutboundGroupSession("room1")
|
||||
if sess != nil {
|
||||
t.Error("Got outbound session after deleting")
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("Error retrieving outbound session: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreDevices(t *testing.T) {
|
||||
stores, cleanup := getCryptoStores(t)
|
||||
defer cleanup()
|
||||
for storeName, store := range stores {
|
||||
t.Run(storeName, func(t *testing.T) {
|
||||
acc1 := NewOlmAccount()
|
||||
acc2 := NewOlmAccount()
|
||||
err := store.PutDevices("user1", map[id.DeviceID]*DeviceIdentity{
|
||||
"dev1": {
|
||||
UserID: "user1",
|
||||
DeviceID: "dev1",
|
||||
IdentityKey: acc1.IdentityKey(),
|
||||
SigningKey: acc1.SigningKey(),
|
||||
},
|
||||
"dev2": {
|
||||
UserID: "user2",
|
||||
DeviceID: "dev2",
|
||||
IdentityKey: acc2.IdentityKey(),
|
||||
SigningKey: acc2.SigningKey(),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("Error string devices: %v", err)
|
||||
}
|
||||
devs, err := store.GetDevices("user1")
|
||||
if err != nil {
|
||||
t.Errorf("Error getting devices: %v", err)
|
||||
}
|
||||
if len(devs) != 2 {
|
||||
t.Errorf("Stored 2 devices, got back %v", len(devs))
|
||||
}
|
||||
|
||||
filtered := store.FilterTrackedUsers([]id.UserID{"user0", "user1", "user2"})
|
||||
if len(filtered) != 1 || filtered[0] != "user1" {
|
||||
t.Errorf("Expected to get 'user1' from filter, got %v", filtered)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
1
go.mod
1
go.mod
|
|
@ -5,6 +5,7 @@ go 1.14
|
|||
require (
|
||||
github.com/gorilla/mux v1.7.4
|
||||
github.com/lib/pq v1.7.0
|
||||
github.com/mattn/go-sqlite3 v1.14.0
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/russross/blackfriday/v2 v2.0.1
|
||||
github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect
|
||||
|
|
|
|||
7
go.sum
7
go.sum
|
|
@ -1,9 +1,13 @@
|
|||
github.com/PuerkitoBio/goquery v1.5.1/go.mod h1:GsLWisAFVj4WgDibEWF4pvYnkVQBpKBKeU+7zCJoLcc=
|
||||
github.com/andybalholm/cascadia v1.1.0/go.mod h1:GsXiBklL0woXo1j/WYWtSYYC4ouU9PqHO0sqidkEA4Y=
|
||||
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/gorilla/mux v1.7.4 h1:VuZ8uybHlWmqV03+zRzdwKL4tUnIp1MAQtp1mIFE1bc=
|
||||
github.com/gorilla/mux v1.7.4/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
|
||||
github.com/lib/pq v1.7.0 h1:h93mCPfUSkaul3Ka/VG8uZdmW1uMHDGxzu0NWHuJmHY=
|
||||
github.com/lib/pq v1.7.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/mattn/go-sqlite3 v1.14.0 h1:mLyGNKR8+Vv9CAU7PphKa2hkEqxxhn8i32J6FPj1/QA=
|
||||
github.com/mattn/go-sqlite3 v1.14.0/go.mod h1:JIl7NbARA7phWnGvh0LKTyg7S9BA+6gx71ShQilpsus=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
|
|
@ -25,6 +29,9 @@ github.com/tidwall/pretty v1.0.1/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhV
|
|||
github.com/tidwall/sjson v1.1.1 h1:7h1vk049Jnd5EH9NyzNiEuwYW4b5qgreBbqRC19AS3U=
|
||||
github.com/tidwall/sjson v1.1.1/go.mod h1:yvVuSnpEQv5cYIrO+AT6kw4QVfd5SDZoGIS7/5+fZFs=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/net v0.0.0-20180218175443-cbe0f9307d01/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
|
||||
golang.org/x/net v0.0.0-20200602114024-627f9648deb9 h1:pNX+40auqi2JqRfOP1akLGtYcn15TUbkhwuCO3foqqM=
|
||||
golang.org/x/net v0.0.0-20200602114024-627f9648deb9/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue