verificationhelper: don't request QR scan if not enabled
Some checks are pending
Go / Lint (latest) (push) Waiting to run
Go / Build (old, libolm) (push) Waiting to run
Go / Build (latest, libolm) (push) Waiting to run
Go / Build (old, goolm) (push) Waiting to run
Go / Build (latest, goolm) (push) Waiting to run

Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
This commit is contained in:
Sumner Evans 2025-02-05 12:22:02 -07:00
commit 890db20d8e
No known key found for this signature in database
3 changed files with 86 additions and 33 deletions

View file

@ -32,6 +32,8 @@ type baseVerificationCallbacks struct {
decimalsShown map[id.VerificationTransactionID][]int
}
var _ verificationhelper.RequiredCallbacks = (*baseVerificationCallbacks)(nil)
func newBaseVerificationCallbacks() *baseVerificationCallbacks {
return &baseVerificationCallbacks{
verificationsRequested: map[id.UserID][]id.VerificationTransactionID{},
@ -98,6 +100,8 @@ type sasVerificationCallbacks struct {
*baseVerificationCallbacks
}
var _ verificationhelper.ShowSASCallbacks = (*sasVerificationCallbacks)(nil)
func newSASVerificationCallbacks() *sasVerificationCallbacks {
return &sasVerificationCallbacks{newBaseVerificationCallbacks()}
}
@ -112,34 +116,76 @@ func (c *sasVerificationCallbacks) ShowSAS(ctx context.Context, txnID id.Verific
c.decimalsShown[txnID] = decimals
}
type qrCodeVerificationCallbacks struct {
type scanQRCodeVerificationCallbacks struct {
*baseVerificationCallbacks
}
func newQRCodeVerificationCallbacks() *qrCodeVerificationCallbacks {
return &qrCodeVerificationCallbacks{newBaseVerificationCallbacks()}
var _ verificationhelper.ScanQRCodeCallbacks = (*scanQRCodeVerificationCallbacks)(nil)
func newScanQRCodeVerificationCallbacks() *scanQRCodeVerificationCallbacks {
return &scanQRCodeVerificationCallbacks{newBaseVerificationCallbacks()}
}
func newQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) *qrCodeVerificationCallbacks {
return &qrCodeVerificationCallbacks{base}
func newScanQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) *scanQRCodeVerificationCallbacks {
return &scanQRCodeVerificationCallbacks{base}
}
func (c *qrCodeVerificationCallbacks) ScanQRCode(ctx context.Context, txnID id.VerificationTransactionID) {
func (c *scanQRCodeVerificationCallbacks) ScanQRCode(ctx context.Context, txnID id.VerificationTransactionID) {
c.scanQRCodeTransactions = append(c.scanQRCodeTransactions, txnID)
}
func (c *qrCodeVerificationCallbacks) ShowQRCode(ctx context.Context, txnID id.VerificationTransactionID, qrCode *verificationhelper.QRCode) {
type showQRCodeVerificationCallbacks struct {
*baseVerificationCallbacks
}
var _ verificationhelper.ShowQRCodeCallbacks = (*showQRCodeVerificationCallbacks)(nil)
func newShowQRCodeVerificationCallbacks() *showQRCodeVerificationCallbacks {
return &showQRCodeVerificationCallbacks{newBaseVerificationCallbacks()}
}
func newShowQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) *showQRCodeVerificationCallbacks {
return &showQRCodeVerificationCallbacks{base}
}
func (c *showQRCodeVerificationCallbacks) ShowQRCode(ctx context.Context, txnID id.VerificationTransactionID, qrCode *verificationhelper.QRCode) {
c.qrCodesShown[txnID] = qrCode
}
func (c *qrCodeVerificationCallbacks) QRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) {
func (c *showQRCodeVerificationCallbacks) QRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) {
c.qrCodesScanned[txnID] = struct{}{}
}
type showAndScanQRCodeVerificationCallbacks struct {
*baseVerificationCallbacks
*showQRCodeVerificationCallbacks
*scanQRCodeVerificationCallbacks
}
var _ verificationhelper.ScanQRCodeCallbacks = (*showAndScanQRCodeVerificationCallbacks)(nil)
var _ verificationhelper.ShowQRCodeCallbacks = (*showAndScanQRCodeVerificationCallbacks)(nil)
func newShowAndScanQRCodeVerificationCallbacks() *showAndScanQRCodeVerificationCallbacks {
base := newBaseVerificationCallbacks()
return &showAndScanQRCodeVerificationCallbacks{
base,
newShowQRCodeVerificationCallbacks(),
newScanQRCodeVerificationCallbacks(),
}
}
func newShowAndScanQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) *showAndScanQRCodeVerificationCallbacks {
return &showAndScanQRCodeVerificationCallbacks{
base,
newShowQRCodeVerificationCallbacks(),
newScanQRCodeVerificationCallbacks(),
}
}
type allVerificationCallbacks struct {
*baseVerificationCallbacks
*sasVerificationCallbacks
*qrCodeVerificationCallbacks
*scanQRCodeVerificationCallbacks
*showQRCodeVerificationCallbacks
}
func newAllVerificationCallbacks() *allVerificationCallbacks {
@ -147,6 +193,7 @@ func newAllVerificationCallbacks() *allVerificationCallbacks {
return &allVerificationCallbacks{
base,
newSASVerificationCallbacksWithBase(base),
newQRCodeVerificationCallbacksWithBase(base),
newScanQRCodeVerificationCallbacksWithBase(base),
newShowQRCodeVerificationCallbacksWithBase(base),
}
}

View file

@ -15,6 +15,7 @@ import (
"time"
"github.com/rs/zerolog"
"go.mau.fi/util/exslices"
"go.mau.fi/util/jsontime"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
@ -47,12 +48,14 @@ type ShowSASCallbacks interface {
ShowSAS(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, emojiDescriptions []string, decimals []int)
}
type ShowQRCodeCallbacks interface {
type ScanQRCodeCallbacks interface {
// ScanQRCode is called when another device has sent a
// m.key.verification.ready event and indicated that they are capable of
// showing a QR code.
ScanQRCode(ctx context.Context, txnID id.VerificationTransactionID)
}
type ShowQRCodeCallbacks interface {
// ShowQRCode is called when the verification has been accepted and a QR
// code should be shown to the user.
ShowQRCode(ctx context.Context, txnID id.VerificationTransactionID, qrCode *QRCode)
@ -108,24 +111,22 @@ func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, stor
helper.verificationDone = c.VerificationDone
}
supportedMethods := map[event.VerificationMethod]struct{}{}
if c, ok := callbacks.(ShowSASCallbacks); ok {
supportedMethods[event.VerificationMethodSAS] = struct{}{}
helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodSAS)
helper.showSAS = c.ShowSAS
}
if c, ok := callbacks.(ShowQRCodeCallbacks); ok {
supportedMethods[event.VerificationMethodQRCodeShow] = struct{}{}
supportedMethods[event.VerificationMethodReciprocate] = struct{}{}
helper.scanQRCode = c.ScanQRCode
helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodQRCodeShow)
helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodReciprocate)
helper.showQRCode = c.ShowQRCode
helper.qrCodeScaned = c.QRCodeScanned
}
if supportsScan {
supportedMethods[event.VerificationMethodQRCodeScan] = struct{}{}
supportedMethods[event.VerificationMethodReciprocate] = struct{}{}
if c, ok := callbacks.(ScanQRCodeCallbacks); ok && supportsScan {
helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodQRCodeScan)
helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodReciprocate)
helper.scanQRCode = c.ScanQRCode
}
helper.supportedMethods = maps.Keys(supportedMethods)
helper.supportedMethods = exslices.DeduplicateUnsorted(helper.supportedMethods)
return &helper
}
@ -420,7 +421,9 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V
}
txn.VerificationState = VerificationStateReady
if vh.scanQRCode != nil && slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) {
if vh.scanQRCode != nil &&
slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) && // technically redundant because vh.scanQRCode is only set if this is true
slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) {
vh.scanQRCode(ctx, txn.TransactionID)
}
@ -734,7 +737,9 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn Verif
}
}
if vh.scanQRCode != nil && slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) {
if vh.scanQRCode != nil &&
slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) && // technically redundant because vh.scanQRCode is only set if this is true
slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) {
vh.scanQRCode(ctx, txn.TransactionID)
}

