Add draft of high-level client framework

This commit is contained in:
Tulir Asokan 2024-05-26 18:29:22 +03:00
commit c1eb217b9e
20 changed files with 2132 additions and 1 deletions

2
.gitignore vendored
View file

@ -1,4 +1,4 @@
.idea/
.vscode/
*.db
*.db*
*.log

82
hicli/cryptohelper.go Normal file
View file

@ -0,0 +1,82 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package hicli
import (
"context"
"errors"
"fmt"
"time"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
type hiCryptoHelper HiClient
var _ mautrix.CryptoHelper = (*hiCryptoHelper)(nil)
func (h *hiCryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtType event.Type, content any) (encrypted *event.EncryptedEventContent, err error) {
h.encryptLock.Lock()
defer h.encryptLock.Unlock()
encrypted, err = h.Crypto.EncryptMegolmEvent(ctx, roomID, evtType, content)
if err != nil {
if !errors.Is(err, crypto.SessionExpired) && !errors.Is(err, crypto.NoGroupSession) && !errors.Is(err, crypto.SessionNotShared) {
return
}
h.Log.Debug().
Err(err).
Str("room_id", roomID.String()).
Msg("Got session error while encrypting event, sharing group session and trying again")
var users []id.UserID
users, err = h.ClientStore.GetRoomJoinedOrInvitedMembers(ctx, roomID)
if err != nil {
err = fmt.Errorf("failed to get room member list: %w", err)
} else if err = h.Crypto.ShareGroupSession(ctx, roomID, users); err != nil {
err = fmt.Errorf("failed to share group session: %w", err)
} else if encrypted, err = h.Crypto.EncryptMegolmEvent(ctx, roomID, evtType, content); err != nil {
err = fmt.Errorf("failed to encrypt event after re-sharing group session: %w", err)
}
}
return
}
func (h *hiCryptoHelper) Decrypt(ctx context.Context, evt *event.Event) (*event.Event, error) {
return h.Crypto.DecryptMegolmEvent(ctx, evt)
}
func (h *hiCryptoHelper) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
return h.Crypto.WaitForSession(ctx, roomID, senderKey, sessionID, timeout)
}
func (h *hiCryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) {
err := h.Crypto.SendRoomKeyRequest(ctx, roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{
userID: {deviceID},
h.Account.UserID: {"*"},
})
if err != nil {
zerolog.Ctx(ctx).Err(err).
Stringer("room_id", roomID).
Stringer("session_id", sessionID).
Stringer("user_id", userID).
Msg("Failed to send room key request")
} else {
zerolog.Ctx(ctx).Debug().
Stringer("room_id", roomID).
Stringer("session_id", sessionID).
Stringer("user_id", userID).
Msg("Sent room key request")
}
}
func (h *hiCryptoHelper) Init(ctx context.Context) error {
return nil
}

65
hicli/database/account.go Normal file
View file

@ -0,0 +1,65 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package database
import (
"context"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/id"
)
const (
getAccountQuery = `SELECT user_id, device_id, access_token, homeserver_url, next_batch FROM account WHERE user_id = $1`
putNextBatchQuery = `UPDATE account SET next_batch = $1 WHERE user_id = $2`
upsertAccountQuery = `
INSERT INTO account (user_id, device_id, access_token, homeserver_url, next_batch)
VALUES ($1, $2, $3, $4, $5) ON CONFLICT (user_id)
DO UPDATE SET device_id = excluded.device_id,
access_token = excluded.access_token,
homeserver_url = excluded.homeserver_url,
next_batch = excluded.next_batch
`
)
type AccountQuery struct {
*dbutil.QueryHelper[*Account]
}
func (aq *AccountQuery) GetFirstUserID(ctx context.Context) (userID id.UserID, err error) {
err = aq.GetDB().QueryRow(ctx, `SELECT user_id FROM account LIMIT 1`).Scan(&userID)
return
}
func (aq *AccountQuery) Get(ctx context.Context, userID id.UserID) (*Account, error) {
return aq.QueryOne(ctx, getAccountQuery, userID)
}
func (aq *AccountQuery) PutNextBatch(ctx context.Context, userID id.UserID, nextBatch string) error {
return aq.Exec(ctx, putNextBatchQuery, nextBatch, userID)
}
func (aq *AccountQuery) Put(ctx context.Context, account *Account) error {
return aq.Exec(ctx, upsertAccountQuery, account.sqlVariables()...)
}
type Account struct {
UserID id.UserID
DeviceID id.DeviceID
AccessToken string
HomeserverURL string
NextBatch string
}
func (a *Account) Scan(row dbutil.Scannable) (*Account, error) {
return dbutil.ValueOrErr(a, row.Scan(&a.UserID, &a.DeviceID, &a.AccessToken, &a.HomeserverURL, &a.NextBatch))
}
func (a *Account) sqlVariables() []any {
return []any{a.UserID, a.DeviceID, a.AccessToken, a.HomeserverURL, a.NextBatch}
}

View file

@ -0,0 +1,71 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package database
import (
"context"
"database/sql"
"encoding/json"
"unsafe"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
const (
upsertAccountDataQuery = `
INSERT INTO account_data (user_id, type, content) VALUES ($1, $2, $3)
ON CONFLICT (user_id, type) DO UPDATE SET content = excluded.content
`
upsertRoomAccountDataQuery = `
INSERT INTO room_account_data (user_id, room_id, type, content) VALUES ($1, $2, $3, $4)
ON CONFLICT (user_id, room_id, type) DO UPDATE SET content = excluded.content
`
)
type AccountDataQuery struct {
*dbutil.QueryHelper[*AccountData]
}
func unsafeJSONString(content json.RawMessage) *string {
if content == nil {
return nil
}
str := unsafe.String(unsafe.SliceData(content), len(content))
return &str
}
func (adq *AccountDataQuery) Put(ctx context.Context, userID id.UserID, eventType event.Type, content json.RawMessage) error {
return adq.Exec(ctx, upsertAccountDataQuery, userID, eventType.Type, unsafeJSONString(content))
}
func (adq *AccountDataQuery) PutRoom(ctx context.Context, userID id.UserID, roomID id.RoomID, eventType event.Type, content json.RawMessage) error {
return adq.Exec(ctx, upsertRoomAccountDataQuery, userID, roomID, eventType.Type, unsafeJSONString(content))
}
type AccountData struct {
UserID id.UserID
RoomID id.RoomID
Type string
Content json.RawMessage
}
func (a *AccountData) Scan(row dbutil.Scannable) (*AccountData, error) {
var roomID sql.NullString
err := row.Scan(&a.UserID, &roomID, &a.Type, (*[]byte)(&a.Content))
if err != nil {
return nil, err
}
a.RoomID = id.RoomID(roomID.String)
return a, nil
}
func (a *AccountData) sqlVariables() []any {
return []any{a.UserID, dbutil.StrPtr(a.RoomID), a.Type, unsafeJSONString(a.Content)}
}

View file

@ -0,0 +1,60 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package database
import (
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/hicli/database/upgrades"
)
type Database struct {
*dbutil.Database
Account AccountQuery
AccountData AccountDataQuery
Room RoomQuery
Event EventQuery
CurrentState CurrentStateQuery
Timeline TimelineQuery
SessionRequest SessionRequestQuery
}
func New(rawDB *dbutil.Database) *Database {
rawDB.UpgradeTable = upgrades.Table
return &Database{
Database: rawDB,
Account: AccountQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newAccount)},
AccountData: AccountDataQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newAccountData)},
Room: RoomQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newRoom)},
Event: EventQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newEvent)},
CurrentState: CurrentStateQuery{Database: rawDB},
Timeline: TimelineQuery{Database: rawDB},
SessionRequest: SessionRequestQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newSessionRequest)},
}
}
func newSessionRequest(_ *dbutil.QueryHelper[*SessionRequest]) *SessionRequest {
return &SessionRequest{}
}
func newEvent(_ *dbutil.QueryHelper[*Event]) *Event {
return &Event{}
}
func newRoom(_ *dbutil.QueryHelper[*Room]) *Room {
return &Room{}
}
func newAccountData(_ *dbutil.QueryHelper[*AccountData]) *AccountData {
return &AccountData{}
}
func newAccount(_ *dbutil.QueryHelper[*Account]) *Account {
return &Account{}
}

