hicli: add media cache entries when receiving events

This commit is contained in:
Tulir Asokan 2024-10-06 15:07:04 +03:00
commit 014ea70762
2 changed files with 82 additions and 7 deletions

View file

@ -14,6 +14,7 @@ import (
"github.com/rs/zerolog"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/hicli/database"
"maunium.net/go/mautrix/id"
)
@ -52,11 +53,13 @@ func (h *HiClient) handleReceivedMegolmSession(ctx context.Context, roomID id.Ro
continue
}
evt.Decrypted, evt.DecryptedType, err = h.decryptEvent(ctx, evt.AsRawMautrix())
var mautrixEvt *event.Event
mautrixEvt, evt.Decrypted, evt.DecryptedType, err = h.decryptEvent(ctx, evt.AsRawMautrix())
if err != nil {
log.Warn().Err(err).Stringer("event_id", evt.ID).Msg("Failed to decrypt event even after receiving megolm session")
} else {
decrypted = append(decrypted, evt)
h.cacheMedia(ctx, mautrixEvt, evt.RowID)
}
}
if len(decrypted) > 0 {

View file

@ -223,20 +223,86 @@ func removeReplyFallback(evt *event.Event) []byte {
return nil
}
func (h *HiClient) decryptEvent(ctx context.Context, evt *event.Event) ([]byte, string, error) {
func (h *HiClient) decryptEvent(ctx context.Context, evt *event.Event) (*event.Event, []byte, string, error) {
err := evt.Content.ParseRaw(evt.Type)
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
return nil, "", err
return nil, nil, "", err
}
decrypted, err := h.Crypto.DecryptMegolmEvent(ctx, evt)
if err != nil {
return nil, "", err
return nil, nil, "", err
}
withoutFallback := removeReplyFallback(decrypted)
if withoutFallback != nil {
return withoutFallback, decrypted.Type.Type, nil
return decrypted, withoutFallback, decrypted.Type.Type, nil
}
return decrypted, decrypted.Content.VeryRaw, decrypted.Type.Type, nil
}
func (h *HiClient) addMediaCache(
ctx context.Context,
eventRowID database.EventRowID,
uri id.ContentURIString,
file *event.EncryptedFileInfo,
info *event.FileInfo,
fileName string,
) {
parsedMXC := uri.ParseOrIgnore()
if !parsedMXC.IsValid() {
return
}
cm := &database.CachedMedia{
MXC: parsedMXC,
EventRowID: eventRowID,
FileName: fileName,
}
if file != nil {
cm.EncFile = &file.EncryptedFile
}
if info != nil {
cm.MimeType = info.MimeType
}
err := h.DB.CachedMedia.Put(ctx, cm)
if err != nil {
zerolog.Ctx(ctx).Warn().Err(err).
Stringer("mxc", parsedMXC).
Int64("event_rowid", int64(eventRowID)).
Msg("Failed to add cached media entry")
}
}
func (h *HiClient) cacheMedia(ctx context.Context, evt *event.Event, rowID database.EventRowID) {
switch evt.Type {
case event.EventMessage, event.EventSticker:
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
if !ok {
return
}
if content.File != nil {
h.addMediaCache(ctx, rowID, content.File.URL, content.File, content.Info, content.GetFileName())
} else if content.URL != "" {
h.addMediaCache(ctx, rowID, content.URL, nil, content.Info, content.GetFileName())
}
if content.GetInfo().ThumbnailFile != nil {
h.addMediaCache(ctx, rowID, content.Info.ThumbnailFile.URL, content.Info.ThumbnailFile, content.Info.ThumbnailInfo, "")
} else if content.GetInfo().ThumbnailURL != "" {
h.addMediaCache(ctx, rowID, content.Info.ThumbnailURL, nil, content.Info.ThumbnailInfo, "")
}
case event.StateRoomAvatar:
_ = evt.Content.ParseRaw(evt.Type)
content, ok := evt.Content.Parsed.(*event.RoomAvatarEventContent)
if !ok {
return
}
h.addMediaCache(ctx, rowID, content.URL, nil, nil, "")
case event.StateMember:
_ = evt.Content.ParseRaw(evt.Type)
content, ok := evt.Content.Parsed.(*event.MemberEventContent)
if !ok {
return
}
h.addMediaCache(ctx, rowID, content.AvatarURL, nil, nil, "")
}
return decrypted.Content.VeryRaw, decrypted.Type.Type, nil
}
func (h *HiClient) processEvent(ctx context.Context, evt *event.Event, decryptionQueue map[id.SessionID]*database.SessionRequest, checkDB bool) (*database.Event, error) {
@ -254,8 +320,9 @@ func (h *HiClient) processEvent(ctx context.Context, evt *event.Event, decryptio
dbEvt.Content = contentWithoutFallback
}
var decryptionErr error
var decryptedMautrixEvt *event.Event
if evt.Type == event.EventEncrypted && dbEvt.RedactedBy == "" {
dbEvt.Decrypted, dbEvt.DecryptedType, decryptionErr = h.decryptEvent(ctx, evt)
decryptedMautrixEvt, dbEvt.Decrypted, dbEvt.DecryptedType, decryptionErr = h.decryptEvent(ctx, evt)
if decryptionErr != nil {
dbEvt.DecryptionError = decryptionErr.Error()
}
@ -272,6 +339,11 @@ func (h *HiClient) processEvent(ctx context.Context, evt *event.Event, decryptio
if err != nil {
return dbEvt, fmt.Errorf("failed to save event %s: %w", evt.ID, err)
}
if decryptedMautrixEvt != nil {
h.cacheMedia(ctx, decryptedMautrixEvt, dbEvt.RowID)
} else {
h.cacheMedia(ctx, evt, dbEvt.RowID)
}
if decryptionErr != nil && isDecryptionErrorRetryable(decryptionErr) {
req, ok := decryptionQueue[dbEvt.MegolmSessionID]
if !ok {