From 758e80a5f09ad84ab0cea66583d73746ebdd41b8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 17 Oct 2024 00:21:53 +0300 Subject: [PATCH] hicli: add html sanitization and push rule evaluation --- go.mod | 1 + go.sum | 2 + hicli/database/event.go | 74 ++- hicli/database/room.go | 34 +- hicli/database/state.go | 6 +- hicli/database/timeline.go | 6 +- .../database/upgrades/00-latest-revision.sql | 37 +- .../upgrades/03-more-event-fields.sql | 6 + hicli/decryptionqueue.go | 4 +- hicli/events.go | 16 +- hicli/html.go | 476 ++++++++++++++++++ hicli/paginate.go | 6 +- hicli/pushrules.go | 80 +++ hicli/send.go | 2 +- hicli/sync.go | 141 +++++- pushrules/ruleset.go | 3 + 16 files changed, 823 insertions(+), 71 deletions(-) create mode 100644 hicli/database/upgrades/03-more-event-fields.sql create mode 100644 hicli/html.go create mode 100644 hicli/pushrules.go diff --git a/go.mod b/go.mod index f45b8990..ad6dbdc5 100644 --- a/go.mod +++ b/go.mod @@ -28,6 +28,7 @@ require ( golang.org/x/sync v0.8.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 + mvdan.cc/xurls/v2 v2.5.0 ) require ( diff --git a/go.sum b/go.sum index e7a58076..955cbb91 100644 --- a/go.sum +++ b/go.sum @@ -83,3 +83,5 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M= maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA= +mvdan.cc/xurls/v2 v2.5.0 h1:lyBNOm8Wo71UknhUs4QTFUNNMyxy2JEIaKKo0RWOh+8= +mvdan.cc/xurls/v2 v2.5.0/go.mod h1:yQgaGQ1rFtJUzkmKiHYSSfuQxqfYmd//X6PxvholpeE= diff --git a/hicli/database/event.go b/hicli/database/event.go index 0c55d84c..b0f64eb3 100644 --- a/hicli/database/event.go +++ b/hicli/database/event.go @@ -25,9 +25,10 @@ import ( const ( getEventBaseQuery = ` - SELECT rowid, -1, room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned, - transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, send_error, - reactions, last_edit_rowid + SELECT rowid, -1, + room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, + unsigned, local_content, transaction_id, redacted_by, relates_to, relation_type, + megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid, unread_type FROM event ` getEventByRowID = getEventBaseQuery + `WHERE rowid = $1` @@ -36,10 +37,11 @@ const ( getFailedEventsByMegolmSessionID = getEventBaseQuery + `WHERE room_id = $1 AND megolm_session_id = $2 AND decryption_error IS NOT NULL` insertEventBaseQuery = ` INSERT INTO event ( - room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned, - transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, send_error + room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, + unsigned, local_content, transaction_id, redacted_by, relates_to, relation_type, + megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid, unread_type ) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21) ` insertEventQuery = insertEventBaseQuery + `RETURNING rowid` upsertEventQuery = insertEventBaseQuery + ` @@ -50,7 +52,8 @@ const ( decryption_error=CASE WHEN COALESCE(event.decrypted, excluded.decrypted) IS NULL THEN COALESCE(excluded.decryption_error, event.decryption_error) END, send_error=excluded.send_error, timestamp=excluded.timestamp, - unsigned=COALESCE(excluded.unsigned, event.unsigned) + unsigned=COALESCE(excluded.unsigned, event.unsigned), + local_content=COALESCE(excluded.local_content, event.local_content) ON CONFLICT (transaction_id) DO UPDATE SET event_id=excluded.event_id, timestamp=excluded.timestamp, @@ -59,7 +62,7 @@ const ( ` updateEventSendErrorQuery = `UPDATE event SET send_error = $2 WHERE rowid = $1` updateEventIDQuery = `UPDATE event SET event_id = $2, send_error = NULL WHERE rowid=$1` - updateEventDecryptedQuery = `UPDATE event SET decrypted = $1, decrypted_type = $2, decryption_error = NULL WHERE rowid = $3` + updateEventDecryptedQuery = `UPDATE event SET decrypted = $2, decrypted_type = $3, decryption_error = NULL, unread_type = $4, local_content = $5 WHERE rowid = $1` getEventReactionsQuery = getEventBaseQuery + ` WHERE room_id = ? AND type = 'm.reaction' @@ -131,8 +134,16 @@ func (eq *EventQuery) UpdateSendError(ctx context.Context, rowID EventRowID, sen return eq.Exec(ctx, updateEventSendErrorQuery, rowID, sendError) } -func (eq *EventQuery) UpdateDecrypted(ctx context.Context, rowID EventRowID, decrypted json.RawMessage, decryptedType string) error { - return eq.Exec(ctx, updateEventDecryptedQuery, unsafeJSONString(decrypted), decryptedType, rowID) +func (eq *EventQuery) UpdateDecrypted(ctx context.Context, evt *Event) error { + return eq.Exec( + ctx, + updateEventDecryptedQuery, + evt.RowID, + unsafeJSONString(evt.Decrypted), + evt.DecryptedType, + evt.UnreadType, + dbutil.JSONPtr(evt.LocalContent), + ) } func (eq *EventQuery) FillReactionCounts(ctx context.Context, roomID id.RoomID, events []*Event) error { @@ -264,6 +275,24 @@ func (m EventRowID) GetMassInsertValues() [1]any { return [1]any{m} } +type LocalContent struct { + SanitizedHTML string `json:"sanitized_html,omitempty"` +} + +type UnreadType int + +func (ut UnreadType) Is(flag UnreadType) bool { + return ut&flag != 0 +} + +const ( + UnreadTypeNone UnreadType = 0b0000 + UnreadTypeNormal UnreadType = 0b0001 + UnreadTypeNotify UnreadType = 0b0010 + UnreadTypeHighlight UnreadType = 0b0100 + UnreadTypeSound UnreadType = 0b1000 +) + type Event struct { RowID EventRowID `json:"rowid"` TimelineRowID TimelineRowID `json:"timeline_rowid"` @@ -279,6 +308,7 @@ type Event struct { Decrypted json.RawMessage `json:"decrypted,omitempty"` DecryptedType string `json:"decrypted_type,omitempty"` Unsigned json.RawMessage `json:"unsigned,omitempty"` + LocalContent *LocalContent `json:"local_content,omitempty"` TransactionID string `json:"transaction_id,omitempty"` @@ -292,6 +322,7 @@ type Event struct { Reactions map[string]int `json:"reactions,omitempty"` LastEditRowID *EventRowID `json:"last_edit_rowid,omitempty"` + UnreadType UnreadType `json:"unread_type,omitempty"` } func MautrixToEvent(evt *event.Event) *Event { @@ -318,6 +349,9 @@ func MautrixToEvent(evt *event.Event) *Event { } func (e *Event) AsRawMautrix() *event.Event { + if e == nil { + return nil + } evt := &event.Event{ RoomID: e.RoomID, ID: e.ID, @@ -355,6 +389,7 @@ func (e *Event) Scan(row dbutil.Scannable) (*Event, error) { (*[]byte)(&e.Decrypted), &decryptedType, (*[]byte)(&e.Unsigned), + dbutil.JSON{Data: &e.LocalContent}, &transactionID, &redactedBy, &relatesTo, @@ -364,6 +399,7 @@ func (e *Event) Scan(row dbutil.Scannable) (*Event, error) { &sendError, dbutil.JSON{Data: &e.Reactions}, &e.LastEditRowID, + &e.UnreadType, ) if err != nil { return nil, err @@ -425,6 +461,7 @@ func (e *Event) sqlVariables() []any { unsafeJSONString(e.Decrypted), dbutil.StrPtr(e.DecryptedType), unsafeJSONString(e.Unsigned), + dbutil.JSONPtr(e.LocalContent), dbutil.StrPtr(e.TransactionID), dbutil.StrPtr(e.RedactedBy), dbutil.StrPtr(e.RelatesTo), @@ -434,9 +471,26 @@ func (e *Event) sqlVariables() []any { dbutil.StrPtr(e.SendError), dbutil.JSON{Data: reactions}, e.LastEditRowID, + e.UnreadType, } } +func (e *Event) GetNonPushUnreadType() UnreadType { + if e.RelationType == event.RelReplace { + return UnreadTypeNone + } + switch e.Type { + case event.EventMessage.Type, event.EventSticker.Type: + return UnreadTypeNormal + case event.EventEncrypted.Type: + switch e.DecryptedType { + case event.EventMessage.Type, event.EventSticker.Type: + return UnreadTypeNormal + } + } + return UnreadTypeNone +} + func (e *Event) CanUseForPreview() bool { return (e.Type == event.EventMessage.Type || e.Type == event.EventSticker.Type || (e.Type == event.EventEncrypted.Type && diff --git a/hicli/database/room.go b/hicli/database/room.go index d9293cf8..42108022 100644 --- a/hicli/database/room.go +++ b/hicli/database/room.go @@ -23,8 +23,8 @@ import ( const ( getRoomBaseQuery = ` SELECT room_id, creation_content, name, name_quality, avatar, explicit_avatar, topic, canonical_alias, - lazy_load_summary, encryption_event, has_member_list, - preview_event_rowid, sorting_timestamp, prev_batch + lazy_load_summary, encryption_event, has_member_list, preview_event_rowid, sorting_timestamp, + unread_highlights, unread_notifications, unread_messages, prev_batch FROM room ` getRoomsBySortingTimestampQuery = getRoomBaseQuery + `WHERE sorting_timestamp < $1 AND sorting_timestamp > 0 ORDER BY sorting_timestamp DESC LIMIT $2` @@ -47,7 +47,10 @@ const ( has_member_list = room.has_member_list OR $11, preview_event_rowid = COALESCE($12, room.preview_event_rowid), sorting_timestamp = COALESCE($13, room.sorting_timestamp), - prev_batch = COALESCE($14, room.prev_batch) + unread_highlights = COALESCE($14, room.unread_highlights), + unread_notifications = COALESCE($15, room.unread_notifications), + unread_messages = COALESCE($16, room.unread_messages), + prev_batch = COALESCE($17, room.prev_batch) WHERE room_id = $1 ` setRoomPrevBatchQuery = ` @@ -143,8 +146,11 @@ type Room struct { EncryptionEvent *event.EncryptionEventContent `json:"encryption_event,omitempty"` HasMemberList bool `json:"has_member_list"` - PreviewEventRowID EventRowID `json:"preview_event_rowid"` - SortingTimestamp jsontime.UnixMilli `json:"sorting_timestamp"` + PreviewEventRowID EventRowID `json:"preview_event_rowid"` + SortingTimestamp jsontime.UnixMilli `json:"sorting_timestamp"` + UnreadHighlights int `json:"unread_highlights"` + UnreadNotifications int `json:"unread_notifications"` + UnreadMessages int `json:"unread_messages"` PrevBatch string `json:"prev_batch"` } @@ -188,6 +194,18 @@ func (r *Room) CheckChangesAndCopyInto(other *Room) (hasChanges bool) { other.SortingTimestamp = r.SortingTimestamp hasChanges = true } + if r.UnreadHighlights != other.UnreadHighlights { + other.UnreadHighlights = r.UnreadHighlights + hasChanges = true + } + if r.UnreadNotifications != other.UnreadNotifications { + other.UnreadNotifications = r.UnreadNotifications + hasChanges = true + } + if r.UnreadMessages != other.UnreadMessages { + other.UnreadMessages = r.UnreadMessages + hasChanges = true + } if r.PrevBatch != "" && other.PrevBatch == "" { other.PrevBatch = r.PrevBatch hasChanges = true @@ -212,6 +230,9 @@ func (r *Room) Scan(row dbutil.Scannable) (*Room, error) { &r.HasMemberList, &previewEventRowID, &sortingTimestamp, + &r.UnreadHighlights, + &r.UnreadNotifications, + &r.UnreadMessages, &prevBatch, ) if err != nil { @@ -238,6 +259,9 @@ func (r *Room) sqlVariables() []any { r.HasMemberList, dbutil.NumPtr(r.PreviewEventRowID), dbutil.UnixMilliPtr(r.SortingTimestamp.Time), + r.UnreadHighlights, + r.UnreadNotifications, + r.UnreadMessages, dbutil.StrPtr(r.PrevBatch), } } diff --git a/hicli/database/state.go b/hicli/database/state.go index c12f9f60..d6fbf53d 100644 --- a/hicli/database/state.go +++ b/hicli/database/state.go @@ -30,8 +30,10 @@ const ( DELETE FROM current_state WHERE room_id = $1 ` getCurrentRoomStateQuery = ` - SELECT event.rowid, -1, event.room_id, event.event_id, sender, event.type, event.state_key, timestamp, content, decrypted, decrypted_type, unsigned, - transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid + SELECT event.rowid, -1, + event.room_id, event.event_id, sender, event.type, event.state_key, timestamp, content, decrypted, decrypted_type, + unsigned, local_content, transaction_id, redacted_by, relates_to, relation_type, + megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid, unread_type FROM current_state cs JOIN event ON cs.event_rowid = event.rowid WHERE cs.room_id = $1 diff --git a/hicli/database/timeline.go b/hicli/database/timeline.go index ddebd793..e04eeb88 100644 --- a/hicli/database/timeline.go +++ b/hicli/database/timeline.go @@ -34,8 +34,10 @@ const ( ` findMinRowIDQuery = `SELECT MIN(rowid) FROM timeline` getTimelineQuery = ` - SELECT event.rowid, timeline.rowid, event.room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned, - transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid + SELECT event.rowid, timeline.rowid, + event.room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, + unsigned, local_content, transaction_id, redacted_by, relates_to, relation_type, + megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid, unread_type FROM timeline JOIN event ON event.rowid = timeline.event_rowid WHERE timeline.room_id = $1 AND ($2 = 0 OR timeline.rowid < $2) diff --git a/hicli/database/upgrades/00-latest-revision.sql b/hicli/database/upgrades/00-latest-revision.sql index f8c84a61..0808a6e9 100644 --- a/hicli/database/upgrades/00-latest-revision.sql +++ b/hicli/database/upgrades/00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v2 (compatible with v1+): Latest revision +-- v0 -> v3 (compatible with v1+): Latest revision CREATE TABLE account ( user_id TEXT NOT NULL PRIMARY KEY, device_id TEXT NOT NULL, @@ -9,29 +9,34 @@ CREATE TABLE account ( ) STRICT; CREATE TABLE room ( - room_id TEXT NOT NULL PRIMARY KEY, - creation_content TEXT, + room_id TEXT NOT NULL PRIMARY KEY, + creation_content TEXT, - name TEXT, - name_quality INTEGER NOT NULL DEFAULT 0, - avatar TEXT, - explicit_avatar INTEGER NOT NULL DEFAULT 0, - topic TEXT, - canonical_alias TEXT, - lazy_load_summary TEXT, + name TEXT, + name_quality INTEGER NOT NULL DEFAULT 0, + avatar TEXT, + explicit_avatar INTEGER NOT NULL DEFAULT 0, + topic TEXT, + canonical_alias TEXT, + lazy_load_summary TEXT, - encryption_event TEXT, - has_member_list INTEGER NOT NULL DEFAULT false, + encryption_event TEXT, + has_member_list INTEGER NOT NULL DEFAULT false, - preview_event_rowid INTEGER, - sorting_timestamp INTEGER, + preview_event_rowid INTEGER, + sorting_timestamp INTEGER, + unread_highlights INTEGER NOT NULL DEFAULT 0, + unread_notifications INTEGER NOT NULL DEFAULT 0, + unread_messages INTEGER NOT NULL DEFAULT 0, - prev_batch TEXT, + prev_batch TEXT, CONSTRAINT room_preview_event_fkey FOREIGN KEY (preview_event_rowid) REFERENCES event (rowid) ON DELETE SET NULL ) STRICT; CREATE INDEX room_type_idx ON room (creation_content ->> 'type'); CREATE INDEX room_sorting_timestamp_idx ON room (sorting_timestamp DESC); +-- CREATE INDEX room_sorting_timestamp_idx ON room (unread_notifications > 0); +-- CREATE INDEX room_sorting_timestamp_idx ON room (unread_messages > 0); CREATE TABLE account_data ( user_id TEXT NOT NULL, @@ -66,6 +71,7 @@ CREATE TABLE event ( decrypted TEXT, decrypted_type TEXT, unsigned TEXT NOT NULL, + local_content TEXT, transaction_id TEXT, @@ -79,6 +85,7 @@ CREATE TABLE event ( reactions TEXT, last_edit_rowid INTEGER, + unread_type INTEGER NOT NULL DEFAULT 0, CONSTRAINT event_id_unique_key UNIQUE (event_id), CONSTRAINT transaction_id_unique_key UNIQUE (transaction_id), diff --git a/hicli/database/upgrades/03-more-event-fields.sql b/hicli/database/upgrades/03-more-event-fields.sql new file mode 100644 index 00000000..3e07ad75 --- /dev/null +++ b/hicli/database/upgrades/03-more-event-fields.sql @@ -0,0 +1,6 @@ +-- v3 (compatible with v1+): Add more fields to events +ALTER TABLE event ADD COLUMN local_content TEXT; +ALTER TABLE event ADD COLUMN unread_type INTEGER NOT NULL DEFAULT 0; +ALTER TABLE room ADD COLUMN unread_highlights INTEGER NOT NULL DEFAULT 0; +ALTER TABLE room ADD COLUMN unread_notifications INTEGER NOT NULL DEFAULT 0; +ALTER TABLE room ADD COLUMN unread_messages INTEGER NOT NULL DEFAULT 0; diff --git a/hicli/decryptionqueue.go b/hicli/decryptionqueue.go index 87b6b8b2..665ee78a 100644 --- a/hicli/decryptionqueue.go +++ b/hicli/decryptionqueue.go @@ -59,14 +59,14 @@ func (h *HiClient) handleReceivedMegolmSession(ctx context.Context, roomID id.Ro log.Warn().Err(err).Stringer("event_id", evt.ID).Msg("Failed to decrypt event even after receiving megolm session") } else { decrypted = append(decrypted, evt) - h.cacheMedia(ctx, mautrixEvt, evt.RowID) + h.postDecryptProcess(ctx, nil, evt, mautrixEvt) } } if len(decrypted) > 0 { var newPreview database.EventRowID err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error { for _, evt := range decrypted { - err = h.DB.Event.UpdateDecrypted(ctx, evt.RowID, evt.Decrypted, evt.DecryptedType) + err = h.DB.Event.UpdateDecrypted(ctx, evt) if err != nil { return fmt.Errorf("failed to save decrypted content for %s: %w", evt.ID, err) } diff --git a/hicli/events.go b/hicli/events.go index b96fd266..e730475b 100644 --- a/hicli/events.go +++ b/hicli/events.go @@ -13,11 +13,17 @@ import ( ) type SyncRoom struct { - Meta *database.Room `json:"meta"` - Timeline []database.TimelineRowTuple `json:"timeline"` - State map[event.Type]map[string]database.EventRowID `json:"state"` - Events []*database.Event `json:"events"` - Reset bool `json:"reset"` + Meta *database.Room `json:"meta"` + Timeline []database.TimelineRowTuple `json:"timeline"` + State map[event.Type]map[string]database.EventRowID `json:"state"` + Events []*database.Event `json:"events"` + Reset bool `json:"reset"` + Notifications []SyncNotification `json:"notifications"` +} + +type SyncNotification struct { + RowID database.EventRowID `json:"event_rowid"` + Sound bool `json:"sound"` } type SyncComplete struct { diff --git a/hicli/html.go b/hicli/html.go new file mode 100644 index 00000000..b0ad824d --- /dev/null +++ b/hicli/html.go @@ -0,0 +1,476 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "bytes" + "errors" + "fmt" + "io" + "net/url" + "regexp" + "slices" + "strconv" + "strings" + + "golang.org/x/net/html" + "golang.org/x/net/html/atom" + "mvdan.cc/xurls/v2" + + "maunium.net/go/mautrix/id" +) + +func tagIsAllowed(tag atom.Atom) bool { + switch tag { + case atom.Del, atom.H1, atom.H2, atom.H3, atom.H4, atom.H5, atom.H6, atom.Blockquote, atom.P, + atom.A, atom.Ul, atom.Ol, atom.Sup, atom.Sub, atom.Li, atom.B, atom.I, atom.U, atom.Strong, + atom.Em, atom.S, atom.Code, atom.Hr, atom.Br, atom.Div, atom.Table, atom.Thead, atom.Tbody, + atom.Tr, atom.Th, atom.Td, atom.Caption, atom.Pre, atom.Span, atom.Font, atom.Img, + atom.Details, atom.Summary: + return true + default: + return false + } +} + +func isSelfClosing(tag atom.Atom) bool { + switch tag { + case atom.Img, atom.Br, atom.Hr: + return true + default: + return false + } +} + +var languageRegex = regexp.MustCompile(`^language-[a-zA-Z0-9-]+$`) +var allowedColorRegex = regexp.MustCompile(`^#[0-9a-fA-F]{6}$`) + +// This is approximately a mirror of web/src/util/mediasize.ts in gomuks +func calculateMediaSize(widthInt, heightInt int) (width, height float64, ok bool) { + if widthInt <= 0 || heightInt <= 0 { + return + } + width = float64(widthInt) + height = float64(heightInt) + const imageContainerWidth float64 = 320 + const imageContainerHeight float64 = 240 + const imageContainerAspectRatio = imageContainerWidth / imageContainerHeight + if width > imageContainerWidth || height > imageContainerHeight { + aspectRatio := width / height + if aspectRatio > imageContainerAspectRatio { + width = imageContainerWidth + height = imageContainerWidth / aspectRatio + } else if aspectRatio < imageContainerAspectRatio { + width = imageContainerHeight * aspectRatio + height = imageContainerHeight + } else { + width = imageContainerWidth + height = imageContainerHeight + } + } + ok = true + return +} + +func parseImgAttributes(attrs []html.Attribute) (src, alt, title string, isCustomEmoji bool, width, height int) { + for _, attr := range attrs { + switch attr.Key { + case "src": + src = attr.Val + case "alt": + alt = attr.Val + case "title": + title = attr.Val + case "data-mx-emoticon": + isCustomEmoji = true + case "width": + width, _ = strconv.Atoi(attr.Val) + case "height": + height, _ = strconv.Atoi(attr.Val) + } + } + return +} + +func parseSpanAttributes(attrs []html.Attribute) (bgColor, textColor, spoiler, maths string, isSpoiler bool) { + for _, attr := range attrs { + switch attr.Key { + case "data-mx-bg-color": + if allowedColorRegex.MatchString(attr.Val) { + bgColor = attr.Val + } + case "data-mx-color", "color": + if allowedColorRegex.MatchString(attr.Val) { + textColor = attr.Val + } + case "data-mx-spoiler": + spoiler = attr.Val + isSpoiler = true + case "data-mx-maths": + maths = attr.Val + } + } + return +} + +func parseAAttributes(attrs []html.Attribute) (href string) { + for _, attr := range attrs { + switch attr.Key { + case "href": + href = strings.TrimSpace(attr.Val) + } + } + return +} + +func attributeIsAllowed(tag atom.Atom, attr html.Attribute) bool { + switch tag { + case atom.Ol: + switch attr.Key { + case "start": + _, err := strconv.Atoi(attr.Val) + return err == nil + } + case atom.Code: + switch attr.Key { + case "class": + return languageRegex.MatchString(attr.Val) + } + case atom.Div: + switch attr.Key { + case "data-mx-maths": + return true + } + } + return false +} + +// Funny user IDs will just need to be linkified by the sender, no auto-linkification for them. +var plainUserOrAliasMentionRegex = regexp.MustCompile(`[@#][a-zA-Z0-9._=/+-]{0,254}:[a-zA-Z0-9.-]+(?:\d{1,5})?`) + +func getNextItem(items [][]int, minIndex int) (index, start, end int, ok bool) { + for i, item := range items { + if item[0] >= minIndex { + return i, item[0], item[1], true + } + } + return -1, -1, -1, false +} + +func writeMention(w *strings.Builder, mention []byte) { + w.WriteString(`') + writeEscapedBytes(w, mention) + w.WriteString("") +} + +func writeURL(w *strings.Builder, addr []byte) { + parsedURL, err := url.Parse(string(addr)) + if err != nil { + writeEscapedBytes(w, addr) + return + } + if parsedURL.Scheme == "" { + parsedURL.Scheme = "https" + } + w.WriteString(`') + writeEscapedBytes(w, addr) + w.WriteString("") +} + +func linkifyAndWriteBytes(w *strings.Builder, s []byte) { + mentions := plainUserOrAliasMentionRegex.FindAllIndex(s, -1) + urls := xurls.Relaxed().FindAllIndex(s, -1) + minIndex := 0 + for { + mentionIdx, nextMentionStart, nextMentionEnd, hasMention := getNextItem(mentions, minIndex) + urlIdx, nextURLStart, nextURLEnd, hasURL := getNextItem(urls, minIndex) + if hasMention && (!hasURL || nextMentionStart <= nextURLStart) { + writeEscapedBytes(w, s[minIndex:nextMentionStart]) + writeMention(w, s[nextMentionStart:nextMentionEnd]) + minIndex = nextMentionEnd + mentions = mentions[mentionIdx:] + } else if hasURL && (!hasMention || nextURLStart < nextMentionStart) { + writeEscapedBytes(w, s[minIndex:nextURLStart]) + writeURL(w, s[nextURLStart:nextURLEnd]) + minIndex = nextURLEnd + urls = urls[urlIdx:] + } else { + break + } + } + writeEscapedBytes(w, s[minIndex:]) +} + +const escapedChars = "&'<>\"\r" + +func writeEscapedBytes(w *strings.Builder, s []byte) { + i := bytes.IndexAny(s, escapedChars) + for i != -1 { + w.Write(s[:i]) + var esc string + switch s[i] { + case '&': + esc = "&" + case '\'': + // "'" is shorter than "'" and apos was not in HTML until HTML5. + esc = "'" + case '<': + esc = "<" + case '>': + esc = ">" + case '"': + // """ is shorter than """. + esc = """ + case '\r': + esc = " " + default: + panic("unrecognized escape character") + } + s = s[i+1:] + w.WriteString(esc) + i = bytes.IndexAny(s, escapedChars) + } + w.Write(s) +} + +func writeEscapedString(w *strings.Builder, s string) { + i := strings.IndexAny(s, escapedChars) + for i != -1 { + w.WriteString(s[:i]) + var esc string + switch s[i] { + case '&': + esc = "&" + case '\'': + // "'" is shorter than "'" and apos was not in HTML until HTML5. + esc = "'" + case '<': + esc = "<" + case '>': + esc = ">" + case '"': + // """ is shorter than """. + esc = """ + case '\r': + esc = " " + default: + panic("unrecognized escape character") + } + s = s[i+1:] + w.WriteString(esc) + i = strings.IndexAny(s, escapedChars) + } + w.WriteString(s) +} + +func writeAttribute(w *strings.Builder, key, value string) { + w.WriteByte(' ') + w.WriteString(key) + w.WriteString(`="`) + writeEscapedString(w, value) + w.WriteByte('"') +} + +func writeA(w *strings.Builder, attr []html.Attribute) { + w.WriteString("`) + w.WriteString(spoiler) + w.WriteString(" ") + } + w.WriteByte('<') + w.WriteString("span") + if isSpoiler { + writeAttribute(w, "class", "hicli-spoiler") + } + var style string + if bgColor != "" { + style += fmt.Sprintf("background-color: %s;", bgColor) + } + if textColor != "" { + style += fmt.Sprintf("color: %s;", textColor) + } + if style != "" { + writeAttribute(w, "style", style) + } +} + +type tagStack []atom.Atom + +func (ts *tagStack) contains(tags ...atom.Atom) bool { + for i := len(*ts) - 1; i >= 0; i-- { + for _, tag := range tags { + if (*ts)[i] == tag { + return true + } + } + } + return false +} + +func (ts *tagStack) push(tag atom.Atom) { + *ts = append(*ts, tag) +} + +func (ts *tagStack) pop(tag atom.Atom) bool { + if len(*ts) > 0 && (*ts)[len(*ts)-1] == tag { + *ts = (*ts)[:len(*ts)-1] + return true + } + return false +} + +func sanitizeAndLinkifyHTML(body string) (string, error) { + tz := html.NewTokenizer(strings.NewReader(body)) + var built strings.Builder + ts := make(tagStack, 2) +Loop: + for { + switch tz.Next() { + case html.ErrorToken: + err := tz.Err() + if errors.Is(err, io.EOF) { + break Loop + } + return "", err + case html.StartTagToken, html.SelfClosingTagToken: + token := tz.Token() + if !tagIsAllowed(token.DataAtom) { + continue + } + tagIsSelfClosing := isSelfClosing(token.DataAtom) + if token.Type == html.SelfClosingTagToken && !tagIsSelfClosing { + continue + } + switch token.DataAtom { + case atom.A: + writeA(&built, token.Attr) + case atom.Img: + writeImg(&built, token.Attr) + case atom.Span, atom.Font: + writeSpan(&built, token.Attr) + default: + built.WriteByte('<') + built.WriteString(token.Data) + for _, attr := range token.Attr { + if attributeIsAllowed(token.DataAtom, attr) { + writeAttribute(&built, attr.Key, attr.Val) + } + } + } + built.WriteByte('>') + if !tagIsSelfClosing { + ts.push(token.DataAtom) + } + case html.EndTagToken: + tagName, _ := tz.TagName() + tag := atom.Lookup(tagName) + if tagIsAllowed(tag) && ts.pop(tag) { + built.WriteString("') + } + case html.TextToken: + if ts.contains(atom.Pre, atom.Code, atom.A) { + writeEscapedBytes(&built, tz.Text()) + } else { + linkifyAndWriteBytes(&built, tz.Text()) + } + case html.DoctypeToken, html.CommentToken: + // ignore + } + } + slices.Reverse(ts) + for _, t := range ts { + built.WriteString("') + } + return built.String(), nil +} diff --git a/hicli/paginate.go b/hicli/paginate.go index da927b9b..7fc50827 100644 --- a/hicli/paginate.go +++ b/hicli/paginate.go @@ -59,7 +59,7 @@ func (h *HiClient) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.Ev } else if serverEvt, err := h.Client.GetEvent(ctx, roomID, eventID); err != nil { return nil, fmt.Errorf("failed to get event from server: %w", err) } else { - return h.processEvent(ctx, serverEvt, nil, false) + return h.processEvent(ctx, serverEvt, nil, nil, false) } } @@ -90,7 +90,7 @@ func (h *HiClient) GetRoomState(ctx context.Context, roomID id.RoomID, fetchMemb } entries := make([]*database.CurrentStateEntry, len(evts)) for i, evt := range evts { - dbEvt, err := h.processEvent(ctx, evt, nil, false) + dbEvt, err := h.processEvent(ctx, evt, room.LazyLoadSummary, nil, false) if err != nil { return fmt.Errorf("failed to process event %s: %w", evt.ID, err) } @@ -186,7 +186,7 @@ func (h *HiClient) PaginateServer(ctx context.Context, roomID id.RoomID, limit i decryptionQueue := make(map[id.SessionID]*database.SessionRequest) iOffset := 0 for i, evt := range resp.Chunk { - dbEvt, err := h.processEvent(ctx, evt, decryptionQueue, true) + dbEvt, err := h.processEvent(ctx, evt, room.LazyLoadSummary, decryptionQueue, true) if err != nil { return err } else if exists, err := h.DB.Timeline.Has(ctx, roomID, dbEvt.RowID); err != nil { diff --git a/hicli/pushrules.go b/hicli/pushrules.go new file mode 100644 index 00000000..74c0e8e4 --- /dev/null +++ b/hicli/pushrules.go @@ -0,0 +1,80 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/hicli/database" + "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/pushrules" +) + +type pushRoom struct { + ctx context.Context + roomID id.RoomID + h *HiClient + ll *mautrix.LazyLoadSummary +} + +func (p *pushRoom) GetOwnDisplayname() string { + // TODO implement + return "" +} + +func (p *pushRoom) GetMemberCount() int { + if p.ll == nil { + room, err := p.h.DB.Room.Get(p.ctx, p.roomID) + if err != nil { + zerolog.Ctx(p.ctx).Err(err). + Stringer("room_id", p.roomID). + Msg("Failed to get room by ID in push rule evaluator") + } else if room != nil { + p.ll = room.LazyLoadSummary + } + } + if p.ll != nil && p.ll.JoinedMemberCount != nil { + return *p.ll.JoinedMemberCount + } + // TODO query db? + return 0 +} + +func (p *pushRoom) GetEvent(id id.EventID) *event.Event { + evt, err := p.h.DB.Event.GetByID(p.ctx, id) + if err != nil { + zerolog.Ctx(p.ctx).Err(err). + Stringer("event_id", id). + Msg("Failed to get event by ID in push rule evaluator") + } + return evt.AsRawMautrix() +} + +var _ pushrules.EventfulRoom = (*pushRoom)(nil) + +func (h *HiClient) evaluatePushRules(ctx context.Context, llSummary *mautrix.LazyLoadSummary, baseType database.UnreadType, evt *event.Event) database.UnreadType { + should := h.PushRules.Load().GetMatchingRule(&pushRoom{ + ctx: ctx, + roomID: evt.RoomID, + h: h, + ll: llSummary, + }, evt).GetActions().Should() + if should.Notify { + baseType |= database.UnreadTypeNotify + } + if should.Highlight { + baseType |= database.UnreadTypeHighlight + } + if should.PlaySound { + baseType |= database.UnreadTypeSound + } + return baseType +} diff --git a/hicli/send.go b/hicli/send.go index 76852dde..cdb8571b 100644 --- a/hicli/send.go +++ b/hicli/send.go @@ -218,7 +218,7 @@ func (h *HiClient) loadMembers(ctx context.Context, room *database.Room) error { err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error { entries := make([]*database.CurrentStateEntry, len(resp.Chunk)) for i, evt := range resp.Chunk { - dbEvt, err := h.processEvent(ctx, evt, nil, true) + dbEvt, err := h.processEvent(ctx, evt, nil, nil, true) if err != nil { return err } diff --git a/hicli/sync.go b/hicli/sync.go index 16930b59..dcb33637 100644 --- a/hicli/sync.go +++ b/hicli/sync.go @@ -134,11 +134,15 @@ func (h *HiClient) processSyncResponse(ctx context.Context, resp *mautrix.RespSy return nil } -func receiptsToList(content *event.ReceiptEventContent) []*database.Receipt { +func (h *HiClient) receiptsToList(content *event.ReceiptEventContent) ([]*database.Receipt, []id.EventID) { receiptList := make([]*database.Receipt, 0) + var newOwnReceipts []id.EventID for eventID, receipts := range *content { for receiptType, users := range receipts { for userID, receiptInfo := range users { + if userID == h.Account.UserID { + newOwnReceipts = append(newOwnReceipts, eventID) + } receiptList = append(receiptList, &database.Receipt{ UserID: userID, ReceiptType: receiptType, @@ -149,7 +153,12 @@ func receiptsToList(content *event.ReceiptEventContent) []*database.Receipt { } } } - return receiptList + return receiptList, newOwnReceipts +} + +type receiptsToSave struct { + roomID id.RoomID + receipts []*database.Receipt } func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncJoinedRoom) error { @@ -172,10 +181,8 @@ func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID, return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err) } } - err = h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary) - if err != nil { - return err - } + var receipts []receiptsToSave + var newOwnReceipts []id.EventID for _, evt := range room.Ephemeral.Events { evt.Type.Class = event.EphemeralEventType err = evt.Content.ParseRaw(evt.Type) @@ -185,18 +192,24 @@ func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID, } switch evt.Type { case event.EphemeralEventReceipt: - err = h.DB.Receipt.PutMany(ctx, roomID, receiptsToList(evt.Content.AsReceipt())...) - if err != nil { - return fmt.Errorf("failed to save receipts: %w", err) - } + var receiptsList []*database.Receipt + receiptsList, newOwnReceipts = h.receiptsToList(evt.Content.AsReceipt()) + receipts = append(receipts, receiptsToSave{roomID, receiptsList}) case event.EphemeralEventTyping: go h.EventHandler(&Typing{ RoomID: roomID, TypingEventContent: *evt.Content.AsTyping(), }) } - if evt.Type != event.EphemeralEventReceipt { - continue + } + err = h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary, newOwnReceipts, room.UnreadNotifications) + if err != nil { + return err + } + for _, rs := range receipts { + err = h.DB.Receipt.PutMany(ctx, rs.roomID, rs.receipts...) + if err != nil { + return fmt.Errorf("failed to save receipts: %w", err) } } return nil @@ -209,7 +222,8 @@ func (h *HiClient) processSyncLeftRoom(ctx context.Context, roomID id.RoomID, ro } else if existingRoomData == nil { return nil } - return h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary) + // TODO delete room instead of processing? + return h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary, nil, nil) } func isDecryptionErrorRetryable(err error) bool { @@ -318,7 +332,47 @@ func (h *HiClient) cacheMedia(ctx context.Context, evt *event.Event, rowID datab } } -func (h *HiClient) processEvent(ctx context.Context, evt *event.Event, decryptionQueue map[id.SessionID]*database.SessionRequest, checkDB bool) (*database.Event, error) { +func (h *HiClient) calculateLocalContent(ctx context.Context, dbEvt *database.Event, evt *event.Event) *database.LocalContent { + if evt.Type != event.EventMessage && evt.Type != event.EventSticker { + return nil + } + _ = evt.Content.ParseRaw(evt.Type) + content, ok := evt.Content.Parsed.(*event.MessageEventContent) + if !ok { + return nil + } + if dbEvt.RelationType == event.RelReplace && content.NewContent != nil { + content = content.NewContent + } + if content != nil { + var sanitizedHTML string + if content.Format == event.FormatHTML { + sanitizedHTML, _ = sanitizeAndLinkifyHTML(content.FormattedBody) + } else { + var builder strings.Builder + linkifyAndWriteBytes(&builder, []byte(content.Body)) + sanitizedHTML = builder.String() + } + return &database.LocalContent{SanitizedHTML: sanitizedHTML} + } + return nil +} + +func (h *HiClient) postDecryptProcess(ctx context.Context, llSummary *mautrix.LazyLoadSummary, dbEvt *database.Event, evt *event.Event) { + if dbEvt.RowID != 0 { + h.cacheMedia(ctx, evt, dbEvt.RowID) + } + dbEvt.UnreadType = h.evaluatePushRules(ctx, llSummary, dbEvt.GetNonPushUnreadType(), evt) + dbEvt.LocalContent = h.calculateLocalContent(ctx, dbEvt, evt) +} + +func (h *HiClient) processEvent( + ctx context.Context, + evt *event.Event, + llSummary *mautrix.LazyLoadSummary, + decryptionQueue map[id.SessionID]*database.SessionRequest, + checkDB bool, +) (*database.Event, error) { if checkDB { dbEvt, err := h.DB.Event.GetByID(ctx, evt.ID) if err != nil { @@ -350,6 +404,11 @@ func (h *HiClient) processEvent(ctx context.Context, evt *event.Event, decryptio evt.Redacts = id.EventID(gjson.GetBytes(evt.Content.VeryRaw, "redacts").Str) } } + if decryptedMautrixEvt != nil { + h.postDecryptProcess(ctx, llSummary, dbEvt, decryptedMautrixEvt) + } else { + h.postDecryptProcess(ctx, llSummary, dbEvt, evt) + } _, err := h.DB.Event.Upsert(ctx, dbEvt) if err != nil { return dbEvt, fmt.Errorf("failed to save event %s: %w", evt.ID, err) @@ -386,12 +445,27 @@ func (h *HiClient) processEvent(ctx context.Context, evt *event.Event, decryptio return dbEvt, err } -func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.Room, state *mautrix.SyncEventsList, timeline *mautrix.SyncTimeline, summary *mautrix.LazyLoadSummary) error { +func (h *HiClient) processStateAndTimeline( + ctx context.Context, + room *database.Room, + state *mautrix.SyncEventsList, + timeline *mautrix.SyncTimeline, + summary *mautrix.LazyLoadSummary, + newOwnReceipts []id.EventID, + serverNotificationCounts *mautrix.UnreadNotificationCounts, +) error { updatedRoom := &database.Room{ ID: room.ID, - SortingTimestamp: room.SortingTimestamp, - NameQuality: room.NameQuality, + SortingTimestamp: room.SortingTimestamp, + NameQuality: room.NameQuality, + UnreadHighlights: room.UnreadHighlights, + UnreadNotifications: room.UnreadNotifications, + UnreadMessages: room.UnreadMessages, + } + if serverNotificationCounts != nil { + updatedRoom.UnreadHighlights = serverNotificationCounts.HighlightCount + updatedRoom.UnreadNotifications = serverNotificationCounts.NotificationCount } heroesChanged := false if summary.Heroes == nil && summary.JoinedMemberCount == nil && summary.InvitedMemberCount == nil { @@ -405,6 +479,7 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R } decryptionQueue := make(map[id.SessionID]*database.SessionRequest) allNewEvents := make([]*database.Event, 0, len(state.Events)+len(timeline.Events)) + newNotifications := make([]SyncNotification, 0) recalculatePreviewEvent := false addOldEvent := func(rowID database.EventRowID, evtID id.EventID) (dbEvt *database.Event, err error) { if rowID != 0 { @@ -440,12 +515,18 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R } return nil } - processNewEvent := func(evt *event.Event, isTimeline bool) (database.EventRowID, error) { + processNewEvent := func(evt *event.Event, isTimeline, isUnread bool) (database.EventRowID, error) { evt.RoomID = room.ID - dbEvt, err := h.processEvent(ctx, evt, decryptionQueue, false) + dbEvt, err := h.processEvent(ctx, evt, summary, decryptionQueue, false) if err != nil { return -1, err } + if isUnread && dbEvt.UnreadType.Is(database.UnreadTypeNotify) { + newNotifications = append(newNotifications, SyncNotification{ + RowID: dbEvt.RowID, + Sound: dbEvt.UnreadType.Is(database.UnreadTypeSound), + }) + } if isTimeline { if dbEvt.CanUseForPreview() { updatedRoom.PreviewEventRowID = dbEvt.RowID @@ -492,7 +573,7 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R } for _, evt := range state.Events { evt.Type.Class = event.StateEventType - rowID, err := processNewEvent(evt, false) + rowID, err := processNewEvent(evt, false, false) if err != nil { return err } @@ -502,13 +583,20 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R var err error if len(timeline.Events) > 0 { timelineIDs := make([]database.EventRowID, len(timeline.Events)) + readUpToIndex := -1 + for i := len(timeline.Events) - 1; i >= 0; i-- { + if slices.Contains(newOwnReceipts, timeline.Events[i].ID) { + readUpToIndex = i + break + } + } for i, evt := range timeline.Events { if evt.StateKey != nil { evt.Type.Class = event.StateEventType } else { evt.Type.Class = event.MessageEventType } - timelineIDs[i], err = processNewEvent(evt, true) + timelineIDs[i], err = processNewEvent(evt, true, i > readUpToIndex) if err != nil { return err } @@ -578,11 +666,12 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R } if roomChanged || len(timelineRowTuples) > 0 || len(allNewEvents) > 0 { ctx.Value(syncContextKey).(*syncContext).evt.Rooms[room.ID] = &SyncRoom{ - Meta: room, - Timeline: timelineRowTuples, - State: changedState, - Reset: timeline.Limited, - Events: allNewEvents, + Meta: room, + Timeline: timelineRowTuples, + State: changedState, + Reset: timeline.Limited, + Events: allNewEvents, + Notifications: newNotifications, } } return nil diff --git a/pushrules/ruleset.go b/pushrules/ruleset.go index 609997b4..c42d4799 100644 --- a/pushrules/ruleset.go +++ b/pushrules/ruleset.go @@ -68,6 +68,9 @@ func (rs *PushRuleset) MarshalJSON() ([]byte, error) { var DefaultPushActions = PushActionArray{&PushAction{Action: ActionDontNotify}} func (rs *PushRuleset) GetMatchingRule(room Room, evt *event.Event) (rule *PushRule) { + if rs == nil { + return nil + } // Add push rule collections to array in priority order arrays := []PushRuleCollection{rs.Override, rs.Content, rs.Room, rs.Sender, rs.Underride} // Loop until one of the push rule collections matches the room/event combo.