mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
verificationhelper: add VerificationReady callback for when verification is accepted
Some checks failed
Some checks failed
This callback supersedes the ScanQRCode and ShowQRCode callbacks. Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
This commit is contained in:
parent
14008caaa4
commit
5600dd4054
5 changed files with 129 additions and 129 deletions
|
|
@ -17,12 +17,14 @@ import (
|
|||
type MockVerificationCallbacks interface {
|
||||
GetRequestedVerifications() map[id.UserID][]id.VerificationTransactionID
|
||||
GetScanQRCodeTransactions() []id.VerificationTransactionID
|
||||
GetVerificationsReadyTransactions() []id.VerificationTransactionID
|
||||
GetQRCodeShown(id.VerificationTransactionID) *verificationhelper.QRCode
|
||||
}
|
||||
|
||||
type baseVerificationCallbacks struct {
|
||||
scanQRCodeTransactions []id.VerificationTransactionID
|
||||
verificationsRequested map[id.UserID][]id.VerificationTransactionID
|
||||
verificationsReady []id.VerificationTransactionID
|
||||
qrCodesShown map[id.VerificationTransactionID]*verificationhelper.QRCode
|
||||
qrCodesScanned map[id.VerificationTransactionID]struct{}
|
||||
doneTransactions map[id.VerificationTransactionID]struct{}
|
||||
|
|
@ -33,6 +35,7 @@ type baseVerificationCallbacks struct {
|
|||
}
|
||||
|
||||
var _ verificationhelper.RequiredCallbacks = (*baseVerificationCallbacks)(nil)
|
||||
var _ MockVerificationCallbacks = (*baseVerificationCallbacks)(nil)
|
||||
|
||||
func newBaseVerificationCallbacks() *baseVerificationCallbacks {
|
||||
return &baseVerificationCallbacks{
|
||||
|
|
@ -55,6 +58,10 @@ func (c *baseVerificationCallbacks) GetScanQRCodeTransactions() []id.Verificatio
|
|||
return c.scanQRCodeTransactions
|
||||
}
|
||||
|
||||
func (c *baseVerificationCallbacks) GetVerificationsReadyTransactions() []id.VerificationTransactionID {
|
||||
return c.verificationsReady
|
||||
}
|
||||
|
||||
func (c *baseVerificationCallbacks) GetQRCodeShown(txnID id.VerificationTransactionID) *verificationhelper.QRCode {
|
||||
return c.qrCodesShown[txnID]
|
||||
}
|
||||
|
|
@ -85,6 +92,16 @@ func (c *baseVerificationCallbacks) VerificationRequested(ctx context.Context, t
|
|||
c.verificationsRequested[from] = append(c.verificationsRequested[from], txnID)
|
||||
}
|
||||
|
||||
func (c *baseVerificationCallbacks) VerificationReady(ctx context.Context, txnID id.VerificationTransactionID, otherDeviceID id.DeviceID, supportsSAS, allowScanQRCode bool, qrCode *verificationhelper.QRCode) {
|
||||
c.verificationsReady = append(c.verificationsReady, txnID)
|
||||
if allowScanQRCode {
|
||||
c.scanQRCodeTransactions = append(c.scanQRCodeTransactions, txnID)
|
||||
}
|
||||
if qrCode != nil {
|
||||
c.qrCodesShown[txnID] = qrCode
|
||||
}
|
||||
}
|
||||
|
||||
func (c *baseVerificationCallbacks) VerificationCancelled(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) {
|
||||
c.verificationCancellation[txnID] = &event.VerificationCancelEventContent{
|
||||
Code: code,
|
||||
|
|
@ -116,23 +133,6 @@ func (c *sasVerificationCallbacks) ShowSAS(ctx context.Context, txnID id.Verific
|
|||
c.decimalsShown[txnID] = decimals
|
||||
}
|
||||
|
||||
type scanQRCodeVerificationCallbacks struct {
|
||||
*baseVerificationCallbacks
|
||||
}
|
||||
|
||||
var _ verificationhelper.ScanQRCodeCallbacks = (*scanQRCodeVerificationCallbacks)(nil)
|
||||
|
||||
func newScanQRCodeVerificationCallbacks() *scanQRCodeVerificationCallbacks {
|
||||
return &scanQRCodeVerificationCallbacks{newBaseVerificationCallbacks()}
|
||||
}
|
||||
|
||||
func newScanQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) *scanQRCodeVerificationCallbacks {
|
||||
return &scanQRCodeVerificationCallbacks{base}
|
||||
}
|
||||
func (c *scanQRCodeVerificationCallbacks) ScanQRCode(ctx context.Context, txnID id.VerificationTransactionID) {
|
||||
c.scanQRCodeTransactions = append(c.scanQRCodeTransactions, txnID)
|
||||
}
|
||||
|
||||
type showQRCodeVerificationCallbacks struct {
|
||||
*baseVerificationCallbacks
|
||||
}
|
||||
|
|
@ -147,44 +147,13 @@ func newShowQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks)
|
|||
return &showQRCodeVerificationCallbacks{base}
|
||||
}
|
||||
|
||||
func (c *showQRCodeVerificationCallbacks) ShowQRCode(ctx context.Context, txnID id.VerificationTransactionID, qrCode *verificationhelper.QRCode) {
|
||||
c.qrCodesShown[txnID] = qrCode
|
||||
}
|
||||
|
||||
func (c *showQRCodeVerificationCallbacks) QRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) {
|
||||
c.qrCodesScanned[txnID] = struct{}{}
|
||||
}
|
||||
|
||||
type showAndScanQRCodeVerificationCallbacks struct {
|
||||
*baseVerificationCallbacks
|
||||
*showQRCodeVerificationCallbacks
|
||||
*scanQRCodeVerificationCallbacks
|
||||
}
|
||||
|
||||
var _ verificationhelper.ScanQRCodeCallbacks = (*showAndScanQRCodeVerificationCallbacks)(nil)
|
||||
var _ verificationhelper.ShowQRCodeCallbacks = (*showAndScanQRCodeVerificationCallbacks)(nil)
|
||||
|
||||
func newShowAndScanQRCodeVerificationCallbacks() *showAndScanQRCodeVerificationCallbacks {
|
||||
base := newBaseVerificationCallbacks()
|
||||
return &showAndScanQRCodeVerificationCallbacks{
|
||||
base,
|
||||
newShowQRCodeVerificationCallbacks(),
|
||||
newScanQRCodeVerificationCallbacks(),
|
||||
}
|
||||
}
|
||||
|
||||
func newShowAndScanQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) *showAndScanQRCodeVerificationCallbacks {
|
||||
return &showAndScanQRCodeVerificationCallbacks{
|
||||
base,
|
||||
newShowQRCodeVerificationCallbacks(),
|
||||
newScanQRCodeVerificationCallbacks(),
|
||||
}
|
||||
}
|
||||
|
||||
type allVerificationCallbacks struct {
|
||||
*baseVerificationCallbacks
|
||||
*sasVerificationCallbacks
|
||||
*scanQRCodeVerificationCallbacks
|
||||
*showQRCodeVerificationCallbacks
|
||||
}
|
||||
|
||||
|
|
@ -193,7 +162,6 @@ func newAllVerificationCallbacks() *allVerificationCallbacks {
|
|||
return &allVerificationCallbacks{
|
||||
base,
|
||||
newSASVerificationCallbacksWithBase(base),
|
||||
newScanQRCodeVerificationCallbacksWithBase(base),
|
||||
newShowQRCodeVerificationCallbacksWithBase(base),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -82,6 +82,10 @@ func NewQRCodeFromBytes(data []byte) (*QRCode, error) {
|
|||
//
|
||||
// [Section 11.12.2.4.1]: https://spec.matrix.org/v1.9/client-server-api/#qr-code-format
|
||||
func (q *QRCode) Bytes() []byte {
|
||||
if q == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString("MATRIX") // Header
|
||||
buf.WriteByte(0x02) // Version
|
||||
|
|
|
|||
|
|
@ -270,28 +270,30 @@ func (vh *VerificationHelper) ConfirmQRCodeScanned(ctx context.Context, txnID id
|
|||
return nil
|
||||
}
|
||||
|
||||
func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *VerificationTransaction) error {
|
||||
func (vh *VerificationHelper) generateQRCode(ctx context.Context, txn *VerificationTransaction) (*QRCode, error) {
|
||||
log := vh.getLog(ctx).With().
|
||||
Str("verification_action", "generate and show QR code").
|
||||
Stringer("transaction_id", txn.TransactionID).
|
||||
Logger()
|
||||
ctx = log.WithContext(ctx)
|
||||
if vh.showQRCode == nil {
|
||||
log.Info().Msg("Ignoring QR code generation request as showing a QR code is not enabled on this device")
|
||||
return nil
|
||||
|
||||
if !slices.Contains(vh.supportedMethods, event.VerificationMethodReciprocate) ||
|
||||
!slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodReciprocate) {
|
||||
log.Info().Msg("Ignoring QR code generation request as reciprocating is not supported by both devices")
|
||||
return nil, nil
|
||||
} else if !slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeScan) {
|
||||
log.Info().Msg("Ignoring QR code generation request as other device cannot scan QR codes")
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ownCrossSigningPublicKeys := vh.mach.GetOwnCrossSigningPublicKeys(ctx)
|
||||
if ownCrossSigningPublicKeys == nil || len(ownCrossSigningPublicKeys.MasterKey) == 0 {
|
||||
return errors.New("failed to get own cross-signing master public key")
|
||||
return nil, errors.New("failed to get own cross-signing master public key")
|
||||
}
|
||||
|
||||
ownMasterKeyTrusted, err := vh.mach.CryptoStore.IsKeySignedBy(ctx, vh.client.UserID, ownCrossSigningPublicKeys.MasterKey, vh.client.UserID, vh.mach.OwnIdentity().SigningKey)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
mode := QRCodeModeCrossSigning
|
||||
if vh.client.UserID == txn.TheirUserID {
|
||||
|
|
@ -304,7 +306,7 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *Ve
|
|||
} else {
|
||||
// This is a cross-signing situation.
|
||||
if !ownMasterKeyTrusted {
|
||||
return errors.New("cannot cross-sign other device when own master key is not trusted")
|
||||
return nil, errors.New("cannot cross-sign other device when own master key is not trusted")
|
||||
}
|
||||
mode = QRCodeModeCrossSigning
|
||||
}
|
||||
|
|
@ -318,7 +320,7 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *Ve
|
|||
// Key 2 is the other user's master signing key.
|
||||
theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
key2 = theirSigningKeys.MasterKey.Bytes()
|
||||
case QRCodeModeSelfVerifyingMasterKeyTrusted:
|
||||
|
|
@ -328,7 +330,7 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *Ve
|
|||
// Key 2 is the other device's key.
|
||||
theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
key2 = theirDevice.SigningKey.Bytes()
|
||||
case QRCodeModeSelfVerifyingMasterKeyUntrusted:
|
||||
|
|
@ -343,6 +345,5 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *Ve
|
|||
|
||||
qrCode := NewQRCode(mode, txn.TransactionID, [32]byte(key1), [32]byte(key2))
|
||||
txn.QRCodeSharedSecret = qrCode.SharedSecret
|
||||
vh.showQRCode(ctx, txn.TransactionID, qrCode)
|
||||
return nil
|
||||
return qrCode, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -33,6 +33,10 @@ type RequiredCallbacks interface {
|
|||
// from another device.
|
||||
VerificationRequested(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID, fromDevice id.DeviceID)
|
||||
|
||||
// VerificationReady is called when a verification request has been
|
||||
// accepted by both parties.
|
||||
VerificationReady(ctx context.Context, txnID id.VerificationTransactionID, otherDeviceID id.DeviceID, supportsSAS, supportsScanQRCode bool, qrCode *QRCode)
|
||||
|
||||
// VerificationCancelled is called when the verification is cancelled.
|
||||
VerificationCancelled(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string)
|
||||
|
||||
|
|
@ -48,18 +52,7 @@ type ShowSASCallbacks interface {
|
|||
ShowSAS(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, emojiDescriptions []string, decimals []int)
|
||||
}
|
||||
|
||||
type ScanQRCodeCallbacks interface {
|
||||
// ScanQRCode is called when another device has sent a
|
||||
// m.key.verification.ready event and indicated that they are capable of
|
||||
// showing a QR code.
|
||||
ScanQRCode(ctx context.Context, txnID id.VerificationTransactionID)
|
||||
}
|
||||
|
||||
type ShowQRCodeCallbacks interface {
|
||||
// ShowQRCode is called when the verification has been accepted and a QR
|
||||
// code should be shown to the user.
|
||||
ShowQRCode(ctx context.Context, txnID id.VerificationTransactionID, qrCode *QRCode)
|
||||
|
||||
// QRCodeScanned is called when the other user has scanned the QR code and
|
||||
// sent the m.key.verification.start event.
|
||||
QRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID)
|
||||
|
|
@ -71,24 +64,25 @@ type VerificationHelper struct {
|
|||
|
||||
store VerificationStore
|
||||
activeTransactionsLock sync.Mutex
|
||||
// activeTransactions map[id.VerificationTransactionID]*verificationTransaction
|
||||
|
||||
// supportedMethods are the methods that *we* support
|
||||
supportedMethods []event.VerificationMethod
|
||||
verificationRequested func(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID, fromDevice id.DeviceID)
|
||||
verificationReady func(ctx context.Context, txnID id.VerificationTransactionID, otherDeviceID id.DeviceID, supportsSAS, supportsScanQRCode bool, qrCode *QRCode)
|
||||
verificationCancelledCallback func(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string)
|
||||
verificationDone func(ctx context.Context, txnID id.VerificationTransactionID)
|
||||
|
||||
// showSAS is a callback that will be called after the SAS verification
|
||||
// dance is complete and we want the client to show the emojis/decimals
|
||||
showSAS func(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, emojiDescriptions []string, decimals []int)
|
||||
|
||||
scanQRCode func(ctx context.Context, txnID id.VerificationTransactionID)
|
||||
showQRCode func(ctx context.Context, txnID id.VerificationTransactionID, qrCode *QRCode)
|
||||
// qrCodeScaned is a callback that will be called when the other device
|
||||
// scanned the QR code we are showing
|
||||
qrCodeScaned func(ctx context.Context, txnID id.VerificationTransactionID)
|
||||
}
|
||||
|
||||
var _ mautrix.VerificationHelper = (*VerificationHelper)(nil)
|
||||
|
||||
func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, store VerificationStore, callbacks any, supportsScan bool) *VerificationHelper {
|
||||
func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, store VerificationStore, callbacks any, supportsQRShow, supportsQRScan bool) *VerificationHelper {
|
||||
if client.Crypto == nil {
|
||||
panic("client.Crypto is nil")
|
||||
}
|
||||
|
|
@ -107,6 +101,7 @@ func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, stor
|
|||
panic("callbacks must implement RequiredCallbacks")
|
||||
} else {
|
||||
helper.verificationRequested = c.VerificationRequested
|
||||
helper.verificationReady = c.VerificationReady
|
||||
helper.verificationCancelledCallback = c.VerificationCancelled
|
||||
helper.verificationDone = c.VerificationDone
|
||||
}
|
||||
|
|
@ -115,16 +110,18 @@ func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, stor
|
|||
helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodSAS)
|
||||
helper.showSAS = c.ShowSAS
|
||||
}
|
||||
if c, ok := callbacks.(ShowQRCodeCallbacks); ok {
|
||||
helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodQRCodeShow)
|
||||
helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodReciprocate)
|
||||
helper.showQRCode = c.ShowQRCode
|
||||
helper.qrCodeScaned = c.QRCodeScanned
|
||||
if supportsQRShow {
|
||||
if c, ok := callbacks.(ShowQRCodeCallbacks); !ok {
|
||||
panic("callbacks must implement ShowQRCodeCallbacks if supportsQRShow is true")
|
||||
} else {
|
||||
helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodQRCodeShow)
|
||||
helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodReciprocate)
|
||||
helper.qrCodeScaned = c.QRCodeScanned
|
||||
}
|
||||
}
|
||||
if c, ok := callbacks.(ScanQRCodeCallbacks); ok && supportsScan {
|
||||
if supportsQRScan {
|
||||
helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodQRCodeScan)
|
||||
helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodReciprocate)
|
||||
helper.scanQRCode = c.ScanQRCode
|
||||
}
|
||||
helper.supportedMethods = exslices.DeduplicateUnsorted(helper.supportedMethods)
|
||||
return &helper
|
||||
|
|
@ -421,15 +418,19 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V
|
|||
}
|
||||
txn.VerificationState = VerificationStateReady
|
||||
|
||||
if vh.scanQRCode != nil &&
|
||||
slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) && // technically redundant because vh.scanQRCode is only set if this is true
|
||||
slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) {
|
||||
vh.scanQRCode(ctx, txn.TransactionID)
|
||||
}
|
||||
supportsSAS := slices.Contains(vh.supportedMethods, event.VerificationMethodSAS) &&
|
||||
slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodSAS)
|
||||
supportsReciprocate := slices.Contains(vh.supportedMethods, event.VerificationMethodReciprocate) &&
|
||||
slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodReciprocate)
|
||||
supportsScanQRCode := supportsReciprocate &&
|
||||
slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) &&
|
||||
slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow)
|
||||
|
||||
if err := vh.generateAndShowQRCode(ctx, &txn); err != nil {
|
||||
qrCode, err := vh.generateQRCode(ctx, &txn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
vh.verificationReady(ctx, txn.TransactionID, txn.TheirDeviceID, supportsSAS, supportsScanQRCode, qrCode)
|
||||
return vh.store.SaveVerificationTransaction(ctx, txn)
|
||||
}
|
||||
|
||||
|
|
@ -737,15 +738,23 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn Verif
|
|||
}
|
||||
}
|
||||
|
||||
if vh.scanQRCode != nil &&
|
||||
slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) && // technically redundant because vh.scanQRCode is only set if this is true
|
||||
slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) {
|
||||
vh.scanQRCode(ctx, txn.TransactionID)
|
||||
supportsSAS := slices.Contains(vh.supportedMethods, event.VerificationMethodSAS) &&
|
||||
slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodSAS)
|
||||
supportsReciprocate := slices.Contains(vh.supportedMethods, event.VerificationMethodReciprocate) &&
|
||||
slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodReciprocate)
|
||||
supportsScanQRCode := supportsReciprocate &&
|
||||
slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) &&
|
||||
slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow)
|
||||
|
||||
qrCode, err := vh.generateQRCode(ctx, &txn)
|
||||
if err != nil {
|
||||
vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to generate QR code: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := vh.generateAndShowQRCode(ctx, &txn); err != nil {
|
||||
vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to generate and show QR code: %w", err)
|
||||
} else if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil {
|
||||
vh.verificationReady(ctx, txn.TransactionID, txn.TheirDeviceID, supportsSAS, supportsScanQRCode, qrCode)
|
||||
|
||||
if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil {
|
||||
vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to save verification transaction: %w", err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ func initDefaultCallbacks(t *testing.T, ctx context.Context, sendingClient, rece
|
|||
senderVerificationStore, err := NewSQLiteVerificationStore(ctx, senderVerificationDB)
|
||||
require.NoError(t, err)
|
||||
|
||||
sendingHelper = verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, senderVerificationStore, sendingCallbacks, true)
|
||||
sendingHelper = verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, senderVerificationStore, sendingCallbacks, true, true)
|
||||
require.NoError(t, sendingHelper.Init(ctx))
|
||||
|
||||
receivingCallbacks = newAllVerificationCallbacks()
|
||||
|
|
@ -79,7 +79,7 @@ func initDefaultCallbacks(t *testing.T, ctx context.Context, sendingClient, rece
|
|||
require.NoError(t, err)
|
||||
receiverVerificationStore, err := NewSQLiteVerificationStore(ctx, receiverVerificationDB)
|
||||
require.NoError(t, err)
|
||||
receivingHelper = verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, receiverVerificationStore, receivingCallbacks, true)
|
||||
receivingHelper = verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, receiverVerificationStore, receivingCallbacks, true, true)
|
||||
require.NoError(t, receivingHelper.Init(ctx))
|
||||
return
|
||||
}
|
||||
|
|
@ -89,20 +89,27 @@ func TestVerification_Start(t *testing.T) {
|
|||
receivingDeviceID2 := id.DeviceID("receiving2")
|
||||
|
||||
testCases := []struct {
|
||||
supportsShow bool
|
||||
supportsScan bool
|
||||
callbacks MockVerificationCallbacks
|
||||
startVerificationErrMsg string
|
||||
expectedVerificationMethods []event.VerificationMethod
|
||||
}{
|
||||
{false, newBaseVerificationCallbacks(), "no supported verification methods", nil},
|
||||
{false, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS}},
|
||||
{true, newScanQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}},
|
||||
{false, newScanQRCodeVerificationCallbacks(), "no supported verification methods", nil},
|
||||
{false, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}},
|
||||
{true, newShowAndScanQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}},
|
||||
{false, newShowAndScanQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}},
|
||||
{false, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}},
|
||||
{true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}},
|
||||
{false, false, newBaseVerificationCallbacks(), "no supported verification methods", nil},
|
||||
{false, true, newBaseVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}},
|
||||
|
||||
{false, false, newShowQRCodeVerificationCallbacks(), "no supported verification methods", nil},
|
||||
{true, false, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}},
|
||||
{false, true, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}},
|
||||
{true, true, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}},
|
||||
|
||||
{false, false, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS}},
|
||||
{false, true, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}},
|
||||
|
||||
{false, false, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS}},
|
||||
{false, true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}},
|
||||
{true, false, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}},
|
||||
{true, true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeShow, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
|
|
@ -115,7 +122,7 @@ func TestVerification_Start(t *testing.T) {
|
|||
addDeviceID(ctx, cryptoStore, aliceUserID, receivingDeviceID)
|
||||
addDeviceID(ctx, cryptoStore, aliceUserID, receivingDeviceID2)
|
||||
|
||||
senderHelper := verificationhelper.NewVerificationHelper(client, client.Crypto.(*cryptohelper.CryptoHelper).Machine(), nil, tc.callbacks, tc.supportsScan)
|
||||
senderHelper := verificationhelper.NewVerificationHelper(client, client.Crypto.(*cryptohelper.CryptoHelper).Machine(), nil, tc.callbacks, tc.supportsShow, tc.supportsScan)
|
||||
err := senderHelper.Init(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -162,7 +169,7 @@ func TestVerification_StartThenCancel(t *testing.T) {
|
|||
|
||||
bystanderClient, _ := ts.Login(t, ctx, aliceUserID, bystanderDeviceID)
|
||||
bystanderMachine := bystanderClient.Crypto.(*cryptohelper.CryptoHelper).Machine()
|
||||
bystanderHelper := verificationhelper.NewVerificationHelper(bystanderClient, bystanderMachine, nil, newAllVerificationCallbacks(), true)
|
||||
bystanderHelper := verificationhelper.NewVerificationHelper(bystanderClient, bystanderMachine, nil, newAllVerificationCallbacks(), true, true)
|
||||
require.NoError(t, bystanderHelper.Init(ctx))
|
||||
|
||||
require.NoError(t, sendingCryptoStore.PutDevice(ctx, aliceUserID, bystanderMachine.OwnIdentity()))
|
||||
|
|
@ -252,12 +259,12 @@ func TestVerification_Accept_NoSupportedMethods(t *testing.T) {
|
|||
assert.NotEmpty(t, recoveryKey)
|
||||
assert.NotNil(t, cache)
|
||||
|
||||
sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, newAllVerificationCallbacks(), true)
|
||||
sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, newAllVerificationCallbacks(), true, true)
|
||||
err = sendingHelper.Init(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
receivingCallbacks := newBaseVerificationCallbacks()
|
||||
receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingClient.Crypto.(*cryptohelper.CryptoHelper).Machine(), nil, receivingCallbacks, false)
|
||||
receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingClient.Crypto.(*cryptohelper.CryptoHelper).Machine(), nil, receivingCallbacks, false, false)
|
||||
err = receivingHelper.Init(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -278,16 +285,26 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) {
|
|||
|
||||
testCases := []struct {
|
||||
sendingSupportsScan bool
|
||||
sendingSupportsShow bool
|
||||
receivingSupportsScan bool
|
||||
receivingSupportsShow bool
|
||||
sendingCallbacks MockVerificationCallbacks
|
||||
receivingCallbacks MockVerificationCallbacks
|
||||
expectedVerificationMethods []event.VerificationMethod
|
||||
}{
|
||||
{false, false, newSASVerificationCallbacks(), newSASVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS}},
|
||||
{true, false, newShowAndScanQRCodeVerificationCallbacks(), newShowAndScanQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}},
|
||||
{false, true, newShowAndScanQRCodeVerificationCallbacks(), newShowAndScanQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}},
|
||||
{true, false, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}},
|
||||
{true, true, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate, event.VerificationMethodSAS}},
|
||||
// TODO
|
||||
{false, false, false, false, newSASVerificationCallbacks(), newSASVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS}},
|
||||
{true, false, true, false, newSASVerificationCallbacks(), newSASVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS}},
|
||||
|
||||
{true, false, false, true, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeShow}},
|
||||
{false, true, true, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan}},
|
||||
{true, false, true, true, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeShow}},
|
||||
{false, true, true, true, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan}},
|
||||
{true, true, true, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan}},
|
||||
{true, true, false, true, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeShow}},
|
||||
{true, true, true, true, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow}},
|
||||
|
||||
{true, true, true, true, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow}},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
|
|
@ -300,11 +317,11 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) {
|
|||
assert.NotEmpty(t, recoveryKey)
|
||||
assert.NotNil(t, sendingCrossSigningKeysCache)
|
||||
|
||||
sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, tc.sendingCallbacks, tc.sendingSupportsScan)
|
||||
sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, tc.sendingCallbacks, tc.sendingSupportsShow, tc.sendingSupportsScan)
|
||||
err = sendingHelper.Init(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, nil, tc.receivingCallbacks, tc.receivingSupportsScan)
|
||||
receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, nil, tc.receivingCallbacks, tc.receivingSupportsShow, tc.receivingSupportsScan)
|
||||
err = receivingHelper.Init(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -322,16 +339,13 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) {
|
|||
err = receivingHelper.AcceptVerification(ctx, txnID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, sendingIsQRCallbacks := tc.sendingCallbacks.(*showQRCodeVerificationCallbacks)
|
||||
_, sendingIsAllCallbacks := tc.sendingCallbacks.(*allVerificationCallbacks)
|
||||
sendingCanShowQR := sendingIsQRCallbacks || sendingIsAllCallbacks
|
||||
_, receivingIsQRCallbacks := tc.receivingCallbacks.(*showQRCodeVerificationCallbacks)
|
||||
_, receivingIsAllCallbacks := tc.receivingCallbacks.(*allVerificationCallbacks)
|
||||
receivingCanShowQR := receivingIsQRCallbacks || receivingIsAllCallbacks
|
||||
// Ensure that the receiving device get a notification about the
|
||||
// transaction being ready.
|
||||
assert.Contains(t, tc.receivingCallbacks.GetVerificationsReadyTransactions(), txnID)
|
||||
|
||||
// Ensure that if the receiving device should show a QR code that
|
||||
// it has the correct content.
|
||||
if tc.sendingSupportsScan && receivingCanShowQR {
|
||||
if tc.sendingSupportsScan && tc.receivingSupportsShow {
|
||||
receivingShownQRCode := tc.receivingCallbacks.GetQRCodeShown(txnID)
|
||||
require.NotNil(t, receivingShownQRCode)
|
||||
assert.Equal(t, txnID, receivingShownQRCode.TransactionID)
|
||||
|
|
@ -340,7 +354,7 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) {
|
|||
|
||||
// Check for whether the receiving device should be scanning a QR
|
||||
// code.
|
||||
if tc.receivingSupportsScan && sendingCanShowQR {
|
||||
if tc.receivingSupportsScan && tc.sendingSupportsShow {
|
||||
assert.Contains(t, tc.receivingCallbacks.GetScanQRCodeTransactions(), txnID)
|
||||
}
|
||||
|
||||
|
|
@ -357,9 +371,13 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) {
|
|||
// device.
|
||||
ts.dispatchToDevice(t, ctx, sendingClient)
|
||||
|
||||
// Ensure that the sending device got a notification about the
|
||||
// transaction being ready.
|
||||
assert.Contains(t, tc.sendingCallbacks.GetVerificationsReadyTransactions(), txnID)
|
||||
|
||||
// Ensure that if the sending device should show a QR code that it
|
||||
// has the correct content.
|
||||
if tc.receivingSupportsScan && sendingCanShowQR {
|
||||
if tc.receivingSupportsScan && tc.sendingSupportsShow {
|
||||
sendingShownQRCode := tc.sendingCallbacks.GetQRCodeShown(txnID)
|
||||
require.NotNil(t, sendingShownQRCode)
|
||||
assert.Equal(t, txnID, sendingShownQRCode.TransactionID)
|
||||
|
|
@ -368,7 +386,7 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) {
|
|||
|
||||
// Check for whether the sending device should be scanning a QR
|
||||
// code.
|
||||
if tc.sendingSupportsScan && receivingCanShowQR {
|
||||
if tc.sendingSupportsScan && tc.receivingSupportsShow {
|
||||
assert.Contains(t, tc.sendingCallbacks.GetScanQRCodeTransactions(), txnID)
|
||||
}
|
||||
})
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue