mautrix-go/mockserver/mockserver.go
2025-09-26 20:37:58 +03:00

307 lines
9.5 KiB
Go

// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package mockserver
import (
"context"
"encoding/json"
"fmt"
"io"
"maps"
"net/http"
"net/http/httptest"
"strings"
"testing"
globallog "github.com/rs/zerolog/log" // zerolog-allow-global-log
"github.com/stretchr/testify/require"
"go.mau.fi/util/dbutil"
"go.mau.fi/util/exerrors"
"go.mau.fi/util/exhttp"
"go.mau.fi/util/random"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/crypto/cryptohelper"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
func mustDecode(r *http.Request, data any) {
exerrors.PanicIfNotNil(json.NewDecoder(r.Body).Decode(data))
}
type userAndDeviceID struct {
UserID id.UserID
DeviceID id.DeviceID
}
type MockServer struct {
Router *http.ServeMux
Server *httptest.Server
AccessTokenToUserID map[string]userAndDeviceID
DeviceInbox map[id.UserID]map[id.DeviceID][]event.Event
AccountData map[id.UserID]map[event.Type]json.RawMessage
DeviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys
OneTimeKeys map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey
MasterKeys map[id.UserID]mautrix.CrossSigningKeys
SelfSigningKeys map[id.UserID]mautrix.CrossSigningKeys
UserSigningKeys map[id.UserID]mautrix.CrossSigningKeys
PopOTKs bool
MemoryStore bool
}
func Create(t testing.TB) *MockServer {
t.Helper()
server := MockServer{
AccessTokenToUserID: map[string]userAndDeviceID{},
DeviceInbox: map[id.UserID]map[id.DeviceID][]event.Event{},
AccountData: map[id.UserID]map[event.Type]json.RawMessage{},
DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{},
OneTimeKeys: map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{},
MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{},
SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
PopOTKs: true,
MemoryStore: true,
}
router := http.NewServeMux()
router.HandleFunc("POST /_matrix/client/v3/login", server.postLogin)
router.HandleFunc("POST /_matrix/client/v3/keys/query", server.postKeysQuery)
router.HandleFunc("POST /_matrix/client/v3/keys/claim", server.postKeysClaim)
router.HandleFunc("PUT /_matrix/client/v3/sendToDevice/{type}/{txn}", server.putSendToDevice)
router.HandleFunc("PUT /_matrix/client/v3/user/{userID}/account_data/{type}", server.putAccountData)
router.HandleFunc("POST /_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload)
router.HandleFunc("POST /_matrix/client/v3/keys/signatures/upload", server.emptyResp)
router.HandleFunc("POST /_matrix/client/v3/keys/upload", server.postKeysUpload)
server.Router = router
server.Server = httptest.NewServer(router)
t.Cleanup(server.Server.Close)
return &server
}
func (ms *MockServer) getUserID(r *http.Request) userAndDeviceID {
authHeader := r.Header.Get("Authorization")
authHeader = strings.TrimPrefix(authHeader, "Bearer ")
userID, ok := ms.AccessTokenToUserID[authHeader]
if !ok {
panic("no user ID found for access token " + authHeader)
}
return userID
}
func (ms *MockServer) emptyResp(w http.ResponseWriter, _ *http.Request) {
exhttp.WriteEmptyJSONResponse(w, http.StatusOK)
}
func (ms *MockServer) postLogin(w http.ResponseWriter, r *http.Request) {
var loginReq mautrix.ReqLogin
mustDecode(r, &loginReq)
deviceID := loginReq.DeviceID
if deviceID == "" {
deviceID = id.DeviceID(random.String(10))
}
accessToken := random.String(30)
userID := id.UserID(loginReq.Identifier.User)
ms.AccessTokenToUserID[accessToken] = userAndDeviceID{
UserID: userID,
DeviceID: deviceID,
}
exhttp.WriteJSONResponse(w, http.StatusOK, &mautrix.RespLogin{
AccessToken: accessToken,
DeviceID: deviceID,
UserID: userID,
})
}
func (ms *MockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) {
var req mautrix.ReqSendToDevice
mustDecode(r, &req)
evtType := event.Type{Type: r.PathValue("type"), Class: event.ToDeviceEventType}
for user, devices := range req.Messages {
for device, content := range devices {
if _, ok := ms.DeviceInbox[user]; !ok {
ms.DeviceInbox[user] = map[id.DeviceID][]event.Event{}
}
content.ParseRaw(evtType)
ms.DeviceInbox[user][device] = append(ms.DeviceInbox[user][device], event.Event{
Sender: ms.getUserID(r).UserID,
Type: evtType,
Content: *content,
})
}
}
ms.emptyResp(w, r)
}
func (ms *MockServer) putAccountData(w http.ResponseWriter, r *http.Request) {
userID := id.UserID(r.PathValue("userID"))
eventType := event.Type{Type: r.PathValue("type"), Class: event.AccountDataEventType}
jsonData, _ := io.ReadAll(r.Body)
if _, ok := ms.AccountData[userID]; !ok {
ms.AccountData[userID] = map[event.Type]json.RawMessage{}
}
ms.AccountData[userID][eventType] = json.RawMessage(jsonData)
ms.emptyResp(w, r)
}
func (ms *MockServer) postKeysQuery(w http.ResponseWriter, r *http.Request) {
var req mautrix.ReqQueryKeys
mustDecode(r, &req)
resp := mautrix.RespQueryKeys{
MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{},
UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{},
}
for user := range req.DeviceKeys {
resp.MasterKeys[user] = ms.MasterKeys[user]
resp.UserSigningKeys[user] = ms.UserSigningKeys[user]
resp.SelfSigningKeys[user] = ms.SelfSigningKeys[user]
resp.DeviceKeys[user] = ms.DeviceKeys[user]
}
exhttp.WriteJSONResponse(w, http.StatusOK, &resp)
}
func (ms *MockServer) postKeysClaim(w http.ResponseWriter, r *http.Request) {
var req mautrix.ReqClaimKeys
mustDecode(r, &req)
resp := mautrix.RespClaimKeys{
OneTimeKeys: map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{},
}
for user, devices := range req.OneTimeKeys {
resp.OneTimeKeys[user] = map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{}
for device := range devices {
keys := ms.OneTimeKeys[user][device]
for keyID, key := range keys {
if ms.PopOTKs {
delete(keys, keyID)
}
resp.OneTimeKeys[user][device] = map[id.KeyID]mautrix.OneTimeKey{
keyID: key,
}
break
}
}
}
exhttp.WriteJSONResponse(w, http.StatusOK, &resp)
}
func (ms *MockServer) postKeysUpload(w http.ResponseWriter, r *http.Request) {
var req mautrix.ReqUploadKeys
mustDecode(r, &req)
uid := ms.getUserID(r)
userID := uid.UserID
if _, ok := ms.DeviceKeys[userID]; !ok {
ms.DeviceKeys[userID] = map[id.DeviceID]mautrix.DeviceKeys{}
}
if _, ok := ms.OneTimeKeys[userID]; !ok {
ms.OneTimeKeys[userID] = map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{}
}
if req.DeviceKeys != nil {
ms.DeviceKeys[userID][uid.DeviceID] = *req.DeviceKeys
}
otks, ok := ms.OneTimeKeys[userID][uid.DeviceID]
if !ok {
otks = map[id.KeyID]mautrix.OneTimeKey{}
ms.OneTimeKeys[userID][uid.DeviceID] = otks
}
if req.OneTimeKeys != nil {
maps.Copy(otks, req.OneTimeKeys)
}
exhttp.WriteJSONResponse(w, http.StatusOK, &mautrix.RespUploadKeys{
OneTimeKeyCounts: mautrix.OTKCount{SignedCurve25519: len(otks)},
})
}
func (ms *MockServer) postDeviceSigningUpload(w http.ResponseWriter, r *http.Request) {
var req mautrix.UploadCrossSigningKeysReq
mustDecode(r, &req)
userID := ms.getUserID(r).UserID
ms.MasterKeys[userID] = req.Master
ms.SelfSigningKeys[userID] = req.SelfSigning
ms.UserSigningKeys[userID] = req.UserSigning
ms.emptyResp(w, r)
}
func (ms *MockServer) Login(t testing.TB, ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*mautrix.Client, crypto.Store) {
t.Helper()
if ctx == nil {
ctx = context.TODO()
}
client, err := mautrix.NewClient(ms.Server.URL, "", "")
require.NoError(t, err)
client.Client = ms.Server.Client()
_, err = client.Login(ctx, &mautrix.ReqLogin{
Type: mautrix.AuthTypePassword,
Identifier: mautrix.UserIdentifier{
Type: mautrix.IdentifierTypeUser,
User: userID.String(),
},
DeviceID: deviceID,
Password: "password",
StoreCredentials: true,
})
require.NoError(t, err)
var store any
if ms.MemoryStore {
store = crypto.NewMemoryStore(nil)
client.StateStore = mautrix.NewMemoryStateStore()
} else {
store, err = dbutil.NewFromConfig("", dbutil.Config{
PoolConfig: dbutil.PoolConfig{
Type: "sqlite3-fk-wal",
URI: fmt.Sprintf("file:%s?mode=memory&cache=shared&_txlock=immediate", random.String(10)),
MaxOpenConns: 5,
MaxIdleConns: 1,
},
}, nil)
require.NoError(t, err)
}
cryptoHelper, err := cryptohelper.NewCryptoHelper(client, []byte("test"), store)
require.NoError(t, err)
client.Crypto = cryptoHelper
err = cryptoHelper.Init(ctx)
require.NoError(t, err)
machineLog := globallog.Logger.With().
Stringer("my_user_id", userID).
Stringer("my_device_id", deviceID).
Logger()
cryptoHelper.Machine().Log = &machineLog
err = cryptoHelper.Machine().ShareKeys(ctx, 50)
require.NoError(t, err)
return client, cryptoHelper.Machine().CryptoStore
}
func (ms *MockServer) DispatchToDevice(t testing.TB, ctx context.Context, client *mautrix.Client) {
t.Helper()
for _, evt := range ms.DeviceInbox[client.UserID][client.DeviceID] {
client.Syncer.(*mautrix.DefaultSyncer).Dispatch(ctx, &evt)
ms.DeviceInbox[client.UserID][client.DeviceID] = ms.DeviceInbox[client.UserID][client.DeviceID][1:]
}
}