mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
hicli: include state in sync and add method to get state
This commit is contained in:
parent
38127b85b2
commit
cb3a7ce87a
9 changed files with 176 additions and 13 deletions
12
client.go
12
client.go
|
|
@ -1469,6 +1469,18 @@ func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomSt
|
|||
return
|
||||
}
|
||||
|
||||
// StateAsArray gets all the state in a room as an array. It does not update the state store.
|
||||
// Use State to get the events as a map and also update the state store.
|
||||
func (cli *Client) StateAsArray(ctx context.Context, roomID id.RoomID) (state []*event.Event, err error) {
|
||||
_, err = cli.MakeRequest(ctx, http.MethodGet, cli.BuildClientURL("v3", "rooms", roomID, "state"), nil, &state)
|
||||
if err == nil {
|
||||
for _, evt := range state {
|
||||
evt.Type.Class = event.StateEventType
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetMediaConfig fetches the configuration of the content repository, such as upload limitations.
|
||||
func (cli *Client) GetMediaConfig(ctx context.Context) (resp *RespMediaConfig, err error) {
|
||||
_, err = cli.MakeRequest(ctx, http.MethodGet, cli.BuildClientURL("v1", "media", "config"), nil, &resp)
|
||||
|
|
|
|||
|
|
@ -153,7 +153,10 @@ func (r *Room) CheckChangesAndCopyInto(other *Room) (hasChanges bool) {
|
|||
other.EncryptionEvent = r.EncryptionEvent
|
||||
hasChanges = true
|
||||
}
|
||||
other.HasMemberList = other.HasMemberList || r.HasMemberList
|
||||
if r.HasMemberList && !other.HasMemberList {
|
||||
hasChanges = true
|
||||
other.HasMemberList = true
|
||||
}
|
||||
if r.PreviewEventRowID > other.PreviewEventRowID {
|
||||
other.PreviewEventRowID = r.PreviewEventRowID
|
||||
hasChanges = true
|
||||
|
|
|
|||
|
|
@ -4,10 +4,14 @@
|
|||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
//go:build go1.23
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
|
||||
|
|
@ -20,6 +24,13 @@ const (
|
|||
INSERT INTO current_state (room_id, event_type, state_key, event_rowid, membership) VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (room_id, event_type, state_key) DO UPDATE SET event_rowid = excluded.event_rowid, membership = excluded.membership
|
||||
`
|
||||
addCurrentStateQuery = `
|
||||
INSERT INTO current_state (room_id, event_type, state_key, event_rowid, membership) VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT DO NOTHING
|
||||
`
|
||||
deleteCurrentStateQuery = `
|
||||
DELETE FROM current_state WHERE room_id = $1
|
||||
`
|
||||
getCurrentRoomStateQuery = `
|
||||
SELECT event.rowid, -1, event.room_id, event.event_id, sender, event.type, event.state_key, timestamp, content, decrypted, decrypted_type, unsigned,
|
||||
transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid
|
||||
|
|
@ -30,6 +41,21 @@ const (
|
|||
getCurrentStateEventQuery = getCurrentRoomStateQuery + `AND cs.event_type = $2 AND cs.state_key = $3`
|
||||
)
|
||||
|
||||
var massInsertCurrentStateBuilder = dbutil.NewMassInsertBuilder[*CurrentStateEntry, [1]any](addCurrentStateQuery, "($1, $%d, $%d, $%d, $%d)")
|
||||
|
||||
const currentStateMassInsertBatchSize = 1000
|
||||
|
||||
type CurrentStateEntry struct {
|
||||
EventType event.Type
|
||||
StateKey string
|
||||
EventRowID EventRowID
|
||||
Membership event.Membership
|
||||
}
|
||||
|
||||
func (cse *CurrentStateEntry) GetMassInsertValues() [4]any {
|
||||
return [4]any{cse.EventType.Type, cse.StateKey, cse.EventRowID, dbutil.StrPtr(cse.Membership)}
|
||||
}
|
||||
|
||||
type CurrentStateQuery struct {
|
||||
*dbutil.QueryHelper[*Event]
|
||||
}
|
||||
|
|
@ -38,6 +64,28 @@ func (csq *CurrentStateQuery) Set(ctx context.Context, roomID id.RoomID, eventTy
|
|||
return csq.Exec(ctx, setCurrentStateQuery, roomID, eventType.Type, stateKey, eventRowID, dbutil.StrPtr(membership))
|
||||
}
|
||||
|
||||
func (csq *CurrentStateQuery) AddMany(ctx context.Context, roomID id.RoomID, deleteOld bool, entries []*CurrentStateEntry) error {
|
||||
var err error
|
||||
if deleteOld {
|
||||
err = csq.Exec(ctx, deleteCurrentStateQuery, roomID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete old state: %w", err)
|
||||
}
|
||||
}
|
||||
for entryChunk := range slices.Chunk(entries, currentStateMassInsertBatchSize) {
|
||||
query, params := massInsertCurrentStateBuilder.Build([1]any{roomID}, entryChunk)
|
||||
err = csq.Exec(ctx, query, params...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (csq *CurrentStateQuery) Add(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, eventRowID EventRowID, membership event.Membership) error {
|
||||
return csq.Exec(ctx, addCurrentStateQuery, roomID, eventType.Type, stateKey, eventRowID, dbutil.StrPtr(membership))
|
||||
}
|
||||
|
||||
func (csq *CurrentStateQuery) Get(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*Event, error) {
|
||||
return csq.QueryOne(ctx, getCurrentStateEventQuery, roomID, eventType.Type, stateKey)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -39,6 +39,9 @@ const (
|
|||
SELECT state_key FROM current_state
|
||||
WHERE room_id = $1 AND event_type = 'm.room.member' AND membership IN ('join', 'invite')
|
||||
`
|
||||
getHasFetchedMembersQuery = `
|
||||
SELECT has_member_list FROM room WHERE room_id = $1
|
||||
`
|
||||
isRoomEncryptedQuery = `
|
||||
SELECT room.encryption_event IS NOT NULL FROM room WHERE room_id = $1
|
||||
`
|
||||
|
|
@ -116,7 +119,12 @@ func (c *ClientStateStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, ro
|
|||
return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList()
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) HasFetchedMembers(ctx context.Context, roomID id.RoomID) (bool, error) {
|
||||
func (c *ClientStateStore) HasFetchedMembers(ctx context.Context, roomID id.RoomID) (hasFetched bool, err error) {
|
||||
//err = c.QueryRow(ctx, getHasFetchedMembersQuery, roomID).Scan(&hasFetched)
|
||||
//if errors.Is(err, sql.ErrNoRows) {
|
||||
// err = nil
|
||||
//}
|
||||
//return
|
||||
return false, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -13,10 +13,11 @@ import (
|
|||
)
|
||||
|
||||
type SyncRoom struct {
|
||||
Meta *database.Room `json:"meta"`
|
||||
Timeline []database.TimelineRowTuple `json:"timeline"`
|
||||
Events []*database.Event `json:"events"`
|
||||
Reset bool `json:"reset"`
|
||||
Meta *database.Room `json:"meta"`
|
||||
Timeline []database.TimelineRowTuple `json:"timeline"`
|
||||
State map[event.Type]map[string]database.EventRowID `json:"state"`
|
||||
Events []*database.Event `json:"events"`
|
||||
Reset bool `json:"reset"`
|
||||
}
|
||||
|
||||
type SyncComplete struct {
|
||||
|
|
|
|||
|
|
@ -49,6 +49,10 @@ func (h *HiClient) handleJSONCommand(ctx context.Context, req *JSONCommand) (any
|
|||
return unmarshalAndCall(req.Data, func(params *getEventsByRowIDsParams) ([]*database.Event, error) {
|
||||
return h.GetEventsByRowIDs(ctx, params.RowIDs)
|
||||
})
|
||||
case "get_room_state":
|
||||
return unmarshalAndCall(req.Data, func(params *getRoomStateParams) ([]*database.Event, error) {
|
||||
return h.GetRoomState(ctx, params.RoomID, params.FetchMembers, params.Refetch)
|
||||
})
|
||||
case "paginate":
|
||||
return unmarshalAndCall(req.Data, func(params *paginateParams) (*PaginationResponse, error) {
|
||||
return h.Paginate(ctx, params.RoomID, params.MaxTimelineID, params.Limit)
|
||||
|
|
@ -111,6 +115,12 @@ type getEventsByRowIDsParams struct {
|
|||
RowIDs []database.EventRowID `json:"row_ids"`
|
||||
}
|
||||
|
||||
type getRoomStateParams struct {
|
||||
RoomID id.RoomID `json:"room_id"`
|
||||
Refetch bool `json:"refetch"`
|
||||
FetchMembers bool `json:"fetch_members"`
|
||||
}
|
||||
|
||||
type ensureGroupSessionSharedParams struct {
|
||||
RoomID id.RoomID `json:"room_id"`
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ import (
|
|||
"github.com/rs/zerolog"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/hicli/database"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
|
@ -62,6 +63,68 @@ func (h *HiClient) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.Ev
|
|||
}
|
||||
}
|
||||
|
||||
func (h *HiClient) GetRoomState(ctx context.Context, roomID id.RoomID, fetchMembers, refetch bool) ([]*database.Event, error) {
|
||||
var evts []*event.Event
|
||||
if refetch {
|
||||
resp, err := h.Client.StateAsArray(ctx, roomID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to refetch state: %w", err)
|
||||
}
|
||||
evts = resp
|
||||
} else if fetchMembers {
|
||||
resp, err := h.Client.Members(ctx, roomID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch members: %w", err)
|
||||
}
|
||||
evts = resp.Chunk
|
||||
}
|
||||
if evts != nil {
|
||||
err := h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
|
||||
room, err := h.DB.Room.Get(ctx, roomID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get room from database: %w", err)
|
||||
}
|
||||
updatedRoom := &database.Room{
|
||||
ID: room.ID,
|
||||
HasMemberList: true,
|
||||
}
|
||||
entries := make([]*database.CurrentStateEntry, len(evts))
|
||||
for i, evt := range evts {
|
||||
dbEvt, err := h.processEvent(ctx, evt, nil, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process event %s: %w", evt.ID, err)
|
||||
}
|
||||
entries[i] = &database.CurrentStateEntry{
|
||||
EventType: evt.Type,
|
||||
StateKey: *evt.StateKey,
|
||||
EventRowID: dbEvt.RowID,
|
||||
}
|
||||
if evt.Type == event.StateMember {
|
||||
entries[i].Membership = event.Membership(evt.Content.Raw["membership"].(string))
|
||||
} else {
|
||||
processImportantEvent(ctx, evt, room, updatedRoom)
|
||||
}
|
||||
}
|
||||
err = h.DB.CurrentState.AddMany(ctx, room.ID, refetch, entries)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
roomChanged := updatedRoom.CheckChangesAndCopyInto(room)
|
||||
if roomChanged {
|
||||
err = h.DB.Room.Upsert(ctx, updatedRoom)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save room data: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return h.DB.CurrentState.GetAll(ctx, roomID)
|
||||
}
|
||||
|
||||
type PaginationResponse struct {
|
||||
Events []*database.Event `json:"events"`
|
||||
HasMore bool `json:"has_more"`
|
||||
|
|
|
|||
|
|
@ -148,17 +148,23 @@ func (h *HiClient) loadMembers(ctx context.Context, room *database.Room) error {
|
|||
return fmt.Errorf("failed to get room member list: %w", err)
|
||||
}
|
||||
err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
|
||||
for _, evt := range resp.Chunk {
|
||||
entries := make([]*database.CurrentStateEntry, len(resp.Chunk))
|
||||
for i, evt := range resp.Chunk {
|
||||
dbEvt, err := h.processEvent(ctx, evt, nil, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
membership := event.Membership(evt.Content.Raw["membership"].(string))
|
||||
err = h.DB.CurrentState.Set(ctx, room.ID, evt.Type, *evt.StateKey, dbEvt.RowID, membership)
|
||||
if err != nil {
|
||||
return err
|
||||
entries[i] = &database.CurrentStateEntry{
|
||||
EventType: evt.Type,
|
||||
StateKey: *evt.StateKey,
|
||||
EventRowID: dbEvt.RowID,
|
||||
Membership: event.Membership(evt.Content.Raw["membership"].(string)),
|
||||
}
|
||||
}
|
||||
err := h.DB.CurrentState.AddMany(ctx, room.ID, false, entries)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return h.DB.Room.Upsert(ctx, &database.Room{
|
||||
ID: room.ID,
|
||||
HasMemberList: true,
|
||||
|
|
|
|||
|
|
@ -410,15 +410,23 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R
|
|||
allNewEvents = append(allNewEvents, dbEvt)
|
||||
return dbEvt.RowID, nil
|
||||
}
|
||||
var err error
|
||||
changedState := make(map[event.Type]map[string]database.EventRowID)
|
||||
setNewState := func(evtType event.Type, stateKey string, rowID database.EventRowID) {
|
||||
if _, ok := changedState[evtType]; !ok {
|
||||
changedState[evtType] = make(map[string]database.EventRowID)
|
||||
}
|
||||
changedState[evtType][stateKey] = rowID
|
||||
}
|
||||
for _, evt := range state.Events {
|
||||
evt.Type.Class = event.StateEventType
|
||||
_, err = processNewEvent(evt, false)
|
||||
rowID, err := processNewEvent(evt, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
setNewState(evt.Type, *evt.StateKey, rowID)
|
||||
}
|
||||
var timelineRowTuples []database.TimelineRowTuple
|
||||
var err error
|
||||
if len(timeline.Events) > 0 {
|
||||
timelineIDs := make([]database.EventRowID, len(timeline.Events))
|
||||
for i, evt := range timeline.Events {
|
||||
|
|
@ -431,6 +439,9 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if evt.StateKey != nil {
|
||||
setNewState(evt.Type, *evt.StateKey, timelineIDs[i])
|
||||
}
|
||||
}
|
||||
for _, entry := range decryptionQueue {
|
||||
err = h.DB.SessionRequest.Put(ctx, entry)
|
||||
|
|
@ -481,6 +492,7 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R
|
|||
ctx.Value(syncContextKey).(*syncContext).evt.Rooms[room.ID] = &SyncRoom{
|
||||
Meta: room,
|
||||
Timeline: timelineRowTuples,
|
||||
State: changedState,
|
||||
Reset: timeline.Limited,
|
||||
Events: allNewEvents,
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue