Merge branch 'main' into tulir/extended-profiles-for-everyone
Some checks failed
Go / Lint (latest) (push) Has been cancelled
Go / Build (old, libolm) (push) Has been cancelled
Go / Build (latest, libolm) (push) Has been cancelled
Go / Build (old, goolm) (push) Has been cancelled
Go / Build (latest, goolm) (push) Has been cancelled

This commit is contained in:
Tulir Asokan 2024-09-05 01:44:37 +03:00 committed by GitHub
commit e4e8933b33
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 397 additions and 183 deletions

View file

@ -59,6 +59,7 @@ type BridgeConfig struct {
CommandPrefix string `yaml:"command_prefix"`
PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"`
PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"`
AsyncEvents bool `yaml:"async_events"`
BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"`
TagOnlyOnCreate bool `yaml:"tag_only_on_create"`
MuteOnlyOnCreate bool `yaml:"mute_only_on_create"`

View file

@ -25,6 +25,7 @@ func doUpgrade(helper up.Helper) {
helper.Copy(up.Str, "bridge", "command_prefix")
helper.Copy(up.Bool, "bridge", "personal_filtering_spaces")
helper.Copy(up.Bool, "bridge", "private_chat_portal_meta")
helper.Copy(up.Bool, "bridge", "async_events")
helper.Copy(up.Bool, "bridge", "bridge_matrix_leave")
helper.Copy(up.Bool, "bridge", "tag_only_on_create")
helper.Copy(up.Bool, "bridge", "mute_only_on_create")

View file

@ -289,7 +289,8 @@ func (as *ASIntent) UploadMediaStream(
var res *bridgev2.FileStreamResult
res, err = cb(tempFile)
if err != nil {
err = fmt.Errorf("failed to write to temp file: %w", err)
err = fmt.Errorf("write callback failed: %w", err)
return
}
var replFile *os.File
if res.ReplacementFile != "" {

View file

@ -7,6 +7,9 @@ bridge:
# Whether the bridge should set names and avatars explicitly for DM portals.
# This is only necessary when using clients that don't support MSC4171.
private_chat_portal_meta: false
# Should events be handled asynchronously within portal rooms?
# If true, events may end up being out of order, but slow events won't block other ones.
async_events: false
# Should leaving Matrix rooms be bridged as leaving groups on the remote network?
bridge_matrix_leave: false

View file

@ -54,6 +54,11 @@ type ProvisioningAPI struct {
// GetAuthFromRequest is a custom function for getting the auth token from
// the request if the Authorization header is not present.
GetAuthFromRequest func(r *http.Request) string
// GetUserIDFromRequest is a custom function for getting the user ID to
// authenticate as instead of using the user ID provided in the query
// parameter.
GetUserIDFromRequest func(r *http.Request) id.UserID
}
type ProvLogin struct {
@ -200,6 +205,9 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
return
}
userID := id.UserID(r.URL.Query().Get("user_id"))
if userID == "" && prov.GetUserIDFromRequest != nil {
userID = prov.GetUserIDFromRequest(r)
}
if auth != prov.br.Config.Provisioning.SharedSecret {
var err error
if strings.HasPrefix(auth, "openid:") {
@ -227,6 +235,14 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
return
}
// TODO handle user being nil?
// TODO per-endpoint permissions?
if !user.Permissions.Login {
jsonResponse(w, http.StatusForbidden, &mautrix.RespError{
Err: "User does not have login permissions",
ErrCode: mautrix.MForbidden.ErrCode,
})
return
}
ctx := context.WithValue(r.Context(), provisioningUserKey, user)
if loginID, ok := mux.Vars(r)["loginProcessID"]; ok {

View file

@ -273,11 +273,16 @@ type MaxFileSizeingNetwork interface {
SetMaxFileSize(maxSize int64)
}
type RemoteEchoHandler func(RemoteMessage, *database.Message) (bool, error)
type MatrixMessageResponse struct {
DB *database.Message
Pending networkid.TransactionID
HandleEcho func(RemoteMessage, *database.Message) (bool, error)
// If Pending is set, the bridge will not save the provided message to the database.
// This should only be used if AddPendingToSave has been called.
Pending bool
// If RemovePending is set, the bridge will remove the provided transaction ID from pending messages
// after saving the provided message to the database. This should be used with AddPendingToIgnore.
RemovePending networkid.TransactionID
}
type FileRestriction struct {

View file

@ -14,6 +14,7 @@ import (
"runtime/debug"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/rs/zerolog"
@ -59,6 +60,7 @@ type portalEvent interface {
type outgoingMessage struct {
db *database.Message
evt *event.Event
ignore bool
handle func(RemoteMessage, *database.Message) (bool, error)
}
@ -274,23 +276,49 @@ func (portal *Portal) queueEvent(ctx context.Context, evt portalEvent) {
func (portal *Portal) eventLoop() {
for rawEvt := range portal.events {
switch evt := rawEvt.(type) {
case *portalMatrixEvent:
portal.handleMatrixEvent(evt.sender, evt.evt)
case *portalRemoteEvent:
portal.handleRemoteEvent(evt.source, evt.evt)
case *portalCreateEvent:
portal.handleCreateEvent(evt)
default:
panic(fmt.Errorf("illegal type %T in eventLoop", evt))
}
portal.handleSingleEventAsync(rawEvt)
}
}
func (portal *Portal) handleCreateEvent(evt *portalCreateEvent) {
func (portal *Portal) handleSingleEventAsync(rawEvt any) {
log := portal.Log.With().Logger()
if _, isCreate := rawEvt.(*portalCreateEvent); isCreate {
portal.handleSingleEvent(&log, rawEvt, func() {})
} else if portal.Bridge.Config.AsyncEvents {
go portal.handleSingleEvent(&log, rawEvt, func() {})
} else {
doneCh := make(chan struct{})
var backgrounded atomic.Bool
go portal.handleSingleEvent(&log, rawEvt, func() {
close(doneCh)
if backgrounded.Load() {
log.Debug().Msg("Event that took too long finally finished handling")
}
})
tick := time.NewTicker(30 * time.Second)
defer tick.Stop()
for i := 0; i < 10; i++ {
select {
case <-doneCh:
if i > 0 {
log.Debug().Msg("Event that took long finished handling")
}
return
case <-tick.C:
log.Warn().Msg("Event handling is taking long")
}
}
log.Warn().Msg("Event handling is taking too long, continuing in background")
backgrounded.Store(true)
}
}
func (portal *Portal) handleSingleEvent(log *zerolog.Logger, rawEvt any, doneCallback func()) {
ctx := log.WithContext(context.Background())
defer func() {
doneCallback()
if err := recover(); err != nil {
logEvt := zerolog.Ctx(evt.ctx).Error()
logEvt := log.Error()
if realErr, ok := err.(error); ok {
logEvt = logEvt.Err(realErr)
} else {
@ -299,10 +327,36 @@ func (portal *Portal) handleCreateEvent(evt *portalCreateEvent) {
logEvt.
Bytes("stack", debug.Stack()).
Msg("Portal creation panicked")
evt.cb(fmt.Errorf("portal creation panicked"))
switch evt := rawEvt.(type) {
case *portalMatrixEvent:
if evt.evt.ID != "" {
go portal.sendErrorStatus(ctx, evt.evt, ErrPanicInEventHandler)
}
case *portalCreateEvent:
evt.cb(fmt.Errorf("portal creation panicked"))
}
}
}()
evt.cb(portal.createMatrixRoomInLoop(evt.ctx, evt.source, evt.info, nil))
switch evt := rawEvt.(type) {
case *portalMatrixEvent:
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str("action", "handle matrix event").
Stringer("event_id", evt.evt.ID).
Str("event_type", evt.evt.Type.Type)
})
portal.handleMatrixEvent(ctx, evt.sender, evt.evt)
case *portalRemoteEvent:
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str("action", "handle remote event").
Str("source_id", string(evt.source.ID))
})
portal.handleRemoteEvent(ctx, evt.source, evt.evt)
case *portalCreateEvent:
*log = *zerolog.Ctx(evt.ctx)
evt.cb(portal.createMatrixRoomInLoop(evt.ctx, evt.source, evt.info, nil))
default:
panic(fmt.Errorf("illegal type %T in eventLoop", evt))
}
}
func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowRelay bool) (*UserLogin, *database.UserPortal, error) {
@ -392,29 +446,8 @@ func (portal *Portal) checkConfusableName(ctx context.Context, userID id.UserID,
return false
}
func (portal *Portal) handleMatrixEvent(sender *User, evt *event.Event) {
log := portal.Log.With().
Str("action", "handle matrix event").
Stringer("event_id", evt.ID).
Str("event_type", evt.Type.Type).
Logger()
ctx := log.WithContext(context.TODO())
defer func() {
if err := recover(); err != nil {
logEvt := log.Error()
if realErr, ok := err.(error); ok {
logEvt = logEvt.Err(realErr)
} else {
logEvt = logEvt.Any(zerolog.ErrorFieldName, err)
}
logEvt.
Bytes("stack", debug.Stack()).
Msg("Matrix event handler panicked")
if evt.ID != "" {
go portal.sendErrorStatus(ctx, evt, ErrPanicInEventHandler)
}
}
}()
func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *event.Event) {
log := zerolog.Ctx(ctx)
if evt.Mautrix.EventSource&event.SourceEphemeral != 0 {
switch evt.Type {
case event.EphemeralEventReceipt:
@ -775,7 +808,7 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin
}
}
resp, err := sender.Client.HandleMatrixMessage(ctx, &MatrixMessage{
wrappedEvt := &MatrixMessage{
MatrixEventBase: MatrixEventBase[*event.MessageEventContent]{
Event: evt,
Content: content,
@ -784,52 +817,30 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin
},
ThreadRoot: threadRoot,
ReplyTo: replyTo,
})
}
resp, err := sender.Client.HandleMatrixMessage(ctx, wrappedEvt)
if err != nil {
log.Err(err).Msg("Failed to handle Matrix message")
portal.sendErrorStatus(ctx, evt, err)
return
}
message := resp.DB
if message.MXID == "" {
message.MXID = evt.ID
}
if message.Room.ID == "" {
message.Room = portal.PortalKey
}
if message.Timestamp.IsZero() {
message.Timestamp = time.UnixMilli(evt.Timestamp)
}
if message.ReplyTo.MessageID == "" && replyTo != nil {
message.ReplyTo.MessageID = replyTo.ID
message.ReplyTo.PartID = &replyTo.PartID
}
if message.ThreadRoot == "" && threadRoot != nil {
message.ThreadRoot = threadRoot.ID
if threadRoot.ThreadRoot != "" {
message.ThreadRoot = threadRoot.ThreadRoot
}
}
if message.SenderMXID == "" {
message.SenderMXID = evt.Sender
}
if resp.Pending != "" {
// TODO if the event queue is ever removed, this will have to be done by the network connector before sending the request
// (for now this is fine because incoming messages will wait in the queue for this function to return)
portal.outgoingMessagesLock.Lock()
portal.outgoingMessages[resp.Pending] = outgoingMessage{
db: message,
evt: evt,
handle: resp.HandleEcho,
}
portal.outgoingMessagesLock.Unlock()
} else {
// Hack to ensure the ghost row exists
// TODO move to better place (like login)
portal.Bridge.GetGhostByID(ctx, message.SenderID)
err = portal.Bridge.DB.Message.Insert(ctx, message)
if err != nil {
log.Err(err).Msg("Failed to save message to database")
message := wrappedEvt.fillDBMessage(resp.DB)
if !resp.Pending {
if resp.DB == nil {
log.Error().Msg("Network connector didn't return a message to save")
} else {
// Hack to ensure the ghost row exists
// TODO move to better place (like login)
portal.Bridge.GetGhostByID(ctx, message.SenderID)
err = portal.Bridge.DB.Message.Insert(ctx, message)
if err != nil {
log.Err(err).Msg("Failed to save message to database")
}
if resp.RemovePending != "" {
portal.outgoingMessagesLock.Lock()
delete(portal.outgoingMessages, resp.RemovePending)
portal.outgoingMessagesLock.Unlock()
}
}
portal.sendSuccessStatus(ctx, evt)
}
@ -846,6 +857,75 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin
}
}
// AddPendingToIgnore adds a transaction ID that should be ignored if encountered as a new message.
//
// This should be used when the network connector will return the real message ID from HandleMatrixMessage.
// The [MatrixMessageResponse] should include RemovePending with the transaction ID sto remove it from the lit
// after saving to database.
//
// See also: [MatrixMessage.AddPendingToSave]
func (evt *MatrixMessage) AddPendingToIgnore(txnID networkid.TransactionID) {
evt.Portal.outgoingMessagesLock.Lock()
evt.Portal.outgoingMessages[txnID] = outgoingMessage{
ignore: true,
}
evt.Portal.outgoingMessagesLock.Unlock()
}
// AddPendingToSave adds a transaction ID that should be processed and pointed at the existing event if encountered.
//
// This should be used when the network connector returns `Pending: true` from HandleMatrixMessage,
// i.e. when the network connector does not know the message ID at the end of the handler.
// The [MatrixMessageResponse] should set Pending to true to prevent saving the returned message to the database.
//
// The provided function will be called when the message is encountered.
func (evt *MatrixMessage) AddPendingToSave(message *database.Message, txnID networkid.TransactionID, handleEcho RemoteEchoHandler) {
evt.Portal.outgoingMessagesLock.Lock()
evt.Portal.outgoingMessages[txnID] = outgoingMessage{
db: evt.fillDBMessage(message),
evt: evt.Event,
handle: handleEcho,
}
evt.Portal.outgoingMessagesLock.Unlock()
}
// RemovePending removes a transaction ID from the list of pending messages.
// This should only be called if sending the message fails.
func (evt *MatrixMessage) RemovePending(txnID networkid.TransactionID) {
evt.Portal.outgoingMessagesLock.Lock()
delete(evt.Portal.outgoingMessages, txnID)
evt.Portal.outgoingMessagesLock.Unlock()
}
func (evt *MatrixMessage) fillDBMessage(message *database.Message) *database.Message {
if message == nil {
message = &database.Message{}
}
if message.MXID == "" {
message.MXID = evt.Event.ID
}
if message.Room.ID == "" {
message.Room = evt.Portal.PortalKey
}
if message.Timestamp.IsZero() {
message.Timestamp = time.UnixMilli(evt.Event.Timestamp)
}
if message.ReplyTo.MessageID == "" && evt.ReplyTo != nil {
message.ReplyTo.MessageID = evt.ReplyTo.ID
message.ReplyTo.PartID = &evt.ReplyTo.PartID
}
if message.ThreadRoot == "" && evt.ThreadRoot != nil {
message.ThreadRoot = evt.ThreadRoot.ID
if evt.ThreadRoot.ThreadRoot != "" {
message.ThreadRoot = evt.ThreadRoot.ThreadRoot
}
}
if message.SenderMXID == "" {
message.SenderMXID = evt.Event.Sender
}
return message
}
func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, content *event.MessageEventContent, caps *NetworkRoomCapabilities) {
log := zerolog.Ctx(ctx)
editTargetID := content.RelatesTo.GetReplaceID()
@ -1410,11 +1490,8 @@ func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLog
portal.sendSuccessStatus(ctx, evt)
}
func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) {
log := portal.Log.With().
Str("source_id", string(source.ID)).
Str("action", "handle remote event").
Logger()
func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, evt RemoteEvent) {
log := zerolog.Ctx(ctx)
defer func() {
if err := recover(); err != nil {
logEvt := log.Error()
@ -1433,7 +1510,6 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) {
c = c.Stringer("bridge_evt_type", evtType)
return evt.AddLogContext(c)
})
ctx := log.WithContext(context.TODO())
if portal.MXID == "" {
mcp, ok := evt.(RemoteEventThatMayCreatePortal)
if !ok || !mcp.ShouldCreatePortal() {
@ -1715,6 +1791,8 @@ func (portal *Portal) checkPendingMessage(ctx context.Context, evt RemoteMessage
pending, ok := portal.outgoingMessages[txnID]
if !ok {
return false, nil
} else if pending.ignore {
return true, nil
}
delete(portal.outgoingMessages, txnID)
pending.db.ID = evt.GetID()
@ -1773,7 +1851,8 @@ func (portal *Portal) handleRemoteUpsert(ctx context.Context, source *UserLogin,
}
if len(res.SubEvents) > 0 {
for _, subEvt := range res.SubEvents {
portal.handleRemoteEvent(source, subEvt)
log := portal.Log.With().Str("source_id", string(source.ID)).Str("action", "handle remote subevent").Logger()
portal.handleRemoteEvent(log.WithContext(ctx), source, subEvt)
}
}
return res.ContinueMessageHandling

View file

@ -304,17 +304,6 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin
partIDs = append(partIDs, part.ID)
portal.applyRelationMeta(part.Content, replyTo, threadRoot, prevThreadEvent)
evtID := portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, msg.ID, part.ID)
out.Events = append(out.Events, &event.Event{
Sender: intent.GetMXID(),
Type: part.Type,
Timestamp: msg.Timestamp.UnixMilli(),
ID: evtID,
RoomID: portal.MXID,
Content: event.Content{
Parsed: part.Content,
Raw: part.Extra,
},
})
dbMessage := &database.Message{
ID: msg.ID,
PartID: part.ID,
@ -327,6 +316,22 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin
ReplyTo: ptr.Val(msg.ReplyTo),
Metadata: part.DBMetadata,
}
if part.DontBridge {
dbMessage.SetFakeMXID()
out.DBMessages = append(out.DBMessages, dbMessage)
continue
}
out.Events = append(out.Events, &event.Event{
Sender: intent.GetMXID(),
Type: part.Type,
Timestamp: msg.Timestamp.UnixMilli(),
ID: evtID,
RoomID: portal.MXID,
Content: event.Content{
Parsed: part.Content,
Raw: part.Extra,
},
})
if firstPart == nil {
firstPart = dbMessage
}

View file

@ -37,8 +37,12 @@ func (portal *PortalInternals) EventLoop() {
(*Portal)(portal).eventLoop()
}
func (portal *PortalInternals) HandleCreateEvent(evt *portalCreateEvent) {
(*Portal)(portal).handleCreateEvent(evt)
func (portal *PortalInternals) HandleSingleEventAsync(rawEvt any) {
(*Portal)(portal).handleSingleEventAsync(rawEvt)
}
func (portal *PortalInternals) HandleSingleEvent(log *zerolog.Logger, rawEvt any, doneCallback func()) {
(*Portal)(portal).handleSingleEvent(log, rawEvt, doneCallback)
}
func (portal *PortalInternals) SendSuccessStatus(ctx context.Context, evt *event.Event) {
@ -53,8 +57,8 @@ func (portal *PortalInternals) CheckConfusableName(ctx context.Context, userID i
return (*Portal)(portal).checkConfusableName(ctx, userID, name)
}
func (portal *PortalInternals) HandleMatrixEvent(sender *User, evt *event.Event) {
(*Portal)(portal).handleMatrixEvent(sender, evt)
func (portal *PortalInternals) HandleMatrixEvent(ctx context.Context, sender *User, evt *event.Event) {
(*Portal)(portal).handleMatrixEvent(ctx, sender, evt)
}
func (portal *PortalInternals) HandleMatrixReceipts(ctx context.Context, evt *event.Event) {
@ -109,8 +113,8 @@ func (portal *PortalInternals) HandleMatrixRedaction(ctx context.Context, sender
(*Portal)(portal).handleMatrixRedaction(ctx, sender, origSender, evt)
}
func (portal *PortalInternals) HandleRemoteEvent(source *UserLogin, evt RemoteEvent) {
(*Portal)(portal).handleRemoteEvent(source, evt)
func (portal *PortalInternals) HandleRemoteEvent(ctx context.Context, source *UserLogin, evt RemoteEvent) {
(*Portal)(portal).handleRemoteEvent(ctx, source, evt)
}
func (portal *PortalInternals) GetIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID) {
@ -297,6 +301,10 @@ func (portal *PortalInternals) DoThreadBackfill(ctx context.Context, source *Use
(*Portal)(portal).doThreadBackfill(ctx, source, threadID)
}
func (portal *PortalInternals) CutoffMessages(ctx context.Context, messages []*BackfillMessage, aggressiveDedup, forward bool, lastMessage *database.Message) []*BackfillMessage {
return (*Portal)(portal).cutoffMessages(ctx, messages, aggressiveDedup, forward, lastMessage)
}
func (portal *PortalInternals) SendBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead, inThread bool) {
(*Portal)(portal).sendBackfill(ctx, source, messages, forceForward, markRead, inThread)
}

View file

@ -13,6 +13,7 @@ import (
"net/http"
"net/url"
"os"
"slices"
"strconv"
"strings"
"sync/atomic"
@ -20,7 +21,9 @@ import (
"github.com/rs/zerolog"
"github.com/tidwall/gjson"
"go.mau.fi/util/ptr"
"go.mau.fi/util/retryafter"
"golang.org/x/exp/maps"
"maunium.net/go/mautrix/crypto/backup"
"maunium.net/go/mautrix/event"
@ -322,7 +325,9 @@ func (cli *Client) RequestStart(req *http.Request) {
func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err error, handlerErr error, contentLength int, duration time.Duration) {
var evt *zerolog.Event
if err != nil {
if errors.Is(err, context.Canceled) {
evt = zerolog.Ctx(req.Context()).Warn()
} else if err != nil {
evt = zerolog.Ctx(req.Context()).Err(err)
} else if handlerErr != nil {
evt = zerolog.Ctx(req.Context()).Warn().
@ -355,7 +360,9 @@ func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err er
if body := req.Context().Value(LogBodyContextKey); body != nil {
evt.Interface("req_body", body)
}
if err != nil {
if errors.Is(err, context.Canceled) {
evt.Msg("Request canceled")
} else if err != nil {
evt.Msg("Request failed")
} else if handlerErr != nil {
evt.Msg("Request parsing failed")
@ -1498,21 +1505,19 @@ func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomSt
Handler: parseRoomStateArray,
})
if err == nil && cli.StateStore != nil {
clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID)
if clearErr != nil {
cli.cliOrContextLog(ctx).Warn().Err(clearErr).
Stringer("room_id", roomID).
Msg("Failed to clear cached member list after fetching state")
}
for _, evts := range stateMap {
for evtType, evts := range stateMap {
if evtType == event.StateMember {
continue
}
for _, evt := range evts {
UpdateStateStore(ctx, cli.StateStore, evt)
}
}
clearErr = cli.StateStore.MarkMembersFetched(ctx, roomID)
if clearErr != nil {
cli.cliOrContextLog(ctx).Warn().Err(clearErr).
Msg("Failed to mark members as fetched after fetching full room state")
updateErr := cli.StateStore.ReplaceCachedMembers(ctx, roomID, maps.Values(stateMap[event.StateMember]))
if updateErr != nil {
cli.cliOrContextLog(ctx).Warn().Err(updateErr).
Stringer("room_id", roomID).
Msg("Failed to update members in state store after fetching members")
}
}
return
@ -1864,24 +1869,26 @@ func (cli *Client) JoinedMembers(ctx context.Context, roomID id.RoomID) (resp *R
u := cli.BuildClientURL("v3", "rooms", roomID, "joined_members")
_, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &resp)
if err == nil && cli.StateStore != nil {
clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID, event.MembershipJoin)
if clearErr != nil {
cli.cliOrContextLog(ctx).Warn().Err(clearErr).
Stringer("room_id", roomID).
Msg("Failed to clear cached member list after fetching joined members")
}
fakeEvents := make([]*event.Event, len(resp.Joined))
i := 0
for userID, member := range resp.Joined {
updateErr := cli.StateStore.SetMember(ctx, roomID, userID, &event.MemberEventContent{
Membership: event.MembershipJoin,
AvatarURL: id.ContentURIString(member.AvatarURL),
Displayname: member.DisplayName,
})
if updateErr != nil {
cli.cliOrContextLog(ctx).Warn().Err(updateErr).
Stringer("room_id", roomID).
Stringer("user_id", userID).
Msg("Failed to update membership in state store after fetching joined members")
fakeEvents[i] = &event.Event{
StateKey: ptr.Ptr(userID.String()),
Type: event.StateMember,
RoomID: roomID,
Content: event.Content{Parsed: &event.MemberEventContent{
Membership: event.MembershipJoin,
AvatarURL: id.ContentURIString(member.AvatarURL),
Displayname: member.DisplayName,
}},
}
i++
}
updateErr := cli.StateStore.ReplaceCachedMembers(ctx, roomID, fakeEvents, event.MembershipJoin)
if updateErr != nil {
cli.cliOrContextLog(ctx).Warn().Err(updateErr).
Stringer("room_id", roomID).
Msg("Failed to update members in state store after fetching joined members")
}
}
return
@ -1910,27 +1917,20 @@ func (cli *Client) Members(ctx context.Context, roomID id.RoomID, req ...ReqMemb
}
}
if err == nil && cli.StateStore != nil {
var clearMemberships []event.Membership
var onlyMemberships []event.Membership
if extra.Membership != "" {
clearMemberships = append(clearMemberships, extra.Membership)
onlyMemberships = []event.Membership{extra.Membership}
} else if extra.NotMembership != "" {
onlyMemberships = []event.Membership{event.MembershipJoin, event.MembershipLeave, event.MembershipInvite, event.MembershipBan, event.MembershipKnock}
onlyMemberships = slices.DeleteFunc(onlyMemberships, func(m event.Membership) bool {
return m == extra.NotMembership
})
}
if extra.NotMembership == "" {
clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID, clearMemberships...)
if clearErr != nil {
cli.cliOrContextLog(ctx).Warn().Err(clearErr).
Stringer("room_id", roomID).
Msg("Failed to clear cached member list after fetching joined members")
}
}
for _, evt := range resp.Chunk {
UpdateStateStore(ctx, cli.StateStore, evt)
}
if extra.NotMembership == "" && extra.Membership == "" {
markErr := cli.StateStore.MarkMembersFetched(ctx, roomID)
if markErr != nil {
cli.cliOrContextLog(ctx).Warn().Err(markErr).
Msg("Failed to mark members as fetched after fetching full member list")
}
updateErr := cli.StateStore.ReplaceCachedMembers(ctx, roomID, resp.Chunk, onlyMemberships...)
if updateErr != nil {
cli.cliOrContextLog(ctx).Warn().Err(updateErr).
Stringer("room_id", roomID).
Msg("Failed to update members in state store after fetching members")
}
}
return

View file

@ -245,8 +245,22 @@ func (mach *OlmMachine) HandleDeviceLists(ctx context.Context, dl *mautrix.Devic
}
}
func (mach *OlmMachine) otkCountIsForCrossSigningKey(otkCount *mautrix.OTKCount) bool {
if mach.crossSigningPubkeys == nil || otkCount.UserID != mach.Client.UserID {
return false
}
switch id.Ed25519(otkCount.DeviceID) {
case mach.crossSigningPubkeys.MasterKey, mach.crossSigningPubkeys.UserSigningKey, mach.crossSigningPubkeys.SelfSigningKey:
return true
}
return false
}
func (mach *OlmMachine) HandleOTKCounts(ctx context.Context, otkCount *mautrix.OTKCount) {
if (len(otkCount.UserID) > 0 && otkCount.UserID != mach.Client.UserID) || (len(otkCount.DeviceID) > 0 && otkCount.DeviceID != mach.Client.DeviceID) {
if mach.otkCountIsForCrossSigningKey(otkCount) {
return
}
// TODO This log probably needs to be silence-able if someone wants to use encrypted appservices with multiple e2ee sessions
mach.Log.Warn().
Str("target_user_id", otkCount.UserID.String()).

View file

@ -187,6 +187,7 @@ type PolicyRecommendation string
const (
PolicyRecommendationBan PolicyRecommendation = "m.ban"
PolicyRecommendationUnstableBan PolicyRecommendation = "org.matrix.mjolnir.ban"
PolicyRecommendationUnban PolicyRecommendation = "fi.mau.meowlnir.unban"
)
// ModPolicyContent represents the content of a m.room.rule.user, m.room.rule.room, and m.room.rule.server state event.

View file

@ -11,6 +11,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
@ -140,7 +141,7 @@ func RequestWellKnown(ctx context.Context, cli *http.Client, hostname string) (*
return nil, time.Time{}, fmt.Errorf("unexpected status code %d", resp.StatusCode)
}
var respData RespWellKnown
err = json.NewDecoder(resp.Body).Decode(&respData)
err = json.NewDecoder(io.LimitReader(resp.Body, 50*1024)).Decode(&respData)
if err != nil {
return nil, time.Time{}, fmt.Errorf("failed to decode response: %w", err)
} else if respData.Server == "" {

12
go.mod
View file

@ -8,18 +8,18 @@ require (
github.com/gorilla/mux v1.8.0
github.com/gorilla/websocket v1.5.0
github.com/lib/pq v1.10.9
github.com/mattn/go-sqlite3 v1.14.22
github.com/rs/xid v1.5.0
github.com/mattn/go-sqlite3 v1.14.23
github.com/rs/xid v1.6.0
github.com/rs/zerolog v1.33.0
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.17.3
github.com/tidwall/sjson v1.2.5
github.com/yuin/goldmark v1.7.4
go.mau.fi/util v0.7.1-0.20240830150939-8c1e9c295943
go.mau.fi/util v0.7.1-0.20240904173517-ca3b3fe376c2
go.mau.fi/zeroconfig v0.1.3
golang.org/x/crypto v0.26.0
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa
golang.org/x/exp v0.0.0-20240823005443-9b4947da3948
golang.org/x/net v0.28.0
golang.org/x/sync v0.8.0
gopkg.in/yaml.v3 v3.0.1
@ -35,7 +35,7 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
golang.org/x/sys v0.24.0 // indirect
golang.org/x/text v0.17.0 // indirect
golang.org/x/sys v0.25.0 // indirect
golang.org/x/text v0.18.0 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
)

23
go.sum
View file

@ -24,15 +24,16 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/mattn/go-sqlite3 v1.14.23 h1:gbShiuAP1W5j9UOksQ06aiiqPMxYecovVGwmTxWtuw0=
github.com/mattn/go-sqlite3 v1.14.23/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7 h1:Dx7Ovyv/SFnMFw3fD4oEoeorXc6saIiQ23LrGLth0Gw=
github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU=
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
@ -50,14 +51,14 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg=
github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
go.mau.fi/util v0.7.1-0.20240830150939-8c1e9c295943 h1:wdJ9XC/M6lVUrwDltHPodaA3SRJq+S+AzGEXdQ/o2AQ=
go.mau.fi/util v0.7.1-0.20240830150939-8c1e9c295943/go.mod h1:WuAOOV0O/otkxGkFUvfv/XE2ztegaoyM15ovS6SYbf4=
go.mau.fi/util v0.7.1-0.20240904173517-ca3b3fe376c2 h1:VZQlKBbeJ7KOlYSh6BnN5uWQTY/ypn/bJv0YyEd+pXc=
go.mau.fi/util v0.7.1-0.20240904173517-ca3b3fe376c2/go.mod h1:WgYvbt9rVmoFeajP97NunQU7AjgvTPiNExN3oTHeePs=
go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM=
go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70=
golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw=
golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54=
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa h1:ELnwvuAXPNtPk1TJRuGkI9fDTwym6AYBu0qzT8AcHdI=
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ=
golang.org/x/exp v0.0.0-20240823005443-9b4947da3948 h1:kx6Ds3MlpiUHKj7syVnbp57++8WpuKPcR5yjLBjvLEA=
golang.org/x/exp v0.0.0-20240823005443-9b4947da3948/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ=
golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE=
golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg=
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
@ -66,10 +67,10 @@ golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg=
golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc=
golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34=
golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224=
golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=

View file

@ -174,3 +174,7 @@ func (c *ClientStateStore) SetEncryptionEvent(ctx context.Context, roomID id.Roo
}
func (c *ClientStateStore) UpdateState(ctx context.Context, evt *event.Event) {}
func (c *ClientStateStore) ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error {
return nil
}

View file

@ -83,6 +83,7 @@ type ReqLogin struct {
Token string `json:"token,omitempty"`
DeviceID id.DeviceID `json:"device_id,omitempty"`
InitialDeviceDisplayName string `json:"initial_device_display_name,omitempty"`
RefreshToken bool `json:"refresh_token,omitempty"`
// Whether or not the returned credentials should be stored in the Client
StoreCredentials bool `json:"-"`

View file

@ -19,6 +19,7 @@ import (
"github.com/rs/zerolog"
"go.mau.fi/util/confusable"
"go.mau.fi/util/dbutil"
"go.mau.fi/util/exslices"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
@ -194,21 +195,37 @@ func (store *SQLStateStore) SetMembership(ctx context.Context, roomID id.RoomID,
return err
}
const insertUserProfileQuery = `
INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url, name_skeleton)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (room_id, user_id) DO UPDATE
SET membership=excluded.membership,
displayname=excluded.displayname,
avatar_url=excluded.avatar_url,
name_skeleton=excluded.name_skeleton
`
type userProfileRow struct {
UserID id.UserID
Membership event.Membership
Displayname string
AvatarURL id.ContentURIString
NameSkeleton []byte
}
func (u *userProfileRow) GetMassInsertValues() [5]any {
return [5]any{u.UserID, u.Membership, u.Displayname, u.AvatarURL, u.NameSkeleton}
}
var userProfileMassInserter = dbutil.NewMassInsertBuilder[*userProfileRow, [1]any](insertUserProfileQuery, "($1, $%d, $%d, $%d, $%d, $%d)")
func (store *SQLStateStore) SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error {
var nameSkeleton []byte
if !store.DisableNameDisambiguation && len(member.Displayname) > 0 {
nameSkeletonArr := confusable.SkeletonHash(member.Displayname)
nameSkeleton = nameSkeletonArr[:]
}
_, err := store.Exec(ctx, `
INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url, name_skeleton)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (room_id, user_id) DO UPDATE
SET membership=excluded.membership,
displayname=excluded.displayname,
avatar_url=excluded.avatar_url,
name_skeleton=excluded.name_skeleton
`, roomID, userID, member.Membership, member.Displayname, member.AvatarURL, nameSkeleton)
_, err := store.Exec(ctx, insertUserProfileQuery, roomID, userID, member.Membership, member.Displayname, member.AvatarURL, nameSkeleton)
return err
}
@ -221,6 +238,50 @@ func (store *SQLStateStore) IsConfusableName(ctx context.Context, roomID id.Room
return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList()
}
const userProfileMassInsertBatchSize = 500
func (store *SQLStateStore) ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error {
return store.DoTxn(ctx, nil, func(ctx context.Context) error {
err := store.ClearCachedMembers(ctx, roomID, onlyMemberships...)
if err != nil {
return fmt.Errorf("failed to clear cached members: %w", err)
}
rows := make([]*userProfileRow, min(len(evts), userProfileMassInsertBatchSize))
for _, evtsChunk := range exslices.Chunk(evts, userProfileMassInsertBatchSize) {
rows = rows[:0]
for _, evt := range evtsChunk {
content, ok := evt.Content.Parsed.(*event.MemberEventContent)
if !ok {
continue
}
row := &userProfileRow{
UserID: id.UserID(*evt.StateKey),
Membership: content.Membership,
Displayname: content.Displayname,
AvatarURL: content.AvatarURL,
}
if !store.DisableNameDisambiguation && len(content.Displayname) > 0 {
nameSkeletonArr := confusable.SkeletonHash(content.Displayname)
row.NameSkeleton = nameSkeletonArr[:]
}
rows = append(rows, row)
}
query, args := userProfileMassInserter.Build([1]any{roomID}, rows)
_, err = store.Exec(ctx, query, args...)
if err != nil {
return fmt.Errorf("failed to insert members: %w", err)
}
}
if len(onlyMemberships) == 0 {
err = store.MarkMembersFetched(ctx, roomID)
if err != nil {
return fmt.Errorf("failed to mark members as fetched: %w", err)
}
}
return nil
})
}
func (store *SQLStateStore) ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error {
query := "DELETE FROM mx_user_profile WHERE room_id=$1"
params := make([]any, len(memberships)+1)

View file

@ -29,6 +29,7 @@ type StateStore interface {
SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error
IsConfusableName(ctx context.Context, roomID id.RoomID, currentUser id.UserID, name string) ([]id.UserID, error)
ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error
ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error
SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error
GetPowerLevels(ctx context.Context, roomID id.RoomID) (*event.PowerLevelsEventContent, error)
@ -270,9 +271,20 @@ func (store *MemoryStateStore) MarkMembersFetched(ctx context.Context, roomID id
return nil
}
func (store *MemoryStateStore) ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error {
_ = store.ClearCachedMembers(ctx, roomID, onlyMemberships...)
for _, evt := range evts {
UpdateStateStore(ctx, store, evt)
}
if len(onlyMemberships) == 0 {
_ = store.MarkMembersFetched(ctx, roomID)
}
return nil
}
func (store *MemoryStateStore) GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) {
store.membersLock.Lock()
defer store.membersLock.Unlock()
store.membersLock.RLock()
defer store.membersLock.RUnlock()
return maps.Clone(store.Members[roomID]), nil
}