mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
Merge branch 'main' into tulir/extended-profiles-for-everyone
Some checks failed
Some checks failed
This commit is contained in:
commit
e4e8933b33
19 changed files with 397 additions and 183 deletions
|
|
@ -59,6 +59,7 @@ type BridgeConfig struct {
|
|||
CommandPrefix string `yaml:"command_prefix"`
|
||||
PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"`
|
||||
PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"`
|
||||
AsyncEvents bool `yaml:"async_events"`
|
||||
BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"`
|
||||
TagOnlyOnCreate bool `yaml:"tag_only_on_create"`
|
||||
MuteOnlyOnCreate bool `yaml:"mute_only_on_create"`
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ func doUpgrade(helper up.Helper) {
|
|||
helper.Copy(up.Str, "bridge", "command_prefix")
|
||||
helper.Copy(up.Bool, "bridge", "personal_filtering_spaces")
|
||||
helper.Copy(up.Bool, "bridge", "private_chat_portal_meta")
|
||||
helper.Copy(up.Bool, "bridge", "async_events")
|
||||
helper.Copy(up.Bool, "bridge", "bridge_matrix_leave")
|
||||
helper.Copy(up.Bool, "bridge", "tag_only_on_create")
|
||||
helper.Copy(up.Bool, "bridge", "mute_only_on_create")
|
||||
|
|
|
|||
|
|
@ -289,7 +289,8 @@ func (as *ASIntent) UploadMediaStream(
|
|||
var res *bridgev2.FileStreamResult
|
||||
res, err = cb(tempFile)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to write to temp file: %w", err)
|
||||
err = fmt.Errorf("write callback failed: %w", err)
|
||||
return
|
||||
}
|
||||
var replFile *os.File
|
||||
if res.ReplacementFile != "" {
|
||||
|
|
|
|||
|
|
@ -7,6 +7,9 @@ bridge:
|
|||
# Whether the bridge should set names and avatars explicitly for DM portals.
|
||||
# This is only necessary when using clients that don't support MSC4171.
|
||||
private_chat_portal_meta: false
|
||||
# Should events be handled asynchronously within portal rooms?
|
||||
# If true, events may end up being out of order, but slow events won't block other ones.
|
||||
async_events: false
|
||||
|
||||
# Should leaving Matrix rooms be bridged as leaving groups on the remote network?
|
||||
bridge_matrix_leave: false
|
||||
|
|
|
|||
|
|
@ -54,6 +54,11 @@ type ProvisioningAPI struct {
|
|||
// GetAuthFromRequest is a custom function for getting the auth token from
|
||||
// the request if the Authorization header is not present.
|
||||
GetAuthFromRequest func(r *http.Request) string
|
||||
|
||||
// GetUserIDFromRequest is a custom function for getting the user ID to
|
||||
// authenticate as instead of using the user ID provided in the query
|
||||
// parameter.
|
||||
GetUserIDFromRequest func(r *http.Request) id.UserID
|
||||
}
|
||||
|
||||
type ProvLogin struct {
|
||||
|
|
@ -200,6 +205,9 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
|
|||
return
|
||||
}
|
||||
userID := id.UserID(r.URL.Query().Get("user_id"))
|
||||
if userID == "" && prov.GetUserIDFromRequest != nil {
|
||||
userID = prov.GetUserIDFromRequest(r)
|
||||
}
|
||||
if auth != prov.br.Config.Provisioning.SharedSecret {
|
||||
var err error
|
||||
if strings.HasPrefix(auth, "openid:") {
|
||||
|
|
@ -227,6 +235,14 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
|
|||
return
|
||||
}
|
||||
// TODO handle user being nil?
|
||||
// TODO per-endpoint permissions?
|
||||
if !user.Permissions.Login {
|
||||
jsonResponse(w, http.StatusForbidden, &mautrix.RespError{
|
||||
Err: "User does not have login permissions",
|
||||
ErrCode: mautrix.MForbidden.ErrCode,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), provisioningUserKey, user)
|
||||
if loginID, ok := mux.Vars(r)["loginProcessID"]; ok {
|
||||
|
|
|
|||
|
|
@ -273,11 +273,16 @@ type MaxFileSizeingNetwork interface {
|
|||
SetMaxFileSize(maxSize int64)
|
||||
}
|
||||
|
||||
type RemoteEchoHandler func(RemoteMessage, *database.Message) (bool, error)
|
||||
|
||||
type MatrixMessageResponse struct {
|
||||
DB *database.Message
|
||||
|
||||
Pending networkid.TransactionID
|
||||
HandleEcho func(RemoteMessage, *database.Message) (bool, error)
|
||||
// If Pending is set, the bridge will not save the provided message to the database.
|
||||
// This should only be used if AddPendingToSave has been called.
|
||||
Pending bool
|
||||
// If RemovePending is set, the bridge will remove the provided transaction ID from pending messages
|
||||
// after saving the provided message to the database. This should be used with AddPendingToIgnore.
|
||||
RemovePending networkid.TransactionID
|
||||
}
|
||||
|
||||
type FileRestriction struct {
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ import (
|
|||
"runtime/debug"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
|
@ -59,6 +60,7 @@ type portalEvent interface {
|
|||
type outgoingMessage struct {
|
||||
db *database.Message
|
||||
evt *event.Event
|
||||
ignore bool
|
||||
handle func(RemoteMessage, *database.Message) (bool, error)
|
||||
}
|
||||
|
||||
|
|
@ -274,23 +276,49 @@ func (portal *Portal) queueEvent(ctx context.Context, evt portalEvent) {
|
|||
|
||||
func (portal *Portal) eventLoop() {
|
||||
for rawEvt := range portal.events {
|
||||
switch evt := rawEvt.(type) {
|
||||
case *portalMatrixEvent:
|
||||
portal.handleMatrixEvent(evt.sender, evt.evt)
|
||||
case *portalRemoteEvent:
|
||||
portal.handleRemoteEvent(evt.source, evt.evt)
|
||||
case *portalCreateEvent:
|
||||
portal.handleCreateEvent(evt)
|
||||
default:
|
||||
panic(fmt.Errorf("illegal type %T in eventLoop", evt))
|
||||
}
|
||||
portal.handleSingleEventAsync(rawEvt)
|
||||
}
|
||||
}
|
||||
|
||||
func (portal *Portal) handleCreateEvent(evt *portalCreateEvent) {
|
||||
func (portal *Portal) handleSingleEventAsync(rawEvt any) {
|
||||
log := portal.Log.With().Logger()
|
||||
if _, isCreate := rawEvt.(*portalCreateEvent); isCreate {
|
||||
portal.handleSingleEvent(&log, rawEvt, func() {})
|
||||
} else if portal.Bridge.Config.AsyncEvents {
|
||||
go portal.handleSingleEvent(&log, rawEvt, func() {})
|
||||
} else {
|
||||
doneCh := make(chan struct{})
|
||||
var backgrounded atomic.Bool
|
||||
go portal.handleSingleEvent(&log, rawEvt, func() {
|
||||
close(doneCh)
|
||||
if backgrounded.Load() {
|
||||
log.Debug().Msg("Event that took too long finally finished handling")
|
||||
}
|
||||
})
|
||||
tick := time.NewTicker(30 * time.Second)
|
||||
defer tick.Stop()
|
||||
for i := 0; i < 10; i++ {
|
||||
select {
|
||||
case <-doneCh:
|
||||
if i > 0 {
|
||||
log.Debug().Msg("Event that took long finished handling")
|
||||
}
|
||||
return
|
||||
case <-tick.C:
|
||||
log.Warn().Msg("Event handling is taking long")
|
||||
}
|
||||
}
|
||||
log.Warn().Msg("Event handling is taking too long, continuing in background")
|
||||
backgrounded.Store(true)
|
||||
}
|
||||
}
|
||||
|
||||
func (portal *Portal) handleSingleEvent(log *zerolog.Logger, rawEvt any, doneCallback func()) {
|
||||
ctx := log.WithContext(context.Background())
|
||||
defer func() {
|
||||
doneCallback()
|
||||
if err := recover(); err != nil {
|
||||
logEvt := zerolog.Ctx(evt.ctx).Error()
|
||||
logEvt := log.Error()
|
||||
if realErr, ok := err.(error); ok {
|
||||
logEvt = logEvt.Err(realErr)
|
||||
} else {
|
||||
|
|
@ -299,10 +327,36 @@ func (portal *Portal) handleCreateEvent(evt *portalCreateEvent) {
|
|||
logEvt.
|
||||
Bytes("stack", debug.Stack()).
|
||||
Msg("Portal creation panicked")
|
||||
evt.cb(fmt.Errorf("portal creation panicked"))
|
||||
switch evt := rawEvt.(type) {
|
||||
case *portalMatrixEvent:
|
||||
if evt.evt.ID != "" {
|
||||
go portal.sendErrorStatus(ctx, evt.evt, ErrPanicInEventHandler)
|
||||
}
|
||||
case *portalCreateEvent:
|
||||
evt.cb(fmt.Errorf("portal creation panicked"))
|
||||
}
|
||||
}
|
||||
}()
|
||||
evt.cb(portal.createMatrixRoomInLoop(evt.ctx, evt.source, evt.info, nil))
|
||||
switch evt := rawEvt.(type) {
|
||||
case *portalMatrixEvent:
|
||||
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str("action", "handle matrix event").
|
||||
Stringer("event_id", evt.evt.ID).
|
||||
Str("event_type", evt.evt.Type.Type)
|
||||
})
|
||||
portal.handleMatrixEvent(ctx, evt.sender, evt.evt)
|
||||
case *portalRemoteEvent:
|
||||
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str("action", "handle remote event").
|
||||
Str("source_id", string(evt.source.ID))
|
||||
})
|
||||
portal.handleRemoteEvent(ctx, evt.source, evt.evt)
|
||||
case *portalCreateEvent:
|
||||
*log = *zerolog.Ctx(evt.ctx)
|
||||
evt.cb(portal.createMatrixRoomInLoop(evt.ctx, evt.source, evt.info, nil))
|
||||
default:
|
||||
panic(fmt.Errorf("illegal type %T in eventLoop", evt))
|
||||
}
|
||||
}
|
||||
|
||||
func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowRelay bool) (*UserLogin, *database.UserPortal, error) {
|
||||
|
|
@ -392,29 +446,8 @@ func (portal *Portal) checkConfusableName(ctx context.Context, userID id.UserID,
|
|||
return false
|
||||
}
|
||||
|
||||
func (portal *Portal) handleMatrixEvent(sender *User, evt *event.Event) {
|
||||
log := portal.Log.With().
|
||||
Str("action", "handle matrix event").
|
||||
Stringer("event_id", evt.ID).
|
||||
Str("event_type", evt.Type.Type).
|
||||
Logger()
|
||||
ctx := log.WithContext(context.TODO())
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
logEvt := log.Error()
|
||||
if realErr, ok := err.(error); ok {
|
||||
logEvt = logEvt.Err(realErr)
|
||||
} else {
|
||||
logEvt = logEvt.Any(zerolog.ErrorFieldName, err)
|
||||
}
|
||||
logEvt.
|
||||
Bytes("stack", debug.Stack()).
|
||||
Msg("Matrix event handler panicked")
|
||||
if evt.ID != "" {
|
||||
go portal.sendErrorStatus(ctx, evt, ErrPanicInEventHandler)
|
||||
}
|
||||
}
|
||||
}()
|
||||
func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *event.Event) {
|
||||
log := zerolog.Ctx(ctx)
|
||||
if evt.Mautrix.EventSource&event.SourceEphemeral != 0 {
|
||||
switch evt.Type {
|
||||
case event.EphemeralEventReceipt:
|
||||
|
|
@ -775,7 +808,7 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin
|
|||
}
|
||||
}
|
||||
|
||||
resp, err := sender.Client.HandleMatrixMessage(ctx, &MatrixMessage{
|
||||
wrappedEvt := &MatrixMessage{
|
||||
MatrixEventBase: MatrixEventBase[*event.MessageEventContent]{
|
||||
Event: evt,
|
||||
Content: content,
|
||||
|
|
@ -784,52 +817,30 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin
|
|||
},
|
||||
ThreadRoot: threadRoot,
|
||||
ReplyTo: replyTo,
|
||||
})
|
||||
}
|
||||
resp, err := sender.Client.HandleMatrixMessage(ctx, wrappedEvt)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to handle Matrix message")
|
||||
portal.sendErrorStatus(ctx, evt, err)
|
||||
return
|
||||
}
|
||||
message := resp.DB
|
||||
if message.MXID == "" {
|
||||
message.MXID = evt.ID
|
||||
}
|
||||
if message.Room.ID == "" {
|
||||
message.Room = portal.PortalKey
|
||||
}
|
||||
if message.Timestamp.IsZero() {
|
||||
message.Timestamp = time.UnixMilli(evt.Timestamp)
|
||||
}
|
||||
if message.ReplyTo.MessageID == "" && replyTo != nil {
|
||||
message.ReplyTo.MessageID = replyTo.ID
|
||||
message.ReplyTo.PartID = &replyTo.PartID
|
||||
}
|
||||
if message.ThreadRoot == "" && threadRoot != nil {
|
||||
message.ThreadRoot = threadRoot.ID
|
||||
if threadRoot.ThreadRoot != "" {
|
||||
message.ThreadRoot = threadRoot.ThreadRoot
|
||||
}
|
||||
}
|
||||
if message.SenderMXID == "" {
|
||||
message.SenderMXID = evt.Sender
|
||||
}
|
||||
if resp.Pending != "" {
|
||||
// TODO if the event queue is ever removed, this will have to be done by the network connector before sending the request
|
||||
// (for now this is fine because incoming messages will wait in the queue for this function to return)
|
||||
portal.outgoingMessagesLock.Lock()
|
||||
portal.outgoingMessages[resp.Pending] = outgoingMessage{
|
||||
db: message,
|
||||
evt: evt,
|
||||
handle: resp.HandleEcho,
|
||||
}
|
||||
portal.outgoingMessagesLock.Unlock()
|
||||
} else {
|
||||
// Hack to ensure the ghost row exists
|
||||
// TODO move to better place (like login)
|
||||
portal.Bridge.GetGhostByID(ctx, message.SenderID)
|
||||
err = portal.Bridge.DB.Message.Insert(ctx, message)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to save message to database")
|
||||
message := wrappedEvt.fillDBMessage(resp.DB)
|
||||
if !resp.Pending {
|
||||
if resp.DB == nil {
|
||||
log.Error().Msg("Network connector didn't return a message to save")
|
||||
} else {
|
||||
// Hack to ensure the ghost row exists
|
||||
// TODO move to better place (like login)
|
||||
portal.Bridge.GetGhostByID(ctx, message.SenderID)
|
||||
err = portal.Bridge.DB.Message.Insert(ctx, message)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to save message to database")
|
||||
}
|
||||
if resp.RemovePending != "" {
|
||||
portal.outgoingMessagesLock.Lock()
|
||||
delete(portal.outgoingMessages, resp.RemovePending)
|
||||
portal.outgoingMessagesLock.Unlock()
|
||||
}
|
||||
}
|
||||
portal.sendSuccessStatus(ctx, evt)
|
||||
}
|
||||
|
|
@ -846,6 +857,75 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin
|
|||
}
|
||||
}
|
||||
|
||||
// AddPendingToIgnore adds a transaction ID that should be ignored if encountered as a new message.
|
||||
//
|
||||
// This should be used when the network connector will return the real message ID from HandleMatrixMessage.
|
||||
// The [MatrixMessageResponse] should include RemovePending with the transaction ID sto remove it from the lit
|
||||
// after saving to database.
|
||||
//
|
||||
// See also: [MatrixMessage.AddPendingToSave]
|
||||
func (evt *MatrixMessage) AddPendingToIgnore(txnID networkid.TransactionID) {
|
||||
evt.Portal.outgoingMessagesLock.Lock()
|
||||
evt.Portal.outgoingMessages[txnID] = outgoingMessage{
|
||||
ignore: true,
|
||||
}
|
||||
evt.Portal.outgoingMessagesLock.Unlock()
|
||||
}
|
||||
|
||||
// AddPendingToSave adds a transaction ID that should be processed and pointed at the existing event if encountered.
|
||||
//
|
||||
// This should be used when the network connector returns `Pending: true` from HandleMatrixMessage,
|
||||
// i.e. when the network connector does not know the message ID at the end of the handler.
|
||||
// The [MatrixMessageResponse] should set Pending to true to prevent saving the returned message to the database.
|
||||
//
|
||||
// The provided function will be called when the message is encountered.
|
||||
func (evt *MatrixMessage) AddPendingToSave(message *database.Message, txnID networkid.TransactionID, handleEcho RemoteEchoHandler) {
|
||||
evt.Portal.outgoingMessagesLock.Lock()
|
||||
evt.Portal.outgoingMessages[txnID] = outgoingMessage{
|
||||
db: evt.fillDBMessage(message),
|
||||
evt: evt.Event,
|
||||
handle: handleEcho,
|
||||
}
|
||||
evt.Portal.outgoingMessagesLock.Unlock()
|
||||
}
|
||||
|
||||
// RemovePending removes a transaction ID from the list of pending messages.
|
||||
// This should only be called if sending the message fails.
|
||||
func (evt *MatrixMessage) RemovePending(txnID networkid.TransactionID) {
|
||||
evt.Portal.outgoingMessagesLock.Lock()
|
||||
delete(evt.Portal.outgoingMessages, txnID)
|
||||
evt.Portal.outgoingMessagesLock.Unlock()
|
||||
}
|
||||
|
||||
func (evt *MatrixMessage) fillDBMessage(message *database.Message) *database.Message {
|
||||
if message == nil {
|
||||
message = &database.Message{}
|
||||
}
|
||||
if message.MXID == "" {
|
||||
message.MXID = evt.Event.ID
|
||||
}
|
||||
if message.Room.ID == "" {
|
||||
message.Room = evt.Portal.PortalKey
|
||||
}
|
||||
if message.Timestamp.IsZero() {
|
||||
message.Timestamp = time.UnixMilli(evt.Event.Timestamp)
|
||||
}
|
||||
if message.ReplyTo.MessageID == "" && evt.ReplyTo != nil {
|
||||
message.ReplyTo.MessageID = evt.ReplyTo.ID
|
||||
message.ReplyTo.PartID = &evt.ReplyTo.PartID
|
||||
}
|
||||
if message.ThreadRoot == "" && evt.ThreadRoot != nil {
|
||||
message.ThreadRoot = evt.ThreadRoot.ID
|
||||
if evt.ThreadRoot.ThreadRoot != "" {
|
||||
message.ThreadRoot = evt.ThreadRoot.ThreadRoot
|
||||
}
|
||||
}
|
||||
if message.SenderMXID == "" {
|
||||
message.SenderMXID = evt.Event.Sender
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
||||
func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, content *event.MessageEventContent, caps *NetworkRoomCapabilities) {
|
||||
log := zerolog.Ctx(ctx)
|
||||
editTargetID := content.RelatesTo.GetReplaceID()
|
||||
|
|
@ -1410,11 +1490,8 @@ func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLog
|
|||
portal.sendSuccessStatus(ctx, evt)
|
||||
}
|
||||
|
||||
func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) {
|
||||
log := portal.Log.With().
|
||||
Str("source_id", string(source.ID)).
|
||||
Str("action", "handle remote event").
|
||||
Logger()
|
||||
func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, evt RemoteEvent) {
|
||||
log := zerolog.Ctx(ctx)
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
logEvt := log.Error()
|
||||
|
|
@ -1433,7 +1510,6 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) {
|
|||
c = c.Stringer("bridge_evt_type", evtType)
|
||||
return evt.AddLogContext(c)
|
||||
})
|
||||
ctx := log.WithContext(context.TODO())
|
||||
if portal.MXID == "" {
|
||||
mcp, ok := evt.(RemoteEventThatMayCreatePortal)
|
||||
if !ok || !mcp.ShouldCreatePortal() {
|
||||
|
|
@ -1715,6 +1791,8 @@ func (portal *Portal) checkPendingMessage(ctx context.Context, evt RemoteMessage
|
|||
pending, ok := portal.outgoingMessages[txnID]
|
||||
if !ok {
|
||||
return false, nil
|
||||
} else if pending.ignore {
|
||||
return true, nil
|
||||
}
|
||||
delete(portal.outgoingMessages, txnID)
|
||||
pending.db.ID = evt.GetID()
|
||||
|
|
@ -1773,7 +1851,8 @@ func (portal *Portal) handleRemoteUpsert(ctx context.Context, source *UserLogin,
|
|||
}
|
||||
if len(res.SubEvents) > 0 {
|
||||
for _, subEvt := range res.SubEvents {
|
||||
portal.handleRemoteEvent(source, subEvt)
|
||||
log := portal.Log.With().Str("source_id", string(source.ID)).Str("action", "handle remote subevent").Logger()
|
||||
portal.handleRemoteEvent(log.WithContext(ctx), source, subEvt)
|
||||
}
|
||||
}
|
||||
return res.ContinueMessageHandling
|
||||
|
|
|
|||
|
|
@ -304,17 +304,6 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin
|
|||
partIDs = append(partIDs, part.ID)
|
||||
portal.applyRelationMeta(part.Content, replyTo, threadRoot, prevThreadEvent)
|
||||
evtID := portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, msg.ID, part.ID)
|
||||
out.Events = append(out.Events, &event.Event{
|
||||
Sender: intent.GetMXID(),
|
||||
Type: part.Type,
|
||||
Timestamp: msg.Timestamp.UnixMilli(),
|
||||
ID: evtID,
|
||||
RoomID: portal.MXID,
|
||||
Content: event.Content{
|
||||
Parsed: part.Content,
|
||||
Raw: part.Extra,
|
||||
},
|
||||
})
|
||||
dbMessage := &database.Message{
|
||||
ID: msg.ID,
|
||||
PartID: part.ID,
|
||||
|
|
@ -327,6 +316,22 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin
|
|||
ReplyTo: ptr.Val(msg.ReplyTo),
|
||||
Metadata: part.DBMetadata,
|
||||
}
|
||||
if part.DontBridge {
|
||||
dbMessage.SetFakeMXID()
|
||||
out.DBMessages = append(out.DBMessages, dbMessage)
|
||||
continue
|
||||
}
|
||||
out.Events = append(out.Events, &event.Event{
|
||||
Sender: intent.GetMXID(),
|
||||
Type: part.Type,
|
||||
Timestamp: msg.Timestamp.UnixMilli(),
|
||||
ID: evtID,
|
||||
RoomID: portal.MXID,
|
||||
Content: event.Content{
|
||||
Parsed: part.Content,
|
||||
Raw: part.Extra,
|
||||
},
|
||||
})
|
||||
if firstPart == nil {
|
||||
firstPart = dbMessage
|
||||
}
|
||||
|
|
|
|||
|
|
@ -37,8 +37,12 @@ func (portal *PortalInternals) EventLoop() {
|
|||
(*Portal)(portal).eventLoop()
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) HandleCreateEvent(evt *portalCreateEvent) {
|
||||
(*Portal)(portal).handleCreateEvent(evt)
|
||||
func (portal *PortalInternals) HandleSingleEventAsync(rawEvt any) {
|
||||
(*Portal)(portal).handleSingleEventAsync(rawEvt)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) HandleSingleEvent(log *zerolog.Logger, rawEvt any, doneCallback func()) {
|
||||
(*Portal)(portal).handleSingleEvent(log, rawEvt, doneCallback)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) SendSuccessStatus(ctx context.Context, evt *event.Event) {
|
||||
|
|
@ -53,8 +57,8 @@ func (portal *PortalInternals) CheckConfusableName(ctx context.Context, userID i
|
|||
return (*Portal)(portal).checkConfusableName(ctx, userID, name)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) HandleMatrixEvent(sender *User, evt *event.Event) {
|
||||
(*Portal)(portal).handleMatrixEvent(sender, evt)
|
||||
func (portal *PortalInternals) HandleMatrixEvent(ctx context.Context, sender *User, evt *event.Event) {
|
||||
(*Portal)(portal).handleMatrixEvent(ctx, sender, evt)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) HandleMatrixReceipts(ctx context.Context, evt *event.Event) {
|
||||
|
|
@ -109,8 +113,8 @@ func (portal *PortalInternals) HandleMatrixRedaction(ctx context.Context, sender
|
|||
(*Portal)(portal).handleMatrixRedaction(ctx, sender, origSender, evt)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) HandleRemoteEvent(source *UserLogin, evt RemoteEvent) {
|
||||
(*Portal)(portal).handleRemoteEvent(source, evt)
|
||||
func (portal *PortalInternals) HandleRemoteEvent(ctx context.Context, source *UserLogin, evt RemoteEvent) {
|
||||
(*Portal)(portal).handleRemoteEvent(ctx, source, evt)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) GetIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID) {
|
||||
|
|
@ -297,6 +301,10 @@ func (portal *PortalInternals) DoThreadBackfill(ctx context.Context, source *Use
|
|||
(*Portal)(portal).doThreadBackfill(ctx, source, threadID)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) CutoffMessages(ctx context.Context, messages []*BackfillMessage, aggressiveDedup, forward bool, lastMessage *database.Message) []*BackfillMessage {
|
||||
return (*Portal)(portal).cutoffMessages(ctx, messages, aggressiveDedup, forward, lastMessage)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) SendBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead, inThread bool) {
|
||||
(*Portal)(portal).sendBackfill(ctx, source, messages, forceForward, markRead, inThread)
|
||||
}
|
||||
|
|
|
|||
96
client.go
96
client.go
|
|
@ -13,6 +13,7 @@ import (
|
|||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
|
@ -20,7 +21,9 @@ import (
|
|||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.mau.fi/util/ptr"
|
||||
"go.mau.fi/util/retryafter"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/backup"
|
||||
"maunium.net/go/mautrix/event"
|
||||
|
|
@ -322,7 +325,9 @@ func (cli *Client) RequestStart(req *http.Request) {
|
|||
|
||||
func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err error, handlerErr error, contentLength int, duration time.Duration) {
|
||||
var evt *zerolog.Event
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
evt = zerolog.Ctx(req.Context()).Warn()
|
||||
} else if err != nil {
|
||||
evt = zerolog.Ctx(req.Context()).Err(err)
|
||||
} else if handlerErr != nil {
|
||||
evt = zerolog.Ctx(req.Context()).Warn().
|
||||
|
|
@ -355,7 +360,9 @@ func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err er
|
|||
if body := req.Context().Value(LogBodyContextKey); body != nil {
|
||||
evt.Interface("req_body", body)
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
evt.Msg("Request canceled")
|
||||
} else if err != nil {
|
||||
evt.Msg("Request failed")
|
||||
} else if handlerErr != nil {
|
||||
evt.Msg("Request parsing failed")
|
||||
|
|
@ -1498,21 +1505,19 @@ func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomSt
|
|||
Handler: parseRoomStateArray,
|
||||
})
|
||||
if err == nil && cli.StateStore != nil {
|
||||
clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID)
|
||||
if clearErr != nil {
|
||||
cli.cliOrContextLog(ctx).Warn().Err(clearErr).
|
||||
Stringer("room_id", roomID).
|
||||
Msg("Failed to clear cached member list after fetching state")
|
||||
}
|
||||
for _, evts := range stateMap {
|
||||
for evtType, evts := range stateMap {
|
||||
if evtType == event.StateMember {
|
||||
continue
|
||||
}
|
||||
for _, evt := range evts {
|
||||
UpdateStateStore(ctx, cli.StateStore, evt)
|
||||
}
|
||||
}
|
||||
clearErr = cli.StateStore.MarkMembersFetched(ctx, roomID)
|
||||
if clearErr != nil {
|
||||
cli.cliOrContextLog(ctx).Warn().Err(clearErr).
|
||||
Msg("Failed to mark members as fetched after fetching full room state")
|
||||
updateErr := cli.StateStore.ReplaceCachedMembers(ctx, roomID, maps.Values(stateMap[event.StateMember]))
|
||||
if updateErr != nil {
|
||||
cli.cliOrContextLog(ctx).Warn().Err(updateErr).
|
||||
Stringer("room_id", roomID).
|
||||
Msg("Failed to update members in state store after fetching members")
|
||||
}
|
||||
}
|
||||
return
|
||||
|
|
@ -1864,24 +1869,26 @@ func (cli *Client) JoinedMembers(ctx context.Context, roomID id.RoomID) (resp *R
|
|||
u := cli.BuildClientURL("v3", "rooms", roomID, "joined_members")
|
||||
_, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &resp)
|
||||
if err == nil && cli.StateStore != nil {
|
||||
clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID, event.MembershipJoin)
|
||||
if clearErr != nil {
|
||||
cli.cliOrContextLog(ctx).Warn().Err(clearErr).
|
||||
Stringer("room_id", roomID).
|
||||
Msg("Failed to clear cached member list after fetching joined members")
|
||||
}
|
||||
fakeEvents := make([]*event.Event, len(resp.Joined))
|
||||
i := 0
|
||||
for userID, member := range resp.Joined {
|
||||
updateErr := cli.StateStore.SetMember(ctx, roomID, userID, &event.MemberEventContent{
|
||||
Membership: event.MembershipJoin,
|
||||
AvatarURL: id.ContentURIString(member.AvatarURL),
|
||||
Displayname: member.DisplayName,
|
||||
})
|
||||
if updateErr != nil {
|
||||
cli.cliOrContextLog(ctx).Warn().Err(updateErr).
|
||||
Stringer("room_id", roomID).
|
||||
Stringer("user_id", userID).
|
||||
Msg("Failed to update membership in state store after fetching joined members")
|
||||
fakeEvents[i] = &event.Event{
|
||||
StateKey: ptr.Ptr(userID.String()),
|
||||
Type: event.StateMember,
|
||||
RoomID: roomID,
|
||||
Content: event.Content{Parsed: &event.MemberEventContent{
|
||||
Membership: event.MembershipJoin,
|
||||
AvatarURL: id.ContentURIString(member.AvatarURL),
|
||||
Displayname: member.DisplayName,
|
||||
}},
|
||||
}
|
||||
i++
|
||||
}
|
||||
updateErr := cli.StateStore.ReplaceCachedMembers(ctx, roomID, fakeEvents, event.MembershipJoin)
|
||||
if updateErr != nil {
|
||||
cli.cliOrContextLog(ctx).Warn().Err(updateErr).
|
||||
Stringer("room_id", roomID).
|
||||
Msg("Failed to update members in state store after fetching joined members")
|
||||
}
|
||||
}
|
||||
return
|
||||
|
|
@ -1910,27 +1917,20 @@ func (cli *Client) Members(ctx context.Context, roomID id.RoomID, req ...ReqMemb
|
|||
}
|
||||
}
|
||||
if err == nil && cli.StateStore != nil {
|
||||
var clearMemberships []event.Membership
|
||||
var onlyMemberships []event.Membership
|
||||
if extra.Membership != "" {
|
||||
clearMemberships = append(clearMemberships, extra.Membership)
|
||||
onlyMemberships = []event.Membership{extra.Membership}
|
||||
} else if extra.NotMembership != "" {
|
||||
onlyMemberships = []event.Membership{event.MembershipJoin, event.MembershipLeave, event.MembershipInvite, event.MembershipBan, event.MembershipKnock}
|
||||
onlyMemberships = slices.DeleteFunc(onlyMemberships, func(m event.Membership) bool {
|
||||
return m == extra.NotMembership
|
||||
})
|
||||
}
|
||||
if extra.NotMembership == "" {
|
||||
clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID, clearMemberships...)
|
||||
if clearErr != nil {
|
||||
cli.cliOrContextLog(ctx).Warn().Err(clearErr).
|
||||
Stringer("room_id", roomID).
|
||||
Msg("Failed to clear cached member list after fetching joined members")
|
||||
}
|
||||
}
|
||||
for _, evt := range resp.Chunk {
|
||||
UpdateStateStore(ctx, cli.StateStore, evt)
|
||||
}
|
||||
if extra.NotMembership == "" && extra.Membership == "" {
|
||||
markErr := cli.StateStore.MarkMembersFetched(ctx, roomID)
|
||||
if markErr != nil {
|
||||
cli.cliOrContextLog(ctx).Warn().Err(markErr).
|
||||
Msg("Failed to mark members as fetched after fetching full member list")
|
||||
}
|
||||
updateErr := cli.StateStore.ReplaceCachedMembers(ctx, roomID, resp.Chunk, onlyMemberships...)
|
||||
if updateErr != nil {
|
||||
cli.cliOrContextLog(ctx).Warn().Err(updateErr).
|
||||
Stringer("room_id", roomID).
|
||||
Msg("Failed to update members in state store after fetching members")
|
||||
}
|
||||
}
|
||||
return
|
||||
|
|
|
|||
|
|
@ -245,8 +245,22 @@ func (mach *OlmMachine) HandleDeviceLists(ctx context.Context, dl *mautrix.Devic
|
|||
}
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) otkCountIsForCrossSigningKey(otkCount *mautrix.OTKCount) bool {
|
||||
if mach.crossSigningPubkeys == nil || otkCount.UserID != mach.Client.UserID {
|
||||
return false
|
||||
}
|
||||
switch id.Ed25519(otkCount.DeviceID) {
|
||||
case mach.crossSigningPubkeys.MasterKey, mach.crossSigningPubkeys.UserSigningKey, mach.crossSigningPubkeys.SelfSigningKey:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) HandleOTKCounts(ctx context.Context, otkCount *mautrix.OTKCount) {
|
||||
if (len(otkCount.UserID) > 0 && otkCount.UserID != mach.Client.UserID) || (len(otkCount.DeviceID) > 0 && otkCount.DeviceID != mach.Client.DeviceID) {
|
||||
if mach.otkCountIsForCrossSigningKey(otkCount) {
|
||||
return
|
||||
}
|
||||
// TODO This log probably needs to be silence-able if someone wants to use encrypted appservices with multiple e2ee sessions
|
||||
mach.Log.Warn().
|
||||
Str("target_user_id", otkCount.UserID.String()).
|
||||
|
|
|
|||
|
|
@ -187,6 +187,7 @@ type PolicyRecommendation string
|
|||
const (
|
||||
PolicyRecommendationBan PolicyRecommendation = "m.ban"
|
||||
PolicyRecommendationUnstableBan PolicyRecommendation = "org.matrix.mjolnir.ban"
|
||||
PolicyRecommendationUnban PolicyRecommendation = "fi.mau.meowlnir.unban"
|
||||
)
|
||||
|
||||
// ModPolicyContent represents the content of a m.room.rule.user, m.room.rule.room, and m.room.rule.server state event.
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
|
@ -140,7 +141,7 @@ func RequestWellKnown(ctx context.Context, cli *http.Client, hostname string) (*
|
|||
return nil, time.Time{}, fmt.Errorf("unexpected status code %d", resp.StatusCode)
|
||||
}
|
||||
var respData RespWellKnown
|
||||
err = json.NewDecoder(resp.Body).Decode(&respData)
|
||||
err = json.NewDecoder(io.LimitReader(resp.Body, 50*1024)).Decode(&respData)
|
||||
if err != nil {
|
||||
return nil, time.Time{}, fmt.Errorf("failed to decode response: %w", err)
|
||||
} else if respData.Server == "" {
|
||||
|
|
|
|||
12
go.mod
12
go.mod
|
|
@ -8,18 +8,18 @@ require (
|
|||
github.com/gorilla/mux v1.8.0
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/mattn/go-sqlite3 v1.14.22
|
||||
github.com/rs/xid v1.5.0
|
||||
github.com/mattn/go-sqlite3 v1.14.23
|
||||
github.com/rs/xid v1.6.0
|
||||
github.com/rs/zerolog v1.33.0
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tidwall/gjson v1.17.3
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
github.com/yuin/goldmark v1.7.4
|
||||
go.mau.fi/util v0.7.1-0.20240830150939-8c1e9c295943
|
||||
go.mau.fi/util v0.7.1-0.20240904173517-ca3b3fe376c2
|
||||
go.mau.fi/zeroconfig v0.1.3
|
||||
golang.org/x/crypto v0.26.0
|
||||
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa
|
||||
golang.org/x/exp v0.0.0-20240823005443-9b4947da3948
|
||||
golang.org/x/net v0.28.0
|
||||
golang.org/x/sync v0.8.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
|
|
@ -35,7 +35,7 @@ require (
|
|||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
golang.org/x/sys v0.24.0 // indirect
|
||||
golang.org/x/text v0.17.0 // indirect
|
||||
golang.org/x/sys v0.25.0 // indirect
|
||||
golang.org/x/text v0.18.0 // indirect
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
|
||||
)
|
||||
|
|
|
|||
23
go.sum
23
go.sum
|
|
@ -24,15 +24,16 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
|
|||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
|
||||
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.23 h1:gbShiuAP1W5j9UOksQ06aiiqPMxYecovVGwmTxWtuw0=
|
||||
github.com/mattn/go-sqlite3 v1.14.23/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7 h1:Dx7Ovyv/SFnMFw3fD4oEoeorXc6saIiQ23LrGLth0Gw=
|
||||
github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
|
||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU=
|
||||
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
|
||||
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
|
||||
|
|
@ -50,14 +51,14 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
|||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg=
|
||||
github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
|
||||
go.mau.fi/util v0.7.1-0.20240830150939-8c1e9c295943 h1:wdJ9XC/M6lVUrwDltHPodaA3SRJq+S+AzGEXdQ/o2AQ=
|
||||
go.mau.fi/util v0.7.1-0.20240830150939-8c1e9c295943/go.mod h1:WuAOOV0O/otkxGkFUvfv/XE2ztegaoyM15ovS6SYbf4=
|
||||
go.mau.fi/util v0.7.1-0.20240904173517-ca3b3fe376c2 h1:VZQlKBbeJ7KOlYSh6BnN5uWQTY/ypn/bJv0YyEd+pXc=
|
||||
go.mau.fi/util v0.7.1-0.20240904173517-ca3b3fe376c2/go.mod h1:WgYvbt9rVmoFeajP97NunQU7AjgvTPiNExN3oTHeePs=
|
||||
go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM=
|
||||
go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70=
|
||||
golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw=
|
||||
golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54=
|
||||
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa h1:ELnwvuAXPNtPk1TJRuGkI9fDTwym6AYBu0qzT8AcHdI=
|
||||
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ=
|
||||
golang.org/x/exp v0.0.0-20240823005443-9b4947da3948 h1:kx6Ds3MlpiUHKj7syVnbp57++8WpuKPcR5yjLBjvLEA=
|
||||
golang.org/x/exp v0.0.0-20240823005443-9b4947da3948/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ=
|
||||
golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE=
|
||||
golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg=
|
||||
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
|
||||
|
|
@ -66,10 +67,10 @@ golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg=
|
||||
golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc=
|
||||
golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
|
||||
golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34=
|
||||
golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224=
|
||||
golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
|
||||
|
|
|
|||
|
|
@ -174,3 +174,7 @@ func (c *ClientStateStore) SetEncryptionEvent(ctx context.Context, roomID id.Roo
|
|||
}
|
||||
|
||||
func (c *ClientStateStore) UpdateState(ctx context.Context, evt *event.Event) {}
|
||||
|
||||
func (c *ClientStateStore) ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error {
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -83,6 +83,7 @@ type ReqLogin struct {
|
|||
Token string `json:"token,omitempty"`
|
||||
DeviceID id.DeviceID `json:"device_id,omitempty"`
|
||||
InitialDeviceDisplayName string `json:"initial_device_display_name,omitempty"`
|
||||
RefreshToken bool `json:"refresh_token,omitempty"`
|
||||
|
||||
// Whether or not the returned credentials should be stored in the Client
|
||||
StoreCredentials bool `json:"-"`
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import (
|
|||
"github.com/rs/zerolog"
|
||||
"go.mau.fi/util/confusable"
|
||||
"go.mau.fi/util/dbutil"
|
||||
"go.mau.fi/util/exslices"
|
||||
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
|
@ -194,21 +195,37 @@ func (store *SQLStateStore) SetMembership(ctx context.Context, roomID id.RoomID,
|
|||
return err
|
||||
}
|
||||
|
||||
const insertUserProfileQuery = `
|
||||
INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url, name_skeleton)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT (room_id, user_id) DO UPDATE
|
||||
SET membership=excluded.membership,
|
||||
displayname=excluded.displayname,
|
||||
avatar_url=excluded.avatar_url,
|
||||
name_skeleton=excluded.name_skeleton
|
||||
`
|
||||
|
||||
type userProfileRow struct {
|
||||
UserID id.UserID
|
||||
Membership event.Membership
|
||||
Displayname string
|
||||
AvatarURL id.ContentURIString
|
||||
NameSkeleton []byte
|
||||
}
|
||||
|
||||
func (u *userProfileRow) GetMassInsertValues() [5]any {
|
||||
return [5]any{u.UserID, u.Membership, u.Displayname, u.AvatarURL, u.NameSkeleton}
|
||||
}
|
||||
|
||||
var userProfileMassInserter = dbutil.NewMassInsertBuilder[*userProfileRow, [1]any](insertUserProfileQuery, "($1, $%d, $%d, $%d, $%d, $%d)")
|
||||
|
||||
func (store *SQLStateStore) SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error {
|
||||
var nameSkeleton []byte
|
||||
if !store.DisableNameDisambiguation && len(member.Displayname) > 0 {
|
||||
nameSkeletonArr := confusable.SkeletonHash(member.Displayname)
|
||||
nameSkeleton = nameSkeletonArr[:]
|
||||
}
|
||||
_, err := store.Exec(ctx, `
|
||||
INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url, name_skeleton)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT (room_id, user_id) DO UPDATE
|
||||
SET membership=excluded.membership,
|
||||
displayname=excluded.displayname,
|
||||
avatar_url=excluded.avatar_url,
|
||||
name_skeleton=excluded.name_skeleton
|
||||
`, roomID, userID, member.Membership, member.Displayname, member.AvatarURL, nameSkeleton)
|
||||
_, err := store.Exec(ctx, insertUserProfileQuery, roomID, userID, member.Membership, member.Displayname, member.AvatarURL, nameSkeleton)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -221,6 +238,50 @@ func (store *SQLStateStore) IsConfusableName(ctx context.Context, roomID id.Room
|
|||
return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList()
|
||||
}
|
||||
|
||||
const userProfileMassInsertBatchSize = 500
|
||||
|
||||
func (store *SQLStateStore) ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error {
|
||||
return store.DoTxn(ctx, nil, func(ctx context.Context) error {
|
||||
err := store.ClearCachedMembers(ctx, roomID, onlyMemberships...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to clear cached members: %w", err)
|
||||
}
|
||||
rows := make([]*userProfileRow, min(len(evts), userProfileMassInsertBatchSize))
|
||||
for _, evtsChunk := range exslices.Chunk(evts, userProfileMassInsertBatchSize) {
|
||||
rows = rows[:0]
|
||||
for _, evt := range evtsChunk {
|
||||
content, ok := evt.Content.Parsed.(*event.MemberEventContent)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
row := &userProfileRow{
|
||||
UserID: id.UserID(*evt.StateKey),
|
||||
Membership: content.Membership,
|
||||
Displayname: content.Displayname,
|
||||
AvatarURL: content.AvatarURL,
|
||||
}
|
||||
if !store.DisableNameDisambiguation && len(content.Displayname) > 0 {
|
||||
nameSkeletonArr := confusable.SkeletonHash(content.Displayname)
|
||||
row.NameSkeleton = nameSkeletonArr[:]
|
||||
}
|
||||
rows = append(rows, row)
|
||||
}
|
||||
query, args := userProfileMassInserter.Build([1]any{roomID}, rows)
|
||||
_, err = store.Exec(ctx, query, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to insert members: %w", err)
|
||||
}
|
||||
}
|
||||
if len(onlyMemberships) == 0 {
|
||||
err = store.MarkMembersFetched(ctx, roomID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to mark members as fetched: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error {
|
||||
query := "DELETE FROM mx_user_profile WHERE room_id=$1"
|
||||
params := make([]any, len(memberships)+1)
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ type StateStore interface {
|
|||
SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error
|
||||
IsConfusableName(ctx context.Context, roomID id.RoomID, currentUser id.UserID, name string) ([]id.UserID, error)
|
||||
ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error
|
||||
ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error
|
||||
|
||||
SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error
|
||||
GetPowerLevels(ctx context.Context, roomID id.RoomID) (*event.PowerLevelsEventContent, error)
|
||||
|
|
@ -270,9 +271,20 @@ func (store *MemoryStateStore) MarkMembersFetched(ctx context.Context, roomID id
|
|||
return nil
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error {
|
||||
_ = store.ClearCachedMembers(ctx, roomID, onlyMemberships...)
|
||||
for _, evt := range evts {
|
||||
UpdateStateStore(ctx, store, evt)
|
||||
}
|
||||
if len(onlyMemberships) == 0 {
|
||||
_ = store.MarkMembersFetched(ctx, roomID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) {
|
||||
store.membersLock.Lock()
|
||||
defer store.membersLock.Unlock()
|
||||
store.membersLock.RLock()
|
||||
defer store.membersLock.RUnlock()
|
||||
return maps.Clone(store.Members[roomID]), nil
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue