crypto: add basic group session sharing benchmark

This commit is contained in:
Tulir Asokan 2025-09-26 20:37:58 +03:00
commit acc449daf4
3 changed files with 173 additions and 27 deletions

View file

@ -0,0 +1,67 @@
// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package crypto_test
import (
"context"
"fmt"
"math/rand/v2"
"testing"
"github.com/rs/zerolog"
globallog "github.com/rs/zerolog/log" // zerolog-allow-global-log
"github.com/stretchr/testify/require"
"maunium.net/go/mautrix/crypto/cryptohelper"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/mockserver"
)
func randomDeviceCount(r *rand.Rand) int {
k := 1
for k < 10 && r.IntN(3) > 0 {
k++
}
return k
}
func BenchmarkOlmMachine_ShareGroupSession(b *testing.B) {
globallog.Logger = zerolog.Nop()
server := mockserver.Create(b)
server.PopOTKs = false
server.MemoryStore = false
var i int
var shareTargets []id.UserID
r := rand.New(rand.NewPCG(293, 0))
var totalDeviceCount int
for i = 1; i < 1000; i++ {
userID := id.UserID(fmt.Sprintf("@user%d:localhost", i))
deviceCount := randomDeviceCount(r)
for j := 0; j < deviceCount; j++ {
client, _ := server.Login(b, nil, userID, id.DeviceID(fmt.Sprintf("u%d_d%d", i, j)))
mach := client.Crypto.(*cryptohelper.CryptoHelper).Machine()
keysCache, err := mach.GenerateCrossSigningKeys()
require.NoError(b, err)
err = mach.PublishCrossSigningKeys(context.TODO(), keysCache, nil)
require.NoError(b, err)
}
totalDeviceCount += deviceCount
shareTargets = append(shareTargets, userID)
}
for b.Loop() {
client, _ := server.Login(b, nil, id.UserID(fmt.Sprintf("@benchuser%d:localhost", i)), id.DeviceID(fmt.Sprintf("u%d_d1", i)))
mach := client.Crypto.(*cryptohelper.CryptoHelper).Machine()
keysCache, err := mach.GenerateCrossSigningKeys()
require.NoError(b, err)
err = mach.PublishCrossSigningKeys(context.TODO(), keysCache, nil)
require.NoError(b, err)
err = mach.ShareGroupSession(context.TODO(), "!room:localhost", shareTargets)
require.NoError(b, err)
i++
}
fmt.Println(totalDeviceCount, "devices total")
}

View file

