From fad4448ab7ef4abdb002be01c9e5ee4fa5e16362 Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Tue, 12 Mar 2024 13:00:55 +0200 Subject: [PATCH] Use a callback to receive secret response To properly receive and store a requested secret, we usually need to validate it against something like a public key to ensure we got the correct one. This changes the API so that we instead use a callback to receive any incoming secret matching our request but we'll fail when we hit the specified timeout if we never receive anything that is accepted. --- crypto/sharing.go | 62 +++++++++++++++++++++++++++++------------------ 1 file changed, 39 insertions(+), 23 deletions(-) diff --git a/crypto/sharing.go b/crypto/sharing.go index 18088b8e..c0f3e209 100644 --- a/crypto/sharing.go +++ b/crypto/sharing.go @@ -16,13 +16,26 @@ import ( "maunium.net/go/mautrix/id" ) -func (mach *OlmMachine) GetOrRequestSecret(ctx context.Context, name id.Secret, timeout time.Duration) (secret string, err error) { - secret, err = mach.CryptoStore.GetSecret(ctx, name) - if err != nil || secret != "" { - return +// Callback function to process a received secret. +// +// Returning true or an error will immediately return from the wait loop, returning false will continue waiting for new responses. +type SecretReceiverFunc func(string) (bool, error) + +func (mach *OlmMachine) GetOrRequestSecret(ctx context.Context, name id.Secret, receiver SecretReceiverFunc, timeout time.Duration) (err error) { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + // always offer our stored secret first, if any + secret, err := mach.CryptoStore.GetSecret(ctx, name) + if err != nil { + return err + } else if secret != "" { + if ok, err := receiver(secret); ok || err != nil { + return err + } } - requestID, secretChan := random.String(64), make(chan string, 1) + requestID, secretChan := random.String(64), make(chan string, 5) mach.secretLock.Lock() mach.secretListeners[requestID] = secretChan mach.secretLock.Unlock() @@ -43,17 +56,27 @@ func (mach *OlmMachine) GetOrRequestSecret(ctx context.Context, name id.Secret, return } - select { - case <-ctx.Done(): - err = ctx.Err() - case <-time.After(timeout): - case secret = <-secretChan: - } + // best effort cancel request from all devices when returning + defer func() { + go mach.sendToOneDevice(context.Background(), mach.Client.UserID, id.DeviceID("*"), event.ToDeviceSecretRequest, &event.SecretRequestEventContent{ + Action: event.SecretRequestCancellation, + RequestID: requestID, + RequestingDeviceID: mach.Client.DeviceID, + }) + }() - if secret != "" { - err = mach.CryptoStore.PutSecret(ctx, name, secret) + for { + select { + case <-ctx.Done(): + return ctx.Err() + case secret = <-secretChan: + if ok, err := receiver(secret); err != nil { + return err + } else if ok { + return mach.CryptoStore.PutSecret(ctx, name, secret) + } + } } - return } func (mach *OlmMachine) HandleSecretRequest(ctx context.Context, userID id.UserID, content *event.SecretRequestEventContent) { @@ -159,17 +182,10 @@ func (mach *OlmMachine) receiveSecret(ctx context.Context, evt *DecryptedOlmEven return } + // secret channel is buffered and we don't want to block + // at worst we drop _some_ of the responses select { case secretChan <- content.Secret: default: } - - // best effort cancel this for all other targets - go func() { - mach.sendToOneDevice(ctx, mach.Client.UserID, id.DeviceID("*"), event.ToDeviceSecretRequest, &event.SecretRequestEventContent{ - Action: event.SecretRequestCancellation, - RequestID: content.RequestID, - RequestingDeviceID: mach.Client.DeviceID, - }) - }() }