hicli: include state in sync and add method to get state

This commit is contained in:
Tulir Asokan 2024-10-09 01:22:13 +03:00
commit cb3a7ce87a
9 changed files with 176 additions and 13 deletions

View file

@ -1469,6 +1469,18 @@ func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomSt
return
}
// StateAsArray gets all the state in a room as an array. It does not update the state store.
// Use State to get the events as a map and also update the state store.
func (cli *Client) StateAsArray(ctx context.Context, roomID id.RoomID) (state []*event.Event, err error) {
_, err = cli.MakeRequest(ctx, http.MethodGet, cli.BuildClientURL("v3", "rooms", roomID, "state"), nil, &state)
if err == nil {
for _, evt := range state {
evt.Type.Class = event.StateEventType
}
}
return
}
// GetMediaConfig fetches the configuration of the content repository, such as upload limitations.
func (cli *Client) GetMediaConfig(ctx context.Context) (resp *RespMediaConfig, err error) {
_, err = cli.MakeRequest(ctx, http.MethodGet, cli.BuildClientURL("v1", "media", "config"), nil, &resp)

View file

@ -153,7 +153,10 @@ func (r *Room) CheckChangesAndCopyInto(other *Room) (hasChanges bool) {
other.EncryptionEvent = r.EncryptionEvent
hasChanges = true
}
other.HasMemberList = other.HasMemberList || r.HasMemberList
if r.HasMemberList && !other.HasMemberList {
hasChanges = true
other.HasMemberList = true
}
if r.PreviewEventRowID > other.PreviewEventRowID {
other.PreviewEventRowID = r.PreviewEventRowID
hasChanges = true

View file

@ -4,10 +4,14 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build go1.23
package database
import (
"context"
"fmt"
"slices"
"go.mau.fi/util/dbutil"
@ -20,6 +24,13 @@ const (
INSERT INTO current_state (room_id, event_type, state_key, event_rowid, membership) VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (room_id, event_type, state_key) DO UPDATE SET event_rowid = excluded.event_rowid, membership = excluded.membership
`
addCurrentStateQuery = `
INSERT INTO current_state (room_id, event_type, state_key, event_rowid, membership) VALUES ($1, $2, $3, $4, $5)
ON CONFLICT DO NOTHING
`
deleteCurrentStateQuery = `
DELETE FROM current_state WHERE room_id = $1
`
getCurrentRoomStateQuery = `
SELECT event.rowid, -1, event.room_id, event.event_id, sender, event.type, event.state_key, timestamp, content, decrypted, decrypted_type, unsigned,
transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid
@ -30,6 +41,21 @@ const (
getCurrentStateEventQuery = getCurrentRoomStateQuery + `AND cs.event_type = $2 AND cs.state_key = $3`
)
var massInsertCurrentStateBuilder = dbutil.NewMassInsertBuilder[*CurrentStateEntry, [1]any](addCurrentStateQuery, "($1, $%d, $%d, $%d, $%d)")
const currentStateMassInsertBatchSize = 1000
type CurrentStateEntry struct {
EventType event.Type
StateKey string
EventRowID EventRowID
Membership event.Membership
}
func (cse *CurrentStateEntry) GetMassInsertValues() [4]any {
return [4]any{cse.EventType.Type, cse.StateKey, cse.EventRowID, dbutil.StrPtr(cse.Membership)}
}
type CurrentStateQuery struct {
*dbutil.QueryHelper[*Event]
}
@ -38,6 +64,28 @@ func (csq *CurrentStateQuery) Set(ctx context.Context, roomID id.RoomID, eventTy
return csq.Exec(ctx, setCurrentStateQuery, roomID, eventType.Type, stateKey, eventRowID, dbutil.StrPtr(membership))
}
func (csq *CurrentStateQuery) AddMany(ctx context.Context, roomID id.RoomID, deleteOld bool, entries []*CurrentStateEntry) error {
var err error
if deleteOld {
err = csq.Exec(ctx, deleteCurrentStateQuery, roomID)
if err != nil {
return fmt.Errorf("failed to delete old state: %w", err)
}
}
for entryChunk := range slices.Chunk(entries, currentStateMassInsertBatchSize) {
query, params := massInsertCurrentStateBuilder.Build([1]any{roomID}, entryChunk)
err = csq.Exec(ctx, query, params...)
if err != nil {
return err
}
}
return nil
}
func (csq *CurrentStateQuery) Add(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, eventRowID EventRowID, membership event.Membership) error {
return csq.Exec(ctx, addCurrentStateQuery, roomID, eventType.Type, stateKey, eventRowID, dbutil.StrPtr(membership))
}
func (csq *CurrentStateQuery) Get(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*Event, error) {
return csq.QueryOne(ctx, getCurrentStateEventQuery, roomID, eventType.Type, stateKey)
}

View file

@ -39,6 +39,9 @@ const (
SELECT state_key FROM current_state
WHERE room_id = $1 AND event_type = 'm.room.member' AND membership IN ('join', 'invite')
`
getHasFetchedMembersQuery = `
SELECT has_member_list FROM room WHERE room_id = $1
`
isRoomEncryptedQuery = `
SELECT room.encryption_event IS NOT NULL FROM room WHERE room_id = $1
`
@ -116,7 +119,12 @@ func (c *ClientStateStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, ro
return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList()
}
func (c *ClientStateStore) HasFetchedMembers(ctx context.Context, roomID id.RoomID) (bool, error) {
func (c *ClientStateStore) HasFetchedMembers(ctx context.Context, roomID id.RoomID) (hasFetched bool, err error) {
//err = c.QueryRow(ctx, getHasFetchedMembersQuery, roomID).Scan(&hasFetched)
//if errors.Is(err, sql.ErrNoRows) {
// err = nil
//}
//return
return false, fmt.Errorf("not implemented")
}

View file

@ -13,10 +13,11 @@ import (
)
type SyncRoom struct {
Meta *database.Room `json:"meta"`
Timeline []database.TimelineRowTuple `json:"timeline"`
Events []*database.Event `json:"events"`
Reset bool `json:"reset"`
Meta *database.Room `json:"meta"`
Timeline []database.TimelineRowTuple `json:"timeline"`
State map[event.Type]map[string]database.EventRowID `json:"state"`
Events []*database.Event `json:"events"`
Reset bool `json:"reset"`
}
type SyncComplete struct {

View file

@ -49,6 +49,10 @@ func (h *HiClient) handleJSONCommand(ctx context.Context, req *JSONCommand) (any
return unmarshalAndCall(req.Data, func(params *getEventsByRowIDsParams) ([]*database.Event, error) {
return h.GetEventsByRowIDs(ctx, params.RowIDs)
})
case "get_room_state":
return unmarshalAndCall(req.Data, func(params *getRoomStateParams) ([]*database.Event, error) {
return h.GetRoomState(ctx, params.RoomID, params.FetchMembers, params.Refetch)
})
case "paginate":
return unmarshalAndCall(req.Data, func(params *paginateParams) (*PaginationResponse, error) {
return h.Paginate(ctx, params.RoomID, params.MaxTimelineID, params.Limit)
@ -111,6 +115,12 @@ type getEventsByRowIDsParams struct {
RowIDs []database.EventRowID `json:"row_ids"`
}
type getRoomStateParams struct {
RoomID id.RoomID `json:"room_id"`
Refetch bool `json:"refetch"`
FetchMembers bool `json:"fetch_members"`
}
type ensureGroupSessionSharedParams struct {
RoomID id.RoomID `json:"room_id"`
}

View file

@ -14,6 +14,7 @@ import (
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/hicli/database"
"maunium.net/go/mautrix/id"
)
@ -62,6 +63,68 @@ func (h *HiClient) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.Ev
}
}
func (h *HiClient) GetRoomState(ctx context.Context, roomID id.RoomID, fetchMembers, refetch bool) ([]*database.Event, error) {
var evts []*event.Event
if refetch {
resp, err := h.Client.StateAsArray(ctx, roomID)
if err != nil {
return nil, fmt.Errorf("failed to refetch state: %w", err)
}
evts = resp
} else if fetchMembers {
resp, err := h.Client.Members(ctx, roomID)
if err != nil {
return nil, fmt.Errorf("failed to fetch members: %w", err)
}
evts = resp.Chunk
}
if evts != nil {
err := h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
room, err := h.DB.Room.Get(ctx, roomID)
if err != nil {
return fmt.Errorf("failed to get room from database: %w", err)
}
updatedRoom := &database.Room{
ID: room.ID,
HasMemberList: true,
}
entries := make([]*database.CurrentStateEntry, len(evts))
for i, evt := range evts {
dbEvt, err := h.processEvent(ctx, evt, nil, false)
if err != nil {
return fmt.Errorf("failed to process event %s: %w", evt.ID, err)
}
entries[i] = &database.CurrentStateEntry{
EventType: evt.Type,
StateKey: *evt.StateKey,
EventRowID: dbEvt.RowID,
}
if evt.Type == event.StateMember {
entries[i].Membership = event.Membership(evt.Content.Raw["membership"].(string))
} else {
processImportantEvent(ctx, evt, room, updatedRoom)
}
}
err = h.DB.CurrentState.AddMany(ctx, room.ID, refetch, entries)
if err != nil {
return err
}
roomChanged := updatedRoom.CheckChangesAndCopyInto(room)
if roomChanged {
err = h.DB.Room.Upsert(ctx, updatedRoom)
if err != nil {
return fmt.Errorf("failed to save room data: %w", err)
}
}
return nil
})
if err != nil {
return nil, err
}
}
return h.DB.CurrentState.GetAll(ctx, roomID)
}
type PaginationResponse struct {
Events []*database.Event `json:"events"`
HasMore bool `json:"has_more"`

View file

@ -148,17 +148,23 @@ func (h *HiClient) loadMembers(ctx context.Context, room *database.Room) error {
return fmt.Errorf("failed to get room member list: %w", err)
}
err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
for _, evt := range resp.Chunk {
entries := make([]*database.CurrentStateEntry, len(resp.Chunk))
for i, evt := range resp.Chunk {
dbEvt, err := h.processEvent(ctx, evt, nil, true)
if err != nil {
return err
}
membership := event.Membership(evt.Content.Raw["membership"].(string))
err = h.DB.CurrentState.Set(ctx, room.ID, evt.Type, *evt.StateKey, dbEvt.RowID, membership)
if err != nil {
return err
entries[i] = &database.CurrentStateEntry{
EventType: evt.Type,
StateKey: *evt.StateKey,
EventRowID: dbEvt.RowID,
Membership: event.Membership(evt.Content.Raw["membership"].(string)),
}
}
err := h.DB.CurrentState.AddMany(ctx, room.ID, false, entries)
if err != nil {
return err
}
return h.DB.Room.Upsert(ctx, &database.Room{
ID: room.ID,
HasMemberList: true,

View file

@ -410,15 +410,23 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R
allNewEvents = append(allNewEvents, dbEvt)
return dbEvt.RowID, nil
}
var err error
changedState := make(map[event.Type]map[string]database.EventRowID)
setNewState := func(evtType event.Type, stateKey string, rowID database.EventRowID) {
if _, ok := changedState[evtType]; !ok {
changedState[evtType] = make(map[string]database.EventRowID)
}
changedState[evtType][stateKey] = rowID
}
for _, evt := range state.Events {
evt.Type.Class = event.StateEventType
_, err = processNewEvent(evt, false)
rowID, err := processNewEvent(evt, false)
if err != nil {
return err
}
setNewState(evt.Type, *evt.StateKey, rowID)
}
var timelineRowTuples []database.TimelineRowTuple
var err error
if len(timeline.Events) > 0 {
timelineIDs := make([]database.EventRowID, len(timeline.Events))
for i, evt := range timeline.Events {
@ -431,6 +439,9 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R
if err != nil {
return err
}
if evt.StateKey != nil {
setNewState(evt.Type, *evt.StateKey, timelineIDs[i])
}
}
for _, entry := range decryptionQueue {
err = h.DB.SessionRequest.Put(ctx, entry)
@ -481,6 +492,7 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R
ctx.Value(syncContextKey).(*syncContext).evt.Rooms[room.ID] = &SyncRoom{
Meta: room,
Timeline: timelineRowTuples,
State: changedState,
Reset: timeline.Limited,
Events: allNewEvents,
}