Use a callback to receive secret response

To properly receive and store a requested secret, we usually need to
validate it against something like a public key to ensure we got the
correct one.

This changes the API so that we instead use a callback to receive any
incoming secret matching our request but we'll fail when we hit the
specified timeout if we never receive anything that is accepted.
This commit is contained in:
Toni Spets 2024-03-12 13:00:55 +02:00 committed by Toni Spets
commit fad4448ab7

View file

@ -16,13 +16,26 @@ import (
"maunium.net/go/mautrix/id"
)
func (mach *OlmMachine) GetOrRequestSecret(ctx context.Context, name id.Secret, timeout time.Duration) (secret string, err error) {
secret, err = mach.CryptoStore.GetSecret(ctx, name)
if err != nil || secret != "" {
return
// Callback function to process a received secret.
//
// Returning true or an error will immediately return from the wait loop, returning false will continue waiting for new responses.
type SecretReceiverFunc func(string) (bool, error)
func (mach *OlmMachine) GetOrRequestSecret(ctx context.Context, name id.Secret, receiver SecretReceiverFunc, timeout time.Duration) (err error) {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
// always offer our stored secret first, if any
secret, err := mach.CryptoStore.GetSecret(ctx, name)
if err != nil {
return err
} else if secret != "" {
if ok, err := receiver(secret); ok || err != nil {
return err
}
}
requestID, secretChan := random.String(64), make(chan string, 1)
requestID, secretChan := random.String(64), make(chan string, 5)
mach.secretLock.Lock()
mach.secretListeners[requestID] = secretChan
mach.secretLock.Unlock()
@ -43,17 +56,27 @@ func (mach *OlmMachine) GetOrRequestSecret(ctx context.Context, name id.Secret,
return
}
select {
case <-ctx.Done():
err = ctx.Err()
case <-time.After(timeout):
case secret = <-secretChan:
}
// best effort cancel request from all devices when returning
defer func() {
go mach.sendToOneDevice(context.Background(), mach.Client.UserID, id.DeviceID("*"), event.ToDeviceSecretRequest, &event.SecretRequestEventContent{
Action: event.SecretRequestCancellation,
RequestID: requestID,
RequestingDeviceID: mach.Client.DeviceID,
})
}()
if secret != "" {
err = mach.CryptoStore.PutSecret(ctx, name, secret)
for {
select {
case <-ctx.Done():
return ctx.Err()
case secret = <-secretChan:
if ok, err := receiver(secret); err != nil {
return err
} else if ok {
return mach.CryptoStore.PutSecret(ctx, name, secret)
}
}
}
return
}
func (mach *OlmMachine) HandleSecretRequest(ctx context.Context, userID id.UserID, content *event.SecretRequestEventContent) {
@ -159,17 +182,10 @@ func (mach *OlmMachine) receiveSecret(ctx context.Context, evt *DecryptedOlmEven
return
}
// secret channel is buffered and we don't want to block
// at worst we drop _some_ of the responses
select {
case secretChan <- content.Secret:
default:
}
// best effort cancel this for all other targets
go func() {
mach.sendToOneDevice(ctx, mach.Client.UserID, id.DeviceID("*"), event.ToDeviceSecretRequest, &event.SecretRequestEventContent{
Action: event.SecretRequestCancellation,
RequestID: content.RequestID,
RequestingDeviceID: mach.Client.DeviceID,
})
}()
}