mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
hicli: add html sanitization and push rule evaluation
This commit is contained in:
parent
1d4c2d2554
commit
758e80a5f0
16 changed files with 823 additions and 71 deletions
1
go.mod
1
go.mod
|
|
@ -28,6 +28,7 @@ require (
|
|||
golang.org/x/sync v0.8.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
maunium.net/go/mauflag v1.0.0
|
||||
mvdan.cc/xurls/v2 v2.5.0
|
||||
)
|
||||
|
||||
require (
|
||||
|
|
|
|||
2
go.sum
2
go.sum
|
|
@ -83,3 +83,5 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
|||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M=
|
||||
maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA=
|
||||
mvdan.cc/xurls/v2 v2.5.0 h1:lyBNOm8Wo71UknhUs4QTFUNNMyxy2JEIaKKo0RWOh+8=
|
||||
mvdan.cc/xurls/v2 v2.5.0/go.mod h1:yQgaGQ1rFtJUzkmKiHYSSfuQxqfYmd//X6PxvholpeE=
|
||||
|
|
|
|||
|
|
@ -25,9 +25,10 @@ import (
|
|||
|
||||
const (
|
||||
getEventBaseQuery = `
|
||||
SELECT rowid, -1, room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned,
|
||||
transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, send_error,
|
||||
reactions, last_edit_rowid
|
||||
SELECT rowid, -1,
|
||||
room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type,
|
||||
unsigned, local_content, transaction_id, redacted_by, relates_to, relation_type,
|
||||
megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid, unread_type
|
||||
FROM event
|
||||
`
|
||||
getEventByRowID = getEventBaseQuery + `WHERE rowid = $1`
|
||||
|
|
@ -36,10 +37,11 @@ const (
|
|||
getFailedEventsByMegolmSessionID = getEventBaseQuery + `WHERE room_id = $1 AND megolm_session_id = $2 AND decryption_error IS NOT NULL`
|
||||
insertEventBaseQuery = `
|
||||
INSERT INTO event (
|
||||
room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned,
|
||||
transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, send_error
|
||||
room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type,
|
||||
unsigned, local_content, transaction_id, redacted_by, relates_to, relation_type,
|
||||
megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid, unread_type
|
||||
)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21)
|
||||
`
|
||||
insertEventQuery = insertEventBaseQuery + `RETURNING rowid`
|
||||
upsertEventQuery = insertEventBaseQuery + `
|
||||
|
|
@ -50,7 +52,8 @@ const (
|
|||
decryption_error=CASE WHEN COALESCE(event.decrypted, excluded.decrypted) IS NULL THEN COALESCE(excluded.decryption_error, event.decryption_error) END,
|
||||
send_error=excluded.send_error,
|
||||
timestamp=excluded.timestamp,
|
||||
unsigned=COALESCE(excluded.unsigned, event.unsigned)
|
||||
unsigned=COALESCE(excluded.unsigned, event.unsigned),
|
||||
local_content=COALESCE(excluded.local_content, event.local_content)
|
||||
ON CONFLICT (transaction_id) DO UPDATE
|
||||
SET event_id=excluded.event_id,
|
||||
timestamp=excluded.timestamp,
|
||||
|
|
@ -59,7 +62,7 @@ const (
|
|||
`
|
||||
updateEventSendErrorQuery = `UPDATE event SET send_error = $2 WHERE rowid = $1`
|
||||
updateEventIDQuery = `UPDATE event SET event_id = $2, send_error = NULL WHERE rowid=$1`
|
||||
updateEventDecryptedQuery = `UPDATE event SET decrypted = $1, decrypted_type = $2, decryption_error = NULL WHERE rowid = $3`
|
||||
updateEventDecryptedQuery = `UPDATE event SET decrypted = $2, decrypted_type = $3, decryption_error = NULL, unread_type = $4, local_content = $5 WHERE rowid = $1`
|
||||
getEventReactionsQuery = getEventBaseQuery + `
|
||||
WHERE room_id = ?
|
||||
AND type = 'm.reaction'
|
||||
|
|
@ -131,8 +134,16 @@ func (eq *EventQuery) UpdateSendError(ctx context.Context, rowID EventRowID, sen
|
|||
return eq.Exec(ctx, updateEventSendErrorQuery, rowID, sendError)
|
||||
}
|
||||
|
||||
func (eq *EventQuery) UpdateDecrypted(ctx context.Context, rowID EventRowID, decrypted json.RawMessage, decryptedType string) error {
|
||||
return eq.Exec(ctx, updateEventDecryptedQuery, unsafeJSONString(decrypted), decryptedType, rowID)
|
||||
func (eq *EventQuery) UpdateDecrypted(ctx context.Context, evt *Event) error {
|
||||
return eq.Exec(
|
||||
ctx,
|
||||
updateEventDecryptedQuery,
|
||||
evt.RowID,
|
||||
unsafeJSONString(evt.Decrypted),
|
||||
evt.DecryptedType,
|
||||
evt.UnreadType,
|
||||
dbutil.JSONPtr(evt.LocalContent),
|
||||
)
|
||||
}
|
||||
|
||||
func (eq *EventQuery) FillReactionCounts(ctx context.Context, roomID id.RoomID, events []*Event) error {
|
||||
|
|
@ -264,6 +275,24 @@ func (m EventRowID) GetMassInsertValues() [1]any {
|
|||
return [1]any{m}
|
||||
}
|
||||
|
||||
type LocalContent struct {
|
||||
SanitizedHTML string `json:"sanitized_html,omitempty"`
|
||||
}
|
||||
|
||||
type UnreadType int
|
||||
|
||||
func (ut UnreadType) Is(flag UnreadType) bool {
|
||||
return ut&flag != 0
|
||||
}
|
||||
|
||||
const (
|
||||
UnreadTypeNone UnreadType = 0b0000
|
||||
UnreadTypeNormal UnreadType = 0b0001
|
||||
UnreadTypeNotify UnreadType = 0b0010
|
||||
UnreadTypeHighlight UnreadType = 0b0100
|
||||
UnreadTypeSound UnreadType = 0b1000
|
||||
)
|
||||
|
||||
type Event struct {
|
||||
RowID EventRowID `json:"rowid"`
|
||||
TimelineRowID TimelineRowID `json:"timeline_rowid"`
|
||||
|
|
@ -279,6 +308,7 @@ type Event struct {
|
|||
Decrypted json.RawMessage `json:"decrypted,omitempty"`
|
||||
DecryptedType string `json:"decrypted_type,omitempty"`
|
||||
Unsigned json.RawMessage `json:"unsigned,omitempty"`
|
||||
LocalContent *LocalContent `json:"local_content,omitempty"`
|
||||
|
||||
TransactionID string `json:"transaction_id,omitempty"`
|
||||
|
||||
|
|
@ -292,6 +322,7 @@ type Event struct {
|
|||
|
||||
Reactions map[string]int `json:"reactions,omitempty"`
|
||||
LastEditRowID *EventRowID `json:"last_edit_rowid,omitempty"`
|
||||
UnreadType UnreadType `json:"unread_type,omitempty"`
|
||||
}
|
||||
|
||||
func MautrixToEvent(evt *event.Event) *Event {
|
||||
|
|
@ -318,6 +349,9 @@ func MautrixToEvent(evt *event.Event) *Event {
|
|||
}
|
||||
|
||||
func (e *Event) AsRawMautrix() *event.Event {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
evt := &event.Event{
|
||||
RoomID: e.RoomID,
|
||||
ID: e.ID,
|
||||
|
|
@ -355,6 +389,7 @@ func (e *Event) Scan(row dbutil.Scannable) (*Event, error) {
|
|||
(*[]byte)(&e.Decrypted),
|
||||
&decryptedType,
|
||||
(*[]byte)(&e.Unsigned),
|
||||
dbutil.JSON{Data: &e.LocalContent},
|
||||
&transactionID,
|
||||
&redactedBy,
|
||||
&relatesTo,
|
||||
|
|
@ -364,6 +399,7 @@ func (e *Event) Scan(row dbutil.Scannable) (*Event, error) {
|
|||
&sendError,
|
||||
dbutil.JSON{Data: &e.Reactions},
|
||||
&e.LastEditRowID,
|
||||
&e.UnreadType,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -425,6 +461,7 @@ func (e *Event) sqlVariables() []any {
|
|||
unsafeJSONString(e.Decrypted),
|
||||
dbutil.StrPtr(e.DecryptedType),
|
||||
unsafeJSONString(e.Unsigned),
|
||||
dbutil.JSONPtr(e.LocalContent),
|
||||
dbutil.StrPtr(e.TransactionID),
|
||||
dbutil.StrPtr(e.RedactedBy),
|
||||
dbutil.StrPtr(e.RelatesTo),
|
||||
|
|
@ -434,9 +471,26 @@ func (e *Event) sqlVariables() []any {
|
|||
dbutil.StrPtr(e.SendError),
|
||||
dbutil.JSON{Data: reactions},
|
||||
e.LastEditRowID,
|
||||
e.UnreadType,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Event) GetNonPushUnreadType() UnreadType {
|
||||
if e.RelationType == event.RelReplace {
|
||||
return UnreadTypeNone
|
||||
}
|
||||
switch e.Type {
|
||||
case event.EventMessage.Type, event.EventSticker.Type:
|
||||
return UnreadTypeNormal
|
||||
case event.EventEncrypted.Type:
|
||||
switch e.DecryptedType {
|
||||
case event.EventMessage.Type, event.EventSticker.Type:
|
||||
return UnreadTypeNormal
|
||||
}
|
||||
}
|
||||
return UnreadTypeNone
|
||||
}
|
||||
|
||||
func (e *Event) CanUseForPreview() bool {
|
||||
return (e.Type == event.EventMessage.Type || e.Type == event.EventSticker.Type ||
|
||||
(e.Type == event.EventEncrypted.Type &&
|
||||
|
|
|
|||
|
|
@ -23,8 +23,8 @@ import (
|
|||
const (
|
||||
getRoomBaseQuery = `
|
||||
SELECT room_id, creation_content, name, name_quality, avatar, explicit_avatar, topic, canonical_alias,
|
||||
lazy_load_summary, encryption_event, has_member_list,
|
||||
preview_event_rowid, sorting_timestamp, prev_batch
|
||||
lazy_load_summary, encryption_event, has_member_list, preview_event_rowid, sorting_timestamp,
|
||||
unread_highlights, unread_notifications, unread_messages, prev_batch
|
||||
FROM room
|
||||
`
|
||||
getRoomsBySortingTimestampQuery = getRoomBaseQuery + `WHERE sorting_timestamp < $1 AND sorting_timestamp > 0 ORDER BY sorting_timestamp DESC LIMIT $2`
|
||||
|
|
@ -47,7 +47,10 @@ const (
|
|||
has_member_list = room.has_member_list OR $11,
|
||||
preview_event_rowid = COALESCE($12, room.preview_event_rowid),
|
||||
sorting_timestamp = COALESCE($13, room.sorting_timestamp),
|
||||
prev_batch = COALESCE($14, room.prev_batch)
|
||||
unread_highlights = COALESCE($14, room.unread_highlights),
|
||||
unread_notifications = COALESCE($15, room.unread_notifications),
|
||||
unread_messages = COALESCE($16, room.unread_messages),
|
||||
prev_batch = COALESCE($17, room.prev_batch)
|
||||
WHERE room_id = $1
|
||||
`
|
||||
setRoomPrevBatchQuery = `
|
||||
|
|
@ -143,8 +146,11 @@ type Room struct {
|
|||
EncryptionEvent *event.EncryptionEventContent `json:"encryption_event,omitempty"`
|
||||
HasMemberList bool `json:"has_member_list"`
|
||||
|
||||
PreviewEventRowID EventRowID `json:"preview_event_rowid"`
|
||||
SortingTimestamp jsontime.UnixMilli `json:"sorting_timestamp"`
|
||||
PreviewEventRowID EventRowID `json:"preview_event_rowid"`
|
||||
SortingTimestamp jsontime.UnixMilli `json:"sorting_timestamp"`
|
||||
UnreadHighlights int `json:"unread_highlights"`
|
||||
UnreadNotifications int `json:"unread_notifications"`
|
||||
UnreadMessages int `json:"unread_messages"`
|
||||
|
||||
PrevBatch string `json:"prev_batch"`
|
||||
}
|
||||
|
|
@ -188,6 +194,18 @@ func (r *Room) CheckChangesAndCopyInto(other *Room) (hasChanges bool) {
|
|||
other.SortingTimestamp = r.SortingTimestamp
|
||||
hasChanges = true
|
||||
}
|
||||
if r.UnreadHighlights != other.UnreadHighlights {
|
||||
other.UnreadHighlights = r.UnreadHighlights
|
||||
hasChanges = true
|
||||
}
|
||||
if r.UnreadNotifications != other.UnreadNotifications {
|
||||
other.UnreadNotifications = r.UnreadNotifications
|
||||
hasChanges = true
|
||||
}
|
||||
if r.UnreadMessages != other.UnreadMessages {
|
||||
other.UnreadMessages = r.UnreadMessages
|
||||
hasChanges = true
|
||||
}
|
||||
if r.PrevBatch != "" && other.PrevBatch == "" {
|
||||
other.PrevBatch = r.PrevBatch
|
||||
hasChanges = true
|
||||
|
|
@ -212,6 +230,9 @@ func (r *Room) Scan(row dbutil.Scannable) (*Room, error) {
|
|||
&r.HasMemberList,
|
||||
&previewEventRowID,
|
||||
&sortingTimestamp,
|
||||
&r.UnreadHighlights,
|
||||
&r.UnreadNotifications,
|
||||
&r.UnreadMessages,
|
||||
&prevBatch,
|
||||
)
|
||||
if err != nil {
|
||||
|
|
@ -238,6 +259,9 @@ func (r *Room) sqlVariables() []any {
|
|||
r.HasMemberList,
|
||||
dbutil.NumPtr(r.PreviewEventRowID),
|
||||
dbutil.UnixMilliPtr(r.SortingTimestamp.Time),
|
||||
r.UnreadHighlights,
|
||||
r.UnreadNotifications,
|
||||
r.UnreadMessages,
|
||||
dbutil.StrPtr(r.PrevBatch),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,8 +30,10 @@ const (
|
|||
DELETE FROM current_state WHERE room_id = $1
|
||||
`
|
||||
getCurrentRoomStateQuery = `
|
||||
SELECT event.rowid, -1, event.room_id, event.event_id, sender, event.type, event.state_key, timestamp, content, decrypted, decrypted_type, unsigned,
|
||||
transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid
|
||||
SELECT event.rowid, -1,
|
||||
event.room_id, event.event_id, sender, event.type, event.state_key, timestamp, content, decrypted, decrypted_type,
|
||||
unsigned, local_content, transaction_id, redacted_by, relates_to, relation_type,
|
||||
megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid, unread_type
|
||||
FROM current_state cs
|
||||
JOIN event ON cs.event_rowid = event.rowid
|
||||
WHERE cs.room_id = $1
|
||||
|
|
|
|||
|
|
@ -34,8 +34,10 @@ const (
|
|||
`
|
||||
findMinRowIDQuery = `SELECT MIN(rowid) FROM timeline`
|
||||
getTimelineQuery = `
|
||||
SELECT event.rowid, timeline.rowid, event.room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned,
|
||||
transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid
|
||||
SELECT event.rowid, timeline.rowid,
|
||||
event.room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type,
|
||||
unsigned, local_content, transaction_id, redacted_by, relates_to, relation_type,
|
||||
megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid, unread_type
|
||||
FROM timeline
|
||||
JOIN event ON event.rowid = timeline.event_rowid
|
||||
WHERE timeline.room_id = $1 AND ($2 = 0 OR timeline.rowid < $2)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
-- v0 -> v2 (compatible with v1+): Latest revision
|
||||
-- v0 -> v3 (compatible with v1+): Latest revision
|
||||
CREATE TABLE account (
|
||||
user_id TEXT NOT NULL PRIMARY KEY,
|
||||
device_id TEXT NOT NULL,
|
||||
|
|
@ -9,29 +9,34 @@ CREATE TABLE account (
|
|||
) STRICT;
|
||||
|
||||
CREATE TABLE room (
|
||||
room_id TEXT NOT NULL PRIMARY KEY,
|
||||
creation_content TEXT,
|
||||
room_id TEXT NOT NULL PRIMARY KEY,
|
||||
creation_content TEXT,
|
||||
|
||||
name TEXT,
|
||||
name_quality INTEGER NOT NULL DEFAULT 0,
|
||||
avatar TEXT,
|
||||
explicit_avatar INTEGER NOT NULL DEFAULT 0,
|
||||
topic TEXT,
|
||||
canonical_alias TEXT,
|
||||
lazy_load_summary TEXT,
|
||||
name TEXT,
|
||||
name_quality INTEGER NOT NULL DEFAULT 0,
|
||||
avatar TEXT,
|
||||
explicit_avatar INTEGER NOT NULL DEFAULT 0,
|
||||
topic TEXT,
|
||||
canonical_alias TEXT,
|
||||
lazy_load_summary TEXT,
|
||||
|
||||
encryption_event TEXT,
|
||||
has_member_list INTEGER NOT NULL DEFAULT false,
|
||||
encryption_event TEXT,
|
||||
has_member_list INTEGER NOT NULL DEFAULT false,
|
||||
|
||||
preview_event_rowid INTEGER,
|
||||
sorting_timestamp INTEGER,
|
||||
preview_event_rowid INTEGER,
|
||||
sorting_timestamp INTEGER,
|
||||
unread_highlights INTEGER NOT NULL DEFAULT 0,
|
||||
unread_notifications INTEGER NOT NULL DEFAULT 0,
|
||||
unread_messages INTEGER NOT NULL DEFAULT 0,
|
||||
|
||||
prev_batch TEXT,
|
||||
prev_batch TEXT,
|
||||
|
||||
CONSTRAINT room_preview_event_fkey FOREIGN KEY (preview_event_rowid) REFERENCES event (rowid) ON DELETE SET NULL
|
||||
) STRICT;
|
||||
CREATE INDEX room_type_idx ON room (creation_content ->> 'type');
|
||||
CREATE INDEX room_sorting_timestamp_idx ON room (sorting_timestamp DESC);
|
||||
-- CREATE INDEX room_sorting_timestamp_idx ON room (unread_notifications > 0);
|
||||
-- CREATE INDEX room_sorting_timestamp_idx ON room (unread_messages > 0);
|
||||
|
||||
CREATE TABLE account_data (
|
||||
user_id TEXT NOT NULL,
|
||||
|
|
@ -66,6 +71,7 @@ CREATE TABLE event (
|
|||
decrypted TEXT,
|
||||
decrypted_type TEXT,
|
||||
unsigned TEXT NOT NULL,
|
||||
local_content TEXT,
|
||||
|
||||
transaction_id TEXT,
|
||||
|
||||
|
|
@ -79,6 +85,7 @@ CREATE TABLE event (
|
|||
|
||||
reactions TEXT,
|
||||
last_edit_rowid INTEGER,
|
||||
unread_type INTEGER NOT NULL DEFAULT 0,
|
||||
|
||||
CONSTRAINT event_id_unique_key UNIQUE (event_id),
|
||||
CONSTRAINT transaction_id_unique_key UNIQUE (transaction_id),
|
||||
|
|
|
|||
6
hicli/database/upgrades/03-more-event-fields.sql
Normal file
6
hicli/database/upgrades/03-more-event-fields.sql
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
-- v3 (compatible with v1+): Add more fields to events
|
||||
ALTER TABLE event ADD COLUMN local_content TEXT;
|
||||
ALTER TABLE event ADD COLUMN unread_type INTEGER NOT NULL DEFAULT 0;
|
||||
ALTER TABLE room ADD COLUMN unread_highlights INTEGER NOT NULL DEFAULT 0;
|
||||
ALTER TABLE room ADD COLUMN unread_notifications INTEGER NOT NULL DEFAULT 0;
|
||||
ALTER TABLE room ADD COLUMN unread_messages INTEGER NOT NULL DEFAULT 0;
|
||||
|
|
@ -59,14 +59,14 @@ func (h *HiClient) handleReceivedMegolmSession(ctx context.Context, roomID id.Ro
|
|||
log.Warn().Err(err).Stringer("event_id", evt.ID).Msg("Failed to decrypt event even after receiving megolm session")
|
||||
} else {
|
||||
decrypted = append(decrypted, evt)
|
||||
h.cacheMedia(ctx, mautrixEvt, evt.RowID)
|
||||
h.postDecryptProcess(ctx, nil, evt, mautrixEvt)
|
||||
}
|
||||
}
|
||||
if len(decrypted) > 0 {
|
||||
var newPreview database.EventRowID
|
||||
err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
|
||||
for _, evt := range decrypted {
|
||||
err = h.DB.Event.UpdateDecrypted(ctx, evt.RowID, evt.Decrypted, evt.DecryptedType)
|
||||
err = h.DB.Event.UpdateDecrypted(ctx, evt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save decrypted content for %s: %w", evt.ID, err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,11 +13,17 @@ import (
|
|||
)
|
||||
|
||||
type SyncRoom struct {
|
||||
Meta *database.Room `json:"meta"`
|
||||
Timeline []database.TimelineRowTuple `json:"timeline"`
|
||||
State map[event.Type]map[string]database.EventRowID `json:"state"`
|
||||
Events []*database.Event `json:"events"`
|
||||
Reset bool `json:"reset"`
|
||||
Meta *database.Room `json:"meta"`
|
||||
Timeline []database.TimelineRowTuple `json:"timeline"`
|
||||
State map[event.Type]map[string]database.EventRowID `json:"state"`
|
||||
Events []*database.Event `json:"events"`
|
||||
Reset bool `json:"reset"`
|
||||
Notifications []SyncNotification `json:"notifications"`
|
||||
}
|
||||
|
||||
type SyncNotification struct {
|
||||
RowID database.EventRowID `json:"event_rowid"`
|
||||
Sound bool `json:"sound"`
|
||||
}
|
||||
|
||||
type SyncComplete struct {
|
||||
|
|
|
|||
476
hicli/html.go
Normal file
476
hicli/html.go
Normal file
|
|
@ -0,0 +1,476 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package hicli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/html"
|
||||
"golang.org/x/net/html/atom"
|
||||
"mvdan.cc/xurls/v2"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
func tagIsAllowed(tag atom.Atom) bool {
|
||||
switch tag {
|
||||
case atom.Del, atom.H1, atom.H2, atom.H3, atom.H4, atom.H5, atom.H6, atom.Blockquote, atom.P,
|
||||
atom.A, atom.Ul, atom.Ol, atom.Sup, atom.Sub, atom.Li, atom.B, atom.I, atom.U, atom.Strong,
|
||||
atom.Em, atom.S, atom.Code, atom.Hr, atom.Br, atom.Div, atom.Table, atom.Thead, atom.Tbody,
|
||||
atom.Tr, atom.Th, atom.Td, atom.Caption, atom.Pre, atom.Span, atom.Font, atom.Img,
|
||||
atom.Details, atom.Summary:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isSelfClosing(tag atom.Atom) bool {
|
||||
switch tag {
|
||||
case atom.Img, atom.Br, atom.Hr:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
var languageRegex = regexp.MustCompile(`^language-[a-zA-Z0-9-]+$`)
|
||||
var allowedColorRegex = regexp.MustCompile(`^#[0-9a-fA-F]{6}$`)
|
||||
|
||||
// This is approximately a mirror of web/src/util/mediasize.ts in gomuks
|
||||
func calculateMediaSize(widthInt, heightInt int) (width, height float64, ok bool) {
|
||||
if widthInt <= 0 || heightInt <= 0 {
|
||||
return
|
||||
}
|
||||
width = float64(widthInt)
|
||||
height = float64(heightInt)
|
||||
const imageContainerWidth float64 = 320
|
||||
const imageContainerHeight float64 = 240
|
||||
const imageContainerAspectRatio = imageContainerWidth / imageContainerHeight
|
||||
if width > imageContainerWidth || height > imageContainerHeight {
|
||||
aspectRatio := width / height
|
||||
if aspectRatio > imageContainerAspectRatio {
|
||||
width = imageContainerWidth
|
||||
height = imageContainerWidth / aspectRatio
|
||||
} else if aspectRatio < imageContainerAspectRatio {
|
||||
width = imageContainerHeight * aspectRatio
|
||||
height = imageContainerHeight
|
||||
} else {
|
||||
width = imageContainerWidth
|
||||
height = imageContainerHeight
|
||||
}
|
||||
}
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
|
||||
func parseImgAttributes(attrs []html.Attribute) (src, alt, title string, isCustomEmoji bool, width, height int) {
|
||||
for _, attr := range attrs {
|
||||
switch attr.Key {
|
||||
case "src":
|
||||
src = attr.Val
|
||||
case "alt":
|
||||
alt = attr.Val
|
||||
case "title":
|
||||
title = attr.Val
|
||||
case "data-mx-emoticon":
|
||||
isCustomEmoji = true
|
||||
case "width":
|
||||
width, _ = strconv.Atoi(attr.Val)
|
||||
case "height":
|
||||
height, _ = strconv.Atoi(attr.Val)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func parseSpanAttributes(attrs []html.Attribute) (bgColor, textColor, spoiler, maths string, isSpoiler bool) {
|
||||
for _, attr := range attrs {
|
||||
switch attr.Key {
|
||||
case "data-mx-bg-color":
|
||||
if allowedColorRegex.MatchString(attr.Val) {
|
||||
bgColor = attr.Val
|
||||
}
|
||||
case "data-mx-color", "color":
|
||||
if allowedColorRegex.MatchString(attr.Val) {
|
||||
textColor = attr.Val
|
||||
}
|
||||
case "data-mx-spoiler":
|
||||
spoiler = attr.Val
|
||||
isSpoiler = true
|
||||
case "data-mx-maths":
|
||||
maths = attr.Val
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func parseAAttributes(attrs []html.Attribute) (href string) {
|
||||
for _, attr := range attrs {
|
||||
switch attr.Key {
|
||||
case "href":
|
||||
href = strings.TrimSpace(attr.Val)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func attributeIsAllowed(tag atom.Atom, attr html.Attribute) bool {
|
||||
switch tag {
|
||||
case atom.Ol:
|
||||
switch attr.Key {
|
||||
case "start":
|
||||
_, err := strconv.Atoi(attr.Val)
|
||||
return err == nil
|
||||
}
|
||||
case atom.Code:
|
||||
switch attr.Key {
|
||||
case "class":
|
||||
return languageRegex.MatchString(attr.Val)
|
||||
}
|
||||
case atom.Div:
|
||||
switch attr.Key {
|
||||
case "data-mx-maths":
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Funny user IDs will just need to be linkified by the sender, no auto-linkification for them.
|
||||
var plainUserOrAliasMentionRegex = regexp.MustCompile(`[@#][a-zA-Z0-9._=/+-]{0,254}:[a-zA-Z0-9.-]+(?:\d{1,5})?`)
|
||||
|
||||
func getNextItem(items [][]int, minIndex int) (index, start, end int, ok bool) {
|
||||
for i, item := range items {
|
||||
if item[0] >= minIndex {
|
||||
return i, item[0], item[1], true
|
||||
}
|
||||
}
|
||||
return -1, -1, -1, false
|
||||
}
|
||||
|
||||
func writeMention(w *strings.Builder, mention []byte) {
|
||||
w.WriteString(`<a class="hicli-matrix-uri"`)
|
||||
writeAttribute(w, "href", (&id.MatrixURI{
|
||||
Sigil1: rune(mention[0]),
|
||||
MXID1: string(mention[1:]),
|
||||
}).String())
|
||||
w.WriteByte('>')
|
||||
writeEscapedBytes(w, mention)
|
||||
w.WriteString("</a>")
|
||||
}
|
||||
|
||||
func writeURL(w *strings.Builder, addr []byte) {
|
||||
parsedURL, err := url.Parse(string(addr))
|
||||
if err != nil {
|
||||
writeEscapedBytes(w, addr)
|
||||
return
|
||||
}
|
||||
if parsedURL.Scheme == "" {
|
||||
parsedURL.Scheme = "https"
|
||||
}
|
||||
w.WriteString(`<a target="_blank" rel="noreferrer noopener"`)
|
||||
writeAttribute(w, "href", parsedURL.String())
|
||||
w.WriteByte('>')
|
||||
writeEscapedBytes(w, addr)
|
||||
w.WriteString("</a>")
|
||||
}
|
||||
|
||||
func linkifyAndWriteBytes(w *strings.Builder, s []byte) {
|
||||
mentions := plainUserOrAliasMentionRegex.FindAllIndex(s, -1)
|
||||
urls := xurls.Relaxed().FindAllIndex(s, -1)
|
||||
minIndex := 0
|
||||
for {
|
||||
mentionIdx, nextMentionStart, nextMentionEnd, hasMention := getNextItem(mentions, minIndex)
|
||||
urlIdx, nextURLStart, nextURLEnd, hasURL := getNextItem(urls, minIndex)
|
||||
if hasMention && (!hasURL || nextMentionStart <= nextURLStart) {
|
||||
writeEscapedBytes(w, s[minIndex:nextMentionStart])
|
||||
writeMention(w, s[nextMentionStart:nextMentionEnd])
|
||||
minIndex = nextMentionEnd
|
||||
mentions = mentions[mentionIdx:]
|
||||
} else if hasURL && (!hasMention || nextURLStart < nextMentionStart) {
|
||||
writeEscapedBytes(w, s[minIndex:nextURLStart])
|
||||
writeURL(w, s[nextURLStart:nextURLEnd])
|
||||
minIndex = nextURLEnd
|
||||
urls = urls[urlIdx:]
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
writeEscapedBytes(w, s[minIndex:])
|
||||
}
|
||||
|
||||
const escapedChars = "&'<>\"\r"
|
||||
|
||||
func writeEscapedBytes(w *strings.Builder, s []byte) {
|
||||
i := bytes.IndexAny(s, escapedChars)
|
||||
for i != -1 {
|
||||
w.Write(s[:i])
|
||||
var esc string
|
||||
switch s[i] {
|
||||
case '&':
|
||||
esc = "&"
|
||||
case '\'':
|
||||
// "'" is shorter than "'" and apos was not in HTML until HTML5.
|
||||
esc = "'"
|
||||
case '<':
|
||||
esc = "<"
|
||||
case '>':
|
||||
esc = ">"
|
||||
case '"':
|
||||
// """ is shorter than """.
|
||||
esc = """
|
||||
case '\r':
|
||||
esc = " "
|
||||
default:
|
||||
panic("unrecognized escape character")
|
||||
}
|
||||
s = s[i+1:]
|
||||
w.WriteString(esc)
|
||||
i = bytes.IndexAny(s, escapedChars)
|
||||
}
|
||||
w.Write(s)
|
||||
}
|
||||
|
||||
func writeEscapedString(w *strings.Builder, s string) {
|
||||
i := strings.IndexAny(s, escapedChars)
|
||||
for i != -1 {
|
||||
w.WriteString(s[:i])
|
||||
var esc string
|
||||
switch s[i] {
|
||||
case '&':
|
||||
esc = "&"
|
||||
case '\'':
|
||||
// "'" is shorter than "'" and apos was not in HTML until HTML5.
|
||||
esc = "'"
|
||||
case '<':
|
||||
esc = "<"
|
||||
case '>':
|
||||
esc = ">"
|
||||
case '"':
|
||||
// """ is shorter than """.
|
||||
esc = """
|
||||
case '\r':
|
||||
esc = " "
|
||||
default:
|
||||
panic("unrecognized escape character")
|
||||
}
|
||||
s = s[i+1:]
|
||||
w.WriteString(esc)
|
||||
i = strings.IndexAny(s, escapedChars)
|
||||
}
|
||||
w.WriteString(s)
|
||||
}
|
||||
|
||||
func writeAttribute(w *strings.Builder, key, value string) {
|
||||
w.WriteByte(' ')
|
||||
w.WriteString(key)
|
||||
w.WriteString(`="`)
|
||||
writeEscapedString(w, value)
|
||||
w.WriteByte('"')
|
||||
}
|
||||
|
||||
func writeA(w *strings.Builder, attr []html.Attribute) {
|
||||
w.WriteString("<a")
|
||||
href := parseAAttributes(attr)
|
||||
if href == "" {
|
||||
return
|
||||
}
|
||||
parsedURL, err := url.Parse(href)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
newTab := true
|
||||
switch parsedURL.Scheme {
|
||||
case "bitcoin", "ftp", "geo", "http", "im", "irc", "ircs", "magnet", "mailto",
|
||||
"mms", "news", "nntp", "openpgp4fpr", "sip", "sftp", "sms", "smsto", "ssh",
|
||||
"tel", "urn", "webcal", "wtai", "xmpp":
|
||||
// allowed
|
||||
case "https":
|
||||
if parsedURL.Host == "matrix.to" {
|
||||
uri, err := id.ProcessMatrixToURL(parsedURL)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
href = uri.String()
|
||||
newTab = false
|
||||
writeAttribute(w, "class", "hicli-matrix-uri")
|
||||
}
|
||||
case "matrix":
|
||||
uri, err := id.ProcessMatrixURI(parsedURL)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
href = uri.String()
|
||||
newTab = false
|
||||
writeAttribute(w, "class", "hicli-matrix-uri")
|
||||
case "mxc":
|
||||
mxc := id.ContentURIString(href).ParseOrIgnore()
|
||||
if !mxc.IsValid() {
|
||||
return
|
||||
}
|
||||
href = fmt.Sprintf(HTMLSanitizerImgSrcTemplate, mxc.Homeserver, mxc.FileID)
|
||||
default:
|
||||
return
|
||||
}
|
||||
writeAttribute(w, "href", href)
|
||||
if newTab {
|
||||
writeAttribute(w, "target", "_blank")
|
||||
writeAttribute(w, "rel", "noreferrer noopener")
|
||||
}
|
||||
}
|
||||
|
||||
var HTMLSanitizerImgSrcTemplate = "mxc://%s/%s"
|
||||
|
||||
func writeImg(w *strings.Builder, attr []html.Attribute) {
|
||||
src, alt, title, isCustomEmoji, width, height := parseImgAttributes(attr)
|
||||
w.WriteString("<img")
|
||||
writeAttribute(w, "alt", alt)
|
||||
if title != "" {
|
||||
writeAttribute(w, "title", title)
|
||||
}
|
||||
mxc := id.ContentURIString(src).ParseOrIgnore()
|
||||
if !mxc.IsValid() {
|
||||
return
|
||||
}
|
||||
writeAttribute(w, "src", fmt.Sprintf(HTMLSanitizerImgSrcTemplate, mxc.Homeserver, mxc.FileID))
|
||||
writeAttribute(w, "loading", "lazy")
|
||||
if isCustomEmoji {
|
||||
writeAttribute(w, "class", "hicli-custom-emoji")
|
||||
} else if cWidth, cHeight, sizeOK := calculateMediaSize(width, height); sizeOK {
|
||||
writeAttribute(w, "class", "hicli-sized-inline-img")
|
||||
writeAttribute(w, "style", fmt.Sprintf("width: %.2fpx; height: %.2fpx;", cWidth, cHeight))
|
||||
} else {
|
||||
writeAttribute(w, "class", "hicli-sizeless-inline-img")
|
||||
}
|
||||
}
|
||||
|
||||
func writeSpan(w *strings.Builder, attr []html.Attribute) {
|
||||
bgColor, textColor, spoiler, _, isSpoiler := parseSpanAttributes(attr)
|
||||
if isSpoiler && spoiler != "" {
|
||||
w.WriteString(`<span class="spoiler-reason">`)
|
||||
w.WriteString(spoiler)
|
||||
w.WriteString(" </span>")
|
||||
}
|
||||
w.WriteByte('<')
|
||||
w.WriteString("span")
|
||||
if isSpoiler {
|
||||
writeAttribute(w, "class", "hicli-spoiler")
|
||||
}
|
||||
var style string
|
||||
if bgColor != "" {
|
||||
style += fmt.Sprintf("background-color: %s;", bgColor)
|
||||
}
|
||||
if textColor != "" {
|
||||
style += fmt.Sprintf("color: %s;", textColor)
|
||||
}
|
||||
if style != "" {
|
||||
writeAttribute(w, "style", style)
|
||||
}
|
||||
}
|
||||
|
||||
type tagStack []atom.Atom
|
||||
|
||||
func (ts *tagStack) contains(tags ...atom.Atom) bool {
|
||||
for i := len(*ts) - 1; i >= 0; i-- {
|
||||
for _, tag := range tags {
|
||||
if (*ts)[i] == tag {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (ts *tagStack) push(tag atom.Atom) {
|
||||
*ts = append(*ts, tag)
|
||||
}
|
||||
|
||||
func (ts *tagStack) pop(tag atom.Atom) bool {
|
||||
if len(*ts) > 0 && (*ts)[len(*ts)-1] == tag {
|
||||
*ts = (*ts)[:len(*ts)-1]
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func sanitizeAndLinkifyHTML(body string) (string, error) {
|
||||
tz := html.NewTokenizer(strings.NewReader(body))
|
||||
var built strings.Builder
|
||||
ts := make(tagStack, 2)
|
||||
Loop:
|
||||
for {
|
||||
switch tz.Next() {
|
||||
case html.ErrorToken:
|
||||
err := tz.Err()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break Loop
|
||||
}
|
||||
return "", err
|
||||
case html.StartTagToken, html.SelfClosingTagToken:
|
||||
token := tz.Token()
|
||||
if !tagIsAllowed(token.DataAtom) {
|
||||
continue
|
||||
}
|
||||
tagIsSelfClosing := isSelfClosing(token.DataAtom)
|
||||
if token.Type == html.SelfClosingTagToken && !tagIsSelfClosing {
|
||||
continue
|
||||
}
|
||||
switch token.DataAtom {
|
||||
case atom.A:
|
||||
writeA(&built, token.Attr)
|
||||
case atom.Img:
|
||||
writeImg(&built, token.Attr)
|
||||
case atom.Span, atom.Font:
|
||||
writeSpan(&built, token.Attr)
|
||||
default:
|
||||
built.WriteByte('<')
|
||||
built.WriteString(token.Data)
|
||||
for _, attr := range token.Attr {
|
||||
if attributeIsAllowed(token.DataAtom, attr) {
|
||||
writeAttribute(&built, attr.Key, attr.Val)
|
||||
}
|
||||
}
|
||||
}
|
||||
built.WriteByte('>')
|
||||
if !tagIsSelfClosing {
|
||||
ts.push(token.DataAtom)
|
||||
}
|
||||
case html.EndTagToken:
|
||||
tagName, _ := tz.TagName()
|
||||
tag := atom.Lookup(tagName)
|
||||
if tagIsAllowed(tag) && ts.pop(tag) {
|
||||
built.WriteString("</")
|
||||
built.Write(tagName)
|
||||
built.WriteByte('>')
|
||||
}
|
||||
case html.TextToken:
|
||||
if ts.contains(atom.Pre, atom.Code, atom.A) {
|
||||
writeEscapedBytes(&built, tz.Text())
|
||||
} else {
|
||||
linkifyAndWriteBytes(&built, tz.Text())
|
||||
}
|
||||
case html.DoctypeToken, html.CommentToken:
|
||||
// ignore
|
||||
}
|
||||
}
|
||||
slices.Reverse(ts)
|
||||
for _, t := range ts {
|
||||
built.WriteString("</")
|
||||
built.WriteString(t.String())
|
||||
built.WriteByte('>')
|
||||
}
|
||||
return built.String(), nil
|
||||
}
|
||||
|
|
@ -59,7 +59,7 @@ func (h *HiClient) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.Ev
|
|||
} else if serverEvt, err := h.Client.GetEvent(ctx, roomID, eventID); err != nil {
|
||||
return nil, fmt.Errorf("failed to get event from server: %w", err)
|
||||
} else {
|
||||
return h.processEvent(ctx, serverEvt, nil, false)
|
||||
return h.processEvent(ctx, serverEvt, nil, nil, false)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -90,7 +90,7 @@ func (h *HiClient) GetRoomState(ctx context.Context, roomID id.RoomID, fetchMemb
|
|||
}
|
||||
entries := make([]*database.CurrentStateEntry, len(evts))
|
||||
for i, evt := range evts {
|
||||
dbEvt, err := h.processEvent(ctx, evt, nil, false)
|
||||
dbEvt, err := h.processEvent(ctx, evt, room.LazyLoadSummary, nil, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process event %s: %w", evt.ID, err)
|
||||
}
|
||||
|
|
@ -186,7 +186,7 @@ func (h *HiClient) PaginateServer(ctx context.Context, roomID id.RoomID, limit i
|
|||
decryptionQueue := make(map[id.SessionID]*database.SessionRequest)
|
||||
iOffset := 0
|
||||
for i, evt := range resp.Chunk {
|
||||
dbEvt, err := h.processEvent(ctx, evt, decryptionQueue, true)
|
||||
dbEvt, err := h.processEvent(ctx, evt, room.LazyLoadSummary, decryptionQueue, true)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if exists, err := h.DB.Timeline.Has(ctx, roomID, dbEvt.RowID); err != nil {
|
||||
|
|
|
|||
80
hicli/pushrules.go
Normal file
80
hicli/pushrules.go
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package hicli
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/hicli/database"
|
||||
"maunium.net/go/mautrix/id"
|
||||
"maunium.net/go/mautrix/pushrules"
|
||||
)
|
||||
|
||||
type pushRoom struct {
|
||||
ctx context.Context
|
||||
roomID id.RoomID
|
||||
h *HiClient
|
||||
ll *mautrix.LazyLoadSummary
|
||||
}
|
||||
|
||||
func (p *pushRoom) GetOwnDisplayname() string {
|
||||
// TODO implement
|
||||
return ""
|
||||
}
|
||||
|
||||
func (p *pushRoom) GetMemberCount() int {
|
||||
if p.ll == nil {
|
||||
room, err := p.h.DB.Room.Get(p.ctx, p.roomID)
|
||||
if err != nil {
|
||||
zerolog.Ctx(p.ctx).Err(err).
|
||||
Stringer("room_id", p.roomID).
|
||||
Msg("Failed to get room by ID in push rule evaluator")
|
||||
} else if room != nil {
|
||||
p.ll = room.LazyLoadSummary
|
||||
}
|
||||
}
|
||||
if p.ll != nil && p.ll.JoinedMemberCount != nil {
|
||||
return *p.ll.JoinedMemberCount
|
||||
}
|
||||
// TODO query db?
|
||||
return 0
|
||||
}
|
||||
|
||||
func (p *pushRoom) GetEvent(id id.EventID) *event.Event {
|
||||
evt, err := p.h.DB.Event.GetByID(p.ctx, id)
|
||||
if err != nil {
|
||||
zerolog.Ctx(p.ctx).Err(err).
|
||||
Stringer("event_id", id).
|
||||
Msg("Failed to get event by ID in push rule evaluator")
|
||||
}
|
||||
return evt.AsRawMautrix()
|
||||
}
|
||||
|
||||
var _ pushrules.EventfulRoom = (*pushRoom)(nil)
|
||||
|
||||
func (h *HiClient) evaluatePushRules(ctx context.Context, llSummary *mautrix.LazyLoadSummary, baseType database.UnreadType, evt *event.Event) database.UnreadType {
|
||||
should := h.PushRules.Load().GetMatchingRule(&pushRoom{
|
||||
ctx: ctx,
|
||||
roomID: evt.RoomID,
|
||||
h: h,
|
||||
ll: llSummary,
|
||||
}, evt).GetActions().Should()
|
||||
if should.Notify {
|
||||
baseType |= database.UnreadTypeNotify
|
||||
}
|
||||
if should.Highlight {
|
||||
baseType |= database.UnreadTypeHighlight
|
||||
}
|
||||
if should.PlaySound {
|
||||
baseType |= database.UnreadTypeSound
|
||||
}
|
||||
return baseType
|
||||
}
|
||||
|
|
@ -218,7 +218,7 @@ func (h *HiClient) loadMembers(ctx context.Context, room *database.Room) error {
|
|||
err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
|
||||
entries := make([]*database.CurrentStateEntry, len(resp.Chunk))
|
||||
for i, evt := range resp.Chunk {
|
||||
dbEvt, err := h.processEvent(ctx, evt, nil, true)
|
||||
dbEvt, err := h.processEvent(ctx, evt, nil, nil, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
141
hicli/sync.go
141
hicli/sync.go
|
|
@ -134,11 +134,15 @@ func (h *HiClient) processSyncResponse(ctx context.Context, resp *mautrix.RespSy
|
|||
return nil
|
||||
}
|
||||
|
||||
func receiptsToList(content *event.ReceiptEventContent) []*database.Receipt {
|
||||
func (h *HiClient) receiptsToList(content *event.ReceiptEventContent) ([]*database.Receipt, []id.EventID) {
|
||||
receiptList := make([]*database.Receipt, 0)
|
||||
var newOwnReceipts []id.EventID
|
||||
for eventID, receipts := range *content {
|
||||
for receiptType, users := range receipts {
|
||||
for userID, receiptInfo := range users {
|
||||
if userID == h.Account.UserID {
|
||||
newOwnReceipts = append(newOwnReceipts, eventID)
|
||||
}
|
||||
receiptList = append(receiptList, &database.Receipt{
|
||||
UserID: userID,
|
||||
ReceiptType: receiptType,
|
||||
|
|
@ -149,7 +153,12 @@ func receiptsToList(content *event.ReceiptEventContent) []*database.Receipt {
|
|||
}
|
||||
}
|
||||
}
|
||||
return receiptList
|
||||
return receiptList, newOwnReceipts
|
||||
}
|
||||
|
||||
type receiptsToSave struct {
|
||||
roomID id.RoomID
|
||||
receipts []*database.Receipt
|
||||
}
|
||||
|
||||
func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncJoinedRoom) error {
|
||||
|
|
@ -172,10 +181,8 @@ func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID,
|
|||
return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err)
|
||||
}
|
||||
}
|
||||
err = h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var receipts []receiptsToSave
|
||||
var newOwnReceipts []id.EventID
|
||||
for _, evt := range room.Ephemeral.Events {
|
||||
evt.Type.Class = event.EphemeralEventType
|
||||
err = evt.Content.ParseRaw(evt.Type)
|
||||
|
|
@ -185,18 +192,24 @@ func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID,
|
|||
}
|
||||
switch evt.Type {
|
||||
case event.EphemeralEventReceipt:
|
||||
err = h.DB.Receipt.PutMany(ctx, roomID, receiptsToList(evt.Content.AsReceipt())...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save receipts: %w", err)
|
||||
}
|
||||
var receiptsList []*database.Receipt
|
||||
receiptsList, newOwnReceipts = h.receiptsToList(evt.Content.AsReceipt())
|
||||
receipts = append(receipts, receiptsToSave{roomID, receiptsList})
|
||||
case event.EphemeralEventTyping:
|
||||
go h.EventHandler(&Typing{
|
||||
RoomID: roomID,
|
||||
TypingEventContent: *evt.Content.AsTyping(),
|
||||
})
|
||||
}
|
||||
if evt.Type != event.EphemeralEventReceipt {
|
||||
continue
|
||||
}
|
||||
err = h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary, newOwnReceipts, room.UnreadNotifications)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, rs := range receipts {
|
||||
err = h.DB.Receipt.PutMany(ctx, rs.roomID, rs.receipts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save receipts: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
|
@ -209,7 +222,8 @@ func (h *HiClient) processSyncLeftRoom(ctx context.Context, roomID id.RoomID, ro
|
|||
} else if existingRoomData == nil {
|
||||
return nil
|
||||
}
|
||||
return h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary)
|
||||
// TODO delete room instead of processing?
|
||||
return h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary, nil, nil)
|
||||
}
|
||||
|
||||
func isDecryptionErrorRetryable(err error) bool {
|
||||
|
|
@ -318,7 +332,47 @@ func (h *HiClient) cacheMedia(ctx context.Context, evt *event.Event, rowID datab
|
|||
}
|
||||
}
|
||||
|
||||
func (h *HiClient) processEvent(ctx context.Context, evt *event.Event, decryptionQueue map[id.SessionID]*database.SessionRequest, checkDB bool) (*database.Event, error) {
|
||||
func (h *HiClient) calculateLocalContent(ctx context.Context, dbEvt *database.Event, evt *event.Event) *database.LocalContent {
|
||||
if evt.Type != event.EventMessage && evt.Type != event.EventSticker {
|
||||
return nil
|
||||
}
|
||||
_ = evt.Content.ParseRaw(evt.Type)
|
||||
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if dbEvt.RelationType == event.RelReplace && content.NewContent != nil {
|
||||
content = content.NewContent
|
||||
}
|
||||
if content != nil {
|
||||
var sanitizedHTML string
|
||||
if content.Format == event.FormatHTML {
|
||||
sanitizedHTML, _ = sanitizeAndLinkifyHTML(content.FormattedBody)
|
||||
} else {
|
||||
var builder strings.Builder
|
||||
linkifyAndWriteBytes(&builder, []byte(content.Body))
|
||||
sanitizedHTML = builder.String()
|
||||
}
|
||||
return &database.LocalContent{SanitizedHTML: sanitizedHTML}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HiClient) postDecryptProcess(ctx context.Context, llSummary *mautrix.LazyLoadSummary, dbEvt *database.Event, evt *event.Event) {
|
||||
if dbEvt.RowID != 0 {
|
||||
h.cacheMedia(ctx, evt, dbEvt.RowID)
|
||||
}
|
||||
dbEvt.UnreadType = h.evaluatePushRules(ctx, llSummary, dbEvt.GetNonPushUnreadType(), evt)
|
||||
dbEvt.LocalContent = h.calculateLocalContent(ctx, dbEvt, evt)
|
||||
}
|
||||
|
||||
func (h *HiClient) processEvent(
|
||||
ctx context.Context,
|
||||
evt *event.Event,
|
||||
llSummary *mautrix.LazyLoadSummary,
|
||||
decryptionQueue map[id.SessionID]*database.SessionRequest,
|
||||
checkDB bool,
|
||||
) (*database.Event, error) {
|
||||
if checkDB {
|
||||
dbEvt, err := h.DB.Event.GetByID(ctx, evt.ID)
|
||||
if err != nil {
|
||||
|
|
@ -350,6 +404,11 @@ func (h *HiClient) processEvent(ctx context.Context, evt *event.Event, decryptio
|
|||
evt.Redacts = id.EventID(gjson.GetBytes(evt.Content.VeryRaw, "redacts").Str)
|
||||
}
|
||||
}
|
||||
if decryptedMautrixEvt != nil {
|
||||
h.postDecryptProcess(ctx, llSummary, dbEvt, decryptedMautrixEvt)
|
||||
} else {
|
||||
h.postDecryptProcess(ctx, llSummary, dbEvt, evt)
|
||||
}
|
||||
_, err := h.DB.Event.Upsert(ctx, dbEvt)
|
||||
if err != nil {
|
||||
return dbEvt, fmt.Errorf("failed to save event %s: %w", evt.ID, err)
|
||||
|
|
@ -386,12 +445,27 @@ func (h *HiClient) processEvent(ctx context.Context, evt *event.Event, decryptio
|
|||
return dbEvt, err
|
||||
}
|
||||
|
||||
func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.Room, state *mautrix.SyncEventsList, timeline *mautrix.SyncTimeline, summary *mautrix.LazyLoadSummary) error {
|
||||
func (h *HiClient) processStateAndTimeline(
|
||||
ctx context.Context,
|
||||
room *database.Room,
|
||||
state *mautrix.SyncEventsList,
|
||||
timeline *mautrix.SyncTimeline,
|
||||
summary *mautrix.LazyLoadSummary,
|
||||
newOwnReceipts []id.EventID,
|
||||
serverNotificationCounts *mautrix.UnreadNotificationCounts,
|
||||
) error {
|
||||
updatedRoom := &database.Room{
|
||||
ID: room.ID,
|
||||
|
||||
SortingTimestamp: room.SortingTimestamp,
|
||||
NameQuality: room.NameQuality,
|
||||
SortingTimestamp: room.SortingTimestamp,
|
||||
NameQuality: room.NameQuality,
|
||||
UnreadHighlights: room.UnreadHighlights,
|
||||
UnreadNotifications: room.UnreadNotifications,
|
||||
UnreadMessages: room.UnreadMessages,
|
||||
}
|
||||
if serverNotificationCounts != nil {
|
||||
updatedRoom.UnreadHighlights = serverNotificationCounts.HighlightCount
|
||||
updatedRoom.UnreadNotifications = serverNotificationCounts.NotificationCount
|
||||
}
|
||||
heroesChanged := false
|
||||
if summary.Heroes == nil && summary.JoinedMemberCount == nil && summary.InvitedMemberCount == nil {
|
||||
|
|
@ -405,6 +479,7 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R
|
|||
}
|
||||
decryptionQueue := make(map[id.SessionID]*database.SessionRequest)
|
||||
allNewEvents := make([]*database.Event, 0, len(state.Events)+len(timeline.Events))
|
||||
newNotifications := make([]SyncNotification, 0)
|
||||
recalculatePreviewEvent := false
|
||||
addOldEvent := func(rowID database.EventRowID, evtID id.EventID) (dbEvt *database.Event, err error) {
|
||||
if rowID != 0 {
|
||||
|
|
@ -440,12 +515,18 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R
|
|||
}
|
||||
return nil
|
||||
}
|
||||
processNewEvent := func(evt *event.Event, isTimeline bool) (database.EventRowID, error) {
|
||||
processNewEvent := func(evt *event.Event, isTimeline, isUnread bool) (database.EventRowID, error) {
|
||||
evt.RoomID = room.ID
|
||||
dbEvt, err := h.processEvent(ctx, evt, decryptionQueue, false)
|
||||
dbEvt, err := h.processEvent(ctx, evt, summary, decryptionQueue, false)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
if isUnread && dbEvt.UnreadType.Is(database.UnreadTypeNotify) {
|
||||
newNotifications = append(newNotifications, SyncNotification{
|
||||
RowID: dbEvt.RowID,
|
||||
Sound: dbEvt.UnreadType.Is(database.UnreadTypeSound),
|
||||
})
|
||||
}
|
||||
if isTimeline {
|
||||
if dbEvt.CanUseForPreview() {
|
||||
updatedRoom.PreviewEventRowID = dbEvt.RowID
|
||||
|
|
@ -492,7 +573,7 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R
|
|||
}
|
||||
for _, evt := range state.Events {
|
||||
evt.Type.Class = event.StateEventType
|
||||
rowID, err := processNewEvent(evt, false)
|
||||
rowID, err := processNewEvent(evt, false, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -502,13 +583,20 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R
|
|||
var err error
|
||||
if len(timeline.Events) > 0 {
|
||||
timelineIDs := make([]database.EventRowID, len(timeline.Events))
|
||||
readUpToIndex := -1
|
||||
for i := len(timeline.Events) - 1; i >= 0; i-- {
|
||||
if slices.Contains(newOwnReceipts, timeline.Events[i].ID) {
|
||||
readUpToIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
for i, evt := range timeline.Events {
|
||||
if evt.StateKey != nil {
|
||||
evt.Type.Class = event.StateEventType
|
||||
} else {
|
||||
evt.Type.Class = event.MessageEventType
|
||||
}
|
||||
timelineIDs[i], err = processNewEvent(evt, true)
|
||||
timelineIDs[i], err = processNewEvent(evt, true, i > readUpToIndex)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -578,11 +666,12 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R
|
|||
}
|
||||
if roomChanged || len(timelineRowTuples) > 0 || len(allNewEvents) > 0 {
|
||||
ctx.Value(syncContextKey).(*syncContext).evt.Rooms[room.ID] = &SyncRoom{
|
||||
Meta: room,
|
||||
Timeline: timelineRowTuples,
|
||||
State: changedState,
|
||||
Reset: timeline.Limited,
|
||||
Events: allNewEvents,
|
||||
Meta: room,
|
||||
Timeline: timelineRowTuples,
|
||||
State: changedState,
|
||||
Reset: timeline.Limited,
|
||||
Events: allNewEvents,
|
||||
Notifications: newNotifications,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
|
|
|||
|
|
@ -68,6 +68,9 @@ func (rs *PushRuleset) MarshalJSON() ([]byte, error) {
|
|||
var DefaultPushActions = PushActionArray{&PushAction{Action: ActionDontNotify}}
|
||||
|
||||
func (rs *PushRuleset) GetMatchingRule(room Room, evt *event.Event) (rule *PushRule) {
|
||||
if rs == nil {
|
||||
return nil
|
||||
}
|
||||
// Add push rule collections to array in priority order
|
||||
arrays := []PushRuleCollection{rs.Override, rs.Content, rs.Room, rs.Sender, rs.Underride}
|
||||
// Loop until one of the push rule collections matches the room/event combo.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue