diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index 1fdbcc70..c47eea71 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -82,7 +82,7 @@ type VerificationHelper struct { var _ mautrix.VerificationHelper = (*VerificationHelper)(nil) -func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, store VerificationStore, callbacks any, supportsQRShow, supportsQRScan bool) *VerificationHelper { +func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, store VerificationStore, callbacks any, supportsQRShow, supportsQRScan, supportsSAS bool) *VerificationHelper { if client.Crypto == nil { panic("client.Crypto is nil") } @@ -106,9 +106,13 @@ func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, stor helper.verificationDone = c.VerificationDone } - if c, ok := callbacks.(ShowSASCallbacks); ok { - helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodSAS) - helper.showSAS = c.ShowSAS + if supportsSAS { + if c, ok := callbacks.(ShowSASCallbacks); !ok { + panic("callbacks must implement showSAS if supportsSAS is true") + } else { + helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodSAS) + helper.showSAS = c.ShowSAS + } } if supportsQRShow { if c, ok := callbacks.(ShowQRCodeCallbacks); !ok { diff --git a/crypto/verificationhelper/verificationhelper_test.go b/crypto/verificationhelper/verificationhelper_test.go index e192508b..b4c21c18 100644 --- a/crypto/verificationhelper/verificationhelper_test.go +++ b/crypto/verificationhelper/verificationhelper_test.go @@ -71,7 +71,7 @@ func initDefaultCallbacks(t *testing.T, ctx context.Context, sendingClient, rece senderVerificationStore, err := NewSQLiteVerificationStore(ctx, senderVerificationDB) require.NoError(t, err) - sendingHelper = verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, senderVerificationStore, sendingCallbacks, true, true) + sendingHelper = verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, senderVerificationStore, sendingCallbacks, true, true, true) require.NoError(t, sendingHelper.Init(ctx)) receivingCallbacks = newAllVerificationCallbacks() @@ -79,7 +79,7 @@ func initDefaultCallbacks(t *testing.T, ctx context.Context, sendingClient, rece require.NoError(t, err) receiverVerificationStore, err := NewSQLiteVerificationStore(ctx, receiverVerificationDB) require.NoError(t, err) - receivingHelper = verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, receiverVerificationStore, receivingCallbacks, true, true) + receivingHelper = verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, receiverVerificationStore, receivingCallbacks, true, true, true) require.NoError(t, receivingHelper.Init(ctx)) return } @@ -91,25 +91,27 @@ func TestVerification_Start(t *testing.T) { testCases := []struct { supportsShow bool supportsScan bool + supportsSAS bool callbacks MockVerificationCallbacks startVerificationErrMsg string expectedVerificationMethods []event.VerificationMethod }{ - {false, false, newBaseVerificationCallbacks(), "no supported verification methods", nil}, - {false, true, newBaseVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + {false, false, false, newBaseVerificationCallbacks(), "no supported verification methods", nil}, + {false, true, false, newBaseVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, - {false, false, newShowQRCodeVerificationCallbacks(), "no supported verification methods", nil}, - {true, false, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, - {false, true, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, - {true, true, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + {false, false, false, newShowQRCodeVerificationCallbacks(), "no supported verification methods", nil}, + {true, false, false, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, + {false, true, false, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + {true, true, false, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, - {false, false, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS}}, - {false, true, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + {false, false, true, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS}}, + {false, true, true, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, - {false, false, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS}}, - {false, true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, - {true, false, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, - {true, true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeShow, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + {false, false, false, newAllVerificationCallbacks(), "no supported verification methods", nil}, + {false, false, true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS}}, + {false, true, true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + {true, false, true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, + {true, true, true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeShow, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, } for i, tc := range testCases { @@ -122,7 +124,7 @@ func TestVerification_Start(t *testing.T) { addDeviceID(ctx, cryptoStore, aliceUserID, receivingDeviceID) addDeviceID(ctx, cryptoStore, aliceUserID, receivingDeviceID2) - senderHelper := verificationhelper.NewVerificationHelper(client, client.Crypto.(*cryptohelper.CryptoHelper).Machine(), nil, tc.callbacks, tc.supportsShow, tc.supportsScan) + senderHelper := verificationhelper.NewVerificationHelper(client, client.Crypto.(*cryptohelper.CryptoHelper).Machine(), nil, tc.callbacks, tc.supportsShow, tc.supportsScan, tc.supportsSAS) err := senderHelper.Init(ctx) require.NoError(t, err) @@ -169,7 +171,7 @@ func TestVerification_StartThenCancel(t *testing.T) { bystanderClient, _ := ts.Login(t, ctx, aliceUserID, bystanderDeviceID) bystanderMachine := bystanderClient.Crypto.(*cryptohelper.CryptoHelper).Machine() - bystanderHelper := verificationhelper.NewVerificationHelper(bystanderClient, bystanderMachine, nil, newAllVerificationCallbacks(), true, true) + bystanderHelper := verificationhelper.NewVerificationHelper(bystanderClient, bystanderMachine, nil, newAllVerificationCallbacks(), true, true, true) require.NoError(t, bystanderHelper.Init(ctx)) require.NoError(t, sendingCryptoStore.PutDevice(ctx, aliceUserID, bystanderMachine.OwnIdentity())) @@ -259,12 +261,12 @@ func TestVerification_Accept_NoSupportedMethods(t *testing.T) { assert.NotEmpty(t, recoveryKey) assert.NotNil(t, cache) - sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, newAllVerificationCallbacks(), true, true) + sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, newAllVerificationCallbacks(), true, true, true) err = sendingHelper.Init(ctx) require.NoError(t, err) receivingCallbacks := newBaseVerificationCallbacks() - receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingClient.Crypto.(*cryptohelper.CryptoHelper).Machine(), nil, receivingCallbacks, false, false) + receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingClient.Crypto.(*cryptohelper.CryptoHelper).Machine(), nil, receivingCallbacks, false, false, false) err = receivingHelper.Init(ctx) require.NoError(t, err) @@ -288,23 +290,25 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { sendingSupportsShow bool receivingSupportsScan bool receivingSupportsShow bool + sendingSupportsSAS bool + receivingSupportsSAS bool sendingCallbacks MockVerificationCallbacks receivingCallbacks MockVerificationCallbacks expectedVerificationMethods []event.VerificationMethod }{ // TODO - {false, false, false, false, newSASVerificationCallbacks(), newSASVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS}}, - {true, false, true, false, newSASVerificationCallbacks(), newSASVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS}}, + {false, false, false, false, true, true, newSASVerificationCallbacks(), newSASVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS}}, + {true, false, true, false, true, true, newSASVerificationCallbacks(), newSASVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS}}, - {true, false, false, true, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeShow}}, - {false, true, true, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan}}, - {true, false, true, true, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeShow}}, - {false, true, true, true, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan}}, - {true, true, true, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan}}, - {true, true, false, true, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeShow}}, - {true, true, true, true, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow}}, + {true, false, false, true, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeShow}}, + {false, true, true, false, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan}}, + {true, false, true, true, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeShow}}, + {false, true, true, true, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan}}, + {true, true, true, false, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan}}, + {true, true, false, true, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeShow}}, + {true, true, true, true, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow}}, - {true, true, true, true, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow}}, + {true, true, true, true, true, true, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow}}, } for i, tc := range testCases { @@ -317,11 +321,11 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { assert.NotEmpty(t, recoveryKey) assert.NotNil(t, sendingCrossSigningKeysCache) - sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, tc.sendingCallbacks, tc.sendingSupportsShow, tc.sendingSupportsScan) + sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, tc.sendingCallbacks, tc.sendingSupportsShow, tc.sendingSupportsScan, tc.sendingSupportsSAS) err = sendingHelper.Init(ctx) require.NoError(t, err) - receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, nil, tc.receivingCallbacks, tc.receivingSupportsShow, tc.receivingSupportsScan) + receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, nil, tc.receivingCallbacks, tc.receivingSupportsShow, tc.receivingSupportsScan, tc.receivingSupportsSAS) err = receivingHelper.Init(ctx) require.NoError(t, err)