@ -9,7 +9,9 @@ package mockserver
import (
"context"
"encoding/json"
"fmt"
"io"
"maps"
"net/http"
"net/http/httptest"
"strings"
@ -17,6 +19,9 @@ import (
globallog "github.com/rs/zerolog/log" // zerolog-allow-global-log
"github.com/stretchr/testify/require"
"go.mau.fi/util/dbutil"
"go.mau.fi/util/exerrors"
"go.mau.fi/util/exhttp"
"go.mau.fi/util/random"
"maunium.net/go/mautrix"
@ -26,35 +31,52 @@ import (
"maunium.net/go/mautrix/id"
)
func mustDecode(r *http.Request, data any) {
exerrors.PanicIfNotNil(json.NewDecoder(r.Body).Decode(data))
}
type userAndDeviceID struct {
UserID id.UserID
DeviceID id.DeviceID
}
type MockServer struct {
Router *http.ServeMux
Server *httptest.Server
AccessTokenToUserID map[string]id.UserID
AccessTokenToUserID map[string]userAndDeviceID
DeviceInbox map[id.UserID]map[id.DeviceID][]event.Event
AccountData map[id.UserID]map[event.Type]json.RawMessage
DeviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys
OneTimeKeys map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey
MasterKeys map[id.UserID]mautrix.CrossSigningKeys
SelfSigningKeys map[id.UserID]mautrix.CrossSigningKeys
UserSigningKeys map[id.UserID]mautrix.CrossSigningKeys
PopOTKs bool
MemoryStore bool
}
func Create(t *testing.T) *MockServer {
func Create(t testing.TB) *MockServer {
t.Helper()
server := MockServer{
AccessTokenToUserID: map[string]id.UserID{},
AccessTokenToUserID: map[string]userAndDeviceID{},
DeviceInbox: map[id.UserID]map[id.DeviceID][]event.Event{},
AccountData: map[id.UserID]map[event.Type]json.RawMessage{},
DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{},
OneTimeKeys: map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{},
MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{},
SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
PopOTKs: true,
MemoryStore: true,
}
router := http.NewServeMux()
router.HandleFunc("POST /_matrix/client/v3/login", server.postLogin)
router.HandleFunc("POST /_matrix/client/v3/keys/query", server.postKeysQuery)
router.HandleFunc("POST /_matrix/client/v3/keys/claim", server.postKeysClaim)
router.HandleFunc("PUT /_matrix/client/v3/sendToDevice/{type}/{txn}", server.putSendToDevice)
router.HandleFunc("PUT /_matrix/client/v3/user/{userID}/account_data/{type}", server.putAccountData)
router.HandleFunc("POST /_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload)
@ -66,7 +88,7 @@ func Create(t *testing.T) *MockServer {
return &server
}
func (ms *MockServer) getUserID(r *http.Request) id.UserID {
func (ms *MockServer) getUserID(r *http.Request) userAndDeviceID {
authHeader := r.Header.Get("Authorization")
authHeader = strings.TrimPrefix(authHeader, "Bearer ")
userID, ok := ms.AccessTokenToUserID[authHeader]
@ -77,12 +99,12 @@ func (ms *MockServer) getUserID(r *http.Request) id.UserID {
}
func (ms *MockServer) emptyResp(w http.ResponseWriter, _ *http.Request) {
w.Write([]byte("{}"))
exhttp.WriteEmptyJSONResponse(w, http.StatusOK)
}
func (ms *MockServer) postLogin(w http.ResponseWriter, r *http.Request) {
var loginReq mautrix.ReqLogin
json.NewDecoder(r.Body).Decode(&loginReq)
mustDecode(r, &loginReq)
deviceID := loginReq.DeviceID
if deviceID == "" {
@ -91,9 +113,12 @@ func (ms *MockServer) postLogin(w http.ResponseWriter, r *http.Request) {
accessToken := random.String(30)
userID := id.UserID(loginReq.Identifier.User)
ms.AccessTokenToUserID[accessToken] = userID
ms.AccessTokenToUserID[accessToken] = userAndDeviceID{
UserID: userID,
DeviceID: deviceID,
}
json.NewEncoder(w).Encode(&mautrix.RespLogin{
exhttp.WriteJSONResponse(w, http.StatusOK, &mautrix.RespLogin{
AccessToken: accessToken,
DeviceID: deviceID,
UserID: userID,
@ -102,7 +127,7 @@ func (ms *MockServer) postLogin(w http.ResponseWriter, r *http.Request) {
func (ms *MockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) {
var req mautrix.ReqSendToDevice
json.NewDecoder(r.Body).Decode(&req)
mustDecode(r, &req)
evtType := event.Type{Type: r.PathValue("type"), Class: event.ToDeviceEventType}
for user, devices := range req.Messages {
@ -112,7 +137,7 @@ func (ms *MockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) {
}
content.ParseRaw(evtType)
ms.DeviceInbox[user][device] = append(ms.DeviceInbox[user][device], event.Event{
Sender: ms.getUserID(r),
Sender: ms.getUserID(r).UserID,
Type: evtType,
Content: *content,
})
@ -135,7 +160,7 @@ func (ms *MockServer) putAccountData(w http.ResponseWriter, r *http.Request) {
func (ms *MockServer) postKeysQuery(w http.ResponseWriter, r *http.Request) {
var req mautrix.ReqQueryKeys
json.NewDecoder(r.Body).Decode(&req)
mustDecode(r, &req)
resp := mautrix.RespQueryKeys{
MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{},
UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
@ -148,29 +173,68 @@ func (ms *MockServer) postKeysQuery(w http.ResponseWriter, r *http.Request) {
resp.SelfSigningKeys[user] = ms.SelfSigningKeys[user]
resp.DeviceKeys[user] = ms.DeviceKeys[user]
}
json.NewEncoder(w).Encode(&resp)
exhttp.WriteJSONResponse(w, http.StatusOK, &resp)
}
func (ms *MockServer) postKeysClaim(w http.ResponseWriter, r *http.Request) {
var req mautrix.ReqClaimKeys
mustDecode(r, &req)
resp := mautrix.RespClaimKeys{
OneTimeKeys: map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{},
}
for user, devices := range req.OneTimeKeys {
resp.OneTimeKeys[user] = map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{}
for device := range devices {
keys := ms.OneTimeKeys[user][device]
for keyID, key := range keys {
if ms.PopOTKs {
delete(keys, keyID)
}
resp.OneTimeKeys[user][device] = map[id.KeyID]mautrix.OneTimeKey{
keyID: key,
}
break
}
}
}
exhttp.WriteJSONResponse(w, http.StatusOK, &resp)
}
func (ms *MockServer) postKeysUpload(w http.ResponseWriter, r *http.Request) {
var req mautrix.ReqUploadKeys
json.NewDecoder(r.Body).Decode(&req)
mustDecode(r, &req)
userID := ms.getUserID(r)
uid := ms.getUserID(r)
userID := uid.UserID
if _, ok := ms.DeviceKeys[userID]; !ok {
ms.DeviceKeys[userID] = map[id.DeviceID]mautrix.DeviceKeys{}
}
ms.DeviceKeys[userID][req.DeviceKeys.DeviceID] = *req.DeviceKeys
if _, ok := ms.OneTimeKeys[userID]; !ok {
ms.OneTimeKeys[userID] = map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{}
}
json.NewEncoder(w).Encode(&mautrix.RespUploadKeys{
OneTimeKeyCounts: mautrix.OTKCount{SignedCurve25519: 50},
if req.DeviceKeys != nil {
ms.DeviceKeys[userID][uid.DeviceID] = *req.DeviceKeys
}
otks, ok := ms.OneTimeKeys[userID][uid.DeviceID]
if !ok {
otks = map[id.KeyID]mautrix.OneTimeKey{}
ms.OneTimeKeys[userID][uid.DeviceID] = otks
}
if req.OneTimeKeys != nil {
maps.Copy(otks, req.OneTimeKeys)
}
exhttp.WriteJSONResponse(w, http.StatusOK, &mautrix.RespUploadKeys{
OneTimeKeyCounts: mautrix.OTKCount{SignedCurve25519: len(otks)},
})
}
func (ms *MockServer) postDeviceSigningUpload(w http.ResponseWriter, r *http.Request) {
var req mautrix.UploadCrossSigningKeysReq
json.NewDecoder(r.Body).Decode(&req)
mustDecode(r, &req)
userID := ms.getUserID(r)
userID := ms.getUserID(r).UserID
ms.MasterKeys[userID] = req.Master
ms.SelfSigningKeys[userID] = req.SelfSigning
ms.UserSigningKeys[userID] = req.UserSigning
@ -178,11 +242,14 @@ func (ms *MockServer) postDeviceSigningUpload(w http.ResponseWriter, r *http.Req
ms.emptyResp(w, r)
}
func (ms *MockServer) Login(t *testing.T, ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*mautrix.Client, crypto.Store) {
func (ms *MockServer) Login(t testing.TB, ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*mautrix.Client, crypto.Store) {
t.Helper()
if ctx == nil {
ctx = context.TODO()
}
client, err := mautrix.NewClient(ms.Server.URL, "", "")
require.NoError(t, err)
client.StateStore = mautrix.NewMemoryStateStore()
client.Client = ms.Server.Client()
_, err = client.Login(ctx, &mautrix.ReqLogin{
Type: mautrix.AuthTypePassword,
@ -196,8 +263,22 @@ func (ms *MockServer) Login(t *testing.T, ctx context.Context, userID id.UserID,
})
require.NoError(t, err)
cryptoStore := crypto.NewMemoryStore(nil)
cryptoHelper, err := cryptohelper.NewCryptoHelper(client, []byte("test"), cryptoStore)
var store any
if ms.MemoryStore {
store = crypto.NewMemoryStore(nil)
client.StateStore = mautrix.NewMemoryStateStore()
} else {
store, err = dbutil.NewFromConfig("", dbutil.Config{
PoolConfig: dbutil.PoolConfig{
Type: "sqlite3-fk-wal",
URI: fmt.Sprintf("file:%s?mode=memory&cache=shared&_txlock=immediate", random.String(10)),
MaxOpenConns: 5,
MaxIdleConns: 1,
},
}, nil)
require.NoError(t, err)
}
cryptoHelper, err := cryptohelper.NewCryptoHelper(client, []byte("test"), store)
require.NoError(t, err)
client.Crypto = cryptoHelper
@ -213,10 +294,10 @@ func (ms *MockServer) Login(t *testing.T, ctx context.Context, userID id.UserID,
err = cryptoHelper.Machine().ShareKeys(ctx, 50)
require.NoError(t, err)
return client, cryptoStore
return client, cryptoHelper.Machine().CryptoStore
}
func (ms *MockServer) DispatchToDevice(t *testing.T, ctx context.Context, client *mautrix.Client) {
func (ms *MockServer) DispatchToDevice(t testing.TB, ctx context.Context, client *mautrix.Client) {
t.Helper()
for _, evt := range ms.DeviceInbox[client.UserID][client.DeviceID] {

View file

@ -8,7 +8,6 @@ package mautrix_test
import (
"encoding/json"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
@ -86,7 +85,6 @@ func TestRespCapabilities_UnmarshalJSON(t *testing.T) {
var caps mautrix.RespCapabilities
err := json.Unmarshal([]byte(sampleData), &caps)
require.NoError(t, err)
fmt.Println(caps)
require.NotNil(t, caps.RoomVersions)
assert.Equal(t, "9", caps.RoomVersions.Default)