mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
Use a callback to receive secret response
To properly receive and store a requested secret, we usually need to validate it against something like a public key to ensure we got the correct one. This changes the API so that we instead use a callback to receive any incoming secret matching our request but we'll fail when we hit the specified timeout if we never receive anything that is accepted.
This commit is contained in:
parent
a7bf485893
commit
fad4448ab7
1 changed files with 39 additions and 23 deletions
|
|
@ -16,13 +16,26 @@ import (
|
|||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
func (mach *OlmMachine) GetOrRequestSecret(ctx context.Context, name id.Secret, timeout time.Duration) (secret string, err error) {
|
||||
secret, err = mach.CryptoStore.GetSecret(ctx, name)
|
||||
if err != nil || secret != "" {
|
||||
return
|
||||
// Callback function to process a received secret.
|
||||
//
|
||||
// Returning true or an error will immediately return from the wait loop, returning false will continue waiting for new responses.
|
||||
type SecretReceiverFunc func(string) (bool, error)
|
||||
|
||||
func (mach *OlmMachine) GetOrRequestSecret(ctx context.Context, name id.Secret, receiver SecretReceiverFunc, timeout time.Duration) (err error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
// always offer our stored secret first, if any
|
||||
secret, err := mach.CryptoStore.GetSecret(ctx, name)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if secret != "" {
|
||||
if ok, err := receiver(secret); ok || err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
requestID, secretChan := random.String(64), make(chan string, 1)
|
||||
requestID, secretChan := random.String(64), make(chan string, 5)
|
||||
mach.secretLock.Lock()
|
||||
mach.secretListeners[requestID] = secretChan
|
||||
mach.secretLock.Unlock()
|
||||
|
|
@ -43,17 +56,27 @@ func (mach *OlmMachine) GetOrRequestSecret(ctx context.Context, name id.Secret,
|
|||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
err = ctx.Err()
|
||||
case <-time.After(timeout):
|
||||
case secret = <-secretChan:
|
||||
}
|
||||
// best effort cancel request from all devices when returning
|
||||
defer func() {
|
||||
go mach.sendToOneDevice(context.Background(), mach.Client.UserID, id.DeviceID("*"), event.ToDeviceSecretRequest, &event.SecretRequestEventContent{
|
||||
Action: event.SecretRequestCancellation,
|
||||
RequestID: requestID,
|
||||
RequestingDeviceID: mach.Client.DeviceID,
|
||||
})
|
||||
}()
|
||||
|
||||
if secret != "" {
|
||||
err = mach.CryptoStore.PutSecret(ctx, name, secret)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case secret = <-secretChan:
|
||||
if ok, err := receiver(secret); err != nil {
|
||||
return err
|
||||
} else if ok {
|
||||
return mach.CryptoStore.PutSecret(ctx, name, secret)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) HandleSecretRequest(ctx context.Context, userID id.UserID, content *event.SecretRequestEventContent) {
|
||||
|
|
@ -159,17 +182,10 @@ func (mach *OlmMachine) receiveSecret(ctx context.Context, evt *DecryptedOlmEven
|
|||
return
|
||||
}
|
||||
|
||||
// secret channel is buffered and we don't want to block
|
||||
// at worst we drop _some_ of the responses
|
||||
select {
|
||||
case secretChan <- content.Secret:
|
||||
default:
|
||||
}
|
||||
|
||||
// best effort cancel this for all other targets
|
||||
go func() {
|
||||
mach.sendToOneDevice(ctx, mach.Client.UserID, id.DeviceID("*"), event.ToDeviceSecretRequest, &event.SecretRequestEventContent{
|
||||
Action: event.SecretRequestCancellation,
|
||||
RequestID: content.RequestID,
|
||||
RequestingDeviceID: mach.Client.DeviceID,
|
||||
})
|
||||
}()
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue