mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 22:35:52 +01:00
hicli: add media cache entries when receiving events
This commit is contained in:
parent
bb6aaf79a9
commit
014ea70762
2 changed files with 82 additions and 7 deletions
|
|
@ -14,6 +14,7 @@ import (
|
|||
"github.com/rs/zerolog"
|
||||
|
||||
"maunium.net/go/mautrix/crypto"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/hicli/database"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
|
@ -52,11 +53,13 @@ func (h *HiClient) handleReceivedMegolmSession(ctx context.Context, roomID id.Ro
|
|||
continue
|
||||
}
|
||||
|
||||
evt.Decrypted, evt.DecryptedType, err = h.decryptEvent(ctx, evt.AsRawMautrix())
|
||||
var mautrixEvt *event.Event
|
||||
mautrixEvt, evt.Decrypted, evt.DecryptedType, err = h.decryptEvent(ctx, evt.AsRawMautrix())
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Stringer("event_id", evt.ID).Msg("Failed to decrypt event even after receiving megolm session")
|
||||
} else {
|
||||
decrypted = append(decrypted, evt)
|
||||
h.cacheMedia(ctx, mautrixEvt, evt.RowID)
|
||||
}
|
||||
}
|
||||
if len(decrypted) > 0 {
|
||||
|
|
|
|||
|
|
@ -223,20 +223,86 @@ func removeReplyFallback(evt *event.Event) []byte {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (h *HiClient) decryptEvent(ctx context.Context, evt *event.Event) ([]byte, string, error) {
|
||||
func (h *HiClient) decryptEvent(ctx context.Context, evt *event.Event) (*event.Event, []byte, string, error) {
|
||||
err := evt.Content.ParseRaw(evt.Type)
|
||||
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
|
||||
return nil, "", err
|
||||
return nil, nil, "", err
|
||||
}
|
||||
decrypted, err := h.Crypto.DecryptMegolmEvent(ctx, evt)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
return nil, nil, "", err
|
||||
}
|
||||
withoutFallback := removeReplyFallback(decrypted)
|
||||
if withoutFallback != nil {
|
||||
return withoutFallback, decrypted.Type.Type, nil
|
||||
return decrypted, withoutFallback, decrypted.Type.Type, nil
|
||||
}
|
||||
return decrypted, decrypted.Content.VeryRaw, decrypted.Type.Type, nil
|
||||
}
|
||||
|
||||
func (h *HiClient) addMediaCache(
|
||||
ctx context.Context,
|
||||
eventRowID database.EventRowID,
|
||||
uri id.ContentURIString,
|
||||
file *event.EncryptedFileInfo,
|
||||
info *event.FileInfo,
|
||||
fileName string,
|
||||
) {
|
||||
parsedMXC := uri.ParseOrIgnore()
|
||||
if !parsedMXC.IsValid() {
|
||||
return
|
||||
}
|
||||
cm := &database.CachedMedia{
|
||||
MXC: parsedMXC,
|
||||
EventRowID: eventRowID,
|
||||
FileName: fileName,
|
||||
}
|
||||
if file != nil {
|
||||
cm.EncFile = &file.EncryptedFile
|
||||
}
|
||||
if info != nil {
|
||||
cm.MimeType = info.MimeType
|
||||
}
|
||||
err := h.DB.CachedMedia.Put(ctx, cm)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Warn().Err(err).
|
||||
Stringer("mxc", parsedMXC).
|
||||
Int64("event_rowid", int64(eventRowID)).
|
||||
Msg("Failed to add cached media entry")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HiClient) cacheMedia(ctx context.Context, evt *event.Event, rowID database.EventRowID) {
|
||||
switch evt.Type {
|
||||
case event.EventMessage, event.EventSticker:
|
||||
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if content.File != nil {
|
||||
h.addMediaCache(ctx, rowID, content.File.URL, content.File, content.Info, content.GetFileName())
|
||||
} else if content.URL != "" {
|
||||
h.addMediaCache(ctx, rowID, content.URL, nil, content.Info, content.GetFileName())
|
||||
}
|
||||
if content.GetInfo().ThumbnailFile != nil {
|
||||
h.addMediaCache(ctx, rowID, content.Info.ThumbnailFile.URL, content.Info.ThumbnailFile, content.Info.ThumbnailInfo, "")
|
||||
} else if content.GetInfo().ThumbnailURL != "" {
|
||||
h.addMediaCache(ctx, rowID, content.Info.ThumbnailURL, nil, content.Info.ThumbnailInfo, "")
|
||||
}
|
||||
case event.StateRoomAvatar:
|
||||
_ = evt.Content.ParseRaw(evt.Type)
|
||||
content, ok := evt.Content.Parsed.(*event.RoomAvatarEventContent)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
h.addMediaCache(ctx, rowID, content.URL, nil, nil, "")
|
||||
case event.StateMember:
|
||||
_ = evt.Content.ParseRaw(evt.Type)
|
||||
content, ok := evt.Content.Parsed.(*event.MemberEventContent)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
h.addMediaCache(ctx, rowID, content.AvatarURL, nil, nil, "")
|
||||
}
|
||||
return decrypted.Content.VeryRaw, decrypted.Type.Type, nil
|
||||
}
|
||||
|
||||
func (h *HiClient) processEvent(ctx context.Context, evt *event.Event, decryptionQueue map[id.SessionID]*database.SessionRequest, checkDB bool) (*database.Event, error) {
|
||||
|
|
@ -254,8 +320,9 @@ func (h *HiClient) processEvent(ctx context.Context, evt *event.Event, decryptio
|
|||
dbEvt.Content = contentWithoutFallback
|
||||
}
|
||||
var decryptionErr error
|
||||
var decryptedMautrixEvt *event.Event
|
||||
if evt.Type == event.EventEncrypted && dbEvt.RedactedBy == "" {
|
||||
dbEvt.Decrypted, dbEvt.DecryptedType, decryptionErr = h.decryptEvent(ctx, evt)
|
||||
decryptedMautrixEvt, dbEvt.Decrypted, dbEvt.DecryptedType, decryptionErr = h.decryptEvent(ctx, evt)
|
||||
if decryptionErr != nil {
|
||||
dbEvt.DecryptionError = decryptionErr.Error()
|
||||
}
|
||||
|
|
@ -272,6 +339,11 @@ func (h *HiClient) processEvent(ctx context.Context, evt *event.Event, decryptio
|
|||
if err != nil {
|
||||
return dbEvt, fmt.Errorf("failed to save event %s: %w", evt.ID, err)
|
||||
}
|
||||
if decryptedMautrixEvt != nil {
|
||||
h.cacheMedia(ctx, decryptedMautrixEvt, dbEvt.RowID)
|
||||
} else {
|
||||
h.cacheMedia(ctx, evt, dbEvt.RowID)
|
||||
}
|
||||
if decryptionErr != nil && isDecryptionErrorRetryable(decryptionErr) {
|
||||
req, ok := decryptionQueue[dbEvt.MegolmSessionID]
|
||||
if !ok {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue