From c1eb217b9e02531e2fdf630e0df2bed1b4390d51 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 26 May 2024 18:29:22 +0300 Subject: [PATCH] Add draft of high-level client framework --- .gitignore | 2 +- hicli/cryptohelper.go | 82 ++++ hicli/database/account.go | 65 ++++ hicli/database/accountdata.go | 71 ++++ hicli/database/database.go | 60 +++ hicli/database/event.go | 193 ++++++++++ hicli/database/room.go | 115 ++++++ hicli/database/sessionrequest.go | 69 ++++ hicli/database/state.go | 32 ++ hicli/database/statestore.go | 149 +++++++ hicli/database/timeline.go | 47 +++ .../database/upgrades/00-latest-revision.sql | 107 ++++++ hicli/database/upgrades/upgrades.go | 22 ++ hicli/decryptionqueue.go | 194 ++++++++++ hicli/hicli.go | 159 ++++++++ hicli/hitest/hitest.go | 69 ++++ hicli/login.go | 77 ++++ hicli/sync.go | 362 ++++++++++++++++++ hicli/syncwrap.go | 100 +++++ hicli/verify.go | 158 ++++++++ 20 files changed, 2132 insertions(+), 1 deletion(-) create mode 100644 hicli/cryptohelper.go create mode 100644 hicli/database/account.go create mode 100644 hicli/database/accountdata.go create mode 100644 hicli/database/database.go create mode 100644 hicli/database/event.go create mode 100644 hicli/database/room.go create mode 100644 hicli/database/sessionrequest.go create mode 100644 hicli/database/state.go create mode 100644 hicli/database/statestore.go create mode 100644 hicli/database/timeline.go create mode 100644 hicli/database/upgrades/00-latest-revision.sql create mode 100644 hicli/database/upgrades/upgrades.go create mode 100644 hicli/decryptionqueue.go create mode 100644 hicli/hicli.go create mode 100644 hicli/hitest/hitest.go create mode 100644 hicli/login.go create mode 100644 hicli/sync.go create mode 100644 hicli/syncwrap.go create mode 100644 hicli/verify.go diff --git a/.gitignore b/.gitignore index f37a7d0c..c01f2f30 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ .idea/ .vscode/ -*.db +*.db* *.log diff --git a/hicli/cryptohelper.go b/hicli/cryptohelper.go new file mode 100644 index 00000000..eb054af9 --- /dev/null +++ b/hicli/cryptohelper.go @@ -0,0 +1,82 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type hiCryptoHelper HiClient + +var _ mautrix.CryptoHelper = (*hiCryptoHelper)(nil) + +func (h *hiCryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtType event.Type, content any) (encrypted *event.EncryptedEventContent, err error) { + h.encryptLock.Lock() + defer h.encryptLock.Unlock() + encrypted, err = h.Crypto.EncryptMegolmEvent(ctx, roomID, evtType, content) + if err != nil { + if !errors.Is(err, crypto.SessionExpired) && !errors.Is(err, crypto.NoGroupSession) && !errors.Is(err, crypto.SessionNotShared) { + return + } + h.Log.Debug(). + Err(err). + Str("room_id", roomID.String()). + Msg("Got session error while encrypting event, sharing group session and trying again") + var users []id.UserID + users, err = h.ClientStore.GetRoomJoinedOrInvitedMembers(ctx, roomID) + if err != nil { + err = fmt.Errorf("failed to get room member list: %w", err) + } else if err = h.Crypto.ShareGroupSession(ctx, roomID, users); err != nil { + err = fmt.Errorf("failed to share group session: %w", err) + } else if encrypted, err = h.Crypto.EncryptMegolmEvent(ctx, roomID, evtType, content); err != nil { + err = fmt.Errorf("failed to encrypt event after re-sharing group session: %w", err) + } + } + return +} + +func (h *hiCryptoHelper) Decrypt(ctx context.Context, evt *event.Event) (*event.Event, error) { + return h.Crypto.DecryptMegolmEvent(ctx, evt) +} + +func (h *hiCryptoHelper) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool { + return h.Crypto.WaitForSession(ctx, roomID, senderKey, sessionID, timeout) +} + +func (h *hiCryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) { + err := h.Crypto.SendRoomKeyRequest(ctx, roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{ + userID: {deviceID}, + h.Account.UserID: {"*"}, + }) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("room_id", roomID). + Stringer("session_id", sessionID). + Stringer("user_id", userID). + Msg("Failed to send room key request") + } else { + zerolog.Ctx(ctx).Debug(). + Stringer("room_id", roomID). + Stringer("session_id", sessionID). + Stringer("user_id", userID). + Msg("Sent room key request") + } +} + +func (h *hiCryptoHelper) Init(ctx context.Context) error { + return nil +} diff --git a/hicli/database/account.go b/hicli/database/account.go new file mode 100644 index 00000000..49b50771 --- /dev/null +++ b/hicli/database/account.go @@ -0,0 +1,65 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/id" +) + +const ( + getAccountQuery = `SELECT user_id, device_id, access_token, homeserver_url, next_batch FROM account WHERE user_id = $1` + putNextBatchQuery = `UPDATE account SET next_batch = $1 WHERE user_id = $2` + upsertAccountQuery = ` + INSERT INTO account (user_id, device_id, access_token, homeserver_url, next_batch) + VALUES ($1, $2, $3, $4, $5) ON CONFLICT (user_id) + DO UPDATE SET device_id = excluded.device_id, + access_token = excluded.access_token, + homeserver_url = excluded.homeserver_url, + next_batch = excluded.next_batch + ` +) + +type AccountQuery struct { + *dbutil.QueryHelper[*Account] +} + +func (aq *AccountQuery) GetFirstUserID(ctx context.Context) (userID id.UserID, err error) { + err = aq.GetDB().QueryRow(ctx, `SELECT user_id FROM account LIMIT 1`).Scan(&userID) + return +} + +func (aq *AccountQuery) Get(ctx context.Context, userID id.UserID) (*Account, error) { + return aq.QueryOne(ctx, getAccountQuery, userID) +} + +func (aq *AccountQuery) PutNextBatch(ctx context.Context, userID id.UserID, nextBatch string) error { + return aq.Exec(ctx, putNextBatchQuery, nextBatch, userID) +} + +func (aq *AccountQuery) Put(ctx context.Context, account *Account) error { + return aq.Exec(ctx, upsertAccountQuery, account.sqlVariables()...) +} + +type Account struct { + UserID id.UserID + DeviceID id.DeviceID + AccessToken string + HomeserverURL string + NextBatch string +} + +func (a *Account) Scan(row dbutil.Scannable) (*Account, error) { + return dbutil.ValueOrErr(a, row.Scan(&a.UserID, &a.DeviceID, &a.AccessToken, &a.HomeserverURL, &a.NextBatch)) +} + +func (a *Account) sqlVariables() []any { + return []any{a.UserID, a.DeviceID, a.AccessToken, a.HomeserverURL, a.NextBatch} +} diff --git a/hicli/database/accountdata.go b/hicli/database/accountdata.go new file mode 100644 index 00000000..963886c3 --- /dev/null +++ b/hicli/database/accountdata.go @@ -0,0 +1,71 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + "encoding/json" + "unsafe" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +const ( + upsertAccountDataQuery = ` + INSERT INTO account_data (user_id, type, content) VALUES ($1, $2, $3) + ON CONFLICT (user_id, type) DO UPDATE SET content = excluded.content + ` + upsertRoomAccountDataQuery = ` + INSERT INTO room_account_data (user_id, room_id, type, content) VALUES ($1, $2, $3, $4) + ON CONFLICT (user_id, room_id, type) DO UPDATE SET content = excluded.content + ` +) + +type AccountDataQuery struct { + *dbutil.QueryHelper[*AccountData] +} + +func unsafeJSONString(content json.RawMessage) *string { + if content == nil { + return nil + } + str := unsafe.String(unsafe.SliceData(content), len(content)) + return &str +} + +func (adq *AccountDataQuery) Put(ctx context.Context, userID id.UserID, eventType event.Type, content json.RawMessage) error { + return adq.Exec(ctx, upsertAccountDataQuery, userID, eventType.Type, unsafeJSONString(content)) +} + +func (adq *AccountDataQuery) PutRoom(ctx context.Context, userID id.UserID, roomID id.RoomID, eventType event.Type, content json.RawMessage) error { + return adq.Exec(ctx, upsertRoomAccountDataQuery, userID, roomID, eventType.Type, unsafeJSONString(content)) +} + +type AccountData struct { + UserID id.UserID + RoomID id.RoomID + Type string + Content json.RawMessage +} + +func (a *AccountData) Scan(row dbutil.Scannable) (*AccountData, error) { + var roomID sql.NullString + err := row.Scan(&a.UserID, &roomID, &a.Type, (*[]byte)(&a.Content)) + if err != nil { + return nil, err + } + a.RoomID = id.RoomID(roomID.String) + return a, nil +} + +func (a *AccountData) sqlVariables() []any { + return []any{a.UserID, dbutil.StrPtr(a.RoomID), a.Type, unsafeJSONString(a.Content)} +} diff --git a/hicli/database/database.go b/hicli/database/database.go new file mode 100644 index 00000000..c1273ab7 --- /dev/null +++ b/hicli/database/database.go @@ -0,0 +1,60 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/hicli/database/upgrades" +) + +type Database struct { + *dbutil.Database + + Account AccountQuery + AccountData AccountDataQuery + Room RoomQuery + Event EventQuery + CurrentState CurrentStateQuery + Timeline TimelineQuery + SessionRequest SessionRequestQuery +} + +func New(rawDB *dbutil.Database) *Database { + rawDB.UpgradeTable = upgrades.Table + return &Database{ + Database: rawDB, + + Account: AccountQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newAccount)}, + AccountData: AccountDataQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newAccountData)}, + Room: RoomQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newRoom)}, + Event: EventQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newEvent)}, + CurrentState: CurrentStateQuery{Database: rawDB}, + Timeline: TimelineQuery{Database: rawDB}, + SessionRequest: SessionRequestQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newSessionRequest)}, + } +} + +func newSessionRequest(_ *dbutil.QueryHelper[*SessionRequest]) *SessionRequest { + return &SessionRequest{} +} + +func newEvent(_ *dbutil.QueryHelper[*Event]) *Event { + return &Event{} +} + +func newRoom(_ *dbutil.QueryHelper[*Room]) *Room { + return &Room{} +} + +func newAccountData(_ *dbutil.QueryHelper[*AccountData]) *AccountData { + return &AccountData{} +} + +func newAccount(_ *dbutil.QueryHelper[*Account]) *Account { + return &Account{} +} diff --git a/hicli/database/event.go b/hicli/database/event.go new file mode 100644 index 00000000..b7b15eea --- /dev/null +++ b/hicli/database/event.go @@ -0,0 +1,193 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "database/sql" + "encoding/json" + "time" + + "github.com/tidwall/gjson" + "go.mau.fi/util/dbutil" + "go.mau.fi/util/exgjson" + "golang.org/x/net/context" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +const ( + getEventBaseQuery = ` + SELECT rowid, room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned, + redacted_by, relates_to, megolm_session_id, decryption_error + FROM event + ` + getFailedEventsByMegolmSessionID = getEventBaseQuery + `WHERE room_id = $1 AND megolm_session_id = $2 AND decryption_error IS NOT NULL` + upsertEventQuery = ` + INSERT INTO event (room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned, redacted_by, relates_to, megolm_session_id, decryption_error) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) + ON CONFLICT (event_id) DO UPDATE + SET decrypted=COALESCE(event.decrypted, excluded.decrypted), + decrypted_type=COALESCE(event.decrypted_type, excluded.decrypted_type), + redacted_by=COALESCE(event.redacted_by, excluded.redacted_by), + decryption_error=CASE WHEN COALESCE(event.decrypted, excluded.decrypted) IS NULL THEN COALESCE(excluded.decryption_error, event.decryption_error) END + RETURNING rowid + ` + updateEventDecryptedQuery = `UPDATE event SET decrypted = $1, decrypted_type = $2, decryption_error = NULL WHERE rowid = $3` +) + +type EventQuery struct { + *dbutil.QueryHelper[*Event] +} + +func (eq *EventQuery) GetFailedByMegolmSessionID(ctx context.Context, roomID id.RoomID, sessionID id.SessionID) ([]*Event, error) { + return eq.QueryMany(ctx, getFailedEventsByMegolmSessionID, roomID, sessionID) +} + +func (eq *EventQuery) Upsert(ctx context.Context, evt *Event) (rowID int64, err error) { + err = eq.GetDB().QueryRow(ctx, upsertEventQuery, evt.sqlVariables()...).Scan(&rowID) + return +} + +func (eq *EventQuery) UpdateDecrypted(ctx context.Context, rowID int64, decrypted json.RawMessage, decryptedType string) error { + return eq.Exec(ctx, updateEventDecryptedQuery, unsafeJSONString(decrypted), decryptedType, rowID) +} + +type Event struct { + RowID int64 + + RoomID id.RoomID + ID id.EventID + Sender id.UserID + Type string + StateKey *string + Timestamp time.Time + + Content json.RawMessage + Decrypted json.RawMessage + DecryptedType string + Unsigned json.RawMessage + + RedactedBy id.EventID + RelatesTo id.EventID + + MegolmSessionID id.SessionID + DecryptionError string +} + +func MautrixToEvent(evt *event.Event) *Event { + dbEvt := &Event{ + RoomID: evt.RoomID, + ID: evt.ID, + Sender: evt.Sender, + Type: evt.Type.Type, + StateKey: evt.StateKey, + Timestamp: time.UnixMilli(evt.Timestamp), + Content: evt.Content.VeryRaw, + RelatesTo: getRelatesTo(evt), + MegolmSessionID: getMegolmSessionID(evt), + } + dbEvt.Unsigned, _ = json.Marshal(&evt.Unsigned) + if evt.Unsigned.RedactedBecause != nil { + dbEvt.RedactedBy = evt.Unsigned.RedactedBecause.ID + } + return dbEvt +} + +func (e *Event) AsRawMautrix() *event.Event { + evt := &event.Event{ + RoomID: e.RoomID, + ID: e.ID, + Sender: e.Sender, + Type: event.Type{Type: e.Type, Class: event.MessageEventType}, + StateKey: e.StateKey, + Timestamp: e.Timestamp.UnixMilli(), + Content: event.Content{VeryRaw: e.Content}, + } + if e.Decrypted != nil { + evt.Content.VeryRaw = e.Decrypted + evt.Type.Type = e.DecryptedType + evt.Mautrix.WasEncrypted = true + } + if e.StateKey != nil { + evt.Type.Class = event.StateEventType + } + _ = json.Unmarshal(e.Unsigned, &evt.Unsigned) + return evt +} + +func (e *Event) Scan(row dbutil.Scannable) (*Event, error) { + var timestamp int64 + var redactedBy, relatesTo, megolmSessionID, decryptionError, decryptedType sql.NullString + err := row.Scan( + &e.RowID, + &e.RoomID, + &e.ID, + &e.Sender, + &e.Type, + &e.StateKey, + ×tamp, + (*[]byte)(&e.Content), + (*[]byte)(&e.Decrypted), + &decryptedType, + (*[]byte)(&e.Unsigned), + &redactedBy, + &relatesTo, + &megolmSessionID, + &decryptionError, + ) + if err != nil { + return nil, err + } + e.Timestamp = time.UnixMilli(timestamp) + e.RedactedBy = id.EventID(redactedBy.String) + e.RelatesTo = id.EventID(relatesTo.String) + e.MegolmSessionID = id.SessionID(megolmSessionID.String) + e.DecryptedType = decryptedType.String + e.DecryptionError = decryptionError.String + return e, nil +} + +var relatesToPath = exgjson.Path("m.relates_to", "event_id") + +func getRelatesTo(evt *event.Event) id.EventID { + res := gjson.GetBytes(evt.Content.VeryRaw, relatesToPath) + if res.Exists() && res.Type == gjson.String { + return id.EventID(res.Str) + } + return "" +} + +func getMegolmSessionID(evt *event.Event) id.SessionID { + if evt.Type != event.EventEncrypted { + return "" + } + res := gjson.GetBytes(evt.Content.VeryRaw, "session_id") + if res.Exists() && res.Type == gjson.String { + return id.SessionID(res.Str) + } + return "" +} + +func (e *Event) sqlVariables() []any { + return []any{ + e.RoomID, + e.ID, + e.Sender, + e.Type, + e.StateKey, + e.Timestamp.UnixMilli(), + unsafeJSONString(e.Content), + unsafeJSONString(e.Decrypted), + dbutil.StrPtr(e.DecryptedType), + unsafeJSONString(e.Unsigned), + dbutil.StrPtr(e.RedactedBy), + dbutil.StrPtr(e.RelatesTo), + dbutil.StrPtr(e.MegolmSessionID), + dbutil.StrPtr(e.DecryptionError), + } +} diff --git a/hicli/database/room.go b/hicli/database/room.go new file mode 100644 index 00000000..c7d13fca --- /dev/null +++ b/hicli/database/room.go @@ -0,0 +1,115 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +const ( + getRoomByIDQuery = ` + SELECT room_id, creation_content, name, avatar, topic, lazy_load_summary, encryption_event, has_member_list, prev_batch + FROM room WHERE room_id = $1 + ` + ensureRoomExistsQuery = ` + INSERT INTO room (room_id) VALUES ($1) + ON CONFLICT (room_id) DO NOTHING + ` + upsertRoomFromSyncQuery = ` + UPDATE room + SET creation_content = COALESCE(room.creation_content, $2), + name = COALESCE($3, room.name), + avatar = COALESCE($4, room.avatar), + topic = COALESCE($5, room.topic), + lazy_load_summary = COALESCE($6, room.lazy_load_summary), + encryption_event = COALESCE($7, room.encryption_event), + has_member_list = room.has_member_list OR $8, + prev_batch = COALESCE(room.prev_batch, $9) + WHERE room_id = $1 + ` + setRoomPrevBatchQuery = ` + INSERT INTO room (room_id, prev_batch) VALUES ($1, $2) + ON CONFLICT (room_id) DO UPDATE SET prev_batch = excluded.prev_batch + ` +) + +type RoomQuery struct { + *dbutil.QueryHelper[*Room] +} + +func (rq *RoomQuery) Get(ctx context.Context, roomID id.RoomID) (*Room, error) { + return rq.QueryOne(ctx, getRoomByIDQuery, roomID) +} + +func (rq *RoomQuery) Upsert(ctx context.Context, room *Room) error { + return rq.Exec(ctx, upsertRoomFromSyncQuery, room.sqlVariables()...) +} + +func (rq *RoomQuery) CreateRow(ctx context.Context, roomID id.RoomID) error { + return rq.Exec(ctx, ensureRoomExistsQuery, roomID) +} + +func (rq *RoomQuery) SetPrevBatch(ctx context.Context, roomID id.RoomID, prevBatch string) error { + return rq.Exec(ctx, setRoomPrevBatchQuery, roomID, prevBatch) +} + +type Room struct { + ID id.RoomID + CreationContent *event.CreateEventContent + + Name *string + Avatar *id.ContentURI + Topic *string + + LazyLoadSummary *mautrix.LazyLoadSummary + + EncryptionEvent *event.EncryptionEventContent + HasMemberList bool + + PrevBatch string +} + +func (r *Room) Scan(row dbutil.Scannable) (*Room, error) { + var prevBatch sql.NullString + err := row.Scan( + &r.ID, + dbutil.JSON{Data: &r.CreationContent}, + &r.Name, + &r.Avatar, + &r.Topic, + dbutil.JSON{Data: &r.LazyLoadSummary}, + dbutil.JSON{Data: &r.EncryptionEvent}, + &r.HasMemberList, + &prevBatch, + ) + if err != nil { + return nil, err + } + r.PrevBatch = prevBatch.String + return r, nil +} + +func (r *Room) sqlVariables() []any { + return []any{ + r.ID, + dbutil.JSONPtr(r.CreationContent), + r.Name, + r.Avatar, + r.Topic, + dbutil.JSONPtr(r.LazyLoadSummary), + dbutil.JSONPtr(r.EncryptionEvent), + r.HasMemberList, + dbutil.StrPtr(r.PrevBatch), + } +} diff --git a/hicli/database/sessionrequest.go b/hicli/database/sessionrequest.go new file mode 100644 index 00000000..6690c13f --- /dev/null +++ b/hicli/database/sessionrequest.go @@ -0,0 +1,69 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/id" +) + +const ( + putSessionRequestQueueEntry = ` + INSERT INTO session_request (room_id, session_id, sender, min_index, backup_checked, request_sent) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (session_id) DO UPDATE + SET min_index = MIN(excluded.min_index, session_request.min_index), + backup_checked = excluded.backup_checked OR session_request.backup_checked, + request_sent = excluded.request_sent OR session_request.request_sent + ` + removeSessionRequestQuery = ` + DELETE FROM session_request WHERE session_id = $1 AND min_index >= $2 + ` + getNextSessionsToRequestQuery = ` + SELECT room_id, session_id, sender, min_index, backup_checked, request_sent + FROM session_request + WHERE request_sent = false OR backup_checked = false + ORDER BY backup_checked, rowid + LIMIT $1 + ` +) + +type SessionRequestQuery struct { + *dbutil.QueryHelper[*SessionRequest] +} + +func (srq *SessionRequestQuery) Next(ctx context.Context, count int) ([]*SessionRequest, error) { + return srq.QueryMany(ctx, getNextSessionsToRequestQuery, count) +} + +func (srq *SessionRequestQuery) Remove(ctx context.Context, sessionID id.SessionID, minIndex uint32) error { + return srq.Exec(ctx, removeSessionRequestQuery, sessionID, minIndex) +} + +func (srq *SessionRequestQuery) Put(ctx context.Context, sr *SessionRequest) error { + return srq.Exec(ctx, putSessionRequestQueueEntry, sr.sqlVariables()...) +} + +type SessionRequest struct { + RoomID id.RoomID + SessionID id.SessionID + Sender id.UserID + MinIndex uint32 + BackupChecked bool + RequestSent bool +} + +func (s *SessionRequest) Scan(row dbutil.Scannable) (*SessionRequest, error) { + return dbutil.ValueOrErr(s, row.Scan(&s.RoomID, &s.SessionID, &s.Sender, &s.MinIndex, &s.BackupChecked, &s.RequestSent)) +} + +func (s *SessionRequest) sqlVariables() []any { + return []any{s.RoomID, s.SessionID, s.Sender, s.MinIndex, s.BackupChecked, s.RequestSent} +} diff --git a/hicli/database/state.go b/hicli/database/state.go new file mode 100644 index 00000000..47c91dcf --- /dev/null +++ b/hicli/database/state.go @@ -0,0 +1,32 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +const ( + setCurrentStateQuery = ` + INSERT INTO current_state (room_id, event_type, state_key, event_rowid, membership) VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (room_id, event_type, state_key) DO UPDATE SET event_rowid = excluded.event_rowid, membership = excluded.membership + ` +) + +type CurrentStateQuery struct { + *dbutil.Database +} + +func (csq *CurrentStateQuery) Set(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, eventRowID int64, membership event.Membership) error { + _, err := csq.Exec(ctx, setCurrentStateQuery, roomID, eventType.Type, stateKey, eventRowID, dbutil.StrPtr(membership)) + return err +} diff --git a/hicli/database/statestore.go b/hicli/database/statestore.go new file mode 100644 index 00000000..e0471ef2 --- /dev/null +++ b/hicli/database/statestore.go @@ -0,0 +1,149 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + "errors" + + "go.mau.fi/util/dbutil" + "golang.org/x/exp/slices" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +const ( + getMembershipQuery = ` + SELECT membership FROM current_state + WHERE room_id = $1 AND event_type = 'm.room.member' AND state_key = $2 + ` + getStateEventContentQuery = ` + SELECT event.content FROM current_state cs + LEFT JOIN event ON event.rowid = cs.event_rowid + WHERE cs.room_id = $1 AND cs.event_type = $2 AND cs.state_key = $3 + ` + getRoomJoinedOrInvitedMembersQuery = ` + SELECT state_key FROM current_state + WHERE room_id = $1 AND event_type = 'm.room.member' AND membership IN ('join', 'invite') + ` + isRoomEncryptedQuery = ` + SELECT room.encryption_event IS NOT NULL FROM room WHERE room_id = $1 + ` + getRoomEncryptionEventQuery = ` + SELECT room.encryption_event FROM room WHERE room_id = $1 + ` + findSharedRoomsQuery = ` + SELECT room_id FROM current_state + WHERE event_type = 'm.room.member' AND state_key = $1 AND membership = 'join' + ` +) + +type ClientStateStore struct { + *Database +} + +var _ mautrix.StateStore = (*ClientStateStore)(nil) +var _ mautrix.StateStoreUpdater = (*ClientStateStore)(nil) +var _ crypto.StateStore = (*ClientStateStore)(nil) + +func (c *ClientStateStore) IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool { + return c.IsMembership(ctx, roomID, userID, event.MembershipJoin) +} + +func (c *ClientStateStore) IsInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) bool { + return c.IsMembership(ctx, roomID, userID, event.MembershipInvite, event.MembershipJoin) +} + +func (c *ClientStateStore) IsMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool { + var membership event.Membership + err := c.QueryRow(ctx, getMembershipQuery, roomID, userID).Scan(&membership) + if errors.Is(err, sql.ErrNoRows) { + err = nil + membership = event.MembershipLeave + } + return slices.Contains(allowedMemberships, membership) +} + +func (c *ClientStateStore) GetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) { + content, err := c.TryGetMember(ctx, roomID, userID) + if content == nil { + content = &event.MemberEventContent{Membership: event.MembershipLeave} + } + return content, err +} + +func (c *ClientStateStore) TryGetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (content *event.MemberEventContent, err error) { + err = c.QueryRow(ctx, getStateEventContentQuery, roomID, event.StateMember.Type, userID).Scan(&dbutil.JSON{Data: &content}) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + return +} + +func (c *ClientStateStore) GetPowerLevels(ctx context.Context, roomID id.RoomID) (content *event.PowerLevelsEventContent, err error) { + err = c.QueryRow(ctx, getStateEventContentQuery, roomID, event.StatePowerLevels.Type, "").Scan(&dbutil.JSON{Data: &content}) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + return +} + +func (c *ClientStateStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error) { + rows, err := c.Query(ctx, getRoomJoinedOrInvitedMembersQuery, roomID) + return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList() +} + +func (c *ClientStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (isEncrypted bool, err error) { + err = c.QueryRow(ctx, isRoomEncryptedQuery, roomID).Scan(&isEncrypted) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + return +} + +func (c *ClientStateStore) GetEncryptionEvent(ctx context.Context, roomID id.RoomID) (content *event.EncryptionEventContent, err error) { + err = c.QueryRow(ctx, getRoomEncryptionEventQuery, roomID). + Scan(&dbutil.JSON{Data: &content}) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + return +} + +func (c *ClientStateStore) FindSharedRooms(ctx context.Context, userID id.UserID) ([]id.RoomID, error) { + // TODO for multiuser support, this might need to filter by the local user's membership + rows, err := c.Query(ctx, findSharedRoomsQuery, userID) + return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.RoomID], err).AsList() +} + +// Update methods are all intentionally no-ops as the state store wants to have the full event + +func (c *ClientStateStore) SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error { + return nil +} + +func (c *ClientStateStore) SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error { + return nil +} + +func (c *ClientStateStore) ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error { + return nil +} + +func (c *ClientStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error { + return nil +} + +func (c *ClientStateStore) SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error { + return nil +} + +func (c *ClientStateStore) UpdateState(ctx context.Context, evt *event.Event) {} diff --git a/hicli/database/timeline.go b/hicli/database/timeline.go new file mode 100644 index 00000000..585e55bb --- /dev/null +++ b/hicli/database/timeline.go @@ -0,0 +1,47 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/id" +) + +const ( + clearTimelineQuery = ` + DELETE FROM timeline WHERE room_id = $1 + ` + setTimelineQuery = ` + INSERT INTO timeline (room_id, event_rowid) VALUES ($1, $2) + ` +) + +type MassInsertableRowID int64 + +func (m MassInsertableRowID) GetMassInsertValues() [1]any { + return [1]any{m} +} + +var setTimelineQueryBuilder = dbutil.NewMassInsertBuilder[MassInsertableRowID, [1]any](setTimelineQuery, "($1, $%d)") + +type TimelineQuery struct { + *dbutil.Database +} + +func (tq *TimelineQuery) Clear(ctx context.Context, roomID id.RoomID) error { + _, err := tq.Exec(ctx, clearTimelineQuery, roomID) + return err +} + +func (tq *TimelineQuery) Append(ctx context.Context, roomID id.RoomID, rowIDs []MassInsertableRowID) error { + query, params := setTimelineQueryBuilder.Build([1]any{roomID}, rowIDs) + _, err := tq.Exec(ctx, query, params...) + return err +} diff --git a/hicli/database/upgrades/00-latest-revision.sql b/hicli/database/upgrades/00-latest-revision.sql new file mode 100644 index 00000000..cc85f25a --- /dev/null +++ b/hicli/database/upgrades/00-latest-revision.sql @@ -0,0 +1,107 @@ +-- v0 -> v1: Latest revision +CREATE TABLE account ( + user_id TEXT NOT NULL PRIMARY KEY, + device_id TEXT NOT NULL, + access_token TEXT NOT NULL, + homeserver_url TEXT NOT NULL, + + next_batch TEXT NOT NULL +) STRICT; + +CREATE TABLE room ( + room_id TEXT NOT NULL PRIMARY KEY, + creation_content TEXT, + + name TEXT, + avatar TEXT, + topic TEXT, + lazy_load_summary TEXT, + + encryption_event TEXT, + has_member_list INTEGER NOT NULL DEFAULT false, + + prev_batch TEXT +) STRICT; +CREATE INDEX room_type_idx ON room (creation_content ->> 'type'); + +CREATE TABLE account_data ( + user_id TEXT NOT NULL, + type TEXT NOT NULL, + content TEXT NOT NULL, + + PRIMARY KEY (user_id, type) +) STRICT; + +CREATE TABLE room_account_data ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + content TEXT NOT NULL, + + PRIMARY KEY (user_id, room_id, type), + CONSTRAINT room_account_data_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE +) STRICT; + +CREATE TABLE event ( + rowid INTEGER PRIMARY KEY, + + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + sender TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT, + timestamp INTEGER NOT NULL, + + content TEXT NOT NULL, + decrypted TEXT, + decrypted_type TEXT, + unsigned TEXT NOT NULL, + + redacted_by TEXT, + relates_to TEXT, + + megolm_session_id TEXT, + decryption_error TEXT, + + CONSTRAINT event_id_unique_key UNIQUE (event_id), + CONSTRAINT event_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE +) STRICT; +CREATE INDEX event_room_id_idx ON event (room_id); +CREATE INDEX event_redacted_by_idx ON event (room_id, redacted_by); +CREATE INDEX event_relates_to_idx ON event (room_id, relates_to); +CREATE INDEX event_megolm_session_id_idx ON event (room_id, megolm_session_id); + +CREATE TABLE session_request ( + room_id TEXT NOT NULL, + session_id TEXT NOT NULL, + sender TEXT NOT NULL, + min_index INTEGER NOT NULL, + backup_checked INTEGER NOT NULL DEFAULT false, + request_sent INTEGER NOT NULL DEFAULT false, + + PRIMARY KEY (session_id), + CONSTRAINT session_request_queue_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE +) STRICT; + +CREATE TABLE timeline ( + rowid INTEGER PRIMARY KEY, + room_id TEXT NOT NULL, + event_rowid INTEGER NOT NULL, + + CONSTRAINT timeline_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE, + CONSTRAINT timeline_event_fkey FOREIGN KEY (event_rowid) REFERENCES event (rowid) ON DELETE CASCADE +) STRICT; +CREATE INDEX timeline_room_id_idx ON timeline (room_id); + +CREATE TABLE current_state ( + room_id TEXT NOT NULL, + event_type TEXT NOT NULL, + state_key TEXT NOT NULL, + event_rowid INTEGER NOT NULL, + + membership TEXT, + + PRIMARY KEY (room_id, event_type, state_key), + CONSTRAINT current_state_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE, + CONSTRAINT current_state_event_fkey FOREIGN KEY (event_rowid) REFERENCES event (rowid) +) STRICT, WITHOUT ROWID; diff --git a/hicli/database/upgrades/upgrades.go b/hicli/database/upgrades/upgrades.go new file mode 100644 index 00000000..9d0bd1a0 --- /dev/null +++ b/hicli/database/upgrades/upgrades.go @@ -0,0 +1,22 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package upgrades + +import ( + "embed" + + "go.mau.fi/util/dbutil" +) + +var Table dbutil.UpgradeTable + +//go:embed *.sql +var upgrades embed.FS + +func init() { + Table.RegisterFS(upgrades) +} diff --git a/hicli/decryptionqueue.go b/hicli/decryptionqueue.go new file mode 100644 index 00000000..551713a8 --- /dev/null +++ b/hicli/decryptionqueue.go @@ -0,0 +1,194 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "fmt" + "sync" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/hicli/database" + "maunium.net/go/mautrix/id" +) + +func (h *HiClient) fetchFromKeyBackup(ctx context.Context, roomID id.RoomID, sessionID id.SessionID) (*crypto.InboundGroupSession, error) { + data, err := h.Client.GetKeyBackupForRoomAndSession(ctx, h.KeyBackupVersion, roomID, sessionID) + if err != nil { + return nil, err + } else if data == nil { + return nil, nil + } + decrypted, err := data.SessionData.Decrypt(h.KeyBackupKey) + if err != nil { + return nil, err + } + return h.Crypto.ImportRoomKeyFromBackup(ctx, h.KeyBackupVersion, roomID, sessionID, decrypted) +} + +func (h *HiClient) handleReceivedMegolmSession(ctx context.Context, roomID id.RoomID, sessionID id.SessionID, firstKnownIndex uint32) { + log := zerolog.Ctx(ctx) + err := h.DB.SessionRequest.Remove(ctx, sessionID, firstKnownIndex) + if err != nil { + log.Warn().Err(err).Msg("Failed to remove session request after receiving megolm session") + } + events, err := h.DB.Event.GetFailedByMegolmSessionID(ctx, roomID, sessionID) + if err != nil { + log.Err(err).Msg("Failed to get events that failed to decrypt to retry decryption") + return + } else if len(events) == 0 { + log.Trace().Msg("No events to retry decryption for") + return + } + decrypted := events[:0] + for _, evt := range events { + if evt.Decrypted != nil { + continue + } + + evt.Decrypted, evt.DecryptedType, err = h.decryptEvent(ctx, evt.AsRawMautrix()) + if err != nil { + log.Warn().Err(err).Stringer("event_id", evt.ID).Msg("Failed to decrypt event even after receiving megolm session") + } else { + decrypted = append(decrypted, evt) + } + } + if len(decrypted) > 0 { + err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error { + for _, evt := range decrypted { + err = h.DB.Event.UpdateDecrypted(ctx, evt.RowID, evt.Decrypted, evt.DecryptedType) + if err != nil { + return fmt.Errorf("failed to save decrypted content for %s: %w", evt.ID, err) + } + } + return nil + }) + if err != nil { + log.Err(err).Msg("Failed to save decrypted events") + } + } +} + +func (h *HiClient) WakeupRequestQueue() { + select { + case h.requestQueueWakeup <- struct{}{}: + default: + } +} + +func (h *HiClient) RunRequestQueue(ctx context.Context) { + log := zerolog.Ctx(ctx).With().Str("action", "request queue").Logger() + ctx = log.WithContext(ctx) + log.Info().Msg("Starting key request queue") + defer func() { + log.Info().Msg("Stopping key request queue") + }() + for { + err := h.FetchKeysForOutdatedUsers(ctx) + if err != nil { + log.Err(err).Msg("Failed to fetch outdated device lists for tracked users") + } + madeRequests, err := h.RequestQueuedSessions(ctx) + if err != nil { + log.Err(err).Msg("Failed to handle session request queue") + } else if madeRequests { + continue + } + select { + case <-ctx.Done(): + return + case <-h.requestQueueWakeup: + } + } +} + +func (h *HiClient) requestQueuedSession(ctx context.Context, req *database.SessionRequest, doneFunc func()) { + defer doneFunc() + log := zerolog.Ctx(ctx) + if !req.BackupChecked { + sess, err := h.fetchFromKeyBackup(ctx, req.RoomID, req.SessionID) + if err != nil { + log.Err(err). + Stringer("session_id", req.SessionID). + Msg("Failed to fetch session from key backup") + + // TODO should this have retries instead of just storing it's checked? + req.BackupChecked = true + err = h.DB.SessionRequest.Put(ctx, req) + if err != nil { + log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after trying to check backup") + } + } else if sess == nil || sess.Internal.FirstKnownIndex() > req.MinIndex { + req.BackupChecked = true + err = h.DB.SessionRequest.Put(ctx, req) + if err != nil { + log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after checking backup") + } + } else { + log.Debug().Stringer("session_id", req.SessionID). + Msg("Found session with sufficiently low first known index, removing from queue") + err = h.DB.SessionRequest.Remove(ctx, req.SessionID, sess.Internal.FirstKnownIndex()) + if err != nil { + log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to remove session from request queue") + } + } + } else { + err := h.Crypto.SendRoomKeyRequest(ctx, req.RoomID, "", req.SessionID, "", map[id.UserID][]id.DeviceID{ + h.Account.UserID: {"*"}, + req.Sender: {"*"}, + }) + //var err error + if err != nil { + log.Err(err). + Stringer("session_id", req.SessionID). + Msg("Failed to send key request") + } else { + log.Debug().Stringer("session_id", req.SessionID).Msg("Sent key request") + req.RequestSent = true + err = h.DB.SessionRequest.Put(ctx, req) + if err != nil { + log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after sending request") + } + } + } +} + +const MaxParallelRequests = 5 + +func (h *HiClient) RequestQueuedSessions(ctx context.Context) (bool, error) { + sessions, err := h.DB.SessionRequest.Next(ctx, MaxParallelRequests) + if err != nil { + return false, fmt.Errorf("failed to get next events to decrypt: %w", err) + } else if len(sessions) == 0 { + return false, nil + } + var wg sync.WaitGroup + wg.Add(len(sessions)) + for _, req := range sessions { + go h.requestQueuedSession(ctx, req, wg.Done) + } + wg.Wait() + + return true, err +} + +func (h *HiClient) FetchKeysForOutdatedUsers(ctx context.Context) error { + outdatedUsers, err := h.Crypto.CryptoStore.GetOutdatedTrackedUsers(ctx) + if err != nil { + return err + } else if len(outdatedUsers) == 0 { + return nil + } + _, err = h.Crypto.FetchKeys(ctx, outdatedUsers, false) + if err != nil { + return err + } + // TODO backoff for users that fail to be fetched? + return nil +} diff --git a/hicli/hicli.go b/hicli/hicli.go new file mode 100644 index 00000000..9b889d3c --- /dev/null +++ b/hicli/hicli.go @@ -0,0 +1,159 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// Package hicli contains a highly opinionated high-level framework for developing instant messaging clients on Matrix. +package hicli + +import ( + "context" + "fmt" + "net" + "net/http" + "net/url" + "sync" + "time" + + "github.com/rs/zerolog" + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/crypto/backup" + "maunium.net/go/mautrix/hicli/database" + "maunium.net/go/mautrix/id" +) + +type HiClient struct { + DB *database.Database + Account *database.Account + Client *mautrix.Client + Crypto *crypto.OlmMachine + CryptoStore *crypto.SQLCryptoStore + ClientStore *database.ClientStateStore + Log zerolog.Logger + + Verified bool + + KeyBackupVersion id.KeyBackupVersion + KeyBackupKey *backup.MegolmBackupKey + + firstSyncReceived bool + syncingID int + syncLock sync.Mutex + encryptLock sync.Mutex + + requestQueueWakeup chan struct{} +} + +func New(rawDB *dbutil.Database, log zerolog.Logger, pickleKey []byte) *HiClient { + rawDB.Owner = "hicli" + rawDB.IgnoreForeignTables = true + db := database.New(rawDB) + db.Log = dbutil.ZeroLogger(log.With().Str("db_section", "hicli").Logger()) + c := &HiClient{ + DB: db, + Log: log, + + requestQueueWakeup: make(chan struct{}, 1), + } + c.ClientStore = &database.ClientStateStore{Database: db} + c.Client = &mautrix.Client{ + UserAgent: mautrix.DefaultUserAgent, + Client: &http.Client{ + Transport: &http.Transport{ + DialContext: (&net.Dialer{Timeout: 10 * time.Second}).DialContext, + TLSHandshakeTimeout: 10 * time.Second, + // This needs to be relatively high to allow initial syncs + ResponseHeaderTimeout: 180 * time.Second, + ForceAttemptHTTP2: true, + }, + Timeout: 180 * time.Second, + }, + Syncer: (*hiSyncer)(c), + Store: (*hiStore)(c), + StateStore: c.ClientStore, + Log: log.With().Str("component", "mautrix client").Logger(), + } + c.CryptoStore = crypto.NewSQLCryptoStore(rawDB, dbutil.ZeroLogger(log.With().Str("db_section", "crypto").Logger()), "", "", pickleKey) + cryptoLog := log.With().Str("component", "crypto").Logger() + c.Crypto = crypto.NewOlmMachine(c.Client, &cryptoLog, c.CryptoStore, c.ClientStore) + c.Crypto.SessionReceived = c.handleReceivedMegolmSession + c.Crypto.DisableRatchetTracking = true + c.Crypto.DisableDecryptKeyFetching = true + c.Client.Crypto = (*hiCryptoHelper)(c) + return c +} + +func (h *HiClient) IsLoggedIn() bool { + return h.Account != nil +} + +func (h *HiClient) Start(ctx context.Context, userID id.UserID) error { + err := h.DB.Upgrade(ctx) + if err != nil { + return fmt.Errorf("failed to upgrade hicli db: %w", err) + } + err = h.CryptoStore.DB.Upgrade(ctx) + if err != nil { + return fmt.Errorf("failed to upgrade crypto db: %w", err) + } + account, err := h.DB.Account.Get(ctx, userID) + if err != nil { + return err + } + if account != nil { + zerolog.Ctx(ctx).Debug().Stringer("user_id", account.UserID).Msg("Preparing client with existing credentials") + h.Account = account + h.CryptoStore.AccountID = account.UserID.String() + h.CryptoStore.DeviceID = account.DeviceID + h.Client.UserID = account.UserID + h.Client.DeviceID = account.DeviceID + h.Client.AccessToken = account.AccessToken + h.Client.HomeserverURL, err = url.Parse(account.HomeserverURL) + if err != nil { + return err + } + err = h.Crypto.Load(ctx) + if err != nil { + return fmt.Errorf("failed to load olm machine: %w", err) + } + + h.Verified, err = h.checkIsCurrentDeviceVerified(ctx) + if err != nil { + return err + } + zerolog.Ctx(ctx).Debug().Bool("verified", h.Verified).Msg("Checked current device verification status") + if h.Verified { + err = h.loadPrivateKeys(ctx) + if err != nil { + return err + } + go h.Sync() + go h.RunRequestQueue(ctx) + } + } + return nil +} + +func (h *HiClient) Sync() { + h.Client.StopSync() + h.syncLock.Lock() + defer h.syncLock.Unlock() + h.syncingID++ + syncingID := h.syncingID + log := h.Log.With(). + Str("action", "sync"). + Int("sync_id", syncingID). + Logger() + ctx := log.WithContext(context.Background()) + log.Info().Msg("Starting syncing") + err := h.Client.SyncWithContext(ctx) + if err != nil { + log.Err(err).Msg("Fatal error in syncer") + } else { + log.Info().Msg("Syncing stopped") + } +} diff --git a/hicli/hitest/hitest.go b/hicli/hitest/hitest.go new file mode 100644 index 00000000..ec94a328 --- /dev/null +++ b/hicli/hitest/hitest.go @@ -0,0 +1,69 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package main + +import ( + "context" + "io" + "os" + "os/signal" + "syscall" + + "github.com/chzyer/readline" + _ "github.com/mattn/go-sqlite3" + "go.mau.fi/util/dbutil" + _ "go.mau.fi/util/dbutil/litestream" + "go.mau.fi/util/exerrors" + "go.mau.fi/util/exzerolog" + "go.mau.fi/zeroconfig" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/hicli" + "maunium.net/go/mautrix/id" +) + +var writerTypeReadline zeroconfig.WriterType = "hitest_readline" + +func main() { + rl := exerrors.Must(readline.New("> ")) + defer func() { + _ = rl.Close() + }() + zeroconfig.RegisterWriter(writerTypeReadline, func(config *zeroconfig.WriterConfig) (io.Writer, error) { + return rl.Stdout(), nil + }) + log := exerrors.Must((&zeroconfig.Config{ + Writers: []zeroconfig.WriterConfig{{ + Type: writerTypeReadline, + Format: zeroconfig.LogFormatPrettyColored, + }}, + }).Compile()) + exzerolog.SetupDefaults(log) + + rawDB := exerrors.Must(dbutil.NewWithDialect("hicli.db", "sqlite3-fk-wal")) + ctx := log.WithContext(context.Background()) + cli := hicli.New(rawDB, *log, []byte("meow")) + userID, _ := cli.DB.Account.GetFirstUserID(ctx) + exerrors.PanicIfNotNil(cli.Start(ctx, userID)) + if !cli.IsLoggedIn() { + rl.SetPrompt("User ID: ") + userID := id.UserID(exerrors.Must(rl.Readline())) + _, serverName := exerrors.Must2(userID.Parse()) + discovery, err := mautrix.DiscoverClientAPI(ctx, serverName) + if discovery == nil { + log.Fatal().Err(err).Msg("Failed to discover homeserver") + } + password := exerrors.Must(rl.ReadPassword("Password: ")) + recoveryCode := exerrors.Must(rl.ReadPassword("Recovery code: ")) + exerrors.PanicIfNotNil(cli.LoginAndVerify(ctx, discovery.Homeserver.BaseURL, userID.String(), string(password), string(recoveryCode))) + } + rl.SetPrompt("> ") + + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + <-c +} diff --git a/hicli/login.go b/hicli/login.go new file mode 100644 index 00000000..47ea5a4d --- /dev/null +++ b/hicli/login.go @@ -0,0 +1,77 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "fmt" + "net/url" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/hicli/database" +) + +func (h *HiClient) LoginPassword(ctx context.Context, homeserverURL, username, password string) error { + var err error + h.Client.HomeserverURL, err = url.Parse(homeserverURL) + if err != nil { + return err + } + return h.Login(ctx, &mautrix.ReqLogin{ + Type: mautrix.AuthTypePassword, + Identifier: mautrix.UserIdentifier{ + Type: mautrix.IdentifierTypeUser, + User: username, + }, + Password: password, + InitialDeviceDisplayName: "mautrix client", + }) +} + +func (h *HiClient) Login(ctx context.Context, req *mautrix.ReqLogin) error { + req.StoreCredentials = true + req.StoreHomeserverURL = true + resp, err := h.Client.Login(ctx, req) + if err != nil { + return err + } + h.Account = &database.Account{ + UserID: resp.UserID, + DeviceID: resp.DeviceID, + AccessToken: resp.AccessToken, + HomeserverURL: h.Client.HomeserverURL.String(), + } + h.CryptoStore.AccountID = resp.UserID.String() + h.CryptoStore.DeviceID = resp.DeviceID + err = h.DB.Account.Put(ctx, h.Account) + if err != nil { + return err + } + err = h.Crypto.Load(ctx) + if err != nil { + return fmt.Errorf("failed to load olm machine: %w", err) + } + err = h.Crypto.ShareKeys(ctx, 0) + if err != nil { + return err + } + return nil +} + +func (h *HiClient) LoginAndVerify(ctx context.Context, homeserverURL, username, password, recoveryCode string) error { + err := h.LoginPassword(ctx, homeserverURL, username, password) + if err != nil { + return err + } + err = h.VerifyWithRecoveryCode(ctx, recoveryCode) + if err != nil { + return err + } + go h.Sync() + go h.RunRequestQueue(ctx) + return nil +} diff --git a/hicli/sync.go b/hicli/sync.go new file mode 100644 index 00000000..d0064015 --- /dev/null +++ b/hicli/sync.go @@ -0,0 +1,362 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "errors" + "fmt" + + "github.com/rs/zerolog" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "go.mau.fi/util/exzerolog" + "golang.org/x/exp/slices" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/hicli/database" + "maunium.net/go/mautrix/id" +) + +type syncContext struct { + shouldWakeupRequestQueue bool +} + +func (h *HiClient) preProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) error { + log := zerolog.Ctx(ctx) + postponedToDevices := resp.ToDevice.Events[:0] + for _, evt := range resp.ToDevice.Events { + evt.Type.Class = event.ToDeviceEventType + err := evt.Content.ParseRaw(evt.Type) + if err != nil { + log.Warn().Err(err). + Stringer("event_type", &evt.Type). + Stringer("sender", evt.Sender). + Msg("Failed to parse to-device event, skipping") + continue + } + + switch content := evt.Content.Parsed.(type) { + case *event.EncryptedEventContent: + h.Crypto.HandleEncryptedEvent(ctx, evt) + case *event.RoomKeyWithheldEventContent: + h.Crypto.HandleRoomKeyWithheld(ctx, content) + default: + postponedToDevices = append(postponedToDevices, evt) + } + } + resp.ToDevice.Events = postponedToDevices + + return nil +} + +func (h *HiClient) postProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) error { + h.Crypto.HandleOTKCounts(ctx, &resp.DeviceOTKCount) + go h.asyncPostProcessSyncResponse(ctx, resp, since) + if ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue { + h.WakeupRequestQueue() + } + return nil +} + +func (h *HiClient) asyncPostProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) { + for _, evt := range resp.ToDevice.Events { + switch content := evt.Content.Parsed.(type) { + case *event.SecretRequestEventContent: + h.Crypto.HandleSecretRequest(ctx, evt.Sender, content) + case *event.RoomKeyRequestEventContent: + h.Crypto.HandleRoomKeyRequest(ctx, evt.Sender, content) + } + } +} + +func (h *HiClient) processSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) error { + if len(resp.DeviceLists.Changed) > 0 { + zerolog.Ctx(ctx).Debug(). + Array("users", exzerolog.ArrayOfStringers(resp.DeviceLists.Changed)). + Msg("Marking changed device lists for tracked users as outdated") + err := h.Crypto.CryptoStore.MarkTrackedUsersOutdated(ctx, resp.DeviceLists.Changed) + if err != nil { + return fmt.Errorf("failed to mark changed device lists as outdated: %w", err) + } + ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue = true + } + + for _, evt := range resp.AccountData.Events { + evt.Type.Class = event.AccountDataEventType + err := h.DB.AccountData.Put(ctx, h.Account.UserID, evt.Type, evt.Content.VeryRaw) + if err != nil { + return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err) + } + } + for roomID, room := range resp.Rooms.Join { + err := h.processSyncJoinedRoom(ctx, roomID, room) + if err != nil { + return fmt.Errorf("failed to process joined room %s: %w", roomID, err) + } + } + for roomID, room := range resp.Rooms.Leave { + err := h.processSyncLeftRoom(ctx, roomID, room) + if err != nil { + return fmt.Errorf("failed to process left room %s: %w", roomID, err) + } + } + h.Account.NextBatch = resp.NextBatch + err := h.DB.Account.PutNextBatch(ctx, h.Account.UserID, resp.NextBatch) + if err != nil { + return fmt.Errorf("failed to save next_batch: %w", err) + } + return nil +} + +func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncJoinedRoom) error { + existingRoomData, err := h.DB.Room.Get(ctx, roomID) + if err != nil { + return fmt.Errorf("failed to get room data: %w", err) + } else if existingRoomData == nil { + err = h.DB.Room.CreateRow(ctx, roomID) + if err != nil { + return fmt.Errorf("failed to ensure room row exists: %w", err) + } + existingRoomData = &database.Room{ID: roomID} + } + + for _, evt := range room.AccountData.Events { + evt.Type.Class = event.AccountDataEventType + evt.RoomID = roomID + err = h.DB.AccountData.PutRoom(ctx, h.Account.UserID, roomID, evt.Type, evt.Content.VeryRaw) + if err != nil { + return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err) + } + } + err = h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary) + if err != nil { + return err + } + return nil +} + +func (h *HiClient) processSyncLeftRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncLeftRoom) error { + existingRoomData, err := h.DB.Room.Get(ctx, roomID) + if err != nil { + return fmt.Errorf("failed to get room data: %w", err) + } else if existingRoomData == nil { + return nil + } + return h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary) +} + +func isDecryptionErrorRetryable(err error) bool { + return errors.Is(err, crypto.NoSessionFound) || errors.Is(err, olm.UnknownMessageIndex) || errors.Is(err, crypto.ErrGroupSessionWithheld) +} + +func removeReplyFallback(evt *event.Event) []byte { + content, ok := evt.Content.Parsed.(*event.MessageEventContent) + if ok && content.RelatesTo.GetReplyTo() != "" { + prevFormattedBody := content.FormattedBody + content.RemoveReplyFallback() + if content.FormattedBody != prevFormattedBody { + bytes, err := sjson.SetBytes(evt.Content.VeryRaw, "formatted_body", content.FormattedBody) + if err == nil { + return bytes + } + bytes, err = sjson.SetBytes(evt.Content.VeryRaw, "body", content.Body) + if err == nil { + return bytes + } + } + } + return nil +} + +func (h *HiClient) decryptEvent(ctx context.Context, evt *event.Event) ([]byte, string, error) { + err := evt.Content.ParseRaw(evt.Type) + if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) { + return nil, "", err + } + decrypted, err := h.Crypto.DecryptMegolmEvent(ctx, evt) + if err != nil { + return nil, "", err + } + withoutFallback := removeReplyFallback(decrypted) + if withoutFallback != nil { + return withoutFallback, decrypted.Type.Type, nil + } + return decrypted.Content.VeryRaw, decrypted.Type.Type, nil +} + +func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.Room, state *mautrix.SyncEventsList, timeline *mautrix.SyncTimeline, summary *mautrix.LazyLoadSummary) error { + decryptionQueue := make(map[id.SessionID]*database.SessionRequest) + roomDataChanged := false + processEvent := func(evt *event.Event) (database.MassInsertableRowID, error) { + evt.RoomID = room.ID + dbEvt := database.MautrixToEvent(evt) + contentWithoutFallback := removeReplyFallback(evt) + if contentWithoutFallback != nil { + dbEvt.Content = contentWithoutFallback + } + var decryptionErr error + if evt.Type == event.EventEncrypted { + dbEvt.Decrypted, dbEvt.DecryptedType, decryptionErr = h.decryptEvent(ctx, evt) + if decryptionErr != nil { + dbEvt.DecryptionError = decryptionErr.Error() + } + } + rowID, err := h.DB.Event.Upsert(ctx, dbEvt) + if err != nil { + return -1, fmt.Errorf("failed to save event %s: %w", evt.ID, err) + } + if decryptionErr != nil && isDecryptionErrorRetryable(decryptionErr) { + req, ok := decryptionQueue[dbEvt.MegolmSessionID] + if !ok { + req = &database.SessionRequest{ + RoomID: room.ID, + SessionID: dbEvt.MegolmSessionID, + Sender: evt.Sender, + } + } + minIndex, _ := crypto.ParseMegolmMessageIndex(evt.Content.AsEncrypted().MegolmCiphertext) + req.MinIndex = min(uint32(minIndex), req.MinIndex) + decryptionQueue[dbEvt.MegolmSessionID] = req + } + if evt.StateKey != nil { + var membership event.Membership + if evt.Type == event.StateMember { + membership = event.Membership(gjson.GetBytes(evt.Content.VeryRaw, "membership").Str) + } + err = h.DB.CurrentState.Set(ctx, room.ID, evt.Type, *evt.StateKey, rowID, membership) + if err != nil { + return -1, fmt.Errorf("failed to save current state event ID %s for %s/%s: %w", evt.ID, evt.Type.Type, *evt.StateKey, err) + } + roomDataChanged = processImportantEvent(ctx, evt, room) || roomDataChanged + } + return database.MassInsertableRowID(rowID), nil + } + var err error + for _, evt := range state.Events { + evt.Type.Class = event.StateEventType + _, err = processEvent(evt) + if err != nil { + return err + } + } + if len(timeline.Events) > 0 { + timelineIDs := make([]database.MassInsertableRowID, len(timeline.Events)) + for i, evt := range timeline.Events { + if evt.StateKey != nil { + evt.Type.Class = event.StateEventType + } else { + evt.Type.Class = event.MessageEventType + } + timelineIDs[i], err = processEvent(evt) + if err != nil { + return err + } + } + for _, entry := range decryptionQueue { + err = h.DB.SessionRequest.Put(ctx, entry) + if err != nil { + return fmt.Errorf("failed to save session request for %s: %w", entry.SessionID, err) + } + } + if len(decryptionQueue) > 0 { + ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue = true + } + if timeline.Limited { + err = h.DB.Timeline.Clear(ctx, room.ID) + if err != nil { + return fmt.Errorf("failed to clear old timeline: %w", err) + } + } + err = h.DB.Timeline.Append(ctx, room.ID, timelineIDs) + if err != nil { + return fmt.Errorf("failed to append timeline: %w", err) + } + } + if timeline.PrevBatch != "" && room.PrevBatch == "" { + room.PrevBatch = timeline.PrevBatch + roomDataChanged = true + } + if summary.Heroes != nil { + roomDataChanged = roomDataChanged || room.LazyLoadSummary == nil || + !slices.Equal(summary.Heroes, room.LazyLoadSummary.Heroes) || + !intPtrEqual(summary.JoinedMemberCount, room.LazyLoadSummary.JoinedMemberCount) || + !intPtrEqual(summary.InvitedMemberCount, room.LazyLoadSummary.InvitedMemberCount) + room.LazyLoadSummary = summary + } + if roomDataChanged { + err = h.DB.Room.Upsert(ctx, room) + if err != nil { + return fmt.Errorf("failed to save room data: %w", err) + } + } + return nil +} + +func intPtrEqual(a, b *int) bool { + if a == nil || b == nil { + return a == b + } + return *a == *b +} + +func processImportantEvent(ctx context.Context, evt *event.Event, existingRoomData *database.Room) (roomDataChanged bool) { + if evt.StateKey == nil { + return + } + switch evt.Type { + case event.StateCreate, event.StateRoomName, event.StateRoomAvatar, event.StateTopic, event.StateEncryption: + if *evt.StateKey != "" { + return + } + default: + return + } + err := evt.Content.ParseRaw(evt.Type) + if err != nil { + zerolog.Ctx(ctx).Warn().Err(err). + Stringer("event_type", &evt.Type). + Stringer("event_id", evt.ID). + Msg("Failed to parse state event, skipping") + return + } + switch evt.Type { + case event.StateCreate: + if existingRoomData.CreationContent == nil { + roomDataChanged = true + } + existingRoomData.CreationContent, _ = evt.Content.Parsed.(*event.CreateEventContent) + case event.StateEncryption: + newEncryption, _ := evt.Content.Parsed.(*event.EncryptionEventContent) + if existingRoomData.EncryptionEvent == nil || existingRoomData.EncryptionEvent.Algorithm == newEncryption.Algorithm { + roomDataChanged = true + existingRoomData.EncryptionEvent = newEncryption + } + case event.StateRoomName: + content, ok := evt.Content.Parsed.(*event.RoomNameEventContent) + if ok { + roomDataChanged = existingRoomData.Name == nil || *existingRoomData.Name != content.Name + existingRoomData.Name = &content.Name + } + case event.StateRoomAvatar: + content, ok := evt.Content.Parsed.(*event.RoomAvatarEventContent) + if ok { + roomDataChanged = existingRoomData.Avatar == nil || *existingRoomData.Avatar != content.URL + existingRoomData.Avatar = &content.URL + } + case event.StateTopic: + content, ok := evt.Content.Parsed.(*event.TopicEventContent) + if ok { + roomDataChanged = existingRoomData.Topic == nil || *existingRoomData.Topic != content.Topic + existingRoomData.Topic = &content.Topic + } + } + return +} diff --git a/hicli/syncwrap.go b/hicli/syncwrap.go new file mode 100644 index 00000000..eccdb7b1 --- /dev/null +++ b/hicli/syncwrap.go @@ -0,0 +1,100 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "fmt" + "time" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/id" +) + +type hiSyncer HiClient + +var _ mautrix.Syncer = (*hiSyncer)(nil) + +type contextKey int + +const ( + syncContextKey contextKey = iota +) + +func (h *hiSyncer) ProcessResponse(ctx context.Context, resp *mautrix.RespSync, since string) error { + c := (*HiClient)(h) + ctx = context.WithValue(ctx, syncContextKey, &syncContext{}) + err := c.preProcessSyncResponse(ctx, resp, since) + if err != nil { + return err + } + err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error { + return c.processSyncResponse(ctx, resp, since) + }) + if err != nil { + return err + } + err = c.postProcessSyncResponse(ctx, resp, since) + if err != nil { + return err + } + c.firstSyncReceived = true + return nil +} + +func (h *hiSyncer) OnFailedSync(_ *mautrix.RespSync, err error) (time.Duration, error) { + (*HiClient)(h).Log.Err(err).Msg("Sync failed, retrying in 1 second") + return 1 * time.Second, nil +} + +func (h *hiSyncer) GetFilterJSON(_ id.UserID) *mautrix.Filter { + if !h.Verified { + return &mautrix.Filter{ + Presence: mautrix.FilterPart{ + NotRooms: []id.RoomID{"*"}, + }, + Room: mautrix.RoomFilter{ + NotRooms: []id.RoomID{"*"}, + }, + } + } + return &mautrix.Filter{ + Presence: mautrix.FilterPart{ + NotRooms: []id.RoomID{"*"}, + }, + Room: mautrix.RoomFilter{ + State: mautrix.FilterPart{ + LazyLoadMembers: true, + }, + Timeline: mautrix.FilterPart{ + Limit: 100, + LazyLoadMembers: true, + }, + }, + } +} + +type hiStore HiClient + +var _ mautrix.SyncStore = (*hiStore)(nil) + +// Filter ID save and load are intentionally no-ops: we want to recreate filters when restarting syncing + +func (h *hiStore) SaveFilterID(_ context.Context, _ id.UserID, _ string) error { return nil } +func (h *hiStore) LoadFilterID(_ context.Context, _ id.UserID) (string, error) { return "", nil } + +func (h *hiStore) SaveNextBatch(ctx context.Context, userID id.UserID, nextBatchToken string) error { + // This is intentionally a no-op: we don't want to save the next batch before processing the sync + return nil +} + +func (h *hiStore) LoadNextBatch(_ context.Context, userID id.UserID) (string, error) { + if h.Account.UserID != userID { + return "", fmt.Errorf("mismatching user ID") + } + return h.Account.NextBatch, nil +} diff --git a/hicli/verify.go b/hicli/verify.go new file mode 100644 index 00000000..2062519a --- /dev/null +++ b/hicli/verify.go @@ -0,0 +1,158 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "encoding/base64" + "fmt" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/crypto/backup" + "maunium.net/go/mautrix/crypto/ssss" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +func (h *HiClient) checkIsCurrentDeviceVerified(ctx context.Context) (bool, error) { + keys := h.Crypto.GetOwnCrossSigningPublicKeys(ctx) + if keys == nil { + return false, fmt.Errorf("own cross-signing keys not found") + } + isVerified, err := h.Crypto.CryptoStore.IsKeySignedBy(ctx, h.Account.UserID, h.Crypto.GetAccount().SigningKey(), h.Account.UserID, keys.SelfSigningKey) + if err != nil { + return false, fmt.Errorf("failed to check if current device is signed by own self-signing key: %w", err) + } + return isVerified, nil +} + +func (h *HiClient) fetchKeyBackupKey(ctx context.Context, ssssKey *ssss.Key) error { + latestVersion, err := h.Client.GetKeyBackupLatestVersion(ctx) + if err != nil { + return fmt.Errorf("failed to get key backup latest version: %w", err) + } + h.KeyBackupVersion = latestVersion.Version + data, err := h.Crypto.SSSS.GetDecryptedAccountData(ctx, event.AccountDataMegolmBackupKey, ssssKey) + if err != nil { + return fmt.Errorf("failed to get megolm backup key from SSSS: %w", err) + } + key, err := backup.MegolmBackupKeyFromBytes(data) + if err != nil { + return fmt.Errorf("failed to parse megolm backup key: %w", err) + } + err = h.CryptoStore.PutSecret(ctx, id.SecretMegolmBackupV1, base64.StdEncoding.EncodeToString(key.Bytes())) + if err != nil { + return fmt.Errorf("failed to store megolm backup key: %w", err) + } + h.KeyBackupKey = key + return nil +} + +func (h *HiClient) getAndDecodeSecret(ctx context.Context, secret id.Secret) ([]byte, error) { + secretData, err := h.CryptoStore.GetSecret(ctx, secret) + if err != nil { + return nil, fmt.Errorf("failed to get secret %s: %w", secret, err) + } + data, err := base64.StdEncoding.DecodeString(secretData) + if err != nil { + return nil, fmt.Errorf("failed to decode secret %s: %w", secret, err) + } + return data, nil +} + +func (h *HiClient) loadPrivateKeys(ctx context.Context) error { + zerolog.Ctx(ctx).Debug().Msg("Loading cross-signing private keys") + masterKeySeed, err := h.getAndDecodeSecret(ctx, id.SecretXSMaster) + if err != nil { + return fmt.Errorf("failed to get master key: %w", err) + } + selfSigningKeySeed, err := h.getAndDecodeSecret(ctx, id.SecretXSSelfSigning) + if err != nil { + return fmt.Errorf("failed to get self-signing key: %w", err) + } + userSigningKeySeed, err := h.getAndDecodeSecret(ctx, id.SecretXSUserSigning) + if err != nil { + return fmt.Errorf("failed to get user signing key: %w", err) + } + err = h.Crypto.ImportCrossSigningKeys(crypto.CrossSigningSeeds{ + MasterKey: masterKeySeed, + SelfSigningKey: selfSigningKeySeed, + UserSigningKey: userSigningKeySeed, + }) + if err != nil { + return fmt.Errorf("failed to import cross-signing private keys: %w", err) + } + zerolog.Ctx(ctx).Debug().Msg("Loading key backup key") + keyBackupKey, err := h.getAndDecodeSecret(ctx, id.SecretMegolmBackupV1) + if err != nil { + return fmt.Errorf("failed to get megolm backup key: %w", err) + } + h.KeyBackupKey, err = backup.MegolmBackupKeyFromBytes(keyBackupKey) + if err != nil { + return fmt.Errorf("failed to parse megolm backup key: %w", err) + } + zerolog.Ctx(ctx).Debug().Msg("Fetching key backup version") + latestVersion, err := h.Client.GetKeyBackupLatestVersion(ctx) + if err != nil { + return fmt.Errorf("failed to get key backup latest version: %w", err) + } + h.KeyBackupVersion = latestVersion.Version + zerolog.Ctx(ctx).Debug().Msg("Secrets loaded") + return nil +} + +func (h *HiClient) storeCrossSigningPrivateKeys(ctx context.Context) error { + keys := h.Crypto.CrossSigningKeys + err := h.CryptoStore.PutSecret(ctx, id.SecretXSMaster, base64.StdEncoding.EncodeToString(keys.MasterKey.Seed())) + if err != nil { + return err + } + err = h.CryptoStore.PutSecret(ctx, id.SecretXSSelfSigning, base64.StdEncoding.EncodeToString(keys.SelfSigningKey.Seed())) + if err != nil { + return err + } + err = h.CryptoStore.PutSecret(ctx, id.SecretXSUserSigning, base64.StdEncoding.EncodeToString(keys.UserSigningKey.Seed())) + if err != nil { + return err + } + return nil +} + +func (h *HiClient) VerifyWithRecoveryCode(ctx context.Context, code string) error { + _, keyData, err := h.Crypto.SSSS.GetDefaultKeyData(ctx) + if err != nil { + return fmt.Errorf("failed to get default SSSS key data: %w", err) + } + key, err := keyData.VerifyRecoveryKey(code) + if err != nil { + return err + } + err = h.Crypto.FetchCrossSigningKeysFromSSSS(ctx, key) + if err != nil { + return fmt.Errorf("failed to fetch cross-signing keys from SSSS: %w", err) + } + err = h.Crypto.SignOwnDevice(ctx, h.Crypto.OwnIdentity()) + if err != nil { + return fmt.Errorf("failed to sign own device: %w", err) + } + err = h.Crypto.SignOwnMasterKey(ctx) + if err != nil { + return fmt.Errorf("failed to sign own master key: %w", err) + } + err = h.storeCrossSigningPrivateKeys(ctx) + if err != nil { + return fmt.Errorf("failed to store cross-signing private keys: %w", err) + } + err = h.fetchKeyBackupKey(ctx, key) + if err != nil { + return fmt.Errorf("failed to fetch key backup key: %w", err) + } + h.Verified = true + return nil +}