diff --git a/client.go b/client.go index 836afc55..125bba0d 100644 --- a/client.go +++ b/client.go @@ -1469,6 +1469,18 @@ func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomSt return } +// StateAsArray gets all the state in a room as an array. It does not update the state store. +// Use State to get the events as a map and also update the state store. +func (cli *Client) StateAsArray(ctx context.Context, roomID id.RoomID) (state []*event.Event, err error) { + _, err = cli.MakeRequest(ctx, http.MethodGet, cli.BuildClientURL("v3", "rooms", roomID, "state"), nil, &state) + if err == nil { + for _, evt := range state { + evt.Type.Class = event.StateEventType + } + } + return +} + // GetMediaConfig fetches the configuration of the content repository, such as upload limitations. func (cli *Client) GetMediaConfig(ctx context.Context) (resp *RespMediaConfig, err error) { _, err = cli.MakeRequest(ctx, http.MethodGet, cli.BuildClientURL("v1", "media", "config"), nil, &resp) diff --git a/hicli/database/room.go b/hicli/database/room.go index 5778e5f5..e7138d94 100644 --- a/hicli/database/room.go +++ b/hicli/database/room.go @@ -153,7 +153,10 @@ func (r *Room) CheckChangesAndCopyInto(other *Room) (hasChanges bool) { other.EncryptionEvent = r.EncryptionEvent hasChanges = true } - other.HasMemberList = other.HasMemberList || r.HasMemberList + if r.HasMemberList && !other.HasMemberList { + hasChanges = true + other.HasMemberList = true + } if r.PreviewEventRowID > other.PreviewEventRowID { other.PreviewEventRowID = r.PreviewEventRowID hasChanges = true diff --git a/hicli/database/state.go b/hicli/database/state.go index e74f2950..5dc13729 100644 --- a/hicli/database/state.go +++ b/hicli/database/state.go @@ -4,10 +4,14 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. +//go:build go1.23 + package database import ( "context" + "fmt" + "slices" "go.mau.fi/util/dbutil" @@ -20,6 +24,13 @@ const ( INSERT INTO current_state (room_id, event_type, state_key, event_rowid, membership) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (room_id, event_type, state_key) DO UPDATE SET event_rowid = excluded.event_rowid, membership = excluded.membership ` + addCurrentStateQuery = ` + INSERT INTO current_state (room_id, event_type, state_key, event_rowid, membership) VALUES ($1, $2, $3, $4, $5) + ON CONFLICT DO NOTHING + ` + deleteCurrentStateQuery = ` + DELETE FROM current_state WHERE room_id = $1 + ` getCurrentRoomStateQuery = ` SELECT event.rowid, -1, event.room_id, event.event_id, sender, event.type, event.state_key, timestamp, content, decrypted, decrypted_type, unsigned, transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid @@ -30,6 +41,21 @@ const ( getCurrentStateEventQuery = getCurrentRoomStateQuery + `AND cs.event_type = $2 AND cs.state_key = $3` ) +var massInsertCurrentStateBuilder = dbutil.NewMassInsertBuilder[*CurrentStateEntry, [1]any](addCurrentStateQuery, "($1, $%d, $%d, $%d, $%d)") + +const currentStateMassInsertBatchSize = 1000 + +type CurrentStateEntry struct { + EventType event.Type + StateKey string + EventRowID EventRowID + Membership event.Membership +} + +func (cse *CurrentStateEntry) GetMassInsertValues() [4]any { + return [4]any{cse.EventType.Type, cse.StateKey, cse.EventRowID, dbutil.StrPtr(cse.Membership)} +} + type CurrentStateQuery struct { *dbutil.QueryHelper[*Event] } @@ -38,6 +64,28 @@ func (csq *CurrentStateQuery) Set(ctx context.Context, roomID id.RoomID, eventTy return csq.Exec(ctx, setCurrentStateQuery, roomID, eventType.Type, stateKey, eventRowID, dbutil.StrPtr(membership)) } +func (csq *CurrentStateQuery) AddMany(ctx context.Context, roomID id.RoomID, deleteOld bool, entries []*CurrentStateEntry) error { + var err error + if deleteOld { + err = csq.Exec(ctx, deleteCurrentStateQuery, roomID) + if err != nil { + return fmt.Errorf("failed to delete old state: %w", err) + } + } + for entryChunk := range slices.Chunk(entries, currentStateMassInsertBatchSize) { + query, params := massInsertCurrentStateBuilder.Build([1]any{roomID}, entryChunk) + err = csq.Exec(ctx, query, params...) + if err != nil { + return err + } + } + return nil +} + +func (csq *CurrentStateQuery) Add(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, eventRowID EventRowID, membership event.Membership) error { + return csq.Exec(ctx, addCurrentStateQuery, roomID, eventType.Type, stateKey, eventRowID, dbutil.StrPtr(membership)) +} + func (csq *CurrentStateQuery) Get(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*Event, error) { return csq.QueryOne(ctx, getCurrentStateEventQuery, roomID, eventType.Type, stateKey) } diff --git a/hicli/database/statestore.go b/hicli/database/statestore.go index 1779afa5..fcd6aceb 100644 --- a/hicli/database/statestore.go +++ b/hicli/database/statestore.go @@ -39,6 +39,9 @@ const ( SELECT state_key FROM current_state WHERE room_id = $1 AND event_type = 'm.room.member' AND membership IN ('join', 'invite') ` + getHasFetchedMembersQuery = ` + SELECT has_member_list FROM room WHERE room_id = $1 + ` isRoomEncryptedQuery = ` SELECT room.encryption_event IS NOT NULL FROM room WHERE room_id = $1 ` @@ -116,7 +119,12 @@ func (c *ClientStateStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, ro return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList() } -func (c *ClientStateStore) HasFetchedMembers(ctx context.Context, roomID id.RoomID) (bool, error) { +func (c *ClientStateStore) HasFetchedMembers(ctx context.Context, roomID id.RoomID) (hasFetched bool, err error) { + //err = c.QueryRow(ctx, getHasFetchedMembersQuery, roomID).Scan(&hasFetched) + //if errors.Is(err, sql.ErrNoRows) { + // err = nil + //} + //return return false, fmt.Errorf("not implemented") } diff --git a/hicli/events.go b/hicli/events.go index c7228541..ea03be7e 100644 --- a/hicli/events.go +++ b/hicli/events.go @@ -13,10 +13,11 @@ import ( ) type SyncRoom struct { - Meta *database.Room `json:"meta"` - Timeline []database.TimelineRowTuple `json:"timeline"` - Events []*database.Event `json:"events"` - Reset bool `json:"reset"` + Meta *database.Room `json:"meta"` + Timeline []database.TimelineRowTuple `json:"timeline"` + State map[event.Type]map[string]database.EventRowID `json:"state"` + Events []*database.Event `json:"events"` + Reset bool `json:"reset"` } type SyncComplete struct { diff --git a/hicli/json-commands.go b/hicli/json-commands.go index 29a2ac73..12026f6b 100644 --- a/hicli/json-commands.go +++ b/hicli/json-commands.go @@ -49,6 +49,10 @@ func (h *HiClient) handleJSONCommand(ctx context.Context, req *JSONCommand) (any return unmarshalAndCall(req.Data, func(params *getEventsByRowIDsParams) ([]*database.Event, error) { return h.GetEventsByRowIDs(ctx, params.RowIDs) }) + case "get_room_state": + return unmarshalAndCall(req.Data, func(params *getRoomStateParams) ([]*database.Event, error) { + return h.GetRoomState(ctx, params.RoomID, params.FetchMembers, params.Refetch) + }) case "paginate": return unmarshalAndCall(req.Data, func(params *paginateParams) (*PaginationResponse, error) { return h.Paginate(ctx, params.RoomID, params.MaxTimelineID, params.Limit) @@ -111,6 +115,12 @@ type getEventsByRowIDsParams struct { RowIDs []database.EventRowID `json:"row_ids"` } +type getRoomStateParams struct { + RoomID id.RoomID `json:"room_id"` + Refetch bool `json:"refetch"` + FetchMembers bool `json:"fetch_members"` +} + type ensureGroupSessionSharedParams struct { RoomID id.RoomID `json:"room_id"` } diff --git a/hicli/paginate.go b/hicli/paginate.go index 878a033a..4109e7af 100644 --- a/hicli/paginate.go +++ b/hicli/paginate.go @@ -14,6 +14,7 @@ import ( "github.com/rs/zerolog" "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/hicli/database" "maunium.net/go/mautrix/id" ) @@ -62,6 +63,68 @@ func (h *HiClient) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.Ev } } +func (h *HiClient) GetRoomState(ctx context.Context, roomID id.RoomID, fetchMembers, refetch bool) ([]*database.Event, error) { + var evts []*event.Event + if refetch { + resp, err := h.Client.StateAsArray(ctx, roomID) + if err != nil { + return nil, fmt.Errorf("failed to refetch state: %w", err) + } + evts = resp + } else if fetchMembers { + resp, err := h.Client.Members(ctx, roomID) + if err != nil { + return nil, fmt.Errorf("failed to fetch members: %w", err) + } + evts = resp.Chunk + } + if evts != nil { + err := h.DB.DoTxn(ctx, nil, func(ctx context.Context) error { + room, err := h.DB.Room.Get(ctx, roomID) + if err != nil { + return fmt.Errorf("failed to get room from database: %w", err) + } + updatedRoom := &database.Room{ + ID: room.ID, + HasMemberList: true, + } + entries := make([]*database.CurrentStateEntry, len(evts)) + for i, evt := range evts { + dbEvt, err := h.processEvent(ctx, evt, nil, false) + if err != nil { + return fmt.Errorf("failed to process event %s: %w", evt.ID, err) + } + entries[i] = &database.CurrentStateEntry{ + EventType: evt.Type, + StateKey: *evt.StateKey, + EventRowID: dbEvt.RowID, + } + if evt.Type == event.StateMember { + entries[i].Membership = event.Membership(evt.Content.Raw["membership"].(string)) + } else { + processImportantEvent(ctx, evt, room, updatedRoom) + } + } + err = h.DB.CurrentState.AddMany(ctx, room.ID, refetch, entries) + if err != nil { + return err + } + roomChanged := updatedRoom.CheckChangesAndCopyInto(room) + if roomChanged { + err = h.DB.Room.Upsert(ctx, updatedRoom) + if err != nil { + return fmt.Errorf("failed to save room data: %w", err) + } + } + return nil + }) + if err != nil { + return nil, err + } + } + return h.DB.CurrentState.GetAll(ctx, roomID) +} + type PaginationResponse struct { Events []*database.Event `json:"events"` HasMore bool `json:"has_more"` diff --git a/hicli/send.go b/hicli/send.go index 8824f3c3..1c76a5a2 100644 --- a/hicli/send.go +++ b/hicli/send.go @@ -148,17 +148,23 @@ func (h *HiClient) loadMembers(ctx context.Context, room *database.Room) error { return fmt.Errorf("failed to get room member list: %w", err) } err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error { - for _, evt := range resp.Chunk { + entries := make([]*database.CurrentStateEntry, len(resp.Chunk)) + for i, evt := range resp.Chunk { dbEvt, err := h.processEvent(ctx, evt, nil, true) if err != nil { return err } - membership := event.Membership(evt.Content.Raw["membership"].(string)) - err = h.DB.CurrentState.Set(ctx, room.ID, evt.Type, *evt.StateKey, dbEvt.RowID, membership) - if err != nil { - return err + entries[i] = &database.CurrentStateEntry{ + EventType: evt.Type, + StateKey: *evt.StateKey, + EventRowID: dbEvt.RowID, + Membership: event.Membership(evt.Content.Raw["membership"].(string)), } } + err := h.DB.CurrentState.AddMany(ctx, room.ID, false, entries) + if err != nil { + return err + } return h.DB.Room.Upsert(ctx, &database.Room{ ID: room.ID, HasMemberList: true, diff --git a/hicli/sync.go b/hicli/sync.go index 3b40af9f..aaf1f5c6 100644 --- a/hicli/sync.go +++ b/hicli/sync.go @@ -410,15 +410,23 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R allNewEvents = append(allNewEvents, dbEvt) return dbEvt.RowID, nil } - var err error + changedState := make(map[event.Type]map[string]database.EventRowID) + setNewState := func(evtType event.Type, stateKey string, rowID database.EventRowID) { + if _, ok := changedState[evtType]; !ok { + changedState[evtType] = make(map[string]database.EventRowID) + } + changedState[evtType][stateKey] = rowID + } for _, evt := range state.Events { evt.Type.Class = event.StateEventType - _, err = processNewEvent(evt, false) + rowID, err := processNewEvent(evt, false) if err != nil { return err } + setNewState(evt.Type, *evt.StateKey, rowID) } var timelineRowTuples []database.TimelineRowTuple + var err error if len(timeline.Events) > 0 { timelineIDs := make([]database.EventRowID, len(timeline.Events)) for i, evt := range timeline.Events { @@ -431,6 +439,9 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R if err != nil { return err } + if evt.StateKey != nil { + setNewState(evt.Type, *evt.StateKey, timelineIDs[i]) + } } for _, entry := range decryptionQueue { err = h.DB.SessionRequest.Put(ctx, entry) @@ -481,6 +492,7 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R ctx.Value(syncContextKey).(*syncContext).evt.Rooms[room.ID] = &SyncRoom{ Meta: room, Timeline: timelineRowTuples, + State: changedState, Reset: timeline.Limited, Events: allNewEvents, }