mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
verificationhelper: don't request QR scan if not enabled
Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
This commit is contained in:
parent
475c4bf39d
commit
890db20d8e
3 changed files with 86 additions and 33 deletions
|
|
@ -32,6 +32,8 @@ type baseVerificationCallbacks struct {
|
|||
decimalsShown map[id.VerificationTransactionID][]int
|
||||
}
|
||||
|
||||
var _ verificationhelper.RequiredCallbacks = (*baseVerificationCallbacks)(nil)
|
||||
|
||||
func newBaseVerificationCallbacks() *baseVerificationCallbacks {
|
||||
return &baseVerificationCallbacks{
|
||||
verificationsRequested: map[id.UserID][]id.VerificationTransactionID{},
|
||||
|
|
@ -98,6 +100,8 @@ type sasVerificationCallbacks struct {
|
|||
*baseVerificationCallbacks
|
||||
}
|
||||
|
||||
var _ verificationhelper.ShowSASCallbacks = (*sasVerificationCallbacks)(nil)
|
||||
|
||||
func newSASVerificationCallbacks() *sasVerificationCallbacks {
|
||||
return &sasVerificationCallbacks{newBaseVerificationCallbacks()}
|
||||
}
|
||||
|
|
@ -112,34 +116,76 @@ func (c *sasVerificationCallbacks) ShowSAS(ctx context.Context, txnID id.Verific
|
|||
c.decimalsShown[txnID] = decimals
|
||||
}
|
||||
|
||||
type qrCodeVerificationCallbacks struct {
|
||||
type scanQRCodeVerificationCallbacks struct {
|
||||
*baseVerificationCallbacks
|
||||
}
|
||||
|
||||
func newQRCodeVerificationCallbacks() *qrCodeVerificationCallbacks {
|
||||
return &qrCodeVerificationCallbacks{newBaseVerificationCallbacks()}
|
||||
var _ verificationhelper.ScanQRCodeCallbacks = (*scanQRCodeVerificationCallbacks)(nil)
|
||||
|
||||
func newScanQRCodeVerificationCallbacks() *scanQRCodeVerificationCallbacks {
|
||||
return &scanQRCodeVerificationCallbacks{newBaseVerificationCallbacks()}
|
||||
}
|
||||
|
||||
func newQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) *qrCodeVerificationCallbacks {
|
||||
return &qrCodeVerificationCallbacks{base}
|
||||
func newScanQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) *scanQRCodeVerificationCallbacks {
|
||||
return &scanQRCodeVerificationCallbacks{base}
|
||||
}
|
||||
|
||||
func (c *qrCodeVerificationCallbacks) ScanQRCode(ctx context.Context, txnID id.VerificationTransactionID) {
|
||||
func (c *scanQRCodeVerificationCallbacks) ScanQRCode(ctx context.Context, txnID id.VerificationTransactionID) {
|
||||
c.scanQRCodeTransactions = append(c.scanQRCodeTransactions, txnID)
|
||||
}
|
||||
|
||||
func (c *qrCodeVerificationCallbacks) ShowQRCode(ctx context.Context, txnID id.VerificationTransactionID, qrCode *verificationhelper.QRCode) {
|
||||
type showQRCodeVerificationCallbacks struct {
|
||||
*baseVerificationCallbacks
|
||||
}
|
||||
|
||||
var _ verificationhelper.ShowQRCodeCallbacks = (*showQRCodeVerificationCallbacks)(nil)
|
||||
|
||||
func newShowQRCodeVerificationCallbacks() *showQRCodeVerificationCallbacks {
|
||||
return &showQRCodeVerificationCallbacks{newBaseVerificationCallbacks()}
|
||||
}
|
||||
|
||||
func newShowQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) *showQRCodeVerificationCallbacks {
|
||||
return &showQRCodeVerificationCallbacks{base}
|
||||
}
|
||||
|
||||
func (c *showQRCodeVerificationCallbacks) ShowQRCode(ctx context.Context, txnID id.VerificationTransactionID, qrCode *verificationhelper.QRCode) {
|
||||
c.qrCodesShown[txnID] = qrCode
|
||||
}
|
||||
|
||||
func (c *qrCodeVerificationCallbacks) QRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) {
|
||||
func (c *showQRCodeVerificationCallbacks) QRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) {
|
||||
c.qrCodesScanned[txnID] = struct{}{}
|
||||
}
|
||||
|
||||
type showAndScanQRCodeVerificationCallbacks struct {
|
||||
*baseVerificationCallbacks
|
||||
*showQRCodeVerificationCallbacks
|
||||
*scanQRCodeVerificationCallbacks
|
||||
}
|
||||
|
||||
var _ verificationhelper.ScanQRCodeCallbacks = (*showAndScanQRCodeVerificationCallbacks)(nil)
|
||||
var _ verificationhelper.ShowQRCodeCallbacks = (*showAndScanQRCodeVerificationCallbacks)(nil)
|
||||
|
||||
func newShowAndScanQRCodeVerificationCallbacks() *showAndScanQRCodeVerificationCallbacks {
|
||||
base := newBaseVerificationCallbacks()
|
||||
return &showAndScanQRCodeVerificationCallbacks{
|
||||
base,
|
||||
newShowQRCodeVerificationCallbacks(),
|
||||
newScanQRCodeVerificationCallbacks(),
|
||||
}
|
||||
}
|
||||
|
||||
func newShowAndScanQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) *showAndScanQRCodeVerificationCallbacks {
|
||||
return &showAndScanQRCodeVerificationCallbacks{
|
||||
base,
|
||||
newShowQRCodeVerificationCallbacks(),
|
||||
newScanQRCodeVerificationCallbacks(),
|
||||
}
|
||||
}
|
||||
|
||||
type allVerificationCallbacks struct {
|
||||
*baseVerificationCallbacks
|
||||
*sasVerificationCallbacks
|
||||
*qrCodeVerificationCallbacks
|
||||
*scanQRCodeVerificationCallbacks
|
||||
*showQRCodeVerificationCallbacks
|
||||
}
|
||||
|
||||
func newAllVerificationCallbacks() *allVerificationCallbacks {
|
||||
|
|
@ -147,6 +193,7 @@ func newAllVerificationCallbacks() *allVerificationCallbacks {
|
|||
return &allVerificationCallbacks{
|
||||
base,
|
||||
newSASVerificationCallbacksWithBase(base),
|
||||
newQRCodeVerificationCallbacksWithBase(base),
|
||||
newScanQRCodeVerificationCallbacksWithBase(base),
|
||||
newShowQRCodeVerificationCallbacksWithBase(base),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"go.mau.fi/util/exslices"
|
||||
"go.mau.fi/util/jsontime"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
|
|
@ -47,12 +48,14 @@ type ShowSASCallbacks interface {
|
|||
ShowSAS(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, emojiDescriptions []string, decimals []int)
|
||||
}
|
||||
|
||||
type ShowQRCodeCallbacks interface {
|
||||
type ScanQRCodeCallbacks interface {
|
||||
// ScanQRCode is called when another device has sent a
|
||||
// m.key.verification.ready event and indicated that they are capable of
|
||||
// showing a QR code.
|
||||
ScanQRCode(ctx context.Context, txnID id.VerificationTransactionID)
|
||||
}
|
||||
|
||||
type ShowQRCodeCallbacks interface {
|
||||
// ShowQRCode is called when the verification has been accepted and a QR
|
||||
// code should be shown to the user.
|
||||
ShowQRCode(ctx context.Context, txnID id.VerificationTransactionID, qrCode *QRCode)
|
||||
|
|
@ -108,24 +111,22 @@ func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, stor
|
|||
helper.verificationDone = c.VerificationDone
|
||||
}
|
||||
|
||||
supportedMethods := map[event.VerificationMethod]struct{}{}
|
||||
if c, ok := callbacks.(ShowSASCallbacks); ok {
|
||||
supportedMethods[event.VerificationMethodSAS] = struct{}{}
|
||||
helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodSAS)
|
||||
helper.showSAS = c.ShowSAS
|
||||
}
|
||||
if c, ok := callbacks.(ShowQRCodeCallbacks); ok {
|
||||
supportedMethods[event.VerificationMethodQRCodeShow] = struct{}{}
|
||||
supportedMethods[event.VerificationMethodReciprocate] = struct{}{}
|
||||
helper.scanQRCode = c.ScanQRCode
|
||||
helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodQRCodeShow)
|
||||
helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodReciprocate)
|
||||
helper.showQRCode = c.ShowQRCode
|
||||
helper.qrCodeScaned = c.QRCodeScanned
|
||||
}
|
||||
if supportsScan {
|
||||
supportedMethods[event.VerificationMethodQRCodeScan] = struct{}{}
|
||||
supportedMethods[event.VerificationMethodReciprocate] = struct{}{}
|
||||
if c, ok := callbacks.(ScanQRCodeCallbacks); ok && supportsScan {
|
||||
helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodQRCodeScan)
|
||||
helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodReciprocate)
|
||||
helper.scanQRCode = c.ScanQRCode
|
||||
}
|
||||
|
||||
helper.supportedMethods = maps.Keys(supportedMethods)
|
||||
helper.supportedMethods = exslices.DeduplicateUnsorted(helper.supportedMethods)
|
||||
return &helper
|
||||
}
|
||||
|
||||
|
|
@ -420,7 +421,9 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V
|
|||
}
|
||||
txn.VerificationState = VerificationStateReady
|
||||
|
||||
if vh.scanQRCode != nil && slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) {
|
||||
if vh.scanQRCode != nil &&
|
||||
slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) && // technically redundant because vh.scanQRCode is only set if this is true
|
||||
slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) {
|
||||
vh.scanQRCode(ctx, txn.TransactionID)
|
||||
}
|
||||
|
||||
|
|
@ -734,7 +737,9 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn Verif
|
|||
}
|
||||
}
|
||||
|
||||
if vh.scanQRCode != nil && slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) {
|
||||
if vh.scanQRCode != nil &&
|
||||
slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) && // technically redundant because vh.scanQRCode is only set if this is true
|
||||
slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) {
|
||||
vh.scanQRCode(ctx, txn.TransactionID)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -95,11 +95,12 @@ func TestVerification_Start(t *testing.T) {
|
|||
expectedVerificationMethods []event.VerificationMethod
|
||||
}{
|
||||
{false, newBaseVerificationCallbacks(), "no supported verification methods", nil},
|
||||
{true, newBaseVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}},
|
||||
{false, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS}},
|
||||
{true, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate, event.VerificationMethodSAS}},
|
||||
{true, newQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}},
|
||||
{false, newQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}},
|
||||
{true, newScanQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}},
|
||||
{false, newScanQRCodeVerificationCallbacks(), "no supported verification methods", nil},
|
||||
{false, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}},
|
||||
{true, newShowAndScanQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}},
|
||||
{false, newShowAndScanQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}},
|
||||
{false, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}},
|
||||
{true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}},
|
||||
}
|
||||
|
|
@ -124,7 +125,7 @@ func TestVerification_Start(t *testing.T) {
|
|||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, txnID)
|
||||
|
||||
toDeviceInbox := ts.DeviceInbox[aliceUserID]
|
||||
|
|
@ -283,8 +284,8 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) {
|
|||
expectedVerificationMethods []event.VerificationMethod
|
||||
}{
|
||||
{false, false, newSASVerificationCallbacks(), newSASVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS}},
|
||||
{true, false, newQRCodeVerificationCallbacks(), newQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}},
|
||||
{false, true, newQRCodeVerificationCallbacks(), newQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}},
|
||||
{true, false, newShowAndScanQRCodeVerificationCallbacks(), newShowAndScanQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}},
|
||||
{false, true, newShowAndScanQRCodeVerificationCallbacks(), newShowAndScanQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}},
|
||||
{true, false, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}},
|
||||
{true, true, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate, event.VerificationMethodSAS}},
|
||||
}
|
||||
|
|
@ -321,10 +322,10 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) {
|
|||
err = receivingHelper.AcceptVerification(ctx, txnID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, sendingIsQRCallbacks := tc.sendingCallbacks.(*qrCodeVerificationCallbacks)
|
||||
_, sendingIsQRCallbacks := tc.sendingCallbacks.(*showQRCodeVerificationCallbacks)
|
||||
_, sendingIsAllCallbacks := tc.sendingCallbacks.(*allVerificationCallbacks)
|
||||
sendingCanShowQR := sendingIsQRCallbacks || sendingIsAllCallbacks
|
||||
_, receivingIsQRCallbacks := tc.receivingCallbacks.(*qrCodeVerificationCallbacks)
|
||||
_, receivingIsQRCallbacks := tc.receivingCallbacks.(*showQRCodeVerificationCallbacks)
|
||||
_, receivingIsAllCallbacks := tc.receivingCallbacks.(*allVerificationCallbacks)
|
||||
receivingCanShowQR := receivingIsQRCallbacks || receivingIsAllCallbacks
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue