mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
crypto: add basic group session sharing benchmark
This commit is contained in:
parent
fa90bba820
commit
acc449daf4
3 changed files with 173 additions and 27 deletions
67
crypto/machine_bench_test.go
Normal file
67
crypto/machine_bench_test.go
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
// Copyright (c) 2025 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package crypto_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand/v2"
|
||||
"testing"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
globallog "github.com/rs/zerolog/log" // zerolog-allow-global-log
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/cryptohelper"
|
||||
"maunium.net/go/mautrix/id"
|
||||
"maunium.net/go/mautrix/mockserver"
|
||||
)
|
||||
|
||||
func randomDeviceCount(r *rand.Rand) int {
|
||||
k := 1
|
||||
for k < 10 && r.IntN(3) > 0 {
|
||||
k++
|
||||
}
|
||||
return k
|
||||
}
|
||||
|
||||
func BenchmarkOlmMachine_ShareGroupSession(b *testing.B) {
|
||||
globallog.Logger = zerolog.Nop()
|
||||
server := mockserver.Create(b)
|
||||
server.PopOTKs = false
|
||||
server.MemoryStore = false
|
||||
var i int
|
||||
var shareTargets []id.UserID
|
||||
r := rand.New(rand.NewPCG(293, 0))
|
||||
var totalDeviceCount int
|
||||
for i = 1; i < 1000; i++ {
|
||||
userID := id.UserID(fmt.Sprintf("@user%d:localhost", i))
|
||||
deviceCount := randomDeviceCount(r)
|
||||
for j := 0; j < deviceCount; j++ {
|
||||
client, _ := server.Login(b, nil, userID, id.DeviceID(fmt.Sprintf("u%d_d%d", i, j)))
|
||||
mach := client.Crypto.(*cryptohelper.CryptoHelper).Machine()
|
||||
keysCache, err := mach.GenerateCrossSigningKeys()
|
||||
require.NoError(b, err)
|
||||
err = mach.PublishCrossSigningKeys(context.TODO(), keysCache, nil)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
totalDeviceCount += deviceCount
|
||||
shareTargets = append(shareTargets, userID)
|
||||
}
|
||||
for b.Loop() {
|
||||
client, _ := server.Login(b, nil, id.UserID(fmt.Sprintf("@benchuser%d:localhost", i)), id.DeviceID(fmt.Sprintf("u%d_d1", i)))
|
||||
mach := client.Crypto.(*cryptohelper.CryptoHelper).Machine()
|
||||
keysCache, err := mach.GenerateCrossSigningKeys()
|
||||
require.NoError(b, err)
|
||||
err = mach.PublishCrossSigningKeys(context.TODO(), keysCache, nil)
|
||||
require.NoError(b, err)
|
||||
err = mach.ShareGroupSession(context.TODO(), "!room:localhost", shareTargets)
|
||||
require.NoError(b, err)
|
||||
i++
|
||||
}
|
||||
fmt.Println(totalDeviceCount, "devices total")
|
||||
}
|
||||
|
|
@ -9,7 +9,9 @@ package mockserver
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"maps"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
|
|
@ -17,6 +19,9 @@ import (
|
|||
|
||||
globallog "github.com/rs/zerolog/log" // zerolog-allow-global-log
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.mau.fi/util/dbutil"
|
||||
"go.mau.fi/util/exerrors"
|
||||
"go.mau.fi/util/exhttp"
|
||||
"go.mau.fi/util/random"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
|
|
@ -26,35 +31,52 @@ import (
|
|||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
func mustDecode(r *http.Request, data any) {
|
||||
exerrors.PanicIfNotNil(json.NewDecoder(r.Body).Decode(data))
|
||||
}
|
||||
|
||||
type userAndDeviceID struct {
|
||||
UserID id.UserID
|
||||
DeviceID id.DeviceID
|
||||
}
|
||||
|
||||
type MockServer struct {
|
||||
Router *http.ServeMux
|
||||
Server *httptest.Server
|
||||
|
||||
AccessTokenToUserID map[string]id.UserID
|
||||
AccessTokenToUserID map[string]userAndDeviceID
|
||||
DeviceInbox map[id.UserID]map[id.DeviceID][]event.Event
|
||||
AccountData map[id.UserID]map[event.Type]json.RawMessage
|
||||
DeviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys
|
||||
OneTimeKeys map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey
|
||||
MasterKeys map[id.UserID]mautrix.CrossSigningKeys
|
||||
SelfSigningKeys map[id.UserID]mautrix.CrossSigningKeys
|
||||
UserSigningKeys map[id.UserID]mautrix.CrossSigningKeys
|
||||
|
||||
PopOTKs bool
|
||||
MemoryStore bool
|
||||
}
|
||||
|
||||
func Create(t *testing.T) *MockServer {
|
||||
func Create(t testing.TB) *MockServer {
|
||||
t.Helper()
|
||||
|
||||
server := MockServer{
|
||||
AccessTokenToUserID: map[string]id.UserID{},
|
||||
AccessTokenToUserID: map[string]userAndDeviceID{},
|
||||
DeviceInbox: map[id.UserID]map[id.DeviceID][]event.Event{},
|
||||
AccountData: map[id.UserID]map[event.Type]json.RawMessage{},
|
||||
DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{},
|
||||
OneTimeKeys: map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{},
|
||||
MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{},
|
||||
SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
|
||||
UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
|
||||
PopOTKs: true,
|
||||
MemoryStore: true,
|
||||
}
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.HandleFunc("POST /_matrix/client/v3/login", server.postLogin)
|
||||
router.HandleFunc("POST /_matrix/client/v3/keys/query", server.postKeysQuery)
|
||||
router.HandleFunc("POST /_matrix/client/v3/keys/claim", server.postKeysClaim)
|
||||
router.HandleFunc("PUT /_matrix/client/v3/sendToDevice/{type}/{txn}", server.putSendToDevice)
|
||||
router.HandleFunc("PUT /_matrix/client/v3/user/{userID}/account_data/{type}", server.putAccountData)
|
||||
router.HandleFunc("POST /_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload)
|
||||
|
|
@ -66,7 +88,7 @@ func Create(t *testing.T) *MockServer {
|
|||
return &server
|
||||
}
|
||||
|
||||
func (ms *MockServer) getUserID(r *http.Request) id.UserID {
|
||||
func (ms *MockServer) getUserID(r *http.Request) userAndDeviceID {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
authHeader = strings.TrimPrefix(authHeader, "Bearer ")
|
||||
userID, ok := ms.AccessTokenToUserID[authHeader]
|
||||
|
|
@ -77,12 +99,12 @@ func (ms *MockServer) getUserID(r *http.Request) id.UserID {
|
|||
}
|
||||
|
||||
func (ms *MockServer) emptyResp(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Write([]byte("{}"))
|
||||
exhttp.WriteEmptyJSONResponse(w, http.StatusOK)
|
||||
}
|
||||
|
||||
func (ms *MockServer) postLogin(w http.ResponseWriter, r *http.Request) {
|
||||
var loginReq mautrix.ReqLogin
|
||||
json.NewDecoder(r.Body).Decode(&loginReq)
|
||||
mustDecode(r, &loginReq)
|
||||
|
||||
deviceID := loginReq.DeviceID
|
||||
if deviceID == "" {
|
||||
|
|
@ -91,9 +113,12 @@ func (ms *MockServer) postLogin(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
accessToken := random.String(30)
|
||||
userID := id.UserID(loginReq.Identifier.User)
|
||||
ms.AccessTokenToUserID[accessToken] = userID
|
||||
ms.AccessTokenToUserID[accessToken] = userAndDeviceID{
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(&mautrix.RespLogin{
|
||||
exhttp.WriteJSONResponse(w, http.StatusOK, &mautrix.RespLogin{
|
||||
AccessToken: accessToken,
|
||||
DeviceID: deviceID,
|
||||
UserID: userID,
|
||||
|
|
@ -102,7 +127,7 @@ func (ms *MockServer) postLogin(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
func (ms *MockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) {
|
||||
var req mautrix.ReqSendToDevice
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
mustDecode(r, &req)
|
||||
evtType := event.Type{Type: r.PathValue("type"), Class: event.ToDeviceEventType}
|
||||
|
||||
for user, devices := range req.Messages {
|
||||
|
|
@ -112,7 +137,7 @@ func (ms *MockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
content.ParseRaw(evtType)
|
||||
ms.DeviceInbox[user][device] = append(ms.DeviceInbox[user][device], event.Event{
|
||||
Sender: ms.getUserID(r),
|
||||
Sender: ms.getUserID(r).UserID,
|
||||
Type: evtType,
|
||||
Content: *content,
|
||||
})
|
||||
|
|
@ -135,7 +160,7 @@ func (ms *MockServer) putAccountData(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
func (ms *MockServer) postKeysQuery(w http.ResponseWriter, r *http.Request) {
|
||||
var req mautrix.ReqQueryKeys
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
mustDecode(r, &req)
|
||||
resp := mautrix.RespQueryKeys{
|
||||
MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{},
|
||||
UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
|
||||
|
|
@ -148,29 +173,68 @@ func (ms *MockServer) postKeysQuery(w http.ResponseWriter, r *http.Request) {
|
|||
resp.SelfSigningKeys[user] = ms.SelfSigningKeys[user]
|
||||
resp.DeviceKeys[user] = ms.DeviceKeys[user]
|
||||
}
|
||||
json.NewEncoder(w).Encode(&resp)
|
||||
exhttp.WriteJSONResponse(w, http.StatusOK, &resp)
|
||||
}
|
||||
|
||||
func (ms *MockServer) postKeysClaim(w http.ResponseWriter, r *http.Request) {
|
||||
var req mautrix.ReqClaimKeys
|
||||
mustDecode(r, &req)
|
||||
resp := mautrix.RespClaimKeys{
|
||||
OneTimeKeys: map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{},
|
||||
}
|
||||
for user, devices := range req.OneTimeKeys {
|
||||
resp.OneTimeKeys[user] = map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{}
|
||||
for device := range devices {
|
||||
keys := ms.OneTimeKeys[user][device]
|
||||
for keyID, key := range keys {
|
||||
if ms.PopOTKs {
|
||||
delete(keys, keyID)
|
||||
}
|
||||
resp.OneTimeKeys[user][device] = map[id.KeyID]mautrix.OneTimeKey{
|
||||
keyID: key,
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
exhttp.WriteJSONResponse(w, http.StatusOK, &resp)
|
||||
}
|
||||
|
||||
func (ms *MockServer) postKeysUpload(w http.ResponseWriter, r *http.Request) {
|
||||
var req mautrix.ReqUploadKeys
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
mustDecode(r, &req)
|
||||
|
||||
userID := ms.getUserID(r)
|
||||
uid := ms.getUserID(r)
|
||||
userID := uid.UserID
|
||||
if _, ok := ms.DeviceKeys[userID]; !ok {
|
||||
ms.DeviceKeys[userID] = map[id.DeviceID]mautrix.DeviceKeys{}
|
||||
}
|
||||
ms.DeviceKeys[userID][req.DeviceKeys.DeviceID] = *req.DeviceKeys
|
||||
if _, ok := ms.OneTimeKeys[userID]; !ok {
|
||||
ms.OneTimeKeys[userID] = map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{}
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(&mautrix.RespUploadKeys{
|
||||
OneTimeKeyCounts: mautrix.OTKCount{SignedCurve25519: 50},
|
||||
if req.DeviceKeys != nil {
|
||||
ms.DeviceKeys[userID][uid.DeviceID] = *req.DeviceKeys
|
||||
}
|
||||
otks, ok := ms.OneTimeKeys[userID][uid.DeviceID]
|
||||
if !ok {
|
||||
otks = map[id.KeyID]mautrix.OneTimeKey{}
|
||||
ms.OneTimeKeys[userID][uid.DeviceID] = otks
|
||||
}
|
||||
if req.OneTimeKeys != nil {
|
||||
maps.Copy(otks, req.OneTimeKeys)
|
||||
}
|
||||
|
||||
exhttp.WriteJSONResponse(w, http.StatusOK, &mautrix.RespUploadKeys{
|
||||
OneTimeKeyCounts: mautrix.OTKCount{SignedCurve25519: len(otks)},
|
||||
})
|
||||
}
|
||||
|
||||
func (ms *MockServer) postDeviceSigningUpload(w http.ResponseWriter, r *http.Request) {
|
||||
var req mautrix.UploadCrossSigningKeysReq
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
mustDecode(r, &req)
|
||||
|
||||
userID := ms.getUserID(r)
|
||||
userID := ms.getUserID(r).UserID
|
||||
ms.MasterKeys[userID] = req.Master
|
||||
ms.SelfSigningKeys[userID] = req.SelfSigning
|
||||
ms.UserSigningKeys[userID] = req.UserSigning
|
||||
|
|
@ -178,11 +242,14 @@ func (ms *MockServer) postDeviceSigningUpload(w http.ResponseWriter, r *http.Req
|
|||
ms.emptyResp(w, r)
|
||||
}
|
||||
|
||||
func (ms *MockServer) Login(t *testing.T, ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*mautrix.Client, crypto.Store) {
|
||||
func (ms *MockServer) Login(t testing.TB, ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*mautrix.Client, crypto.Store) {
|
||||
t.Helper()
|
||||
if ctx == nil {
|
||||
ctx = context.TODO()
|
||||
}
|
||||
client, err := mautrix.NewClient(ms.Server.URL, "", "")
|
||||
require.NoError(t, err)
|
||||
client.StateStore = mautrix.NewMemoryStateStore()
|
||||
client.Client = ms.Server.Client()
|
||||
|
||||
_, err = client.Login(ctx, &mautrix.ReqLogin{
|
||||
Type: mautrix.AuthTypePassword,
|
||||
|
|
@ -196,8 +263,22 @@ func (ms *MockServer) Login(t *testing.T, ctx context.Context, userID id.UserID,
|
|||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
cryptoStore := crypto.NewMemoryStore(nil)
|
||||
cryptoHelper, err := cryptohelper.NewCryptoHelper(client, []byte("test"), cryptoStore)
|
||||
var store any
|
||||
if ms.MemoryStore {
|
||||
store = crypto.NewMemoryStore(nil)
|
||||
client.StateStore = mautrix.NewMemoryStateStore()
|
||||
} else {
|
||||
store, err = dbutil.NewFromConfig("", dbutil.Config{
|
||||
PoolConfig: dbutil.PoolConfig{
|
||||
Type: "sqlite3-fk-wal",
|
||||
URI: fmt.Sprintf("file:%s?mode=memory&cache=shared&_txlock=immediate", random.String(10)),
|
||||
MaxOpenConns: 5,
|
||||
MaxIdleConns: 1,
|
||||
},
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
cryptoHelper, err := cryptohelper.NewCryptoHelper(client, []byte("test"), store)
|
||||
require.NoError(t, err)
|
||||
client.Crypto = cryptoHelper
|
||||
|
||||
|
|
@ -213,10 +294,10 @@ func (ms *MockServer) Login(t *testing.T, ctx context.Context, userID id.UserID,
|
|||
err = cryptoHelper.Machine().ShareKeys(ctx, 50)
|
||||
require.NoError(t, err)
|
||||
|
||||
return client, cryptoStore
|
||||
return client, cryptoHelper.Machine().CryptoStore
|
||||
}
|
||||
|
||||
func (ms *MockServer) DispatchToDevice(t *testing.T, ctx context.Context, client *mautrix.Client) {
|
||||
func (ms *MockServer) DispatchToDevice(t testing.TB, ctx context.Context, client *mautrix.Client) {
|
||||
t.Helper()
|
||||
|
||||
for _, evt := range ms.DeviceInbox[client.UserID][client.DeviceID] {
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ package mautrix_test
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
|
@ -86,7 +85,6 @@ func TestRespCapabilities_UnmarshalJSON(t *testing.T) {
|
|||
var caps mautrix.RespCapabilities
|
||||
err := json.Unmarshal([]byte(sampleData), &caps)
|
||||
require.NoError(t, err)
|
||||
fmt.Println(caps)
|
||||
|
||||
require.NotNil(t, caps.RoomVersions)
|
||||
assert.Equal(t, "9", caps.RoomVersions.Default)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue