crypto/verificationhelper: extract mockserver to new package

This commit is contained in:
Tulir Asokan 2025-09-26 16:56:48 +03:00
commit 0685bd7786
5 changed files with 117 additions and 130 deletions

View file

@ -32,7 +32,6 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) {
for _, tc := range testCases {
t.Run(fmt.Sprintf("sendingScansQR=%t", tc.sendingScansQR), func(t *testing.T) {
ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginAliceBob(t, ctx)
defer ts.Close()
sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
var err error
@ -51,10 +50,10 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) {
// event on the sending device.
txnID, err := sendingHelper.StartVerification(ctx, bobUserID)
require.NoError(t, err)
ts.dispatchToDevice(t, ctx, receivingClient)
ts.DispatchToDevice(t, ctx, receivingClient)
err = receivingHelper.AcceptVerification(ctx, txnID)
require.NoError(t, err)
ts.dispatchToDevice(t, ctx, sendingClient)
ts.DispatchToDevice(t, ctx, sendingClient)
receivingShownQRCode := receivingCallbacks.GetQRCodeShown(txnID)
require.NotNil(t, receivingShownQRCode)
@ -83,7 +82,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) {
// Handle the start and done events on the receiving client and
// confirm the scan.
ts.dispatchToDevice(t, ctx, receivingClient)
ts.DispatchToDevice(t, ctx, receivingClient)
// Ensure that the receiving device detected that its QR code
// was scanned.
@ -98,7 +97,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) {
doneEvt = sendingInbox[0].Content.AsVerificationDone()
assert.Equal(t, txnID, doneEvt.TransactionID)
ts.dispatchToDevice(t, ctx, sendingClient)
ts.DispatchToDevice(t, ctx, sendingClient)
} else { // receiving scans QR
// Emulate scanning the QR code shown by the sending device on
// the receiving device.
@ -121,7 +120,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) {
// Handle the start and done events on the receiving client and
// confirm the scan.
ts.dispatchToDevice(t, ctx, sendingClient)
ts.DispatchToDevice(t, ctx, sendingClient)
// Ensure that the sending device detected that its QR code was
// scanned.
@ -136,7 +135,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) {
doneEvt = receivingInbox[0].Content.AsVerificationDone()
assert.Equal(t, txnID, doneEvt.TransactionID)
ts.dispatchToDevice(t, ctx, receivingClient)
ts.DispatchToDevice(t, ctx, receivingClient)
}
// Ensure that both devices have marked the verification as done.

View file

@ -36,7 +36,6 @@ func TestSelfVerification_Accept_QRContents(t *testing.T) {
for _, tc := range testCases {
t.Run(fmt.Sprintf("sendingGenerated=%t receivingGenerated=%t err=%s", tc.sendingGeneratedCrossSigningKeys, tc.receivingGeneratedCrossSigningKeys, tc.expectedAcceptError), func(t *testing.T) {
ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
defer ts.Close()
sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
var err error
@ -62,7 +61,7 @@ func TestSelfVerification_Accept_QRContents(t *testing.T) {
// event on the sending device.
txnID, err := sendingHelper.StartVerification(ctx, aliceUserID)
require.NoError(t, err)
ts.dispatchToDevice(t, ctx, receivingClient)
ts.DispatchToDevice(t, ctx, receivingClient)
err = receivingHelper.AcceptVerification(ctx, txnID)
if tc.expectedAcceptError != "" {
@ -72,7 +71,7 @@ func TestSelfVerification_Accept_QRContents(t *testing.T) {
require.NoError(t, err)
}
ts.dispatchToDevice(t, ctx, sendingClient)
ts.DispatchToDevice(t, ctx, sendingClient)
receivingShownQRCode := receivingCallbacks.GetQRCodeShown(txnID)
require.NotNil(t, receivingShownQRCode)
@ -135,7 +134,6 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) {
for _, tc := range testCases {
t.Run(fmt.Sprintf("sendingGeneratedCrossSigningKeys=%t sendingScansQR=%t", tc.sendingGeneratedCrossSigningKeys, tc.sendingScansQR), func(t *testing.T) {
ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
defer ts.Close()
sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
var err error
@ -152,10 +150,10 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) {
// event on the sending device.
txnID, err := sendingHelper.StartVerification(ctx, aliceUserID)
require.NoError(t, err)
ts.dispatchToDevice(t, ctx, receivingClient)
ts.DispatchToDevice(t, ctx, receivingClient)
err = receivingHelper.AcceptVerification(ctx, txnID)
require.NoError(t, err)
ts.dispatchToDevice(t, ctx, sendingClient)
ts.DispatchToDevice(t, ctx, sendingClient)
receivingShownQRCode := receivingCallbacks.GetQRCodeShown(txnID)
require.NotNil(t, receivingShownQRCode)
@ -184,7 +182,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) {
// Handle the start and done events on the receiving client and
// confirm the scan.
ts.dispatchToDevice(t, ctx, receivingClient)
ts.DispatchToDevice(t, ctx, receivingClient)
// Ensure that the receiving device detected that its QR code
// was scanned.
@ -199,7 +197,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) {
doneEvt = sendingInbox[0].Content.AsVerificationDone()
assert.Equal(t, txnID, doneEvt.TransactionID)
ts.dispatchToDevice(t, ctx, sendingClient)
ts.DispatchToDevice(t, ctx, sendingClient)
} else { // receiving scans QR
// Emulate scanning the QR code shown by the sending device on
// the receiving device.
@ -222,7 +220,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) {
// Handle the start and done events on the receiving client and
// confirm the scan.
ts.dispatchToDevice(t, ctx, sendingClient)
ts.DispatchToDevice(t, ctx, sendingClient)
// Ensure that the sending device detected that its QR code was
// scanned.
@ -237,7 +235,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) {
doneEvt = receivingInbox[0].Content.AsVerificationDone()
assert.Equal(t, txnID, doneEvt.TransactionID)
ts.dispatchToDevice(t, ctx, receivingClient)
ts.DispatchToDevice(t, ctx, receivingClient)
}
// Ensure that both devices have marked the verification as done.
@ -251,7 +249,6 @@ func TestSelfVerification_ScanQRTransactionIDCorrupted(t *testing.T) {
ctx := log.Logger.WithContext(context.TODO())
ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
defer ts.Close()
sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
var err error
@ -263,10 +260,10 @@ func TestSelfVerification_ScanQRTransactionIDCorrupted(t *testing.T) {
// event on the sending device.
txnID, err := sendingHelper.StartVerification(ctx, aliceUserID)
require.NoError(t, err)
ts.dispatchToDevice(t, ctx, receivingClient)
ts.DispatchToDevice(t, ctx, receivingClient)
err = receivingHelper.AcceptVerification(ctx, txnID)
require.NoError(t, err)
ts.dispatchToDevice(t, ctx, sendingClient)
ts.DispatchToDevice(t, ctx, sendingClient)
receivingShownQRCodeBytes := receivingCallbacks.GetQRCodeShown(txnID).Bytes()
sendingShownQRCodeBytes := sendingCallbacks.GetQRCodeShown(txnID).Bytes()
@ -310,7 +307,6 @@ func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) {
for _, tc := range testCases {
t.Run(fmt.Sprintf("sendingGeneratedCrossSigningKeys=%t sendingScansQR=%t corrupt=%d", tc.sendingGeneratedCrossSigningKeys, tc.sendingScansQR, tc.corruptByte), func(t *testing.T) {
ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
defer ts.Close()
sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
var err error
@ -327,10 +323,10 @@ func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) {
// event on the sending device.
txnID, err := sendingHelper.StartVerification(ctx, aliceUserID)
require.NoError(t, err)
ts.dispatchToDevice(t, ctx, receivingClient)
ts.DispatchToDevice(t, ctx, receivingClient)
err = receivingHelper.AcceptVerification(ctx, txnID)
require.NoError(t, err)
ts.dispatchToDevice(t, ctx, sendingClient)
ts.DispatchToDevice(t, ctx, sendingClient)
receivingShownQRCodeBytes := receivingCallbacks.GetQRCodeShown(txnID).Bytes()
sendingShownQRCodeBytes := sendingCallbacks.GetQRCodeShown(txnID).Bytes()
@ -348,7 +344,7 @@ func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) {
// Ensure that the receiving device received a cancellation.
receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID]
assert.Len(t, receivingInbox, 1)
ts.dispatchToDevice(t, ctx, receivingClient)
ts.DispatchToDevice(t, ctx, receivingClient)
cancellation := receivingCallbacks.GetVerificationCancellation(txnID)
require.NotNil(t, cancellation)
assert.Equal(t, event.VerificationCancelCodeKeyMismatch, cancellation.Code)
@ -362,7 +358,7 @@ func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) {
// Ensure that the sending device received a cancellation.
sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID]
assert.Len(t, sendingInbox, 1)
ts.dispatchToDevice(t, ctx, sendingClient)
ts.DispatchToDevice(t, ctx, sendingClient)
cancellation := sendingCallbacks.GetVerificationCancellation(txnID)
require.NotNil(t, cancellation)
assert.Equal(t, event.VerificationCancelCodeKeyMismatch, cancellation.Code)

View file

@ -36,7 +36,6 @@ func TestVerification_SAS(t *testing.T) {
for _, tc := range testCases {
t.Run(fmt.Sprintf("sendingGenerated=%t sendingStartsSAS=%t sendingConfirmsFirst=%t", tc.sendingGeneratedCrossSigningKeys, tc.sendingStartsSAS, tc.sendingConfirmsFirst), func(t *testing.T) {
ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
defer ts.Close()
sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
var err error
@ -60,10 +59,10 @@ func TestVerification_SAS(t *testing.T) {
// event on the sending device.
txnID, err := sendingHelper.StartVerification(ctx, aliceUserID)
require.NoError(t, err)
ts.dispatchToDevice(t, ctx, receivingClient)
ts.DispatchToDevice(t, ctx, receivingClient)
err = receivingHelper.AcceptVerification(ctx, txnID)
require.NoError(t, err)
ts.dispatchToDevice(t, ctx, sendingClient)
ts.DispatchToDevice(t, ctx, sendingClient)
// Test that the start event is correct
var startEvt *event.VerificationStartEventContent
@ -102,7 +101,7 @@ func TestVerification_SAS(t *testing.T) {
if tc.sendingStartsSAS {
// Process the verification start event on the receiving
// device.
ts.dispatchToDevice(t, ctx, receivingClient)
ts.DispatchToDevice(t, ctx, receivingClient)
// Receiving device sent the accept event to the sending device
sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID]
@ -110,7 +109,7 @@ func TestVerification_SAS(t *testing.T) {
acceptEvt = sendingInbox[0].Content.AsVerificationAccept()
} else {
// Process the verification start event on the sending device.
ts.dispatchToDevice(t, ctx, sendingClient)
ts.DispatchToDevice(t, ctx, sendingClient)
// Sending device sent the accept event to the receiving device
receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID]
@ -129,7 +128,7 @@ func TestVerification_SAS(t *testing.T) {
var firstKeyEvt *event.VerificationKeyEventContent
if tc.sendingStartsSAS {
// Process the verification accept event on the sending device.
ts.dispatchToDevice(t, ctx, sendingClient)
ts.DispatchToDevice(t, ctx, sendingClient)
// Sending device sends first key event to the receiving
// device.
@ -139,7 +138,7 @@ func TestVerification_SAS(t *testing.T) {
} else {
// Process the verification accept event on the receiving
// device.
ts.dispatchToDevice(t, ctx, receivingClient)
ts.DispatchToDevice(t, ctx, receivingClient)
// Receiving device sends first key event to the sending
// device.
@ -155,7 +154,7 @@ func TestVerification_SAS(t *testing.T) {
var secondKeyEvt *event.VerificationKeyEventContent
if tc.sendingStartsSAS {
// Process the first key event on the receiving device.
ts.dispatchToDevice(t, ctx, receivingClient)
ts.DispatchToDevice(t, ctx, receivingClient)
// Receiving device sends second key event to the sending
// device.
@ -170,7 +169,7 @@ func TestVerification_SAS(t *testing.T) {
assert.Len(t, descriptions, 7)
} else {
// Process the first key event on the sending device.
ts.dispatchToDevice(t, ctx, sendingClient)
ts.DispatchToDevice(t, ctx, sendingClient)
// Sending device sends second key event to the receiving
// device.
@ -191,10 +190,10 @@ func TestVerification_SAS(t *testing.T) {
// Ensure that the SAS codes are the same.
if tc.sendingStartsSAS {
// Process the second key event on the sending device.
ts.dispatchToDevice(t, ctx, sendingClient)
ts.DispatchToDevice(t, ctx, sendingClient)
} else {
// Process the second key event on the receiving device.
ts.dispatchToDevice(t, ctx, receivingClient)
ts.DispatchToDevice(t, ctx, receivingClient)
}
assert.Equal(t, sendingCallbacks.GetDecimalsShown(txnID), receivingCallbacks.GetDecimalsShown(txnID))
sendingEmojis, sendingDescriptions := sendingCallbacks.GetEmojisAndDescriptionsShown(txnID)
@ -274,10 +273,10 @@ func TestVerification_SAS(t *testing.T) {
// Test the transaction is done on both sides. We have to dispatch
// twice to process and drain all of the events.
ts.dispatchToDevice(t, ctx, sendingClient)
ts.dispatchToDevice(t, ctx, receivingClient)
ts.dispatchToDevice(t, ctx, sendingClient)
ts.dispatchToDevice(t, ctx, receivingClient)
ts.DispatchToDevice(t, ctx, sendingClient)
ts.DispatchToDevice(t, ctx, receivingClient)
ts.DispatchToDevice(t, ctx, sendingClient)
ts.DispatchToDevice(t, ctx, receivingClient)
assert.True(t, sendingCallbacks.IsVerificationDone(txnID))
assert.True(t, receivingCallbacks.IsVerificationDone(txnID))
})
@ -288,7 +287,6 @@ func TestVerification_SAS_BothCallStart(t *testing.T) {
ctx := log.Logger.WithContext(context.TODO())
ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
defer ts.Close()
sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
var err error
@ -305,10 +303,10 @@ func TestVerification_SAS_BothCallStart(t *testing.T) {
// event on the sending device.
txnID, err := sendingHelper.StartVerification(ctx, aliceUserID)
require.NoError(t, err)
ts.dispatchToDevice(t, ctx, receivingClient)
ts.DispatchToDevice(t, ctx, receivingClient)
err = receivingHelper.AcceptVerification(ctx, txnID)
require.NoError(t, err)
ts.dispatchToDevice(t, ctx, sendingClient)
ts.DispatchToDevice(t, ctx, sendingClient)
err = sendingHelper.StartSAS(ctx, txnID)
require.NoError(t, err)
@ -325,7 +323,7 @@ func TestVerification_SAS_BothCallStart(t *testing.T) {
assert.Equal(t, txnID, sendingInbox[0].Content.AsVerificationStart().TransactionID)
// Process the start event from the receiving client to the sending client.
ts.dispatchToDevice(t, ctx, sendingClient)
ts.DispatchToDevice(t, ctx, sendingClient)
receivingInbox = ts.DeviceInbox[aliceUserID][receivingDeviceID]
assert.Len(t, receivingInbox, 2)
assert.Equal(t, txnID, receivingInbox[0].Content.AsVerificationStart().TransactionID)
@ -333,13 +331,13 @@ func TestVerification_SAS_BothCallStart(t *testing.T) {
// Process the rest of the events until we need to confirm the SAS.
for len(ts.DeviceInbox[aliceUserID][sendingDeviceID]) > 0 || len(ts.DeviceInbox[aliceUserID][receivingDeviceID]) > 0 {
ts.dispatchToDevice(t, ctx, receivingClient)
ts.dispatchToDevice(t, ctx, sendingClient)
ts.DispatchToDevice(t, ctx, receivingClient)
ts.DispatchToDevice(t, ctx, sendingClient)
}
// Confirm the SAS only the receiving device.
receivingHelper.ConfirmSAS(ctx, txnID)
ts.dispatchToDevice(t, ctx, sendingClient)
ts.DispatchToDevice(t, ctx, sendingClient)
// Verification is not done until both devices confirm the SAS.
assert.False(t, sendingCallbacks.IsVerificationDone(txnID))
@ -350,13 +348,13 @@ func TestVerification_SAS_BothCallStart(t *testing.T) {
// Dispatching the events to the receiving device should get us to the done
// state on the receiving device.
ts.dispatchToDevice(t, ctx, receivingClient)
ts.DispatchToDevice(t, ctx, receivingClient)
assert.False(t, sendingCallbacks.IsVerificationDone(txnID))
assert.True(t, receivingCallbacks.IsVerificationDone(txnID))
// Dispatching the events to the sending client should get us to the done
// state on the sending device.
ts.dispatchToDevice(t, ctx, sendingClient)
ts.DispatchToDevice(t, ctx, sendingClient)
assert.True(t, sendingCallbacks.IsVerificationDone(txnID))
assert.True(t, receivingCallbacks.IsVerificationDone(txnID))
}

View file

@ -19,6 +19,7 @@ import (
"maunium.net/go/mautrix/crypto/verificationhelper"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/mockserver"
)
var aliceUserID = id.UserID("@alice:example.org")
@ -31,9 +32,19 @@ func init() {
zerolog.DefaultContextLogger = &log.Logger
}
func initServerAndLoginTwoAlice(t *testing.T, ctx context.Context) (ts *mockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) {
func addDeviceID(ctx context.Context, cryptoStore crypto.Store, userID id.UserID, deviceID id.DeviceID) {
err := cryptoStore.PutDevice(ctx, userID, &id.Device{
UserID: userID,
DeviceID: deviceID,
})
if err != nil {
panic(err)
}
}
func initServerAndLoginTwoAlice(t *testing.T, ctx context.Context) (ts *mockserver.MockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) {
t.Helper()
ts = createMockServer(t)
ts = mockserver.Create(t)
sendingClient, sendingCryptoStore = ts.Login(t, ctx, aliceUserID, sendingDeviceID)
sendingMachine = sendingClient.Crypto.(*cryptohelper.CryptoHelper).Machine()
@ -47,9 +58,9 @@ func initServerAndLoginTwoAlice(t *testing.T, ctx context.Context) (ts *mockServ
return
}
func initServerAndLoginAliceBob(t *testing.T, ctx context.Context) (ts *mockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) {
func initServerAndLoginAliceBob(t *testing.T, ctx context.Context) (ts *mockserver.MockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) {
t.Helper()
ts = createMockServer(t)
ts = mockserver.Create(t)
sendingClient, sendingCryptoStore = ts.Login(t, ctx, aliceUserID, sendingDeviceID)
sendingMachine = sendingClient.Crypto.(*cryptohelper.CryptoHelper).Machine()
@ -116,8 +127,7 @@ func TestVerification_Start(t *testing.T) {
for i, tc := range testCases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
ts := createMockServer(t)
defer ts.Close()
ts := mockserver.Create(t)
client, cryptoStore := ts.Login(t, ctx, aliceUserID, sendingDeviceID)
addDeviceID(ctx, cryptoStore, aliceUserID, sendingDeviceID)
@ -166,7 +176,6 @@ func TestVerification_StartThenCancel(t *testing.T) {
for _, sendingCancels := range []bool{true, false} {
t.Run(fmt.Sprintf("sendingCancels=%t", sendingCancels), func(t *testing.T) {
ts, sendingClient, receivingClient, sendingCryptoStore, receivingCryptoStore, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
defer ts.Close()
_, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
bystanderClient, _ := ts.Login(t, ctx, aliceUserID, bystanderDeviceID)
@ -186,13 +195,13 @@ func TestVerification_StartThenCancel(t *testing.T) {
receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID]
assert.Len(t, receivingInbox, 1)
assert.Equal(t, txnID, receivingInbox[0].Content.AsVerificationRequest().TransactionID)
ts.dispatchToDevice(t, ctx, receivingClient)
ts.DispatchToDevice(t, ctx, receivingClient)
// Process the request event on the bystander device.
bystanderInbox := ts.DeviceInbox[aliceUserID][bystanderDeviceID]
assert.Len(t, bystanderInbox, 1)
assert.Equal(t, txnID, bystanderInbox[0].Content.AsVerificationRequest().TransactionID)
ts.dispatchToDevice(t, ctx, bystanderClient)
ts.DispatchToDevice(t, ctx, bystanderClient)
// Cancel the verification request.
var cancelEvt *event.VerificationCancelEventContent
@ -231,7 +240,7 @@ func TestVerification_StartThenCancel(t *testing.T) {
if !sendingCancels {
// Process the cancellation event on the sending device.
ts.dispatchToDevice(t, ctx, sendingClient)
ts.DispatchToDevice(t, ctx, sendingClient)
// Ensure that the cancellation event was sent to the bystander device.
assert.Len(t, ts.DeviceInbox[aliceUserID][bystanderDeviceID], 1)
@ -247,8 +256,7 @@ func TestVerification_StartThenCancel(t *testing.T) {
func TestVerification_Accept_NoSupportedMethods(t *testing.T) {
ctx := log.Logger.WithContext(context.TODO())
ts := createMockServer(t)
defer ts.Close()
ts := mockserver.Create(t)
sendingClient, sendingCryptoStore := ts.Login(t, ctx, aliceUserID, sendingDeviceID)
receivingClient, _ := ts.Login(t, ctx, aliceUserID, receivingDeviceID)
@ -274,7 +282,7 @@ func TestVerification_Accept_NoSupportedMethods(t *testing.T) {
require.NoError(t, err)
require.NotEmpty(t, txnID)
ts.dispatchToDevice(t, ctx, receivingClient)
ts.DispatchToDevice(t, ctx, receivingClient)
// Ensure that the receiver ignored the request because it
// doesn't support any of the verification methods in the
@ -314,7 +322,6 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) {
for i, tc := range testCases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
defer ts.Close()
recoveryKey, sendingCrossSigningKeysCache, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "")
assert.NoError(t, err)
@ -333,7 +340,7 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) {
require.NoError(t, err)
// Process the verification request on the receiving device.
ts.dispatchToDevice(t, ctx, receivingClient)
ts.DispatchToDevice(t, ctx, receivingClient)
// Ensure that the receiving device received a verification
// request with the correct transaction ID.
@ -373,7 +380,7 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) {
// Receive the m.key.verification.ready event on the sending
// device.
ts.dispatchToDevice(t, ctx, sendingClient)
ts.DispatchToDevice(t, ctx, sendingClient)
// Ensure that the sending device got a notification about the
// transaction being ready.
@ -402,7 +409,6 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) {
func TestVerification_Accept_CancelOnNonParticipatingDevices(t *testing.T) {
ctx := log.Logger.WithContext(context.TODO())
ts, sendingClient, receivingClient, sendingCryptoStore, receivingCryptoStore, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
defer ts.Close()
_, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
nonParticipatingDeviceID1 := id.DeviceID("non-participating1")
@ -419,12 +425,12 @@ func TestVerification_Accept_CancelOnNonParticipatingDevices(t *testing.T) {
// the receiving device.
txnID, err := sendingHelper.StartVerification(ctx, aliceUserID)
require.NoError(t, err)
ts.dispatchToDevice(t, ctx, receivingClient)
ts.DispatchToDevice(t, ctx, receivingClient)
err = receivingHelper.AcceptVerification(ctx, txnID)
require.NoError(t, err)
// Receive the m.key.verification.ready event on the sending device.
ts.dispatchToDevice(t, ctx, sendingClient)
ts.DispatchToDevice(t, ctx, sendingClient)
// The sending and receiving devices should not have any cancellation
// events in their inboxes.
@ -444,7 +450,6 @@ func TestVerification_Accept_CancelOnNonParticipatingDevices(t *testing.T) {
func TestVerification_ErrorOnDoubleAccept(t *testing.T) {
ctx := log.Logger.WithContext(context.TODO())
ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
defer ts.Close()
_, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
_, _, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "")
@ -452,7 +457,7 @@ func TestVerification_ErrorOnDoubleAccept(t *testing.T) {
txnID, err := sendingHelper.StartVerification(ctx, aliceUserID)
require.NoError(t, err)
ts.dispatchToDevice(t, ctx, receivingClient)
ts.DispatchToDevice(t, ctx, receivingClient)
err = receivingHelper.AcceptVerification(ctx, txnID)
require.NoError(t, err)
err = receivingHelper.AcceptVerification(ctx, txnID)
@ -472,7 +477,6 @@ func TestVerification_ErrorOnDoubleAccept(t *testing.T) {
func TestVerification_CancelOnDoubleStart(t *testing.T) {
ctx := log.Logger.WithContext(context.TODO())
ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
defer ts.Close()
sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
_, _, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "")
@ -481,15 +485,15 @@ func TestVerification_CancelOnDoubleStart(t *testing.T) {
// Send and accept the first verification request.
txnID1, err := sendingHelper.StartVerification(ctx, aliceUserID)
require.NoError(t, err)
ts.dispatchToDevice(t, ctx, receivingClient)
ts.DispatchToDevice(t, ctx, receivingClient)
err = receivingHelper.AcceptVerification(ctx, txnID1)
require.NoError(t, err)
ts.dispatchToDevice(t, ctx, sendingClient) // Process the m.key.verification.ready event
ts.DispatchToDevice(t, ctx, sendingClient) // Process the m.key.verification.ready event
// Send a second verification request
txnID2, err := sendingHelper.StartVerification(ctx, aliceUserID)
require.NoError(t, err)
ts.dispatchToDevice(t, ctx, receivingClient)
ts.DispatchToDevice(t, ctx, receivingClient)
// Ensure that the sending device received a cancellation event for both of
// the ongoing transactions.
@ -507,7 +511,7 @@ func TestVerification_CancelOnDoubleStart(t *testing.T) {
assert.NotNil(t, receivingCallbacks.GetVerificationCancellation(txnID1))
assert.NotNil(t, receivingCallbacks.GetVerificationCancellation(txnID2))
ts.dispatchToDevice(t, ctx, sendingClient) // Process the m.key.verification.cancel events
ts.DispatchToDevice(t, ctx, sendingClient) // Process the m.key.verification.cancel events
assert.NotNil(t, sendingCallbacks.GetVerificationCancellation(txnID1))
assert.NotNil(t, sendingCallbacks.GetVerificationCancellation(txnID2))
}

View file

@ -1,10 +1,10 @@
// Copyright (c) 2024 Sumner Evans
// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package verificationhelper_test
package mockserver
import (
"context"
@ -15,7 +15,7 @@ import (
"strings"
"testing"
"github.com/rs/zerolog/log" // zerolog-allow-global-log
globallog "github.com/rs/zerolog/log" // zerolog-allow-global-log
"github.com/stretchr/testify/require"
"go.mau.fi/util/random"
@ -26,10 +26,9 @@ import (
"maunium.net/go/mautrix/id"
)
// mockServer is a mock Matrix server that wraps an [httptest.Server] to allow
// testing of the interactive verification process.
type mockServer struct {
*httptest.Server
type MockServer struct {
Router *http.ServeMux
Server *httptest.Server
AccessTokenToUserID map[string]id.UserID
DeviceInbox map[id.UserID]map[id.DeviceID][]event.Event
@ -40,10 +39,10 @@ type mockServer struct {
UserSigningKeys map[id.UserID]mautrix.CrossSigningKeys
}
func createMockServer(t *testing.T) *mockServer {
func Create(t *testing.T) *MockServer {
t.Helper()
server := mockServer{
server := MockServer{
AccessTokenToUserID: map[string]id.UserID{},
DeviceInbox: map[id.UserID]map[id.DeviceID][]event.Event{},
AccountData: map[id.UserID]map[event.Type]json.RawMessage{},
@ -61,12 +60,13 @@ func createMockServer(t *testing.T) *mockServer {
router.HandleFunc("POST /_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload)
router.HandleFunc("POST /_matrix/client/v3/keys/signatures/upload", server.emptyResp)
router.HandleFunc("POST /_matrix/client/v3/keys/upload", server.postKeysUpload)
server.Router = router
server.Server = httptest.NewServer(router)
t.Cleanup(server.Server.Close)
return &server
}
func (ms *mockServer) getUserID(r *http.Request) id.UserID {
func (ms *MockServer) getUserID(r *http.Request) id.UserID {
authHeader := r.Header.Get("Authorization")
authHeader = strings.TrimPrefix(authHeader, "Bearer ")
userID, ok := ms.AccessTokenToUserID[authHeader]
@ -76,11 +76,11 @@ func (ms *mockServer) getUserID(r *http.Request) id.UserID {
return userID
}
func (s *mockServer) emptyResp(w http.ResponseWriter, _ *http.Request) {
func (ms *MockServer) emptyResp(w http.ResponseWriter, _ *http.Request) {
w.Write([]byte("{}"))
}
func (s *mockServer) postLogin(w http.ResponseWriter, r *http.Request) {
func (ms *MockServer) postLogin(w http.ResponseWriter, r *http.Request) {
var loginReq mautrix.ReqLogin
json.NewDecoder(r.Body).Decode(&loginReq)
@ -91,7 +91,7 @@ func (s *mockServer) postLogin(w http.ResponseWriter, r *http.Request) {
accessToken := random.String(30)
userID := id.UserID(loginReq.Identifier.User)
s.AccessTokenToUserID[accessToken] = userID
ms.AccessTokenToUserID[accessToken] = userID
json.NewEncoder(w).Encode(&mautrix.RespLogin{
AccessToken: accessToken,
@ -100,40 +100,40 @@ func (s *mockServer) postLogin(w http.ResponseWriter, r *http.Request) {
})
}
func (s *mockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) {
func (ms *MockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) {
var req mautrix.ReqSendToDevice
json.NewDecoder(r.Body).Decode(&req)
evtType := event.Type{Type: r.PathValue("type"), Class: event.ToDeviceEventType}
for user, devices := range req.Messages {
for device, content := range devices {
if _, ok := s.DeviceInbox[user]; !ok {
s.DeviceInbox[user] = map[id.DeviceID][]event.Event{}
if _, ok := ms.DeviceInbox[user]; !ok {
ms.DeviceInbox[user] = map[id.DeviceID][]event.Event{}
}
content.ParseRaw(evtType)
s.DeviceInbox[user][device] = append(s.DeviceInbox[user][device], event.Event{
Sender: s.getUserID(r),
ms.DeviceInbox[user][device] = append(ms.DeviceInbox[user][device], event.Event{
Sender: ms.getUserID(r),
Type: evtType,
Content: *content,
})
}
}
s.emptyResp(w, r)
ms.emptyResp(w, r)
}
func (s *mockServer) putAccountData(w http.ResponseWriter, r *http.Request) {
func (ms *MockServer) putAccountData(w http.ResponseWriter, r *http.Request) {
userID := id.UserID(r.PathValue("userID"))
eventType := event.Type{Type: r.PathValue("type"), Class: event.AccountDataEventType}
jsonData, _ := io.ReadAll(r.Body)
if _, ok := s.AccountData[userID]; !ok {
s.AccountData[userID] = map[event.Type]json.RawMessage{}
if _, ok := ms.AccountData[userID]; !ok {
ms.AccountData[userID] = map[event.Type]json.RawMessage{}
}
s.AccountData[userID][eventType] = json.RawMessage(jsonData)
s.emptyResp(w, r)
ms.AccountData[userID][eventType] = json.RawMessage(jsonData)
ms.emptyResp(w, r)
}
func (s *mockServer) postKeysQuery(w http.ResponseWriter, r *http.Request) {
func (ms *MockServer) postKeysQuery(w http.ResponseWriter, r *http.Request) {
var req mautrix.ReqQueryKeys
json.NewDecoder(r.Body).Decode(&req)
resp := mautrix.RespQueryKeys{
@ -143,44 +143,44 @@ func (s *mockServer) postKeysQuery(w http.ResponseWriter, r *http.Request) {
DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{},
}
for user := range req.DeviceKeys {
resp.MasterKeys[user] = s.MasterKeys[user]
resp.UserSigningKeys[user] = s.UserSigningKeys[user]
resp.SelfSigningKeys[user] = s.SelfSigningKeys[user]
resp.DeviceKeys[user] = s.DeviceKeys[user]
resp.MasterKeys[user] = ms.MasterKeys[user]
resp.UserSigningKeys[user] = ms.UserSigningKeys[user]
resp.SelfSigningKeys[user] = ms.SelfSigningKeys[user]
resp.DeviceKeys[user] = ms.DeviceKeys[user]
}
json.NewEncoder(w).Encode(&resp)
}
func (s *mockServer) postKeysUpload(w http.ResponseWriter, r *http.Request) {
func (ms *MockServer) postKeysUpload(w http.ResponseWriter, r *http.Request) {
var req mautrix.ReqUploadKeys
json.NewDecoder(r.Body).Decode(&req)
userID := s.getUserID(r)
if _, ok := s.DeviceKeys[userID]; !ok {
s.DeviceKeys[userID] = map[id.DeviceID]mautrix.DeviceKeys{}
userID := ms.getUserID(r)
if _, ok := ms.DeviceKeys[userID]; !ok {
ms.DeviceKeys[userID] = map[id.DeviceID]mautrix.DeviceKeys{}
}
s.DeviceKeys[userID][req.DeviceKeys.DeviceID] = *req.DeviceKeys
ms.DeviceKeys[userID][req.DeviceKeys.DeviceID] = *req.DeviceKeys
json.NewEncoder(w).Encode(&mautrix.RespUploadKeys{
OneTimeKeyCounts: mautrix.OTKCount{SignedCurve25519: 50},
})
}
func (s *mockServer) postDeviceSigningUpload(w http.ResponseWriter, r *http.Request) {
func (ms *MockServer) postDeviceSigningUpload(w http.ResponseWriter, r *http.Request) {
var req mautrix.UploadCrossSigningKeysReq
json.NewDecoder(r.Body).Decode(&req)
userID := s.getUserID(r)
s.MasterKeys[userID] = req.Master
s.SelfSigningKeys[userID] = req.SelfSigning
s.UserSigningKeys[userID] = req.UserSigning
userID := ms.getUserID(r)
ms.MasterKeys[userID] = req.Master
ms.SelfSigningKeys[userID] = req.SelfSigning
ms.UserSigningKeys[userID] = req.UserSigning
s.emptyResp(w, r)
ms.emptyResp(w, r)
}
func (ms *mockServer) Login(t *testing.T, ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*mautrix.Client, crypto.Store) {
func (ms *MockServer) Login(t *testing.T, ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*mautrix.Client, crypto.Store) {
t.Helper()
client, err := mautrix.NewClient(ms.URL, "", "")
client, err := mautrix.NewClient(ms.Server.URL, "", "")
require.NoError(t, err)
client.StateStore = mautrix.NewMemoryStateStore()
@ -204,7 +204,7 @@ func (ms *mockServer) Login(t *testing.T, ctx context.Context, userID id.UserID,
err = cryptoHelper.Init(ctx)
require.NoError(t, err)
machineLog := log.Logger.With().
machineLog := globallog.Logger.With().
Stringer("my_user_id", userID).
Stringer("my_device_id", deviceID).
Logger()
@ -216,7 +216,7 @@ func (ms *mockServer) Login(t *testing.T, ctx context.Context, userID id.UserID,
return client, cryptoStore
}
func (ms *mockServer) dispatchToDevice(t *testing.T, ctx context.Context, client *mautrix.Client) {
func (ms *MockServer) DispatchToDevice(t *testing.T, ctx context.Context, client *mautrix.Client) {
t.Helper()
for _, evt := range ms.DeviceInbox[client.UserID][client.DeviceID] {
@ -224,13 +224,3 @@ func (ms *mockServer) dispatchToDevice(t *testing.T, ctx context.Context, client
ms.DeviceInbox[client.UserID][client.DeviceID] = ms.DeviceInbox[client.UserID][client.DeviceID][1:]
}
}
func addDeviceID(ctx context.Context, cryptoStore crypto.Store, userID id.UserID, deviceID id.DeviceID) {
err := cryptoStore.PutDevice(ctx, userID, &id.Device{
UserID: userID,
DeviceID: deviceID,
})
if err != nil {
panic(err)
}
}