View file

@ -95,11 +95,12 @@ func TestVerification_Start(t *testing.T) {
expectedVerificationMethods []event.VerificationMethod
}{
{false, newBaseVerificationCallbacks(), "no supported verification methods", nil},
{true, newBaseVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}},
{false, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS}},
{true, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate, event.VerificationMethodSAS}},
{true, newQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}},
{false, newQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}},
{true, newScanQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}},
{false, newScanQRCodeVerificationCallbacks(), "no supported verification methods", nil},
{false, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}},
{true, newShowAndScanQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}},
{false, newShowAndScanQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}},
{false, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}},
{true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}},
}
@ -124,7 +125,7 @@ func TestVerification_Start(t *testing.T) {
return
}
assert.NoError(t, err)
require.NoError(t, err)
assert.NotEmpty(t, txnID)
toDeviceInbox := ts.DeviceInbox[aliceUserID]
@ -283,8 +284,8 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) {
expectedVerificationMethods []event.VerificationMethod
}{
{false, false, newSASVerificationCallbacks(), newSASVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS}},
{true, false, newQRCodeVerificationCallbacks(), newQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}},
{false, true, newQRCodeVerificationCallbacks(), newQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}},
{true, false, newShowAndScanQRCodeVerificationCallbacks(), newShowAndScanQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}},
{false, true, newShowAndScanQRCodeVerificationCallbacks(), newShowAndScanQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}},
{true, false, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}},
{true, true, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate, event.VerificationMethodSAS}},
}
@ -321,10 +322,10 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) {
err = receivingHelper.AcceptVerification(ctx, txnID)
require.NoError(t, err)
_, sendingIsQRCallbacks := tc.sendingCallbacks.(*qrCodeVerificationCallbacks)
_, sendingIsQRCallbacks := tc.sendingCallbacks.(*showQRCodeVerificationCallbacks)
_, sendingIsAllCallbacks := tc.sendingCallbacks.(*allVerificationCallbacks)
sendingCanShowQR := sendingIsQRCallbacks || sendingIsAllCallbacks
_, receivingIsQRCallbacks := tc.receivingCallbacks.(*qrCodeVerificationCallbacks)
_, receivingIsQRCallbacks := tc.receivingCallbacks.(*showQRCodeVerificationCallbacks)
_, receivingIsAllCallbacks := tc.receivingCallbacks.(*allVerificationCallbacks)
receivingCanShowQR := receivingIsQRCallbacks || receivingIsAllCallbacks