diff --git a/crypto/machine_bench_test.go b/crypto/machine_bench_test.go new file mode 100644 index 00000000..fd40d795 --- /dev/null +++ b/crypto/machine_bench_test.go @@ -0,0 +1,67 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package crypto_test + +import ( + "context" + "fmt" + "math/rand/v2" + "testing" + + "github.com/rs/zerolog" + globallog "github.com/rs/zerolog/log" // zerolog-allow-global-log + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/crypto/cryptohelper" + "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/mockserver" +) + +func randomDeviceCount(r *rand.Rand) int { + k := 1 + for k < 10 && r.IntN(3) > 0 { + k++ + } + return k +} + +func BenchmarkOlmMachine_ShareGroupSession(b *testing.B) { + globallog.Logger = zerolog.Nop() + server := mockserver.Create(b) + server.PopOTKs = false + server.MemoryStore = false + var i int + var shareTargets []id.UserID + r := rand.New(rand.NewPCG(293, 0)) + var totalDeviceCount int + for i = 1; i < 1000; i++ { + userID := id.UserID(fmt.Sprintf("@user%d:localhost", i)) + deviceCount := randomDeviceCount(r) + for j := 0; j < deviceCount; j++ { + client, _ := server.Login(b, nil, userID, id.DeviceID(fmt.Sprintf("u%d_d%d", i, j))) + mach := client.Crypto.(*cryptohelper.CryptoHelper).Machine() + keysCache, err := mach.GenerateCrossSigningKeys() + require.NoError(b, err) + err = mach.PublishCrossSigningKeys(context.TODO(), keysCache, nil) + require.NoError(b, err) + } + totalDeviceCount += deviceCount + shareTargets = append(shareTargets, userID) + } + for b.Loop() { + client, _ := server.Login(b, nil, id.UserID(fmt.Sprintf("@benchuser%d:localhost", i)), id.DeviceID(fmt.Sprintf("u%d_d1", i))) + mach := client.Crypto.(*cryptohelper.CryptoHelper).Machine() + keysCache, err := mach.GenerateCrossSigningKeys() + require.NoError(b, err) + err = mach.PublishCrossSigningKeys(context.TODO(), keysCache, nil) + require.NoError(b, err) + err = mach.ShareGroupSession(context.TODO(), "!room:localhost", shareTargets) + require.NoError(b, err) + i++ + } + fmt.Println(totalDeviceCount, "devices total") +} diff --git a/mockserver/mockserver.go b/mockserver/mockserver.go index 9f62b567..e52c387a 100644 --- a/mockserver/mockserver.go +++ b/mockserver/mockserver.go @@ -9,7 +9,9 @@ package mockserver import ( "context" "encoding/json" + "fmt" "io" + "maps" "net/http" "net/http/httptest" "strings" @@ -17,6 +19,9 @@ import ( globallog "github.com/rs/zerolog/log" // zerolog-allow-global-log "github.com/stretchr/testify/require" + "go.mau.fi/util/dbutil" + "go.mau.fi/util/exerrors" + "go.mau.fi/util/exhttp" "go.mau.fi/util/random" "maunium.net/go/mautrix" @@ -26,35 +31,52 @@ import ( "maunium.net/go/mautrix/id" ) +func mustDecode(r *http.Request, data any) { + exerrors.PanicIfNotNil(json.NewDecoder(r.Body).Decode(data)) +} + +type userAndDeviceID struct { + UserID id.UserID + DeviceID id.DeviceID +} + type MockServer struct { Router *http.ServeMux Server *httptest.Server - AccessTokenToUserID map[string]id.UserID + AccessTokenToUserID map[string]userAndDeviceID DeviceInbox map[id.UserID]map[id.DeviceID][]event.Event AccountData map[id.UserID]map[event.Type]json.RawMessage DeviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys + OneTimeKeys map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey MasterKeys map[id.UserID]mautrix.CrossSigningKeys SelfSigningKeys map[id.UserID]mautrix.CrossSigningKeys UserSigningKeys map[id.UserID]mautrix.CrossSigningKeys + + PopOTKs bool + MemoryStore bool } -func Create(t *testing.T) *MockServer { +func Create(t testing.TB) *MockServer { t.Helper() server := MockServer{ - AccessTokenToUserID: map[string]id.UserID{}, + AccessTokenToUserID: map[string]userAndDeviceID{}, DeviceInbox: map[id.UserID]map[id.DeviceID][]event.Event{}, AccountData: map[id.UserID]map[event.Type]json.RawMessage{}, DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{}, + OneTimeKeys: map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{}, MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{}, SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, + PopOTKs: true, + MemoryStore: true, } router := http.NewServeMux() router.HandleFunc("POST /_matrix/client/v3/login", server.postLogin) router.HandleFunc("POST /_matrix/client/v3/keys/query", server.postKeysQuery) + router.HandleFunc("POST /_matrix/client/v3/keys/claim", server.postKeysClaim) router.HandleFunc("PUT /_matrix/client/v3/sendToDevice/{type}/{txn}", server.putSendToDevice) router.HandleFunc("PUT /_matrix/client/v3/user/{userID}/account_data/{type}", server.putAccountData) router.HandleFunc("POST /_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload) @@ -66,7 +88,7 @@ func Create(t *testing.T) *MockServer { return &server } -func (ms *MockServer) getUserID(r *http.Request) id.UserID { +func (ms *MockServer) getUserID(r *http.Request) userAndDeviceID { authHeader := r.Header.Get("Authorization") authHeader = strings.TrimPrefix(authHeader, "Bearer ") userID, ok := ms.AccessTokenToUserID[authHeader] @@ -77,12 +99,12 @@ func (ms *MockServer) getUserID(r *http.Request) id.UserID { } func (ms *MockServer) emptyResp(w http.ResponseWriter, _ *http.Request) { - w.Write([]byte("{}")) + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) } func (ms *MockServer) postLogin(w http.ResponseWriter, r *http.Request) { var loginReq mautrix.ReqLogin - json.NewDecoder(r.Body).Decode(&loginReq) + mustDecode(r, &loginReq) deviceID := loginReq.DeviceID if deviceID == "" { @@ -91,9 +113,12 @@ func (ms *MockServer) postLogin(w http.ResponseWriter, r *http.Request) { accessToken := random.String(30) userID := id.UserID(loginReq.Identifier.User) - ms.AccessTokenToUserID[accessToken] = userID + ms.AccessTokenToUserID[accessToken] = userAndDeviceID{ + UserID: userID, + DeviceID: deviceID, + } - json.NewEncoder(w).Encode(&mautrix.RespLogin{ + exhttp.WriteJSONResponse(w, http.StatusOK, &mautrix.RespLogin{ AccessToken: accessToken, DeviceID: deviceID, UserID: userID, @@ -102,7 +127,7 @@ func (ms *MockServer) postLogin(w http.ResponseWriter, r *http.Request) { func (ms *MockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) { var req mautrix.ReqSendToDevice - json.NewDecoder(r.Body).Decode(&req) + mustDecode(r, &req) evtType := event.Type{Type: r.PathValue("type"), Class: event.ToDeviceEventType} for user, devices := range req.Messages { @@ -112,7 +137,7 @@ func (ms *MockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) { } content.ParseRaw(evtType) ms.DeviceInbox[user][device] = append(ms.DeviceInbox[user][device], event.Event{ - Sender: ms.getUserID(r), + Sender: ms.getUserID(r).UserID, Type: evtType, Content: *content, }) @@ -135,7 +160,7 @@ func (ms *MockServer) putAccountData(w http.ResponseWriter, r *http.Request) { func (ms *MockServer) postKeysQuery(w http.ResponseWriter, r *http.Request) { var req mautrix.ReqQueryKeys - json.NewDecoder(r.Body).Decode(&req) + mustDecode(r, &req) resp := mautrix.RespQueryKeys{ MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{}, UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, @@ -148,29 +173,68 @@ func (ms *MockServer) postKeysQuery(w http.ResponseWriter, r *http.Request) { resp.SelfSigningKeys[user] = ms.SelfSigningKeys[user] resp.DeviceKeys[user] = ms.DeviceKeys[user] } - json.NewEncoder(w).Encode(&resp) + exhttp.WriteJSONResponse(w, http.StatusOK, &resp) +} + +func (ms *MockServer) postKeysClaim(w http.ResponseWriter, r *http.Request) { + var req mautrix.ReqClaimKeys + mustDecode(r, &req) + resp := mautrix.RespClaimKeys{ + OneTimeKeys: map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{}, + } + for user, devices := range req.OneTimeKeys { + resp.OneTimeKeys[user] = map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{} + for device := range devices { + keys := ms.OneTimeKeys[user][device] + for keyID, key := range keys { + if ms.PopOTKs { + delete(keys, keyID) + } + resp.OneTimeKeys[user][device] = map[id.KeyID]mautrix.OneTimeKey{ + keyID: key, + } + break + } + } + } + exhttp.WriteJSONResponse(w, http.StatusOK, &resp) } func (ms *MockServer) postKeysUpload(w http.ResponseWriter, r *http.Request) { var req mautrix.ReqUploadKeys - json.NewDecoder(r.Body).Decode(&req) + mustDecode(r, &req) - userID := ms.getUserID(r) + uid := ms.getUserID(r) + userID := uid.UserID if _, ok := ms.DeviceKeys[userID]; !ok { ms.DeviceKeys[userID] = map[id.DeviceID]mautrix.DeviceKeys{} } - ms.DeviceKeys[userID][req.DeviceKeys.DeviceID] = *req.DeviceKeys + if _, ok := ms.OneTimeKeys[userID]; !ok { + ms.OneTimeKeys[userID] = map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{} + } - json.NewEncoder(w).Encode(&mautrix.RespUploadKeys{ - OneTimeKeyCounts: mautrix.OTKCount{SignedCurve25519: 50}, + if req.DeviceKeys != nil { + ms.DeviceKeys[userID][uid.DeviceID] = *req.DeviceKeys + } + otks, ok := ms.OneTimeKeys[userID][uid.DeviceID] + if !ok { + otks = map[id.KeyID]mautrix.OneTimeKey{} + ms.OneTimeKeys[userID][uid.DeviceID] = otks + } + if req.OneTimeKeys != nil { + maps.Copy(otks, req.OneTimeKeys) + } + + exhttp.WriteJSONResponse(w, http.StatusOK, &mautrix.RespUploadKeys{ + OneTimeKeyCounts: mautrix.OTKCount{SignedCurve25519: len(otks)}, }) } func (ms *MockServer) postDeviceSigningUpload(w http.ResponseWriter, r *http.Request) { var req mautrix.UploadCrossSigningKeysReq - json.NewDecoder(r.Body).Decode(&req) + mustDecode(r, &req) - userID := ms.getUserID(r) + userID := ms.getUserID(r).UserID ms.MasterKeys[userID] = req.Master ms.SelfSigningKeys[userID] = req.SelfSigning ms.UserSigningKeys[userID] = req.UserSigning @@ -178,11 +242,14 @@ func (ms *MockServer) postDeviceSigningUpload(w http.ResponseWriter, r *http.Req ms.emptyResp(w, r) } -func (ms *MockServer) Login(t *testing.T, ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*mautrix.Client, crypto.Store) { +func (ms *MockServer) Login(t testing.TB, ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*mautrix.Client, crypto.Store) { t.Helper() + if ctx == nil { + ctx = context.TODO() + } client, err := mautrix.NewClient(ms.Server.URL, "", "") require.NoError(t, err) - client.StateStore = mautrix.NewMemoryStateStore() + client.Client = ms.Server.Client() _, err = client.Login(ctx, &mautrix.ReqLogin{ Type: mautrix.AuthTypePassword, @@ -196,8 +263,22 @@ func (ms *MockServer) Login(t *testing.T, ctx context.Context, userID id.UserID, }) require.NoError(t, err) - cryptoStore := crypto.NewMemoryStore(nil) - cryptoHelper, err := cryptohelper.NewCryptoHelper(client, []byte("test"), cryptoStore) + var store any + if ms.MemoryStore { + store = crypto.NewMemoryStore(nil) + client.StateStore = mautrix.NewMemoryStateStore() + } else { + store, err = dbutil.NewFromConfig("", dbutil.Config{ + PoolConfig: dbutil.PoolConfig{ + Type: "sqlite3-fk-wal", + URI: fmt.Sprintf("file:%s?mode=memory&cache=shared&_txlock=immediate", random.String(10)), + MaxOpenConns: 5, + MaxIdleConns: 1, + }, + }, nil) + require.NoError(t, err) + } + cryptoHelper, err := cryptohelper.NewCryptoHelper(client, []byte("test"), store) require.NoError(t, err) client.Crypto = cryptoHelper @@ -213,10 +294,10 @@ func (ms *MockServer) Login(t *testing.T, ctx context.Context, userID id.UserID, err = cryptoHelper.Machine().ShareKeys(ctx, 50) require.NoError(t, err) - return client, cryptoStore + return client, cryptoHelper.Machine().CryptoStore } -func (ms *MockServer) DispatchToDevice(t *testing.T, ctx context.Context, client *mautrix.Client) { +func (ms *MockServer) DispatchToDevice(t testing.TB, ctx context.Context, client *mautrix.Client) { t.Helper() for _, evt := range ms.DeviceInbox[client.UserID][client.DeviceID] { diff --git a/responses_test.go b/responses_test.go index b23d85ad..73d82635 100644 --- a/responses_test.go +++ b/responses_test.go @@ -8,7 +8,6 @@ package mautrix_test import ( "encoding/json" - "fmt" "testing" "github.com/stretchr/testify/assert" @@ -86,7 +85,6 @@ func TestRespCapabilities_UnmarshalJSON(t *testing.T) { var caps mautrix.RespCapabilities err := json.Unmarshal([]byte(sampleData), &caps) require.NoError(t, err) - fmt.Println(caps) require.NotNil(t, caps.RoomVersions) assert.Equal(t, "9", caps.RoomVersions.Default)