diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index e7ea53c5..2719ea78 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -746,13 +746,8 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn *veri Reason: "The verification was accepted on another device.", }, } - devices, err := vh.mach.CryptoStore.GetDevices(ctx, txn.TheirUser) - if err != nil { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to get devices for %s: %w", txn.TheirUser, err) - return - } req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{txn.TheirUser: {}}} - for deviceID := range devices { + for _, deviceID := range txn.SentToDeviceIDs { if deviceID == txn.TheirDevice || deviceID == vh.client.DeviceID { // Don't ever send a cancellation to the device that accepted // the request or to our own device (which can happen if this @@ -762,7 +757,7 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn *veri req.Messages[txn.TheirUser][deviceID] = content } - _, err = vh.client.SendToDevice(ctx, event.ToDeviceVerificationCancel, &req) + _, err := vh.client.SendToDevice(ctx, event.ToDeviceVerificationCancel, &req) if err != nil { log.Warn().Err(err).Msg("Failed to send cancellation requests") } @@ -878,14 +873,48 @@ func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn *verif func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn *verificationTransaction, evt *event.Event) { cancelEvt := evt.Content.AsVerificationCancel() - vh.getLog(ctx).Info(). + log := vh.getLog(ctx).With(). Str("verification_action", "cancel"). Stringer("transaction_id", txn.TransactionID). Str("cancel_code", string(cancelEvt.Code)). Str("reason", cancelEvt.Reason). - Msg("Verification was cancelled") + Logger() + ctx = log.WithContext(ctx) + log.Info().Msg("Verification was cancelled") vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() + + // Element (and at least the old desktop client) send cancellation events + // when the user rejects the verification request. This is really dumb, + // because they should just instead ignore the request and not send a + // cancellation. + // + // The above behavior causes a problem with the other devices that we sent + // the verification request to because they don't know that the request was + // cancelled. + // + // As a workaround, if we receive a cancellation event to a transaction + // that is currently in the REQUESTED state, then we will send + // cancellations to all of the devices that we sent the request to. This + // will ensure that all of the clients know that the request was cancelled. + if txn.VerificationState == verificationStateRequested && len(txn.SentToDeviceIDs) > 0 { + content := &event.Content{ + Parsed: &event.VerificationCancelEventContent{ + ToDeviceVerificationEvent: event.ToDeviceVerificationEvent{TransactionID: txn.TransactionID}, + Code: event.VerificationCancelCodeUser, + Reason: "The verification was rejected from another device.", + }, + } + req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{txn.TheirUser: {}}} + for _, deviceID := range txn.SentToDeviceIDs { + req.Messages[txn.TheirUser][deviceID] = content + } + _, err := vh.client.SendToDevice(ctx, event.ToDeviceVerificationCancel, &req) + if err != nil { + log.Warn().Err(err).Msg("Failed to send cancellation requests") + } + } + delete(vh.activeTransactions, txn.TransactionID) vh.verificationCancelledCallback(ctx, txn.TransactionID, cancelEvt.Code, cancelEvt.Reason) } diff --git a/crypto/verificationhelper/verificationhelper_test.go b/crypto/verificationhelper/verificationhelper_test.go index e8be5771..876e90f7 100644 --- a/crypto/verificationhelper/verificationhelper_test.go +++ b/crypto/verificationhelper/verificationhelper_test.go @@ -141,13 +141,22 @@ func TestVerification_Start(t *testing.T) { func TestVerification_StartThenCancel(t *testing.T) { ctx := log.Logger.WithContext(context.TODO()) + bystanderDeviceID := id.DeviceID("bystander") for _, sendingCancels := range []bool{true, false} { t.Run(fmt.Sprintf("sendingCancels=%t", sendingCancels), func(t *testing.T) { - ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + ts, sendingClient, receivingClient, sendingCryptoStore, receivingCryptoStore, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) defer ts.Close() _, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) + bystanderClient, _ := ts.Login(t, ctx, aliceUserID, bystanderDeviceID) + bystanderMachine := bystanderClient.Crypto.(*cryptohelper.CryptoHelper).Machine() + bystanderHelper := verificationhelper.NewVerificationHelper(bystanderClient, bystanderMachine, newAllVerificationCallbacks(), true) + require.NoError(t, bystanderHelper.Init(ctx)) + + require.NoError(t, sendingCryptoStore.PutDevice(ctx, aliceUserID, bystanderMachine.OwnIdentity())) + require.NoError(t, receivingCryptoStore.PutDevice(ctx, aliceUserID, bystanderMachine.OwnIdentity())) + txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) @@ -159,7 +168,13 @@ func TestVerification_StartThenCancel(t *testing.T) { assert.Equal(t, txnID, receivingInbox[0].Content.AsVerificationRequest().TransactionID) ts.dispatchToDevice(t, ctx, receivingClient) - // Cancel the verification request on the sending device. + // Process the request event on the bystander device. + bystanderInbox := ts.DeviceInbox[aliceUserID][bystanderDeviceID] + assert.Len(t, bystanderInbox, 1) + assert.Equal(t, txnID, bystanderInbox[0].Content.AsVerificationRequest().TransactionID) + ts.dispatchToDevice(t, ctx, bystanderClient) + + // Cancel the verification request. var cancelEvt *event.VerificationCancelEventContent if sendingCancels { err = sendingHelper.CancelVerification(ctx, txnID, event.VerificationCancelCodeUser, "Recovery code preferred") @@ -171,6 +186,11 @@ func TestVerification_StartThenCancel(t *testing.T) { // Ensure that the cancellation event was sent to the receiving device. assert.Len(t, ts.DeviceInbox[aliceUserID][receivingDeviceID], 1) cancelEvt = ts.DeviceInbox[aliceUserID][receivingDeviceID][0].Content.AsVerificationCancel() + + // Ensure that the cancellation event was sent to the bystander device. + assert.Len(t, ts.DeviceInbox[aliceUserID][bystanderDeviceID], 1) + bystanderCancelEvt := ts.DeviceInbox[aliceUserID][bystanderDeviceID][0].Content.AsVerificationCancel() + assert.Equal(t, cancelEvt, bystanderCancelEvt) } else { err = receivingHelper.CancelVerification(ctx, txnID, event.VerificationCancelCodeUser, "Recovery code preferred") assert.NoError(t, err) @@ -181,10 +201,25 @@ func TestVerification_StartThenCancel(t *testing.T) { // Ensure that the cancellation event was sent to the sending device. assert.Len(t, ts.DeviceInbox[aliceUserID][sendingDeviceID], 1) cancelEvt = ts.DeviceInbox[aliceUserID][sendingDeviceID][0].Content.AsVerificationCancel() + + // The bystander device should not have a cancellation event. + assert.Empty(t, ts.DeviceInbox[aliceUserID][bystanderDeviceID]) } assert.Equal(t, txnID, cancelEvt.TransactionID) assert.Equal(t, event.VerificationCancelCodeUser, cancelEvt.Code) assert.Equal(t, "Recovery code preferred", cancelEvt.Reason) + + if !sendingCancels { + // Process the cancellation event on the sending device. + ts.dispatchToDevice(t, ctx, sendingClient) + + // Ensure that the cancellation event was sent to the bystander device. + assert.Len(t, ts.DeviceInbox[aliceUserID][bystanderDeviceID], 1) + bystanderCancelEvt := ts.DeviceInbox[aliceUserID][bystanderDeviceID][0].Content.AsVerificationCancel() + assert.Equal(t, txnID, bystanderCancelEvt.TransactionID) + assert.Equal(t, event.VerificationCancelCodeUser, bystanderCancelEvt.Code) + assert.Equal(t, "The verification was rejected from another device.", bystanderCancelEvt.Reason) + } }) } }