mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
Add draft of high-level client framework
This commit is contained in:
parent
0b07ae9942
commit
c1eb217b9e
20 changed files with 2132 additions and 1 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -1,4 +1,4 @@
|
|||
.idea/
|
||||
.vscode/
|
||||
*.db
|
||||
*.db*
|
||||
*.log
|
||||
|
|
|
|||
82
hicli/cryptohelper.go
Normal file
82
hicli/cryptohelper.go
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package hicli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/crypto"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
type hiCryptoHelper HiClient
|
||||
|
||||
var _ mautrix.CryptoHelper = (*hiCryptoHelper)(nil)
|
||||
|
||||
func (h *hiCryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtType event.Type, content any) (encrypted *event.EncryptedEventContent, err error) {
|
||||
h.encryptLock.Lock()
|
||||
defer h.encryptLock.Unlock()
|
||||
encrypted, err = h.Crypto.EncryptMegolmEvent(ctx, roomID, evtType, content)
|
||||
if err != nil {
|
||||
if !errors.Is(err, crypto.SessionExpired) && !errors.Is(err, crypto.NoGroupSession) && !errors.Is(err, crypto.SessionNotShared) {
|
||||
return
|
||||
}
|
||||
h.Log.Debug().
|
||||
Err(err).
|
||||
Str("room_id", roomID.String()).
|
||||
Msg("Got session error while encrypting event, sharing group session and trying again")
|
||||
var users []id.UserID
|
||||
users, err = h.ClientStore.GetRoomJoinedOrInvitedMembers(ctx, roomID)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to get room member list: %w", err)
|
||||
} else if err = h.Crypto.ShareGroupSession(ctx, roomID, users); err != nil {
|
||||
err = fmt.Errorf("failed to share group session: %w", err)
|
||||
} else if encrypted, err = h.Crypto.EncryptMegolmEvent(ctx, roomID, evtType, content); err != nil {
|
||||
err = fmt.Errorf("failed to encrypt event after re-sharing group session: %w", err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (h *hiCryptoHelper) Decrypt(ctx context.Context, evt *event.Event) (*event.Event, error) {
|
||||
return h.Crypto.DecryptMegolmEvent(ctx, evt)
|
||||
}
|
||||
|
||||
func (h *hiCryptoHelper) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
|
||||
return h.Crypto.WaitForSession(ctx, roomID, senderKey, sessionID, timeout)
|
||||
}
|
||||
|
||||
func (h *hiCryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) {
|
||||
err := h.Crypto.SendRoomKeyRequest(ctx, roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{
|
||||
userID: {deviceID},
|
||||
h.Account.UserID: {"*"},
|
||||
})
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).
|
||||
Stringer("room_id", roomID).
|
||||
Stringer("session_id", sessionID).
|
||||
Stringer("user_id", userID).
|
||||
Msg("Failed to send room key request")
|
||||
} else {
|
||||
zerolog.Ctx(ctx).Debug().
|
||||
Stringer("room_id", roomID).
|
||||
Stringer("session_id", sessionID).
|
||||
Stringer("user_id", userID).
|
||||
Msg("Sent room key request")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *hiCryptoHelper) Init(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
65
hicli/database/account.go
Normal file
65
hicli/database/account.go
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
const (
|
||||
getAccountQuery = `SELECT user_id, device_id, access_token, homeserver_url, next_batch FROM account WHERE user_id = $1`
|
||||
putNextBatchQuery = `UPDATE account SET next_batch = $1 WHERE user_id = $2`
|
||||
upsertAccountQuery = `
|
||||
INSERT INTO account (user_id, device_id, access_token, homeserver_url, next_batch)
|
||||
VALUES ($1, $2, $3, $4, $5) ON CONFLICT (user_id)
|
||||
DO UPDATE SET device_id = excluded.device_id,
|
||||
access_token = excluded.access_token,
|
||||
homeserver_url = excluded.homeserver_url,
|
||||
next_batch = excluded.next_batch
|
||||
`
|
||||
)
|
||||
|
||||
type AccountQuery struct {
|
||||
*dbutil.QueryHelper[*Account]
|
||||
}
|
||||
|
||||
func (aq *AccountQuery) GetFirstUserID(ctx context.Context) (userID id.UserID, err error) {
|
||||
err = aq.GetDB().QueryRow(ctx, `SELECT user_id FROM account LIMIT 1`).Scan(&userID)
|
||||
return
|
||||
}
|
||||
|
||||
func (aq *AccountQuery) Get(ctx context.Context, userID id.UserID) (*Account, error) {
|
||||
return aq.QueryOne(ctx, getAccountQuery, userID)
|
||||
}
|
||||
|
||||
func (aq *AccountQuery) PutNextBatch(ctx context.Context, userID id.UserID, nextBatch string) error {
|
||||
return aq.Exec(ctx, putNextBatchQuery, nextBatch, userID)
|
||||
}
|
||||
|
||||
func (aq *AccountQuery) Put(ctx context.Context, account *Account) error {
|
||||
return aq.Exec(ctx, upsertAccountQuery, account.sqlVariables()...)
|
||||
}
|
||||
|
||||
type Account struct {
|
||||
UserID id.UserID
|
||||
DeviceID id.DeviceID
|
||||
AccessToken string
|
||||
HomeserverURL string
|
||||
NextBatch string
|
||||
}
|
||||
|
||||
func (a *Account) Scan(row dbutil.Scannable) (*Account, error) {
|
||||
return dbutil.ValueOrErr(a, row.Scan(&a.UserID, &a.DeviceID, &a.AccessToken, &a.HomeserverURL, &a.NextBatch))
|
||||
}
|
||||
|
||||
func (a *Account) sqlVariables() []any {
|
||||
return []any{a.UserID, a.DeviceID, a.AccessToken, a.HomeserverURL, a.NextBatch}
|
||||
}
|
||||
71
hicli/database/accountdata.go
Normal file
71
hicli/database/accountdata.go
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"unsafe"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
const (
|
||||
upsertAccountDataQuery = `
|
||||
INSERT INTO account_data (user_id, type, content) VALUES ($1, $2, $3)
|
||||
ON CONFLICT (user_id, type) DO UPDATE SET content = excluded.content
|
||||
`
|
||||
upsertRoomAccountDataQuery = `
|
||||
INSERT INTO room_account_data (user_id, room_id, type, content) VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (user_id, room_id, type) DO UPDATE SET content = excluded.content
|
||||
`
|
||||
)
|
||||
|
||||
type AccountDataQuery struct {
|
||||
*dbutil.QueryHelper[*AccountData]
|
||||
}
|
||||
|
||||
func unsafeJSONString(content json.RawMessage) *string {
|
||||
if content == nil {
|
||||
return nil
|
||||
}
|
||||
str := unsafe.String(unsafe.SliceData(content), len(content))
|
||||
return &str
|
||||
}
|
||||
|
||||
func (adq *AccountDataQuery) Put(ctx context.Context, userID id.UserID, eventType event.Type, content json.RawMessage) error {
|
||||
return adq.Exec(ctx, upsertAccountDataQuery, userID, eventType.Type, unsafeJSONString(content))
|
||||
}
|
||||
|
||||
func (adq *AccountDataQuery) PutRoom(ctx context.Context, userID id.UserID, roomID id.RoomID, eventType event.Type, content json.RawMessage) error {
|
||||
return adq.Exec(ctx, upsertRoomAccountDataQuery, userID, roomID, eventType.Type, unsafeJSONString(content))
|
||||
}
|
||||
|
||||
type AccountData struct {
|
||||
UserID id.UserID
|
||||
RoomID id.RoomID
|
||||
Type string
|
||||
Content json.RawMessage
|
||||
}
|
||||
|
||||
func (a *AccountData) Scan(row dbutil.Scannable) (*AccountData, error) {
|
||||
var roomID sql.NullString
|
||||
err := row.Scan(&a.UserID, &roomID, &a.Type, (*[]byte)(&a.Content))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
a.RoomID = id.RoomID(roomID.String)
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (a *AccountData) sqlVariables() []any {
|
||||
return []any{a.UserID, dbutil.StrPtr(a.RoomID), a.Type, unsafeJSONString(a.Content)}
|
||||
}
|
||||
60
hicli/database/database.go
Normal file
60
hicli/database/database.go
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"go.mau.fi/util/dbutil"
|
||||
|
||||
"maunium.net/go/mautrix/hicli/database/upgrades"
|
||||
)
|
||||
|
||||
type Database struct {
|
||||
*dbutil.Database
|
||||
|
||||
Account AccountQuery
|
||||
AccountData AccountDataQuery
|
||||
Room RoomQuery
|
||||
Event EventQuery
|
||||
CurrentState CurrentStateQuery
|
||||
Timeline TimelineQuery
|
||||
SessionRequest SessionRequestQuery
|
||||
}
|
||||
|
||||
func New(rawDB *dbutil.Database) *Database {
|
||||
rawDB.UpgradeTable = upgrades.Table
|
||||
return &Database{
|
||||
Database: rawDB,
|
||||
|
||||
Account: AccountQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newAccount)},
|
||||
AccountData: AccountDataQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newAccountData)},
|
||||
Room: RoomQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newRoom)},
|
||||
Event: EventQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newEvent)},
|
||||
CurrentState: CurrentStateQuery{Database: rawDB},
|
||||
Timeline: TimelineQuery{Database: rawDB},
|
||||
SessionRequest: SessionRequestQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newSessionRequest)},
|
||||
}
|
||||
}
|
||||
|
||||
func newSessionRequest(_ *dbutil.QueryHelper[*SessionRequest]) *SessionRequest {
|
||||
return &SessionRequest{}
|
||||
}
|
||||
|
||||
func newEvent(_ *dbutil.QueryHelper[*Event]) *Event {
|
||||
return &Event{}
|
||||
}
|
||||
|
||||
func newRoom(_ *dbutil.QueryHelper[*Room]) *Room {
|
||||
return &Room{}
|
||||
}
|
||||
|
||||
func newAccountData(_ *dbutil.QueryHelper[*AccountData]) *AccountData {
|
||||
return &AccountData{}
|
||||
}
|
||||
|
||||
func newAccount(_ *dbutil.QueryHelper[*Account]) *Account {
|
||||
return &Account{}
|
||||
}
|
||||
193
hicli/database/event.go
Normal file
193
hicli/database/event.go
Normal file
|
|
@ -0,0 +1,193 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"go.mau.fi/util/dbutil"
|
||||
"go.mau.fi/util/exgjson"
|
||||
"golang.org/x/net/context"
|
||||
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
const (
|
||||
getEventBaseQuery = `
|
||||
SELECT rowid, room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned,
|
||||
redacted_by, relates_to, megolm_session_id, decryption_error
|
||||
FROM event
|
||||
`
|
||||
getFailedEventsByMegolmSessionID = getEventBaseQuery + `WHERE room_id = $1 AND megolm_session_id = $2 AND decryption_error IS NOT NULL`
|
||||
upsertEventQuery = `
|
||||
INSERT INTO event (room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned, redacted_by, relates_to, megolm_session_id, decryption_error)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
|
||||
ON CONFLICT (event_id) DO UPDATE
|
||||
SET decrypted=COALESCE(event.decrypted, excluded.decrypted),
|
||||
decrypted_type=COALESCE(event.decrypted_type, excluded.decrypted_type),
|
||||
redacted_by=COALESCE(event.redacted_by, excluded.redacted_by),
|
||||
decryption_error=CASE WHEN COALESCE(event.decrypted, excluded.decrypted) IS NULL THEN COALESCE(excluded.decryption_error, event.decryption_error) END
|
||||
RETURNING rowid
|
||||
`
|
||||
updateEventDecryptedQuery = `UPDATE event SET decrypted = $1, decrypted_type = $2, decryption_error = NULL WHERE rowid = $3`
|
||||
)
|
||||
|
||||
type EventQuery struct {
|
||||
*dbutil.QueryHelper[*Event]
|
||||
}
|
||||
|
||||
func (eq *EventQuery) GetFailedByMegolmSessionID(ctx context.Context, roomID id.RoomID, sessionID id.SessionID) ([]*Event, error) {
|
||||
return eq.QueryMany(ctx, getFailedEventsByMegolmSessionID, roomID, sessionID)
|
||||
}
|
||||
|
||||
func (eq *EventQuery) Upsert(ctx context.Context, evt *Event) (rowID int64, err error) {
|
||||
err = eq.GetDB().QueryRow(ctx, upsertEventQuery, evt.sqlVariables()...).Scan(&rowID)
|
||||
return
|
||||
}
|
||||
|
||||
func (eq *EventQuery) UpdateDecrypted(ctx context.Context, rowID int64, decrypted json.RawMessage, decryptedType string) error {
|
||||
return eq.Exec(ctx, updateEventDecryptedQuery, unsafeJSONString(decrypted), decryptedType, rowID)
|
||||
}
|
||||
|
||||
type Event struct {
|
||||
RowID int64
|
||||
|
||||
RoomID id.RoomID
|
||||
ID id.EventID
|
||||
Sender id.UserID
|
||||
Type string
|
||||
StateKey *string
|
||||
Timestamp time.Time
|
||||
|
||||
Content json.RawMessage
|
||||
Decrypted json.RawMessage
|
||||
DecryptedType string
|
||||
Unsigned json.RawMessage
|
||||
|
||||
RedactedBy id.EventID
|
||||
RelatesTo id.EventID
|
||||
|
||||
MegolmSessionID id.SessionID
|
||||
DecryptionError string
|
||||
}
|
||||
|
||||
func MautrixToEvent(evt *event.Event) *Event {
|
||||
dbEvt := &Event{
|
||||
RoomID: evt.RoomID,
|
||||
ID: evt.ID,
|
||||
Sender: evt.Sender,
|
||||
Type: evt.Type.Type,
|
||||
StateKey: evt.StateKey,
|
||||
Timestamp: time.UnixMilli(evt.Timestamp),
|
||||
Content: evt.Content.VeryRaw,
|
||||
RelatesTo: getRelatesTo(evt),
|
||||
MegolmSessionID: getMegolmSessionID(evt),
|
||||
}
|
||||
dbEvt.Unsigned, _ = json.Marshal(&evt.Unsigned)
|
||||
if evt.Unsigned.RedactedBecause != nil {
|
||||
dbEvt.RedactedBy = evt.Unsigned.RedactedBecause.ID
|
||||
}
|
||||
return dbEvt
|
||||
}
|
||||
|
||||
func (e *Event) AsRawMautrix() *event.Event {
|
||||
evt := &event.Event{
|
||||
RoomID: e.RoomID,
|
||||
ID: e.ID,
|
||||
Sender: e.Sender,
|
||||
Type: event.Type{Type: e.Type, Class: event.MessageEventType},
|
||||
StateKey: e.StateKey,
|
||||
Timestamp: e.Timestamp.UnixMilli(),
|
||||
Content: event.Content{VeryRaw: e.Content},
|
||||
}
|
||||
if e.Decrypted != nil {
|
||||
evt.Content.VeryRaw = e.Decrypted
|
||||
evt.Type.Type = e.DecryptedType
|
||||
evt.Mautrix.WasEncrypted = true
|
||||
}
|
||||
if e.StateKey != nil {
|
||||
evt.Type.Class = event.StateEventType
|
||||
}
|
||||
_ = json.Unmarshal(e.Unsigned, &evt.Unsigned)
|
||||
return evt
|
||||
}
|
||||
|
||||
func (e *Event) Scan(row dbutil.Scannable) (*Event, error) {
|
||||
var timestamp int64
|
||||
var redactedBy, relatesTo, megolmSessionID, decryptionError, decryptedType sql.NullString
|
||||
err := row.Scan(
|
||||
&e.RowID,
|
||||
&e.RoomID,
|
||||
&e.ID,
|
||||
&e.Sender,
|
||||
&e.Type,
|
||||
&e.StateKey,
|
||||
×tamp,
|
||||
(*[]byte)(&e.Content),
|
||||
(*[]byte)(&e.Decrypted),
|
||||
&decryptedType,
|
||||
(*[]byte)(&e.Unsigned),
|
||||
&redactedBy,
|
||||
&relatesTo,
|
||||
&megolmSessionID,
|
||||
&decryptionError,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
e.Timestamp = time.UnixMilli(timestamp)
|
||||
e.RedactedBy = id.EventID(redactedBy.String)
|
||||
e.RelatesTo = id.EventID(relatesTo.String)
|
||||
e.MegolmSessionID = id.SessionID(megolmSessionID.String)
|
||||
e.DecryptedType = decryptedType.String
|
||||
e.DecryptionError = decryptionError.String
|
||||
return e, nil
|
||||
}
|
||||
|
||||
var relatesToPath = exgjson.Path("m.relates_to", "event_id")
|
||||
|
||||
func getRelatesTo(evt *event.Event) id.EventID {
|
||||
res := gjson.GetBytes(evt.Content.VeryRaw, relatesToPath)
|
||||
if res.Exists() && res.Type == gjson.String {
|
||||
return id.EventID(res.Str)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func getMegolmSessionID(evt *event.Event) id.SessionID {
|
||||
if evt.Type != event.EventEncrypted {
|
||||
return ""
|
||||
}
|
||||
res := gjson.GetBytes(evt.Content.VeryRaw, "session_id")
|
||||
if res.Exists() && res.Type == gjson.String {
|
||||
return id.SessionID(res.Str)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (e *Event) sqlVariables() []any {
|
||||
return []any{
|
||||
e.RoomID,
|
||||
e.ID,
|
||||
e.Sender,
|
||||
e.Type,
|
||||
e.StateKey,
|
||||
e.Timestamp.UnixMilli(),
|
||||
unsafeJSONString(e.Content),
|
||||
unsafeJSONString(e.Decrypted),
|
||||
dbutil.StrPtr(e.DecryptedType),
|
||||
unsafeJSONString(e.Unsigned),
|
||||
dbutil.StrPtr(e.RedactedBy),
|
||||
dbutil.StrPtr(e.RelatesTo),
|
||||
dbutil.StrPtr(e.MegolmSessionID),
|
||||
dbutil.StrPtr(e.DecryptionError),
|
||||
}
|
||||
}
|
||||
115
hicli/database/room.go
Normal file
115
hicli/database/room.go
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
const (
|
||||
getRoomByIDQuery = `
|
||||
SELECT room_id, creation_content, name, avatar, topic, lazy_load_summary, encryption_event, has_member_list, prev_batch
|
||||
FROM room WHERE room_id = $1
|
||||
`
|
||||
ensureRoomExistsQuery = `
|
||||
INSERT INTO room (room_id) VALUES ($1)
|
||||
ON CONFLICT (room_id) DO NOTHING
|
||||
`
|
||||
upsertRoomFromSyncQuery = `
|
||||
UPDATE room
|
||||
SET creation_content = COALESCE(room.creation_content, $2),
|
||||
name = COALESCE($3, room.name),
|
||||
avatar = COALESCE($4, room.avatar),
|
||||
topic = COALESCE($5, room.topic),
|
||||
lazy_load_summary = COALESCE($6, room.lazy_load_summary),
|
||||
encryption_event = COALESCE($7, room.encryption_event),
|
||||
has_member_list = room.has_member_list OR $8,
|
||||
prev_batch = COALESCE(room.prev_batch, $9)
|
||||
WHERE room_id = $1
|
||||
`
|
||||
setRoomPrevBatchQuery = `
|
||||
INSERT INTO room (room_id, prev_batch) VALUES ($1, $2)
|
||||
ON CONFLICT (room_id) DO UPDATE SET prev_batch = excluded.prev_batch
|
||||
`
|
||||
)
|
||||
|
||||
type RoomQuery struct {
|
||||
*dbutil.QueryHelper[*Room]
|
||||
}
|
||||
|
||||
func (rq *RoomQuery) Get(ctx context.Context, roomID id.RoomID) (*Room, error) {
|
||||
return rq.QueryOne(ctx, getRoomByIDQuery, roomID)
|
||||
}
|
||||
|
||||
func (rq *RoomQuery) Upsert(ctx context.Context, room *Room) error {
|
||||
return rq.Exec(ctx, upsertRoomFromSyncQuery, room.sqlVariables()...)
|
||||
}
|
||||
|
||||
func (rq *RoomQuery) CreateRow(ctx context.Context, roomID id.RoomID) error {
|
||||
return rq.Exec(ctx, ensureRoomExistsQuery, roomID)
|
||||
}
|
||||
|
||||
func (rq *RoomQuery) SetPrevBatch(ctx context.Context, roomID id.RoomID, prevBatch string) error {
|
||||
return rq.Exec(ctx, setRoomPrevBatchQuery, roomID, prevBatch)
|
||||
}
|
||||
|
||||
type Room struct {
|
||||
ID id.RoomID
|
||||
CreationContent *event.CreateEventContent
|
||||
|
||||
Name *string
|
||||
Avatar *id.ContentURI
|
||||
Topic *string
|
||||
|
||||
LazyLoadSummary *mautrix.LazyLoadSummary
|
||||
|
||||
EncryptionEvent *event.EncryptionEventContent
|
||||
HasMemberList bool
|
||||
|
||||
PrevBatch string
|
||||
}
|
||||
|
||||
func (r *Room) Scan(row dbutil.Scannable) (*Room, error) {
|
||||
var prevBatch sql.NullString
|
||||
err := row.Scan(
|
||||
&r.ID,
|
||||
dbutil.JSON{Data: &r.CreationContent},
|
||||
&r.Name,
|
||||
&r.Avatar,
|
||||
&r.Topic,
|
||||
dbutil.JSON{Data: &r.LazyLoadSummary},
|
||||
dbutil.JSON{Data: &r.EncryptionEvent},
|
||||
&r.HasMemberList,
|
||||
&prevBatch,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.PrevBatch = prevBatch.String
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *Room) sqlVariables() []any {
|
||||
return []any{
|
||||
r.ID,
|
||||
dbutil.JSONPtr(r.CreationContent),
|
||||
r.Name,
|
||||
r.Avatar,
|
||||
r.Topic,
|
||||
dbutil.JSONPtr(r.LazyLoadSummary),
|
||||
dbutil.JSONPtr(r.EncryptionEvent),
|
||||
r.HasMemberList,
|
||||
dbutil.StrPtr(r.PrevBatch),
|
||||
}
|
||||
}
|
||||
69
hicli/database/sessionrequest.go
Normal file
69
hicli/database/sessionrequest.go
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
const (
|
||||
putSessionRequestQueueEntry = `
|
||||
INSERT INTO session_request (room_id, session_id, sender, min_index, backup_checked, request_sent)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT (session_id) DO UPDATE
|
||||
SET min_index = MIN(excluded.min_index, session_request.min_index),
|
||||
backup_checked = excluded.backup_checked OR session_request.backup_checked,
|
||||
request_sent = excluded.request_sent OR session_request.request_sent
|
||||
`
|
||||
removeSessionRequestQuery = `
|
||||
DELETE FROM session_request WHERE session_id = $1 AND min_index >= $2
|
||||
`
|
||||
getNextSessionsToRequestQuery = `
|
||||
SELECT room_id, session_id, sender, min_index, backup_checked, request_sent
|
||||
FROM session_request
|
||||
WHERE request_sent = false OR backup_checked = false
|
||||
ORDER BY backup_checked, rowid
|
||||
LIMIT $1
|
||||
`
|
||||
)
|
||||
|
||||
type SessionRequestQuery struct {
|
||||
*dbutil.QueryHelper[*SessionRequest]
|
||||
}
|
||||
|
||||
func (srq *SessionRequestQuery) Next(ctx context.Context, count int) ([]*SessionRequest, error) {
|
||||
return srq.QueryMany(ctx, getNextSessionsToRequestQuery, count)
|
||||
}
|
||||
|
||||
func (srq *SessionRequestQuery) Remove(ctx context.Context, sessionID id.SessionID, minIndex uint32) error {
|
||||
return srq.Exec(ctx, removeSessionRequestQuery, sessionID, minIndex)
|
||||
}
|
||||
|
||||
func (srq *SessionRequestQuery) Put(ctx context.Context, sr *SessionRequest) error {
|
||||
return srq.Exec(ctx, putSessionRequestQueueEntry, sr.sqlVariables()...)
|
||||
}
|
||||
|
||||
type SessionRequest struct {
|
||||
RoomID id.RoomID
|
||||
SessionID id.SessionID
|
||||
Sender id.UserID
|
||||
MinIndex uint32
|
||||
BackupChecked bool
|
||||
RequestSent bool
|
||||
}
|
||||
|
||||
func (s *SessionRequest) Scan(row dbutil.Scannable) (*SessionRequest, error) {
|
||||
return dbutil.ValueOrErr(s, row.Scan(&s.RoomID, &s.SessionID, &s.Sender, &s.MinIndex, &s.BackupChecked, &s.RequestSent))
|
||||
}
|
||||
|
||||
func (s *SessionRequest) sqlVariables() []any {
|
||||
return []any{s.RoomID, s.SessionID, s.Sender, s.MinIndex, s.BackupChecked, s.RequestSent}
|
||||
}
|
||||
32
hicli/database/state.go
Normal file
32
hicli/database/state.go
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
const (
|
||||
setCurrentStateQuery = `
|
||||
INSERT INTO current_state (room_id, event_type, state_key, event_rowid, membership) VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (room_id, event_type, state_key) DO UPDATE SET event_rowid = excluded.event_rowid, membership = excluded.membership
|
||||
`
|
||||
)
|
||||
|
||||
type CurrentStateQuery struct {
|
||||
*dbutil.Database
|
||||
}
|
||||
|
||||
func (csq *CurrentStateQuery) Set(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, eventRowID int64, membership event.Membership) error {
|
||||
_, err := csq.Exec(ctx, setCurrentStateQuery, roomID, eventType.Type, stateKey, eventRowID, dbutil.StrPtr(membership))
|
||||
return err
|
||||
}
|
||||
149
hicli/database/statestore.go
Normal file
149
hicli/database/statestore.go
Normal file
|
|
@ -0,0 +1,149 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/crypto"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
const (
|
||||
getMembershipQuery = `
|
||||
SELECT membership FROM current_state
|
||||
WHERE room_id = $1 AND event_type = 'm.room.member' AND state_key = $2
|
||||
`
|
||||
getStateEventContentQuery = `
|
||||
SELECT event.content FROM current_state cs
|
||||
LEFT JOIN event ON event.rowid = cs.event_rowid
|
||||
WHERE cs.room_id = $1 AND cs.event_type = $2 AND cs.state_key = $3
|
||||
`
|
||||
getRoomJoinedOrInvitedMembersQuery = `
|
||||
SELECT state_key FROM current_state
|
||||
WHERE room_id = $1 AND event_type = 'm.room.member' AND membership IN ('join', 'invite')
|
||||
`
|
||||
isRoomEncryptedQuery = `
|
||||
SELECT room.encryption_event IS NOT NULL FROM room WHERE room_id = $1
|
||||
`
|
||||
getRoomEncryptionEventQuery = `
|
||||
SELECT room.encryption_event FROM room WHERE room_id = $1
|
||||
`
|
||||
findSharedRoomsQuery = `
|
||||
SELECT room_id FROM current_state
|
||||
WHERE event_type = 'm.room.member' AND state_key = $1 AND membership = 'join'
|
||||
`
|
||||
)
|
||||
|
||||
type ClientStateStore struct {
|
||||
*Database
|
||||
}
|
||||
|
||||
var _ mautrix.StateStore = (*ClientStateStore)(nil)
|
||||
var _ mautrix.StateStoreUpdater = (*ClientStateStore)(nil)
|
||||
var _ crypto.StateStore = (*ClientStateStore)(nil)
|
||||
|
||||
func (c *ClientStateStore) IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool {
|
||||
return c.IsMembership(ctx, roomID, userID, event.MembershipJoin)
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) IsInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) bool {
|
||||
return c.IsMembership(ctx, roomID, userID, event.MembershipInvite, event.MembershipJoin)
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) IsMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
|
||||
var membership event.Membership
|
||||
err := c.QueryRow(ctx, getMembershipQuery, roomID, userID).Scan(&membership)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
membership = event.MembershipLeave
|
||||
}
|
||||
return slices.Contains(allowedMemberships, membership)
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) GetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) {
|
||||
content, err := c.TryGetMember(ctx, roomID, userID)
|
||||
if content == nil {
|
||||
content = &event.MemberEventContent{Membership: event.MembershipLeave}
|
||||
}
|
||||
return content, err
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) TryGetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (content *event.MemberEventContent, err error) {
|
||||
err = c.QueryRow(ctx, getStateEventContentQuery, roomID, event.StateMember.Type, userID).Scan(&dbutil.JSON{Data: &content})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) GetPowerLevels(ctx context.Context, roomID id.RoomID) (content *event.PowerLevelsEventContent, err error) {
|
||||
err = c.QueryRow(ctx, getStateEventContentQuery, roomID, event.StatePowerLevels.Type, "").Scan(&dbutil.JSON{Data: &content})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error) {
|
||||
rows, err := c.Query(ctx, getRoomJoinedOrInvitedMembersQuery, roomID)
|
||||
return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList()
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (isEncrypted bool, err error) {
|
||||
err = c.QueryRow(ctx, isRoomEncryptedQuery, roomID).Scan(&isEncrypted)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) GetEncryptionEvent(ctx context.Context, roomID id.RoomID) (content *event.EncryptionEventContent, err error) {
|
||||
err = c.QueryRow(ctx, getRoomEncryptionEventQuery, roomID).
|
||||
Scan(&dbutil.JSON{Data: &content})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) FindSharedRooms(ctx context.Context, userID id.UserID) ([]id.RoomID, error) {
|
||||
// TODO for multiuser support, this might need to filter by the local user's membership
|
||||
rows, err := c.Query(ctx, findSharedRoomsQuery, userID)
|
||||
return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.RoomID], err).AsList()
|
||||
}
|
||||
|
||||
// Update methods are all intentionally no-ops as the state store wants to have the full event
|
||||
|
||||
func (c *ClientStateStore) SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) UpdateState(ctx context.Context, evt *event.Event) {}
|
||||
47
hicli/database/timeline.go
Normal file
47
hicli/database/timeline.go
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
const (
|
||||
clearTimelineQuery = `
|
||||
DELETE FROM timeline WHERE room_id = $1
|
||||
`
|
||||
setTimelineQuery = `
|
||||
INSERT INTO timeline (room_id, event_rowid) VALUES ($1, $2)
|
||||
`
|
||||
)
|
||||
|
||||
type MassInsertableRowID int64
|
||||
|
||||
func (m MassInsertableRowID) GetMassInsertValues() [1]any {
|
||||
return [1]any{m}
|
||||
}
|
||||
|
||||
var setTimelineQueryBuilder = dbutil.NewMassInsertBuilder[MassInsertableRowID, [1]any](setTimelineQuery, "($1, $%d)")
|
||||
|
||||
type TimelineQuery struct {
|
||||
*dbutil.Database
|
||||
}
|
||||
|
||||
func (tq *TimelineQuery) Clear(ctx context.Context, roomID id.RoomID) error {
|
||||
_, err := tq.Exec(ctx, clearTimelineQuery, roomID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (tq *TimelineQuery) Append(ctx context.Context, roomID id.RoomID, rowIDs []MassInsertableRowID) error {
|
||||
query, params := setTimelineQueryBuilder.Build([1]any{roomID}, rowIDs)
|
||||
_, err := tq.Exec(ctx, query, params...)
|
||||
return err
|
||||
}
|
||||
107
hicli/database/upgrades/00-latest-revision.sql
Normal file
107
hicli/database/upgrades/00-latest-revision.sql
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
-- v0 -> v1: Latest revision
|
||||
CREATE TABLE account (
|
||||
user_id TEXT NOT NULL PRIMARY KEY,
|
||||
device_id TEXT NOT NULL,
|
||||
access_token TEXT NOT NULL,
|
||||
homeserver_url TEXT NOT NULL,
|
||||
|
||||
next_batch TEXT NOT NULL
|
||||
) STRICT;
|
||||
|
||||
CREATE TABLE room (
|
||||
room_id TEXT NOT NULL PRIMARY KEY,
|
||||
creation_content TEXT,
|
||||
|
||||
name TEXT,
|
||||
avatar TEXT,
|
||||
topic TEXT,
|
||||
lazy_load_summary TEXT,
|
||||
|
||||
encryption_event TEXT,
|
||||
has_member_list INTEGER NOT NULL DEFAULT false,
|
||||
|
||||
prev_batch TEXT
|
||||
) STRICT;
|
||||
CREATE INDEX room_type_idx ON room (creation_content ->> 'type');
|
||||
|
||||
CREATE TABLE account_data (
|
||||
user_id TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
|
||||
PRIMARY KEY (user_id, type)
|
||||
) STRICT;
|
||||
|
||||
CREATE TABLE room_account_data (
|
||||
user_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
|
||||
PRIMARY KEY (user_id, room_id, type),
|
||||
CONSTRAINT room_account_data_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE
|
||||
) STRICT;
|
||||
|
||||
CREATE TABLE event (
|
||||
rowid INTEGER PRIMARY KEY,
|
||||
|
||||
room_id TEXT NOT NULL,
|
||||
event_id TEXT NOT NULL,
|
||||
sender TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
state_key TEXT,
|
||||
timestamp INTEGER NOT NULL,
|
||||
|
||||
content TEXT NOT NULL,
|
||||
decrypted TEXT,
|
||||
decrypted_type TEXT,
|
||||
unsigned TEXT NOT NULL,
|
||||
|
||||
redacted_by TEXT,
|
||||
relates_to TEXT,
|
||||
|
||||
megolm_session_id TEXT,
|
||||
decryption_error TEXT,
|
||||
|
||||
CONSTRAINT event_id_unique_key UNIQUE (event_id),
|
||||
CONSTRAINT event_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE
|
||||
) STRICT;
|
||||
CREATE INDEX event_room_id_idx ON event (room_id);
|
||||
CREATE INDEX event_redacted_by_idx ON event (room_id, redacted_by);
|
||||
CREATE INDEX event_relates_to_idx ON event (room_id, relates_to);
|
||||
CREATE INDEX event_megolm_session_id_idx ON event (room_id, megolm_session_id);
|
||||
|
||||
CREATE TABLE session_request (
|
||||
room_id TEXT NOT NULL,
|
||||
session_id TEXT NOT NULL,
|
||||
sender TEXT NOT NULL,
|
||||
min_index INTEGER NOT NULL,
|
||||
backup_checked INTEGER NOT NULL DEFAULT false,
|
||||
request_sent INTEGER NOT NULL DEFAULT false,
|
||||
|
||||
PRIMARY KEY (session_id),
|
||||
CONSTRAINT session_request_queue_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE
|
||||
) STRICT;
|
||||
|
||||
CREATE TABLE timeline (
|
||||
rowid INTEGER PRIMARY KEY,
|
||||
room_id TEXT NOT NULL,
|
||||
event_rowid INTEGER NOT NULL,
|
||||
|
||||
CONSTRAINT timeline_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE,
|
||||
CONSTRAINT timeline_event_fkey FOREIGN KEY (event_rowid) REFERENCES event (rowid) ON DELETE CASCADE
|
||||
) STRICT;
|
||||
CREATE INDEX timeline_room_id_idx ON timeline (room_id);
|
||||
|
||||
CREATE TABLE current_state (
|
||||
room_id TEXT NOT NULL,
|
||||
event_type TEXT NOT NULL,
|
||||
state_key TEXT NOT NULL,
|
||||
event_rowid INTEGER NOT NULL,
|
||||
|
||||
membership TEXT,
|
||||
|
||||
PRIMARY KEY (room_id, event_type, state_key),
|
||||
CONSTRAINT current_state_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE,
|
||||
CONSTRAINT current_state_event_fkey FOREIGN KEY (event_rowid) REFERENCES event (rowid)
|
||||
) STRICT, WITHOUT ROWID;
|
||||
22
hicli/database/upgrades/upgrades.go
Normal file
22
hicli/database/upgrades/upgrades.go
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package upgrades
|
||||
|
||||
import (
|
||||
"embed"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
)
|
||||
|
||||
var Table dbutil.UpgradeTable
|
||||
|
||||
//go:embed *.sql
|
||||
var upgrades embed.FS
|
||||
|
||||
func init() {
|
||||
Table.RegisterFS(upgrades)
|
||||
}
|
||||
194
hicli/decryptionqueue.go
Normal file
194
hicli/decryptionqueue.go
Normal file
|
|
@ -0,0 +1,194 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package hicli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"maunium.net/go/mautrix/crypto"
|
||||
"maunium.net/go/mautrix/hicli/database"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
func (h *HiClient) fetchFromKeyBackup(ctx context.Context, roomID id.RoomID, sessionID id.SessionID) (*crypto.InboundGroupSession, error) {
|
||||
data, err := h.Client.GetKeyBackupForRoomAndSession(ctx, h.KeyBackupVersion, roomID, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if data == nil {
|
||||
return nil, nil
|
||||
}
|
||||
decrypted, err := data.SessionData.Decrypt(h.KeyBackupKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return h.Crypto.ImportRoomKeyFromBackup(ctx, h.KeyBackupVersion, roomID, sessionID, decrypted)
|
||||
}
|
||||
|
||||
func (h *HiClient) handleReceivedMegolmSession(ctx context.Context, roomID id.RoomID, sessionID id.SessionID, firstKnownIndex uint32) {
|
||||
log := zerolog.Ctx(ctx)
|
||||
err := h.DB.SessionRequest.Remove(ctx, sessionID, firstKnownIndex)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to remove session request after receiving megolm session")
|
||||
}
|
||||
events, err := h.DB.Event.GetFailedByMegolmSessionID(ctx, roomID, sessionID)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to get events that failed to decrypt to retry decryption")
|
||||
return
|
||||
} else if len(events) == 0 {
|
||||
log.Trace().Msg("No events to retry decryption for")
|
||||
return
|
||||
}
|
||||
decrypted := events[:0]
|
||||
for _, evt := range events {
|
||||
if evt.Decrypted != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
evt.Decrypted, evt.DecryptedType, err = h.decryptEvent(ctx, evt.AsRawMautrix())
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Stringer("event_id", evt.ID).Msg("Failed to decrypt event even after receiving megolm session")
|
||||
} else {
|
||||
decrypted = append(decrypted, evt)
|
||||
}
|
||||
}
|
||||
if len(decrypted) > 0 {
|
||||
err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
|
||||
for _, evt := range decrypted {
|
||||
err = h.DB.Event.UpdateDecrypted(ctx, evt.RowID, evt.Decrypted, evt.DecryptedType)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save decrypted content for %s: %w", evt.ID, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to save decrypted events")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HiClient) WakeupRequestQueue() {
|
||||
select {
|
||||
case h.requestQueueWakeup <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HiClient) RunRequestQueue(ctx context.Context) {
|
||||
log := zerolog.Ctx(ctx).With().Str("action", "request queue").Logger()
|
||||
ctx = log.WithContext(ctx)
|
||||
log.Info().Msg("Starting key request queue")
|
||||
defer func() {
|
||||
log.Info().Msg("Stopping key request queue")
|
||||
}()
|
||||
for {
|
||||
err := h.FetchKeysForOutdatedUsers(ctx)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to fetch outdated device lists for tracked users")
|
||||
}
|
||||
madeRequests, err := h.RequestQueuedSessions(ctx)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to handle session request queue")
|
||||
} else if madeRequests {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-h.requestQueueWakeup:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HiClient) requestQueuedSession(ctx context.Context, req *database.SessionRequest, doneFunc func()) {
|
||||
defer doneFunc()
|
||||
log := zerolog.Ctx(ctx)
|
||||
if !req.BackupChecked {
|
||||
sess, err := h.fetchFromKeyBackup(ctx, req.RoomID, req.SessionID)
|
||||
if err != nil {
|
||||
log.Err(err).
|
||||
Stringer("session_id", req.SessionID).
|
||||
Msg("Failed to fetch session from key backup")
|
||||
|
||||
// TODO should this have retries instead of just storing it's checked?
|
||||
req.BackupChecked = true
|
||||
err = h.DB.SessionRequest.Put(ctx, req)
|
||||
if err != nil {
|
||||
log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after trying to check backup")
|
||||
}
|
||||
} else if sess == nil || sess.Internal.FirstKnownIndex() > req.MinIndex {
|
||||
req.BackupChecked = true
|
||||
err = h.DB.SessionRequest.Put(ctx, req)
|
||||
if err != nil {
|
||||
log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after checking backup")
|
||||
}
|
||||
} else {
|
||||
log.Debug().Stringer("session_id", req.SessionID).
|
||||
Msg("Found session with sufficiently low first known index, removing from queue")
|
||||
err = h.DB.SessionRequest.Remove(ctx, req.SessionID, sess.Internal.FirstKnownIndex())
|
||||
if err != nil {
|
||||
log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to remove session from request queue")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
err := h.Crypto.SendRoomKeyRequest(ctx, req.RoomID, "", req.SessionID, "", map[id.UserID][]id.DeviceID{
|
||||
h.Account.UserID: {"*"},
|
||||
req.Sender: {"*"},
|
||||
})
|
||||
//var err error
|
||||
if err != nil {
|
||||
log.Err(err).
|
||||
Stringer("session_id", req.SessionID).
|
||||
Msg("Failed to send key request")
|
||||
} else {
|
||||
log.Debug().Stringer("session_id", req.SessionID).Msg("Sent key request")
|
||||
req.RequestSent = true
|
||||
err = h.DB.SessionRequest.Put(ctx, req)
|
||||
if err != nil {
|
||||
log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after sending request")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const MaxParallelRequests = 5
|
||||
|
||||
func (h *HiClient) RequestQueuedSessions(ctx context.Context) (bool, error) {
|
||||
sessions, err := h.DB.SessionRequest.Next(ctx, MaxParallelRequests)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to get next events to decrypt: %w", err)
|
||||
} else if len(sessions) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(sessions))
|
||||
for _, req := range sessions {
|
||||
go h.requestQueuedSession(ctx, req, wg.Done)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
return true, err
|
||||
}
|
||||
|
||||
func (h *HiClient) FetchKeysForOutdatedUsers(ctx context.Context) error {
|
||||
outdatedUsers, err := h.Crypto.CryptoStore.GetOutdatedTrackedUsers(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if len(outdatedUsers) == 0 {
|
||||
return nil
|
||||
}
|
||||
_, err = h.Crypto.FetchKeys(ctx, outdatedUsers, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// TODO backoff for users that fail to be fetched?
|
||||
return nil
|
||||
}
|
||||
159
hicli/hicli.go
Normal file
159
hicli/hicli.go
Normal file
|
|
@ -0,0 +1,159 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
// Package hicli contains a highly opinionated high-level framework for developing instant messaging clients on Matrix.
|
||||
package hicli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"go.mau.fi/util/dbutil"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/crypto"
|
||||
"maunium.net/go/mautrix/crypto/backup"
|
||||
"maunium.net/go/mautrix/hicli/database"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
type HiClient struct {
|
||||
DB *database.Database
|
||||
Account *database.Account
|
||||
Client *mautrix.Client
|
||||
Crypto *crypto.OlmMachine
|
||||
CryptoStore *crypto.SQLCryptoStore
|
||||
ClientStore *database.ClientStateStore
|
||||
Log zerolog.Logger
|
||||
|
||||
Verified bool
|
||||
|
||||
KeyBackupVersion id.KeyBackupVersion
|
||||
KeyBackupKey *backup.MegolmBackupKey
|
||||
|
||||
firstSyncReceived bool
|
||||
syncingID int
|
||||
syncLock sync.Mutex
|
||||
encryptLock sync.Mutex
|
||||
|
||||
requestQueueWakeup chan struct{}
|
||||
}
|
||||
|
||||
func New(rawDB *dbutil.Database, log zerolog.Logger, pickleKey []byte) *HiClient {
|
||||
rawDB.Owner = "hicli"
|
||||
rawDB.IgnoreForeignTables = true
|
||||
db := database.New(rawDB)
|
||||
db.Log = dbutil.ZeroLogger(log.With().Str("db_section", "hicli").Logger())
|
||||
c := &HiClient{
|
||||
DB: db,
|
||||
Log: log,
|
||||
|
||||
requestQueueWakeup: make(chan struct{}, 1),
|
||||
}
|
||||
c.ClientStore = &database.ClientStateStore{Database: db}
|
||||
c.Client = &mautrix.Client{
|
||||
UserAgent: mautrix.DefaultUserAgent,
|
||||
Client: &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: (&net.Dialer{Timeout: 10 * time.Second}).DialContext,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
// This needs to be relatively high to allow initial syncs
|
||||
ResponseHeaderTimeout: 180 * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
},
|
||||
Timeout: 180 * time.Second,
|
||||
},
|
||||
Syncer: (*hiSyncer)(c),
|
||||
Store: (*hiStore)(c),
|
||||
StateStore: c.ClientStore,
|
||||
Log: log.With().Str("component", "mautrix client").Logger(),
|
||||
}
|
||||
c.CryptoStore = crypto.NewSQLCryptoStore(rawDB, dbutil.ZeroLogger(log.With().Str("db_section", "crypto").Logger()), "", "", pickleKey)
|
||||
cryptoLog := log.With().Str("component", "crypto").Logger()
|
||||
c.Crypto = crypto.NewOlmMachine(c.Client, &cryptoLog, c.CryptoStore, c.ClientStore)
|
||||
c.Crypto.SessionReceived = c.handleReceivedMegolmSession
|
||||
c.Crypto.DisableRatchetTracking = true
|
||||
c.Crypto.DisableDecryptKeyFetching = true
|
||||
c.Client.Crypto = (*hiCryptoHelper)(c)
|
||||
return c
|
||||
}
|
||||
|
||||
func (h *HiClient) IsLoggedIn() bool {
|
||||
return h.Account != nil
|
||||
}
|
||||
|
||||
func (h *HiClient) Start(ctx context.Context, userID id.UserID) error {
|
||||
err := h.DB.Upgrade(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upgrade hicli db: %w", err)
|
||||
}
|
||||
err = h.CryptoStore.DB.Upgrade(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upgrade crypto db: %w", err)
|
||||
}
|
||||
account, err := h.DB.Account.Get(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if account != nil {
|
||||
zerolog.Ctx(ctx).Debug().Stringer("user_id", account.UserID).Msg("Preparing client with existing credentials")
|
||||
h.Account = account
|
||||
h.CryptoStore.AccountID = account.UserID.String()
|
||||
h.CryptoStore.DeviceID = account.DeviceID
|
||||
h.Client.UserID = account.UserID
|
||||
h.Client.DeviceID = account.DeviceID
|
||||
h.Client.AccessToken = account.AccessToken
|
||||
h.Client.HomeserverURL, err = url.Parse(account.HomeserverURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = h.Crypto.Load(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load olm machine: %w", err)
|
||||
}
|
||||
|
||||
h.Verified, err = h.checkIsCurrentDeviceVerified(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
zerolog.Ctx(ctx).Debug().Bool("verified", h.Verified).Msg("Checked current device verification status")
|
||||
if h.Verified {
|
||||
err = h.loadPrivateKeys(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go h.Sync()
|
||||
go h.RunRequestQueue(ctx)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HiClient) Sync() {
|
||||
h.Client.StopSync()
|
||||
h.syncLock.Lock()
|
||||
defer h.syncLock.Unlock()
|
||||
h.syncingID++
|
||||
syncingID := h.syncingID
|
||||
log := h.Log.With().
|
||||
Str("action", "sync").
|
||||
Int("sync_id", syncingID).
|
||||
Logger()
|
||||
ctx := log.WithContext(context.Background())
|
||||
log.Info().Msg("Starting syncing")
|
||||
err := h.Client.SyncWithContext(ctx)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Fatal error in syncer")
|
||||
} else {
|
||||
log.Info().Msg("Syncing stopped")
|
||||
}
|
||||
}
|
||||
69
hicli/hitest/hitest.go
Normal file
69
hicli/hitest/hitest.go
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/chzyer/readline"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"go.mau.fi/util/dbutil"
|
||||
_ "go.mau.fi/util/dbutil/litestream"
|
||||
"go.mau.fi/util/exerrors"
|
||||
"go.mau.fi/util/exzerolog"
|
||||
"go.mau.fi/zeroconfig"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/hicli"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
var writerTypeReadline zeroconfig.WriterType = "hitest_readline"
|
||||
|
||||
func main() {
|
||||
rl := exerrors.Must(readline.New("> "))
|
||||
defer func() {
|
||||
_ = rl.Close()
|
||||
}()
|
||||
zeroconfig.RegisterWriter(writerTypeReadline, func(config *zeroconfig.WriterConfig) (io.Writer, error) {
|
||||
return rl.Stdout(), nil
|
||||
})
|
||||
log := exerrors.Must((&zeroconfig.Config{
|
||||
Writers: []zeroconfig.WriterConfig{{
|
||||
Type: writerTypeReadline,
|
||||
Format: zeroconfig.LogFormatPrettyColored,
|
||||
}},
|
||||
}).Compile())
|
||||
exzerolog.SetupDefaults(log)
|
||||
|
||||
rawDB := exerrors.Must(dbutil.NewWithDialect("hicli.db", "sqlite3-fk-wal"))
|
||||
ctx := log.WithContext(context.Background())
|
||||
cli := hicli.New(rawDB, *log, []byte("meow"))
|
||||
userID, _ := cli.DB.Account.GetFirstUserID(ctx)
|
||||
exerrors.PanicIfNotNil(cli.Start(ctx, userID))
|
||||
if !cli.IsLoggedIn() {
|
||||
rl.SetPrompt("User ID: ")
|
||||
userID := id.UserID(exerrors.Must(rl.Readline()))
|
||||
_, serverName := exerrors.Must2(userID.Parse())
|
||||
discovery, err := mautrix.DiscoverClientAPI(ctx, serverName)
|
||||
if discovery == nil {
|
||||
log.Fatal().Err(err).Msg("Failed to discover homeserver")
|
||||
}
|
||||
password := exerrors.Must(rl.ReadPassword("Password: "))
|
||||
recoveryCode := exerrors.Must(rl.ReadPassword("Recovery code: "))
|
||||
exerrors.PanicIfNotNil(cli.LoginAndVerify(ctx, discovery.Homeserver.BaseURL, userID.String(), string(password), string(recoveryCode)))
|
||||
}
|
||||
rl.SetPrompt("> ")
|
||||
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||
<-c
|
||||
}
|
||||
77
hicli/login.go
Normal file
77
hicli/login.go
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package hicli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/hicli/database"
|
||||
)
|
||||
|
||||
func (h *HiClient) LoginPassword(ctx context.Context, homeserverURL, username, password string) error {
|
||||
var err error
|
||||
h.Client.HomeserverURL, err = url.Parse(homeserverURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return h.Login(ctx, &mautrix.ReqLogin{
|
||||
Type: mautrix.AuthTypePassword,
|
||||
Identifier: mautrix.UserIdentifier{
|
||||
Type: mautrix.IdentifierTypeUser,
|
||||
User: username,
|
||||
},
|
||||
Password: password,
|
||||
InitialDeviceDisplayName: "mautrix client",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *HiClient) Login(ctx context.Context, req *mautrix.ReqLogin) error {
|
||||
req.StoreCredentials = true
|
||||
req.StoreHomeserverURL = true
|
||||
resp, err := h.Client.Login(ctx, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.Account = &database.Account{
|
||||
UserID: resp.UserID,
|
||||
DeviceID: resp.DeviceID,
|
||||
AccessToken: resp.AccessToken,
|
||||
HomeserverURL: h.Client.HomeserverURL.String(),
|
||||
}
|
||||
h.CryptoStore.AccountID = resp.UserID.String()
|
||||
h.CryptoStore.DeviceID = resp.DeviceID
|
||||
err = h.DB.Account.Put(ctx, h.Account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = h.Crypto.Load(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load olm machine: %w", err)
|
||||
}
|
||||
err = h.Crypto.ShareKeys(ctx, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HiClient) LoginAndVerify(ctx context.Context, homeserverURL, username, password, recoveryCode string) error {
|
||||
err := h.LoginPassword(ctx, homeserverURL, username, password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = h.VerifyWithRecoveryCode(ctx, recoveryCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go h.Sync()
|
||||
go h.RunRequestQueue(ctx)
|
||||
return nil
|
||||
}
|
||||
362
hicli/sync.go
Normal file
362
hicli/sync.go
Normal file
|
|
@ -0,0 +1,362 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package hicli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"go.mau.fi/util/exzerolog"
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/crypto"
|
||||
"maunium.net/go/mautrix/crypto/olm"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/hicli/database"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
type syncContext struct {
|
||||
shouldWakeupRequestQueue bool
|
||||
}
|
||||
|
||||
func (h *HiClient) preProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) error {
|
||||
log := zerolog.Ctx(ctx)
|
||||
postponedToDevices := resp.ToDevice.Events[:0]
|
||||
for _, evt := range resp.ToDevice.Events {
|
||||
evt.Type.Class = event.ToDeviceEventType
|
||||
err := evt.Content.ParseRaw(evt.Type)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).
|
||||
Stringer("event_type", &evt.Type).
|
||||
Stringer("sender", evt.Sender).
|
||||
Msg("Failed to parse to-device event, skipping")
|
||||
continue
|
||||
}
|
||||
|
||||
switch content := evt.Content.Parsed.(type) {
|
||||
case *event.EncryptedEventContent:
|
||||
h.Crypto.HandleEncryptedEvent(ctx, evt)
|
||||
case *event.RoomKeyWithheldEventContent:
|
||||
h.Crypto.HandleRoomKeyWithheld(ctx, content)
|
||||
default:
|
||||
postponedToDevices = append(postponedToDevices, evt)
|
||||
}
|
||||
}
|
||||
resp.ToDevice.Events = postponedToDevices
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HiClient) postProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) error {
|
||||
h.Crypto.HandleOTKCounts(ctx, &resp.DeviceOTKCount)
|
||||
go h.asyncPostProcessSyncResponse(ctx, resp, since)
|
||||
if ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue {
|
||||
h.WakeupRequestQueue()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HiClient) asyncPostProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) {
|
||||
for _, evt := range resp.ToDevice.Events {
|
||||
switch content := evt.Content.Parsed.(type) {
|
||||
case *event.SecretRequestEventContent:
|
||||
h.Crypto.HandleSecretRequest(ctx, evt.Sender, content)
|
||||
case *event.RoomKeyRequestEventContent:
|
||||
h.Crypto.HandleRoomKeyRequest(ctx, evt.Sender, content)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HiClient) processSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) error {
|
||||
if len(resp.DeviceLists.Changed) > 0 {
|
||||
zerolog.Ctx(ctx).Debug().
|
||||
Array("users", exzerolog.ArrayOfStringers(resp.DeviceLists.Changed)).
|
||||
Msg("Marking changed device lists for tracked users as outdated")
|
||||
err := h.Crypto.CryptoStore.MarkTrackedUsersOutdated(ctx, resp.DeviceLists.Changed)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to mark changed device lists as outdated: %w", err)
|
||||
}
|
||||
ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue = true
|
||||
}
|
||||
|
||||
for _, evt := range resp.AccountData.Events {
|
||||
evt.Type.Class = event.AccountDataEventType
|
||||
err := h.DB.AccountData.Put(ctx, h.Account.UserID, evt.Type, evt.Content.VeryRaw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err)
|
||||
}
|
||||
}
|
||||
for roomID, room := range resp.Rooms.Join {
|
||||
err := h.processSyncJoinedRoom(ctx, roomID, room)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process joined room %s: %w", roomID, err)
|
||||
}
|
||||
}
|
||||
for roomID, room := range resp.Rooms.Leave {
|
||||
err := h.processSyncLeftRoom(ctx, roomID, room)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process left room %s: %w", roomID, err)
|
||||
}
|
||||
}
|
||||
h.Account.NextBatch = resp.NextBatch
|
||||
err := h.DB.Account.PutNextBatch(ctx, h.Account.UserID, resp.NextBatch)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save next_batch: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncJoinedRoom) error {
|
||||
existingRoomData, err := h.DB.Room.Get(ctx, roomID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get room data: %w", err)
|
||||
} else if existingRoomData == nil {
|
||||
err = h.DB.Room.CreateRow(ctx, roomID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to ensure room row exists: %w", err)
|
||||
}
|
||||
existingRoomData = &database.Room{ID: roomID}
|
||||
}
|
||||
|
||||
for _, evt := range room.AccountData.Events {
|
||||
evt.Type.Class = event.AccountDataEventType
|
||||
evt.RoomID = roomID
|
||||
err = h.DB.AccountData.PutRoom(ctx, h.Account.UserID, roomID, evt.Type, evt.Content.VeryRaw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err)
|
||||
}
|
||||
}
|
||||
err = h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HiClient) processSyncLeftRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncLeftRoom) error {
|
||||
existingRoomData, err := h.DB.Room.Get(ctx, roomID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get room data: %w", err)
|
||||
} else if existingRoomData == nil {
|
||||
return nil
|
||||
}
|
||||
return h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary)
|
||||
}
|
||||
|
||||
func isDecryptionErrorRetryable(err error) bool {
|
||||
return errors.Is(err, crypto.NoSessionFound) || errors.Is(err, olm.UnknownMessageIndex) || errors.Is(err, crypto.ErrGroupSessionWithheld)
|
||||
}
|
||||
|
||||
func removeReplyFallback(evt *event.Event) []byte {
|
||||
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
|
||||
if ok && content.RelatesTo.GetReplyTo() != "" {
|
||||
prevFormattedBody := content.FormattedBody
|
||||
content.RemoveReplyFallback()
|
||||
if content.FormattedBody != prevFormattedBody {
|
||||
bytes, err := sjson.SetBytes(evt.Content.VeryRaw, "formatted_body", content.FormattedBody)
|
||||
if err == nil {
|
||||
return bytes
|
||||
}
|
||||
bytes, err = sjson.SetBytes(evt.Content.VeryRaw, "body", content.Body)
|
||||
if err == nil {
|
||||
return bytes
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HiClient) decryptEvent(ctx context.Context, evt *event.Event) ([]byte, string, error) {
|
||||
err := evt.Content.ParseRaw(evt.Type)
|
||||
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
|
||||
return nil, "", err
|
||||
}
|
||||
decrypted, err := h.Crypto.DecryptMegolmEvent(ctx, evt)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
withoutFallback := removeReplyFallback(decrypted)
|
||||
if withoutFallback != nil {
|
||||
return withoutFallback, decrypted.Type.Type, nil
|
||||
}
|
||||
return decrypted.Content.VeryRaw, decrypted.Type.Type, nil
|
||||
}
|
||||
|
||||
func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.Room, state *mautrix.SyncEventsList, timeline *mautrix.SyncTimeline, summary *mautrix.LazyLoadSummary) error {
|
||||
decryptionQueue := make(map[id.SessionID]*database.SessionRequest)
|
||||
roomDataChanged := false
|
||||
processEvent := func(evt *event.Event) (database.MassInsertableRowID, error) {
|
||||
evt.RoomID = room.ID
|
||||
dbEvt := database.MautrixToEvent(evt)
|
||||
contentWithoutFallback := removeReplyFallback(evt)
|
||||
if contentWithoutFallback != nil {
|
||||
dbEvt.Content = contentWithoutFallback
|
||||
}
|
||||
var decryptionErr error
|
||||
if evt.Type == event.EventEncrypted {
|
||||
dbEvt.Decrypted, dbEvt.DecryptedType, decryptionErr = h.decryptEvent(ctx, evt)
|
||||
if decryptionErr != nil {
|
||||
dbEvt.DecryptionError = decryptionErr.Error()
|
||||
}
|
||||
}
|
||||
rowID, err := h.DB.Event.Upsert(ctx, dbEvt)
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("failed to save event %s: %w", evt.ID, err)
|
||||
}
|
||||
if decryptionErr != nil && isDecryptionErrorRetryable(decryptionErr) {
|
||||
req, ok := decryptionQueue[dbEvt.MegolmSessionID]
|
||||
if !ok {
|
||||
req = &database.SessionRequest{
|
||||
RoomID: room.ID,
|
||||
SessionID: dbEvt.MegolmSessionID,
|
||||
Sender: evt.Sender,
|
||||
}
|
||||
}
|
||||
minIndex, _ := crypto.ParseMegolmMessageIndex(evt.Content.AsEncrypted().MegolmCiphertext)
|
||||
req.MinIndex = min(uint32(minIndex), req.MinIndex)
|
||||
decryptionQueue[dbEvt.MegolmSessionID] = req
|
||||
}
|
||||
if evt.StateKey != nil {
|
||||
var membership event.Membership
|
||||
if evt.Type == event.StateMember {
|
||||
membership = event.Membership(gjson.GetBytes(evt.Content.VeryRaw, "membership").Str)
|
||||
}
|
||||
err = h.DB.CurrentState.Set(ctx, room.ID, evt.Type, *evt.StateKey, rowID, membership)
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("failed to save current state event ID %s for %s/%s: %w", evt.ID, evt.Type.Type, *evt.StateKey, err)
|
||||
}
|
||||
roomDataChanged = processImportantEvent(ctx, evt, room) || roomDataChanged
|
||||
}
|
||||
return database.MassInsertableRowID(rowID), nil
|
||||
}
|
||||
var err error
|
||||
for _, evt := range state.Events {
|
||||
evt.Type.Class = event.StateEventType
|
||||
_, err = processEvent(evt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if len(timeline.Events) > 0 {
|
||||
timelineIDs := make([]database.MassInsertableRowID, len(timeline.Events))
|
||||
for i, evt := range timeline.Events {
|
||||
if evt.StateKey != nil {
|
||||
evt.Type.Class = event.StateEventType
|
||||
} else {
|
||||
evt.Type.Class = event.MessageEventType
|
||||
}
|
||||
timelineIDs[i], err = processEvent(evt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for _, entry := range decryptionQueue {
|
||||
err = h.DB.SessionRequest.Put(ctx, entry)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save session request for %s: %w", entry.SessionID, err)
|
||||
}
|
||||
}
|
||||
if len(decryptionQueue) > 0 {
|
||||
ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue = true
|
||||
}
|
||||
if timeline.Limited {
|
||||
err = h.DB.Timeline.Clear(ctx, room.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to clear old timeline: %w", err)
|
||||
}
|
||||
}
|
||||
err = h.DB.Timeline.Append(ctx, room.ID, timelineIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to append timeline: %w", err)
|
||||
}
|
||||
}
|
||||
if timeline.PrevBatch != "" && room.PrevBatch == "" {
|
||||
room.PrevBatch = timeline.PrevBatch
|
||||
roomDataChanged = true
|
||||
}
|
||||
if summary.Heroes != nil {
|
||||
roomDataChanged = roomDataChanged || room.LazyLoadSummary == nil ||
|
||||
!slices.Equal(summary.Heroes, room.LazyLoadSummary.Heroes) ||
|
||||
!intPtrEqual(summary.JoinedMemberCount, room.LazyLoadSummary.JoinedMemberCount) ||
|
||||
!intPtrEqual(summary.InvitedMemberCount, room.LazyLoadSummary.InvitedMemberCount)
|
||||
room.LazyLoadSummary = summary
|
||||
}
|
||||
if roomDataChanged {
|
||||
err = h.DB.Room.Upsert(ctx, room)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save room data: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func intPtrEqual(a, b *int) bool {
|
||||
if a == nil || b == nil {
|
||||
return a == b
|
||||
}
|
||||
return *a == *b
|
||||
}
|
||||
|
||||
func processImportantEvent(ctx context.Context, evt *event.Event, existingRoomData *database.Room) (roomDataChanged bool) {
|
||||
if evt.StateKey == nil {
|
||||
return
|
||||
}
|
||||
switch evt.Type {
|
||||
case event.StateCreate, event.StateRoomName, event.StateRoomAvatar, event.StateTopic, event.StateEncryption:
|
||||
if *evt.StateKey != "" {
|
||||
return
|
||||
}
|
||||
default:
|
||||
return
|
||||
}
|
||||
err := evt.Content.ParseRaw(evt.Type)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Warn().Err(err).
|
||||
Stringer("event_type", &evt.Type).
|
||||
Stringer("event_id", evt.ID).
|
||||
Msg("Failed to parse state event, skipping")
|
||||
return
|
||||
}
|
||||
switch evt.Type {
|
||||
case event.StateCreate:
|
||||
if existingRoomData.CreationContent == nil {
|
||||
roomDataChanged = true
|
||||
}
|
||||
existingRoomData.CreationContent, _ = evt.Content.Parsed.(*event.CreateEventContent)
|
||||
case event.StateEncryption:
|
||||
newEncryption, _ := evt.Content.Parsed.(*event.EncryptionEventContent)
|
||||
if existingRoomData.EncryptionEvent == nil || existingRoomData.EncryptionEvent.Algorithm == newEncryption.Algorithm {
|
||||
roomDataChanged = true
|
||||
existingRoomData.EncryptionEvent = newEncryption
|
||||
}
|
||||
case event.StateRoomName:
|
||||
content, ok := evt.Content.Parsed.(*event.RoomNameEventContent)
|
||||
if ok {
|
||||
roomDataChanged = existingRoomData.Name == nil || *existingRoomData.Name != content.Name
|
||||
existingRoomData.Name = &content.Name
|
||||
}
|
||||
case event.StateRoomAvatar:
|
||||
content, ok := evt.Content.Parsed.(*event.RoomAvatarEventContent)
|
||||
if ok {
|
||||
roomDataChanged = existingRoomData.Avatar == nil || *existingRoomData.Avatar != content.URL
|
||||
existingRoomData.Avatar = &content.URL
|
||||
}
|
||||
case event.StateTopic:
|
||||
content, ok := evt.Content.Parsed.(*event.TopicEventContent)
|
||||
if ok {
|
||||
roomDataChanged = existingRoomData.Topic == nil || *existingRoomData.Topic != content.Topic
|
||||
existingRoomData.Topic = &content.Topic
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
100
hicli/syncwrap.go
Normal file
100
hicli/syncwrap.go
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package hicli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
type hiSyncer HiClient
|
||||
|
||||
var _ mautrix.Syncer = (*hiSyncer)(nil)
|
||||
|
||||
type contextKey int
|
||||
|
||||
const (
|
||||
syncContextKey contextKey = iota
|
||||
)
|
||||
|
||||
func (h *hiSyncer) ProcessResponse(ctx context.Context, resp *mautrix.RespSync, since string) error {
|
||||
c := (*HiClient)(h)
|
||||
ctx = context.WithValue(ctx, syncContextKey, &syncContext{})
|
||||
err := c.preProcessSyncResponse(ctx, resp, since)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
|
||||
return c.processSyncResponse(ctx, resp, since)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = c.postProcessSyncResponse(ctx, resp, since)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.firstSyncReceived = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *hiSyncer) OnFailedSync(_ *mautrix.RespSync, err error) (time.Duration, error) {
|
||||
(*HiClient)(h).Log.Err(err).Msg("Sync failed, retrying in 1 second")
|
||||
return 1 * time.Second, nil
|
||||
}
|
||||
|
||||
func (h *hiSyncer) GetFilterJSON(_ id.UserID) *mautrix.Filter {
|
||||
if !h.Verified {
|
||||
return &mautrix.Filter{
|
||||
Presence: mautrix.FilterPart{
|
||||
NotRooms: []id.RoomID{"*"},
|
||||
},
|
||||
Room: mautrix.RoomFilter{
|
||||
NotRooms: []id.RoomID{"*"},
|
||||
},
|
||||
}
|
||||
}
|
||||
return &mautrix.Filter{
|
||||
Presence: mautrix.FilterPart{
|
||||
NotRooms: []id.RoomID{"*"},
|
||||
},
|
||||
Room: mautrix.RoomFilter{
|
||||
State: mautrix.FilterPart{
|
||||
LazyLoadMembers: true,
|
||||
},
|
||||
Timeline: mautrix.FilterPart{
|
||||
Limit: 100,
|
||||
LazyLoadMembers: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type hiStore HiClient
|
||||
|
||||
var _ mautrix.SyncStore = (*hiStore)(nil)
|
||||
|
||||
// Filter ID save and load are intentionally no-ops: we want to recreate filters when restarting syncing
|
||||
|
||||
func (h *hiStore) SaveFilterID(_ context.Context, _ id.UserID, _ string) error { return nil }
|
||||
func (h *hiStore) LoadFilterID(_ context.Context, _ id.UserID) (string, error) { return "", nil }
|
||||
|
||||
func (h *hiStore) SaveNextBatch(ctx context.Context, userID id.UserID, nextBatchToken string) error {
|
||||
// This is intentionally a no-op: we don't want to save the next batch before processing the sync
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *hiStore) LoadNextBatch(_ context.Context, userID id.UserID) (string, error) {
|
||||
if h.Account.UserID != userID {
|
||||
return "", fmt.Errorf("mismatching user ID")
|
||||
}
|
||||
return h.Account.NextBatch, nil
|
||||
}
|
||||
158
hicli/verify.go
Normal file
158
hicli/verify.go
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package hicli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"maunium.net/go/mautrix/crypto"
|
||||
"maunium.net/go/mautrix/crypto/backup"
|
||||
"maunium.net/go/mautrix/crypto/ssss"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
func (h *HiClient) checkIsCurrentDeviceVerified(ctx context.Context) (bool, error) {
|
||||
keys := h.Crypto.GetOwnCrossSigningPublicKeys(ctx)
|
||||
if keys == nil {
|
||||
return false, fmt.Errorf("own cross-signing keys not found")
|
||||
}
|
||||
isVerified, err := h.Crypto.CryptoStore.IsKeySignedBy(ctx, h.Account.UserID, h.Crypto.GetAccount().SigningKey(), h.Account.UserID, keys.SelfSigningKey)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check if current device is signed by own self-signing key: %w", err)
|
||||
}
|
||||
return isVerified, nil
|
||||
}
|
||||
|
||||
func (h *HiClient) fetchKeyBackupKey(ctx context.Context, ssssKey *ssss.Key) error {
|
||||
latestVersion, err := h.Client.GetKeyBackupLatestVersion(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get key backup latest version: %w", err)
|
||||
}
|
||||
h.KeyBackupVersion = latestVersion.Version
|
||||
data, err := h.Crypto.SSSS.GetDecryptedAccountData(ctx, event.AccountDataMegolmBackupKey, ssssKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get megolm backup key from SSSS: %w", err)
|
||||
}
|
||||
key, err := backup.MegolmBackupKeyFromBytes(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse megolm backup key: %w", err)
|
||||
}
|
||||
err = h.CryptoStore.PutSecret(ctx, id.SecretMegolmBackupV1, base64.StdEncoding.EncodeToString(key.Bytes()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store megolm backup key: %w", err)
|
||||
}
|
||||
h.KeyBackupKey = key
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HiClient) getAndDecodeSecret(ctx context.Context, secret id.Secret) ([]byte, error) {
|
||||
secretData, err := h.CryptoStore.GetSecret(ctx, secret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get secret %s: %w", secret, err)
|
||||
}
|
||||
data, err := base64.StdEncoding.DecodeString(secretData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode secret %s: %w", secret, err)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (h *HiClient) loadPrivateKeys(ctx context.Context) error {
|
||||
zerolog.Ctx(ctx).Debug().Msg("Loading cross-signing private keys")
|
||||
masterKeySeed, err := h.getAndDecodeSecret(ctx, id.SecretXSMaster)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get master key: %w", err)
|
||||
}
|
||||
selfSigningKeySeed, err := h.getAndDecodeSecret(ctx, id.SecretXSSelfSigning)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get self-signing key: %w", err)
|
||||
}
|
||||
userSigningKeySeed, err := h.getAndDecodeSecret(ctx, id.SecretXSUserSigning)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user signing key: %w", err)
|
||||
}
|
||||
err = h.Crypto.ImportCrossSigningKeys(crypto.CrossSigningSeeds{
|
||||
MasterKey: masterKeySeed,
|
||||
SelfSigningKey: selfSigningKeySeed,
|
||||
UserSigningKey: userSigningKeySeed,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to import cross-signing private keys: %w", err)
|
||||
}
|
||||
zerolog.Ctx(ctx).Debug().Msg("Loading key backup key")
|
||||
keyBackupKey, err := h.getAndDecodeSecret(ctx, id.SecretMegolmBackupV1)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get megolm backup key: %w", err)
|
||||
}
|
||||
h.KeyBackupKey, err = backup.MegolmBackupKeyFromBytes(keyBackupKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse megolm backup key: %w", err)
|
||||
}
|
||||
zerolog.Ctx(ctx).Debug().Msg("Fetching key backup version")
|
||||
latestVersion, err := h.Client.GetKeyBackupLatestVersion(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get key backup latest version: %w", err)
|
||||
}
|
||||
h.KeyBackupVersion = latestVersion.Version
|
||||
zerolog.Ctx(ctx).Debug().Msg("Secrets loaded")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HiClient) storeCrossSigningPrivateKeys(ctx context.Context) error {
|
||||
keys := h.Crypto.CrossSigningKeys
|
||||
err := h.CryptoStore.PutSecret(ctx, id.SecretXSMaster, base64.StdEncoding.EncodeToString(keys.MasterKey.Seed()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = h.CryptoStore.PutSecret(ctx, id.SecretXSSelfSigning, base64.StdEncoding.EncodeToString(keys.SelfSigningKey.Seed()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = h.CryptoStore.PutSecret(ctx, id.SecretXSUserSigning, base64.StdEncoding.EncodeToString(keys.UserSigningKey.Seed()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HiClient) VerifyWithRecoveryCode(ctx context.Context, code string) error {
|
||||
_, keyData, err := h.Crypto.SSSS.GetDefaultKeyData(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get default SSSS key data: %w", err)
|
||||
}
|
||||
key, err := keyData.VerifyRecoveryKey(code)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = h.Crypto.FetchCrossSigningKeysFromSSSS(ctx, key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch cross-signing keys from SSSS: %w", err)
|
||||
}
|
||||
err = h.Crypto.SignOwnDevice(ctx, h.Crypto.OwnIdentity())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sign own device: %w", err)
|
||||
}
|
||||
err = h.Crypto.SignOwnMasterKey(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sign own master key: %w", err)
|
||||
}
|
||||
err = h.storeCrossSigningPrivateKeys(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store cross-signing private keys: %w", err)
|
||||
}
|
||||
err = h.fetchKeyBackupKey(ctx, key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch key backup key: %w", err)
|
||||
}
|
||||
h.Verified = true
|
||||
return nil
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue