diff --git a/crypto/verificationhelper/callbacks_test.go b/crypto/verificationhelper/callbacks_test.go index 466a60fc..b5ca9af8 100644 --- a/crypto/verificationhelper/callbacks_test.go +++ b/crypto/verificationhelper/callbacks_test.go @@ -32,6 +32,8 @@ type baseVerificationCallbacks struct { decimalsShown map[id.VerificationTransactionID][]int } +var _ verificationhelper.RequiredCallbacks = (*baseVerificationCallbacks)(nil) + func newBaseVerificationCallbacks() *baseVerificationCallbacks { return &baseVerificationCallbacks{ verificationsRequested: map[id.UserID][]id.VerificationTransactionID{}, @@ -98,6 +100,8 @@ type sasVerificationCallbacks struct { *baseVerificationCallbacks } +var _ verificationhelper.ShowSASCallbacks = (*sasVerificationCallbacks)(nil) + func newSASVerificationCallbacks() *sasVerificationCallbacks { return &sasVerificationCallbacks{newBaseVerificationCallbacks()} } @@ -112,34 +116,76 @@ func (c *sasVerificationCallbacks) ShowSAS(ctx context.Context, txnID id.Verific c.decimalsShown[txnID] = decimals } -type qrCodeVerificationCallbacks struct { +type scanQRCodeVerificationCallbacks struct { *baseVerificationCallbacks } -func newQRCodeVerificationCallbacks() *qrCodeVerificationCallbacks { - return &qrCodeVerificationCallbacks{newBaseVerificationCallbacks()} +var _ verificationhelper.ScanQRCodeCallbacks = (*scanQRCodeVerificationCallbacks)(nil) + +func newScanQRCodeVerificationCallbacks() *scanQRCodeVerificationCallbacks { + return &scanQRCodeVerificationCallbacks{newBaseVerificationCallbacks()} } -func newQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) *qrCodeVerificationCallbacks { - return &qrCodeVerificationCallbacks{base} +func newScanQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) *scanQRCodeVerificationCallbacks { + return &scanQRCodeVerificationCallbacks{base} } - -func (c *qrCodeVerificationCallbacks) ScanQRCode(ctx context.Context, txnID id.VerificationTransactionID) { +func (c *scanQRCodeVerificationCallbacks) ScanQRCode(ctx context.Context, txnID id.VerificationTransactionID) { c.scanQRCodeTransactions = append(c.scanQRCodeTransactions, txnID) } -func (c *qrCodeVerificationCallbacks) ShowQRCode(ctx context.Context, txnID id.VerificationTransactionID, qrCode *verificationhelper.QRCode) { +type showQRCodeVerificationCallbacks struct { + *baseVerificationCallbacks +} + +var _ verificationhelper.ShowQRCodeCallbacks = (*showQRCodeVerificationCallbacks)(nil) + +func newShowQRCodeVerificationCallbacks() *showQRCodeVerificationCallbacks { + return &showQRCodeVerificationCallbacks{newBaseVerificationCallbacks()} +} + +func newShowQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) *showQRCodeVerificationCallbacks { + return &showQRCodeVerificationCallbacks{base} +} + +func (c *showQRCodeVerificationCallbacks) ShowQRCode(ctx context.Context, txnID id.VerificationTransactionID, qrCode *verificationhelper.QRCode) { c.qrCodesShown[txnID] = qrCode } -func (c *qrCodeVerificationCallbacks) QRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) { +func (c *showQRCodeVerificationCallbacks) QRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) { c.qrCodesScanned[txnID] = struct{}{} } +type showAndScanQRCodeVerificationCallbacks struct { + *baseVerificationCallbacks + *showQRCodeVerificationCallbacks + *scanQRCodeVerificationCallbacks +} + +var _ verificationhelper.ScanQRCodeCallbacks = (*showAndScanQRCodeVerificationCallbacks)(nil) +var _ verificationhelper.ShowQRCodeCallbacks = (*showAndScanQRCodeVerificationCallbacks)(nil) + +func newShowAndScanQRCodeVerificationCallbacks() *showAndScanQRCodeVerificationCallbacks { + base := newBaseVerificationCallbacks() + return &showAndScanQRCodeVerificationCallbacks{ + base, + newShowQRCodeVerificationCallbacks(), + newScanQRCodeVerificationCallbacks(), + } +} + +func newShowAndScanQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) *showAndScanQRCodeVerificationCallbacks { + return &showAndScanQRCodeVerificationCallbacks{ + base, + newShowQRCodeVerificationCallbacks(), + newScanQRCodeVerificationCallbacks(), + } +} + type allVerificationCallbacks struct { *baseVerificationCallbacks *sasVerificationCallbacks - *qrCodeVerificationCallbacks + *scanQRCodeVerificationCallbacks + *showQRCodeVerificationCallbacks } func newAllVerificationCallbacks() *allVerificationCallbacks { @@ -147,6 +193,7 @@ func newAllVerificationCallbacks() *allVerificationCallbacks { return &allVerificationCallbacks{ base, newSASVerificationCallbacksWithBase(base), - newQRCodeVerificationCallbacksWithBase(base), + newScanQRCodeVerificationCallbacksWithBase(base), + newShowQRCodeVerificationCallbacksWithBase(base), } } diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index be547e7e..92d4de23 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -15,6 +15,7 @@ import ( "time" "github.com/rs/zerolog" + "go.mau.fi/util/exslices" "go.mau.fi/util/jsontime" "golang.org/x/exp/maps" "golang.org/x/exp/slices" @@ -47,12 +48,14 @@ type ShowSASCallbacks interface { ShowSAS(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, emojiDescriptions []string, decimals []int) } -type ShowQRCodeCallbacks interface { +type ScanQRCodeCallbacks interface { // ScanQRCode is called when another device has sent a // m.key.verification.ready event and indicated that they are capable of // showing a QR code. ScanQRCode(ctx context.Context, txnID id.VerificationTransactionID) +} +type ShowQRCodeCallbacks interface { // ShowQRCode is called when the verification has been accepted and a QR // code should be shown to the user. ShowQRCode(ctx context.Context, txnID id.VerificationTransactionID, qrCode *QRCode) @@ -108,24 +111,22 @@ func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, stor helper.verificationDone = c.VerificationDone } - supportedMethods := map[event.VerificationMethod]struct{}{} if c, ok := callbacks.(ShowSASCallbacks); ok { - supportedMethods[event.VerificationMethodSAS] = struct{}{} + helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodSAS) helper.showSAS = c.ShowSAS } if c, ok := callbacks.(ShowQRCodeCallbacks); ok { - supportedMethods[event.VerificationMethodQRCodeShow] = struct{}{} - supportedMethods[event.VerificationMethodReciprocate] = struct{}{} - helper.scanQRCode = c.ScanQRCode + helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodQRCodeShow) + helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodReciprocate) helper.showQRCode = c.ShowQRCode helper.qrCodeScaned = c.QRCodeScanned } - if supportsScan { - supportedMethods[event.VerificationMethodQRCodeScan] = struct{}{} - supportedMethods[event.VerificationMethodReciprocate] = struct{}{} + if c, ok := callbacks.(ScanQRCodeCallbacks); ok && supportsScan { + helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodQRCodeScan) + helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodReciprocate) + helper.scanQRCode = c.ScanQRCode } - - helper.supportedMethods = maps.Keys(supportedMethods) + helper.supportedMethods = exslices.DeduplicateUnsorted(helper.supportedMethods) return &helper } @@ -420,7 +421,9 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V } txn.VerificationState = VerificationStateReady - if vh.scanQRCode != nil && slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) { + if vh.scanQRCode != nil && + slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) && // technically redundant because vh.scanQRCode is only set if this is true + slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) { vh.scanQRCode(ctx, txn.TransactionID) } @@ -734,7 +737,9 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn Verif } } - if vh.scanQRCode != nil && slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) { + if vh.scanQRCode != nil && + slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) && // technically redundant because vh.scanQRCode is only set if this is true + slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) { vh.scanQRCode(ctx, txn.TransactionID) } diff --git a/crypto/verificationhelper/verificationhelper_test.go b/crypto/verificationhelper/verificationhelper_test.go index 49c8db07..31bc7d6e 100644 --- a/crypto/verificationhelper/verificationhelper_test.go +++ b/crypto/verificationhelper/verificationhelper_test.go @@ -95,11 +95,12 @@ func TestVerification_Start(t *testing.T) { expectedVerificationMethods []event.VerificationMethod }{ {false, newBaseVerificationCallbacks(), "no supported verification methods", nil}, - {true, newBaseVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, {false, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS}}, - {true, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, - {true, newQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, - {false, newQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, + {true, newScanQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + {false, newScanQRCodeVerificationCallbacks(), "no supported verification methods", nil}, + {false, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, + {true, newShowAndScanQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, + {false, newShowAndScanQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, {false, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, {true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, } @@ -124,7 +125,7 @@ func TestVerification_Start(t *testing.T) { return } - assert.NoError(t, err) + require.NoError(t, err) assert.NotEmpty(t, txnID) toDeviceInbox := ts.DeviceInbox[aliceUserID] @@ -283,8 +284,8 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { expectedVerificationMethods []event.VerificationMethod }{ {false, false, newSASVerificationCallbacks(), newSASVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS}}, - {true, false, newQRCodeVerificationCallbacks(), newQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, - {false, true, newQRCodeVerificationCallbacks(), newQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + {true, false, newShowAndScanQRCodeVerificationCallbacks(), newShowAndScanQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, + {false, true, newShowAndScanQRCodeVerificationCallbacks(), newShowAndScanQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, {true, false, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, {true, true, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, } @@ -321,10 +322,10 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) - _, sendingIsQRCallbacks := tc.sendingCallbacks.(*qrCodeVerificationCallbacks) + _, sendingIsQRCallbacks := tc.sendingCallbacks.(*showQRCodeVerificationCallbacks) _, sendingIsAllCallbacks := tc.sendingCallbacks.(*allVerificationCallbacks) sendingCanShowQR := sendingIsQRCallbacks || sendingIsAllCallbacks - _, receivingIsQRCallbacks := tc.receivingCallbacks.(*qrCodeVerificationCallbacks) + _, receivingIsQRCallbacks := tc.receivingCallbacks.(*showQRCodeVerificationCallbacks) _, receivingIsAllCallbacks := tc.receivingCallbacks.(*allVerificationCallbacks) receivingCanShowQR := receivingIsQRCallbacks || receivingIsAllCallbacks