hicli: add html sanitization and push rule evaluation

This commit is contained in:
Tulir Asokan 2024-10-17 00:21:53 +03:00
commit 758e80a5f0
16 changed files with 823 additions and 71 deletions

1
go.mod
View file

@ -28,6 +28,7 @@ require (
golang.org/x/sync v0.8.0
gopkg.in/yaml.v3 v3.0.1
maunium.net/go/mauflag v1.0.0
mvdan.cc/xurls/v2 v2.5.0
)
require (

2
go.sum
View file

@ -83,3 +83,5 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M=
maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA=
mvdan.cc/xurls/v2 v2.5.0 h1:lyBNOm8Wo71UknhUs4QTFUNNMyxy2JEIaKKo0RWOh+8=
mvdan.cc/xurls/v2 v2.5.0/go.mod h1:yQgaGQ1rFtJUzkmKiHYSSfuQxqfYmd//X6PxvholpeE=

View file

@ -25,9 +25,10 @@ import (
const (
getEventBaseQuery = `
SELECT rowid, -1, room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned,
transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, send_error,
reactions, last_edit_rowid
SELECT rowid, -1,
room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type,
unsigned, local_content, transaction_id, redacted_by, relates_to, relation_type,
megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid, unread_type
FROM event
`
getEventByRowID = getEventBaseQuery + `WHERE rowid = $1`
@ -36,10 +37,11 @@ const (
getFailedEventsByMegolmSessionID = getEventBaseQuery + `WHERE room_id = $1 AND megolm_session_id = $2 AND decryption_error IS NOT NULL`
insertEventBaseQuery = `
INSERT INTO event (
room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned,
transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, send_error
room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type,
unsigned, local_content, transaction_id, redacted_by, relates_to, relation_type,
megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid, unread_type
)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21)
`
insertEventQuery = insertEventBaseQuery + `RETURNING rowid`
upsertEventQuery = insertEventBaseQuery + `
@ -50,7 +52,8 @@ const (
decryption_error=CASE WHEN COALESCE(event.decrypted, excluded.decrypted) IS NULL THEN COALESCE(excluded.decryption_error, event.decryption_error) END,
send_error=excluded.send_error,
timestamp=excluded.timestamp,
unsigned=COALESCE(excluded.unsigned, event.unsigned)
unsigned=COALESCE(excluded.unsigned, event.unsigned),
local_content=COALESCE(excluded.local_content, event.local_content)
ON CONFLICT (transaction_id) DO UPDATE
SET event_id=excluded.event_id,
timestamp=excluded.timestamp,
@ -59,7 +62,7 @@ const (
`
updateEventSendErrorQuery = `UPDATE event SET send_error = $2 WHERE rowid = $1`
updateEventIDQuery = `UPDATE event SET event_id = $2, send_error = NULL WHERE rowid=$1`
updateEventDecryptedQuery = `UPDATE event SET decrypted = $1, decrypted_type = $2, decryption_error = NULL WHERE rowid = $3`
updateEventDecryptedQuery = `UPDATE event SET decrypted = $2, decrypted_type = $3, decryption_error = NULL, unread_type = $4, local_content = $5 WHERE rowid = $1`
getEventReactionsQuery = getEventBaseQuery + `
WHERE room_id = ?
AND type = 'm.reaction'
@ -131,8 +134,16 @@ func (eq *EventQuery) UpdateSendError(ctx context.Context, rowID EventRowID, sen
return eq.Exec(ctx, updateEventSendErrorQuery, rowID, sendError)
}
func (eq *EventQuery) UpdateDecrypted(ctx context.Context, rowID EventRowID, decrypted json.RawMessage, decryptedType string) error {
return eq.Exec(ctx, updateEventDecryptedQuery, unsafeJSONString(decrypted), decryptedType, rowID)
func (eq *EventQuery) UpdateDecrypted(ctx context.Context, evt *Event) error {
return eq.Exec(
ctx,
updateEventDecryptedQuery,
evt.RowID,
unsafeJSONString(evt.Decrypted),
evt.DecryptedType,
evt.UnreadType,
dbutil.JSONPtr(evt.LocalContent),
)
}
func (eq *EventQuery) FillReactionCounts(ctx context.Context, roomID id.RoomID, events []*Event) error {
@ -264,6 +275,24 @@ func (m EventRowID) GetMassInsertValues() [1]any {
return [1]any{m}
}
type LocalContent struct {
SanitizedHTML string `json:"sanitized_html,omitempty"`
}
type UnreadType int
func (ut UnreadType) Is(flag UnreadType) bool {
return ut&flag != 0
}
const (
UnreadTypeNone UnreadType = 0b0000
UnreadTypeNormal UnreadType = 0b0001
UnreadTypeNotify UnreadType = 0b0010
UnreadTypeHighlight UnreadType = 0b0100
UnreadTypeSound UnreadType = 0b1000
)
type Event struct {
RowID EventRowID `json:"rowid"`
TimelineRowID TimelineRowID `json:"timeline_rowid"`
@ -279,6 +308,7 @@ type Event struct {
Decrypted json.RawMessage `json:"decrypted,omitempty"`
DecryptedType string `json:"decrypted_type,omitempty"`
Unsigned json.RawMessage `json:"unsigned,omitempty"`
LocalContent *LocalContent `json:"local_content,omitempty"`
TransactionID string `json:"transaction_id,omitempty"`
@ -292,6 +322,7 @@ type Event struct {
Reactions map[string]int `json:"reactions,omitempty"`
LastEditRowID *EventRowID `json:"last_edit_rowid,omitempty"`
UnreadType UnreadType `json:"unread_type,omitempty"`
}
func MautrixToEvent(evt *event.Event) *Event {
@ -318,6 +349,9 @@ func MautrixToEvent(evt *event.Event) *Event {
}
func (e *Event) AsRawMautrix() *event.Event {
if e == nil {
return nil
}
evt := &event.Event{
RoomID: e.RoomID,
ID: e.ID,
@ -355,6 +389,7 @@ func (e *Event) Scan(row dbutil.Scannable) (*Event, error) {
(*[]byte)(&e.Decrypted),
&decryptedType,
(*[]byte)(&e.Unsigned),
dbutil.JSON{Data: &e.LocalContent},
&transactionID,
&redactedBy,
&relatesTo,
@ -364,6 +399,7 @@ func (e *Event) Scan(row dbutil.Scannable) (*Event, error) {
&sendError,
dbutil.JSON{Data: &e.Reactions},
&e.LastEditRowID,
&e.UnreadType,
)
if err != nil {
return nil, err
@ -425,6 +461,7 @@ func (e *Event) sqlVariables() []any {
unsafeJSONString(e.Decrypted),
dbutil.StrPtr(e.DecryptedType),
unsafeJSONString(e.Unsigned),
dbutil.JSONPtr(e.LocalContent),
dbutil.StrPtr(e.TransactionID),
dbutil.StrPtr(e.RedactedBy),
dbutil.StrPtr(e.RelatesTo),
@ -434,9 +471,26 @@ func (e *Event) sqlVariables() []any {
dbutil.StrPtr(e.SendError),
dbutil.JSON{Data: reactions},
e.LastEditRowID,
e.UnreadType,
}
}
func (e *Event) GetNonPushUnreadType() UnreadType {
if e.RelationType == event.RelReplace {
return UnreadTypeNone
}
switch e.Type {
case event.EventMessage.Type, event.EventSticker.Type:
return UnreadTypeNormal
case event.EventEncrypted.Type:
switch e.DecryptedType {
case event.EventMessage.Type, event.EventSticker.Type:
return UnreadTypeNormal
}
}
return UnreadTypeNone
}
func (e *Event) CanUseForPreview() bool {
return (e.Type == event.EventMessage.Type || e.Type == event.EventSticker.Type ||
(e.Type == event.EventEncrypted.Type &&

View file

@ -23,8 +23,8 @@ import (
const (
getRoomBaseQuery = `
SELECT room_id, creation_content, name, name_quality, avatar, explicit_avatar, topic, canonical_alias,
lazy_load_summary, encryption_event, has_member_list,
preview_event_rowid, sorting_timestamp, prev_batch
lazy_load_summary, encryption_event, has_member_list, preview_event_rowid, sorting_timestamp,
unread_highlights, unread_notifications, unread_messages, prev_batch
FROM room
`
getRoomsBySortingTimestampQuery = getRoomBaseQuery + `WHERE sorting_timestamp < $1 AND sorting_timestamp > 0 ORDER BY sorting_timestamp DESC LIMIT $2`
@ -47,7 +47,10 @@ const (
has_member_list = room.has_member_list OR $11,
preview_event_rowid = COALESCE($12, room.preview_event_rowid),
sorting_timestamp = COALESCE($13, room.sorting_timestamp),
prev_batch = COALESCE($14, room.prev_batch)
unread_highlights = COALESCE($14, room.unread_highlights),
unread_notifications = COALESCE($15, room.unread_notifications),
unread_messages = COALESCE($16, room.unread_messages),
prev_batch = COALESCE($17, room.prev_batch)
WHERE room_id = $1
`
setRoomPrevBatchQuery = `
@ -143,8 +146,11 @@ type Room struct {
EncryptionEvent *event.EncryptionEventContent `json:"encryption_event,omitempty"`
HasMemberList bool `json:"has_member_list"`
PreviewEventRowID EventRowID `json:"preview_event_rowid"`
SortingTimestamp jsontime.UnixMilli `json:"sorting_timestamp"`
PreviewEventRowID EventRowID `json:"preview_event_rowid"`
SortingTimestamp jsontime.UnixMilli `json:"sorting_timestamp"`
UnreadHighlights int `json:"unread_highlights"`
UnreadNotifications int `json:"unread_notifications"`
UnreadMessages int `json:"unread_messages"`
PrevBatch string `json:"prev_batch"`
}
@ -188,6 +194,18 @@ func (r *Room) CheckChangesAndCopyInto(other *Room) (hasChanges bool) {
other.SortingTimestamp = r.SortingTimestamp
hasChanges = true
}
if r.UnreadHighlights != other.UnreadHighlights {
other.UnreadHighlights = r.UnreadHighlights
hasChanges = true
}
if r.UnreadNotifications != other.UnreadNotifications {
other.UnreadNotifications = r.UnreadNotifications
hasChanges = true
}
if r.UnreadMessages != other.UnreadMessages {
other.UnreadMessages = r.UnreadMessages
hasChanges = true
}
if r.PrevBatch != "" && other.PrevBatch == "" {
other.PrevBatch = r.PrevBatch
hasChanges = true
@ -212,6 +230,9 @@ func (r *Room) Scan(row dbutil.Scannable) (*Room, error) {
&r.HasMemberList,
&previewEventRowID,
&sortingTimestamp,
&r.UnreadHighlights,
&r.UnreadNotifications,
&r.UnreadMessages,
&prevBatch,
)
if err != nil {
@ -238,6 +259,9 @@ func (r *Room) sqlVariables() []any {
r.HasMemberList,
dbutil.NumPtr(r.PreviewEventRowID),
dbutil.UnixMilliPtr(r.SortingTimestamp.Time),
r.UnreadHighlights,
r.UnreadNotifications,
r.UnreadMessages,
dbutil.StrPtr(r.PrevBatch),
}
}

View file

@ -30,8 +30,10 @@ const (
DELETE FROM current_state WHERE room_id = $1
`
getCurrentRoomStateQuery = `
SELECT event.rowid, -1, event.room_id, event.event_id, sender, event.type, event.state_key, timestamp, content, decrypted, decrypted_type, unsigned,
transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid
SELECT event.rowid, -1,
event.room_id, event.event_id, sender, event.type, event.state_key, timestamp, content, decrypted, decrypted_type,
unsigned, local_content, transaction_id, redacted_by, relates_to, relation_type,
megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid, unread_type
FROM current_state cs
JOIN event ON cs.event_rowid = event.rowid
WHERE cs.room_id = $1

View file

@ -34,8 +34,10 @@ const (
`
findMinRowIDQuery = `SELECT MIN(rowid) FROM timeline`
getTimelineQuery = `
SELECT event.rowid, timeline.rowid, event.room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned,
transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid
SELECT event.rowid, timeline.rowid,
event.room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type,
unsigned, local_content, transaction_id, redacted_by, relates_to, relation_type,
megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid, unread_type
FROM timeline
JOIN event ON event.rowid = timeline.event_rowid
WHERE timeline.room_id = $1 AND ($2 = 0 OR timeline.rowid < $2)

View file

@ -1,4 +1,4 @@
-- v0 -> v2 (compatible with v1+): Latest revision
-- v0 -> v3 (compatible with v1+): Latest revision
CREATE TABLE account (
user_id TEXT NOT NULL PRIMARY KEY,
device_id TEXT NOT NULL,
@ -9,29 +9,34 @@ CREATE TABLE account (
) STRICT;
CREATE TABLE room (
room_id TEXT NOT NULL PRIMARY KEY,
creation_content TEXT,
room_id TEXT NOT NULL PRIMARY KEY,
creation_content TEXT,
name TEXT,
name_quality INTEGER NOT NULL DEFAULT 0,
avatar TEXT,
explicit_avatar INTEGER NOT NULL DEFAULT 0,
topic TEXT,
canonical_alias TEXT,
lazy_load_summary TEXT,
name TEXT,
name_quality INTEGER NOT NULL DEFAULT 0,
avatar TEXT,
explicit_avatar INTEGER NOT NULL DEFAULT 0,
topic TEXT,
canonical_alias TEXT,
lazy_load_summary TEXT,
encryption_event TEXT,
has_member_list INTEGER NOT NULL DEFAULT false,
encryption_event TEXT,
has_member_list INTEGER NOT NULL DEFAULT false,
preview_event_rowid INTEGER,
sorting_timestamp INTEGER,
preview_event_rowid INTEGER,
sorting_timestamp INTEGER,
unread_highlights INTEGER NOT NULL DEFAULT 0,
unread_notifications INTEGER NOT NULL DEFAULT 0,
unread_messages INTEGER NOT NULL DEFAULT 0,
prev_batch TEXT,
prev_batch TEXT,
CONSTRAINT room_preview_event_fkey FOREIGN KEY (preview_event_rowid) REFERENCES event (rowid) ON DELETE SET NULL
) STRICT;
CREATE INDEX room_type_idx ON room (creation_content ->> 'type');
CREATE INDEX room_sorting_timestamp_idx ON room (sorting_timestamp DESC);
-- CREATE INDEX room_sorting_timestamp_idx ON room (unread_notifications > 0);
-- CREATE INDEX room_sorting_timestamp_idx ON room (unread_messages > 0);
CREATE TABLE account_data (
user_id TEXT NOT NULL,
@ -66,6 +71,7 @@ CREATE TABLE event (
decrypted TEXT,
decrypted_type TEXT,
unsigned TEXT NOT NULL,
local_content TEXT,
transaction_id TEXT,
@ -79,6 +85,7 @@ CREATE TABLE event (
reactions TEXT,
last_edit_rowid INTEGER,
unread_type INTEGER NOT NULL DEFAULT 0,
CONSTRAINT event_id_unique_key UNIQUE (event_id),
CONSTRAINT transaction_id_unique_key UNIQUE (transaction_id),

View file

@ -0,0 +1,6 @@
-- v3 (compatible with v1+): Add more fields to events
ALTER TABLE event ADD COLUMN local_content TEXT;
ALTER TABLE event ADD COLUMN unread_type INTEGER NOT NULL DEFAULT 0;
ALTER TABLE room ADD COLUMN unread_highlights INTEGER NOT NULL DEFAULT 0;
ALTER TABLE room ADD COLUMN unread_notifications INTEGER NOT NULL DEFAULT 0;
ALTER TABLE room ADD COLUMN unread_messages INTEGER NOT NULL DEFAULT 0;

View file

@ -59,14 +59,14 @@ func (h *HiClient) handleReceivedMegolmSession(ctx context.Context, roomID id.Ro
log.Warn().Err(err).Stringer("event_id", evt.ID).Msg("Failed to decrypt event even after receiving megolm session")
} else {
decrypted = append(decrypted, evt)
h.cacheMedia(ctx, mautrixEvt, evt.RowID)
h.postDecryptProcess(ctx, nil, evt, mautrixEvt)
}
}
if len(decrypted) > 0 {
var newPreview database.EventRowID
err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
for _, evt := range decrypted {
err = h.DB.Event.UpdateDecrypted(ctx, evt.RowID, evt.Decrypted, evt.DecryptedType)
err = h.DB.Event.UpdateDecrypted(ctx, evt)
if err != nil {
return fmt.Errorf("failed to save decrypted content for %s: %w", evt.ID, err)
}

View file

@ -13,11 +13,17 @@ import (
)
type SyncRoom struct {
Meta *database.Room `json:"meta"`
Timeline []database.TimelineRowTuple `json:"timeline"`
State map[event.Type]map[string]database.EventRowID `json:"state"`
Events []*database.Event `json:"events"`
Reset bool `json:"reset"`
Meta *database.Room `json:"meta"`
Timeline []database.TimelineRowTuple `json:"timeline"`
State map[event.Type]map[string]database.EventRowID `json:"state"`
Events []*database.Event `json:"events"`
Reset bool `json:"reset"`
Notifications []SyncNotification `json:"notifications"`
}
type SyncNotification struct {
RowID database.EventRowID `json:"event_rowid"`
Sound bool `json:"sound"`
}
type SyncComplete struct {

476
hicli/html.go Normal file
View file

@ -0,0 +1,476 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package hicli
import (
"bytes"
"errors"
"fmt"
"io"
"net/url"
"regexp"
"slices"
"strconv"
"strings"
"golang.org/x/net/html"
"golang.org/x/net/html/atom"
"mvdan.cc/xurls/v2"
"maunium.net/go/mautrix/id"
)
func tagIsAllowed(tag atom.Atom) bool {
switch tag {
case atom.Del, atom.H1, atom.H2, atom.H3, atom.H4, atom.H5, atom.H6, atom.Blockquote, atom.P,
atom.A, atom.Ul, atom.Ol, atom.Sup, atom.Sub, atom.Li, atom.B, atom.I, atom.U, atom.Strong,
atom.Em, atom.S, atom.Code, atom.Hr, atom.Br, atom.Div, atom.Table, atom.Thead, atom.Tbody,
atom.Tr, atom.Th, atom.Td, atom.Caption, atom.Pre, atom.Span, atom.Font, atom.Img,
atom.Details, atom.Summary:
return true
default:
return false
}
}
func isSelfClosing(tag atom.Atom) bool {
switch tag {
case atom.Img, atom.Br, atom.Hr:
return true
default:
return false
}
}
var languageRegex = regexp.MustCompile(`^language-[a-zA-Z0-9-]+$`)
var allowedColorRegex = regexp.MustCompile(`^#[0-9a-fA-F]{6}$`)
// This is approximately a mirror of web/src/util/mediasize.ts in gomuks
func calculateMediaSize(widthInt, heightInt int) (width, height float64, ok bool) {
if widthInt <= 0 || heightInt <= 0 {
return
}
width = float64(widthInt)
height = float64(heightInt)
const imageContainerWidth float64 = 320
const imageContainerHeight float64 = 240
const imageContainerAspectRatio = imageContainerWidth / imageContainerHeight
if width > imageContainerWidth || height > imageContainerHeight {
aspectRatio := width / height
if aspectRatio > imageContainerAspectRatio {
width = imageContainerWidth
height = imageContainerWidth / aspectRatio
} else if aspectRatio < imageContainerAspectRatio {
width = imageContainerHeight * aspectRatio
height = imageContainerHeight
} else {
width = imageContainerWidth
height = imageContainerHeight
}
}
ok = true
return
}
func parseImgAttributes(attrs []html.Attribute) (src, alt, title string, isCustomEmoji bool, width, height int) {
for _, attr := range attrs {
switch attr.Key {
case "src":
src = attr.Val
case "alt":
alt = attr.Val
case "title":
title = attr.Val
case "data-mx-emoticon":
isCustomEmoji = true
case "width":
width, _ = strconv.Atoi(attr.Val)
case "height":
height, _ = strconv.Atoi(attr.Val)
}
}
return
}
func parseSpanAttributes(attrs []html.Attribute) (bgColor, textColor, spoiler, maths string, isSpoiler bool) {
for _, attr := range attrs {
switch attr.Key {
case "data-mx-bg-color":
if allowedColorRegex.MatchString(attr.Val) {
bgColor = attr.Val
}
case "data-mx-color", "color":
if allowedColorRegex.MatchString(attr.Val) {
textColor = attr.Val
}
case "data-mx-spoiler":
spoiler = attr.Val
isSpoiler = true
case "data-mx-maths":
maths = attr.Val
}
}
return
}
func parseAAttributes(attrs []html.Attribute) (href string) {
for _, attr := range attrs {
switch attr.Key {
case "href":
href = strings.TrimSpace(attr.Val)
}
}
return
}
func attributeIsAllowed(tag atom.Atom, attr html.Attribute) bool {
switch tag {
case atom.Ol:
switch attr.Key {
case "start":
_, err := strconv.Atoi(attr.Val)
return err == nil
}
case atom.Code:
switch attr.Key {
case "class":
return languageRegex.MatchString(attr.Val)
}
case atom.Div:
switch attr.Key {
case "data-mx-maths":
return true
}
}
return false
}
// Funny user IDs will just need to be linkified by the sender, no auto-linkification for them.
var plainUserOrAliasMentionRegex = regexp.MustCompile(`[@#][a-zA-Z0-9._=/+-]{0,254}:[a-zA-Z0-9.-]+(?:\d{1,5})?`)
func getNextItem(items [][]int, minIndex int) (index, start, end int, ok bool) {
for i, item := range items {
if item[0] >= minIndex {
return i, item[0], item[1], true
}
}
return -1, -1, -1, false
}
func writeMention(w *strings.Builder, mention []byte) {
w.WriteString(`<a class="hicli-matrix-uri"`)
writeAttribute(w, "href", (&id.MatrixURI{
Sigil1: rune(mention[0]),
MXID1: string(mention[1:]),
}).String())
w.WriteByte('>')
writeEscapedBytes(w, mention)
w.WriteString("</a>")
}
func writeURL(w *strings.Builder, addr []byte) {
parsedURL, err := url.Parse(string(addr))
if err != nil {
writeEscapedBytes(w, addr)
return
}
if parsedURL.Scheme == "" {
parsedURL.Scheme = "https"
}
w.WriteString(`<a target="_blank" rel="noreferrer noopener"`)
writeAttribute(w, "href", parsedURL.String())
w.WriteByte('>')
writeEscapedBytes(w, addr)
w.WriteString("</a>")
}
func linkifyAndWriteBytes(w *strings.Builder, s []byte) {
mentions := plainUserOrAliasMentionRegex.FindAllIndex(s, -1)
urls := xurls.Relaxed().FindAllIndex(s, -1)
minIndex := 0
for {
mentionIdx, nextMentionStart, nextMentionEnd, hasMention := getNextItem(mentions, minIndex)
urlIdx, nextURLStart, nextURLEnd, hasURL := getNextItem(urls, minIndex)
if hasMention && (!hasURL || nextMentionStart <= nextURLStart) {
writeEscapedBytes(w, s[minIndex:nextMentionStart])
writeMention(w, s[nextMentionStart:nextMentionEnd])
minIndex = nextMentionEnd
mentions = mentions[mentionIdx:]
} else if hasURL && (!hasMention || nextURLStart < nextMentionStart) {
writeEscapedBytes(w, s[minIndex:nextURLStart])
writeURL(w, s[nextURLStart:nextURLEnd])
minIndex = nextURLEnd
urls = urls[urlIdx:]
} else {
break
}
}
writeEscapedBytes(w, s[minIndex:])
}
const escapedChars = "&'<>\"\r"
func writeEscapedBytes(w *strings.Builder, s []byte) {
i := bytes.IndexAny(s, escapedChars)
for i != -1 {
w.Write(s[:i])
var esc string
switch s[i] {
case '&':
esc = "&amp;"
case '\'':
// "&#39;" is shorter than "&apos;" and apos was not in HTML until HTML5.
esc = "&#39;"
case '<':
esc = "&lt;"
case '>':
esc = "&gt;"
case '"':
// "&#34;" is shorter than "&quot;".
esc = "&#34;"
case '\r':
esc = "&#13;"
default:
panic("unrecognized escape character")
}
s = s[i+1:]
w.WriteString(esc)
i = bytes.IndexAny(s, escapedChars)
}
w.Write(s)
}
func writeEscapedString(w *strings.Builder, s string) {
i := strings.IndexAny(s, escapedChars)
for i != -1 {
w.WriteString(s[:i])
var esc string
switch s[i] {
case '&':
esc = "&amp;"
case '\'':
// "&#39;" is shorter than "&apos;" and apos was not in HTML until HTML5.
esc = "&#39;"
case '<':
esc = "&lt;"
case '>':
esc = "&gt;"
case '"':
// "&#34;" is shorter than "&quot;".
esc = "&#34;"
case '\r':
esc = "&#13;"
default:
panic("unrecognized escape character")
}
s = s[i+1:]
w.WriteString(esc)
i = strings.IndexAny(s, escapedChars)
}
w.WriteString(s)
}
func writeAttribute(w *strings.Builder, key, value string) {
w.WriteByte(' ')
w.WriteString(key)
w.WriteString(`="`)
writeEscapedString(w, value)
w.WriteByte('"')
}
func writeA(w *strings.Builder, attr []html.Attribute) {
w.WriteString("<a")
href := parseAAttributes(attr)
if href == "" {
return
}
parsedURL, err := url.Parse(href)
if err != nil {
return
}
newTab := true
switch parsedURL.Scheme {
case "bitcoin", "ftp", "geo", "http", "im", "irc", "ircs", "magnet", "mailto",
"mms", "news", "nntp", "openpgp4fpr", "sip", "sftp", "sms", "smsto", "ssh",
"tel", "urn", "webcal", "wtai", "xmpp":
// allowed
case "https":
if parsedURL.Host == "matrix.to" {
uri, err := id.ProcessMatrixToURL(parsedURL)
if err != nil {
return
}
href = uri.String()
newTab = false
writeAttribute(w, "class", "hicli-matrix-uri")
}
case "matrix":
uri, err := id.ProcessMatrixURI(parsedURL)
if err != nil {
return
}
href = uri.String()
newTab = false
writeAttribute(w, "class", "hicli-matrix-uri")
case "mxc":
mxc := id.ContentURIString(href).ParseOrIgnore()
if !mxc.IsValid() {
return
}
href = fmt.Sprintf(HTMLSanitizerImgSrcTemplate, mxc.Homeserver, mxc.FileID)
default:
return
}
writeAttribute(w, "href", href)
if newTab {
writeAttribute(w, "target", "_blank")
writeAttribute(w, "rel", "noreferrer noopener")
}
}
var HTMLSanitizerImgSrcTemplate = "mxc://%s/%s"
func writeImg(w *strings.Builder, attr []html.Attribute) {
src, alt, title, isCustomEmoji, width, height := parseImgAttributes(attr)
w.WriteString("<img")
writeAttribute(w, "alt", alt)
if title != "" {
writeAttribute(w, "title", title)
}
mxc := id.ContentURIString(src).ParseOrIgnore()
if !mxc.IsValid() {
return
}
writeAttribute(w, "src", fmt.Sprintf(HTMLSanitizerImgSrcTemplate, mxc.Homeserver, mxc.FileID))
writeAttribute(w, "loading", "lazy")
if isCustomEmoji {
writeAttribute(w, "class", "hicli-custom-emoji")
} else if cWidth, cHeight, sizeOK := calculateMediaSize(width, height); sizeOK {
writeAttribute(w, "class", "hicli-sized-inline-img")
writeAttribute(w, "style", fmt.Sprintf("width: %.2fpx; height: %.2fpx;", cWidth, cHeight))
} else {
writeAttribute(w, "class", "hicli-sizeless-inline-img")
}
}
func writeSpan(w *strings.Builder, attr []html.Attribute) {
bgColor, textColor, spoiler, _, isSpoiler := parseSpanAttributes(attr)
if isSpoiler && spoiler != "" {
w.WriteString(`<span class="spoiler-reason">`)
w.WriteString(spoiler)
w.WriteString(" </span>")
}
w.WriteByte('<')
w.WriteString("span")
if isSpoiler {
writeAttribute(w, "class", "hicli-spoiler")
}
var style string
if bgColor != "" {
style += fmt.Sprintf("background-color: %s;", bgColor)
}
if textColor != "" {
style += fmt.Sprintf("color: %s;", textColor)
}
if style != "" {
writeAttribute(w, "style", style)
}
}
type tagStack []atom.Atom
func (ts *tagStack) contains(tags ...atom.Atom) bool {
for i := len(*ts) - 1; i >= 0; i-- {
for _, tag := range tags {
if (*ts)[i] == tag {
return true
}
}
}
return false
}
func (ts *tagStack) push(tag atom.Atom) {
*ts = append(*ts, tag)
}
func (ts *tagStack) pop(tag atom.Atom) bool {
if len(*ts) > 0 && (*ts)[len(*ts)-1] == tag {
*ts = (*ts)[:len(*ts)-1]
return true
}
return false
}
func sanitizeAndLinkifyHTML(body string) (string, error) {
tz := html.NewTokenizer(strings.NewReader(body))
var built strings.Builder
ts := make(tagStack, 2)
Loop:
for {
switch tz.Next() {
case html.ErrorToken:
err := tz.Err()
if errors.Is(err, io.EOF) {
break Loop
}
return "", err
case html.StartTagToken, html.SelfClosingTagToken:
token := tz.Token()
if !tagIsAllowed(token.DataAtom) {
continue
}
tagIsSelfClosing := isSelfClosing(token.DataAtom)
if token.Type == html.SelfClosingTagToken && !tagIsSelfClosing {
continue
}
switch token.DataAtom {
case atom.A:
writeA(&built, token.Attr)
case atom.Img:
writeImg(&built, token.Attr)
case atom.Span, atom.Font:
writeSpan(&built, token.Attr)
default:
built.WriteByte('<')
built.WriteString(token.Data)
for _, attr := range token.Attr {
if attributeIsAllowed(token.DataAtom, attr) {
writeAttribute(&built, attr.Key, attr.Val)
}
}
}
built.WriteByte('>')
if !tagIsSelfClosing {
ts.push(token.DataAtom)
}
case html.EndTagToken:
tagName, _ := tz.TagName()
tag := atom.Lookup(tagName)
if tagIsAllowed(tag) && ts.pop(tag) {
built.WriteString("</")
built.Write(tagName)
built.WriteByte('>')
}
case html.TextToken:
if ts.contains(atom.Pre, atom.Code, atom.A) {
writeEscapedBytes(&built, tz.Text())
} else {
linkifyAndWriteBytes(&built, tz.Text())
}
case html.DoctypeToken, html.CommentToken:
// ignore
}
}
slices.Reverse(ts)
for _, t := range ts {
built.WriteString("</")
built.WriteString(t.String())
built.WriteByte('>')
}
return built.String(), nil
}

View file

@ -59,7 +59,7 @@ func (h *HiClient) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.Ev
} else if serverEvt, err := h.Client.GetEvent(ctx, roomID, eventID); err != nil {
return nil, fmt.Errorf("failed to get event from server: %w", err)
} else {
return h.processEvent(ctx, serverEvt, nil, false)
return h.processEvent(ctx, serverEvt, nil, nil, false)
}
}
@ -90,7 +90,7 @@ func (h *HiClient) GetRoomState(ctx context.Context, roomID id.RoomID, fetchMemb
}
entries := make([]*database.CurrentStateEntry, len(evts))
for i, evt := range evts {
dbEvt, err := h.processEvent(ctx, evt, nil, false)
dbEvt, err := h.processEvent(ctx, evt, room.LazyLoadSummary, nil, false)
if err != nil {
return fmt.Errorf("failed to process event %s: %w", evt.ID, err)
}
@ -186,7 +186,7 @@ func (h *HiClient) PaginateServer(ctx context.Context, roomID id.RoomID, limit i
decryptionQueue := make(map[id.SessionID]*database.SessionRequest)
iOffset := 0
for i, evt := range resp.Chunk {
dbEvt, err := h.processEvent(ctx, evt, decryptionQueue, true)
dbEvt, err := h.processEvent(ctx, evt, room.LazyLoadSummary, decryptionQueue, true)
if err != nil {
return err
} else if exists, err := h.DB.Timeline.Has(ctx, roomID, dbEvt.RowID); err != nil {

80
hicli/pushrules.go Normal file
View file

@ -0,0 +1,80 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package hicli
import (
"context"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/hicli/database"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/pushrules"
)
type pushRoom struct {
ctx context.Context
roomID id.RoomID
h *HiClient
ll *mautrix.LazyLoadSummary
}
func (p *pushRoom) GetOwnDisplayname() string {
// TODO implement
return ""
}
func (p *pushRoom) GetMemberCount() int {
if p.ll == nil {
room, err := p.h.DB.Room.Get(p.ctx, p.roomID)
if err != nil {
zerolog.Ctx(p.ctx).Err(err).
Stringer("room_id", p.roomID).
Msg("Failed to get room by ID in push rule evaluator")
} else if room != nil {
p.ll = room.LazyLoadSummary
}
}
if p.ll != nil && p.ll.JoinedMemberCount != nil {
return *p.ll.JoinedMemberCount
}
// TODO query db?
return 0
}
func (p *pushRoom) GetEvent(id id.EventID) *event.Event {
evt, err := p.h.DB.Event.GetByID(p.ctx, id)
if err != nil {
zerolog.Ctx(p.ctx).Err(err).
Stringer("event_id", id).
Msg("Failed to get event by ID in push rule evaluator")
}
return evt.AsRawMautrix()
}
var _ pushrules.EventfulRoom = (*pushRoom)(nil)
func (h *HiClient) evaluatePushRules(ctx context.Context, llSummary *mautrix.LazyLoadSummary, baseType database.UnreadType, evt *event.Event) database.UnreadType {
should := h.PushRules.Load().GetMatchingRule(&pushRoom{
ctx: ctx,
roomID: evt.RoomID,
h: h,
ll: llSummary,
}, evt).GetActions().Should()
if should.Notify {
baseType |= database.UnreadTypeNotify
}
if should.Highlight {
baseType |= database.UnreadTypeHighlight
}
if should.PlaySound {
baseType |= database.UnreadTypeSound
}
return baseType
}

View file

@ -218,7 +218,7 @@ func (h *HiClient) loadMembers(ctx context.Context, room *database.Room) error {
err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
entries := make([]*database.CurrentStateEntry, len(resp.Chunk))
for i, evt := range resp.Chunk {
dbEvt, err := h.processEvent(ctx, evt, nil, true)
dbEvt, err := h.processEvent(ctx, evt, nil, nil, true)
if err != nil {
return err
}

View file

@ -134,11 +134,15 @@ func (h *HiClient) processSyncResponse(ctx context.Context, resp *mautrix.RespSy
return nil
}
func receiptsToList(content *event.ReceiptEventContent) []*database.Receipt {
func (h *HiClient) receiptsToList(content *event.ReceiptEventContent) ([]*database.Receipt, []id.EventID) {
receiptList := make([]*database.Receipt, 0)
var newOwnReceipts []id.EventID
for eventID, receipts := range *content {
for receiptType, users := range receipts {
for userID, receiptInfo := range users {
if userID == h.Account.UserID {
newOwnReceipts = append(newOwnReceipts, eventID)
}
receiptList = append(receiptList, &database.Receipt{
UserID: userID,
ReceiptType: receiptType,
@ -149,7 +153,12 @@ func receiptsToList(content *event.ReceiptEventContent) []*database.Receipt {
}
}
}
return receiptList
return receiptList, newOwnReceipts
}
type receiptsToSave struct {
roomID id.RoomID
receipts []*database.Receipt
}
func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncJoinedRoom) error {
@ -172,10 +181,8 @@ func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID,
return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err)
}
}
err = h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary)
if err != nil {
return err
}
var receipts []receiptsToSave
var newOwnReceipts []id.EventID
for _, evt := range room.Ephemeral.Events {
evt.Type.Class = event.EphemeralEventType
err = evt.Content.ParseRaw(evt.Type)
@ -185,18 +192,24 @@ func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID,
}
switch evt.Type {
case event.EphemeralEventReceipt:
err = h.DB.Receipt.PutMany(ctx, roomID, receiptsToList(evt.Content.AsReceipt())...)
if err != nil {
return fmt.Errorf("failed to save receipts: %w", err)
}
var receiptsList []*database.Receipt
receiptsList, newOwnReceipts = h.receiptsToList(evt.Content.AsReceipt())
receipts = append(receipts, receiptsToSave{roomID, receiptsList})
case event.EphemeralEventTyping:
go h.EventHandler(&Typing{
RoomID: roomID,
TypingEventContent: *evt.Content.AsTyping(),
})
}
if evt.Type != event.EphemeralEventReceipt {
continue
}
err = h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary, newOwnReceipts, room.UnreadNotifications)
if err != nil {
return err
}
for _, rs := range receipts {
err = h.DB.Receipt.PutMany(ctx, rs.roomID, rs.receipts...)
if err != nil {
return fmt.Errorf("failed to save receipts: %w", err)
}
}
return nil
@ -209,7 +222,8 @@ func (h *HiClient) processSyncLeftRoom(ctx context.Context, roomID id.RoomID, ro
} else if existingRoomData == nil {
return nil
}
return h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary)
// TODO delete room instead of processing?
return h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary, nil, nil)
}
func isDecryptionErrorRetryable(err error) bool {
@ -318,7 +332,47 @@ func (h *HiClient) cacheMedia(ctx context.Context, evt *event.Event, rowID datab
}
}
func (h *HiClient) processEvent(ctx context.Context, evt *event.Event, decryptionQueue map[id.SessionID]*database.SessionRequest, checkDB bool) (*database.Event, error) {
func (h *HiClient) calculateLocalContent(ctx context.Context, dbEvt *database.Event, evt *event.Event) *database.LocalContent {
if evt.Type != event.EventMessage && evt.Type != event.EventSticker {
return nil
}
_ = evt.Content.ParseRaw(evt.Type)
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
if !ok {
return nil
}
if dbEvt.RelationType == event.RelReplace && content.NewContent != nil {
content = content.NewContent
}
if content != nil {
var sanitizedHTML string
if content.Format == event.FormatHTML {
sanitizedHTML, _ = sanitizeAndLinkifyHTML(content.FormattedBody)
} else {
var builder strings.Builder
linkifyAndWriteBytes(&builder, []byte(content.Body))
sanitizedHTML = builder.String()
}
return &database.LocalContent{SanitizedHTML: sanitizedHTML}
}
return nil
}
func (h *HiClient) postDecryptProcess(ctx context.Context, llSummary *mautrix.LazyLoadSummary, dbEvt *database.Event, evt *event.Event) {
if dbEvt.RowID != 0 {
h.cacheMedia(ctx, evt, dbEvt.RowID)
}
dbEvt.UnreadType = h.evaluatePushRules(ctx, llSummary, dbEvt.GetNonPushUnreadType(), evt)
dbEvt.LocalContent = h.calculateLocalContent(ctx, dbEvt, evt)
}
func (h *HiClient) processEvent(
ctx context.Context,
evt *event.Event,
llSummary *mautrix.LazyLoadSummary,
decryptionQueue map[id.SessionID]*database.SessionRequest,
checkDB bool,
) (*database.Event, error) {
if checkDB {
dbEvt, err := h.DB.Event.GetByID(ctx, evt.ID)
if err != nil {
@ -350,6 +404,11 @@ func (h *HiClient) processEvent(ctx context.Context, evt *event.Event, decryptio
evt.Redacts = id.EventID(gjson.GetBytes(evt.Content.VeryRaw, "redacts").Str)
}
}
if decryptedMautrixEvt != nil {
h.postDecryptProcess(ctx, llSummary, dbEvt, decryptedMautrixEvt)
} else {
h.postDecryptProcess(ctx, llSummary, dbEvt, evt)
}
_, err := h.DB.Event.Upsert(ctx, dbEvt)
if err != nil {
return dbEvt, fmt.Errorf("failed to save event %s: %w", evt.ID, err)
@ -386,12 +445,27 @@ func (h *HiClient) processEvent(ctx context.Context, evt *event.Event, decryptio
return dbEvt, err
}
func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.Room, state *mautrix.SyncEventsList, timeline *mautrix.SyncTimeline, summary *mautrix.LazyLoadSummary) error {
func (h *HiClient) processStateAndTimeline(
ctx context.Context,
room *database.Room,
state *mautrix.SyncEventsList,
timeline *mautrix.SyncTimeline,
summary *mautrix.LazyLoadSummary,
newOwnReceipts []id.EventID,
serverNotificationCounts *mautrix.UnreadNotificationCounts,
) error {
updatedRoom := &database.Room{
ID: room.ID,
SortingTimestamp: room.SortingTimestamp,
NameQuality: room.NameQuality,
SortingTimestamp: room.SortingTimestamp,
NameQuality: room.NameQuality,
UnreadHighlights: room.UnreadHighlights,
UnreadNotifications: room.UnreadNotifications,
UnreadMessages: room.UnreadMessages,
}
if serverNotificationCounts != nil {
updatedRoom.UnreadHighlights = serverNotificationCounts.HighlightCount
updatedRoom.UnreadNotifications = serverNotificationCounts.NotificationCount
}
heroesChanged := false
if summary.Heroes == nil && summary.JoinedMemberCount == nil && summary.InvitedMemberCount == nil {
@ -405,6 +479,7 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R
}
decryptionQueue := make(map[id.SessionID]*database.SessionRequest)
allNewEvents := make([]*database.Event, 0, len(state.Events)+len(timeline.Events))
newNotifications := make([]SyncNotification, 0)
recalculatePreviewEvent := false
addOldEvent := func(rowID database.EventRowID, evtID id.EventID) (dbEvt *database.Event, err error) {
if rowID != 0 {
@ -440,12 +515,18 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R
}
return nil
}
processNewEvent := func(evt *event.Event, isTimeline bool) (database.EventRowID, error) {
processNewEvent := func(evt *event.Event, isTimeline, isUnread bool) (database.EventRowID, error) {
evt.RoomID = room.ID
dbEvt, err := h.processEvent(ctx, evt, decryptionQueue, false)
dbEvt, err := h.processEvent(ctx, evt, summary, decryptionQueue, false)
if err != nil {
return -1, err
}
if isUnread && dbEvt.UnreadType.Is(database.UnreadTypeNotify) {
newNotifications = append(newNotifications, SyncNotification{
RowID: dbEvt.RowID,
Sound: dbEvt.UnreadType.Is(database.UnreadTypeSound),
})
}
if isTimeline {
if dbEvt.CanUseForPreview() {
updatedRoom.PreviewEventRowID = dbEvt.RowID
@ -492,7 +573,7 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R
}
for _, evt := range state.Events {
evt.Type.Class = event.StateEventType
rowID, err := processNewEvent(evt, false)
rowID, err := processNewEvent(evt, false, false)
if err != nil {
return err
}
@ -502,13 +583,20 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R
var err error
if len(timeline.Events) > 0 {
timelineIDs := make([]database.EventRowID, len(timeline.Events))
readUpToIndex := -1
for i := len(timeline.Events) - 1; i >= 0; i-- {
if slices.Contains(newOwnReceipts, timeline.Events[i].ID) {
readUpToIndex = i
break
}
}
for i, evt := range timeline.Events {
if evt.StateKey != nil {
evt.Type.Class = event.StateEventType
} else {
evt.Type.Class = event.MessageEventType
}
timelineIDs[i], err = processNewEvent(evt, true)
timelineIDs[i], err = processNewEvent(evt, true, i > readUpToIndex)
if err != nil {
return err
}
@ -578,11 +666,12 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R
}
if roomChanged || len(timelineRowTuples) > 0 || len(allNewEvents) > 0 {
ctx.Value(syncContextKey).(*syncContext).evt.Rooms[room.ID] = &SyncRoom{
Meta: room,
Timeline: timelineRowTuples,
State: changedState,
Reset: timeline.Limited,
Events: allNewEvents,
Meta: room,
Timeline: timelineRowTuples,
State: changedState,
Reset: timeline.Limited,
Events: allNewEvents,
Notifications: newNotifications,
}
}
return nil

View file

@ -68,6 +68,9 @@ func (rs *PushRuleset) MarshalJSON() ([]byte, error) {
var DefaultPushActions = PushActionArray{&PushAction{Action: ActionDontNotify}}
func (rs *PushRuleset) GetMatchingRule(room Room, evt *event.Event) (rule *PushRule) {
if rs == nil {
return nil
}
// Add push rule collections to array in priority order
arrays := []PushRuleCollection{rs.Override, rs.Content, rs.Room, rs.Sender, rs.Underride}
// Loop until one of the push rule collections matches the room/event combo.