193
hicli/database/event.go Normal file
View file

@ -0,0 +1,193 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package database
import (
"database/sql"
"encoding/json"
"time"
"github.com/tidwall/gjson"
"go.mau.fi/util/dbutil"
"go.mau.fi/util/exgjson"
"golang.org/x/net/context"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
const (
getEventBaseQuery = `
SELECT rowid, room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned,
redacted_by, relates_to, megolm_session_id, decryption_error
FROM event
`
getFailedEventsByMegolmSessionID = getEventBaseQuery + `WHERE room_id = $1 AND megolm_session_id = $2 AND decryption_error IS NOT NULL`
upsertEventQuery = `
INSERT INTO event (room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned, redacted_by, relates_to, megolm_session_id, decryption_error)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
ON CONFLICT (event_id) DO UPDATE
SET decrypted=COALESCE(event.decrypted, excluded.decrypted),
decrypted_type=COALESCE(event.decrypted_type, excluded.decrypted_type),
redacted_by=COALESCE(event.redacted_by, excluded.redacted_by),
decryption_error=CASE WHEN COALESCE(event.decrypted, excluded.decrypted) IS NULL THEN COALESCE(excluded.decryption_error, event.decryption_error) END
RETURNING rowid
`
updateEventDecryptedQuery = `UPDATE event SET decrypted = $1, decrypted_type = $2, decryption_error = NULL WHERE rowid = $3`
)
type EventQuery struct {
*dbutil.QueryHelper[*Event]
}
func (eq *EventQuery) GetFailedByMegolmSessionID(ctx context.Context, roomID id.RoomID, sessionID id.SessionID) ([]*Event, error) {
return eq.QueryMany(ctx, getFailedEventsByMegolmSessionID, roomID, sessionID)
}
func (eq *EventQuery) Upsert(ctx context.Context, evt *Event) (rowID int64, err error) {
err = eq.GetDB().QueryRow(ctx, upsertEventQuery, evt.sqlVariables()...).Scan(&rowID)
return
}
func (eq *EventQuery) UpdateDecrypted(ctx context.Context, rowID int64, decrypted json.RawMessage, decryptedType string) error {
return eq.Exec(ctx, updateEventDecryptedQuery, unsafeJSONString(decrypted), decryptedType, rowID)
}
type Event struct {
RowID int64
RoomID id.RoomID
ID id.EventID
Sender id.UserID
Type string
StateKey *string
Timestamp time.Time
Content json.RawMessage
Decrypted json.RawMessage
DecryptedType string
Unsigned json.RawMessage
RedactedBy id.EventID
RelatesTo id.EventID
MegolmSessionID id.SessionID
DecryptionError string
}
func MautrixToEvent(evt *event.Event) *Event {
dbEvt := &Event{
RoomID: evt.RoomID,
ID: evt.ID,
Sender: evt.Sender,
Type: evt.Type.Type,
StateKey: evt.StateKey,
Timestamp: time.UnixMilli(evt.Timestamp),
Content: evt.Content.VeryRaw,
RelatesTo: getRelatesTo(evt),
MegolmSessionID: getMegolmSessionID(evt),
}
dbEvt.Unsigned, _ = json.Marshal(&evt.Unsigned)
if evt.Unsigned.RedactedBecause != nil {
dbEvt.RedactedBy = evt.Unsigned.RedactedBecause.ID
}
return dbEvt
}
func (e *Event) AsRawMautrix() *event.Event {
evt := &event.Event{
RoomID: e.RoomID,
ID: e.ID,
Sender: e.Sender,
Type: event.Type{Type: e.Type, Class: event.MessageEventType},
StateKey: e.StateKey,
Timestamp: e.Timestamp.UnixMilli(),
Content: event.Content{VeryRaw: e.Content},
}
if e.Decrypted != nil {
evt.Content.VeryRaw = e.Decrypted
evt.Type.Type = e.DecryptedType
evt.Mautrix.WasEncrypted = true
}
if e.StateKey != nil {
evt.Type.Class = event.StateEventType
}
_ = json.Unmarshal(e.Unsigned, &evt.Unsigned)
return evt
}
func (e *Event) Scan(row dbutil.Scannable) (*Event, error) {
var timestamp int64
var redactedBy, relatesTo, megolmSessionID, decryptionError, decryptedType sql.NullString
err := row.Scan(
&e.RowID,
&e.RoomID,
&e.ID,
&e.Sender,
&e.Type,
&e.StateKey,
&timestamp,
(*[]byte)(&e.Content),
(*[]byte)(&e.Decrypted),
&decryptedType,
(*[]byte)(&e.Unsigned),
&redactedBy,
&relatesTo,
&megolmSessionID,
&decryptionError,
)
if err != nil {
return nil, err
}
e.Timestamp = time.UnixMilli(timestamp)
e.RedactedBy = id.EventID(redactedBy.String)
e.RelatesTo = id.EventID(relatesTo.String)
e.MegolmSessionID = id.SessionID(megolmSessionID.String)
e.DecryptedType = decryptedType.String
e.DecryptionError = decryptionError.String
return e, nil
}
var relatesToPath = exgjson.Path("m.relates_to", "event_id")
func getRelatesTo(evt *event.Event) id.EventID {
res := gjson.GetBytes(evt.Content.VeryRaw, relatesToPath)
if res.Exists() && res.Type == gjson.String {
return id.EventID(res.Str)
}
return ""
}
func getMegolmSessionID(evt *event.Event) id.SessionID {
if evt.Type != event.EventEncrypted {
return ""
}
res := gjson.GetBytes(evt.Content.VeryRaw, "session_id")
if res.Exists() && res.Type == gjson.String {
return id.SessionID(res.Str)
}
return ""
}
func (e *Event) sqlVariables() []any {
return []any{
e.RoomID,
e.ID,
e.Sender,
e.Type,
e.StateKey,
e.Timestamp.UnixMilli(),
unsafeJSONString(e.Content),
unsafeJSONString(e.Decrypted),
dbutil.StrPtr(e.DecryptedType),
unsafeJSONString(e.Unsigned),
dbutil.StrPtr(e.RedactedBy),
dbutil.StrPtr(e.RelatesTo),
dbutil.StrPtr(e.MegolmSessionID),
dbutil.StrPtr(e.DecryptionError),
}
}

115
hicli/database/room.go Normal file
View file

@ -0,0 +1,115 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package database
import (
"context"
"database/sql"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
const (
getRoomByIDQuery = `
SELECT room_id, creation_content, name, avatar, topic, lazy_load_summary, encryption_event, has_member_list, prev_batch
FROM room WHERE room_id = $1
`
ensureRoomExistsQuery = `
INSERT INTO room (room_id) VALUES ($1)
ON CONFLICT (room_id) DO NOTHING
`
upsertRoomFromSyncQuery = `
UPDATE room
SET creation_content = COALESCE(room.creation_content, $2),
name = COALESCE($3, room.name),
avatar = COALESCE($4, room.avatar),
topic = COALESCE($5, room.topic),
lazy_load_summary = COALESCE($6, room.lazy_load_summary),
encryption_event = COALESCE($7, room.encryption_event),
has_member_list = room.has_member_list OR $8,
prev_batch = COALESCE(room.prev_batch, $9)
WHERE room_id = $1
`
setRoomPrevBatchQuery = `
INSERT INTO room (room_id, prev_batch) VALUES ($1, $2)
ON CONFLICT (room_id) DO UPDATE SET prev_batch = excluded.prev_batch
`
)
type RoomQuery struct {
*dbutil.QueryHelper[*Room]
}
func (rq *RoomQuery) Get(ctx context.Context, roomID id.RoomID) (*Room, error) {
return rq.QueryOne(ctx, getRoomByIDQuery, roomID)
}
func (rq *RoomQuery) Upsert(ctx context.Context, room *Room) error {
return rq.Exec(ctx, upsertRoomFromSyncQuery, room.sqlVariables()...)
}
func (rq *RoomQuery) CreateRow(ctx context.Context, roomID id.RoomID) error {
return rq.Exec(ctx, ensureRoomExistsQuery, roomID)
}
func (rq *RoomQuery) SetPrevBatch(ctx context.Context, roomID id.RoomID, prevBatch string) error {
return rq.Exec(ctx, setRoomPrevBatchQuery, roomID, prevBatch)
}
type Room struct {
ID id.RoomID
CreationContent *event.CreateEventContent
Name *string
Avatar *id.ContentURI
Topic *string
LazyLoadSummary *mautrix.LazyLoadSummary
EncryptionEvent *event.EncryptionEventContent
HasMemberList bool
PrevBatch string
}
func (r *Room) Scan(row dbutil.Scannable) (*Room, error) {
var prevBatch sql.NullString
err := row.Scan(
&r.ID,
dbutil.JSON{Data: &r.CreationContent},
&r.Name,
&r.Avatar,
&r.Topic,
dbutil.JSON{Data: &r.LazyLoadSummary},
dbutil.JSON{Data: &r.EncryptionEvent},
&r.HasMemberList,
&prevBatch,
)
if err != nil {
return nil, err
}
r.PrevBatch = prevBatch.String
return r, nil
}
func (r *Room) sqlVariables() []any {
return []any{
r.ID,
dbutil.JSONPtr(r.CreationContent),
r.Name,
r.Avatar,
r.Topic,
dbutil.JSONPtr(r.LazyLoadSummary),
dbutil.JSONPtr(r.EncryptionEvent),
r.HasMemberList,
dbutil.StrPtr(r.PrevBatch),
}
}

View file

@ -0,0 +1,69 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package database
import (
"context"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/id"
)
const (
putSessionRequestQueueEntry = `
INSERT INTO session_request (room_id, session_id, sender, min_index, backup_checked, request_sent)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (session_id) DO UPDATE
SET min_index = MIN(excluded.min_index, session_request.min_index),
backup_checked = excluded.backup_checked OR session_request.backup_checked,
request_sent = excluded.request_sent OR session_request.request_sent
`
removeSessionRequestQuery = `
DELETE FROM session_request WHERE session_id = $1 AND min_index >= $2
`
getNextSessionsToRequestQuery = `
SELECT room_id, session_id, sender, min_index, backup_checked, request_sent
FROM session_request
WHERE request_sent = false OR backup_checked = false
ORDER BY backup_checked, rowid
LIMIT $1
`
)
type SessionRequestQuery struct {
*dbutil.QueryHelper[*SessionRequest]
}
func (srq *SessionRequestQuery) Next(ctx context.Context, count int) ([]*SessionRequest, error) {
return srq.QueryMany(ctx, getNextSessionsToRequestQuery, count)
}
func (srq *SessionRequestQuery) Remove(ctx context.Context, sessionID id.SessionID, minIndex uint32) error {
return srq.Exec(ctx, removeSessionRequestQuery, sessionID, minIndex)
}
func (srq *SessionRequestQuery) Put(ctx context.Context, sr *SessionRequest) error {
return srq.Exec(ctx, putSessionRequestQueueEntry, sr.sqlVariables()...)
}
type SessionRequest struct {
RoomID id.RoomID
SessionID id.SessionID
Sender id.UserID
MinIndex uint32
BackupChecked bool
RequestSent bool
}
func (s *SessionRequest) Scan(row dbutil.Scannable) (*SessionRequest, error) {
return dbutil.ValueOrErr(s, row.Scan(&s.RoomID, &s.SessionID, &s.Sender, &s.MinIndex, &s.BackupChecked, &s.RequestSent))
}
func (s *SessionRequest) sqlVariables() []any {
return []any{s.RoomID, s.SessionID, s.Sender, s.MinIndex, s.BackupChecked, s.RequestSent}
}

32
hicli/database/state.go Normal file
View file

@ -0,0 +1,32 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package database
import (
"context"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
const (
setCurrentStateQuery = `
INSERT INTO current_state (room_id, event_type, state_key, event_rowid, membership) VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (room_id, event_type, state_key) DO UPDATE SET event_rowid = excluded.event_rowid, membership = excluded.membership
`
)
type CurrentStateQuery struct {
*dbutil.Database
}
func (csq *CurrentStateQuery) Set(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, eventRowID int64, membership event.Membership) error {
_, err := csq.Exec(ctx, setCurrentStateQuery, roomID, eventType.Type, stateKey, eventRowID, dbutil.StrPtr(membership))
return err
}

View file

@ -0,0 +1,149 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package database
import (
"context"
"database/sql"
"errors"
"go.mau.fi/util/dbutil"
"golang.org/x/exp/slices"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
const (
getMembershipQuery = `
SELECT membership FROM current_state
WHERE room_id = $1 AND event_type = 'm.room.member' AND state_key = $2
`
getStateEventContentQuery = `
SELECT event.content FROM current_state cs
LEFT JOIN event ON event.rowid = cs.event_rowid
WHERE cs.room_id = $1 AND cs.event_type = $2 AND cs.state_key = $3
`
getRoomJoinedOrInvitedMembersQuery = `
SELECT state_key FROM current_state
WHERE room_id = $1 AND event_type = 'm.room.member' AND membership IN ('join', 'invite')
`
isRoomEncryptedQuery = `
SELECT room.encryption_event IS NOT NULL FROM room WHERE room_id = $1
`
getRoomEncryptionEventQuery = `
SELECT room.encryption_event FROM room WHERE room_id = $1
`
findSharedRoomsQuery = `
SELECT room_id FROM current_state
WHERE event_type = 'm.room.member' AND state_key = $1 AND membership = 'join'
`
)
type ClientStateStore struct {
*Database
}
var _ mautrix.StateStore = (*ClientStateStore)(nil)
var _ mautrix.StateStoreUpdater = (*ClientStateStore)(nil)
var _ crypto.StateStore = (*ClientStateStore)(nil)
func (c *ClientStateStore) IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool {
return c.IsMembership(ctx, roomID, userID, event.MembershipJoin)
}
func (c *ClientStateStore) IsInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) bool {
return c.IsMembership(ctx, roomID, userID, event.MembershipInvite, event.MembershipJoin)
}
func (c *ClientStateStore) IsMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
var membership event.Membership
err := c.QueryRow(ctx, getMembershipQuery, roomID, userID).Scan(&membership)
if errors.Is(err, sql.ErrNoRows) {
err = nil
membership = event.MembershipLeave
}
return slices.Contains(allowedMemberships, membership)
}
func (c *ClientStateStore) GetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) {
content, err := c.TryGetMember(ctx, roomID, userID)
if content == nil {
content = &event.MemberEventContent{Membership: event.MembershipLeave}
}
return content, err
}
func (c *ClientStateStore) TryGetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (content *event.MemberEventContent, err error) {
err = c.QueryRow(ctx, getStateEventContentQuery, roomID, event.StateMember.Type, userID).Scan(&dbutil.JSON{Data: &content})
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
return
}
func (c *ClientStateStore) GetPowerLevels(ctx context.Context, roomID id.RoomID) (content *event.PowerLevelsEventContent, err error) {
err = c.QueryRow(ctx, getStateEventContentQuery, roomID, event.StatePowerLevels.Type, "").Scan(&dbutil.JSON{Data: &content})
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
return
}
func (c *ClientStateStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error) {
rows, err := c.Query(ctx, getRoomJoinedOrInvitedMembersQuery, roomID)
return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList()
}
func (c *ClientStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (isEncrypted bool, err error) {
err = c.QueryRow(ctx, isRoomEncryptedQuery, roomID).Scan(&isEncrypted)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
return
}
func (c *ClientStateStore) GetEncryptionEvent(ctx context.Context, roomID id.RoomID) (content *event.EncryptionEventContent, err error) {
err = c.QueryRow(ctx, getRoomEncryptionEventQuery, roomID).
Scan(&dbutil.JSON{Data: &content})
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
return
}
func (c *ClientStateStore) FindSharedRooms(ctx context.Context, userID id.UserID) ([]id.RoomID, error) {
// TODO for multiuser support, this might need to filter by the local user's membership
rows, err := c.Query(ctx, findSharedRoomsQuery, userID)
return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.RoomID], err).AsList()
}
// Update methods are all intentionally no-ops as the state store wants to have the full event
func (c *ClientStateStore) SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error {
return nil
}
func (c *ClientStateStore) SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error {
return nil
}
func (c *ClientStateStore) ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error {
return nil
}
func (c *ClientStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error {
return nil
}
func (c *ClientStateStore) SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error {
return nil
}
func (c *ClientStateStore) UpdateState(ctx context.Context, evt *event.Event) {}

View file

@ -0,0 +1,47 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package database
import (
"context"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/id"
)
const (
clearTimelineQuery = `
DELETE FROM timeline WHERE room_id = $1
`
setTimelineQuery = `
INSERT INTO timeline (room_id, event_rowid) VALUES ($1, $2)
`
)
type MassInsertableRowID int64
func (m MassInsertableRowID) GetMassInsertValues() [1]any {
return [1]any{m}
}
var setTimelineQueryBuilder = dbutil.NewMassInsertBuilder[MassInsertableRowID, [1]any](setTimelineQuery, "($1, $%d)")
type TimelineQuery struct {
*dbutil.Database
}
func (tq *TimelineQuery) Clear(ctx context.Context, roomID id.RoomID) error {
_, err := tq.Exec(ctx, clearTimelineQuery, roomID)
return err
}
func (tq *TimelineQuery) Append(ctx context.Context, roomID id.RoomID, rowIDs []MassInsertableRowID) error {
query, params := setTimelineQueryBuilder.Build([1]any{roomID}, rowIDs)
_, err := tq.Exec(ctx, query, params...)
return err
}

View file

@ -0,0 +1,107 @@
-- v0 -> v1: Latest revision
CREATE TABLE account (
user_id TEXT NOT NULL PRIMARY KEY,
device_id TEXT NOT NULL,
access_token TEXT NOT NULL,
homeserver_url TEXT NOT NULL,
next_batch TEXT NOT NULL
) STRICT;
CREATE TABLE room (
room_id TEXT NOT NULL PRIMARY KEY,
creation_content TEXT,
name TEXT,
avatar TEXT,
topic TEXT,
lazy_load_summary TEXT,
encryption_event TEXT,
has_member_list INTEGER NOT NULL DEFAULT false,
prev_batch TEXT
) STRICT;
CREATE INDEX room_type_idx ON room (creation_content ->> 'type');
CREATE TABLE account_data (
user_id TEXT NOT NULL,
type TEXT NOT NULL,
content TEXT NOT NULL,
PRIMARY KEY (user_id, type)
) STRICT;
CREATE TABLE room_account_data (
user_id TEXT NOT NULL,
room_id TEXT NOT NULL,
type TEXT NOT NULL,
content TEXT NOT NULL,
PRIMARY KEY (user_id, room_id, type),
CONSTRAINT room_account_data_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE
) STRICT;
CREATE TABLE event (
rowid INTEGER PRIMARY KEY,
room_id TEXT NOT NULL,
event_id TEXT NOT NULL,
sender TEXT NOT NULL,
type TEXT NOT NULL,
state_key TEXT,
timestamp INTEGER NOT NULL,
content TEXT NOT NULL,
decrypted TEXT,
decrypted_type TEXT,
unsigned TEXT NOT NULL,
redacted_by TEXT,
relates_to TEXT,
megolm_session_id TEXT,
decryption_error TEXT,
CONSTRAINT event_id_unique_key UNIQUE (event_id),
CONSTRAINT event_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE
) STRICT;
CREATE INDEX event_room_id_idx ON event (room_id);
CREATE INDEX event_redacted_by_idx ON event (room_id, redacted_by);
CREATE INDEX event_relates_to_idx ON event (room_id, relates_to);
CREATE INDEX event_megolm_session_id_idx ON event (room_id, megolm_session_id);
CREATE TABLE session_request (
room_id TEXT NOT NULL,
session_id TEXT NOT NULL,
sender TEXT NOT NULL,
min_index INTEGER NOT NULL,
backup_checked INTEGER NOT NULL DEFAULT false,
request_sent INTEGER NOT NULL DEFAULT false,
PRIMARY KEY (session_id),
CONSTRAINT session_request_queue_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE
) STRICT;
CREATE TABLE timeline (
rowid INTEGER PRIMARY KEY,
room_id TEXT NOT NULL,
event_rowid INTEGER NOT NULL,
CONSTRAINT timeline_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE,
CONSTRAINT timeline_event_fkey FOREIGN KEY (event_rowid) REFERENCES event (rowid) ON DELETE CASCADE
) STRICT;
CREATE INDEX timeline_room_id_idx ON timeline (room_id);
CREATE TABLE current_state (
room_id TEXT NOT NULL,
event_type TEXT NOT NULL,
state_key TEXT NOT NULL,
event_rowid INTEGER NOT NULL,
membership TEXT,
PRIMARY KEY (room_id, event_type, state_key),
CONSTRAINT current_state_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE,
CONSTRAINT current_state_event_fkey FOREIGN KEY (event_rowid) REFERENCES event (rowid)
) STRICT, WITHOUT ROWID;

View file

@ -0,0 +1,22 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package upgrades
import (
"embed"
"go.mau.fi/util/dbutil"
)
var Table dbutil.UpgradeTable
//go:embed *.sql
var upgrades embed.FS
func init() {
Table.RegisterFS(upgrades)
}

194
hicli/decryptionqueue.go Normal file
View file

@ -0,0 +1,194 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package hicli
import (
"context"
"fmt"
"sync"
"github.com/rs/zerolog"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/hicli/database"
"maunium.net/go/mautrix/id"
)
func (h *HiClient) fetchFromKeyBackup(ctx context.Context, roomID id.RoomID, sessionID id.SessionID) (*crypto.InboundGroupSession, error) {
data, err := h.Client.GetKeyBackupForRoomAndSession(ctx, h.KeyBackupVersion, roomID, sessionID)
if err != nil {
return nil, err
} else if data == nil {
return nil, nil
}
decrypted, err := data.SessionData.Decrypt(h.KeyBackupKey)
if err != nil {
return nil, err
}
return h.Crypto.ImportRoomKeyFromBackup(ctx, h.KeyBackupVersion, roomID, sessionID, decrypted)
}
func (h *HiClient) handleReceivedMegolmSession(ctx context.Context, roomID id.RoomID, sessionID id.SessionID, firstKnownIndex uint32) {
log := zerolog.Ctx(ctx)
err := h.DB.SessionRequest.Remove(ctx, sessionID, firstKnownIndex)
if err != nil {
log.Warn().Err(err).Msg("Failed to remove session request after receiving megolm session")
}
events, err := h.DB.Event.GetFailedByMegolmSessionID(ctx, roomID, sessionID)
if err != nil {
log.Err(err).Msg("Failed to get events that failed to decrypt to retry decryption")
return
} else if len(events) == 0 {
log.Trace().Msg("No events to retry decryption for")
return
}
decrypted := events[:0]
for _, evt := range events {
if evt.Decrypted != nil {
continue
}
evt.Decrypted, evt.DecryptedType, err = h.decryptEvent(ctx, evt.AsRawMautrix())
if err != nil {
log.Warn().Err(err).Stringer("event_id", evt.ID).Msg("Failed to decrypt event even after receiving megolm session")
} else {
decrypted = append(decrypted, evt)
}
}
if len(decrypted) > 0 {
err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
for _, evt := range decrypted {
err = h.DB.Event.UpdateDecrypted(ctx, evt.RowID, evt.Decrypted, evt.DecryptedType)
if err != nil {
return fmt.Errorf("failed to save decrypted content for %s: %w", evt.ID, err)
}
}
return nil
})
if err != nil {
log.Err(err).Msg("Failed to save decrypted events")
}
}
}
func (h *HiClient) WakeupRequestQueue() {
select {
case h.requestQueueWakeup <- struct{}{}:
default:
}
}
func (h *HiClient) RunRequestQueue(ctx context.Context) {
log := zerolog.Ctx(ctx).With().Str("action", "request queue").Logger()
ctx = log.WithContext(ctx)
log.Info().Msg("Starting key request queue")
defer func() {
log.Info().Msg("Stopping key request queue")
}()
for {
err := h.FetchKeysForOutdatedUsers(ctx)
if err != nil {
log.Err(err).Msg("Failed to fetch outdated device lists for tracked users")
}
madeRequests, err := h.RequestQueuedSessions(ctx)
if err != nil {
log.Err(err).Msg("Failed to handle session request queue")
} else if madeRequests {
continue
}
select {
case <-ctx.Done():
return
case <-h.requestQueueWakeup:
}
}
}
func (h *HiClient) requestQueuedSession(ctx context.Context, req *database.SessionRequest, doneFunc func()) {
defer doneFunc()
log := zerolog.Ctx(ctx)
if !req.BackupChecked {
sess, err := h.fetchFromKeyBackup(ctx, req.RoomID, req.SessionID)
if err != nil {
log.Err(err).
Stringer("session_id", req.SessionID).
Msg("Failed to fetch session from key backup")
// TODO should this have retries instead of just storing it's checked?
req.BackupChecked = true
err = h.DB.SessionRequest.Put(ctx, req)
if err != nil {
log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after trying to check backup")
}
} else if sess == nil || sess.Internal.FirstKnownIndex() > req.MinIndex {
req.BackupChecked = true
err = h.DB.SessionRequest.Put(ctx, req)
if err != nil {
log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after checking backup")
}
} else {
log.Debug().Stringer("session_id", req.SessionID).
Msg("Found session with sufficiently low first known index, removing from queue")
err = h.DB.SessionRequest.Remove(ctx, req.SessionID, sess.Internal.FirstKnownIndex())
if err != nil {
log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to remove session from request queue")
}
}
} else {
err := h.Crypto.SendRoomKeyRequest(ctx, req.RoomID, "", req.SessionID, "", map[id.UserID][]id.DeviceID{
h.Account.UserID: {"*"},
req.Sender: {"*"},
})
//var err error
if err != nil {
log.Err(err).
Stringer("session_id", req.SessionID).
Msg("Failed to send key request")
} else {
log.Debug().Stringer("session_id", req.SessionID).Msg("Sent key request")
req.RequestSent = true
err = h.DB.SessionRequest.Put(ctx, req)
if err != nil {
log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after sending request")
}
}
}
}
const MaxParallelRequests = 5
func (h *HiClient) RequestQueuedSessions(ctx context.Context) (bool, error) {
sessions, err := h.DB.SessionRequest.Next(ctx, MaxParallelRequests)
if err != nil {
return false, fmt.Errorf("failed to get next events to decrypt: %w", err)
} else if len(sessions) == 0 {
return false, nil
}
var wg sync.WaitGroup
wg.Add(len(sessions))
for _, req := range sessions {
go h.requestQueuedSession(ctx, req, wg.Done)
}
wg.Wait()
return true, err
}
func (h *HiClient) FetchKeysForOutdatedUsers(ctx context.Context) error {
outdatedUsers, err := h.Crypto.CryptoStore.GetOutdatedTrackedUsers(ctx)
if err != nil {
return err
} else if len(outdatedUsers) == 0 {
return nil
}
_, err = h.Crypto.FetchKeys(ctx, outdatedUsers, false)
if err != nil {
return err
}
// TODO backoff for users that fail to be fetched?
return nil
}

