From 014ea707622a188ccf592a6e0bc537b7d668d4d3 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 6 Oct 2024 15:07:04 +0300 Subject: [PATCH] hicli: add media cache entries when receiving events --- hicli/decryptionqueue.go | 5 ++- hicli/sync.go | 84 +++++++++++++++++++++++++++++++++++++--- 2 files changed, 82 insertions(+), 7 deletions(-) diff --git a/hicli/decryptionqueue.go b/hicli/decryptionqueue.go index 70ea9f23..87b6b8b2 100644 --- a/hicli/decryptionqueue.go +++ b/hicli/decryptionqueue.go @@ -14,6 +14,7 @@ import ( "github.com/rs/zerolog" "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/hicli/database" "maunium.net/go/mautrix/id" ) @@ -52,11 +53,13 @@ func (h *HiClient) handleReceivedMegolmSession(ctx context.Context, roomID id.Ro continue } - evt.Decrypted, evt.DecryptedType, err = h.decryptEvent(ctx, evt.AsRawMautrix()) + var mautrixEvt *event.Event + mautrixEvt, evt.Decrypted, evt.DecryptedType, err = h.decryptEvent(ctx, evt.AsRawMautrix()) if err != nil { log.Warn().Err(err).Stringer("event_id", evt.ID).Msg("Failed to decrypt event even after receiving megolm session") } else { decrypted = append(decrypted, evt) + h.cacheMedia(ctx, mautrixEvt, evt.RowID) } } if len(decrypted) > 0 { diff --git a/hicli/sync.go b/hicli/sync.go index 086f6dd1..3b40af9f 100644 --- a/hicli/sync.go +++ b/hicli/sync.go @@ -223,20 +223,86 @@ func removeReplyFallback(evt *event.Event) []byte { return nil } -func (h *HiClient) decryptEvent(ctx context.Context, evt *event.Event) ([]byte, string, error) { +func (h *HiClient) decryptEvent(ctx context.Context, evt *event.Event) (*event.Event, []byte, string, error) { err := evt.Content.ParseRaw(evt.Type) if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) { - return nil, "", err + return nil, nil, "", err } decrypted, err := h.Crypto.DecryptMegolmEvent(ctx, evt) if err != nil { - return nil, "", err + return nil, nil, "", err } withoutFallback := removeReplyFallback(decrypted) if withoutFallback != nil { - return withoutFallback, decrypted.Type.Type, nil + return decrypted, withoutFallback, decrypted.Type.Type, nil + } + return decrypted, decrypted.Content.VeryRaw, decrypted.Type.Type, nil +} + +func (h *HiClient) addMediaCache( + ctx context.Context, + eventRowID database.EventRowID, + uri id.ContentURIString, + file *event.EncryptedFileInfo, + info *event.FileInfo, + fileName string, +) { + parsedMXC := uri.ParseOrIgnore() + if !parsedMXC.IsValid() { + return + } + cm := &database.CachedMedia{ + MXC: parsedMXC, + EventRowID: eventRowID, + FileName: fileName, + } + if file != nil { + cm.EncFile = &file.EncryptedFile + } + if info != nil { + cm.MimeType = info.MimeType + } + err := h.DB.CachedMedia.Put(ctx, cm) + if err != nil { + zerolog.Ctx(ctx).Warn().Err(err). + Stringer("mxc", parsedMXC). + Int64("event_rowid", int64(eventRowID)). + Msg("Failed to add cached media entry") + } +} + +func (h *HiClient) cacheMedia(ctx context.Context, evt *event.Event, rowID database.EventRowID) { + switch evt.Type { + case event.EventMessage, event.EventSticker: + content, ok := evt.Content.Parsed.(*event.MessageEventContent) + if !ok { + return + } + if content.File != nil { + h.addMediaCache(ctx, rowID, content.File.URL, content.File, content.Info, content.GetFileName()) + } else if content.URL != "" { + h.addMediaCache(ctx, rowID, content.URL, nil, content.Info, content.GetFileName()) + } + if content.GetInfo().ThumbnailFile != nil { + h.addMediaCache(ctx, rowID, content.Info.ThumbnailFile.URL, content.Info.ThumbnailFile, content.Info.ThumbnailInfo, "") + } else if content.GetInfo().ThumbnailURL != "" { + h.addMediaCache(ctx, rowID, content.Info.ThumbnailURL, nil, content.Info.ThumbnailInfo, "") + } + case event.StateRoomAvatar: + _ = evt.Content.ParseRaw(evt.Type) + content, ok := evt.Content.Parsed.(*event.RoomAvatarEventContent) + if !ok { + return + } + h.addMediaCache(ctx, rowID, content.URL, nil, nil, "") + case event.StateMember: + _ = evt.Content.ParseRaw(evt.Type) + content, ok := evt.Content.Parsed.(*event.MemberEventContent) + if !ok { + return + } + h.addMediaCache(ctx, rowID, content.AvatarURL, nil, nil, "") } - return decrypted.Content.VeryRaw, decrypted.Type.Type, nil } func (h *HiClient) processEvent(ctx context.Context, evt *event.Event, decryptionQueue map[id.SessionID]*database.SessionRequest, checkDB bool) (*database.Event, error) { @@ -254,8 +320,9 @@ func (h *HiClient) processEvent(ctx context.Context, evt *event.Event, decryptio dbEvt.Content = contentWithoutFallback } var decryptionErr error + var decryptedMautrixEvt *event.Event if evt.Type == event.EventEncrypted && dbEvt.RedactedBy == "" { - dbEvt.Decrypted, dbEvt.DecryptedType, decryptionErr = h.decryptEvent(ctx, evt) + decryptedMautrixEvt, dbEvt.Decrypted, dbEvt.DecryptedType, decryptionErr = h.decryptEvent(ctx, evt) if decryptionErr != nil { dbEvt.DecryptionError = decryptionErr.Error() } @@ -272,6 +339,11 @@ func (h *HiClient) processEvent(ctx context.Context, evt *event.Event, decryptio if err != nil { return dbEvt, fmt.Errorf("failed to save event %s: %w", evt.ID, err) } + if decryptedMautrixEvt != nil { + h.cacheMedia(ctx, decryptedMautrixEvt, dbEvt.RowID) + } else { + h.cacheMedia(ctx, evt, dbEvt.RowID) + } if decryptionErr != nil && isDecryptionErrorRetryable(decryptionErr) { req, ok := decryptionQueue[dbEvt.MegolmSessionID] if !ok {