159
hicli/hicli.go Normal file
View file

@ -0,0 +1,159 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
// Package hicli contains a highly opinionated high-level framework for developing instant messaging clients on Matrix.
package hicli
import (
"context"
"fmt"
"net"
"net/http"
"net/url"
"sync"
"time"
"github.com/rs/zerolog"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/crypto/backup"
"maunium.net/go/mautrix/hicli/database"
"maunium.net/go/mautrix/id"
)
type HiClient struct {
DB *database.Database
Account *database.Account
Client *mautrix.Client
Crypto *crypto.OlmMachine
CryptoStore *crypto.SQLCryptoStore
ClientStore *database.ClientStateStore
Log zerolog.Logger
Verified bool
KeyBackupVersion id.KeyBackupVersion
KeyBackupKey *backup.MegolmBackupKey
firstSyncReceived bool
syncingID int
syncLock sync.Mutex
encryptLock sync.Mutex
requestQueueWakeup chan struct{}
}
func New(rawDB *dbutil.Database, log zerolog.Logger, pickleKey []byte) *HiClient {
rawDB.Owner = "hicli"
rawDB.IgnoreForeignTables = true
db := database.New(rawDB)
db.Log = dbutil.ZeroLogger(log.With().Str("db_section", "hicli").Logger())
c := &HiClient{
DB: db,
Log: log,
requestQueueWakeup: make(chan struct{}, 1),
}
c.ClientStore = &database.ClientStateStore{Database: db}
c.Client = &mautrix.Client{
UserAgent: mautrix.DefaultUserAgent,
Client: &http.Client{
Transport: &http.Transport{
DialContext: (&net.Dialer{Timeout: 10 * time.Second}).DialContext,
TLSHandshakeTimeout: 10 * time.Second,
// This needs to be relatively high to allow initial syncs
ResponseHeaderTimeout: 180 * time.Second,
ForceAttemptHTTP2: true,
},
Timeout: 180 * time.Second,
},
Syncer: (*hiSyncer)(c),
Store: (*hiStore)(c),
StateStore: c.ClientStore,
Log: log.With().Str("component", "mautrix client").Logger(),
}
c.CryptoStore = crypto.NewSQLCryptoStore(rawDB, dbutil.ZeroLogger(log.With().Str("db_section", "crypto").Logger()), "", "", pickleKey)
cryptoLog := log.With().Str("component", "crypto").Logger()
c.Crypto = crypto.NewOlmMachine(c.Client, &cryptoLog, c.CryptoStore, c.ClientStore)
c.Crypto.SessionReceived = c.handleReceivedMegolmSession
c.Crypto.DisableRatchetTracking = true
c.Crypto.DisableDecryptKeyFetching = true
c.Client.Crypto = (*hiCryptoHelper)(c)
return c
}
func (h *HiClient) IsLoggedIn() bool {
return h.Account != nil
}
func (h *HiClient) Start(ctx context.Context, userID id.UserID) error {
err := h.DB.Upgrade(ctx)
if err != nil {
return fmt.Errorf("failed to upgrade hicli db: %w", err)
}
err = h.CryptoStore.DB.Upgrade(ctx)
if err != nil {
return fmt.Errorf("failed to upgrade crypto db: %w", err)
}
account, err := h.DB.Account.Get(ctx, userID)
if err != nil {
return err
}
if account != nil {
zerolog.Ctx(ctx).Debug().Stringer("user_id", account.UserID).Msg("Preparing client with existing credentials")
h.Account = account
h.CryptoStore.AccountID = account.UserID.String()
h.CryptoStore.DeviceID = account.DeviceID
h.Client.UserID = account.UserID
h.Client.DeviceID = account.DeviceID
h.Client.AccessToken = account.AccessToken
h.Client.HomeserverURL, err = url.Parse(account.HomeserverURL)
if err != nil {
return err
}
err = h.Crypto.Load(ctx)
if err != nil {
return fmt.Errorf("failed to load olm machine: %w", err)
}
h.Verified, err = h.checkIsCurrentDeviceVerified(ctx)
if err != nil {
return err
}
zerolog.Ctx(ctx).Debug().Bool("verified", h.Verified).Msg("Checked current device verification status")
if h.Verified {
err = h.loadPrivateKeys(ctx)
if err != nil {
return err
}
go h.Sync()
go h.RunRequestQueue(ctx)
}
}
return nil
}
func (h *HiClient) Sync() {
h.Client.StopSync()
h.syncLock.Lock()
defer h.syncLock.Unlock()
h.syncingID++
syncingID := h.syncingID
log := h.Log.With().
Str("action", "sync").
Int("sync_id", syncingID).
Logger()
ctx := log.WithContext(context.Background())
log.Info().Msg("Starting syncing")
err := h.Client.SyncWithContext(ctx)
if err != nil {
log.Err(err).Msg("Fatal error in syncer")
} else {
log.Info().Msg("Syncing stopped")
}
}

69
hicli/hitest/hitest.go Normal file
View file

@ -0,0 +1,69 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package main
import (
"context"
"io"
"os"
"os/signal"
"syscall"
"github.com/chzyer/readline"
_ "github.com/mattn/go-sqlite3"
"go.mau.fi/util/dbutil"
_ "go.mau.fi/util/dbutil/litestream"
"go.mau.fi/util/exerrors"
"go.mau.fi/util/exzerolog"
"go.mau.fi/zeroconfig"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/hicli"
"maunium.net/go/mautrix/id"
)
var writerTypeReadline zeroconfig.WriterType = "hitest_readline"
func main() {
rl := exerrors.Must(readline.New("> "))
defer func() {
_ = rl.Close()
}()
zeroconfig.RegisterWriter(writerTypeReadline, func(config *zeroconfig.WriterConfig) (io.Writer, error) {
return rl.Stdout(), nil
})
log := exerrors.Must((&zeroconfig.Config{
Writers: []zeroconfig.WriterConfig{{
Type: writerTypeReadline,
Format: zeroconfig.LogFormatPrettyColored,
}},
}).Compile())
exzerolog.SetupDefaults(log)
rawDB := exerrors.Must(dbutil.NewWithDialect("hicli.db", "sqlite3-fk-wal"))
ctx := log.WithContext(context.Background())
cli := hicli.New(rawDB, *log, []byte("meow"))
userID, _ := cli.DB.Account.GetFirstUserID(ctx)
exerrors.PanicIfNotNil(cli.Start(ctx, userID))
if !cli.IsLoggedIn() {
rl.SetPrompt("User ID: ")
userID := id.UserID(exerrors.Must(rl.Readline()))
_, serverName := exerrors.Must2(userID.Parse())
discovery, err := mautrix.DiscoverClientAPI(ctx, serverName)
if discovery == nil {
log.Fatal().Err(err).Msg("Failed to discover homeserver")
}
password := exerrors.Must(rl.ReadPassword("Password: "))
recoveryCode := exerrors.Must(rl.ReadPassword("Recovery code: "))
exerrors.PanicIfNotNil(cli.LoginAndVerify(ctx, discovery.Homeserver.BaseURL, userID.String(), string(password), string(recoveryCode)))
}
rl.SetPrompt("> ")
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
<-c
}

77
hicli/login.go Normal file
View file

@ -0,0 +1,77 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package hicli
import (
"context"
"fmt"
"net/url"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/hicli/database"
)
func (h *HiClient) LoginPassword(ctx context.Context, homeserverURL, username, password string) error {
var err error
h.Client.HomeserverURL, err = url.Parse(homeserverURL)
if err != nil {
return err
}
return h.Login(ctx, &mautrix.ReqLogin{
Type: mautrix.AuthTypePassword,
Identifier: mautrix.UserIdentifier{
Type: mautrix.IdentifierTypeUser,
User: username,
},
Password: password,
InitialDeviceDisplayName: "mautrix client",
})
}
func (h *HiClient) Login(ctx context.Context, req *mautrix.ReqLogin) error {
req.StoreCredentials = true
req.StoreHomeserverURL = true
resp, err := h.Client.Login(ctx, req)
if err != nil {
return err
}
h.Account = &database.Account{
UserID: resp.UserID,
DeviceID: resp.DeviceID,
AccessToken: resp.AccessToken,
HomeserverURL: h.Client.HomeserverURL.String(),
}
h.CryptoStore.AccountID = resp.UserID.String()
h.CryptoStore.DeviceID = resp.DeviceID
err = h.DB.Account.Put(ctx, h.Account)
if err != nil {
return err
}
err = h.Crypto.Load(ctx)
if err != nil {
return fmt.Errorf("failed to load olm machine: %w", err)
}
err = h.Crypto.ShareKeys(ctx, 0)
if err != nil {
return err
}
return nil
}
func (h *HiClient) LoginAndVerify(ctx context.Context, homeserverURL, username, password, recoveryCode string) error {
err := h.LoginPassword(ctx, homeserverURL, username, password)
if err != nil {
return err
}
err = h.VerifyWithRecoveryCode(ctx, recoveryCode)
if err != nil {
return err
}
go h.Sync()
go h.RunRequestQueue(ctx)
return nil
}

362
hicli/sync.go Normal file
View file

@ -0,0 +1,362 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package hicli
import (
"context"
"errors"
"fmt"
"github.com/rs/zerolog"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.mau.fi/util/exzerolog"
"golang.org/x/exp/slices"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/hicli/database"
"maunium.net/go/mautrix/id"
)
type syncContext struct {
shouldWakeupRequestQueue bool
}
func (h *HiClient) preProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) error {
log := zerolog.Ctx(ctx)
postponedToDevices := resp.ToDevice.Events[:0]
for _, evt := range resp.ToDevice.Events {
evt.Type.Class = event.ToDeviceEventType
err := evt.Content.ParseRaw(evt.Type)
if err != nil {
log.Warn().Err(err).
Stringer("event_type", &evt.Type).
Stringer("sender", evt.Sender).
Msg("Failed to parse to-device event, skipping")
continue
}
switch content := evt.Content.Parsed.(type) {
case *event.EncryptedEventContent:
h.Crypto.HandleEncryptedEvent(ctx, evt)
case *event.RoomKeyWithheldEventContent:
h.Crypto.HandleRoomKeyWithheld(ctx, content)
default:
postponedToDevices = append(postponedToDevices, evt)
}
}
resp.ToDevice.Events = postponedToDevices
return nil
}
func (h *HiClient) postProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) error {
h.Crypto.HandleOTKCounts(ctx, &resp.DeviceOTKCount)
go h.asyncPostProcessSyncResponse(ctx, resp, since)
if ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue {
h.WakeupRequestQueue()
}
return nil
}
func (h *HiClient) asyncPostProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) {
for _, evt := range resp.ToDevice.Events {
switch content := evt.Content.Parsed.(type) {
case *event.SecretRequestEventContent:
h.Crypto.HandleSecretRequest(ctx, evt.Sender, content)
case *event.RoomKeyRequestEventContent:
h.Crypto.HandleRoomKeyRequest(ctx, evt.Sender, content)
}
}
}
func (h *HiClient) processSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) error {
if len(resp.DeviceLists.Changed) > 0 {
zerolog.Ctx(ctx).Debug().
Array("users", exzerolog.ArrayOfStringers(resp.DeviceLists.Changed)).
Msg("Marking changed device lists for tracked users as outdated")
err := h.Crypto.CryptoStore.MarkTrackedUsersOutdated(ctx, resp.DeviceLists.Changed)
if err != nil {
return fmt.Errorf("failed to mark changed device lists as outdated: %w", err)
}
ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue = true
}
for _, evt := range resp.AccountData.Events {
evt.Type.Class = event.AccountDataEventType
err := h.DB.AccountData.Put(ctx, h.Account.UserID, evt.Type, evt.Content.VeryRaw)
if err != nil {
return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err)
}
}
for roomID, room := range resp.Rooms.Join {
err := h.processSyncJoinedRoom(ctx, roomID, room)
if err != nil {
return fmt.Errorf("failed to process joined room %s: %w", roomID, err)
}
}
for roomID, room := range resp.Rooms.Leave {
err := h.processSyncLeftRoom(ctx, roomID, room)
if err != nil {
return fmt.Errorf("failed to process left room %s: %w", roomID, err)
}
}
h.Account.NextBatch = resp.NextBatch
err := h.DB.Account.PutNextBatch(ctx, h.Account.UserID, resp.NextBatch)
if err != nil {
return fmt.Errorf("failed to save next_batch: %w", err)
}
return nil
}
func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncJoinedRoom) error {
existingRoomData, err := h.DB.Room.Get(ctx, roomID)
if err != nil {
return fmt.Errorf("failed to get room data: %w", err)
} else if existingRoomData == nil {
err = h.DB.Room.CreateRow(ctx, roomID)
if err != nil {
return fmt.Errorf("failed to ensure room row exists: %w", err)
}
existingRoomData = &database.Room{ID: roomID}
}
for _, evt := range room.AccountData.Events {
evt.Type.Class = event.AccountDataEventType
evt.RoomID = roomID
err = h.DB.AccountData.PutRoom(ctx, h.Account.UserID, roomID, evt.Type, evt.Content.VeryRaw)
if err != nil {
return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err)
}
}
err = h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary)
if err != nil {
return err
}
return nil
}
func (h *HiClient) processSyncLeftRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncLeftRoom) error {
existingRoomData, err := h.DB.Room.Get(ctx, roomID)
if err != nil {
return fmt.Errorf("failed to get room data: %w", err)
} else if existingRoomData == nil {
return nil
}
return h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary)
}
func isDecryptionErrorRetryable(err error) bool {
return errors.Is(err, crypto.NoSessionFound) || errors.Is(err, olm.UnknownMessageIndex) || errors.Is(err, crypto.ErrGroupSessionWithheld)
}
func removeReplyFallback(evt *event.Event) []byte {
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
if ok && content.RelatesTo.GetReplyTo() != "" {
prevFormattedBody := content.FormattedBody
content.RemoveReplyFallback()
if content.FormattedBody != prevFormattedBody {
bytes, err := sjson.SetBytes(evt.Content.VeryRaw, "formatted_body", content.FormattedBody)
if err == nil {
return bytes
}
bytes, err = sjson.SetBytes(evt.Content.VeryRaw, "body", content.Body)
if err == nil {
return bytes
}
}
}
return nil
}
func (h *HiClient) decryptEvent(ctx context.Context, evt *event.Event) ([]byte, string, error) {
err := evt.Content.ParseRaw(evt.Type)
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
return nil, "", err
}
decrypted, err := h.Crypto.DecryptMegolmEvent(ctx, evt)
if err != nil {
return nil, "", err
}
withoutFallback := removeReplyFallback(decrypted)
if withoutFallback != nil {
return withoutFallback, decrypted.Type.Type, nil
}
return decrypted.Content.VeryRaw, decrypted.Type.Type, nil
}
func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.Room, state *mautrix.SyncEventsList, timeline *mautrix.SyncTimeline, summary *mautrix.LazyLoadSummary) error {
decryptionQueue := make(map[id.SessionID]*database.SessionRequest)
roomDataChanged := false
processEvent := func(evt *event.Event) (database.MassInsertableRowID, error) {
evt.RoomID = room.ID
dbEvt := database.MautrixToEvent(evt)
contentWithoutFallback := removeReplyFallback(evt)
if contentWithoutFallback != nil {
dbEvt.Content = contentWithoutFallback
}
var decryptionErr error
if evt.Type == event.EventEncrypted {
dbEvt.Decrypted, dbEvt.DecryptedType, decryptionErr = h.decryptEvent(ctx, evt)
if decryptionErr != nil {
dbEvt.DecryptionError = decryptionErr.Error()
}
}
rowID, err := h.DB.Event.Upsert(ctx, dbEvt)
if err != nil {
return -1, fmt.Errorf("failed to save event %s: %w", evt.ID, err)
}
if decryptionErr != nil && isDecryptionErrorRetryable(decryptionErr) {
req, ok := decryptionQueue[dbEvt.MegolmSessionID]
if !ok {
req = &database.SessionRequest{
RoomID: room.ID,
SessionID: dbEvt.MegolmSessionID,
Sender: evt.Sender,
}
}
minIndex, _ := crypto.ParseMegolmMessageIndex(evt.Content.AsEncrypted().MegolmCiphertext)
req.MinIndex = min(uint32(minIndex), req.MinIndex)
decryptionQueue[dbEvt.MegolmSessionID] = req
}
if evt.StateKey != nil {
var membership event.Membership
if evt.Type == event.StateMember {
membership = event.Membership(gjson.GetBytes(evt.Content.VeryRaw, "membership").Str)
}
err = h.DB.CurrentState.Set(ctx, room.ID, evt.Type, *evt.StateKey, rowID, membership)
if err != nil {
return -1, fmt.Errorf("failed to save current state event ID %s for %s/%s: %w", evt.ID, evt.Type.Type, *evt.StateKey, err)
}
roomDataChanged = processImportantEvent(ctx, evt, room) || roomDataChanged
}
return database.MassInsertableRowID(rowID), nil
}
var err error
for _, evt := range state.Events {
evt.Type.Class = event.StateEventType
_, err = processEvent(evt)
if err != nil {
return err
}
}
if len(timeline.Events) > 0 {
timelineIDs := make([]database.MassInsertableRowID, len(timeline.Events))
for i, evt := range timeline.Events {
if evt.StateKey != nil {
evt.Type.Class = event.StateEventType
} else {
evt.Type.Class = event.MessageEventType
}
timelineIDs[i], err = processEvent(evt)
if err != nil {
return err
}
}
for _, entry := range decryptionQueue {
err = h.DB.SessionRequest.Put(ctx, entry)
if err != nil {
return fmt.Errorf("failed to save session request for %s: %w", entry.SessionID, err)
}
}
if len(decryptionQueue) > 0 {
ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue = true
}
if timeline.Limited {
err = h.DB.Timeline.Clear(ctx, room.ID)
if err != nil {
return fmt.Errorf("failed to clear old timeline: %w", err)
}
}
err = h.DB.Timeline.Append(ctx, room.ID, timelineIDs)
if err != nil {
return fmt.Errorf("failed to append timeline: %w", err)
}
}
if timeline.PrevBatch != "" && room.PrevBatch == "" {
room.PrevBatch = timeline.PrevBatch
roomDataChanged = true
}
if summary.Heroes != nil {
roomDataChanged = roomDataChanged || room.LazyLoadSummary == nil ||
!slices.Equal(summary.Heroes, room.LazyLoadSummary.Heroes) ||
!intPtrEqual(summary.JoinedMemberCount, room.LazyLoadSummary.JoinedMemberCount) ||
!intPtrEqual(summary.InvitedMemberCount, room.LazyLoadSummary.InvitedMemberCount)
room.LazyLoadSummary = summary
}
if roomDataChanged {
err = h.DB.Room.Upsert(ctx, room)
if err != nil {
return fmt.Errorf("failed to save room data: %w", err)
}
}
return nil
}
func intPtrEqual(a, b *int) bool {
if a == nil || b == nil {
return a == b
}
return *a == *b
}
func processImportantEvent(ctx context.Context, evt *event.Event, existingRoomData *database.Room) (roomDataChanged bool) {
if evt.StateKey == nil {
return
}
switch evt.Type {
case event.StateCreate, event.StateRoomName, event.StateRoomAvatar, event.StateTopic, event.StateEncryption:
if *evt.StateKey != "" {
return
}
default:
return
}
err := evt.Content.ParseRaw(evt.Type)
if err != nil {
zerolog.Ctx(ctx).Warn().Err(err).
Stringer("event_type", &evt.Type).
Stringer("event_id", evt.ID).
Msg("Failed to parse state event, skipping")
return
}
switch evt.Type {
case event.StateCreate:
if existingRoomData.CreationContent == nil {
roomDataChanged = true
}
existingRoomData.CreationContent, _ = evt.Content.Parsed.(*event.CreateEventContent)
case event.StateEncryption:
newEncryption, _ := evt.Content.Parsed.(*event.EncryptionEventContent)
if existingRoomData.EncryptionEvent == nil || existingRoomData.EncryptionEvent.Algorithm == newEncryption.Algorithm {
roomDataChanged = true
existingRoomData.EncryptionEvent = newEncryption
}
case event.StateRoomName:
content, ok := evt.Content.Parsed.(*event.RoomNameEventContent)
if ok {
roomDataChanged = existingRoomData.Name == nil || *existingRoomData.Name != content.Name
existingRoomData.Name = &content.Name
}
case event.StateRoomAvatar:
content, ok := evt.Content.Parsed.(*event.RoomAvatarEventContent)
if ok {
roomDataChanged = existingRoomData.Avatar == nil || *existingRoomData.Avatar != content.URL
existingRoomData.Avatar = &content.URL
}
case event.StateTopic:
content, ok := evt.Content.Parsed.(*event.TopicEventContent)
if ok {
roomDataChanged = existingRoomData.Topic == nil || *existingRoomData.Topic != content.Topic
existingRoomData.Topic = &content.Topic
}
}
return
}

100
hicli/syncwrap.go Normal file
View file

@ -0,0 +1,100 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package hicli
import (
"context"
"fmt"
"time"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/id"
)
type hiSyncer HiClient
var _ mautrix.Syncer = (*hiSyncer)(nil)
type contextKey int
const (
syncContextKey contextKey = iota
)
func (h *hiSyncer) ProcessResponse(ctx context.Context, resp *mautrix.RespSync, since string) error {
c := (*HiClient)(h)
ctx = context.WithValue(ctx, syncContextKey, &syncContext{})
err := c.preProcessSyncResponse(ctx, resp, since)
if err != nil {
return err
}
err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
return c.processSyncResponse(ctx, resp, since)
})
if err != nil {
return err
}
err = c.postProcessSyncResponse(ctx, resp, since)
if err != nil {
return err
}
c.firstSyncReceived = true
return nil
}
func (h *hiSyncer) OnFailedSync(_ *mautrix.RespSync, err error) (time.Duration, error) {
(*HiClient)(h).Log.Err(err).Msg("Sync failed, retrying in 1 second")
return 1 * time.Second, nil
}
func (h *hiSyncer) GetFilterJSON(_ id.UserID) *mautrix.Filter {
if !h.Verified {
return &mautrix.Filter{
Presence: mautrix.FilterPart{
NotRooms: []id.RoomID{"*"},
},
Room: mautrix.RoomFilter{
NotRooms: []id.RoomID{"*"},
},
}
}
return &mautrix.Filter{
Presence: mautrix.FilterPart{
NotRooms: []id.RoomID{"*"},
},
Room: mautrix.RoomFilter{
State: mautrix.FilterPart{
LazyLoadMembers: true,
},
Timeline: mautrix.FilterPart{
Limit: 100,
LazyLoadMembers: true,
},
},
}
}
type hiStore HiClient
var _ mautrix.SyncStore = (*hiStore)(nil)
// Filter ID save and load are intentionally no-ops: we want to recreate filters when restarting syncing
func (h *hiStore) SaveFilterID(_ context.Context, _ id.UserID, _ string) error { return nil }
func (h *hiStore) LoadFilterID(_ context.Context, _ id.UserID) (string, error) { return "", nil }
func (h *hiStore) SaveNextBatch(ctx context.Context, userID id.UserID, nextBatchToken string) error {
// This is intentionally a no-op: we don't want to save the next batch before processing the sync
return nil
}
func (h *hiStore) LoadNextBatch(_ context.Context, userID id.UserID) (string, error) {
if h.Account.UserID != userID {
return "", fmt.Errorf("mismatching user ID")
}
return h.Account.NextBatch, nil
}

158
hicli/verify.go Normal file
View file

@ -0,0 +1,158 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package hicli
import (
"context"
"encoding/base64"
"fmt"
"github.com/rs/zerolog"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/crypto/backup"
"maunium.net/go/mautrix/crypto/ssss"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
func (h *HiClient) checkIsCurrentDeviceVerified(ctx context.Context) (bool, error) {
keys := h.Crypto.GetOwnCrossSigningPublicKeys(ctx)
if keys == nil {
return false, fmt.Errorf("own cross-signing keys not found")
}
isVerified, err := h.Crypto.CryptoStore.IsKeySignedBy(ctx, h.Account.UserID, h.Crypto.GetAccount().SigningKey(), h.Account.UserID, keys.SelfSigningKey)
if err != nil {
return false, fmt.Errorf("failed to check if current device is signed by own self-signing key: %w", err)
}
return isVerified, nil
}
func (h *HiClient) fetchKeyBackupKey(ctx context.Context, ssssKey *ssss.Key) error {
latestVersion, err := h.Client.GetKeyBackupLatestVersion(ctx)
if err != nil {
return fmt.Errorf("failed to get key backup latest version: %w", err)
}
h.KeyBackupVersion = latestVersion.Version
data, err := h.Crypto.SSSS.GetDecryptedAccountData(ctx, event.AccountDataMegolmBackupKey, ssssKey)
if err != nil {
return fmt.Errorf("failed to get megolm backup key from SSSS: %w", err)
}
key, err := backup.MegolmBackupKeyFromBytes(data)
if err != nil {
return fmt.Errorf("failed to parse megolm backup key: %w", err)
}
err = h.CryptoStore.PutSecret(ctx, id.SecretMegolmBackupV1, base64.StdEncoding.EncodeToString(key.Bytes()))
if err != nil {
return fmt.Errorf("failed to store megolm backup key: %w", err)
}
h.KeyBackupKey = key
return nil
}
func (h *HiClient) getAndDecodeSecret(ctx context.Context, secret id.Secret) ([]byte, error) {
secretData, err := h.CryptoStore.GetSecret(ctx, secret)
if err != nil {
return nil, fmt.Errorf("failed to get secret %s: %w", secret, err)
}
data, err := base64.StdEncoding.DecodeString(secretData)
if err != nil {
return nil, fmt.Errorf("failed to decode secret %s: %w", secret, err)
}
return data, nil
}
func (h *HiClient) loadPrivateKeys(ctx context.Context) error {
zerolog.Ctx(ctx).Debug().Msg("Loading cross-signing private keys")
masterKeySeed, err := h.getAndDecodeSecret(ctx, id.SecretXSMaster)
if err != nil {
return fmt.Errorf("failed to get master key: %w", err)
}
selfSigningKeySeed, err := h.getAndDecodeSecret(ctx, id.SecretXSSelfSigning)
if err != nil {
return fmt.Errorf("failed to get self-signing key: %w", err)
}
userSigningKeySeed, err := h.getAndDecodeSecret(ctx, id.SecretXSUserSigning)
if err != nil {
return fmt.Errorf("failed to get user signing key: %w", err)
}
err = h.Crypto.ImportCrossSigningKeys(crypto.CrossSigningSeeds{
MasterKey: masterKeySeed,
SelfSigningKey: selfSigningKeySeed,
UserSigningKey: userSigningKeySeed,
})
if err != nil {
return fmt.Errorf("failed to import cross-signing private keys: %w", err)
}
zerolog.Ctx(ctx).Debug().Msg("Loading key backup key")
keyBackupKey, err := h.getAndDecodeSecret(ctx, id.SecretMegolmBackupV1)
if err != nil {
return fmt.Errorf("failed to get megolm backup key: %w", err)
}
h.KeyBackupKey, err = backup.MegolmBackupKeyFromBytes(keyBackupKey)
if err != nil {
return fmt.Errorf("failed to parse megolm backup key: %w", err)
}
zerolog.Ctx(ctx).Debug().Msg("Fetching key backup version")
latestVersion, err := h.Client.GetKeyBackupLatestVersion(ctx)
if err != nil {
return fmt.Errorf("failed to get key backup latest version: %w", err)
}
h.KeyBackupVersion = latestVersion.Version
zerolog.Ctx(ctx).Debug().Msg("Secrets loaded")
return nil
}
func (h *HiClient) storeCrossSigningPrivateKeys(ctx context.Context) error {
keys := h.Crypto.CrossSigningKeys
err := h.CryptoStore.PutSecret(ctx, id.SecretXSMaster, base64.StdEncoding.EncodeToString(keys.MasterKey.Seed()))
if err != nil {
return err
}
err = h.CryptoStore.PutSecret(ctx, id.SecretXSSelfSigning, base64.StdEncoding.EncodeToString(keys.SelfSigningKey.Seed()))
if err != nil {
return err
}
err = h.CryptoStore.PutSecret(ctx, id.SecretXSUserSigning, base64.StdEncoding.EncodeToString(keys.UserSigningKey.Seed()))
if err != nil {
return err
}
return nil
}
func (h *HiClient) VerifyWithRecoveryCode(ctx context.Context, code string) error {
_, keyData, err := h.Crypto.SSSS.GetDefaultKeyData(ctx)
if err != nil {
return fmt.Errorf("failed to get default SSSS key data: %w", err)
}
key, err := keyData.VerifyRecoveryKey(code)
if err != nil {
return err
}
err = h.Crypto.FetchCrossSigningKeysFromSSSS(ctx, key)
if err != nil {
return fmt.Errorf("failed to fetch cross-signing keys from SSSS: %w", err)
}
err = h.Crypto.SignOwnDevice(ctx, h.Crypto.OwnIdentity())
if err != nil {
return fmt.Errorf("failed to sign own device: %w", err)
}
err = h.Crypto.SignOwnMasterKey(ctx)
if err != nil {
return fmt.Errorf("failed to sign own master key: %w", err)
}
err = h.storeCrossSigningPrivateKeys(ctx)
if err != nil {
return fmt.Errorf("failed to store cross-signing private keys: %w", err)
}
err = h.fetchKeyBackupKey(ctx, key)
if err != nil {
return fmt.Errorf("failed to fetch key backup key: %w", err)
}
h.Verified = true
return nil
}