%s", html.EscapeString(qr)),
+ Info: &event.FileInfo{
+ MimeType: "image/png",
+ Width: qrSizePx,
+ Height: qrSizePx,
+ Size: len(qrData),
+ },
}
if *prevEventID != "" {
content.SetEdit(*prevEventID)
@@ -217,18 +279,55 @@ func sendQR(ce *Event, qr string, prevEventID *id.EventID) error {
return nil
}
+func sendUserInputAttachments(ce *Event, atts []*bridgev2.LoginUserInputAttachment) error {
+ for _, att := range atts {
+ if att.FileName == "" {
+ return fmt.Errorf("missing attachment filename")
+ }
+ mxc, file, err := ce.Bot.UploadMedia(ce.Ctx, ce.RoomID, att.Content, att.FileName, att.Info.MimeType)
+ if err != nil {
+ return fmt.Errorf("failed to upload attachment %q: %w", att.FileName, err)
+ }
+ content := &event.MessageEventContent{
+ MsgType: att.Type,
+ FileName: att.FileName,
+ URL: mxc,
+ File: file,
+ Info: &event.FileInfo{
+ MimeType: att.Info.MimeType,
+ Width: att.Info.Width,
+ Height: att.Info.Height,
+ Size: att.Info.Size,
+ },
+ Body: att.FileName,
+ }
+ _, err = ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventMessage, &event.Content{Parsed: content}, nil)
+ if err != nil {
+ return nil
+ }
+ }
+ return nil
+}
+
type contextKey int
const (
contextKeyPrevEventID contextKey = iota
)
-func doLoginDisplayAndWait(ce *Event, login bridgev2.LoginProcessDisplayAndWait, step *bridgev2.LoginStep) {
+func doLoginDisplayAndWait(ce *Event, login bridgev2.LoginProcessDisplayAndWait, step *bridgev2.LoginStep, override *bridgev2.UserLogin) {
prevEvent, ok := ce.Ctx.Value(contextKeyPrevEventID).(*id.EventID)
if !ok {
prevEvent = new(id.EventID)
ce.Ctx = context.WithValue(ce.Ctx, contextKeyPrevEventID, prevEvent)
}
+ cancelCtx, cancelFunc := context.WithCancel(ce.Ctx)
+ defer cancelFunc()
+ StoreCommandState(ce.User, &CommandState{
+ Action: "Login",
+ Cancel: cancelFunc,
+ })
+ defer StoreCommandState(ce.User, nil)
switch step.DisplayAndWaitParams.Type {
case bridgev2.LoginDisplayTypeQR:
err := sendQR(ce, step.DisplayAndWaitParams.Data, prevEvent)
@@ -248,7 +347,7 @@ func doLoginDisplayAndWait(ce *Event, login bridgev2.LoginProcessDisplayAndWait,
login.Cancel()
return
}
- nextStep, err := login.Wait(ce.Ctx)
+ nextStep, err := login.Wait(cancelCtx)
// Redact the QR code, unless the next step is refreshing the code (in which case the event is just edited)
if *prevEvent != "" && (nextStep == nil || nextStep.StepID != step.StepID) {
_, _ = ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventRedaction, &event.Content{
@@ -262,12 +361,13 @@ func doLoginDisplayAndWait(ce *Event, login bridgev2.LoginProcessDisplayAndWait,
ce.Reply("Login failed: %v", err)
return
}
- doLoginStep(ce, login, nextStep)
+ doLoginStep(ce, login, nextStep, override)
}
type cookieLoginCommandState struct {
- Login bridgev2.LoginProcessCookies
- Data *bridgev2.LoginCookiesParams
+ Login bridgev2.LoginProcessCookies
+ Data *bridgev2.LoginCookiesParams
+ Override *bridgev2.UserLogin
}
func (clcs *cookieLoginCommandState) prompt(ce *Event) {
@@ -379,7 +479,7 @@ func (clcs *cookieLoginCommandState) submit(ce *Event) {
ce.Reply("Login failed: %v", err)
return
}
- doLoginStep(ce, clcs.Login, nextStep)
+ doLoginStep(ce, clcs.Login, nextStep, clcs.Override)
}
func maybeURLDecodeCookie(val string, field *bridgev2.LoginCookieField) string {
@@ -399,27 +499,43 @@ func maybeURLDecodeCookie(val string, field *bridgev2.LoginCookieField) string {
return decoded
}
-func doLoginStep(ce *Event, login bridgev2.LoginProcess, step *bridgev2.LoginStep) {
+func doLoginStep(ce *Event, login bridgev2.LoginProcess, step *bridgev2.LoginStep, override *bridgev2.UserLogin) {
+ ce.Log.Debug().Any("next_step", step).Msg("Got next login step")
if step.Instructions != "" {
ce.Reply(step.Instructions)
}
switch step.Type {
case bridgev2.LoginStepTypeDisplayAndWait:
- doLoginDisplayAndWait(ce, login.(bridgev2.LoginProcessDisplayAndWait), step)
+ doLoginDisplayAndWait(ce, login.(bridgev2.LoginProcessDisplayAndWait), step, override)
case bridgev2.LoginStepTypeCookies:
(&cookieLoginCommandState{
- Login: login.(bridgev2.LoginProcessCookies),
- Data: step.CookiesParams,
+ Login: login.(bridgev2.LoginProcessCookies),
+ Data: step.CookiesParams,
+ Override: override,
}).prompt(ce)
case bridgev2.LoginStepTypeUserInput:
+ err := sendUserInputAttachments(ce, step.UserInputParams.Attachments)
+ if err != nil {
+ ce.Reply("Failed to send attachments: %v", err)
+ }
(&userInputLoginCommandState{
Login: login.(bridgev2.LoginProcessUserInput),
RemainingFields: step.UserInputParams.Fields,
Data: make(map[string]string),
+ Override: override,
}).promptNext(ce)
case bridgev2.LoginStepTypeComplete:
- // Nothing to do other than instructions
+ if override != nil && override.ID != step.CompleteParams.UserLoginID {
+ ce.Log.Info().
+ Str("old_login_id", string(override.ID)).
+ Str("new_login_id", string(step.CompleteParams.UserLoginID)).
+ Msg("Login resulted in different remote ID than what was being overridden. Deleting previous login")
+ override.Delete(ce.Ctx, status.BridgeState{
+ StateEvent: status.StateLoggedOut,
+ Reason: "LOGIN_OVERRIDDEN",
+ }, bridgev2.DeleteOpts{LogoutRemote: true})
+ }
default:
panic(fmt.Errorf("unknown login step type %q", step.Type))
}
diff --git a/bridgev2/commands/processor.go b/bridgev2/commands/processor.go
index 1aca596c..391c3685 100644
--- a/bridgev2/commands/processor.go
+++ b/bridgev2/commands/processor.go
@@ -17,8 +17,7 @@ import (
"github.com/rs/zerolog"
"maunium.net/go/mautrix/bridgev2"
-
- "maunium.net/go/mautrix/bridge/status"
+ "maunium.net/go/mautrix/bridgev2/status"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
@@ -42,10 +41,11 @@ func NewProcessor(bridge *bridgev2.Bridge) bridgev2.CommandProcessor {
}
proc.AddHandlers(
CommandHelp, CommandCancel,
- CommandRegisterPush, CommandDeletePortal, CommandDeleteAllPortals,
- CommandLogin, CommandListLogins, CommandLogout, CommandSetPreferredLogin,
+ CommandRegisterPush, CommandSendAccountData, CommandResetNetwork,
+ CommandDeletePortal, CommandDeleteAllPortals, CommandSetManagementRoom,
+ CommandLogin, CommandRelogin, CommandListLogins, CommandLogout, CommandSetPreferredLogin,
CommandSetRelay, CommandUnsetRelay,
- CommandResolveIdentifier, CommandStartChat, CommandSearch,
+ CommandResolveIdentifier, CommandStartChat, CommandCreateGroup, CommandSearch, CommandSyncChat, CommandMute,
CommandSudo, CommandDoIn,
)
return proc
diff --git a/bridgev2/commands/relay.go b/bridgev2/commands/relay.go
index af756c87..94c19739 100644
--- a/bridgev2/commands/relay.go
+++ b/bridgev2/commands/relay.go
@@ -37,7 +37,7 @@ func fnSetRelay(ce *Event) {
}
onlySetDefaultRelays := !ce.User.Permissions.Admin && ce.Bridge.Config.Relay.AdminOnly
var relay *bridgev2.UserLogin
- if len(ce.Args) == 0 {
+ if len(ce.Args) == 0 && ce.Portal.Receiver == "" {
relay = ce.User.GetDefaultLogin()
isLoggedIn := relay != nil
if onlySetDefaultRelays {
@@ -73,9 +73,19 @@ func fnSetRelay(ce *Event) {
}
}
} else {
- relay = ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0]))
+ var targetID networkid.UserLoginID
+ if ce.Portal.Receiver != "" {
+ targetID = ce.Portal.Receiver
+ if len(ce.Args) > 0 && ce.Args[0] != string(targetID) {
+ ce.Reply("In split portals, only the receiver (%s) can be set as relay", targetID)
+ return
+ }
+ } else {
+ targetID = networkid.UserLoginID(ce.Args[0])
+ }
+ relay = ce.Bridge.GetCachedUserLoginByID(targetID)
if relay == nil {
- ce.Reply("User login with ID `%s` not found", ce.Args[0])
+ ce.Reply("User login with ID `%s` not found", targetID)
return
} else if slices.Contains(ce.Bridge.Config.Relay.DefaultRelays, relay.ID) {
// All good
diff --git a/bridgev2/commands/startchat.go b/bridgev2/commands/startchat.go
index 42f528b0..c7b05a6e 100644
--- a/bridgev2/commands/startchat.go
+++ b/bridgev2/commands/startchat.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2024 Tulir Asokan
+// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -8,13 +8,21 @@ package commands
import (
"context"
+ "errors"
"fmt"
"html"
+ "maps"
+ "slices"
"strings"
"time"
+ "github.com/rs/zerolog"
+
"maunium.net/go/mautrix/bridgev2"
"maunium.net/go/mautrix/bridgev2/networkid"
+ "maunium.net/go/mautrix/bridgev2/provisionutil"
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/format"
"maunium.net/go/mautrix/id"
)
@@ -27,6 +35,36 @@ var CommandResolveIdentifier = &FullHandler{
Args: "[_login ID_] <_identifier_>",
},
RequiresLogin: true,
+ NetworkAPI: NetworkAPIImplements[bridgev2.IdentifierResolvingNetworkAPI],
+}
+
+var CommandSyncChat = &FullHandler{
+ Func: func(ce *Event) {
+ login, _, err := ce.Portal.FindPreferredLogin(ce.Ctx, ce.User, false)
+ if err != nil {
+ ce.Log.Err(err).Msg("Failed to find login for sync")
+ ce.Reply("Failed to find login: %v", err)
+ return
+ } else if login == nil {
+ ce.Reply("No login found for sync")
+ return
+ }
+ info, err := login.Client.GetChatInfo(ce.Ctx, ce.Portal)
+ if err != nil {
+ ce.Log.Err(err).Msg("Failed to get chat info for sync")
+ ce.Reply("Failed to get chat info: %v", err)
+ return
+ }
+ ce.Portal.UpdateInfo(ce.Ctx, info, login, nil, time.Time{})
+ ce.React("✅️")
+ },
+ Name: "sync-portal",
+ Help: HelpMeta{
+ Section: HelpSectionChats,
+ Description: "Sync the current portal room",
+ },
+ RequiresPortal: true,
+ RequiresLogin: true,
}
var CommandStartChat = &FullHandler{
@@ -39,11 +77,18 @@ var CommandStartChat = &FullHandler{
Args: "[_login ID_] <_identifier_>",
},
RequiresLogin: true,
+ NetworkAPI: NetworkAPIImplements[bridgev2.IdentifierResolvingNetworkAPI],
}
-func getClientForStartingChat[T bridgev2.IdentifierResolvingNetworkAPI](ce *Event, thing string) (*bridgev2.UserLogin, T, []string) {
- remainingArgs := ce.Args[1:]
- login := ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0]))
+func getClientForStartingChat[T bridgev2.NetworkAPI](ce *Event, thing string) (*bridgev2.UserLogin, T, []string) {
+ var remainingArgs []string
+ if len(ce.Args) > 1 {
+ remainingArgs = ce.Args[1:]
+ }
+ var login *bridgev2.UserLogin
+ if len(ce.Args) > 0 {
+ login = ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0]))
+ }
if login == nil || login.UserMXID != ce.User.MXID {
remainingArgs = ce.Args
login = ce.User.GetDefaultLogin()
@@ -55,24 +100,13 @@ func getClientForStartingChat[T bridgev2.IdentifierResolvingNetworkAPI](ce *Even
return login, api, remainingArgs
}
-func formatResolveIdentifierResult(ctx context.Context, resp *bridgev2.ResolveIdentifierResponse) string {
- var targetName string
- var targetMXID id.UserID
- if resp.Ghost != nil {
- if resp.UserInfo != nil {
- resp.Ghost.UpdateInfo(ctx, resp.UserInfo)
- }
- targetName = resp.Ghost.Name
- targetMXID = resp.Ghost.Intent.GetMXID()
- } else if resp.UserInfo != nil && resp.UserInfo.Name != nil {
- targetName = *resp.UserInfo.Name
- }
- if targetMXID != "" {
- return fmt.Sprintf("`%s` / [%s](%s)", resp.UserID, targetName, targetMXID.URI().MatrixToURL())
- } else if targetName != "" {
- return fmt.Sprintf("`%s` / %s", resp.UserID, targetName)
+func formatResolveIdentifierResult(resp *provisionutil.RespResolveIdentifier) string {
+ if resp.MXID != "" {
+ return fmt.Sprintf("`%s` / [%s](%s)", resp.ID, resp.Name, resp.MXID.URI().MatrixToURL())
+ } else if resp.Name != "" {
+ return fmt.Sprintf("`%s` / %s", resp.ID, resp.Name)
} else {
- return fmt.Sprintf("`%s`", resp.UserID)
+ return fmt.Sprintf("`%s`", resp.ID)
}
}
@@ -85,65 +119,137 @@ func fnResolveIdentifier(ce *Event) {
if api == nil {
return
}
- createChat := ce.Command == "start-chat"
+ allLogins := ce.User.GetUserLogins()
+ createChat := ce.Command == "start-chat" || ce.Command == "pm"
identifier := strings.Join(identifierParts, " ")
- resp, err := api.ResolveIdentifier(ce.Ctx, identifier, createChat)
+ resp, err := provisionutil.ResolveIdentifier(ce.Ctx, login, identifier, createChat)
+ for i := 0; i < len(allLogins) && errors.Is(err, bridgev2.ErrResolveIdentifierTryNext); i++ {
+ resp, err = provisionutil.ResolveIdentifier(ce.Ctx, allLogins[i], identifier, createChat)
+ }
if err != nil {
- ce.Log.Err(err).Msg("Failed to resolve identifier")
ce.Reply("Failed to resolve identifier: %v", err)
return
} else if resp == nil {
ce.ReplyAdvanced(fmt.Sprintf("Identifier %s not found", html.EscapeString(identifier)), false, true)
return
}
- formattedName := formatResolveIdentifierResult(ce.Ctx, resp)
+ formattedName := formatResolveIdentifierResult(resp)
if createChat {
- if resp.Chat == nil {
- ce.Reply("Interface error: network connector did not return chat for create chat request")
- return
+ name := resp.Portal.Name
+ if name == "" {
+ name = resp.Portal.MXID.String()
}
- portal := resp.Chat.Portal
- if portal == nil {
- portal, err = ce.Bridge.GetPortalByKey(ce.Ctx, resp.Chat.PortalKey)
- if err != nil {
- ce.Log.Err(err).Msg("Failed to get portal")
- ce.Reply("Failed to get portal: %v", err)
- return
- }
- }
- if resp.Chat.PortalInfo == nil {
- resp.Chat.PortalInfo, err = api.GetChatInfo(ce.Ctx, portal)
- if err != nil {
- ce.Log.Err(err).Msg("Failed to get portal info")
- ce.Reply("Failed to get portal info: %v", err)
- return
- }
- }
- if portal.MXID != "" {
- name := portal.Name
- if name == "" {
- name = portal.MXID.String()
- }
- portal.UpdateInfo(ce.Ctx, resp.Chat.PortalInfo, login, nil, time.Time{})
- ce.Reply("You already have a direct chat with %s at [%s](%s)", formattedName, name, portal.MXID.URI().MatrixToURL())
+ if !resp.JustCreated {
+ ce.Reply("You already have a direct chat with %s at [%s](%s)", formattedName, name, resp.Portal.MXID.URI().MatrixToURL())
} else {
- err = portal.CreateMatrixRoom(ce.Ctx, login, resp.Chat.PortalInfo)
- if err != nil {
- ce.Log.Err(err).Msg("Failed to create room")
- ce.Reply("Failed to create room: %v", err)
- return
- }
- name := portal.Name
- if name == "" {
- name = portal.MXID.String()
- }
- ce.Reply("Created chat with %s: [%s](%s)", formattedName, name, portal.MXID.URI().MatrixToURL())
+ ce.Reply("Created chat with %s: [%s](%s)", formattedName, name, resp.Portal.MXID.URI().MatrixToURL())
}
} else {
ce.Reply("Found %s", formattedName)
}
}
+var CommandCreateGroup = &FullHandler{
+ Func: fnCreateGroup,
+ Name: "create-group",
+ Aliases: []string{"create"},
+ Help: HelpMeta{
+ Section: HelpSectionChats,
+ Description: "Create a new group chat for the current Matrix room",
+ Args: "[_group type_]",
+ },
+ RequiresLogin: true,
+ NetworkAPI: NetworkAPIImplements[bridgev2.GroupCreatingNetworkAPI],
+}
+
+func getState[T any](ctx context.Context, roomID id.RoomID, evtType event.Type, provider bridgev2.MatrixConnectorWithArbitraryRoomState) (content T) {
+ evt, err := provider.GetStateEvent(ctx, roomID, evtType, "")
+ if err != nil {
+ zerolog.Ctx(ctx).Err(err).Stringer("event_type", evtType).Msg("Failed to get state event for group creation")
+ } else if evt != nil {
+ content, _ = evt.Content.Parsed.(T)
+ }
+ return
+}
+
+func fnCreateGroup(ce *Event) {
+ ce.Bridge.Matrix.GetCapabilities()
+ login, api, remainingArgs := getClientForStartingChat[bridgev2.GroupCreatingNetworkAPI](ce, "creating group")
+ if api == nil {
+ return
+ }
+ stateProvider, ok := ce.Bridge.Matrix.(bridgev2.MatrixConnectorWithArbitraryRoomState)
+ if !ok {
+ ce.Reply("Matrix connector doesn't support fetching room state")
+ return
+ }
+ members, err := ce.Bridge.Matrix.GetMembers(ce.Ctx, ce.RoomID)
+ if err != nil {
+ ce.Log.Err(err).Msg("Failed to get room members for group creation")
+ ce.Reply("Failed to get room members: %v", err)
+ return
+ }
+ caps := ce.Bridge.Network.GetCapabilities()
+ params := &bridgev2.GroupCreateParams{
+ Username: "",
+ Participants: make([]networkid.UserID, 0, len(members)-2),
+ Parent: nil, // TODO check space parent event
+ Name: getState[*event.RoomNameEventContent](ce.Ctx, ce.RoomID, event.StateRoomName, stateProvider),
+ Avatar: getState[*event.RoomAvatarEventContent](ce.Ctx, ce.RoomID, event.StateRoomAvatar, stateProvider),
+ Topic: getState[*event.TopicEventContent](ce.Ctx, ce.RoomID, event.StateTopic, stateProvider),
+ Disappear: getState[*event.BeeperDisappearingTimer](ce.Ctx, ce.RoomID, event.StateBeeperDisappearingTimer, stateProvider),
+ RoomID: ce.RoomID,
+ }
+ for userID, member := range members {
+ if userID == ce.User.MXID || userID == ce.Bot.GetMXID() || !member.Membership.IsInviteOrJoin() {
+ continue
+ }
+ if parsedUserID, ok := ce.Bridge.Matrix.ParseGhostMXID(userID); ok {
+ params.Participants = append(params.Participants, parsedUserID)
+ } else if !ce.Bridge.Config.SplitPortals {
+ if user, err := ce.Bridge.GetExistingUserByMXID(ce.Ctx, userID); err != nil {
+ ce.Log.Err(err).Stringer("user_id", userID).Msg("Failed to get user for room member")
+ } else if user != nil {
+ // TODO add user logins to participants
+ //for _, login := range user.GetUserLogins() {
+ // params.Participants = append(params.Participants, login.GetUserID())
+ //}
+ }
+ }
+ }
+
+ if len(caps.Provisioning.GroupCreation) == 0 {
+ ce.Reply("No group creation types defined in network capabilities")
+ return
+ } else if len(remainingArgs) > 0 {
+ params.Type = remainingArgs[0]
+ } else if len(caps.Provisioning.GroupCreation) == 1 {
+ for params.Type = range caps.Provisioning.GroupCreation {
+ // The loop assigns the variable we want
+ }
+ } else {
+ types := strings.Join(slices.Collect(maps.Keys(caps.Provisioning.GroupCreation)), "`, `")
+ ce.Reply("Please specify type of group to create: `%s`", types)
+ return
+ }
+ resp, err := provisionutil.CreateGroup(ce.Ctx, login, params)
+ if err != nil {
+ ce.Reply("Failed to create group: %v", err)
+ return
+ }
+ var postfix string
+ if len(resp.FailedParticipants) > 0 {
+ failedParticipantsStrings := make([]string, len(resp.FailedParticipants))
+ i := 0
+ for participantID, meta := range resp.FailedParticipants {
+ failedParticipantsStrings[i] = fmt.Sprintf("* %s: %s", format.SafeMarkdownCode(participantID), meta.Reason)
+ i++
+ }
+ postfix += "\n\nFailed to add some participants:\n" + strings.Join(failedParticipantsStrings, "\n")
+ }
+ ce.Reply("Successfully created group `%s`%s", resp.ID, postfix)
+}
+
var CommandSearch = &FullHandler{
Func: fnSearch,
Name: "search",
@@ -153,6 +259,7 @@ var CommandSearch = &FullHandler{
Args: "<_query_>",
},
RequiresLogin: true,
+ NetworkAPI: NetworkAPIImplements[bridgev2.UserSearchingNetworkAPI],
}
func fnSearch(ce *Event) {
@@ -160,35 +267,67 @@ func fnSearch(ce *Event) {
ce.Reply("Usage: `$cmdprefix search
+ Blockquote = "blockquote", // blockquote
+ InlineLink = "inline_link", // a
+ UserLink = "user_link", //
+ RoomLink = "room_link", //
+ EventLink = "event_link", //
+ AtRoomMention = "at_room_mention", // @room (no html tag)
+ UnorderedList = "unordered_list", // ul + li
+ OrderedList = "ordered_list", // ol + li
+ ListStart = "ordered_list.start", //
+ ListJumpValue = "ordered_list.jump_value", // -
+ CustomEmoji = "custom_emoji", //
+ Spoiler = "spoiler", //
+ SpoilerReason = "spoiler.reason", //
+ TextForegroundColor = "color.foreground", //
+ TextBackgroundColor = "color.background", //
+ HorizontalLine = "horizontal_line", // hr
+ Headers = "headers", // h1, h2, h3, h4, h5, h6
+ Superscript = "superscript", // sup
+ Subscript = "subscript", // sub
+ Math = "math", //
+ DetailsSummary = "details_summary", // ...
...
+ Table = "table", // table, thead, tbody, tr, th, td
+}
diff --git a/event/capabilities.go b/event/capabilities.go
new file mode 100644
index 00000000..a86c726b
--- /dev/null
+++ b/event/capabilities.go
@@ -0,0 +1,414 @@
+// Copyright (c) 2024 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package event
+
+import (
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/binary"
+ "fmt"
+ "io"
+ "mime"
+ "slices"
+ "strings"
+
+ "go.mau.fi/util/exerrors"
+ "go.mau.fi/util/jsontime"
+ "go.mau.fi/util/ptr"
+ "golang.org/x/exp/constraints"
+ "golang.org/x/exp/maps"
+)
+
+type RoomFeatures struct {
+ ID string `json:"id,omitempty"`
+
+ // N.B. New fields need to be added to the Hash function to be included in the deduplication hash.
+
+ Formatting FormattingFeatureMap `json:"formatting,omitempty"`
+ File FileFeatureMap `json:"file,omitempty"`
+ State StateFeatureMap `json:"state,omitempty"`
+ MemberActions MemberFeatureMap `json:"member_actions,omitempty"`
+
+ MaxTextLength int `json:"max_text_length,omitempty"`
+
+ LocationMessage CapabilitySupportLevel `json:"location_message,omitempty"`
+ Poll CapabilitySupportLevel `json:"poll,omitempty"`
+ Thread CapabilitySupportLevel `json:"thread,omitempty"`
+ Reply CapabilitySupportLevel `json:"reply,omitempty"`
+
+ Edit CapabilitySupportLevel `json:"edit,omitempty"`
+ EditMaxCount int `json:"edit_max_count,omitempty"`
+ EditMaxAge *jsontime.Seconds `json:"edit_max_age,omitempty"`
+ Delete CapabilitySupportLevel `json:"delete,omitempty"`
+ DeleteForMe bool `json:"delete_for_me,omitempty"`
+ DeleteMaxAge *jsontime.Seconds `json:"delete_max_age,omitempty"`
+
+ DisappearingTimer *DisappearingTimerCapability `json:"disappearing_timer,omitempty"`
+
+ Reaction CapabilitySupportLevel `json:"reaction,omitempty"`
+ ReactionCount int `json:"reaction_count,omitempty"`
+ AllowedReactions []string `json:"allowed_reactions,omitempty"`
+ CustomEmojiReactions bool `json:"custom_emoji_reactions,omitempty"`
+
+ ReadReceipts bool `json:"read_receipts,omitempty"`
+ TypingNotifications bool `json:"typing_notifications,omitempty"`
+ Archive bool `json:"archive,omitempty"`
+ MarkAsUnread bool `json:"mark_as_unread,omitempty"`
+ DeleteChat bool `json:"delete_chat,omitempty"`
+ DeleteChatForEveryone bool `json:"delete_chat_for_everyone,omitempty"`
+
+ MessageRequest *MessageRequestFeatures `json:"message_request,omitempty"`
+
+ PerMessageProfileRelay bool `json:"-"`
+}
+
+func (rf *RoomFeatures) GetID() string {
+ if rf.ID != "" {
+ return rf.ID
+ }
+ return base64.RawURLEncoding.EncodeToString(rf.Hash())
+}
+
+func (rf *RoomFeatures) Clone() *RoomFeatures {
+ if rf == nil {
+ return nil
+ }
+ clone := *rf
+ clone.File = clone.File.Clone()
+ clone.Formatting = maps.Clone(clone.Formatting)
+ clone.State = clone.State.Clone()
+ clone.MemberActions = clone.MemberActions.Clone()
+ clone.EditMaxAge = ptr.Clone(clone.EditMaxAge)
+ clone.DeleteMaxAge = ptr.Clone(clone.DeleteMaxAge)
+ clone.DisappearingTimer = clone.DisappearingTimer.Clone()
+ clone.AllowedReactions = slices.Clone(clone.AllowedReactions)
+ clone.MessageRequest = clone.MessageRequest.Clone()
+ return &clone
+}
+
+type MemberFeatureMap map[MemberAction]CapabilitySupportLevel
+
+func (mfm MemberFeatureMap) Clone() MemberFeatureMap {
+ return maps.Clone(mfm)
+}
+
+type MemberAction string
+
+const (
+ MemberActionBan MemberAction = "ban"
+ MemberActionKick MemberAction = "kick"
+ MemberActionLeave MemberAction = "leave"
+ MemberActionRevokeInvite MemberAction = "revoke_invite"
+ MemberActionInvite MemberAction = "invite"
+)
+
+type StateFeatureMap map[string]*StateFeatures
+
+func (sfm StateFeatureMap) Clone() StateFeatureMap {
+ dup := maps.Clone(sfm)
+ for key, value := range dup {
+ dup[key] = value.Clone()
+ }
+ return dup
+}
+
+type StateFeatures struct {
+ Level CapabilitySupportLevel `json:"level"`
+}
+
+func (sf *StateFeatures) Clone() *StateFeatures {
+ if sf == nil {
+ return nil
+ }
+ clone := *sf
+ return &clone
+}
+
+func (sf *StateFeatures) Hash() []byte {
+ return sf.Level.Hash()
+}
+
+type FormattingFeatureMap map[FormattingFeature]CapabilitySupportLevel
+
+type FileFeatureMap map[CapabilityMsgType]*FileFeatures
+
+func (ffm FileFeatureMap) Clone() FileFeatureMap {
+ dup := maps.Clone(ffm)
+ for key, value := range dup {
+ dup[key] = value.Clone()
+ }
+ return dup
+}
+
+type DisappearingTimerCapability struct {
+ Types []DisappearingType `json:"types"`
+ Timers []jsontime.Milliseconds `json:"timers,omitempty"`
+
+ OmitEmptyTimer bool `json:"omit_empty_timer,omitempty"`
+}
+
+func (dtc *DisappearingTimerCapability) Clone() *DisappearingTimerCapability {
+ if dtc == nil {
+ return nil
+ }
+ clone := *dtc
+ clone.Types = slices.Clone(clone.Types)
+ clone.Timers = slices.Clone(clone.Timers)
+ return &clone
+}
+
+func (dtc *DisappearingTimerCapability) Supports(content *BeeperDisappearingTimer) bool {
+ if dtc == nil || content == nil || content.Type == DisappearingTypeNone {
+ return true
+ }
+ return slices.Contains(dtc.Types, content.Type) && (dtc.Timers == nil || slices.Contains(dtc.Timers, content.Timer))
+}
+
+type MessageRequestFeatures struct {
+ AcceptWithMessage CapabilitySupportLevel `json:"accept_with_message,omitempty"`
+ AcceptWithButton CapabilitySupportLevel `json:"accept_with_button,omitempty"`
+}
+
+func (mrf *MessageRequestFeatures) Clone() *MessageRequestFeatures {
+ return ptr.Clone(mrf)
+}
+
+func (mrf *MessageRequestFeatures) Hash() []byte {
+ if mrf == nil {
+ return nil
+ }
+ hasher := sha256.New()
+ hashValue(hasher, "accept_with_message", mrf.AcceptWithMessage)
+ hashValue(hasher, "accept_with_button", mrf.AcceptWithButton)
+ return hasher.Sum(nil)
+}
+
+type CapabilityMsgType = MessageType
+
+// Message types which are used for event capability signaling, but aren't real values for the msgtype field.
+const (
+ CapMsgVoice CapabilityMsgType = "org.matrix.msc3245.voice"
+ CapMsgGIF CapabilityMsgType = "fi.mau.gif"
+ CapMsgSticker CapabilityMsgType = "m.sticker"
+)
+
+type CapabilitySupportLevel int
+
+func (csl CapabilitySupportLevel) Partial() bool {
+ return csl >= CapLevelPartialSupport
+}
+
+func (csl CapabilitySupportLevel) Full() bool {
+ return csl >= CapLevelFullySupported
+}
+
+func (csl CapabilitySupportLevel) Reject() bool {
+ return csl <= CapLevelRejected
+}
+
+const (
+ CapLevelRejected CapabilitySupportLevel = -2 // The feature is unsupported and messages using it will be rejected.
+ CapLevelDropped CapabilitySupportLevel = -1 // The feature is unsupported and has no fallback. The message will go through, but data may be lost.
+ CapLevelUnsupported CapabilitySupportLevel = 0 // The feature is unsupported, but may have a fallback.
+ CapLevelPartialSupport CapabilitySupportLevel = 1 // The feature is partially supported (e.g. it may be converted to a different format).
+ CapLevelFullySupported CapabilitySupportLevel = 2 // The feature is fully supported and can be safely used.
+)
+
+type FormattingFeature string
+
+const (
+ FmtBold FormattingFeature = "bold" // strong, b
+ FmtItalic FormattingFeature = "italic" // em, i
+ FmtUnderline FormattingFeature = "underline" // u
+ FmtStrikethrough FormattingFeature = "strikethrough" // del, s
+ FmtInlineCode FormattingFeature = "inline_code" // code
+ FmtCodeBlock FormattingFeature = "code_block" // pre + code
+ FmtSyntaxHighlighting FormattingFeature = "code_block.syntax_highlighting" //
+ FmtBlockquote FormattingFeature = "blockquote" // blockquote
+ FmtInlineLink FormattingFeature = "inline_link" // a
+ FmtUserLink FormattingFeature = "user_link" //
+ FmtRoomLink FormattingFeature = "room_link" //
+ FmtEventLink FormattingFeature = "event_link" //
+ FmtAtRoomMention FormattingFeature = "at_room_mention" // @room (no html tag)
+ FmtUnorderedList FormattingFeature = "unordered_list" // ul + li
+ FmtOrderedList FormattingFeature = "ordered_list" // ol + li
+ FmtListStart FormattingFeature = "ordered_list.start" //
+ FmtListJumpValue FormattingFeature = "ordered_list.jump_value" // -
+ FmtCustomEmoji FormattingFeature = "custom_emoji" //
+ FmtSpoiler FormattingFeature = "spoiler" //
+ FmtSpoilerReason FormattingFeature = "spoiler.reason" //
+ FmtTextForegroundColor FormattingFeature = "color.foreground" //
+ FmtTextBackgroundColor FormattingFeature = "color.background" //
+ FmtHorizontalLine FormattingFeature = "horizontal_line" // hr
+ FmtHeaders FormattingFeature = "headers" // h1, h2, h3, h4, h5, h6
+ FmtSuperscript FormattingFeature = "superscript" // sup
+ FmtSubscript FormattingFeature = "subscript" // sub
+ FmtMath FormattingFeature = "math" //
+ FmtDetailsSummary FormattingFeature = "details_summary" // ...
...
+ FmtTable FormattingFeature = "table" // table, thead, tbody, tr, th, td
+)
+
+type FileFeatures struct {
+ // N.B. New fields need to be added to the Hash function to be included in the deduplication hash.
+
+ MimeTypes map[string]CapabilitySupportLevel `json:"mime_types"`
+
+ Caption CapabilitySupportLevel `json:"caption,omitempty"`
+ MaxCaptionLength int `json:"max_caption_length,omitempty"`
+
+ MaxSize int64 `json:"max_size,omitempty"`
+ MaxWidth int `json:"max_width,omitempty"`
+ MaxHeight int `json:"max_height,omitempty"`
+ MaxDuration *jsontime.Seconds `json:"max_duration,omitempty"`
+
+ ViewOnce bool `json:"view_once,omitempty"`
+}
+
+func (ff *FileFeatures) GetMimeSupport(inputType string) CapabilitySupportLevel {
+ match, ok := ff.MimeTypes[inputType]
+ if ok {
+ return match
+ }
+ if strings.IndexByte(inputType, ';') != -1 {
+ plainMime, _, _ := mime.ParseMediaType(inputType)
+ if plainMime != "" {
+ if match, ok = ff.MimeTypes[plainMime]; ok {
+ return match
+ }
+ }
+ }
+ if slash := strings.IndexByte(inputType, '/'); slash > 0 {
+ generalType := fmt.Sprintf("%s/*", inputType[:slash])
+ if match, ok = ff.MimeTypes[generalType]; ok {
+ return match
+ }
+ }
+ match, ok = ff.MimeTypes["*/*"]
+ if ok {
+ return match
+ }
+ return CapLevelRejected
+}
+
+type hashable interface {
+ Hash() []byte
+}
+
+func hashMap[Key ~string, Value hashable](w io.Writer, name string, data map[Key]Value) {
+ keys := maps.Keys(data)
+ slices.Sort(keys)
+ exerrors.Must(w.Write([]byte(name)))
+ for _, key := range keys {
+ exerrors.Must(w.Write([]byte(key)))
+ exerrors.Must(w.Write(data[key].Hash()))
+ exerrors.Must(w.Write([]byte{0}))
+ }
+}
+
+func hashValue(w io.Writer, name string, data hashable) {
+ exerrors.Must(w.Write([]byte(name)))
+ exerrors.Must(w.Write(data.Hash()))
+}
+
+func hashInt[T constraints.Integer](w io.Writer, name string, data T) {
+ exerrors.Must(w.Write(binary.BigEndian.AppendUint64([]byte(name), uint64(data))))
+}
+
+func hashBool[T ~bool](w io.Writer, name string, data T) {
+ exerrors.Must(w.Write([]byte(name)))
+ if data {
+ exerrors.Must(w.Write([]byte{1}))
+ } else {
+ exerrors.Must(w.Write([]byte{0}))
+ }
+}
+
+func (csl CapabilitySupportLevel) Hash() []byte {
+ return []byte{byte(csl + 128)}
+}
+
+func (rf *RoomFeatures) Hash() []byte {
+ hasher := sha256.New()
+
+ hashMap(hasher, "formatting", rf.Formatting)
+ hashMap(hasher, "file", rf.File)
+ hashMap(hasher, "state", rf.State)
+ hashMap(hasher, "member_actions", rf.MemberActions)
+
+ hashInt(hasher, "max_text_length", rf.MaxTextLength)
+
+ hashValue(hasher, "location_message", rf.LocationMessage)
+ hashValue(hasher, "poll", rf.Poll)
+ hashValue(hasher, "thread", rf.Thread)
+ hashValue(hasher, "reply", rf.Reply)
+
+ hashValue(hasher, "edit", rf.Edit)
+ hashInt(hasher, "edit_max_count", rf.EditMaxCount)
+ hashInt(hasher, "edit_max_age", rf.EditMaxAge.Get())
+
+ hashValue(hasher, "delete", rf.Delete)
+ hashBool(hasher, "delete_for_me", rf.DeleteForMe)
+ hashInt(hasher, "delete_max_age", rf.DeleteMaxAge.Get())
+ hashValue(hasher, "disappearing_timer", rf.DisappearingTimer)
+
+ hashValue(hasher, "reaction", rf.Reaction)
+ hashInt(hasher, "reaction_count", rf.ReactionCount)
+ hasher.Write([]byte("allowed_reactions"))
+ for _, reaction := range rf.AllowedReactions {
+ hasher.Write([]byte(reaction))
+ }
+ hashBool(hasher, "custom_emoji_reactions", rf.CustomEmojiReactions)
+
+ hashBool(hasher, "read_receipts", rf.ReadReceipts)
+ hashBool(hasher, "typing_notifications", rf.TypingNotifications)
+ hashBool(hasher, "archive", rf.Archive)
+ hashBool(hasher, "mark_as_unread", rf.MarkAsUnread)
+ hashBool(hasher, "delete_chat", rf.DeleteChat)
+ hashBool(hasher, "delete_chat_for_everyone", rf.DeleteChatForEveryone)
+ hashValue(hasher, "message_request", rf.MessageRequest)
+
+ return hasher.Sum(nil)
+}
+
+func (dtc *DisappearingTimerCapability) Hash() []byte {
+ if dtc == nil {
+ return nil
+ }
+ hasher := sha256.New()
+ hasher.Write([]byte("types"))
+ for _, t := range dtc.Types {
+ hasher.Write([]byte(t))
+ }
+ hasher.Write([]byte("timers"))
+ for _, timer := range dtc.Timers {
+ hashInt(hasher, "", timer.Milliseconds())
+ }
+ return hasher.Sum(nil)
+}
+
+func (ff *FileFeatures) Hash() []byte {
+ hasher := sha256.New()
+ hashMap(hasher, "mime_types", ff.MimeTypes)
+ hashValue(hasher, "caption", ff.Caption)
+ hashInt(hasher, "max_caption_length", ff.MaxCaptionLength)
+ hashInt(hasher, "max_size", ff.MaxSize)
+ hashInt(hasher, "max_width", ff.MaxWidth)
+ hashInt(hasher, "max_height", ff.MaxHeight)
+ hashInt(hasher, "max_duration", ff.MaxDuration.Get())
+ hashBool(hasher, "view_once", ff.ViewOnce)
+ return hasher.Sum(nil)
+}
+
+func (ff *FileFeatures) Clone() *FileFeatures {
+ if ff == nil {
+ return nil
+ }
+ clone := *ff
+ clone.MimeTypes = maps.Clone(clone.MimeTypes)
+ clone.MaxDuration = ptr.Clone(clone.MaxDuration)
+ return &clone
+}
diff --git a/event/cmdschema/content.go b/event/cmdschema/content.go
new file mode 100644
index 00000000..ce07c4c0
--- /dev/null
+++ b/event/cmdschema/content.go
@@ -0,0 +1,78 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package cmdschema
+
+import (
+ "crypto/sha256"
+ "encoding/base64"
+ "fmt"
+ "reflect"
+ "slices"
+
+ "go.mau.fi/util/exsync"
+ "go.mau.fi/util/ptr"
+
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+)
+
+type EventContent struct {
+ Command string `json:"command"`
+ Aliases []string `json:"aliases,omitempty"`
+ Parameters []*Parameter `json:"parameters,omitempty"`
+ Description *event.ExtensibleTextContainer `json:"description,omitempty"`
+ TailParam string `json:"fi.mau.tail_parameter,omitempty"`
+}
+
+func (ec *EventContent) Validate() error {
+ if ec == nil {
+ return fmt.Errorf("event content is nil")
+ } else if ec.Command == "" {
+ return fmt.Errorf("command is empty")
+ }
+ var tailFound bool
+ dupMap := exsync.NewSet[string]()
+ for i, p := range ec.Parameters {
+ if err := p.Validate(); err != nil {
+ return fmt.Errorf("parameter %q (#%d) is invalid: %w", ptr.Val(p).Key, i+1, err)
+ } else if !dupMap.Add(p.Key) {
+ return fmt.Errorf("duplicate parameter key %q at #%d", p.Key, i+1)
+ } else if p.Key == ec.TailParam {
+ tailFound = true
+ } else if tailFound && !p.Optional {
+ return fmt.Errorf("required parameter %q (#%d) is after tail parameter %q", p.Key, i+1, ec.TailParam)
+ }
+ }
+ if ec.TailParam != "" && !tailFound {
+ return fmt.Errorf("tail parameter %q not found in parameters", ec.TailParam)
+ }
+ return nil
+}
+
+func (ec *EventContent) IsValid() bool {
+ return ec.Validate() == nil
+}
+
+func (ec *EventContent) StateKey(owner id.UserID) string {
+ hash := sha256.Sum256([]byte(ec.Command + owner.String()))
+ return base64.StdEncoding.EncodeToString(hash[:])
+}
+
+func (ec *EventContent) Equals(other *EventContent) bool {
+ if ec == nil || other == nil {
+ return ec == other
+ }
+ return ec.Command == other.Command &&
+ slices.Equal(ec.Aliases, other.Aliases) &&
+ slices.EqualFunc(ec.Parameters, other.Parameters, (*Parameter).Equals) &&
+ ec.Description.Equals(other.Description) &&
+ ec.TailParam == other.TailParam
+}
+
+func init() {
+ event.TypeMap[event.StateMSC4391BotCommand] = reflect.TypeOf(EventContent{})
+}
diff --git a/event/cmdschema/parameter.go b/event/cmdschema/parameter.go
new file mode 100644
index 00000000..4193b297
--- /dev/null
+++ b/event/cmdschema/parameter.go
@@ -0,0 +1,286 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package cmdschema
+
+import (
+ "encoding/json"
+ "fmt"
+ "slices"
+
+ "go.mau.fi/util/exslices"
+
+ "maunium.net/go/mautrix/event"
+)
+
+type Parameter struct {
+ Key string `json:"key"`
+ Schema *ParameterSchema `json:"schema"`
+ Optional bool `json:"optional,omitempty"`
+ Description *event.ExtensibleTextContainer `json:"description,omitempty"`
+ DefaultValue any `json:"fi.mau.default_value,omitempty"`
+}
+
+func (p *Parameter) Equals(other *Parameter) bool {
+ if p == nil || other == nil {
+ return p == other
+ }
+ return p.Key == other.Key &&
+ p.Schema.Equals(other.Schema) &&
+ p.Optional == other.Optional &&
+ p.Description.Equals(other.Description) &&
+ p.DefaultValue == other.DefaultValue // TODO this won't work for room/event ID values
+}
+
+func (p *Parameter) Validate() error {
+ if p == nil {
+ return fmt.Errorf("parameter is nil")
+ } else if p.Key == "" {
+ return fmt.Errorf("key is empty")
+ }
+ return p.Schema.Validate()
+}
+
+func (p *Parameter) IsValid() bool {
+ return p.Validate() == nil
+}
+
+func (p *Parameter) GetDefaultValue() any {
+ if p != nil && p.DefaultValue != nil {
+ return p.DefaultValue
+ } else if p == nil || p.Optional {
+ return nil
+ }
+ return p.Schema.GetDefaultValue()
+}
+
+type PrimitiveType string
+
+const (
+ PrimitiveTypeString PrimitiveType = "string"
+ PrimitiveTypeInteger PrimitiveType = "integer"
+ PrimitiveTypeBoolean PrimitiveType = "boolean"
+ PrimitiveTypeServerName PrimitiveType = "server_name"
+ PrimitiveTypeUserID PrimitiveType = "user_id"
+ PrimitiveTypeRoomID PrimitiveType = "room_id"
+ PrimitiveTypeRoomAlias PrimitiveType = "room_alias"
+ PrimitiveTypeEventID PrimitiveType = "event_id"
+)
+
+func (pt PrimitiveType) Schema() *ParameterSchema {
+ return &ParameterSchema{
+ SchemaType: SchemaTypePrimitive,
+ Type: pt,
+ }
+}
+
+func (pt PrimitiveType) IsValid() bool {
+ switch pt {
+ case PrimitiveTypeString,
+ PrimitiveTypeInteger,
+ PrimitiveTypeBoolean,
+ PrimitiveTypeServerName,
+ PrimitiveTypeUserID,
+ PrimitiveTypeRoomID,
+ PrimitiveTypeRoomAlias,
+ PrimitiveTypeEventID:
+ return true
+ default:
+ return false
+ }
+}
+
+type SchemaType string
+
+const (
+ SchemaTypePrimitive SchemaType = "primitive"
+ SchemaTypeArray SchemaType = "array"
+ SchemaTypeUnion SchemaType = "union"
+ SchemaTypeLiteral SchemaType = "literal"
+)
+
+type ParameterSchema struct {
+ SchemaType SchemaType `json:"schema_type"`
+ Type PrimitiveType `json:"type,omitempty"` // Only for primitive
+ Items *ParameterSchema `json:"items,omitempty"` // Only for array
+ Variants []*ParameterSchema `json:"variants,omitempty"` // Only for union
+ Value any `json:"value,omitempty"` // Only for literal
+}
+
+func Literal(value any) *ParameterSchema {
+ return &ParameterSchema{
+ SchemaType: SchemaTypeLiteral,
+ Value: value,
+ }
+}
+
+func Enum(values ...any) *ParameterSchema {
+ return Union(exslices.CastFunc(values, Literal)...)
+}
+
+func flattenUnion(variants []*ParameterSchema) []*ParameterSchema {
+ var flattened []*ParameterSchema
+ for _, variant := range variants {
+ switch variant.SchemaType {
+ case SchemaTypeArray:
+ panic(fmt.Errorf("illegal array schema in union"))
+ case SchemaTypeUnion:
+ flattened = append(flattened, flattenUnion(variant.Variants)...)
+ default:
+ flattened = append(flattened, variant)
+ }
+ }
+ return flattened
+}
+
+func Union(variants ...*ParameterSchema) *ParameterSchema {
+ needsFlattening := false
+ for _, variant := range variants {
+ if variant.SchemaType == SchemaTypeArray {
+ panic(fmt.Errorf("illegal array schema in union"))
+ } else if variant.SchemaType == SchemaTypeUnion {
+ needsFlattening = true
+ }
+ }
+ if needsFlattening {
+ variants = flattenUnion(variants)
+ }
+ return &ParameterSchema{
+ SchemaType: SchemaTypeUnion,
+ Variants: variants,
+ }
+}
+
+func Array(items *ParameterSchema) *ParameterSchema {
+ if items.SchemaType == SchemaTypeArray {
+ panic(fmt.Errorf("illegal array schema in array"))
+ }
+ return &ParameterSchema{
+ SchemaType: SchemaTypeArray,
+ Items: items,
+ }
+}
+
+func (ps *ParameterSchema) GetDefaultValue() any {
+ if ps == nil {
+ return nil
+ }
+ switch ps.SchemaType {
+ case SchemaTypePrimitive:
+ switch ps.Type {
+ case PrimitiveTypeInteger:
+ return 0
+ case PrimitiveTypeBoolean:
+ return false
+ default:
+ return ""
+ }
+ case SchemaTypeArray:
+ return []any{}
+ case SchemaTypeUnion:
+ if len(ps.Variants) > 0 {
+ return ps.Variants[0].GetDefaultValue()
+ }
+ return nil
+ case SchemaTypeLiteral:
+ return ps.Value
+ default:
+ return nil
+ }
+}
+
+func (ps *ParameterSchema) IsValid() bool {
+ return ps.validate("") == nil
+}
+
+func (ps *ParameterSchema) Validate() error {
+ return ps.validate("")
+}
+
+func (ps *ParameterSchema) validate(parent SchemaType) error {
+ if ps == nil {
+ return fmt.Errorf("schema is nil")
+ }
+ switch ps.SchemaType {
+ case SchemaTypePrimitive:
+ if !ps.Type.IsValid() {
+ return fmt.Errorf("invalid primitive type %s", ps.Type)
+ } else if ps.Items != nil || ps.Variants != nil || ps.Value != nil {
+ return fmt.Errorf("primitive schema has extra fields")
+ }
+ return nil
+ case SchemaTypeArray:
+ if parent != "" {
+ return fmt.Errorf("arrays can't be nested in other types")
+ } else if err := ps.Items.validate(ps.SchemaType); err != nil {
+ return fmt.Errorf("item schema is invalid: %w", err)
+ } else if ps.Type != "" || ps.Variants != nil || ps.Value != nil {
+ return fmt.Errorf("array schema has extra fields")
+ }
+ return nil
+ case SchemaTypeUnion:
+ if len(ps.Variants) == 0 {
+ return fmt.Errorf("no variants specified for union")
+ } else if parent != "" && parent != SchemaTypeArray {
+ return fmt.Errorf("unions can't be nested in anything other than arrays")
+ }
+ for i, v := range ps.Variants {
+ if err := v.validate(ps.SchemaType); err != nil {
+ return fmt.Errorf("variant #%d is invalid: %w", i+1, err)
+ }
+ }
+ if ps.Type != "" || ps.Items != nil || ps.Value != nil {
+ return fmt.Errorf("union schema has extra fields")
+ }
+ return nil
+ case SchemaTypeLiteral:
+ switch typedVal := ps.Value.(type) {
+ case string, float64, int, int64, json.Number, bool, RoomIDValue, *RoomIDValue:
+ // ok
+ case map[string]any:
+ if typedVal["type"] != "event_id" && typedVal["type"] != "room_id" {
+ return fmt.Errorf("literal value has invalid map data")
+ }
+ default:
+ return fmt.Errorf("literal value has unsupported type %T", ps.Value)
+ }
+ if ps.Type != "" || ps.Items != nil || ps.Variants != nil {
+ return fmt.Errorf("literal schema has extra fields")
+ }
+ return nil
+ default:
+ return fmt.Errorf("invalid schema type %s", ps.SchemaType)
+ }
+}
+
+func (ps *ParameterSchema) Equals(other *ParameterSchema) bool {
+ if ps == nil || other == nil {
+ return ps == other
+ }
+ return ps.SchemaType == other.SchemaType &&
+ ps.Type == other.Type &&
+ ps.Items.Equals(other.Items) &&
+ slices.EqualFunc(ps.Variants, other.Variants, (*ParameterSchema).Equals) &&
+ ps.Value == other.Value // TODO this won't work for room/event ID values
+}
+
+func (ps *ParameterSchema) AllowsPrimitive(prim PrimitiveType) bool {
+ switch ps.SchemaType {
+ case SchemaTypePrimitive:
+ return ps.Type == prim
+ case SchemaTypeUnion:
+ for _, variant := range ps.Variants {
+ if variant.AllowsPrimitive(prim) {
+ return true
+ }
+ }
+ return false
+ case SchemaTypeArray:
+ return ps.Items.AllowsPrimitive(prim)
+ default:
+ return false
+ }
+}
diff --git a/event/cmdschema/parse.go b/event/cmdschema/parse.go
new file mode 100644
index 00000000..92e69b60
--- /dev/null
+++ b/event/cmdschema/parse.go
@@ -0,0 +1,478 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package cmdschema
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "regexp"
+ "strconv"
+ "strings"
+
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+)
+
+const botArrayOpener = "<"
+const botArrayCloser = ">"
+
+func parseQuoted(val string) (parsed, remaining string, quoted bool) {
+ if len(val) == 0 {
+ return
+ }
+ if !strings.HasPrefix(val, `"`) {
+ spaceIdx := strings.IndexByte(val, ' ')
+ if spaceIdx == -1 {
+ parsed = val
+ } else {
+ parsed = val[:spaceIdx]
+ remaining = strings.TrimLeft(val[spaceIdx+1:], " ")
+ }
+ return
+ }
+ val = val[1:]
+ var buf strings.Builder
+ for {
+ quoteIdx := strings.IndexByte(val, '"')
+ var valUntilQuote string
+ if quoteIdx == -1 {
+ valUntilQuote = val
+ } else {
+ valUntilQuote = val[:quoteIdx]
+ }
+ escapeIdx := strings.IndexByte(valUntilQuote, '\\')
+ if escapeIdx >= 0 {
+ buf.WriteString(val[:escapeIdx])
+ if len(val) > escapeIdx+1 {
+ buf.WriteByte(val[escapeIdx+1])
+ }
+ val = val[min(escapeIdx+2, len(val)):]
+ } else if quoteIdx >= 0 {
+ buf.WriteString(val[:quoteIdx])
+ val = val[quoteIdx+1:]
+ break
+ } else if buf.Len() == 0 {
+ // Unterminated quote, no escape characters, val is the whole input
+ return val, "", true
+ } else {
+ // Unterminated quote, but there were escape characters previously
+ buf.WriteString(val)
+ val = ""
+ break
+ }
+ }
+ return buf.String(), strings.TrimLeft(val, " "), true
+}
+
+// ParseInput tries to parse the given text into a bot command event matching this command definition.
+//
+// If the prefix doesn't match, this will return a nil content and nil error.
+// If the prefix does match, some content is always returned, but there may still be an error if parsing failed.
+func (ec *EventContent) ParseInput(owner id.UserID, sigils []string, input string) (content *event.MessageEventContent, err error) {
+ prefix := ec.parsePrefix(input, sigils, owner.String())
+ if prefix == "" {
+ return nil, nil
+ }
+ content = &event.MessageEventContent{
+ MsgType: event.MsgText,
+ Body: input,
+ Mentions: &event.Mentions{UserIDs: []id.UserID{owner}},
+ MSC4391BotCommand: &event.MSC4391BotCommandInput{
+ Command: ec.Command,
+ },
+ }
+ content.MSC4391BotCommand.Arguments, err = ec.ParseArguments(input[len(prefix):])
+ return content, err
+}
+
+func (ec *EventContent) ParseArguments(input string) (json.RawMessage, error) {
+ args := make(map[string]any)
+ var retErr error
+ setError := func(err error) {
+ if err != nil && retErr == nil {
+ retErr = err
+ }
+ }
+ processParameter := func(param *Parameter, isLast, isTail, isNamed bool) {
+ origInput := input
+ var nextVal string
+ var wasQuoted bool
+ if param.Schema.SchemaType == SchemaTypeArray {
+ hasOpener := strings.HasPrefix(input, botArrayOpener)
+ arrayClosed := false
+ if hasOpener {
+ input = input[len(botArrayOpener):]
+ if strings.HasPrefix(input, botArrayCloser) {
+ input = strings.TrimLeft(input[len(botArrayCloser):], " ")
+ arrayClosed = true
+ }
+ }
+ var collector []any
+ for len(input) > 0 && !arrayClosed {
+ //origInput = input
+ nextVal, input, wasQuoted = parseQuoted(input)
+ if !wasQuoted && hasOpener && strings.HasSuffix(nextVal, botArrayCloser) {
+ // The value wasn't quoted and has the array delimiter at the end, close the array
+ nextVal = strings.TrimRight(nextVal, botArrayCloser)
+ arrayClosed = true
+ } else if hasOpener && strings.HasPrefix(input, botArrayCloser) {
+ // The value was quoted or there was a space, and the next character is the
+ // array delimiter, close the array
+ input = strings.TrimLeft(input[len(botArrayCloser):], " ")
+ arrayClosed = true
+ } else if !hasOpener && !isLast {
+ // For array arguments in the middle without the <> delimiters, stop after the first item
+ arrayClosed = true
+ }
+ parsedVal, err := param.Schema.Items.ParseString(nextVal)
+ if err == nil {
+ collector = append(collector, parsedVal)
+ } else if hasOpener || isLast {
+ setError(fmt.Errorf("failed to parse item #%d of array %s: %w", len(collector)+1, param.Key, err))
+ } else {
+ //input = origInput
+ }
+ }
+ args[param.Key] = collector
+ } else {
+ nextVal, input, wasQuoted = parseQuoted(input)
+ if (isLast || isTail) && !wasQuoted && len(input) > 0 {
+ // If the last argument is not quoted, just treat the rest of the string
+ // as the argument without escapes (arguments with escapes should be quoted).
+ nextVal += " " + input
+ input = ""
+ }
+ // Special case for named boolean parameters: if no value is given, treat it as true
+ if nextVal == "" && !wasQuoted && isNamed && param.Schema.AllowsPrimitive(PrimitiveTypeBoolean) {
+ args[param.Key] = true
+ return
+ }
+ if nextVal == "" && !wasQuoted && !isNamed && !param.Optional {
+ setError(fmt.Errorf("missing value for required parameter %s", param.Key))
+ }
+ parsedVal, err := param.Schema.ParseString(nextVal)
+ if err != nil {
+ args[param.Key] = param.GetDefaultValue()
+ // For optional parameters that fail to parse, restore the input and try passing it as the next parameter
+ if param.Optional && !isLast && !isNamed {
+ input = strings.TrimLeft(origInput, " ")
+ } else if !param.Optional || isNamed {
+ setError(fmt.Errorf("failed to parse %s: %w", param.Key, err))
+ }
+ } else {
+ args[param.Key] = parsedVal
+ }
+ }
+ }
+ skipParams := make([]bool, len(ec.Parameters))
+ for i, param := range ec.Parameters {
+ for strings.HasPrefix(input, "--") {
+ nameEndIdx := strings.IndexAny(input, " =")
+ if nameEndIdx == -1 {
+ nameEndIdx = len(input)
+ }
+ overrideParam, paramIdx := ec.parameterByName(input[2:nameEndIdx])
+ if overrideParam != nil {
+ // Trim the equals sign, but leave spaces alone to let parseQuoted treat it as empty input
+ input = strings.TrimPrefix(input[nameEndIdx:], "=")
+ skipParams[paramIdx] = true
+ processParameter(overrideParam, false, false, true)
+ } else {
+ break
+ }
+ }
+ isTail := param.Key == ec.TailParam
+ if skipParams[i] || (param.Optional && !isTail) {
+ continue
+ }
+ processParameter(param, i == len(ec.Parameters)-1, isTail, false)
+ }
+ jsonArgs, marshalErr := json.Marshal(args)
+ if marshalErr != nil {
+ return nil, fmt.Errorf("failed to marshal arguments: %w", marshalErr)
+ }
+ return jsonArgs, retErr
+}
+
+func (ec *EventContent) parameterByName(name string) (*Parameter, int) {
+ for i, param := range ec.Parameters {
+ if strings.EqualFold(param.Key, name) {
+ return param, i
+ }
+ }
+ return nil, -1
+}
+
+func (ec *EventContent) parsePrefix(origInput string, sigils []string, owner string) (prefix string) {
+ input := origInput
+ var chosenSigil string
+ for _, sigil := range sigils {
+ if strings.HasPrefix(input, sigil) {
+ chosenSigil = sigil
+ break
+ }
+ }
+ if chosenSigil == "" {
+ return ""
+ }
+ input = input[len(chosenSigil):]
+ var chosenAlias string
+ if !strings.HasPrefix(input, ec.Command) {
+ for _, alias := range ec.Aliases {
+ if strings.HasPrefix(input, alias) {
+ chosenAlias = alias
+ break
+ }
+ }
+ if chosenAlias == "" {
+ return ""
+ }
+ } else {
+ chosenAlias = ec.Command
+ }
+ input = strings.TrimPrefix(input[len(chosenAlias):], owner)
+ if input == "" || input[0] == ' ' {
+ input = strings.TrimLeft(input, " ")
+ return origInput[:len(origInput)-len(input)]
+ }
+ return ""
+}
+
+func (pt PrimitiveType) ValidateValue(value any) bool {
+ _, err := pt.NormalizeValue(value)
+ return err == nil
+}
+
+func normalizeNumber(value any) (int, error) {
+ switch typedValue := value.(type) {
+ case int:
+ return typedValue, nil
+ case int64:
+ return int(typedValue), nil
+ case float64:
+ return int(typedValue), nil
+ case json.Number:
+ if i, err := typedValue.Int64(); err != nil {
+ return 0, fmt.Errorf("failed to parse json.Number: %w", err)
+ } else {
+ return int(i), nil
+ }
+ default:
+ return 0, fmt.Errorf("unsupported type %T for integer", value)
+ }
+}
+
+func (pt PrimitiveType) NormalizeValue(value any) (any, error) {
+ switch pt {
+ case PrimitiveTypeInteger:
+ return normalizeNumber(value)
+ case PrimitiveTypeBoolean:
+ bv, ok := value.(bool)
+ if !ok {
+ return nil, fmt.Errorf("unsupported type %T for boolean", value)
+ }
+ return bv, nil
+ case PrimitiveTypeString, PrimitiveTypeServerName:
+ str, ok := value.(string)
+ if !ok {
+ return nil, fmt.Errorf("unsupported type %T for string", value)
+ }
+ return str, pt.validateStringValue(str)
+ case PrimitiveTypeUserID, PrimitiveTypeRoomAlias:
+ str, ok := value.(string)
+ if !ok {
+ return nil, fmt.Errorf("unsupported type %T for user ID or room alias", value)
+ } else if plainErr := pt.validateStringValue(str); plainErr == nil {
+ return str, nil
+ } else if parsed, err := id.ParseMatrixURIOrMatrixToURL(str); err != nil {
+ return nil, fmt.Errorf("couldn't parse %q as plain ID nor matrix URI: %w / %w", value, plainErr, err)
+ } else if parsed.Sigil1 == '@' && pt == PrimitiveTypeUserID {
+ return parsed.UserID(), nil
+ } else if parsed.Sigil1 == '#' && pt == PrimitiveTypeRoomAlias {
+ return parsed.RoomAlias(), nil
+ } else {
+ return nil, fmt.Errorf("unexpected sigil %c for user ID or room alias", parsed.Sigil1)
+ }
+ case PrimitiveTypeRoomID, PrimitiveTypeEventID:
+ riv, err := NormalizeRoomIDValue(value)
+ if err != nil {
+ return nil, err
+ }
+ return riv, riv.Validate()
+ default:
+ return nil, fmt.Errorf("cannot normalize value for argument type %s", pt)
+ }
+}
+
+func (pt PrimitiveType) validateStringValue(value string) error {
+ switch pt {
+ case PrimitiveTypeString:
+ return nil
+ case PrimitiveTypeServerName:
+ if !id.ValidateServerName(value) {
+ return fmt.Errorf("invalid server name: %q", value)
+ }
+ return nil
+ case PrimitiveTypeUserID:
+ _, _, err := id.UserID(value).ParseAndValidateRelaxed()
+ return err
+ case PrimitiveTypeRoomAlias:
+ sigil, localpart, serverName := id.ParseCommonIdentifier(value)
+ if sigil != '#' || localpart == "" || serverName == "" {
+ return fmt.Errorf("invalid room alias: %q", value)
+ } else if !id.ValidateServerName(serverName) {
+ return fmt.Errorf("invalid server name in room alias: %q", serverName)
+ }
+ return nil
+ default:
+ panic(fmt.Errorf("validateStringValue called with invalid type %s", pt))
+ }
+}
+
+func parseBoolean(val string) (bool, error) {
+ if len(val) == 0 {
+ return false, fmt.Errorf("cannot parse empty string as boolean")
+ }
+ switch strings.ToLower(val) {
+ case "t", "true", "y", "yes", "1":
+ return true, nil
+ case "f", "false", "n", "no", "0":
+ return false, nil
+ default:
+ return false, fmt.Errorf("invalid boolean string: %q", val)
+ }
+}
+
+var markdownLinkRegex = regexp.MustCompile(`^\[.+]\(([^)]+)\)$`)
+
+func parseRoomOrEventID(value string) (*RoomIDValue, error) {
+ if strings.HasPrefix(value, "[") && strings.Contains(value, "](") && strings.HasSuffix(value, ")") {
+ matches := markdownLinkRegex.FindStringSubmatch(value)
+ if len(matches) == 2 {
+ value = matches[1]
+ }
+ }
+ parsed, err := id.ParseMatrixURIOrMatrixToURL(value)
+ if err != nil && strings.HasPrefix(value, "!") {
+ return &RoomIDValue{
+ Type: PrimitiveTypeRoomID,
+ RoomID: id.RoomID(value),
+ }, nil
+ }
+ if err != nil {
+ return nil, err
+ } else if parsed.Sigil1 != '!' {
+ return nil, fmt.Errorf("unexpected sigil %c for room ID", parsed.Sigil1)
+ } else if parsed.MXID2 != "" && parsed.Sigil2 != '$' {
+ return nil, fmt.Errorf("unexpected sigil %c for event ID", parsed.Sigil2)
+ }
+ valType := PrimitiveTypeRoomID
+ if parsed.MXID2 != "" {
+ valType = PrimitiveTypeEventID
+ }
+ return &RoomIDValue{
+ Type: valType,
+ RoomID: parsed.RoomID(),
+ Via: parsed.Via,
+ EventID: parsed.EventID(),
+ }, nil
+}
+
+func (pt PrimitiveType) ParseString(value string) (any, error) {
+ switch pt {
+ case PrimitiveTypeInteger:
+ return strconv.Atoi(value)
+ case PrimitiveTypeBoolean:
+ return parseBoolean(value)
+ case PrimitiveTypeString, PrimitiveTypeServerName, PrimitiveTypeUserID:
+ return value, pt.validateStringValue(value)
+ case PrimitiveTypeRoomAlias:
+ plainErr := pt.validateStringValue(value)
+ if plainErr == nil {
+ return value, nil
+ }
+ parsed, err := id.ParseMatrixURIOrMatrixToURL(value)
+ if err != nil {
+ return nil, fmt.Errorf("couldn't parse %q as plain room alias nor matrix URI: %w / %w", value, plainErr, err)
+ } else if parsed.Sigil1 != '#' {
+ return nil, fmt.Errorf("unexpected sigil %c for room alias", parsed.Sigil1)
+ }
+ return parsed.RoomAlias(), nil
+ case PrimitiveTypeRoomID, PrimitiveTypeEventID:
+ parsed, err := parseRoomOrEventID(value)
+ if err != nil {
+ return nil, err
+ } else if pt != parsed.Type {
+ return nil, fmt.Errorf("mismatching argument type: expected %s but got %s", pt, parsed.Type)
+ }
+ return parsed, nil
+ default:
+ return nil, fmt.Errorf("cannot parse string for argument type %s", pt)
+ }
+}
+
+func (ps *ParameterSchema) ParseString(value string) (any, error) {
+ if ps == nil {
+ return nil, fmt.Errorf("parameter schema is nil")
+ }
+ switch ps.SchemaType {
+ case SchemaTypePrimitive:
+ return ps.Type.ParseString(value)
+ case SchemaTypeLiteral:
+ switch typedValue := ps.Value.(type) {
+ case string:
+ if value == typedValue {
+ return typedValue, nil
+ } else {
+ return nil, fmt.Errorf("literal value %q does not match %q", typedValue, value)
+ }
+ case int, int64, float64, json.Number:
+ expectedVal, _ := normalizeNumber(typedValue)
+ intVal, err := strconv.Atoi(value)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse integer literal: %w", err)
+ } else if intVal != expectedVal {
+ return nil, fmt.Errorf("literal value %d does not match %d", expectedVal, intVal)
+ }
+ return intVal, nil
+ case bool:
+ boolVal, err := parseBoolean(value)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse boolean literal: %w", err)
+ } else if boolVal != typedValue {
+ return nil, fmt.Errorf("literal value %t does not match %t", typedValue, boolVal)
+ }
+ return boolVal, nil
+ case RoomIDValue, *RoomIDValue, map[string]any, json.RawMessage:
+ expectedVal, _ := NormalizeRoomIDValue(typedValue)
+ parsed, err := parseRoomOrEventID(value)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse room or event ID literal: %w", err)
+ } else if !parsed.Equals(expectedVal) {
+ return nil, fmt.Errorf("literal value %s does not match %s", expectedVal, parsed)
+ }
+ return parsed, nil
+ default:
+ return nil, fmt.Errorf("unsupported literal type %T", ps.Value)
+ }
+ case SchemaTypeUnion:
+ var errs []error
+ for _, variant := range ps.Variants {
+ if parsed, err := variant.ParseString(value); err == nil {
+ return parsed, nil
+ } else {
+ errs = append(errs, err)
+ }
+ }
+ return nil, fmt.Errorf("no union variant matched: %w", errors.Join(errs...))
+ case SchemaTypeArray:
+ return nil, fmt.Errorf("cannot parse string for array schema type")
+ default:
+ return nil, fmt.Errorf("unknown schema type %s", ps.SchemaType)
+ }
+}
diff --git a/event/cmdschema/parse_test.go b/event/cmdschema/parse_test.go
new file mode 100644
index 00000000..1e0d1817
--- /dev/null
+++ b/event/cmdschema/parse_test.go
@@ -0,0 +1,118 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package cmdschema
+
+import (
+ "bytes"
+ "encoding/json"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "go.mau.fi/util/exbytes"
+ "go.mau.fi/util/exerrors"
+
+ "maunium.net/go/mautrix/event/cmdschema/testdata"
+)
+
+type QuoteParseOutput struct {
+ Parsed string
+ Remaining string
+ Quoted bool
+}
+
+func (qpo *QuoteParseOutput) UnmarshalJSON(data []byte) error {
+ var arr []any
+ if err := json.Unmarshal(data, &arr); err != nil {
+ return err
+ }
+ qpo.Parsed = arr[0].(string)
+ qpo.Remaining = arr[1].(string)
+ qpo.Quoted = arr[2].(bool)
+ return nil
+}
+
+type QuoteParseTestData struct {
+ Name string `json:"name"`
+ Input string `json:"input"`
+ Output QuoteParseOutput `json:"output"`
+}
+
+func loadFile[T any](name string) (into T) {
+ quoteData := exerrors.Must(testdata.FS.ReadFile(name))
+ exerrors.PanicIfNotNil(json.Unmarshal(quoteData, &into))
+ return
+}
+
+func TestParseQuoted(t *testing.T) {
+ qptd := loadFile[[]QuoteParseTestData]("parse_quote.json")
+ for _, test := range qptd {
+ t.Run(test.Name, func(t *testing.T) {
+ parsed, remaining, quoted := parseQuoted(test.Input)
+ assert.Equalf(t, test.Output, QuoteParseOutput{
+ Parsed: parsed,
+ Remaining: remaining,
+ Quoted: quoted,
+ }, "Failed with input `%s`", test.Input)
+ // Note: can't just test that requoted == input, because some inputs
+ // have unnecessary escapes which won't survive roundtripping
+ t.Run("roundtrip", func(t *testing.T) {
+ requoted := quoteString(parsed) + " " + remaining
+ reparsed, newRemaining, _ := parseQuoted(requoted)
+ assert.Equal(t, parsed, reparsed)
+ assert.Equal(t, remaining, newRemaining)
+ })
+ })
+ }
+}
+
+type CommandTestData struct {
+ Spec *EventContent
+ Tests []*CommandTestUnit
+}
+
+type CommandTestUnit struct {
+ Name string `json:"name"`
+ Input string `json:"input"`
+ Broken string `json:"broken,omitempty"`
+ Error bool `json:"error"`
+ Output json.RawMessage `json:"output"`
+}
+
+func compactJSON(input json.RawMessage) json.RawMessage {
+ var buf bytes.Buffer
+ exerrors.PanicIfNotNil(json.Compact(&buf, input))
+ return buf.Bytes()
+}
+
+func TestMSC4391BotCommandEventContent_ParseInput(t *testing.T) {
+ for _, cmd := range exerrors.Must(testdata.FS.ReadDir("commands")) {
+ t.Run(strings.TrimSuffix(cmd.Name(), ".json"), func(t *testing.T) {
+ ctd := loadFile[CommandTestData]("commands/" + cmd.Name())
+ for _, test := range ctd.Tests {
+ outputStr := exbytes.UnsafeString(compactJSON(test.Output))
+ t.Run(test.Name, func(t *testing.T) {
+ if test.Broken != "" {
+ t.Skip(test.Broken)
+ }
+ output, err := ctd.Spec.ParseInput("@testbot", []string{"/"}, test.Input)
+ if test.Error {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ }
+ if outputStr == "null" {
+ assert.Nil(t, output)
+ } else {
+ assert.Equal(t, ctd.Spec.Command, output.MSC4391BotCommand.Command)
+ assert.Equalf(t, outputStr, exbytes.UnsafeString(output.MSC4391BotCommand.Arguments), "Input: %s", test.Input)
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/event/cmdschema/roomid.go b/event/cmdschema/roomid.go
new file mode 100644
index 00000000..98c421fc
--- /dev/null
+++ b/event/cmdschema/roomid.go
@@ -0,0 +1,135 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package cmdschema
+
+import (
+ "encoding/json"
+ "fmt"
+ "slices"
+ "strings"
+
+ "maunium.net/go/mautrix/id"
+)
+
+var ParameterSchemaJoinableRoom = Union(
+ PrimitiveTypeRoomID.Schema(),
+ PrimitiveTypeRoomAlias.Schema(),
+)
+
+type RoomIDValue struct {
+ Type PrimitiveType `json:"type"`
+ RoomID id.RoomID `json:"id"`
+ Via []string `json:"via,omitempty"`
+ EventID id.EventID `json:"event_id,omitempty"`
+}
+
+func NormalizeRoomIDValue(input any) (riv *RoomIDValue, err error) {
+ switch typedValue := input.(type) {
+ case map[string]any, json.RawMessage:
+ var raw json.RawMessage
+ if raw, err = json.Marshal(input); err != nil {
+ err = fmt.Errorf("failed to roundtrip room ID value: %w", err)
+ } else if err = json.Unmarshal(raw, &riv); err != nil {
+ err = fmt.Errorf("failed to roundtrip room ID value: %w", err)
+ }
+ case *RoomIDValue:
+ riv = typedValue
+ case RoomIDValue:
+ riv = &typedValue
+ default:
+ err = fmt.Errorf("unsupported type %T for room or event ID", input)
+ }
+ return
+}
+
+func (riv *RoomIDValue) String() string {
+ return riv.URI().String()
+}
+
+func (riv *RoomIDValue) URI() *id.MatrixURI {
+ if riv == nil {
+ return nil
+ }
+ switch riv.Type {
+ case PrimitiveTypeRoomID:
+ return riv.RoomID.URI(riv.Via...)
+ case PrimitiveTypeEventID:
+ return riv.RoomID.EventURI(riv.EventID, riv.Via...)
+ default:
+ return nil
+ }
+}
+
+func (riv *RoomIDValue) Equals(other *RoomIDValue) bool {
+ if riv == nil || other == nil {
+ return riv == other
+ }
+ return riv.Type == other.Type &&
+ riv.RoomID == other.RoomID &&
+ riv.EventID == other.EventID &&
+ slices.Equal(riv.Via, other.Via)
+}
+
+func (riv *RoomIDValue) Validate() error {
+ if riv == nil {
+ return fmt.Errorf("value is nil")
+ }
+ switch riv.Type {
+ case PrimitiveTypeRoomID:
+ if riv.EventID != "" {
+ return fmt.Errorf("event ID must be empty for room ID type")
+ }
+ case PrimitiveTypeEventID:
+ if !strings.HasPrefix(riv.EventID.String(), "$") {
+ return fmt.Errorf("event ID not valid: %q", riv.EventID)
+ }
+ default:
+ return fmt.Errorf("unexpected type %s for room/event ID value", riv.Type)
+ }
+ for _, via := range riv.Via {
+ if !id.ValidateServerName(via) {
+ return fmt.Errorf("invalid server name %q in vias", via)
+ }
+ }
+ sigil, localpart, serverName := id.ParseCommonIdentifier(riv.RoomID)
+ if sigil != '!' {
+ return fmt.Errorf("room ID does not start with !: %q", riv.RoomID)
+ } else if localpart == "" && serverName == "" {
+ return fmt.Errorf("room ID has empty localpart and server name: %q", riv.RoomID)
+ } else if serverName != "" && !id.ValidateServerName(serverName) {
+ return fmt.Errorf("invalid server name %q in room ID", serverName)
+ }
+ return nil
+}
+
+func (riv *RoomIDValue) IsValid() bool {
+ return riv.Validate() == nil
+}
+
+type RoomIDOrString string
+
+func (ros *RoomIDOrString) UnmarshalJSON(data []byte) error {
+ if len(data) == 0 {
+ return fmt.Errorf("empty data for room ID or string")
+ }
+ if data[0] == '"' {
+ var str string
+ if err := json.Unmarshal(data, &str); err != nil {
+ return err
+ }
+ *ros = RoomIDOrString(str)
+ return nil
+ }
+ var riv RoomIDValue
+ if err := json.Unmarshal(data, &riv); err != nil {
+ return err
+ } else if err = riv.Validate(); err != nil {
+ return err
+ }
+ *ros = RoomIDOrString(riv.String())
+ return nil
+}
diff --git a/event/cmdschema/stringify.go b/event/cmdschema/stringify.go
new file mode 100644
index 00000000..c5c57c53
--- /dev/null
+++ b/event/cmdschema/stringify.go
@@ -0,0 +1,122 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package cmdschema
+
+import (
+ "encoding/json"
+ "strconv"
+ "strings"
+)
+
+var quoteEscaper = strings.NewReplacer(
+ `"`, `\"`,
+ `\`, `\\`,
+)
+
+const charsToQuote = ` \` + botArrayOpener + botArrayCloser
+
+func quoteString(val string) string {
+ if val == "" {
+ return `""`
+ }
+ val = quoteEscaper.Replace(val)
+ if strings.ContainsAny(val, charsToQuote) {
+ return `"` + val + `"`
+ }
+ return val
+}
+
+func (ec *EventContent) StringifyArgs(args any) string {
+ var argMap map[string]any
+ switch typedArgs := args.(type) {
+ case json.RawMessage:
+ err := json.Unmarshal(typedArgs, &argMap)
+ if err != nil {
+ return ""
+ }
+ case map[string]any:
+ argMap = typedArgs
+ default:
+ if b, err := json.Marshal(args); err != nil {
+ return ""
+ } else if err = json.Unmarshal(b, &argMap); err != nil {
+ return ""
+ }
+ }
+ parts := make([]string, 0, len(ec.Parameters))
+ for i, param := range ec.Parameters {
+ isLast := i == len(ec.Parameters)-1
+ val := argMap[param.Key]
+ if val == nil {
+ val = param.DefaultValue
+ if val == nil && !param.Optional {
+ val = param.Schema.GetDefaultValue()
+ }
+ }
+ if val == nil {
+ continue
+ }
+ var stringified string
+ if param.Schema.SchemaType == SchemaTypeArray {
+ stringified = arrayArgumentToString(val, isLast)
+ } else {
+ stringified = singleArgumentToString(val)
+ }
+ if stringified != "" {
+ parts = append(parts, stringified)
+ }
+ }
+ return strings.Join(parts, " ")
+}
+
+func arrayArgumentToString(val any, isLast bool) string {
+ valArr, ok := val.([]any)
+ if !ok {
+ return ""
+ }
+ parts := make([]string, 0, len(valArr))
+ for _, elem := range valArr {
+ stringified := singleArgumentToString(elem)
+ if stringified != "" {
+ parts = append(parts, stringified)
+ }
+ }
+ joinedParts := strings.Join(parts, " ")
+ if isLast && len(parts) > 0 {
+ return joinedParts
+ }
+ return botArrayOpener + joinedParts + botArrayCloser
+}
+
+func singleArgumentToString(val any) string {
+ switch typedVal := val.(type) {
+ case string:
+ return quoteString(typedVal)
+ case json.Number:
+ return typedVal.String()
+ case bool:
+ return strconv.FormatBool(typedVal)
+ case int:
+ return strconv.Itoa(typedVal)
+ case int64:
+ return strconv.FormatInt(typedVal, 10)
+ case float64:
+ return strconv.FormatInt(int64(typedVal), 10)
+ case map[string]any, json.RawMessage, RoomIDValue, *RoomIDValue:
+ normalized, err := NormalizeRoomIDValue(typedVal)
+ if err != nil {
+ return ""
+ }
+ uri := normalized.URI()
+ if uri == nil {
+ return ""
+ }
+ return quoteString(uri.String())
+ default:
+ return ""
+ }
+}
diff --git a/event/cmdschema/testdata/commands.schema.json b/event/cmdschema/testdata/commands.schema.json
new file mode 100644
index 00000000..e53382db
--- /dev/null
+++ b/event/cmdschema/testdata/commands.schema.json
@@ -0,0 +1,281 @@
+{
+ "$schema": "https://json-schema.org/draft/2020-12/schema#",
+ "$id": "commands.schema.json",
+ "title": "ParseInput test cases",
+ "description": "JSON schema for test case files containing command specifications and test cases",
+ "type": "object",
+ "required": [
+ "spec",
+ "tests"
+ ],
+ "additionalProperties": false,
+ "properties": {
+ "spec": {
+ "title": "MSC4391 Command Description",
+ "description": "JSON schema defining the structure of a bot command event content",
+ "type": "object",
+ "required": [
+ "command"
+ ],
+ "additionalProperties": false,
+ "properties": {
+ "command": {
+ "type": "string",
+ "description": "The command name that triggers this bot command"
+ },
+ "aliases": {
+ "type": "array",
+ "description": "Alternative names/aliases for this command",
+ "items": {
+ "type": "string"
+ }
+ },
+ "parameters": {
+ "type": "array",
+ "description": "List of parameters accepted by this command",
+ "items": {
+ "$ref": "#/$defs/Parameter"
+ }
+ },
+ "description": {
+ "$ref": "#/$defs/ExtensibleTextContainer",
+ "description": "Human-readable description of the command"
+ },
+ "fi.mau.tail_parameter": {
+ "type": "string",
+ "description": "The key of the parameter that accepts remaining arguments as tail text"
+ },
+ "source": {
+ "type": "string",
+ "description": "The user ID of the bot that responds to this command"
+ }
+ }
+ },
+ "tests": {
+ "type": "array",
+ "description": "Array of test cases for the command",
+ "items": {
+ "type": "object",
+ "description": "A single test case for command parsing",
+ "required": [
+ "name",
+ "input"
+ ],
+ "additionalProperties": false,
+ "properties": {
+ "name": {
+ "type": "string",
+ "description": "The name of the test case"
+ },
+ "input": {
+ "type": "string",
+ "description": "The command input string to parse"
+ },
+ "output": {
+ "description": "The expected parsed parameter values, or null if the parsing is expected to fail",
+ "oneOf": [
+ {
+ "type": "object",
+ "additionalProperties": true
+ },
+ {
+ "type": "null"
+ }
+ ]
+ },
+ "error": {
+ "type": "boolean",
+ "description": "Whether parsing should result in an error. May still produce output.",
+ "default": false
+ }
+ }
+ }
+ }
+ },
+ "$defs": {
+ "ExtensibleTextContainer": {
+ "type": "object",
+ "description": "Container for text that can have multiple representations",
+ "required": [
+ "m.text"
+ ],
+ "properties": {
+ "m.text": {
+ "type": "array",
+ "description": "Array of text representations in different formats",
+ "items": {
+ "$ref": "#/$defs/ExtensibleText"
+ }
+ }
+ }
+ },
+ "ExtensibleText": {
+ "type": "object",
+ "description": "A text representation with a specific MIME type",
+ "required": [
+ "body"
+ ],
+ "properties": {
+ "body": {
+ "type": "string",
+ "description": "The text content"
+ },
+ "mimetype": {
+ "type": "string",
+ "description": "The MIME type of the text (e.g., text/plain, text/html)",
+ "default": "text/plain",
+ "examples": [
+ "text/plain",
+ "text/html"
+ ]
+ }
+ }
+ },
+ "Parameter": {
+ "type": "object",
+ "description": "A parameter definition for a command",
+ "required": [
+ "key",
+ "schema"
+ ],
+ "additionalProperties": false,
+ "properties": {
+ "key": {
+ "type": "string",
+ "description": "The identifier for this parameter"
+ },
+ "schema": {
+ "$ref": "#/$defs/ParameterSchema",
+ "description": "The schema defining the type and structure of this parameter"
+ },
+ "optional": {
+ "type": "boolean",
+ "description": "Whether this parameter is optional",
+ "default": false
+ },
+ "description": {
+ "$ref": "#/$defs/ExtensibleTextContainer",
+ "description": "Human-readable description of this parameter"
+ },
+ "fi.mau.default_value": {
+ "description": "Default value for this parameter if not provided"
+ }
+ }
+ },
+ "ParameterSchema": {
+ "type": "object",
+ "description": "Schema definition for a parameter value",
+ "required": [
+ "schema_type"
+ ],
+ "additionalProperties": false,
+ "properties": {
+ "schema_type": {
+ "type": "string",
+ "enum": [
+ "primitive",
+ "array",
+ "union",
+ "literal"
+ ],
+ "description": "The type of schema"
+ }
+ },
+ "allOf": [
+ {
+ "if": {
+ "properties": {
+ "schema_type": {
+ "const": "primitive"
+ }
+ }
+ },
+ "then": {
+ "required": [
+ "type"
+ ],
+ "properties": {
+ "type": {
+ "type": "string",
+ "enum": [
+ "string",
+ "integer",
+ "boolean",
+ "server_name",
+ "user_id",
+ "room_id",
+ "room_alias",
+ "event_id"
+ ],
+ "description": "The primitive type (only for schema_type: primitive)"
+ }
+ }
+ }
+ },
+ {
+ "if": {
+ "properties": {
+ "schema_type": {
+ "const": "array"
+ }
+ }
+ },
+ "then": {
+ "required": [
+ "items"
+ ],
+ "properties": {
+ "items": {
+ "$ref": "#/$defs/ParameterSchema",
+ "description": "The schema for array items (only for schema_type: array)"
+ }
+ }
+ }
+ },
+ {
+ "if": {
+ "properties": {
+ "schema_type": {
+ "const": "union"
+ }
+ }
+ },
+ "then": {
+ "required": [
+ "variants"
+ ],
+ "properties": {
+ "variants": {
+ "type": "array",
+ "description": "The possible variants (only for schema_type: union)",
+ "items": {
+ "$ref": "#/$defs/ParameterSchema"
+ },
+ "minItems": 1
+ }
+ }
+ }
+ },
+ {
+ "if": {
+ "properties": {
+ "schema_type": {
+ "const": "literal"
+ }
+ }
+ },
+ "then": {
+ "required": [
+ "value"
+ ],
+ "properties": {
+ "value": {
+ "description": "The literal value (only for schema_type: literal)"
+ }
+ }
+ }
+ }
+ ]
+ }
+ }
+}
diff --git a/event/cmdschema/testdata/commands/flags.json b/event/cmdschema/testdata/commands/flags.json
new file mode 100644
index 00000000..6ce1f4da
--- /dev/null
+++ b/event/cmdschema/testdata/commands/flags.json
@@ -0,0 +1,126 @@
+{
+ "$schema": "../commands.schema.json#",
+ "spec": {
+ "command": "flag",
+ "source": "@testbot",
+ "parameters": [
+ {
+ "key": "meow",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "string"
+ }
+ },
+ {
+ "key": "user",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "user_id"
+ },
+ "optional": true
+ },
+ {
+ "key": "woof",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "boolean"
+ },
+ "optional": true,
+ "fi.mau.default_value": false
+ }
+ ],
+ "fi.mau.tail_parameter": "user"
+ },
+ "tests": [
+ {
+ "name": "no flags",
+ "input": "/flag mrrp",
+ "output": {
+ "meow": "mrrp",
+ "user": null
+ }
+ },
+ {
+ "name": "no flags, has tail",
+ "input": "/flag mrrp @user:example.com",
+ "output": {
+ "meow": "mrrp",
+ "user": "@user:example.com"
+ }
+ },
+ {
+ "name": "named flag at start",
+ "input": "/flag --woof=yes mrrp @user:example.com",
+ "output": {
+ "meow": "mrrp",
+ "user": "@user:example.com",
+ "woof": true
+ }
+ },
+ {
+ "name": "boolean flag without value",
+ "input": "/flag --woof mrrp @user:example.com",
+ "output": {
+ "meow": "mrrp",
+ "user": "@user:example.com",
+ "woof": true
+ }
+ },
+ {
+ "name": "user id flag without value",
+ "input": "/flag --user --woof mrrp",
+ "error": true,
+ "output": {
+ "meow": "mrrp",
+ "user": null,
+ "woof": true
+ }
+ },
+ {
+ "name": "named flag in the middle",
+ "input": "/flag mrrp --woof=yes @user:example.com",
+ "output": {
+ "meow": "mrrp",
+ "user": "@user:example.com",
+ "woof": true
+ }
+ },
+ {
+ "name": "named flag in the middle with different value",
+ "input": "/flag mrrp --woof=no @user:example.com",
+ "output": {
+ "meow": "mrrp",
+ "user": "@user:example.com",
+ "woof": false
+ }
+ },
+ {
+ "name": "all variables named",
+ "input": "/flag --woof=no --meow=mrrp --user=@user:example.com",
+ "output": {
+ "meow": "mrrp",
+ "user": "@user:example.com",
+ "woof": false
+ }
+ },
+ {
+ "name": "all variables named with quotes",
+ "input": "/flag --woof --meow=\"meow meow mrrp\" --user=\"@user:example.com\"",
+ "output": {
+ "meow": "meow meow mrrp",
+ "user": "@user:example.com",
+ "woof": true
+ }
+ },
+ {
+ "name": "invalid value for named parameter",
+ "input": "/flag --user=meowings mrrp --woof",
+ "error": true,
+ "output": {
+ "meow": "mrrp",
+ "user": null,
+ "woof": true
+ }
+ }
+ ]
+}
diff --git a/event/cmdschema/testdata/commands/room_id_or_alias.json b/event/cmdschema/testdata/commands/room_id_or_alias.json
new file mode 100644
index 00000000..1351c292
--- /dev/null
+++ b/event/cmdschema/testdata/commands/room_id_or_alias.json
@@ -0,0 +1,85 @@
+{
+ "$schema": "../commands.schema.json#",
+ "spec": {
+ "command": "test room reference",
+ "source": "@testbot",
+ "parameters": [
+ {
+ "key": "room",
+ "schema": {
+ "schema_type": "union",
+ "variants": [
+ {
+ "schema_type": "primitive",
+ "type": "room_id"
+ },
+ {
+ "schema_type": "primitive",
+ "type": "room_alias"
+ }
+ ]
+ }
+ }
+ ]
+ },
+ "tests": [
+ {
+ "name": "room alias",
+ "input": "/test room reference #test:matrix.org",
+ "output": {
+ "room": "#test:matrix.org"
+ }
+ },
+ {
+ "name": "room id",
+ "input": "/test room reference !aiwVrNhPwbGBNjqlNu:matrix.org",
+ "output": {
+ "room": {
+ "type": "room_id",
+ "id": "!aiwVrNhPwbGBNjqlNu:matrix.org"
+ }
+ }
+ },
+ {
+ "name": "room id matrix.to link",
+ "input": "/test room reference https://matrix.to/#/!aiwVrNhPwbGBNjqlNu:matrix.org?via=example.com",
+ "output": {
+ "room": {
+ "type": "room_id",
+ "id": "!aiwVrNhPwbGBNjqlNu:matrix.org",
+ "via": [
+ "example.com"
+ ]
+ }
+ }
+ },
+ {
+ "name": "room id matrix.to link with url encoding",
+ "input": "/test room reference https://matrix.to/#/!%23test%2Froom%0Aversion%20%3Cu%3E11%3C%2Fu%3E%2C%20with%20%40%F0%9F%90%88%EF%B8%8F%3Amaunium.net?via=maunium.net",
+ "broken": "Go's url.URL does url decoding on the fragment, which breaks splitting the path segments properly",
+ "output": {
+ "room": {
+ "type": "room_id",
+ "id": "!#test/room\nversion 11, with @🐈️:maunium.net",
+ "via": [
+ "maunium.net"
+ ]
+ }
+ }
+ },
+ {
+ "name": "room id matrix: URI",
+ "input": "/test room reference matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org",
+ "output": {
+ "room": {
+ "type": "room_id",
+ "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ",
+ "via": [
+ "maunium.net",
+ "matrix.org"
+ ]
+ }
+ }
+ }
+ ]
+}
diff --git a/event/cmdschema/testdata/commands/room_reference_list.json b/event/cmdschema/testdata/commands/room_reference_list.json
new file mode 100644
index 00000000..aa266054
--- /dev/null
+++ b/event/cmdschema/testdata/commands/room_reference_list.json
@@ -0,0 +1,106 @@
+{
+ "$schema": "../commands.schema.json#",
+ "spec": {
+ "command": "test room reference",
+ "source": "@testbot",
+ "parameters": [
+ {
+ "key": "rooms",
+ "schema": {
+ "schema_type": "array",
+ "items": {
+ "schema_type": "union",
+ "variants": [
+ {
+ "schema_type": "primitive",
+ "type": "room_id"
+ },
+ {
+ "schema_type": "primitive",
+ "type": "room_alias"
+ }
+ ]
+ }
+ }
+ }
+ ]
+ },
+ "tests": [
+ {
+ "name": "room alias",
+ "input": "/test room reference #test:matrix.org",
+ "output": {
+ "rooms": [
+ "#test:matrix.org"
+ ]
+ }
+ },
+ {
+ "name": "room id",
+ "input": "/test room reference !aiwVrNhPwbGBNjqlNu:matrix.org",
+ "output": {
+ "rooms": [
+ {
+ "type": "room_id",
+ "id": "!aiwVrNhPwbGBNjqlNu:matrix.org"
+ }
+ ]
+ }
+ },
+ {
+ "name": "two room ids",
+ "input": "/test room reference !mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ !aiwVrNhPwbGBNjqlNu:matrix.org",
+ "output": {
+ "rooms": [
+ {
+ "type": "room_id",
+ "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ"
+ },
+ {
+ "type": "room_id",
+ "id": "!aiwVrNhPwbGBNjqlNu:matrix.org"
+ }
+ ]
+ }
+ },
+ {
+ "name": "room id matrix: URI",
+ "input": "/test room reference matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org",
+ "output": {
+ "rooms": [
+ {
+ "type": "room_id",
+ "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ",
+ "via": [
+ "maunium.net",
+ "matrix.org"
+ ]
+ }
+ ]
+ }
+ },
+ {
+ "name": "room id matrix: URI and matrix.to URL",
+ "input": "/test room reference https://matrix.to/#/!aiwVrNhPwbGBNjqlNu:matrix.org?via=example.com matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org",
+ "output": {
+ "rooms": [
+ {
+ "type": "room_id",
+ "id": "!aiwVrNhPwbGBNjqlNu:matrix.org",
+ "via": [
+ "example.com"
+ ]
+ },
+ {
+ "type": "room_id",
+ "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ",
+ "via": [
+ "maunium.net",
+ "matrix.org"
+ ]
+ }
+ ]
+ }
+ }
+ ]
+}
diff --git a/event/cmdschema/testdata/commands/simple.json b/event/cmdschema/testdata/commands/simple.json
new file mode 100644
index 00000000..94667323
--- /dev/null
+++ b/event/cmdschema/testdata/commands/simple.json
@@ -0,0 +1,46 @@
+{
+ "$schema": "../commands.schema.json#",
+ "spec": {
+ "command": "test simple",
+ "source": "@testbot",
+ "parameters": [
+ {
+ "key": "meow",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "string"
+ }
+ }
+ ]
+ },
+ "tests": [
+ {
+ "name": "success",
+ "input": "/test simple mrrp",
+ "output": {
+ "meow": "mrrp"
+ }
+ },
+ {
+ "name": "directed success",
+ "input": "/test simple@testbot mrrp",
+ "output": {
+ "meow": "mrrp"
+ }
+ },
+ {
+ "name": "missing parameter",
+ "input": "/test simple",
+ "error": true,
+ "output": {
+ "meow": ""
+ }
+ },
+ {
+ "name": "directed at another bot",
+ "input": "/test simple@anotherbot mrrp",
+ "error": false,
+ "output": null
+ }
+ ]
+}
diff --git a/event/cmdschema/testdata/commands/tail.json b/event/cmdschema/testdata/commands/tail.json
new file mode 100644
index 00000000..9782f8ec
--- /dev/null
+++ b/event/cmdschema/testdata/commands/tail.json
@@ -0,0 +1,60 @@
+{
+ "$schema": "../commands.schema.json#",
+ "spec": {
+ "command": "tail",
+ "source": "@testbot",
+ "parameters": [
+ {
+ "key": "meow",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "string"
+ }
+ },
+ {
+ "key": "reason",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "string"
+ },
+ "optional": true
+ },
+ {
+ "key": "woof",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "boolean"
+ },
+ "optional": true
+ }
+ ],
+ "fi.mau.tail_parameter": "reason"
+ },
+ "tests": [
+ {
+ "name": "no tail or flag",
+ "input": "/tail mrrp",
+ "output": {
+ "meow": "mrrp",
+ "reason": ""
+ }
+ },
+ {
+ "name": "tail, no flag",
+ "input": "/tail mrrp meow meow",
+ "output": {
+ "meow": "mrrp",
+ "reason": "meow meow"
+ }
+ },
+ {
+ "name": "flag before tail",
+ "input": "/tail mrrp --woof meow meow",
+ "output": {
+ "meow": "mrrp",
+ "reason": "meow meow",
+ "woof": true
+ }
+ }
+ ]
+}
diff --git a/event/cmdschema/testdata/data.go b/event/cmdschema/testdata/data.go
new file mode 100644
index 00000000..eceea3d2
--- /dev/null
+++ b/event/cmdschema/testdata/data.go
@@ -0,0 +1,14 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package testdata
+
+import (
+ "embed"
+)
+
+//go:embed *
+var FS embed.FS
diff --git a/event/cmdschema/testdata/parse_quote.json b/event/cmdschema/testdata/parse_quote.json
new file mode 100644
index 00000000..8f52b7f5
--- /dev/null
+++ b/event/cmdschema/testdata/parse_quote.json
@@ -0,0 +1,30 @@
+[
+ {"name": "empty string", "input": "", "output": ["", "", false]},
+ {"name": "single word", "input": "meow", "output": ["meow", "", false]},
+ {"name": "two words", "input": "meow woof", "output": ["meow", "woof", false]},
+ {"name": "many words", "input": "meow meow mrrp", "output": ["meow", "meow mrrp", false]},
+ {"name": "extra spaces", "input": "meow meow mrrp", "output": ["meow", "meow mrrp", false]},
+ {"name": "trailing space", "input": "meow ", "output": ["meow", "", false]},
+ {"name": "only spaces", "input": " ", "output": ["", "", false]},
+ {"name": "leading spaces", "input": " meow woof", "output": ["", "meow woof", false]},
+ {"name": "backslash at end unquoted", "input": "meow\\ woof", "output": ["meow\\", "woof", false]},
+ {"name": "quoted word", "input": "\"meow\" meow mrrp", "output": ["meow", "meow mrrp", true]},
+ {"name": "quoted words", "input": "\"meow meow\" mrrp", "output": ["meow meow", "mrrp", true]},
+ {"name": "spaces in quotes", "input": "\" meow meow \" mrrp", "output": [" meow meow ", "mrrp", true]},
+ {"name": "empty quoted string", "input": "\"\"", "output": ["", "", true]},
+ {"name": "empty quoted with trailing", "input": "\"\" meow", "output": ["", "meow", true]},
+ {"name": "quote no space before next", "input": "\"meow\"woof", "output": ["meow", "woof", true]},
+ {"name": "just opening quote", "input": "\"", "output": ["", "", true]},
+ {"name": "quote then space then text", "input": "\" meow", "output": [" meow", "", true]},
+ {"name": "quotes after word", "input": "meow \" meow mrrp \"", "output": ["meow", "\" meow mrrp \"", false]},
+ {"name": "escaped quote", "input": "\"meow\\\" meow\" mrrp", "output": ["meow\" meow", "mrrp", true]},
+ {"name": "missing end quote", "input": "\"meow meow mrrp", "output": ["meow meow mrrp", "", true]},
+ {"name": "missing end quote with escaped quote", "input": "\"meow\\\" meow mrrp", "output": ["meow\" meow mrrp", "", true]},
+ {"name": "quote in the middle", "input": "me\"ow meow mrrp", "output": ["me\"ow", "meow mrrp", false]},
+ {"name": "backslash in the middle", "input": "me\\ow meow mrrp", "output": ["me\\ow", "meow mrrp", false]},
+ {"name": "other escaped character", "input": "\"m\\eow\" meow mrrp", "output": ["meow", "meow mrrp", true]},
+ {"name": "escaped backslashes", "input": "\"m\\\\e\\\"ow\\\\\" meow mrrp", "output": ["m\\e\"ow\\", "meow mrrp", true]},
+ {"name": "just quotes", "input": "\"\\\"\\\"\\\\\\\"\" meow", "output": ["\"\"\\\"", "meow", true]},
+ {"name": "escape at eof", "input": "\"meow\\", "output": ["meow", "", true]},
+ {"name": "escaped backslash at eof", "input": "\"meow\\\\", "output": ["meow\\", "", true]}
+]
diff --git a/event/cmdschema/testdata/parse_quote.schema.json b/event/cmdschema/testdata/parse_quote.schema.json
new file mode 100644
index 00000000..9f249116
--- /dev/null
+++ b/event/cmdschema/testdata/parse_quote.schema.json
@@ -0,0 +1,46 @@
+{
+ "$schema": "https://json-schema.org/draft/2020-12/schema#",
+ "$id": "parse_quote.schema.json",
+ "title": "parseQuote test cases",
+ "description": "Test cases for the parseQuoted function",
+ "type": "array",
+ "items": {
+ "type": "object",
+ "required": [
+ "name",
+ "input",
+ "output"
+ ],
+ "properties": {
+ "name": {
+ "type": "string",
+ "description": "Name of the test case"
+ },
+ "input": {
+ "type": "string",
+ "description": "Input string to be parsed"
+ },
+ "output": {
+ "type": "array",
+ "description": "Expected output of parsing: [first word, remaining text, was quoted]",
+ "minItems": 3,
+ "maxItems": 3,
+ "prefixItems": [
+ {
+ "type": "string",
+ "description": "First parsed word"
+ },
+ {
+ "type": "string",
+ "description": "Remaining text after the first word"
+ },
+ {
+ "type": "boolean",
+ "description": "Whether the first word was quoted"
+ }
+ ]
+ }
+ },
+ "additionalProperties": false
+ }
+}
diff --git a/event/content.go b/event/content.go
index ab57c658..814aeec4 100644
--- a/event/content.go
+++ b/event/content.go
@@ -18,6 +18,7 @@ import (
// This is used by Content.ParseRaw() for creating the correct type of struct.
var TypeMap = map[Type]reflect.Type{
StateMember: reflect.TypeOf(MemberEventContent{}),
+ StateThirdPartyInvite: reflect.TypeOf(ThirdPartyInviteEventContent{}),
StatePowerLevels: reflect.TypeOf(PowerLevelsEventContent{}),
StateCanonicalAlias: reflect.TypeOf(CanonicalAliasEventContent{}),
StateRoomName: reflect.TypeOf(RoomNameEventContent{}),
@@ -38,7 +39,9 @@ var TypeMap = map[Type]reflect.Type{
StateHalfShotBridge: reflect.TypeOf(BridgeEventContent{}),
StateSpaceParent: reflect.TypeOf(SpaceParentEventContent{}),
StateSpaceChild: reflect.TypeOf(SpaceChildEventContent{}),
- StateInsertionMarker: reflect.TypeOf(InsertionMarkerContent{}),
+
+ StateRoomPolicy: reflect.TypeOf(RoomPolicyEventContent{}),
+ StateUnstableRoomPolicy: reflect.TypeOf(RoomPolicyEventContent{}),
StateLegacyPolicyRoom: reflect.TypeOf(ModPolicyContent{}),
StateLegacyPolicyServer: reflect.TypeOf(ModPolicyContent{}),
@@ -48,6 +51,8 @@ var TypeMap = map[Type]reflect.Type{
StateUnstablePolicyUser: reflect.TypeOf(ModPolicyContent{}),
StateElementFunctionalMembers: reflect.TypeOf(ElementFunctionalMembersContent{}),
+ StateBeeperRoomFeatures: reflect.TypeOf(RoomFeatures{}),
+ StateBeeperDisappearingTimer: reflect.TypeOf(BeeperDisappearingTimer{}),
EventMessage: reflect.TypeOf(MessageEventContent{}),
EventSticker: reflect.TypeOf(MessageEventContent{}),
@@ -58,7 +63,11 @@ var TypeMap = map[Type]reflect.Type{
EventUnstablePollStart: reflect.TypeOf(PollStartEventContent{}),
EventUnstablePollResponse: reflect.TypeOf(PollResponseEventContent{}),
- BeeperMessageStatus: reflect.TypeOf(BeeperMessageStatusEventContent{}),
+ BeeperMessageStatus: reflect.TypeOf(BeeperMessageStatusEventContent{}),
+ BeeperTranscription: reflect.TypeOf(BeeperTranscriptionEventContent{}),
+ BeeperDeleteChat: reflect.TypeOf(BeeperChatDeleteEventContent{}),
+ BeeperAcceptMessageRequest: reflect.TypeOf(BeeperAcceptMessageRequestEventContent{}),
+ BeeperSendState: reflect.TypeOf(BeeperSendStateEventContent{}),
AccountDataRoomTags: reflect.TypeOf(TagEventContent{}),
AccountDataDirectChats: reflect.TypeOf(DirectChatsEventContent{}),
@@ -67,9 +76,11 @@ var TypeMap = map[Type]reflect.Type{
AccountDataMarkedUnread: reflect.TypeOf(MarkedUnreadEventContent{}),
AccountDataBeeperMute: reflect.TypeOf(BeeperMuteEventContent{}),
- EphemeralEventTyping: reflect.TypeOf(TypingEventContent{}),
- EphemeralEventReceipt: reflect.TypeOf(ReceiptEventContent{}),
- EphemeralEventPresence: reflect.TypeOf(PresenceEventContent{}),
+ EphemeralEventTyping: reflect.TypeOf(TypingEventContent{}),
+ EphemeralEventReceipt: reflect.TypeOf(ReceiptEventContent{}),
+ EphemeralEventPresence: reflect.TypeOf(PresenceEventContent{}),
+ EphemeralEventEncrypted: reflect.TypeOf(EncryptedEventContent{}),
+ BeeperEphemeralEventAIStream: reflect.TypeOf(BeeperAIStreamEventContent{}),
InRoomVerificationReady: reflect.TypeOf(VerificationReadyEventContent{}),
InRoomVerificationStart: reflect.TypeOf(VerificationStartEventContent{}),
@@ -121,7 +132,7 @@ var TypeMap = map[Type]reflect.Type{
// When being marshaled into JSON, the data in Parsed will be marshaled first and then recursively merged
// with the data in Raw. Values in Raw are preferred, but nested objects will be recursed into before merging,
// rather than overriding the whole object with the one in Raw).
-// If one of them is nil, the only the other is used. If both (Parsed and Raw) are nil, VeryRaw is used instead.
+// If one of them is nil, then only the other is used. If both (Parsed and Raw) are nil, VeryRaw is used instead.
type Content struct {
VeryRaw json.RawMessage
Raw map[string]interface{}
diff --git a/event/delayed.go b/event/delayed.go
new file mode 100644
index 00000000..fefb62af
--- /dev/null
+++ b/event/delayed.go
@@ -0,0 +1,70 @@
+package event
+
+import (
+ "encoding/json"
+
+ "go.mau.fi/util/jsontime"
+
+ "maunium.net/go/mautrix/id"
+)
+
+type ScheduledDelayedEvent struct {
+ DelayID id.DelayID `json:"delay_id"`
+ RoomID id.RoomID `json:"room_id"`
+ Type Type `json:"type"`
+ StateKey *string `json:"state_key,omitempty"`
+ Delay int64 `json:"delay"`
+ RunningSince jsontime.UnixMilli `json:"running_since"`
+ Content Content `json:"content"`
+}
+
+func (e ScheduledDelayedEvent) AsEvent(eventID id.EventID, ts jsontime.UnixMilli) (*Event, error) {
+ evt := &Event{
+ ID: eventID,
+ RoomID: e.RoomID,
+ Type: e.Type,
+ StateKey: e.StateKey,
+ Content: e.Content,
+ Timestamp: ts.UnixMilli(),
+ }
+ return evt, evt.Content.ParseRaw(evt.Type)
+}
+
+type FinalisedDelayedEvent struct {
+ DelayedEvent *ScheduledDelayedEvent `json:"scheduled_event"`
+ Outcome DelayOutcome `json:"outcome"`
+ Reason DelayReason `json:"reason"`
+ Error json.RawMessage `json:"error,omitempty"`
+ EventID id.EventID `json:"event_id,omitempty"`
+ Timestamp jsontime.UnixMilli `json:"origin_server_ts"`
+}
+
+type DelayStatus string
+
+var (
+ DelayStatusScheduled DelayStatus = "scheduled"
+ DelayStatusFinalised DelayStatus = "finalised"
+)
+
+type DelayAction string
+
+var (
+ DelayActionSend DelayAction = "send"
+ DelayActionCancel DelayAction = "cancel"
+ DelayActionRestart DelayAction = "restart"
+)
+
+type DelayOutcome string
+
+var (
+ DelayOutcomeSend DelayOutcome = "send"
+ DelayOutcomeCancel DelayOutcome = "cancel"
+)
+
+type DelayReason string
+
+var (
+ DelayReasonAction DelayReason = "action"
+ DelayReasonError DelayReason = "error"
+ DelayReasonDelay DelayReason = "delay"
+)
diff --git a/event/encryption.go b/event/encryption.go
index cf9c2814..c60cb91a 100644
--- a/event/encryption.go
+++ b/event/encryption.go
@@ -63,7 +63,7 @@ func (content *EncryptedEventContent) UnmarshalJSON(data []byte) error {
return json.Unmarshal(content.Ciphertext, &content.OlmCiphertext)
case id.AlgorithmMegolmV1:
if len(content.Ciphertext) == 0 || content.Ciphertext[0] != '"' || content.Ciphertext[len(content.Ciphertext)-1] != '"' {
- return id.InputNotJSONString
+ return fmt.Errorf("ciphertext %w", id.ErrInputNotJSONString)
}
content.MegolmCiphertext = content.Ciphertext[1 : len(content.Ciphertext)-1]
}
@@ -132,8 +132,9 @@ type RoomKeyRequestEventContent struct {
type RequestedKeyInfo struct {
Algorithm id.Algorithm `json:"algorithm"`
RoomID id.RoomID `json:"room_id"`
- SenderKey id.SenderKey `json:"sender_key"`
SessionID id.SessionID `json:"session_id"`
+ // Deprecated: Matrix v1.3
+ SenderKey id.SenderKey `json:"sender_key"`
}
type RoomKeyWithheldCode string
diff --git a/event/events.go b/event/events.go
index 38f0d848..72c1e161 100644
--- a/event/events.go
+++ b/event/events.go
@@ -118,6 +118,9 @@ type MautrixInfo struct {
DecryptionDuration time.Duration
CheckpointSent bool
+ // When using MSC4222 and the state_after field, this field is set
+ // for timeline events to indicate they shouldn't update room state.
+ IgnoreState bool
}
func (evt *Event) GetStateKey() string {
@@ -127,31 +130,29 @@ func (evt *Event) GetStateKey() string {
return ""
}
-type StrippedState struct {
- Content Content `json:"content"`
- Type Type `json:"type"`
- StateKey string `json:"state_key"`
- Sender id.UserID `json:"sender"`
-}
-
type Unsigned struct {
- PrevContent *Content `json:"prev_content,omitempty"`
- PrevSender id.UserID `json:"prev_sender,omitempty"`
- ReplacesState id.EventID `json:"replaces_state,omitempty"`
- Age int64 `json:"age,omitempty"`
- TransactionID string `json:"transaction_id,omitempty"`
- Relations *Relations `json:"m.relations,omitempty"`
- RedactedBecause *Event `json:"redacted_because,omitempty"`
- InviteRoomState []StrippedState `json:"invite_room_state,omitempty"`
+ PrevContent *Content `json:"prev_content,omitempty"`
+ PrevSender id.UserID `json:"prev_sender,omitempty"`
+ Membership Membership `json:"membership,omitempty"`
+ ReplacesState id.EventID `json:"replaces_state,omitempty"`
+ Age int64 `json:"age,omitempty"`
+ TransactionID string `json:"transaction_id,omitempty"`
+ Relations *Relations `json:"m.relations,omitempty"`
+ RedactedBecause *Event `json:"redacted_because,omitempty"`
+ InviteRoomState []*Event `json:"invite_room_state,omitempty"`
- BeeperHSOrder int64 `json:"com.beeper.hs.order,omitempty"`
- BeeperHSSuborder int64 `json:"com.beeper.hs.suborder,omitempty"`
- BeeperHSOrderString BeeperEncodedOrder `json:"com.beeper.hs.order_string,omitempty"`
- BeeperFromBackup bool `json:"com.beeper.from_backup,omitempty"`
+ BeeperHSOrder int64 `json:"com.beeper.hs.order,omitempty"`
+ BeeperHSSuborder int16 `json:"com.beeper.hs.suborder,omitempty"`
+ BeeperHSOrderString *BeeperEncodedOrder `json:"com.beeper.hs.order_string,omitempty"`
+ BeeperFromBackup bool `json:"com.beeper.from_backup,omitempty"`
+
+ ElementSoftFailed bool `json:"io.element.synapse.soft_failed,omitempty"`
+ ElementPolicyServerSpammy bool `json:"io.element.synapse.policy_server_spammy,omitempty"`
}
func (us *Unsigned) IsEmpty() bool {
- return us.PrevContent == nil && us.PrevSender == "" && us.ReplacesState == "" && us.Age == 0 &&
+ return us.PrevContent == nil && us.PrevSender == "" && us.ReplacesState == "" && us.Age == 0 && us.Membership == "" &&
us.TransactionID == "" && us.RedactedBecause == nil && us.InviteRoomState == nil && us.Relations == nil &&
- us.BeeperHSOrder == 0 && us.BeeperHSSuborder == 0 && us.BeeperHSOrderString.IsZero()
+ us.BeeperHSOrder == 0 && us.BeeperHSSuborder == 0 && us.BeeperHSOrderString.IsZero() &&
+ !us.ElementSoftFailed
}
diff --git a/event/member.go b/event/member.go
index ebafdcb7..9956a36b 100644
--- a/event/member.go
+++ b/event/member.go
@@ -7,8 +7,6 @@
package event
import (
- "encoding/json"
-
"maunium.net/go/mautrix/id"
)
@@ -35,19 +33,37 @@ const (
// MemberEventContent represents the content of a m.room.member state event.
// https://spec.matrix.org/v1.2/client-server-api/#mroommember
type MemberEventContent struct {
- Membership Membership `json:"membership"`
- AvatarURL id.ContentURIString `json:"avatar_url,omitempty"`
- Displayname string `json:"displayname,omitempty"`
- IsDirect bool `json:"is_direct,omitempty"`
- ThirdPartyInvite *ThirdPartyInvite `json:"third_party_invite,omitempty"`
- Reason string `json:"reason,omitempty"`
+ Membership Membership `json:"membership"`
+ AvatarURL id.ContentURIString `json:"avatar_url,omitempty"`
+ Displayname string `json:"displayname,omitempty"`
+ IsDirect bool `json:"is_direct,omitempty"`
+ ThirdPartyInvite *ThirdPartyInvite `json:"third_party_invite,omitempty"`
+ Reason string `json:"reason,omitempty"`
+ JoinAuthorisedViaUsersServer id.UserID `json:"join_authorised_via_users_server,omitempty"`
+ MSC3414File *EncryptedFileInfo `json:"org.matrix.msc3414.file,omitempty"`
+
+ MSC4293RedactEvents bool `json:"org.matrix.msc4293.redact_events,omitempty"`
+}
+
+type SignedThirdPartyInvite struct {
+ Token string `json:"token"`
+ Signatures map[string]map[id.KeyID]string `json:"signatures,omitempty"`
+ MXID string `json:"mxid"`
}
type ThirdPartyInvite struct {
- DisplayName string `json:"display_name"`
- Signed struct {
- Token string `json:"token"`
- Signatures json.RawMessage `json:"signatures"`
- MXID string `json:"mxid"`
- }
+ DisplayName string `json:"display_name"`
+ Signed SignedThirdPartyInvite `json:"signed"`
+}
+
+type ThirdPartyInviteEventContent struct {
+ DisplayName string `json:"display_name"`
+ KeyValidityURL string `json:"key_validity_url"`
+ PublicKey id.Ed25519 `json:"public_key"`
+ PublicKeys []ThirdPartyInviteKey `json:"public_keys,omitempty"`
+}
+
+type ThirdPartyInviteKey struct {
+ KeyValidityURL string `json:"key_validity_url,omitempty"`
+ PublicKey id.Ed25519 `json:"public_key"`
}
diff --git a/event/message.go b/event/message.go
index 92bdcf07..3fb3dc82 100644
--- a/event/message.go
+++ b/event/message.go
@@ -32,7 +32,7 @@ func (mt MessageType) IsText() bool {
func (mt MessageType) IsMedia() bool {
switch mt {
- case MsgImage, MsgVideo, MsgAudio, MsgFile, MessageType(EventSticker.Type):
+ case MsgImage, MsgVideo, MsgAudio, MsgFile, CapMsgSticker:
return true
default:
return false
@@ -135,11 +135,42 @@ type MessageEventContent struct {
BeeperGalleryCaption string `json:"com.beeper.gallery.caption,omitempty"`
BeeperGalleryCaptionHTML string `json:"com.beeper.gallery.caption_html,omitempty"`
BeeperPerMessageProfile *BeeperPerMessageProfile `json:"com.beeper.per_message_profile,omitempty"`
+ BeeperActionMessage *BeeperActionMessage `json:"com.beeper.action_message,omitempty"`
BeeperLinkPreviews []*BeeperLinkPreview `json:"com.beeper.linkpreviews,omitempty"`
+ BeeperDisappearingTimer *BeeperDisappearingTimer `json:"com.beeper.disappearing_timer,omitempty"`
+
MSC1767Audio *MSC1767Audio `json:"org.matrix.msc1767.audio,omitempty"`
MSC3245Voice *MSC3245Voice `json:"org.matrix.msc3245.voice,omitempty"`
+
+ MSC4391BotCommand *MSC4391BotCommandInput `json:"org.matrix.msc4391.command,omitempty"`
+}
+
+func (content *MessageEventContent) GetCapMsgType() CapabilityMsgType {
+ switch content.MsgType {
+ case CapMsgSticker:
+ return CapMsgSticker
+ case "":
+ if content.URL != "" || content.File != nil {
+ return CapMsgSticker
+ }
+ case MsgImage:
+ return MsgImage
+ case MsgAudio:
+ if content.MSC3245Voice != nil {
+ return CapMsgVoice
+ }
+ return MsgAudio
+ case MsgVideo:
+ if content.Info != nil && content.Info.MauGIF {
+ return CapMsgGIF
+ }
+ return MsgVideo
+ case MsgFile:
+ return MsgFile
+ }
+ return ""
}
func (content *MessageEventContent) GetFileName() string {
@@ -184,6 +215,7 @@ func (content *MessageEventContent) SetEdit(original id.EventID) {
content.RelatesTo = (&RelatesTo{}).SetReplace(original)
if content.MsgType == MsgText || content.MsgType == MsgNotice {
content.Body = "* " + content.Body
+ content.Mentions = &Mentions{}
if content.Format == FormatHTML && len(content.FormattedBody) > 0 {
content.FormattedBody = "* " + content.FormattedBody
}
@@ -244,24 +276,46 @@ func (m *Mentions) Has(userID id.UserID) bool {
return m != nil && slices.Contains(m.UserIDs, userID)
}
+func (m *Mentions) Merge(other *Mentions) *Mentions {
+ if m == nil {
+ return other
+ } else if other == nil {
+ return m
+ }
+ return &Mentions{
+ UserIDs: slices.Concat(m.UserIDs, other.UserIDs),
+ Room: m.Room || other.Room,
+ }
+}
+
+type MSC4391BotCommandInputCustom[T any] struct {
+ Command string `json:"command"`
+ Arguments T `json:"arguments,omitempty"`
+}
+
+type MSC4391BotCommandInput = MSC4391BotCommandInputCustom[json.RawMessage]
+
type EncryptedFileInfo struct {
attachment.EncryptedFile
URL id.ContentURIString `json:"url"`
}
type FileInfo struct {
- MimeType string `json:"mimetype,omitempty"`
- ThumbnailInfo *FileInfo `json:"thumbnail_info,omitempty"`
- ThumbnailURL id.ContentURIString `json:"thumbnail_url,omitempty"`
- ThumbnailFile *EncryptedFileInfo `json:"thumbnail_file,omitempty"`
+ MimeType string
+ ThumbnailInfo *FileInfo
+ ThumbnailURL id.ContentURIString
+ ThumbnailFile *EncryptedFileInfo
- Blurhash string `json:"blurhash,omitempty"`
- AnoaBlurhash string `json:"xyz.amorgan.blurhash,omitempty"`
+ Blurhash string
+ AnoaBlurhash string
- Width int `json:"-"`
- Height int `json:"-"`
- Duration int `json:"-"`
- Size int `json:"-"`
+ MauGIF bool
+ IsAnimated bool
+
+ Width int
+ Height int
+ Duration int
+ Size int
}
type serializableFileInfo struct {
@@ -273,6 +327,9 @@ type serializableFileInfo struct {
Blurhash string `json:"blurhash,omitempty"`
AnoaBlurhash string `json:"xyz.amorgan.blurhash,omitempty"`
+ MauGIF bool `json:"fi.mau.gif,omitempty"`
+ IsAnimated bool `json:"is_animated,omitempty"`
+
Width json.Number `json:"w,omitempty"`
Height json.Number `json:"h,omitempty"`
Duration json.Number `json:"duration,omitempty"`
@@ -289,6 +346,9 @@ func (sfi *serializableFileInfo) CopyFrom(fileInfo *FileInfo) *serializableFileI
ThumbnailInfo: (&serializableFileInfo{}).CopyFrom(fileInfo.ThumbnailInfo),
ThumbnailFile: fileInfo.ThumbnailFile,
+ MauGIF: fileInfo.MauGIF,
+ IsAnimated: fileInfo.IsAnimated,
+
Blurhash: fileInfo.Blurhash,
AnoaBlurhash: fileInfo.AnoaBlurhash,
}
@@ -317,6 +377,8 @@ func (sfi *serializableFileInfo) CopyTo(fileInfo *FileInfo) {
MimeType: sfi.MimeType,
ThumbnailURL: sfi.ThumbnailURL,
ThumbnailFile: sfi.ThumbnailFile,
+ MauGIF: sfi.MauGIF,
+ IsAnimated: sfi.IsAnimated,
Blurhash: sfi.Blurhash,
AnoaBlurhash: sfi.AnoaBlurhash,
}
diff --git a/event/message_test.go b/event/message_test.go
index 562a6622..c721df35 100644
--- a/event/message_test.go
+++ b/event/message_test.go
@@ -33,7 +33,7 @@ const invalidMessageEvent = `{
func TestMessageEventContent__ParseInvalid(t *testing.T) {
var evt *event.Event
err := json.Unmarshal([]byte(invalidMessageEvent), &evt)
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender)
assert.Equal(t, event.EventMessage, evt.Type)
@@ -42,7 +42,7 @@ func TestMessageEventContent__ParseInvalid(t *testing.T) {
assert.Equal(t, id.RoomID("!bar"), evt.RoomID)
err = evt.Content.ParseRaw(evt.Type)
- assert.NotNil(t, err)
+ assert.Error(t, err)
}
const messageEvent = `{
@@ -68,7 +68,7 @@ const messageEvent = `{
func TestMessageEventContent__ParseEdit(t *testing.T) {
var evt *event.Event
err := json.Unmarshal([]byte(messageEvent), &evt)
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender)
assert.Equal(t, event.EventMessage, evt.Type)
@@ -110,7 +110,7 @@ const imageMessageEvent = `{
func TestMessageEventContent__ParseMedia(t *testing.T) {
var evt *event.Event
err := json.Unmarshal([]byte(imageMessageEvent), &evt)
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender)
assert.Equal(t, event.EventMessage, evt.Type)
@@ -125,7 +125,7 @@ func TestMessageEventContent__ParseMedia(t *testing.T) {
content := evt.Content.Parsed.(*event.MessageEventContent)
assert.Equal(t, event.MsgImage, content.MsgType)
parsedURL, err := content.URL.Parse()
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, id.ContentURI{Homeserver: "example.com", FileID: "image"}, parsedURL)
assert.Nil(t, content.NewContent)
assert.Equal(t, "image/png", content.GetInfo().MimeType)
@@ -145,7 +145,7 @@ const expectedMarshalResult = `{"msgtype":"m.text","body":"test"}`
func TestMessageEventContent__Marshal(t *testing.T) {
data, err := json.Marshal(parsedMessage)
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, expectedMarshalResult, string(data))
}
@@ -163,6 +163,6 @@ const expectedCustomMarshalResult = `{"body":"test","msgtype":"m.text","net.maun
func TestMessageEventContent__Marshal_Custom(t *testing.T) {
data, err := json.Marshal(customParsedMessage)
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, expectedCustomMarshalResult, string(data))
}
diff --git a/event/poll.go b/event/poll.go
index 37333015..9082f65e 100644
--- a/event/poll.go
+++ b/event/poll.go
@@ -29,16 +29,13 @@ func (content *PollResponseEventContent) SetRelatesTo(rel *RelatesTo) {
}
type MSC1767Message struct {
- Text string `json:"org.matrix.msc1767.text,omitempty"`
- HTML string `json:"org.matrix.msc1767.html,omitempty"`
- Message []struct {
- MimeType string `json:"mimetype"`
- Body string `json:"body"`
- } `json:"org.matrix.msc1767.message,omitempty"`
+ Text string `json:"org.matrix.msc1767.text,omitempty"`
+ HTML string `json:"org.matrix.msc1767.html,omitempty"`
+ Message []ExtensibleText `json:"org.matrix.msc1767.message,omitempty"`
}
type PollStartEventContent struct {
- RelatesTo *RelatesTo `json:"m.relates_to"`
+ RelatesTo *RelatesTo `json:"m.relates_to,omitempty"`
Mentions *Mentions `json:"m.mentions,omitempty"`
PollStart struct {
Kind string `json:"kind"`
diff --git a/event/powerlevels.go b/event/powerlevels.go
index 2f4d4573..668eb6d3 100644
--- a/event/powerlevels.go
+++ b/event/powerlevels.go
@@ -7,6 +7,8 @@
package event
import (
+ "math"
+ "slices"
"sync"
"go.mau.fi/util/ptr"
@@ -26,6 +28,9 @@ type PowerLevelsEventContent struct {
Events map[string]int `json:"events,omitempty"`
EventsDefault int `json:"events_default,omitempty"`
+ beeperEphemeralLock sync.RWMutex
+ BeeperEphemeral map[string]int `json:"com.beeper.ephemeral,omitempty"`
+
Notifications *NotificationPowerLevels `json:"notifications,omitempty"`
StateDefaultPtr *int `json:"state_default,omitempty"`
@@ -34,6 +39,12 @@ type PowerLevelsEventContent struct {
KickPtr *int `json:"kick,omitempty"`
BanPtr *int `json:"ban,omitempty"`
RedactPtr *int `json:"redact,omitempty"`
+
+ BeeperEphemeralDefaultPtr *int `json:"com.beeper.ephemeral_default,omitempty"`
+
+ // This is not a part of power levels, it's added by mautrix-go internally in certain places
+ // in order to detect creator power accurately.
+ CreateEvent *Event `json:"-"`
}
func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent {
@@ -45,6 +56,7 @@ func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent {
UsersDefault: pl.UsersDefault,
Events: maps.Clone(pl.Events),
EventsDefault: pl.EventsDefault,
+ BeeperEphemeral: maps.Clone(pl.BeeperEphemeral),
StateDefaultPtr: ptr.Clone(pl.StateDefaultPtr),
Notifications: pl.Notifications.Clone(),
@@ -53,6 +65,10 @@ func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent {
KickPtr: ptr.Clone(pl.KickPtr),
BanPtr: ptr.Clone(pl.BanPtr),
RedactPtr: ptr.Clone(pl.RedactPtr),
+
+ BeeperEphemeralDefaultPtr: ptr.Clone(pl.BeeperEphemeralDefaultPtr),
+
+ CreateEvent: pl.CreateEvent,
}
}
@@ -111,7 +127,17 @@ func (pl *PowerLevelsEventContent) StateDefault() int {
return 50
}
+func (pl *PowerLevelsEventContent) BeeperEphemeralDefault() int {
+ if pl.BeeperEphemeralDefaultPtr != nil {
+ return *pl.BeeperEphemeralDefaultPtr
+ }
+ return pl.EventsDefault
+}
+
func (pl *PowerLevelsEventContent) GetUserLevel(userID id.UserID) int {
+ if pl.isCreator(userID) {
+ return math.MaxInt
+ }
pl.usersLock.RLock()
defer pl.usersLock.RUnlock()
level, ok := pl.Users[userID]
@@ -121,9 +147,19 @@ func (pl *PowerLevelsEventContent) GetUserLevel(userID id.UserID) int {
return level
}
+const maxPL = 1<<53 - 1
+
func (pl *PowerLevelsEventContent) SetUserLevel(userID id.UserID, level int) {
pl.usersLock.Lock()
defer pl.usersLock.Unlock()
+ if pl.isCreator(userID) {
+ return
+ }
+ if level == math.MaxInt && maxPL < math.MaxInt {
+ // Hack to avoid breaking on 32-bit systems (they're only slightly supported)
+ x := int64(maxPL)
+ level = int(x)
+ }
if level == pl.UsersDefault {
delete(pl.Users, userID)
} else {
@@ -138,9 +174,24 @@ func (pl *PowerLevelsEventContent) EnsureUserLevel(target id.UserID, level int)
return pl.EnsureUserLevelAs("", target, level)
}
+func (pl *PowerLevelsEventContent) createContent() *CreateEventContent {
+ if pl.CreateEvent == nil {
+ return &CreateEventContent{}
+ }
+ return pl.CreateEvent.Content.AsCreate()
+}
+
+func (pl *PowerLevelsEventContent) isCreator(userID id.UserID) bool {
+ cc := pl.createContent()
+ return cc.SupportsCreatorPower() && (userID == pl.CreateEvent.Sender || slices.Contains(cc.AdditionalCreators, userID))
+}
+
func (pl *PowerLevelsEventContent) EnsureUserLevelAs(actor, target id.UserID, level int) bool {
+ if pl.isCreator(target) {
+ return false
+ }
existingLevel := pl.GetUserLevel(target)
- if actor != "" {
+ if actor != "" && !pl.isCreator(actor) {
actorLevel := pl.GetUserLevel(actor)
if actorLevel <= existingLevel || actorLevel < level {
return false
@@ -166,6 +217,29 @@ func (pl *PowerLevelsEventContent) GetEventLevel(eventType Type) int {
return level
}
+func (pl *PowerLevelsEventContent) GetBeeperEphemeralLevel(eventType Type) int {
+ pl.beeperEphemeralLock.RLock()
+ defer pl.beeperEphemeralLock.RUnlock()
+ level, ok := pl.BeeperEphemeral[eventType.String()]
+ if !ok {
+ return pl.BeeperEphemeralDefault()
+ }
+ return level
+}
+
+func (pl *PowerLevelsEventContent) SetBeeperEphemeralLevel(eventType Type, level int) {
+ pl.beeperEphemeralLock.Lock()
+ defer pl.beeperEphemeralLock.Unlock()
+ if level == pl.BeeperEphemeralDefault() {
+ delete(pl.BeeperEphemeral, eventType.String())
+ } else {
+ if pl.BeeperEphemeral == nil {
+ pl.BeeperEphemeral = make(map[string]int)
+ }
+ pl.BeeperEphemeral[eventType.String()] = level
+ }
+}
+
func (pl *PowerLevelsEventContent) SetEventLevel(eventType Type, level int) {
pl.eventsLock.Lock()
defer pl.eventsLock.Unlock()
@@ -185,7 +259,7 @@ func (pl *PowerLevelsEventContent) EnsureEventLevel(eventType Type, level int) b
func (pl *PowerLevelsEventContent) EnsureEventLevelAs(actor id.UserID, eventType Type, level int) bool {
existingLevel := pl.GetEventLevel(eventType)
- if actor != "" {
+ if actor != "" && !pl.isCreator(actor) {
actorLevel := pl.GetUserLevel(actor)
if existingLevel > actorLevel || level > actorLevel {
return false
diff --git a/event/powerlevels_ephemeral_test.go b/event/powerlevels_ephemeral_test.go
new file mode 100644
index 00000000..f5861583
--- /dev/null
+++ b/event/powerlevels_ephemeral_test.go
@@ -0,0 +1,67 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package event_test
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "maunium.net/go/mautrix/event"
+)
+
+func TestPowerLevelsEventContent_BeeperEphemeralDefaultFallsBackToEventsDefault(t *testing.T) {
+ pl := &event.PowerLevelsEventContent{
+ EventsDefault: 45,
+ }
+
+ assert.Equal(t, 45, pl.BeeperEphemeralDefault())
+
+ override := 60
+ pl.BeeperEphemeralDefaultPtr = &override
+ assert.Equal(t, 60, pl.BeeperEphemeralDefault())
+}
+
+func TestPowerLevelsEventContent_GetSetBeeperEphemeralLevel(t *testing.T) {
+ pl := &event.PowerLevelsEventContent{
+ EventsDefault: 25,
+ }
+ evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType}
+
+ assert.Equal(t, 25, pl.GetBeeperEphemeralLevel(evtType))
+
+ pl.SetBeeperEphemeralLevel(evtType, 50)
+ assert.Equal(t, 50, pl.GetBeeperEphemeralLevel(evtType))
+ require.NotNil(t, pl.BeeperEphemeral)
+ assert.Equal(t, 50, pl.BeeperEphemeral[evtType.String()])
+
+ pl.SetBeeperEphemeralLevel(evtType, 25)
+ _, exists := pl.BeeperEphemeral[evtType.String()]
+ assert.False(t, exists)
+}
+
+func TestPowerLevelsEventContent_CloneCopiesBeeperEphemeralFields(t *testing.T) {
+ override := 70
+ pl := &event.PowerLevelsEventContent{
+ EventsDefault: 35,
+ BeeperEphemeral: map[string]int{"com.example.ephemeral": 90},
+ BeeperEphemeralDefaultPtr: &override,
+ }
+
+ cloned := pl.Clone()
+ require.NotNil(t, cloned)
+ require.NotNil(t, cloned.BeeperEphemeralDefaultPtr)
+ assert.Equal(t, 70, *cloned.BeeperEphemeralDefaultPtr)
+ assert.Equal(t, 90, cloned.BeeperEphemeral["com.example.ephemeral"])
+
+ cloned.BeeperEphemeral["com.example.ephemeral"] = 99
+ *cloned.BeeperEphemeralDefaultPtr = 71
+
+ assert.Equal(t, 90, pl.BeeperEphemeral["com.example.ephemeral"])
+ assert.Equal(t, 70, *pl.BeeperEphemeralDefaultPtr)
+}
diff --git a/event/relations.go b/event/relations.go
index ea40cc06..2316cbc7 100644
--- a/event/relations.go
+++ b/event/relations.go
@@ -15,10 +15,11 @@ import (
type RelationType string
const (
- RelReplace RelationType = "m.replace"
- RelReference RelationType = "m.reference"
- RelAnnotation RelationType = "m.annotation"
- RelThread RelationType = "m.thread"
+ RelReplace RelationType = "m.replace"
+ RelReference RelationType = "m.reference"
+ RelAnnotation RelationType = "m.annotation"
+ RelThread RelationType = "m.thread"
+ RelBeeperTranscription RelationType = "com.beeper.transcription"
)
type RelatesTo struct {
@@ -33,7 +34,7 @@ type RelatesTo struct {
type InReplyTo struct {
EventID id.EventID `json:"event_id,omitempty"`
- UnstableRoomID id.RoomID `json:"room_id,omitempty"`
+ UnstableRoomID id.RoomID `json:"com.beeper.cross_room_id,omitempty"`
}
func (rel *RelatesTo) Copy() *RelatesTo {
@@ -100,6 +101,10 @@ func (rel *RelatesTo) SetReplace(mxid id.EventID) *RelatesTo {
}
func (rel *RelatesTo) SetReplyTo(mxid id.EventID) *RelatesTo {
+ if rel.Type != RelThread {
+ rel.Type = ""
+ rel.EventID = ""
+ }
rel.InReplyTo = &InReplyTo{EventID: mxid}
rel.IsFallingBack = false
return rel
diff --git a/event/reply.go b/event/reply.go
index 1a88c619..5f55bb80 100644
--- a/event/reply.go
+++ b/event/reply.go
@@ -32,12 +32,13 @@ func TrimReplyFallbackText(text string) string {
}
func (content *MessageEventContent) RemoveReplyFallback() {
- if len(content.RelatesTo.GetReplyTo()) > 0 && !content.replyFallbackRemoved {
- if content.Format == FormatHTML {
- content.FormattedBody = TrimReplyFallbackHTML(content.FormattedBody)
+ if len(content.RelatesTo.GetReplyTo()) > 0 && !content.replyFallbackRemoved && content.Format == FormatHTML {
+ origHTML := content.FormattedBody
+ content.FormattedBody = TrimReplyFallbackHTML(content.FormattedBody)
+ if content.FormattedBody != origHTML {
+ content.Body = TrimReplyFallbackText(content.Body)
+ content.replyFallbackRemoved = true
}
- content.Body = TrimReplyFallbackText(content.Body)
- content.replyFallbackRemoved = true
}
}
@@ -47,5 +48,27 @@ func (content *MessageEventContent) GetReplyTo() id.EventID {
}
func (content *MessageEventContent) SetReply(inReplyTo *Event) {
- content.RelatesTo = (&RelatesTo{}).SetReplyTo(inReplyTo.ID)
+ if content.RelatesTo == nil {
+ content.RelatesTo = &RelatesTo{}
+ }
+ content.RelatesTo.SetReplyTo(inReplyTo.ID)
+ if content.Mentions == nil {
+ content.Mentions = &Mentions{}
+ }
+ content.Mentions.Add(inReplyTo.Sender)
+}
+
+func (content *MessageEventContent) SetThread(inReplyTo *Event) {
+ root := inReplyTo.ID
+ relatable, ok := inReplyTo.Content.Parsed.(Relatable)
+ if ok {
+ targetRoot := relatable.OptionalGetRelatesTo().GetThreadParent()
+ if targetRoot != "" {
+ root = targetRoot
+ }
+ }
+ if content.RelatesTo == nil {
+ content.RelatesTo = &RelatesTo{}
+ }
+ content.RelatesTo.SetThread(root, inReplyTo.ID)
}
diff --git a/event/state.go b/event/state.go
index 15972892..ace170a5 100644
--- a/event/state.go
+++ b/event/state.go
@@ -7,6 +7,12 @@
package event
import (
+ "encoding/base64"
+ "encoding/json"
+ "slices"
+
+ "go.mau.fi/util/jsontime"
+
"maunium.net/go/mautrix/id"
)
@@ -42,7 +48,52 @@ type ServerACLEventContent struct {
// TopicEventContent represents the content of a m.room.topic state event.
// https://spec.matrix.org/v1.2/client-server-api/#mroomtopic
type TopicEventContent struct {
- Topic string `json:"topic"`
+ Topic string `json:"topic"`
+ ExtensibleTopic *ExtensibleTopic `json:"m.topic,omitempty"`
+}
+
+// ExtensibleTopic represents the contents of the m.topic field within the
+// m.room.topic state event as described in [MSC3765].
+//
+// [MSC3765]: https://github.com/matrix-org/matrix-spec-proposals/pull/3765
+type ExtensibleTopic = ExtensibleTextContainer
+
+type ExtensibleTextContainer struct {
+ Text []ExtensibleText `json:"m.text"`
+}
+
+func (c *ExtensibleTextContainer) Equals(description *ExtensibleTextContainer) bool {
+ if c == nil || description == nil {
+ return c == description
+ }
+ return slices.Equal(c.Text, description.Text)
+}
+
+func MakeExtensibleText(text string) *ExtensibleTextContainer {
+ return &ExtensibleTextContainer{
+ Text: []ExtensibleText{{
+ Body: text,
+ MimeType: "text/plain",
+ }},
+ }
+}
+
+func MakeExtensibleFormattedText(plaintext, html string) *ExtensibleTextContainer {
+ return &ExtensibleTextContainer{
+ Text: []ExtensibleText{{
+ Body: plaintext,
+ MimeType: "text/plain",
+ }, {
+ Body: html,
+ MimeType: "text/html",
+ }},
+ }
+}
+
+// ExtensibleText represents the contents of an m.text field.
+type ExtensibleText struct {
+ MimeType string `json:"mimetype,omitempty"`
+ Body string `json:"body"`
}
// TombstoneEventContent represents the content of a m.room.tombstone state event.
@@ -52,35 +103,64 @@ type TombstoneEventContent struct {
ReplacementRoom id.RoomID `json:"replacement_room"`
}
+func (tec *TombstoneEventContent) GetReplacementRoom() id.RoomID {
+ if tec == nil {
+ return ""
+ }
+ return tec.ReplacementRoom
+}
+
type Predecessor struct {
RoomID id.RoomID `json:"room_id"`
EventID id.EventID `json:"event_id"`
}
-type RoomVersion string
+// Deprecated: use id.RoomVersion instead
+type RoomVersion = id.RoomVersion
+// Deprecated: use id.RoomVX constants instead
const (
- RoomV1 RoomVersion = "1"
- RoomV2 RoomVersion = "2"
- RoomV3 RoomVersion = "3"
- RoomV4 RoomVersion = "4"
- RoomV5 RoomVersion = "5"
- RoomV6 RoomVersion = "6"
- RoomV7 RoomVersion = "7"
- RoomV8 RoomVersion = "8"
- RoomV9 RoomVersion = "9"
- RoomV10 RoomVersion = "10"
- RoomV11 RoomVersion = "11"
+ RoomV1 = id.RoomV1
+ RoomV2 = id.RoomV2
+ RoomV3 = id.RoomV3
+ RoomV4 = id.RoomV4
+ RoomV5 = id.RoomV5
+ RoomV6 = id.RoomV6
+ RoomV7 = id.RoomV7
+ RoomV8 = id.RoomV8
+ RoomV9 = id.RoomV9
+ RoomV10 = id.RoomV10
+ RoomV11 = id.RoomV11
+ RoomV12 = id.RoomV12
)
// CreateEventContent represents the content of a m.room.create state event.
// https://spec.matrix.org/v1.2/client-server-api/#mroomcreate
type CreateEventContent struct {
- Type RoomType `json:"type,omitempty"`
- Creator id.UserID `json:"creator,omitempty"`
- Federate bool `json:"m.federate,omitempty"`
- RoomVersion RoomVersion `json:"room_version,omitempty"`
- Predecessor *Predecessor `json:"predecessor,omitempty"`
+ Type RoomType `json:"type,omitempty"`
+ Federate *bool `json:"m.federate,omitempty"`
+ RoomVersion id.RoomVersion `json:"room_version,omitempty"`
+ Predecessor *Predecessor `json:"predecessor,omitempty"`
+
+ // Room v12+ only
+ AdditionalCreators []id.UserID `json:"additional_creators,omitempty"`
+
+ // Deprecated: use the event sender instead
+ Creator id.UserID `json:"creator,omitempty"`
+}
+
+func (cec *CreateEventContent) GetPredecessor() (p Predecessor) {
+ if cec != nil && cec.Predecessor != nil {
+ p = *cec.Predecessor
+ }
+ return
+}
+
+func (cec *CreateEventContent) SupportsCreatorPower() bool {
+ if cec == nil {
+ return false
+ }
+ return cec.RoomVersion.PrivilegedRoomCreators()
}
// JoinRule specifies how open a room is to new members.
@@ -158,7 +238,8 @@ type BridgeInfoSection struct {
AvatarURL id.ContentURIString `json:"avatar_url,omitempty"`
ExternalURL string `json:"external_url,omitempty"`
- Receiver string `json:"fi.mau.receiver,omitempty"`
+ Receiver string `json:"fi.mau.receiver,omitempty"`
+ MessageRequest bool `json:"com.beeper.message_request,omitempty"`
}
// BridgeEventContent represents the content of a m.bridge state event.
@@ -172,6 +253,32 @@ type BridgeEventContent struct {
BeeperRoomType string `json:"com.beeper.room_type,omitempty"`
BeeperRoomTypeV2 string `json:"com.beeper.room_type.v2,omitempty"`
+
+ TempSlackRemoteIDMigratedFlag bool `json:"com.beeper.slack_remote_id_migrated,omitempty"`
+ TempSlackRemoteIDMigratedFlag2 bool `json:"com.beeper.slack_remote_id_really_migrated,omitempty"`
+}
+
+// DisappearingType represents the type of a disappearing message timer.
+type DisappearingType string
+
+const (
+ DisappearingTypeNone DisappearingType = ""
+ DisappearingTypeAfterRead DisappearingType = "after_read"
+ DisappearingTypeAfterSend DisappearingType = "after_send"
+)
+
+type BeeperDisappearingTimer struct {
+ Type DisappearingType `json:"type"`
+ Timer jsontime.Milliseconds `json:"timer"`
+}
+
+type marshalableBeeperDisappearingTimer BeeperDisappearingTimer
+
+func (bdt *BeeperDisappearingTimer) MarshalJSON() ([]byte, error) {
+ if bdt == nil || bdt.Type == DisappearingTypeNone {
+ return []byte("{}"), nil
+ }
+ return json.Marshal((*marshalableBeeperDisappearingTimer)(bdt))
}
type SpaceChildEventContent struct {
@@ -188,25 +295,63 @@ type SpaceParentEventContent struct {
type PolicyRecommendation string
const (
- PolicyRecommendationBan PolicyRecommendation = "m.ban"
- PolicyRecommendationUnstableBan PolicyRecommendation = "org.matrix.mjolnir.ban"
- PolicyRecommendationUnban PolicyRecommendation = "fi.mau.meowlnir.unban"
+ PolicyRecommendationBan PolicyRecommendation = "m.ban"
+ PolicyRecommendationUnstableTakedown PolicyRecommendation = "org.matrix.msc4204.takedown"
+ PolicyRecommendationUnstableBan PolicyRecommendation = "org.matrix.mjolnir.ban"
+ PolicyRecommendationUnban PolicyRecommendation = "fi.mau.meowlnir.unban"
)
+type PolicyHashes struct {
+ SHA256 string `json:"sha256"`
+}
+
+func (ph *PolicyHashes) DecodeSHA256() *[32]byte {
+ if ph == nil || ph.SHA256 == "" {
+ return nil
+ }
+ decoded, _ := base64.StdEncoding.DecodeString(ph.SHA256)
+ if len(decoded) == 32 {
+ return (*[32]byte)(decoded)
+ }
+ return nil
+}
+
// ModPolicyContent represents the content of a m.room.rule.user, m.room.rule.room, and m.room.rule.server state event.
// https://spec.matrix.org/v1.2/client-server-api/#moderation-policy-lists
type ModPolicyContent struct {
- Entity string `json:"entity"`
+ Entity string `json:"entity,omitempty"`
Reason string `json:"reason"`
Recommendation PolicyRecommendation `json:"recommendation"`
+ UnstableHashes *PolicyHashes `json:"org.matrix.msc4205.hashes,omitempty"`
}
-// Deprecated: MSC2716 has been abandoned
-type InsertionMarkerContent struct {
- InsertionID id.EventID `json:"org.matrix.msc2716.marker.insertion"`
- Timestamp int64 `json:"com.beeper.timestamp,omitempty"`
+func (mpc *ModPolicyContent) EntityOrHash() string {
+ if mpc.UnstableHashes != nil && mpc.UnstableHashes.SHA256 != "" {
+ return mpc.UnstableHashes.SHA256
+ }
+ return mpc.Entity
}
type ElementFunctionalMembersContent struct {
ServiceMembers []id.UserID `json:"service_members"`
}
+
+func (efmc *ElementFunctionalMembersContent) Add(mxid id.UserID) bool {
+ if slices.Contains(efmc.ServiceMembers, mxid) {
+ return false
+ }
+ efmc.ServiceMembers = append(efmc.ServiceMembers, mxid)
+ return true
+}
+
+type PolicyServerPublicKeys struct {
+ Ed25519 id.Ed25519 `json:"ed25519,omitempty"`
+}
+
+type RoomPolicyEventContent struct {
+ Via string `json:"via,omitempty"`
+ PublicKeys *PolicyServerPublicKeys `json:"public_keys,omitempty"`
+
+ // Deprecated, only for legacy use
+ PublicKey id.Ed25519 `json:"public_key,omitempty"`
+}
diff --git a/event/type.go b/event/type.go
index f2b841ad..80b86728 100644
--- a/event/type.go
+++ b/event/type.go
@@ -108,13 +108,14 @@ func (et *Type) IsCustom() bool {
func (et *Type) GuessClass() TypeClass {
switch et.Type {
- case StateAliases.Type, StateCanonicalAlias.Type, StateCreate.Type, StateJoinRules.Type, StateMember.Type,
+ case StateAliases.Type, StateCanonicalAlias.Type, StateCreate.Type, StateJoinRules.Type, StateMember.Type, StateThirdPartyInvite.Type,
StatePowerLevels.Type, StateRoomName.Type, StateRoomAvatar.Type, StateServerACL.Type, StateTopic.Type,
StatePinnedEvents.Type, StateTombstone.Type, StateEncryption.Type, StateBridge.Type, StateHalfShotBridge.Type,
StateSpaceParent.Type, StateSpaceChild.Type, StatePolicyRoom.Type, StatePolicyServer.Type, StatePolicyUser.Type,
- StateInsertionMarker.Type, StateElementFunctionalMembers.Type:
+ StateElementFunctionalMembers.Type, StateBeeperRoomFeatures.Type, StateBeeperDisappearingTimer.Type,
+ StateMSC4391BotCommand.Type, StateRoomPolicy.Type, StateUnstableRoomPolicy.Type:
return StateEventType
- case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type:
+ case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type, BeeperEphemeralEventAIStream.Type:
return EphemeralEventType
case AccountDataDirectChats.Type, AccountDataPushRules.Type, AccountDataRoomTags.Type,
AccountDataFullyRead.Type, AccountDataIgnoredUserList.Type, AccountDataMarkedUnread.Type,
@@ -126,7 +127,8 @@ func (et *Type) GuessClass() TypeClass {
InRoomVerificationStart.Type, InRoomVerificationReady.Type, InRoomVerificationAccept.Type,
InRoomVerificationKey.Type, InRoomVerificationMAC.Type, InRoomVerificationCancel.Type,
CallInvite.Type, CallCandidates.Type, CallAnswer.Type, CallReject.Type, CallSelectAnswer.Type,
- CallNegotiate.Type, CallHangup.Type, BeeperMessageStatus.Type, EventUnstablePollStart.Type, EventUnstablePollResponse.Type:
+ CallNegotiate.Type, CallHangup.Type, BeeperMessageStatus.Type, EventUnstablePollStart.Type, EventUnstablePollResponse.Type,
+ EventUnstablePollEnd.Type, BeeperTranscription.Type, BeeperDeleteChat.Type, BeeperAcceptMessageRequest.Type:
return MessageEventType
case ToDeviceRoomKey.Type, ToDeviceRoomKeyRequest.Type, ToDeviceForwardedRoomKey.Type, ToDeviceRoomKeyWithheld.Type,
ToDeviceBeeperRoomKeyAck.Type:
@@ -176,6 +178,7 @@ var (
StateHistoryVisibility = Type{"m.room.history_visibility", StateEventType}
StateGuestAccess = Type{"m.room.guest_access", StateEventType}
StateMember = Type{"m.room.member", StateEventType}
+ StateThirdPartyInvite = Type{"m.room.third_party_invite", StateEventType}
StatePowerLevels = Type{"m.room.power_levels", StateEventType}
StateRoomName = Type{"m.room.name", StateEventType}
StateTopic = Type{"m.room.topic", StateEventType}
@@ -192,6 +195,9 @@ var (
StateSpaceChild = Type{"m.space.child", StateEventType}
StateSpaceParent = Type{"m.space.parent", StateEventType}
+ StateRoomPolicy = Type{"m.room.policy", StateEventType}
+ StateUnstableRoomPolicy = Type{"org.matrix.msc4284.policy", StateEventType}
+
StateLegacyPolicyRoom = Type{"m.room.rule.room", StateEventType}
StateLegacyPolicyServer = Type{"m.room.rule.server", StateEventType}
StateLegacyPolicyUser = Type{"m.room.rule.user", StateEventType}
@@ -199,10 +205,10 @@ var (
StateUnstablePolicyServer = Type{"org.matrix.mjolnir.rule.server", StateEventType}
StateUnstablePolicyUser = Type{"org.matrix.mjolnir.rule.user", StateEventType}
- // Deprecated: MSC2716 has been abandoned
- StateInsertionMarker = Type{"org.matrix.msc2716.marker", StateEventType}
-
StateElementFunctionalMembers = Type{"io.element.functional_members", StateEventType}
+ StateBeeperRoomFeatures = Type{"com.beeper.room_features", StateEventType}
+ StateBeeperDisappearingTimer = Type{"com.beeper.disappearing_timer", StateEventType}
+ StateMSC4391BotCommand = Type{"org.matrix.msc4391.command_description", StateEventType}
)
// Message events
@@ -231,17 +237,24 @@ var (
CallNegotiate = Type{"m.call.negotiate", MessageEventType}
CallHangup = Type{"m.call.hangup", MessageEventType}
- BeeperMessageStatus = Type{"com.beeper.message_send_status", MessageEventType}
+ BeeperMessageStatus = Type{"com.beeper.message_send_status", MessageEventType}
+ BeeperTranscription = Type{"com.beeper.transcription", MessageEventType}
+ BeeperDeleteChat = Type{"com.beeper.delete_chat", MessageEventType}
+ BeeperAcceptMessageRequest = Type{"com.beeper.accept_message_request", MessageEventType}
+ BeeperSendState = Type{"com.beeper.send_state", MessageEventType}
EventUnstablePollStart = Type{Type: "org.matrix.msc3381.poll.start", Class: MessageEventType}
EventUnstablePollResponse = Type{Type: "org.matrix.msc3381.poll.response", Class: MessageEventType}
+ EventUnstablePollEnd = Type{Type: "org.matrix.msc3381.poll.end", Class: MessageEventType}
)
// Ephemeral events
var (
- EphemeralEventReceipt = Type{"m.receipt", EphemeralEventType}
- EphemeralEventTyping = Type{"m.typing", EphemeralEventType}
- EphemeralEventPresence = Type{"m.presence", EphemeralEventType}
+ EphemeralEventReceipt = Type{"m.receipt", EphemeralEventType}
+ EphemeralEventTyping = Type{"m.typing", EphemeralEventType}
+ EphemeralEventPresence = Type{"m.presence", EphemeralEventType}
+ EphemeralEventEncrypted = Type{"m.room.encrypted", EphemeralEventType}
+ BeeperEphemeralEventAIStream = Type{"com.beeper.ai.stream_event", EphemeralEventType}
)
// Account data events
diff --git a/event/voip.go b/event/voip.go
index 28f56c95..cd8364a1 100644
--- a/event/voip.go
+++ b/event/voip.go
@@ -76,7 +76,7 @@ func (cv *CallVersion) Int() (int, error) {
type BaseCallEventContent struct {
CallID string `json:"call_id"`
PartyID string `json:"party_id"`
- Version CallVersion `json:"version"`
+ Version CallVersion `json:"version,omitempty"`
}
type CallInviteEventContent struct {
diff --git a/example/main.go b/example/main.go
index d8006d46..2bf4bef3 100644
--- a/example/main.go
+++ b/example/main.go
@@ -143,7 +143,7 @@ func main() {
if err != nil {
log.Error().Err(err).Msg("Failed to send event")
} else {
- log.Info().Str("event_id", resp.EventID.String()).Msg("Event sent")
+ log.Info().Stringer("event_id", resp.EventID).Msg("Event sent")
}
}
cancelSync()
diff --git a/federation/cache.go b/federation/cache.go
new file mode 100644
index 00000000..24154974
--- /dev/null
+++ b/federation/cache.go
@@ -0,0 +1,153 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package federation
+
+import (
+ "errors"
+ "fmt"
+ "math"
+ "sync"
+ "time"
+)
+
+// ResolutionCache is an interface for caching resolved server names.
+type ResolutionCache interface {
+ StoreResolution(*ResolvedServerName)
+ // LoadResolution loads a resolved server name from the cache.
+ // Expired entries MUST NOT be returned.
+ LoadResolution(serverName string) (*ResolvedServerName, error)
+}
+
+type KeyCache interface {
+ StoreKeys(*ServerKeyResponse)
+ StoreFetchError(serverName string, err error)
+ ShouldReQuery(serverName string) bool
+ LoadKeys(serverName string) (*ServerKeyResponse, error)
+}
+
+type InMemoryCache struct {
+ MinKeyRefetchDelay time.Duration
+
+ resolutions map[string]*ResolvedServerName
+ resolutionsLock sync.RWMutex
+ keys map[string]*ServerKeyResponse
+ lastReQueryAt map[string]time.Time
+ lastError map[string]*resolutionErrorCache
+ keysLock sync.RWMutex
+}
+
+var (
+ _ ResolutionCache = (*InMemoryCache)(nil)
+ _ KeyCache = (*InMemoryCache)(nil)
+)
+
+func NewInMemoryCache() *InMemoryCache {
+ return &InMemoryCache{
+ resolutions: make(map[string]*ResolvedServerName),
+ keys: make(map[string]*ServerKeyResponse),
+ lastReQueryAt: make(map[string]time.Time),
+ lastError: make(map[string]*resolutionErrorCache),
+ MinKeyRefetchDelay: 1 * time.Hour,
+ }
+}
+
+func (c *InMemoryCache) StoreResolution(resolution *ResolvedServerName) {
+ c.resolutionsLock.Lock()
+ defer c.resolutionsLock.Unlock()
+ c.resolutions[resolution.ServerName] = resolution
+}
+
+func (c *InMemoryCache) LoadResolution(serverName string) (*ResolvedServerName, error) {
+ c.resolutionsLock.RLock()
+ defer c.resolutionsLock.RUnlock()
+ resolution, ok := c.resolutions[serverName]
+ if !ok || time.Until(resolution.Expires) < 0 {
+ return nil, nil
+ }
+ return resolution, nil
+}
+
+func (c *InMemoryCache) StoreKeys(keys *ServerKeyResponse) {
+ c.keysLock.Lock()
+ defer c.keysLock.Unlock()
+ c.keys[keys.ServerName] = keys
+ delete(c.lastError, keys.ServerName)
+}
+
+type resolutionErrorCache struct {
+ Error error
+ Time time.Time
+ Count int
+}
+
+const MaxBackoff = 7 * 24 * time.Hour
+
+func (rec *resolutionErrorCache) ShouldRetry() bool {
+ backoff := time.Duration(math.Exp(float64(rec.Count))) * time.Second
+ return time.Since(rec.Time) > backoff
+}
+
+var ErrRecentKeyQueryFailed = errors.New("last retry was too recent")
+
+func (c *InMemoryCache) LoadKeys(serverName string) (*ServerKeyResponse, error) {
+ c.keysLock.RLock()
+ defer c.keysLock.RUnlock()
+ keys, ok := c.keys[serverName]
+ if !ok || time.Until(keys.ValidUntilTS.Time) < 0 {
+ err, ok := c.lastError[serverName]
+ if ok && !err.ShouldRetry() {
+ return nil, fmt.Errorf(
+ "%w (%s ago) and failed with %w",
+ ErrRecentKeyQueryFailed,
+ time.Since(err.Time).String(),
+ err.Error,
+ )
+ }
+ return nil, nil
+ }
+ return keys, nil
+}
+
+func (c *InMemoryCache) StoreFetchError(serverName string, err error) {
+ c.keysLock.Lock()
+ defer c.keysLock.Unlock()
+ errorCache, ok := c.lastError[serverName]
+ if ok {
+ errorCache.Time = time.Now()
+ errorCache.Error = err
+ errorCache.Count++
+ } else {
+ c.lastError[serverName] = &resolutionErrorCache{Error: err, Time: time.Now(), Count: 1}
+ }
+}
+
+func (c *InMemoryCache) ShouldReQuery(serverName string) bool {
+ c.keysLock.Lock()
+ defer c.keysLock.Unlock()
+ lastQuery, ok := c.lastReQueryAt[serverName]
+ if ok && time.Since(lastQuery) < c.MinKeyRefetchDelay {
+ return false
+ }
+ c.lastReQueryAt[serverName] = time.Now()
+ return true
+}
+
+type noopCache struct{}
+
+func (*noopCache) StoreKeys(_ *ServerKeyResponse) {}
+func (*noopCache) LoadKeys(_ string) (*ServerKeyResponse, error) { return nil, nil }
+func (*noopCache) StoreFetchError(_ string, _ error) {}
+func (*noopCache) ShouldReQuery(_ string) bool { return true }
+func (*noopCache) StoreResolution(_ *ResolvedServerName) {}
+func (*noopCache) LoadResolution(_ string) (*ResolvedServerName, error) { return nil, nil }
+
+var (
+ _ ResolutionCache = (*noopCache)(nil)
+ _ KeyCache = (*noopCache)(nil)
+)
+
+var NoopCache *noopCache
diff --git a/federation/client.go b/federation/client.go
index 098df095..183fb5d1 100644
--- a/federation/client.go
+++ b/federation/client.go
@@ -9,7 +9,6 @@ package federation
import (
"bytes"
"context"
- "encoding/base64"
"encoding/json"
"fmt"
"io"
@@ -22,6 +21,7 @@ import (
"go.mau.fi/util/jsontime"
"maunium.net/go/mautrix"
+ "maunium.net/go/mautrix/federation/signutil"
"maunium.net/go/mautrix/id"
)
@@ -30,17 +30,25 @@ type Client struct {
ServerName string
UserAgent string
Key *SigningKey
+
+ ResponseSizeLimit int64
}
-func NewClient(serverName string, key *SigningKey) *Client {
+func NewClient(serverName string, key *SigningKey, cache ResolutionCache) *Client {
return &Client{
HTTP: &http.Client{
- Transport: NewServerResolvingTransport(),
+ Transport: NewServerResolvingTransport(cache),
Timeout: 120 * time.Second,
+ CheckRedirect: func(req *http.Request, via []*http.Request) error {
+ // Federation requests do not allow redirects.
+ return http.ErrUseLastResponse
+ },
},
UserAgent: mautrix.DefaultUserAgent,
ServerName: serverName,
Key: key,
+
+ ResponseSizeLimit: mautrix.DefaultResponseSizeLimit,
}
}
@@ -54,7 +62,7 @@ func (c *Client) ServerKeys(ctx context.Context, serverName string) (resp *Serve
return
}
-func (c *Client) QueryKeys(ctx context.Context, serverName string, req *ReqQueryKeys) (resp *ServerKeyResponse, err error) {
+func (c *Client) QueryKeys(ctx context.Context, serverName string, req *ReqQueryKeys) (resp *QueryKeysResponse, err error) {
err = c.MakeRequest(ctx, serverName, false, http.MethodPost, KeyURLPath{"v2", "query"}, req, &resp)
return
}
@@ -81,7 +89,7 @@ type RespSendTransaction struct {
}
func (c *Client) SendTransaction(ctx context.Context, req *ReqSendTransaction) (resp *RespSendTransaction, err error) {
- err = c.MakeRequest(ctx, req.Destination, true, http.MethodPost, URLPath{"v1", "send", req.TxnID}, req, &resp)
+ err = c.MakeRequest(ctx, req.Destination, true, http.MethodPut, URLPath{"v1", "send", req.TxnID}, req, &resp)
return
}
@@ -220,6 +228,26 @@ func (c *Client) Query(ctx context.Context, serverName, queryType string, queryP
return
}
+func queryToValues(query map[string]string) url.Values {
+ values := make(url.Values, len(query))
+ for k, v := range query {
+ values[k] = []string{v}
+ }
+ return values
+}
+
+func (c *Client) PublicRooms(ctx context.Context, serverName string, req *mautrix.ReqPublicRooms) (resp *mautrix.RespPublicRooms, err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: serverName,
+ Method: http.MethodGet,
+ Path: URLPath{"v1", "publicRooms"},
+ Query: queryToValues(req.Query()),
+ Authenticate: true,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
type RespOpenIDUserInfo struct {
Sub id.UserID `json:"sub"`
}
@@ -235,6 +263,169 @@ func (c *Client) GetOpenIDUserInfo(ctx context.Context, serverName, accessToken
return
}
+type ReqMakeJoin struct {
+ RoomID id.RoomID
+ UserID id.UserID
+ Via string
+ SupportedVersions []id.RoomVersion
+}
+
+type RespMakeJoin struct {
+ RoomVersion id.RoomVersion `json:"room_version"`
+ Event PDU `json:"event"`
+}
+
+type ReqSendJoin struct {
+ RoomID id.RoomID
+ EventID id.EventID
+ OmitMembers bool
+ Event PDU
+ Via string
+}
+
+type ReqSendKnock struct {
+ RoomID id.RoomID
+ EventID id.EventID
+ Event PDU
+ Via string
+}
+
+type RespSendJoin struct {
+ AuthChain []PDU `json:"auth_chain"`
+ Event PDU `json:"event"`
+ MembersOmitted bool `json:"members_omitted"`
+ ServersInRoom []string `json:"servers_in_room"`
+ State []PDU `json:"state"`
+}
+
+type RespSendKnock struct {
+ KnockRoomState []PDU `json:"knock_room_state"`
+}
+
+type ReqSendInvite struct {
+ RoomID id.RoomID `json:"-"`
+ UserID id.UserID `json:"-"`
+ Event PDU `json:"event"`
+ InviteRoomState []PDU `json:"invite_room_state"`
+ RoomVersion id.RoomVersion `json:"room_version"`
+}
+
+type RespSendInvite struct {
+ Event PDU `json:"event"`
+}
+
+type ReqMakeLeave struct {
+ RoomID id.RoomID
+ UserID id.UserID
+ Via string
+}
+
+type ReqSendLeave struct {
+ RoomID id.RoomID
+ EventID id.EventID
+ Event PDU
+ Via string
+}
+
+type (
+ ReqMakeKnock = ReqMakeJoin
+ RespMakeKnock = RespMakeJoin
+ RespMakeLeave = RespMakeJoin
+)
+
+func (c *Client) MakeJoin(ctx context.Context, req *ReqMakeJoin) (resp *RespMakeJoin, err error) {
+ versions := make([]string, len(req.SupportedVersions))
+ for i, v := range req.SupportedVersions {
+ versions[i] = string(v)
+ }
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.Via,
+ Method: http.MethodGet,
+ Path: URLPath{"v1", "make_join", req.RoomID, req.UserID},
+ Query: url.Values{"ver": versions},
+ Authenticate: true,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+func (c *Client) MakeKnock(ctx context.Context, req *ReqMakeKnock) (resp *RespMakeKnock, err error) {
+ versions := make([]string, len(req.SupportedVersions))
+ for i, v := range req.SupportedVersions {
+ versions[i] = string(v)
+ }
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.Via,
+ Method: http.MethodGet,
+ Path: URLPath{"v1", "make_knock", req.RoomID, req.UserID},
+ Query: url.Values{"ver": versions},
+ Authenticate: true,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+func (c *Client) SendJoin(ctx context.Context, req *ReqSendJoin) (resp *RespSendJoin, err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.Via,
+ Method: http.MethodPut,
+ Path: URLPath{"v2", "send_join", req.RoomID, req.EventID},
+ Query: url.Values{
+ "omit_members": {strconv.FormatBool(req.OmitMembers)},
+ },
+ Authenticate: true,
+ RequestJSON: req.Event,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+func (c *Client) SendKnock(ctx context.Context, req *ReqSendKnock) (resp *RespSendKnock, err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.Via,
+ Method: http.MethodPut,
+ Path: URLPath{"v1", "send_knock", req.RoomID, req.EventID},
+ Authenticate: true,
+ RequestJSON: req.Event,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+func (c *Client) SendInvite(ctx context.Context, req *ReqSendInvite) (resp *RespSendInvite, err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.UserID.Homeserver(),
+ Method: http.MethodPut,
+ Path: URLPath{"v2", "invite", req.RoomID, req.UserID},
+ Authenticate: true,
+ RequestJSON: req,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+func (c *Client) MakeLeave(ctx context.Context, req *ReqMakeLeave) (resp *RespMakeLeave, err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.Via,
+ Method: http.MethodGet,
+ Path: URLPath{"v1", "make_leave", req.RoomID, req.UserID},
+ Authenticate: true,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+func (c *Client) SendLeave(ctx context.Context, req *ReqSendLeave) (err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.Via,
+ Method: http.MethodPut,
+ Path: URLPath{"v2", "send_leave", req.RoomID, req.EventID},
+ Authenticate: true,
+ RequestJSON: req.Event,
+ })
+ return
+}
+
type URLPath []any
func (fup URLPath) FullPath() []any {
@@ -286,15 +477,27 @@ func (c *Client) MakeFullRequest(ctx context.Context, params RequestParams) ([]b
WrappedError: err,
}
}
- defer func() {
- _ = resp.Body.Close()
- }()
+ if !params.DontReadBody {
+ defer resp.Body.Close()
+ }
var body []byte
- if resp.StatusCode >= 400 {
+ if resp.StatusCode >= 300 {
body, err = mautrix.ParseErrorResponse(req, resp)
return body, resp, err
} else if params.ResponseJSON != nil || !params.DontReadBody {
- body, err = io.ReadAll(resp.Body)
+ if resp.ContentLength > c.ResponseSizeLimit {
+ return body, resp, mautrix.HTTPError{
+ Request: req,
+ Response: resp,
+
+ Message: "not reading response",
+ WrappedError: fmt.Errorf("%w (%.2f MiB)", mautrix.ErrResponseTooLong, float64(resp.ContentLength)/1024/1024),
+ }
+ }
+ body, err = io.ReadAll(io.LimitReader(resp.Body, c.ResponseSizeLimit+1))
+ if err == nil && len(body) > int(c.ResponseSizeLimit) {
+ err = mautrix.ErrBodyReadReachedLimit
+ }
if err != nil {
return body, resp, mautrix.HTTPError{
Request: req,
@@ -354,16 +557,12 @@ func (c *Client) compileRequest(ctx context.Context, params RequestParams) (*htt
Message: "client not configured for authentication",
}
}
- var contentAny any
- if reqJSON != nil {
- contentAny = reqJSON
- }
auth, err := (&signableRequest{
Method: req.Method,
URI: reqURL.RequestURI(),
Origin: c.ServerName,
Destination: params.ServerName,
- Content: contentAny,
+ Content: reqJSON,
}).Sign(c.Key)
if err != nil {
return nil, mautrix.HTTPError{
@@ -377,11 +576,19 @@ func (c *Client) compileRequest(ctx context.Context, params RequestParams) (*htt
}
type signableRequest struct {
- Method string `json:"method"`
- URI string `json:"uri"`
- Origin string `json:"origin"`
- Destination string `json:"destination"`
- Content any `json:"content,omitempty"`
+ Method string `json:"method"`
+ URI string `json:"uri"`
+ Origin string `json:"origin"`
+ Destination string `json:"destination"`
+ Content json.RawMessage `json:"content,omitempty"`
+}
+
+func (r *signableRequest) Verify(key id.SigningKey, sig string) error {
+ message, err := json.Marshal(r)
+ if err != nil {
+ return fmt.Errorf("failed to marshal data: %w", err)
+ }
+ return signutil.VerifyJSONRaw(key, sig, message)
}
func (r *signableRequest) Sign(key *SigningKey) (string, error) {
@@ -389,11 +596,10 @@ func (r *signableRequest) Sign(key *SigningKey) (string, error) {
if err != nil {
return "", err
}
- return fmt.Sprintf(
- `X-Matrix origin="%s",destination="%s",key="%s",sig="%s"`,
- r.Origin,
- r.Destination,
- key.ID,
- base64.RawURLEncoding.EncodeToString(sig),
- ), nil
+ return XMatrixAuth{
+ Origin: r.Origin,
+ Destination: r.Destination,
+ KeyID: key.ID,
+ Signature: sig,
+ }.String(), nil
}
diff --git a/federation/client_test.go b/federation/client_test.go
index ba3c3ed4..ece399ea 100644
--- a/federation/client_test.go
+++ b/federation/client_test.go
@@ -16,7 +16,7 @@ import (
)
func TestClient_Version(t *testing.T) {
- cli := federation.NewClient("", nil)
+ cli := federation.NewClient("", nil, nil)
resp, err := cli.Version(context.TODO(), "maunium.net")
require.NoError(t, err)
require.Equal(t, "Synapse", resp.Server.Name)
diff --git a/federation/context.go b/federation/context.go
new file mode 100644
index 00000000..eedb2dc1
--- /dev/null
+++ b/federation/context.go
@@ -0,0 +1,42 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package federation
+
+import (
+ "context"
+ "net/http"
+)
+
+type contextKey int
+
+const (
+ contextKeyIPPort contextKey = iota
+ contextKeyDestinationServer
+ contextKeyOriginServer
+)
+
+func DestinationServerNameFromRequest(r *http.Request) string {
+ return DestinationServerName(r.Context())
+}
+
+func DestinationServerName(ctx context.Context) string {
+ if dest, ok := ctx.Value(contextKeyDestinationServer).(string); ok {
+ return dest
+ }
+ return ""
+}
+
+func OriginServerNameFromRequest(r *http.Request) string {
+ return OriginServerName(r.Context())
+}
+
+func OriginServerName(ctx context.Context) string {
+ if origin, ok := ctx.Value(contextKeyOriginServer).(string); ok {
+ return origin
+ }
+ return ""
+}
diff --git a/federation/eventauth/eventauth.go b/federation/eventauth/eventauth.go
new file mode 100644
index 00000000..c72933c2
--- /dev/null
+++ b/federation/eventauth/eventauth.go
@@ -0,0 +1,851 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package eventauth
+
+import (
+ "encoding/json"
+ "encoding/json/jsontext"
+ "errors"
+ "fmt"
+ "slices"
+ "strconv"
+ "strings"
+
+ "github.com/tidwall/gjson"
+ "go.mau.fi/util/exgjson"
+ "go.mau.fi/util/exstrings"
+ "go.mau.fi/util/ptr"
+
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/federation/pdu"
+ "maunium.net/go/mautrix/federation/signutil"
+ "maunium.net/go/mautrix/id"
+)
+
+type AuthFailError struct {
+ Index string
+ Message string
+ Wrapped error
+}
+
+func (afe AuthFailError) Error() string {
+ if afe.Message != "" {
+ return fmt.Sprintf("fail %s: %s", afe.Index, afe.Message)
+ } else if afe.Wrapped != nil {
+ return fmt.Sprintf("fail %s: %s", afe.Index, afe.Wrapped.Error())
+ }
+ return fmt.Sprintf("fail %s", afe.Index)
+}
+
+func (afe AuthFailError) Unwrap() error {
+ return afe.Wrapped
+}
+
+var mFederatePath = exgjson.Path("m.federate")
+
+var (
+ ErrCreateHasPrevEvents = AuthFailError{Index: "1.1", Message: "m.room.create event has prev_events"}
+ ErrCreateHasRoomID = AuthFailError{Index: "1.2", Message: "m.room.create event has room_id set"}
+ ErrRoomIDDoesntMatchSender = AuthFailError{Index: "1.2", Message: "room ID server doesn't match sender server"}
+ ErrUnknownRoomVersion = AuthFailError{Index: "1.3", Wrapped: id.ErrUnknownRoomVersion}
+ ErrInvalidAdditionalCreators = AuthFailError{Index: "1.4", Message: "m.room.create event has invalid additional_creators"}
+ ErrMissingCreator = AuthFailError{Index: "1.4", Message: "m.room.create event is missing creator field"}
+
+ ErrInvalidRoomIDLength = AuthFailError{Index: "2", Message: "room ID length is invalid"}
+ ErrFailedToGetCreateEvent = AuthFailError{Index: "2", Message: "failed to get m.room.create event"}
+ ErrCreateEventNotFound = AuthFailError{Index: "2", Message: "m.room.create event not found using room ID as event ID"}
+ ErrRejectedCreateEvent = AuthFailError{Index: "2", Message: "m.room.create event was rejected"}
+
+ ErrFailedToGetAuthEvents = AuthFailError{Index: "3", Message: "failed to get auth events"}
+ ErrFailedToParsePowerLevels = AuthFailError{Index: "?", Message: "failed to parse power levels"}
+ ErrDuplicateAuthEvent = AuthFailError{Index: "3.1", Message: "duplicate type/state key pair in auth events"}
+ ErrNonStateAuthEvent = AuthFailError{Index: "3.2", Message: "non-state event in auth events"}
+ ErrMissingAuthEvent = AuthFailError{Index: "3.2", Message: "missing auth event"}
+ ErrUnexpectedAuthEvent = AuthFailError{Index: "3.2", Message: "unexpected type/state key pair in auth events"}
+ ErrNoCreateEvent = AuthFailError{Index: "3.2", Message: "no m.room.create event found in auth events"}
+ ErrRejectedAuthEvent = AuthFailError{Index: "3.3", Message: "auth event was rejected"}
+ ErrMismatchingRoomIDInAuthEvent = AuthFailError{Index: "3.4", Message: "auth event room ID does not match event room ID"}
+
+ ErrFederationDisabled = AuthFailError{Index: "4", Message: "federation is disabled for this room"}
+
+ ErrMemberNotState = AuthFailError{Index: "5.1", Message: "m.room.member event is not a state event"}
+ ErrNotSignedByAuthoriser = AuthFailError{Index: "5.2", Message: "m.room.member event is not signed by server of join_authorised_via_users_server"}
+ ErrCantJoinOtherUser = AuthFailError{Index: "5.3.2", Message: "can't send join event with different state key"}
+ ErrCantJoinBanned = AuthFailError{Index: "5.3.3", Message: "user is banned from the room"}
+ ErrAuthoriserCantInvite = AuthFailError{Index: "5.3.5.2", Message: "authoriser doesn't have sufficient power level to invite"}
+ ErrAuthoriserNotInRoom = AuthFailError{Index: "5.3.5.2", Message: "authoriser isn't a member of the room"}
+ ErrCantJoinWithoutInvite = AuthFailError{Index: "5.3.7", Message: "can't join invite-only room without invite"}
+ ErrInvalidJoinRule = AuthFailError{Index: "5.3.7", Message: "invalid join rule in room"}
+ ErrThirdPartyInviteBanned = AuthFailError{Index: "5.4.1.1", Message: "third party invite target user is banned"}
+ ErrThirdPartyInviteMissingFields = AuthFailError{Index: "5.4.1.3", Message: "third party invite is missing mxid or token fields"}
+ ErrThirdPartyInviteMXIDMismatch = AuthFailError{Index: "5.4.1.4", Message: "mxid in signed third party invite doesn't match event state key"}
+ ErrThirdPartyInviteNotFound = AuthFailError{Index: "5.4.1.5", Message: "matching m.room.third_party_invite event not found in auth events"}
+ ErrThirdPartyInviteSenderMismatch = AuthFailError{Index: "5.4.1.6", Message: "sender of third party invite doesn't match sender of member event"}
+ ErrThirdPartyInviteNotSigned = AuthFailError{Index: "5.4.1.8", Message: "no valid signatures found for third party invite"}
+ ErrInviterNotInRoom = AuthFailError{Index: "5.4.2", Message: "inviter's membership is not join"}
+ ErrInviteTargetAlreadyInRoom = AuthFailError{Index: "5.4.3", Message: "invite target user is already in the room"}
+ ErrInviteTargetBanned = AuthFailError{Index: "5.4.3", Message: "invite target user is banned"}
+ ErrInsufficientPermissionForInvite = AuthFailError{Index: "5.4.5", Message: "inviter does not have sufficient permission to send invites"}
+ ErrCantLeaveWithoutBeingInRoom = AuthFailError{Index: "5.5.1", Message: "can't leave room without being in it"}
+ ErrCantKickWithoutBeingInRoom = AuthFailError{Index: "5.5.2", Message: "can't kick another user without being in the room"}
+ ErrInsufficientPermissionForUnban = AuthFailError{Index: "5.5.3", Message: "sender does not have sufficient permission to unban users"}
+ ErrInsufficientPermissionForKick = AuthFailError{Index: "5.5.5", Message: "sender does not have sufficient permission to kick the user"}
+ ErrCantBanWithoutBeingInRoom = AuthFailError{Index: "5.6.1", Message: "can't ban another user without being in the room"}
+ ErrInsufficientPermissionForBan = AuthFailError{Index: "5.6.3", Message: "sender does not have sufficient permission to ban the user"}
+ ErrNotKnockableRoom = AuthFailError{Index: "5.7.1", Message: "join rule doesn't allow knocking"}
+ ErrCantKnockOtherUser = AuthFailError{Index: "5.7.1", Message: "can't send knock event with different state key"}
+ ErrCantKnockWhileInRoom = AuthFailError{Index: "5.7.2", Message: "can't knock while joined, invited or banned"}
+ ErrUnknownMembership = AuthFailError{Index: "5.8", Message: "unknown membership in m.room.member event"}
+
+ ErrNotInRoom = AuthFailError{Index: "6", Message: "sender is not a member of the room"}
+
+ ErrInsufficientPowerForThirdPartyInvite = AuthFailError{Index: "7.1", Message: "sender does not have sufficient power level to send third party invite"}
+
+ ErrInsufficientPowerLevel = AuthFailError{Index: "8", Message: "sender does not have sufficient power level to send event"}
+
+ ErrMismatchingPrivateStateKey = AuthFailError{Index: "9", Message: "state keys starting with @ must match sender user ID"}
+
+ ErrTopLevelPLNotInteger = AuthFailError{Index: "10.1", Message: "invalid type for top-level power level field"}
+ ErrPLNotInteger = AuthFailError{Index: "10.2", Message: "invalid type for power level"}
+ ErrInvalidUserIDInPL = AuthFailError{Index: "10.3", Message: "invalid user ID in power levels"}
+ ErrUserPLNotInteger = AuthFailError{Index: "10.3", Message: "invalid type for user power level"}
+ ErrCreatorInPowerLevels = AuthFailError{Index: "10.4", Message: "room creators must not be specified in power levels"}
+ ErrInvalidPowerChange = AuthFailError{Index: "10.x", Message: "illegal power level change"}
+ ErrInvalidUserPowerChange = AuthFailError{Index: "10.9", Message: "illegal power level change"}
+)
+
+func isRejected(evt *pdu.PDU) bool {
+ return evt.InternalMeta.Rejected
+}
+
+type GetEventsFunc = func(ids []id.EventID) ([]*pdu.PDU, error)
+
+func Authorize(roomVersion id.RoomVersion, evt *pdu.PDU, getEvents GetEventsFunc, getKey pdu.GetKeyFunc) error {
+ if evt.Type == event.StateCreate.Type {
+ // 1. If type is m.room.create:
+ return authorizeCreate(roomVersion, evt)
+ }
+ var createEvt *pdu.PDU
+ if roomVersion.RoomIDIsCreateEventID() {
+ // 2. If the event’s room_id is not an event ID for an accepted (not rejected) m.room.create event,
+ // with the sigil ! instead of $, reject.
+ if len(evt.RoomID) != 44 {
+ return fmt.Errorf("%w (%d)", ErrInvalidRoomIDLength, len(evt.RoomID))
+ } else if createEvts, err := getEvents([]id.EventID{id.EventID("$" + evt.RoomID[1:])}); err != nil {
+ return fmt.Errorf("%w: %w", ErrFailedToGetCreateEvent, err)
+ } else if len(createEvts) != 1 {
+ return fmt.Errorf("%w (%s)", ErrCreateEventNotFound, evt.RoomID)
+ } else if isRejected(createEvts[0]) {
+ return ErrRejectedCreateEvent
+ } else {
+ createEvt = createEvts[0]
+ }
+ }
+ authEvents, err := getEvents(evt.AuthEvents)
+ if err != nil {
+ return fmt.Errorf("%w: %w", ErrFailedToGetAuthEvents, err)
+ }
+ expectedAuthEvents := evt.AuthEventSelection(roomVersion)
+ deduplicator := make(map[pdu.StateKey]id.EventID, len(expectedAuthEvents))
+ // 3. Considering the event’s auth_events:
+ for i, ae := range authEvents {
+ authEvtID := evt.AuthEvents[i]
+ if ae == nil {
+ return fmt.Errorf("%w (%s)", ErrMissingAuthEvent, authEvtID)
+ } else if ae.StateKey == nil {
+ // This approximately falls under rule 3.2.
+ return fmt.Errorf("%w (%s)", ErrNonStateAuthEvent, authEvtID)
+ }
+ key := pdu.StateKey{Type: ae.Type, StateKey: *ae.StateKey}
+ if prevEvtID, alreadyFound := deduplicator[key]; alreadyFound {
+ // 3.1. If there are duplicate entries for a given type and state_key pair, reject.
+ return fmt.Errorf("%w for %s/%s: found %s and %s", ErrDuplicateAuthEvent, ae.Type, *ae.StateKey, prevEvtID, authEvtID)
+ } else if !expectedAuthEvents.Has(key) {
+ // 3.2. If there are entries whose type and state_key don’t match those specified by
+ // the auth events selection algorithm described in the server specification, reject.
+ return fmt.Errorf("%w: found %s with key %s/%s", ErrUnexpectedAuthEvent, authEvtID, ae.Type, *ae.StateKey)
+ } else if isRejected(ae) {
+ // 3.3. If there are entries which were themselves rejected under the checks performed on receipt of a PDU, reject.
+ return fmt.Errorf("%w (%s)", ErrRejectedAuthEvent, authEvtID)
+ } else if ae.RoomID != evt.RoomID {
+ // 3.4. If any event in auth_events has a room_id which does not match that of the event being authorised, reject.
+ return fmt.Errorf("%w (%s)", ErrMismatchingRoomIDInAuthEvent, authEvtID)
+ } else {
+ deduplicator[key] = authEvtID
+ }
+ if ae.Type == event.StateCreate.Type {
+ if createEvt == nil {
+ createEvt = ae
+ } else {
+ // Duplicates are prevented by deduplicator, AuthEventSelection also won't allow a create event at all for v12+
+ panic(fmt.Errorf("impossible case: multiple create events found in auth events"))
+ }
+ }
+ }
+ if createEvt == nil {
+ // This comes either from auth_events or room_id depending on the room version.
+ // The checks above make sure it's from the right source.
+ return ErrNoCreateEvent
+ }
+ if federateVal := gjson.GetBytes(createEvt.Content, mFederatePath); federateVal.Type == gjson.False && createEvt.Sender.Homeserver() != evt.Sender.Homeserver() {
+ // 4. If the content of the m.room.create event in the room state has the property m.federate set to false,
+ // and the sender domain of the event does not match the sender domain of the create event, reject.
+ return ErrFederationDisabled
+ }
+ if evt.Type == event.StateMember.Type {
+ // 5. If type is m.room.member:
+ return authorizeMember(roomVersion, evt, createEvt, authEvents, getKey)
+ }
+ senderMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, evt.Sender.String(), "membership", "leave"))
+ if senderMembership != event.MembershipJoin {
+ // 6. If the sender’s current membership state is not join, reject.
+ return ErrNotInRoom
+ }
+ powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt)
+ if err != nil {
+ return err
+ }
+ senderPL := powerLevels.GetUserLevel(evt.Sender)
+ if evt.Type == event.StateThirdPartyInvite.Type {
+ // 7.1. Allow if and only if sender’s current power level is greater than or equal to the invite level.
+ if senderPL >= powerLevels.Invite() {
+ return nil
+ }
+ return ErrInsufficientPowerForThirdPartyInvite
+ }
+ typeClass := event.MessageEventType
+ if evt.StateKey != nil {
+ typeClass = event.StateEventType
+ }
+ evtLevel := powerLevels.GetEventLevel(event.Type{Type: evt.Type, Class: typeClass})
+ if evtLevel > senderPL {
+ // 8. If the event type’s required power level is greater than the sender’s power level, reject.
+ return fmt.Errorf("%w (%d > %d)", ErrInsufficientPowerLevel, evtLevel, senderPL)
+ }
+
+ if evt.StateKey != nil && strings.HasPrefix(*evt.StateKey, "@") && *evt.StateKey != evt.Sender.String() {
+ // 9. If the event has a state_key that starts with an @ and does not match the sender, reject.
+ return ErrMismatchingPrivateStateKey
+ }
+
+ if evt.Type == event.StatePowerLevels.Type {
+ // 10. If type is m.room.power_levels:
+ return authorizePowerLevels(roomVersion, evt, createEvt, authEvents)
+ }
+
+ // 11. Otherwise, allow.
+ return nil
+}
+
+var ErrUserIDNotAString = errors.New("not a string")
+var ErrUserIDNotValid = errors.New("not a valid user ID")
+
+func isValidUserID(roomVersion id.RoomVersion, userID gjson.Result) error {
+ if userID.Type != gjson.String {
+ return ErrUserIDNotAString
+ }
+ // In a future room version, user IDs will have stricter validation
+ _, _, err := id.UserID(userID.Str).Parse()
+ if err != nil {
+ return ErrUserIDNotValid
+ }
+ return nil
+}
+
+func authorizeCreate(roomVersion id.RoomVersion, evt *pdu.PDU) error {
+ if len(evt.PrevEvents) > 0 {
+ // 1.1. If it has any prev_events, reject.
+ return ErrCreateHasPrevEvents
+ }
+ if roomVersion.RoomIDIsCreateEventID() {
+ if evt.RoomID != "" {
+ // 1.2. If the event has a room_id, reject.
+ return ErrCreateHasRoomID
+ }
+ } else {
+ _, _, server := id.ParseCommonIdentifier(evt.RoomID)
+ if server == "" || server != evt.Sender.Homeserver() {
+ // 1.2. (v11 and below) If the domain of the room_id does not match the domain of the sender, reject.
+ return ErrRoomIDDoesntMatchSender
+ }
+ }
+ if !roomVersion.IsKnown() {
+ // 1.3. If content.room_version is present and is not a recognised version, reject.
+ return fmt.Errorf("%w %s", ErrUnknownRoomVersion, roomVersion)
+ }
+ if roomVersion.PrivilegedRoomCreators() {
+ additionalCreators := gjson.GetBytes(evt.Content, "additional_creators")
+ if additionalCreators.Exists() {
+ if !additionalCreators.IsArray() {
+ return fmt.Errorf("%w: not an array", ErrInvalidAdditionalCreators)
+ }
+ for i, item := range additionalCreators.Array() {
+ // 1.4. If additional_creators is present in content and is not an array of strings
+ // where each string passes the same user ID validation applied to sender, reject.
+ if err := isValidUserID(roomVersion, item); err != nil {
+ return fmt.Errorf("%w: item #%d %w", ErrInvalidAdditionalCreators, i+1, err)
+ }
+ }
+ }
+ }
+ if roomVersion.CreatorInContent() {
+ // 1.4. (v10 and below) If content has no creator property, reject.
+ if !gjson.GetBytes(evt.Content, "creator").Exists() {
+ return ErrMissingCreator
+ }
+ }
+ // 1.5. Otherwise, allow.
+ return nil
+}
+
+func authorizeMember(roomVersion id.RoomVersion, evt, createEvt *pdu.PDU, authEvents []*pdu.PDU, getKey pdu.GetKeyFunc) error {
+ membership := event.Membership(gjson.GetBytes(evt.Content, "membership").Str)
+ if evt.StateKey == nil {
+ // 5.1. If there is no state_key property, or no membership property in content, reject.
+ return ErrMemberNotState
+ }
+ authorizedVia := id.UserID(gjson.GetBytes(evt.Content, "authorised_via_users_server").Str)
+ if authorizedVia != "" {
+ homeserver := authorizedVia.Homeserver()
+ err := evt.VerifySignature(roomVersion, homeserver, getKey)
+ if err != nil {
+ // 5.2. If content has a join_authorised_via_users_server key:
+ // 5.2.1. If the event is not validly signed by the homeserver of the user ID denoted by the key, reject.
+ return fmt.Errorf("%w: %w", ErrNotSignedByAuthoriser, err)
+ }
+ }
+ targetPrevMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, *evt.StateKey, "membership", "leave"))
+ senderMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, evt.Sender.String(), "membership", "leave"))
+ switch membership {
+ case event.MembershipJoin:
+ createEvtID, err := createEvt.GetEventID(roomVersion)
+ if err != nil {
+ return fmt.Errorf("failed to get create event ID: %w", err)
+ }
+ creator := createEvt.Sender.String()
+ if roomVersion.CreatorInContent() {
+ creator = gjson.GetBytes(evt.Content, "creator").Str
+ }
+ if len(evt.PrevEvents) == 1 &&
+ len(evt.AuthEvents) <= 1 &&
+ evt.PrevEvents[0] == createEvtID &&
+ *evt.StateKey == creator {
+ // 5.3.1. If the only previous event is an m.room.create and the state_key is the sender of the m.room.create, allow.
+ return nil
+ }
+ // Spec wart: this would make more sense before the check above.
+ // Now you can set anyone as the sender of the first join.
+ if evt.Sender.String() != *evt.StateKey {
+ // 5.3.2. If the sender does not match state_key, reject.
+ return ErrCantJoinOtherUser
+ }
+
+ if senderMembership == event.MembershipBan {
+ // 5.3.3. If the sender is banned, reject.
+ return ErrCantJoinBanned
+ }
+
+ joinRule := event.JoinRule(findEventAndReadString(authEvents, event.StateJoinRules.Type, "", "join_rule", "invite"))
+ switch joinRule {
+ case event.JoinRuleKnock:
+ if !roomVersion.Knocks() {
+ return ErrInvalidJoinRule
+ }
+ fallthrough
+ case event.JoinRuleInvite:
+ // 5.3.4. If the join_rule is invite or knock then allow if membership state is invite or join.
+ if targetPrevMembership == event.MembershipJoin || targetPrevMembership == event.MembershipInvite {
+ return nil
+ }
+ return ErrCantJoinWithoutInvite
+ case event.JoinRuleKnockRestricted:
+ if !roomVersion.KnockRestricted() {
+ return ErrInvalidJoinRule
+ }
+ fallthrough
+ case event.JoinRuleRestricted:
+ if joinRule == event.JoinRuleRestricted && !roomVersion.RestrictedJoins() {
+ return ErrInvalidJoinRule
+ }
+ if targetPrevMembership == event.MembershipJoin || targetPrevMembership == event.MembershipInvite {
+ // 5.3.5.1. If membership state is join or invite, allow.
+ return nil
+ }
+ powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt)
+ if err != nil {
+ return err
+ }
+ if powerLevels.GetUserLevel(authorizedVia) < powerLevels.Invite() {
+ // 5.3.5.2. If the join_authorised_via_users_server key in content is not a user with sufficient permission to invite other users, reject.
+ return ErrAuthoriserCantInvite
+ }
+ authorizerMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, authorizedVia.String(), "membership", string(event.MembershipLeave)))
+ if authorizerMembership != event.MembershipJoin {
+ return ErrAuthoriserNotInRoom
+ }
+ // 5.3.5.3. Otherwise, allow.
+ return nil
+ case event.JoinRulePublic:
+ // 5.3.6. If the join_rule is public, allow.
+ return nil
+ default:
+ // 5.3.7. Otherwise, reject.
+ return ErrInvalidJoinRule
+ }
+ case event.MembershipInvite:
+ tpiVal := gjson.GetBytes(evt.Content, "third_party_invite")
+ if tpiVal.Exists() {
+ if targetPrevMembership == event.MembershipBan {
+ return ErrThirdPartyInviteBanned
+ }
+ signed := tpiVal.Get("signed")
+ mxid := signed.Get("mxid").Str
+ token := signed.Get("token").Str
+ if mxid == "" || token == "" {
+ // 5.4.1.2. If content.third_party_invite does not have a signed property, reject.
+ // 5.4.1.3. If signed does not have mxid and token properties, reject.
+ return ErrThirdPartyInviteMissingFields
+ }
+ if mxid != *evt.StateKey {
+ // 5.4.1.4. If mxid does not match state_key, reject.
+ return ErrThirdPartyInviteMXIDMismatch
+ }
+ tpiEvt := findEvent(authEvents, event.StateThirdPartyInvite.Type, token)
+ if tpiEvt == nil {
+ // 5.4.1.5. If there is no m.room.third_party_invite event in the current room state with state_key matching token, reject.
+ return ErrThirdPartyInviteNotFound
+ }
+ if tpiEvt.Sender != evt.Sender {
+ // 5.4.1.6. If sender does not match sender of the m.room.third_party_invite, reject.
+ return ErrThirdPartyInviteSenderMismatch
+ }
+ var keys []id.Ed25519
+ const ed25519Base64Len = 43
+ oldPubKey := gjson.GetBytes(evt.Content, "public_key.token")
+ if oldPubKey.Type == gjson.String && len(oldPubKey.Str) == ed25519Base64Len {
+ keys = append(keys, id.Ed25519(oldPubKey.Str))
+ }
+ gjson.GetBytes(evt.Content, "public_keys").ForEach(func(key, value gjson.Result) bool {
+ if key.Type != gjson.Number {
+ return false
+ }
+ if value.Type == gjson.String && len(value.Str) == ed25519Base64Len {
+ keys = append(keys, id.Ed25519(value.Str))
+ }
+ return true
+ })
+ rawSigned := jsontext.Value(exstrings.UnsafeBytes(signed.Str))
+ var validated bool
+ for _, key := range keys {
+ if signutil.VerifyJSONAny(key, rawSigned) == nil {
+ validated = true
+ }
+ }
+ if validated {
+ // 4.4.1.7. If any signature in signed matches any public key in the m.room.third_party_invite event, allow.
+ return nil
+ }
+ // 4.4.1.8. Otherwise, reject.
+ return ErrThirdPartyInviteNotSigned
+ }
+ if senderMembership != event.MembershipJoin {
+ // 5.4.2. If the sender’s current membership state is not join, reject.
+ return ErrInviterNotInRoom
+ }
+ // 5.4.3. If target user’s current membership state is join or ban, reject.
+ if targetPrevMembership == event.MembershipJoin {
+ return ErrInviteTargetAlreadyInRoom
+ } else if targetPrevMembership == event.MembershipBan {
+ return ErrInviteTargetBanned
+ }
+ powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt)
+ if err != nil {
+ return err
+ }
+ if powerLevels.GetUserLevel(evt.Sender) >= powerLevels.Invite() {
+ // 5.4.4. If the sender’s power level is greater than or equal to the invite level, allow.
+ return nil
+ }
+ // 5.4.5. Otherwise, reject.
+ return ErrInsufficientPermissionForInvite
+ case event.MembershipLeave:
+ if evt.Sender.String() == *evt.StateKey {
+ // 5.5.1. If the sender matches state_key, allow if and only if that user’s current membership state is invite, join, or knock.
+ if senderMembership == event.MembershipInvite ||
+ senderMembership == event.MembershipJoin ||
+ (senderMembership == event.MembershipKnock && roomVersion.Knocks()) {
+ return nil
+ }
+ return ErrCantLeaveWithoutBeingInRoom
+ }
+ if senderMembership != event.MembershipJoin {
+ // 5.5.2. If the sender’s current membership state is not join, reject.
+ return ErrCantKickWithoutBeingInRoom
+ }
+ powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt)
+ if err != nil {
+ return err
+ }
+ senderLevel := powerLevels.GetUserLevel(evt.Sender)
+ if targetPrevMembership == event.MembershipBan && senderLevel < powerLevels.Ban() {
+ // 5.5.3. If the target user’s current membership state is ban, and the sender’s power level is less than the ban level, reject.
+ return ErrInsufficientPermissionForUnban
+ }
+ if senderLevel >= powerLevels.Kick() && powerLevels.GetUserLevel(id.UserID(*evt.StateKey)) < senderLevel {
+ // 5.5.4. If the sender’s power level is greater than or equal to the kick level, and the target user’s power level is less than the sender’s power level, allow.
+ return nil
+ }
+ // TODO separate errors for < kick and < target user level?
+ // 5.5.5. Otherwise, reject.
+ return ErrInsufficientPermissionForKick
+ case event.MembershipBan:
+ if senderMembership != event.MembershipJoin {
+ // 5.6.1. If the sender’s current membership state is not join, reject.
+ return ErrCantBanWithoutBeingInRoom
+ }
+ powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt)
+ if err != nil {
+ return err
+ }
+ senderLevel := powerLevels.GetUserLevel(evt.Sender)
+ if senderLevel >= powerLevels.Ban() && powerLevels.GetUserLevel(id.UserID(*evt.StateKey)) < senderLevel {
+ // 5.6.2. If the sender’s power level is greater than or equal to the ban level, and the target user’s power level is less than the sender’s power level, allow.
+ return nil
+ }
+ // 5.6.3. Otherwise, reject.
+ return ErrInsufficientPermissionForBan
+ case event.MembershipKnock:
+ joinRule := event.JoinRule(findEventAndReadString(authEvents, event.StateJoinRules.Type, "", "join_rule", "invite"))
+ validKnockRule := roomVersion.Knocks() && joinRule == event.JoinRuleKnock
+ validKnockRestrictedRule := roomVersion.KnockRestricted() && joinRule == event.JoinRuleKnockRestricted
+ if !validKnockRule && !validKnockRestrictedRule {
+ // 5.7.1. If the join_rule is anything other than knock or knock_restricted, reject.
+ return ErrNotKnockableRoom
+ }
+ if evt.Sender.String() != *evt.StateKey {
+ // 5.7.2. If the sender does not match state_key, reject.
+ return ErrCantKnockOtherUser
+ }
+ if senderMembership != event.MembershipBan && senderMembership != event.MembershipInvite && senderMembership != event.MembershipJoin {
+ // 5.7.3. If the sender’s current membership is not ban, invite, or join, allow.
+ return nil
+ }
+ // 5.7.4. Otherwise, reject.
+ return ErrCantKnockWhileInRoom
+ default:
+ // 5.8. Otherwise, the membership is unknown. Reject.
+ return ErrUnknownMembership
+ }
+}
+
+func authorizePowerLevels(roomVersion id.RoomVersion, evt, createEvt *pdu.PDU, authEvents []*pdu.PDU) error {
+ if roomVersion.ValidatePowerLevelInts() {
+ for _, key := range []string{"users_default", "events_default", "state_default", "ban", "redact", "kick", "invite"} {
+ res := gjson.GetBytes(evt.Content, key)
+ if !res.Exists() {
+ continue
+ }
+ if parseIntWithVersion(roomVersion, res) == nil {
+ // 10.1. If any of the properties users_default, events_default, state_default, ban, redact, kick, or invite in content are present and not an integer, reject.
+ return fmt.Errorf("%w %s", ErrTopLevelPLNotInteger, key)
+ }
+ }
+ for _, key := range []string{"events", "notifications"} {
+ obj := gjson.GetBytes(evt.Content, key)
+ if !obj.Exists() {
+ continue
+ }
+ // 10.2. If either of the properties events or notifications in content are present and not an object [...], reject.
+ if !obj.IsObject() {
+ return fmt.Errorf("%w %s", ErrTopLevelPLNotInteger, key)
+ }
+ var err error
+ // 10.2. [...] are not an object with values that are integers, reject.
+ obj.ForEach(func(innerKey, value gjson.Result) bool {
+ if parseIntWithVersion(roomVersion, value) == nil {
+ err = fmt.Errorf("%w %s.%s", ErrPLNotInteger, key, innerKey.Str)
+ return false
+ }
+ return true
+ })
+ if err != nil {
+ return err
+ }
+ }
+ }
+ var creators []id.UserID
+ if roomVersion.PrivilegedRoomCreators() {
+ creators = append(creators, createEvt.Sender)
+ gjson.GetBytes(createEvt.Content, "additional_creators").ForEach(func(key, value gjson.Result) bool {
+ creators = append(creators, id.UserID(value.Str))
+ return true
+ })
+ }
+ users := gjson.GetBytes(evt.Content, "users")
+ if users.Exists() {
+ if !users.IsObject() {
+ // 10.3. If the users property in content is not an object [...], reject.
+ return fmt.Errorf("%w users", ErrTopLevelPLNotInteger)
+ }
+ var err error
+ users.ForEach(func(key, value gjson.Result) bool {
+ if validatorErr := isValidUserID(roomVersion, key); validatorErr != nil {
+ // 10.3. [...] is not an object with keys that are valid user IDs [...], reject.
+ err = fmt.Errorf("%w: %q %w", ErrInvalidUserIDInPL, key.Str, validatorErr)
+ return false
+ }
+ if parseIntWithVersion(roomVersion, value) == nil {
+ // 10.3. [...] is not an object [...] with values that are integers, reject.
+ err = fmt.Errorf("%w %q", ErrUserPLNotInteger, key.Str)
+ return false
+ }
+ // creators is only filled if the room version has privileged room creators
+ if slices.Contains(creators, id.UserID(key.Str)) {
+ // 10.4. If the users property in content contains the sender of the m.room.create event or any of
+ // the additional_creators array (if present) from the content of the m.room.create event, reject.
+ err = fmt.Errorf("%w: %q", ErrCreatorInPowerLevels, key.Str)
+ return false
+ }
+ return true
+ })
+ if err != nil {
+ return err
+ }
+ }
+ oldPL := findEvent(authEvents, event.StatePowerLevels.Type, "")
+ if oldPL == nil {
+ // 10.5. If there is no previous m.room.power_levels event in the room, allow.
+ return nil
+ }
+ if slices.Contains(creators, evt.Sender) {
+ // Skip remaining checks for creators
+ return nil
+ }
+ senderPLPtr := parsePythonInt(gjson.GetBytes(oldPL.Content, exgjson.Path("users", evt.Sender.String())))
+ if senderPLPtr == nil {
+ senderPLPtr = parsePythonInt(gjson.GetBytes(oldPL.Content, "users_default"))
+ if senderPLPtr == nil {
+ senderPLPtr = ptr.Ptr(0)
+ }
+ }
+ for _, key := range []string{"users_default", "events_default", "state_default", "ban", "redact", "kick", "invite"} {
+ oldVal := gjson.GetBytes(oldPL.Content, key)
+ newVal := gjson.GetBytes(evt.Content, key)
+ if err := allowPowerChange(roomVersion, *senderPLPtr, key, oldVal, newVal); err != nil {
+ return err
+ }
+ }
+ if err := allowPowerChangeMap(
+ roomVersion, *senderPLPtr, "events", "",
+ gjson.GetBytes(oldPL.Content, "events"),
+ gjson.GetBytes(evt.Content, "events"),
+ ); err != nil {
+ return err
+ }
+ if err := allowPowerChangeMap(
+ roomVersion, *senderPLPtr, "notifications", "",
+ gjson.GetBytes(oldPL.Content, "notifications"),
+ gjson.GetBytes(evt.Content, "notifications"),
+ ); err != nil {
+ return err
+ }
+ if err := allowPowerChangeMap(
+ roomVersion, *senderPLPtr, "users", evt.Sender.String(),
+ gjson.GetBytes(oldPL.Content, "users"),
+ gjson.GetBytes(evt.Content, "users"),
+ ); err != nil {
+ return err
+ }
+ return nil
+}
+
+func allowPowerChangeMap(roomVersion id.RoomVersion, maxVal int, path, ownID string, old, new gjson.Result) (err error) {
+ old.ForEach(func(key, value gjson.Result) bool {
+ newVal := new.Get(exgjson.Path(key.Str))
+ err = allowPowerChange(roomVersion, maxVal, path+"."+key.Str, value, newVal)
+ if err == nil && ownID != "" && key.Str != ownID {
+ parsedOldVal := parseIntWithVersion(roomVersion, value)
+ parsedNewVal := parseIntWithVersion(roomVersion, newVal)
+ if *parsedOldVal >= maxVal && *parsedOldVal != *parsedNewVal {
+ err = fmt.Errorf("%w: can't change users.%s from %s to %s with sender level %d", ErrInvalidUserPowerChange, key.Str, stringifyForError(value), stringifyForError(newVal), maxVal)
+ }
+ }
+ return err == nil
+ })
+ if err != nil {
+ return
+ }
+ new.ForEach(func(key, value gjson.Result) bool {
+ err = allowPowerChange(roomVersion, maxVal, path+"."+key.Str, old.Get(exgjson.Path(key.Str)), value)
+ return err == nil
+ })
+ return
+}
+
+func allowPowerChange(roomVersion id.RoomVersion, maxVal int, path string, old, new gjson.Result) error {
+ oldVal := parseIntWithVersion(roomVersion, old)
+ newVal := parseIntWithVersion(roomVersion, new)
+ if oldVal == nil {
+ if newVal == nil || *newVal <= maxVal {
+ return nil
+ }
+ } else if newVal == nil {
+ if *oldVal <= maxVal {
+ return nil
+ }
+ } else if *oldVal == *newVal || (*oldVal <= maxVal && *newVal <= maxVal) {
+ return nil
+ }
+ return fmt.Errorf("%w can't change %s from %s to %s with sender level %d", ErrInvalidPowerChange, path, stringifyForError(old), stringifyForError(new), maxVal)
+}
+
+func stringifyForError(val gjson.Result) string {
+ if !val.Exists() {
+ return "null"
+ }
+ return val.Raw
+}
+
+func findEvent(events []*pdu.PDU, evtType, stateKey string) *pdu.PDU {
+ for _, evt := range events {
+ if evt.Type == evtType && *evt.StateKey == stateKey {
+ return evt
+ }
+ }
+ return nil
+}
+
+func findEventAndReadData[T any](events []*pdu.PDU, evtType, stateKey string, reader func(evt *pdu.PDU) T) T {
+ return reader(findEvent(events, evtType, stateKey))
+}
+
+func findEventAndReadString(events []*pdu.PDU, evtType, stateKey, fieldPath, defVal string) string {
+ return findEventAndReadData(events, evtType, stateKey, func(evt *pdu.PDU) string {
+ if evt == nil {
+ return defVal
+ }
+ res := gjson.GetBytes(evt.Content, fieldPath)
+ if res.Type != gjson.String {
+ return defVal
+ }
+ return res.Str
+ })
+}
+
+func getPowerLevels(roomVersion id.RoomVersion, authEvents []*pdu.PDU, createEvt *pdu.PDU) (*event.PowerLevelsEventContent, error) {
+ var err error
+ powerLevels := findEventAndReadData(authEvents, event.StatePowerLevels.Type, "", func(evt *pdu.PDU) *event.PowerLevelsEventContent {
+ if evt == nil {
+ return nil
+ }
+ content := evt.Content
+ out := &event.PowerLevelsEventContent{}
+ if !roomVersion.ValidatePowerLevelInts() {
+ safeParsePowerLevels(content, out)
+ } else {
+ err = json.Unmarshal(content, out)
+ }
+ return out
+ })
+ if err != nil {
+ // This should never happen thanks to safeParsePowerLevels for v1-9 and strict validation in v10+
+ return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err)
+ }
+ if roomVersion.PrivilegedRoomCreators() {
+ if powerLevels == nil {
+ powerLevels = &event.PowerLevelsEventContent{}
+ }
+ powerLevels.CreateEvent, err = createEvt.ToClientEvent(roomVersion)
+ if err != nil {
+ return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err)
+ }
+ err = powerLevels.CreateEvent.Content.ParseRaw(powerLevels.CreateEvent.Type)
+ if err != nil {
+ return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err)
+ }
+ } else if powerLevels == nil {
+ powerLevels = &event.PowerLevelsEventContent{
+ Users: map[id.UserID]int{
+ createEvt.Sender: 100,
+ },
+ }
+ }
+ return powerLevels, nil
+}
+
+func parseIntWithVersion(roomVersion id.RoomVersion, val gjson.Result) *int {
+ if roomVersion.ValidatePowerLevelInts() {
+ if val.Type != gjson.Number {
+ return nil
+ }
+ return ptr.Ptr(int(val.Int()))
+ }
+ return parsePythonInt(val)
+}
+
+func parsePythonInt(val gjson.Result) *int {
+ switch val.Type {
+ case gjson.True:
+ return ptr.Ptr(1)
+ case gjson.False:
+ return ptr.Ptr(0)
+ case gjson.Number:
+ return ptr.Ptr(int(val.Int()))
+ case gjson.String:
+ // strconv.Atoi accepts signs as well as leading zeroes, so we just need to trim spaces beforehand
+ num, err := strconv.Atoi(strings.TrimSpace(val.Str))
+ if err != nil {
+ return nil
+ }
+ return &num
+ default:
+ // Python int() doesn't accept nulls, arrays or dicts
+ return nil
+ }
+}
+
+func safeParsePowerLevels(content jsontext.Value, into *event.PowerLevelsEventContent) {
+ *into = event.PowerLevelsEventContent{
+ Users: make(map[id.UserID]int),
+ UsersDefault: ptr.Val(parsePythonInt(gjson.GetBytes(content, "users_default"))),
+ Events: make(map[string]int),
+ EventsDefault: ptr.Val(parsePythonInt(gjson.GetBytes(content, "events_default"))),
+ Notifications: nil, // irrelevant for event auth
+ StateDefaultPtr: parsePythonInt(gjson.GetBytes(content, "state_default")),
+ InvitePtr: parsePythonInt(gjson.GetBytes(content, "invite")),
+ KickPtr: parsePythonInt(gjson.GetBytes(content, "kick")),
+ BanPtr: parsePythonInt(gjson.GetBytes(content, "ban")),
+ RedactPtr: parsePythonInt(gjson.GetBytes(content, "redact")),
+ }
+ gjson.GetBytes(content, "events").ForEach(func(key, value gjson.Result) bool {
+ if key.Type != gjson.String {
+ return false
+ }
+ val := parsePythonInt(value)
+ if val != nil {
+ into.Events[key.Str] = *val
+ }
+ return true
+ })
+ gjson.GetBytes(content, "users").ForEach(func(key, value gjson.Result) bool {
+ if key.Type != gjson.String {
+ return false
+ }
+ val := parsePythonInt(value)
+ if val == nil {
+ return false
+ }
+ userID := id.UserID(key.Str)
+ if _, _, err := userID.Parse(); err != nil {
+ return false
+ }
+ into.Users[userID] = *val
+ return true
+ })
+}
diff --git a/federation/eventauth/eventauth_internal_test.go b/federation/eventauth/eventauth_internal_test.go
new file mode 100644
index 00000000..d316f3c8
--- /dev/null
+++ b/federation/eventauth/eventauth_internal_test.go
@@ -0,0 +1,66 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package eventauth
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+type pythonIntTest struct {
+ Name string
+ Input string
+ Expected int64
+}
+
+var pythonIntTests = []pythonIntTest{
+ {"True", `true`, 1},
+ {"False", `false`, 0},
+ {"SmallFloat", `3.1415`, 3},
+ {"SmallFloatRoundDown", `10.999999999999999`, 10},
+ {"SmallFloatRoundUp", `10.9999999999999999`, 11},
+ {"BigFloatRoundDown", `1000000.9999999999`, 1000000},
+ {"BigFloatRoundUp", `1000000.99999999999`, 1000001},
+ {"BigFloatPrecisionError", `9007199254740993.0`, 9007199254740992},
+ {"BigFloatPrecisionError2", `9007199254740993.123`, 9007199254740994},
+ {"Int64", `9223372036854775807`, 9223372036854775807},
+ {"Int64String", `"9223372036854775807"`, 9223372036854775807},
+ {"String", `"123"`, 123},
+ {"InvalidFloatInString", `"123.456"`, 0},
+ {"StringWithPlusSign", `"+123"`, 123},
+ {"StringWithMinusSign", `"-123"`, -123},
+ {"StringWithSpaces", `" 123 "`, 123},
+ {"StringWithSpacesAndSign", `" -123 "`, -123},
+ //{"StringWithUnderscores", `"123_456"`, 123456},
+ //{"StringWithUnderscores", `"123_456"`, 123456},
+ {"InvalidStringWithTrailingUnderscore", `"123_456_"`, 0},
+ {"InvalidStringWithMultipleUnderscores", `"123__456"`, 0},
+ {"InvalidStringWithLeadingUnderscore", `"_123_456"`, 0},
+ {"InvalidStringWithUnderscoreAfterSign", `"+_123_456"`, 0},
+ {"InvalidStringWithUnderscoreAfterSpace", `" _123_456"`, 0},
+ //{"StringWithUnderscoresAndSpaces", `" +1_2_3_4_5_6 "`, 123456},
+}
+
+func TestParsePythonInt(t *testing.T) {
+ for _, test := range pythonIntTests {
+ t.Run(test.Name, func(t *testing.T) {
+ output := parsePythonInt(gjson.Parse(test.Input))
+ if strings.HasPrefix(test.Name, "Invalid") {
+ assert.Nil(t, output)
+ } else {
+ require.NotNil(t, output)
+ assert.Equal(t, int(test.Expected), *output)
+ }
+ })
+ }
+}
diff --git a/federation/eventauth/eventauth_test.go b/federation/eventauth/eventauth_test.go
new file mode 100644
index 00000000..e3c5cd76
--- /dev/null
+++ b/federation/eventauth/eventauth_test.go
@@ -0,0 +1,85 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package eventauth_test
+
+import (
+ "embed"
+ "encoding/json/jsontext"
+ "encoding/json/v2"
+ "errors"
+ "io"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+ "go.mau.fi/util/exerrors"
+ "go.mau.fi/util/ptr"
+
+ "maunium.net/go/mautrix/federation/eventauth"
+ "maunium.net/go/mautrix/federation/pdu"
+ "maunium.net/go/mautrix/id"
+)
+
+//go:embed *.jsonl
+var data embed.FS
+
+type eventMap map[id.EventID]*pdu.PDU
+
+func (em eventMap) Get(ids []id.EventID) ([]*pdu.PDU, error) {
+ output := make([]*pdu.PDU, len(ids))
+ for i, evtID := range ids {
+ output[i] = em[evtID]
+ }
+ return output, nil
+}
+
+func GetKey(serverName string, keyID id.KeyID, validUntilTS time.Time) (id.SigningKey, time.Time, error) {
+ return "", time.Time{}, nil
+}
+
+func TestAuthorize(t *testing.T) {
+ files := exerrors.Must(data.ReadDir("."))
+ for _, file := range files {
+ t.Run(file.Name(), func(t *testing.T) {
+ decoder := jsontext.NewDecoder(exerrors.Must(data.Open(file.Name())))
+ events := make(eventMap)
+ var roomVersion *id.RoomVersion
+ for i := 1; ; i++ {
+ var evt *pdu.PDU
+ err := json.UnmarshalDecode(decoder, &evt)
+ if errors.Is(err, io.EOF) {
+ break
+ }
+ require.NoError(t, err)
+ if roomVersion == nil {
+ require.Equal(t, evt.Type, "m.room.create")
+ roomVersion = ptr.Ptr(id.RoomVersion(gjson.GetBytes(evt.Content, "room_version").Str))
+ }
+ expectedEventID := gjson.GetBytes(evt.Unsigned, "event_id").Str
+ evtID, err := evt.GetEventID(*roomVersion)
+ require.NoError(t, err)
+ require.Equalf(t, id.EventID(expectedEventID), evtID, "Event ID mismatch for event #%d", i)
+
+ // TODO allow redacted events
+ assert.True(t, evt.VerifyContentHash(), i)
+
+ events[evtID] = evt
+ err = eventauth.Authorize(*roomVersion, evt, events.Get, GetKey)
+ if err != nil {
+ evt.InternalMeta.Rejected = true
+ }
+ // TODO allow testing intentionally rejected events
+ assert.NoErrorf(t, err, "Failed to authorize event #%d / %s of type %s", i, evtID, evt.Type)
+ }
+ })
+ }
+
+}
diff --git a/federation/eventauth/testroom-v12-success.jsonl b/federation/eventauth/testroom-v12-success.jsonl
new file mode 100644
index 00000000..2b751de3
--- /dev/null
+++ b/federation/eventauth/testroom-v12-success.jsonl
@@ -0,0 +1,21 @@
+{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age_ts":1756071567186,"event_id":"$lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54"}}
+{"auth_events":[],"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"depth":2,"hashes":{"sha256":"MXmgq0e4J9CdIP0IVKVvueFhOb+ndlsXpeyI+6l/2FI"},"origin_server_ts":1756071567259,"prev_events":["$lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"xMgRzyRg9VM9XCKpfFJA+MrYoI68b8PIddKpMTcxz/fDzmGSHEy6Ta2b59VxiX3NoJe2CigkDZ3+jVsQoZYIBA"}},"state_key":"@tulir:maunium.net","type":"m.room.member","unsigned":{"age_ts":1756071567259,"event_id":"$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"}}
+{"auth_events":["$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":150},"events_default":0,"historical":100,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001},"users_default":0},"depth":3,"hashes":{"sha256":"/JzQNBNqJ/i8vwj6xESDaD5EDdOqB4l/LmKlvAVl5jY"},"origin_server_ts":1756071567319,"prev_events":["$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"W3N3X/enja+lumXw3uz66/wT9oczoxrmHbAD5/RF069cX4wkCtqtDd61VWPkSGmKxdV1jurgbCqSX6+Q9/t3AA"}},"state_key":"","type":"m.room.power_levels","unsigned":{"age_ts":1756071567319,"event_id":"$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"}}
+{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"join_rule":"invite"},"depth":4,"hashes":{"sha256":"GBu5AySj75ZXlOLd65mB03KueFKOHNgvtg2o/LUnLyI"},"origin_server_ts":1756071567320,"prev_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"XqWEnFREo2PhRnaebGjNzdHdtD691BtCQKkLnpKd8P3lVDewDt8OkCbDSk/Uzh9rDtzwWEsbsIoKSYuOm+G6CA"}},"state_key":"","type":"m.room.join_rules","unsigned":{"age_ts":1756071567320,"event_id":"$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"}}
+{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"history_visibility":"shared"},"depth":5,"hashes":{"sha256":"niDi5vG2akQm0f5pm0aoCYXqmWjXRfmP1ulr/ZEPm/k"},"origin_server_ts":1756071567320,"prev_events":["$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"PTIrNke/fc9+ObKAl/K0PGZfmpe8dwREyoA5rXffOXWdRHSaBifn9UIiJUqd68Bzvrv4RcADTR/ci7lUquFBBw"}},"state_key":"","type":"m.room.history_visibility","unsigned":{"age_ts":1756071567320,"event_id":"$Wmy3G9yxl9ArVg5ZsdeIDPxBsNAdgseuvHoqHTZ2vug"}}
+{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"guest_access":"can_join"},"depth":6,"hashes":{"sha256":"sZ9QqsId4oarFF724esTohXuRxDNnaXPl+QmTDG60dw"},"origin_server_ts":1756071567321,"prev_events":["$Wmy3G9yxl9ArVg5ZsdeIDPxBsNAdgseuvHoqHTZ2vug"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"Eh2P9/hl38wfZx2AQbeS5VCD4wldXPfeP2sQsJsLtfmdwFV74jrlGVBaKIkaYcXY4eA08iDp8HW5jqttZqKKDg"}},"state_key":"","type":"m.room.guest_access","unsigned":{"age_ts":1756071567321,"event_id":"$hYVRH7F4P5mB5IqvBDDU5aXY7pYGG0ApstrryiVPKmQ"}}
+{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"name":"event auth test v12"},"depth":7,"hashes":{"sha256":"tjwPo38yR+23Was6SbxLvPMhNx44DaXLhF3rKgngepU"},"origin_server_ts":1756071567321,"prev_events":["$hYVRH7F4P5mB5IqvBDDU5aXY7pYGG0ApstrryiVPKmQ"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"q1rk0c5m8TJYE9tePsMaLeaigatNNbvaLRom0X8KiZY0EH+itujfA+/UnksvmPmMmThfAXWlFLx5u8tcuSVyCQ"}},"state_key":"","type":"m.room.name","unsigned":{"age_ts":1756071567321,"event_id":"$fFDwIavLTEIfcnggWuryB6JwfS-L2KT6vP1ap3P6ctE"}}
+{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"],"content":{"avatar_url":"mxc://envs.net/000cf1510b7c61018f9c72ca4cc63668370782c81725865933316030464","displayname":"tulir[e]","membership":"invite"},"depth":8,"hashes":{"sha256":"r5EBUZN/4LbVcMYwuffDcVV9G4OMHzAQuNbnjigL+OE"},"origin_server_ts":1756071567548,"prev_events":["$fFDwIavLTEIfcnggWuryB6JwfS-L2KT6vP1ap3P6ctE"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"envs.net":{"ed25519:wuJyKT":"svB+uW4Tsj8/I+SYbLl+LPPjBlqxGNXE4wGyAxlP7vfyJtFf7Kn/19jx65wT9ebeCq5sTGlEDV4Fabwma9LhDA"},"maunium.net":{"ed25519:a_xxeS":"LBYMcdJVSNsLd6SmOgx5oOU/0xOeCl03o4g83VwJfHWlRuTT5l9+qlpNED28wY07uxoU9MgLgXXICJ0EezMBCg"}},"state_key":"@tulir:envs.net","type":"m.room.member","unsigned":{"age_ts":1756071567548,"event_id":"$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok","invite_room_state":[{"content":{"join_rule":"invite"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.join_rules"},{"content":{"name":"event auth test v12"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.name"},{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age_ts":1756071567186}},{"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"sender":"@tulir:maunium.net","state_key":"@tulir:maunium.net","type":"m.room.member"}]}}
+{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":9,"hashes":{"sha256":"23rgMf7EGJcYt3Aj0qAFnmBWCxuU9Uk+ReidqtIJDKQ"},"origin_server_ts":1756071575986,"prev_events":["$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"p+Fm/uWO8VXJdCYvN/dVb8HF8W3t1sssNCBiOWbzAeuS3QqYjoMKHyixLuN1mOdnCyATv7SsHHmA4+cELRGdAA"}},"type":"m.room.message","unsigned":{"age_ts":1756071576002,"event_id":"$eZDCydRWSRnR5od0c7ahz2qSZQDHbl5g5PITT0OMC3E"}}
+{"auth_events":["$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"],"content":{"avatar_url":"mxc://envs.net/000cf1510b7c61018f9c72ca4cc63668370782c81725865933316030464","displayname":"tulir[e]","membership":"join"},"depth":10,"hashes":{"sha256":"2kJPx2UsysNzTH8QGYHUKTO/05yetxKRlI0nKFeGbts"},"origin_server_ts":1756071578631,"prev_events":["$eZDCydRWSRnR5od0c7ahz2qSZQDHbl5g5PITT0OMC3E"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"Wuzxkh8nEEX6mdJzph6Bt5ku+odFkEg2RIpFAAirOqxgcrwRaz42PsJni3YbfzH1qneF+iWQ/neA+up6jLXFBw"}},"state_key":"@tulir:envs.net","type":"m.room.member","unsigned":{"age":6,"event_id":"$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","replaces_state":"$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok"}}
+{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"content":{"avatar_url":"mxc://matrix.org/BDYVQFSLvZHMaKHDGiRkvhVg","displayname":"tulir[m]","membership":"invite"},"depth":11,"hashes":{"sha256":"dRE11R2hBfFalQ5tIJdyaElUIiSE5aCKMddjek4wR3c"},"origin_server_ts":1756071591449,"prev_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"/Mi4kX40fbR+V3DCJJGI/9L3Uuf8y5Un8LHlCQv1T0O5gnFZGQ3qN6rRNaZ1Kdh3QJBU6H4NTfnd+SVj3wt3CQ"},"matrix.org":{"ed25519:a_RXGa":"ZeLm/oxP3/Cds/uCL2FaZpgjUp0vTDBlGG6YVFNl76yIVlyIKKQKR6BSVw2u5KC5Mu9M1f+0lDmLGQujR5NkBg"}},"state_key":"@tulir:matrix.org","type":"m.room.member","unsigned":{"event_id":"$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4","invite_room_state":[{"content":{"join_rule":"invite"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.join_rules"},{"content":{"name":"event auth test v12"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.name"},{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age":11553}},{"content":{"avatar_url":"mxc://envs.net/000cf1510b7c61018f9c72ca4cc63668370782c81725865933316030464","displayname":"tulir[e]","membership":"join"},"sender":"@tulir:envs.net","state_key":"@tulir:envs.net","type":"m.room.member"}]}}
+{"auth_events":["$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"content":{"avatar_url":"mxc://matrix.org/BDYVQFSLvZHMaKHDGiRkvhVg","displayname":"tulir[m]","membership":"join"},"depth":12,"hashes":{"sha256":"hR/fRIyFkxKnA1XNxIB+NKC0VR0vHs82EDgydhmmZXU"},"origin_server_ts":1756071609205,"prev_events":["$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"matrix.org":{"ed25519:a_RXGa":"keWbZHm+LPW22XWxb14Att4Ae4GVc6XAKAnxFRr3hxhrgEhsnMcxUx7fjqlA1dk3As6kjLKdekcyCef+AQCXCA"}},"state_key":"@tulir:matrix.org","type":"m.room.member","unsigned":{"age":19,"event_id":"$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw","replaces_state":"$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4"}}
+{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":150},"events_default":0,"historical":100,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":13,"hashes":{"sha256":"30Wuw3xIbA8+eXQBa4nFDKcyHtMbKPBYhLW1zft9/fE"},"origin_server_ts":1756071643928,"prev_events":["$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"x6Y4uViq4nK8LVPqtMLdCuvNET2bnjxYTgiKuEe1JYfwB4jPBnPuqvrt1O9oaanMpcRWbnuiZjckq4bUlRZ7Cw"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg","replaces_state":"$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"}}
+{"auth_events":["$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg","$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw"],"content":{"name":"event auth test v12!"},"depth":14,"hashes":{"sha256":"WT0gz7KYXvbdNruRavqIi9Hhul3rxCdZ+YY9yMGN+Fw"},"origin_server_ts":1756071656988,"prev_events":["$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"matrix.org":{"ed25519:a_RXGa":"bSplmqtXVhO2Z3hJ8JMQ/u7G2Wmg6yt7SwhYXObRQJfthekddJN152ME4YJIwy7YD8WFq7EkyB/NMyQoliYyCg"}},"state_key":"","type":"m.room.name","unsigned":{"event_id":"$p4xvOczrhzQMtRW3-Tf86LYUb5aqpGFIgjwHBuxWIcI","replaces_state":"$fFDwIavLTEIfcnggWuryB6JwfS-L2KT6vP1ap3P6ctE"}}
+{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":9001},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":15,"hashes":{"sha256":"FnGzbcXc8YOiB1TY33QunGA17Axoyuu3sdVOj5Z408o"},"origin_server_ts":1756071804931,"prev_events":["$p4xvOczrhzQMtRW3-Tf86LYUb5aqpGFIgjwHBuxWIcI"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"uyTUsPR+CzCtlevzB5+sNXvmfbPSp6u7RZC4E4TLVsj45+pjmMRswAvuHP9PT2+Tkl6Hu8ZPigsXgbKZtR35Aw"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw","replaces_state":"$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg"}}
+{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":100},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":16,"hashes":{"sha256":"KcivsiLesdnUnKX23Akk3OJEJFGRSY0g4H+p7XIThnw"},"origin_server_ts":1756071812688,"prev_events":["$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"cAK8dO2AVZklY9te5aVKbF1jR/eB5rzeNOXfYPjBLf+aSAS4Z6R2aMKW6hJB9PqRS4S+UZc24DTrjUjnvMzeBA"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU","replaces_state":"$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"}}
+{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"],"content":{"body":"meow #2","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":17,"hashes":{"sha256":"SgH9fOXGdbdqpRfYmoz1t29+gX8Ze4ThSoj6klZs3Og"},"origin_server_ts":1756247476706,"prev_events":["$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"SMYK7zP3SaQOKhzZUKUBVCKwffYqi3PFAlPM34kRJtmfGU3KZXNBT0zi+veXDMmxkMunqhF2RTHBD6joa0kBAQ"}},"type":"m.room.message","unsigned":{"event_id":"$KFHLO0-ENYOGQXogp84C-ISSu1xtKUzIMaZ6LiBcR_w"}}
+{"auth_events":["$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw","$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":100},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:beeper.com":8999,"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":18,"hashes":{"sha256":"l8Mw3VKn/Bvntg7bZ8uh5J8M2IBZM93Xg7hsdaSci8s"},"origin_server_ts":1758918656341,"prev_events":["$KFHLO0-ENYOGQXogp84C-ISSu1xtKUzIMaZ6LiBcR_w"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"matrix.org":{"ed25519:a_RXGa":"cg5LP0WuTnVB5jFhNERLLU5b+EhmyACiOq6cp3gKJnZsTAb1yajcgJybLWKrc8QQqxPa7hPnskRBgt4OBTFNAA"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0","replaces_state":"$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"}}
+{"auth_events":["$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw","$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"],"content":{"avatar_url":"mxc://beeper.com/eBdwbHbllONoAySQkXLjbfFM","displayname":"tulir[b]","membership":"invite"},"depth":19,"hashes":{"sha256":"KpmaRUQnJju8TIDMPzakitUIKOWJxTvULpFB3a1CGgc"},"origin_server_ts":1758918665952,"prev_events":["$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"beeper.com":{"ed25519:a_zgvp":"mzI9rPkQ1xHl2/G5Yrn0qmIRt5OyjPNqRwilPfH4jmr1tP+vv3vC0m4mph/MCOq8S1c/DQaCWSpdOX1uWfchBQ"},"matrix.org":{"ed25519:a_RXGa":"kEdfr8DjxC/bdvGYxnniFI/pxDWeyG73OjG/Gu1uoHLhjdtAT/vEQ6lotJJs214/KX5eAaQWobE9qtMvtPwMDw"}},"state_key":"@tulir:beeper.com","type":"m.room.member","unsigned":{"event_id":"$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro","invite_room_state":[{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age":11553}},{"content":{"avatar_url":"mxc://matrix.org/BDYVQFSLvZHMaKHDGiRkvhVg","displayname":"tulir[m]","membership":"join"},"sender":"@tulir:matrix.org","state_key":"@tulir:matrix.org","type":"m.room.member"},{"content":{"name":"event auth test v12!"},"sender":"@tulir:matrix.org","state_key":"","type":"m.room.name"},{"content":{"join_rule":"invite"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.join_rules"}]}}
+{"auth_events":["$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro","$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"],"content":{"avatar_url":"mxc://beeper.com/eBdwbHbllONoAySQkXLjbfFM","displayname":"tulir[b]","membership":"join"},"depth":20,"hashes":{"sha256":"bmaHSm4mYPNBNlUfFsauSTxLrUH4CUSAKYvr1v76qkk"},"origin_server_ts":1758918670276,"prev_events":["$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:beeper.com","signatures":{"beeper.com":{"ed25519:a_zgvp":"D3cz3m15m89a3G4c5yWOBCjhtSeI5IxBfQKt5XOr9a44QHyc3nwjjvIJaRrKNcS5tLUJwZ2IpVzjlrpbPHpxDA"}},"state_key":"@tulir:beeper.com","type":"m.room.member","unsigned":{"age":6,"event_id":"$_hayW1Y0HRWp3VEGZZbsMf0Ncg9x6n0ikveD0lbCwMw","replaces_state":"$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro"}}
+{"auth_events":["$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0","$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":100},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:beeper.com":9000,"@tulir:envs.net":9001,"@tulir:matrix.org":8999},"users_default":0},"depth":21,"hashes":{"sha256":"xCj9vszChHiXba9DaPzhtF79Tphek3pRViMp36DOurU"},"origin_server_ts":1758918689485,"prev_events":["$_hayW1Y0HRWp3VEGZZbsMf0Ncg9x6n0ikveD0lbCwMw"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"odkrWD30+ObeYtagULtECB/QmGae7qNy66nmJMWYXiQMYUJw/GMzSmgAiLAWfVYlfD3aEvMb/CBdrhL07tfSBw"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$di6cI89-GxX8-Wbx-0T69l4wg6TUWITRkjWXzG7EBqo","replaces_state":"$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"}}
diff --git a/federation/httpclient.go b/federation/httpclient.go
index d6d97280..2f8dbb4f 100644
--- a/federation/httpclient.go
+++ b/federation/httpclient.go
@@ -12,7 +12,6 @@ import (
"net"
"net/http"
"sync"
- "time"
)
// ServerResolvingTransport is an http.RoundTripper that resolves Matrix server names before sending requests.
@@ -22,17 +21,20 @@ type ServerResolvingTransport struct {
Transport *http.Transport
Dialer *net.Dialer
- cache map[string]*ResolvedServerName
- resolveLocks map[string]*sync.Mutex
- cacheLock sync.Mutex
+ cache ResolutionCache
+
+ resolveLocks map[string]*sync.Mutex
+ resolveLocksLock sync.Mutex
}
-func NewServerResolvingTransport() *ServerResolvingTransport {
+func NewServerResolvingTransport(cache ResolutionCache) *ServerResolvingTransport {
+ if cache == nil {
+ cache = NewInMemoryCache()
+ }
srt := &ServerResolvingTransport{
- cache: make(map[string]*ResolvedServerName),
resolveLocks: make(map[string]*sync.Mutex),
-
- Dialer: &net.Dialer{},
+ cache: cache,
+ Dialer: &net.Dialer{},
}
srt.Transport = &http.Transport{
DialContext: srt.DialContext,
@@ -50,12 +52,6 @@ func (srt *ServerResolvingTransport) DialContext(ctx context.Context, network, a
return srt.Dialer.DialContext(ctx, network, addrs[0])
}
-type contextKey int
-
-const (
- contextKeyIPPort contextKey = iota
-)
-
func (srt *ServerResolvingTransport) RoundTrip(request *http.Request) (*http.Response, error) {
if request.URL.Scheme != "matrix-federation" {
return nil, fmt.Errorf("unsupported scheme: %s", request.URL.Scheme)
@@ -72,37 +68,25 @@ func (srt *ServerResolvingTransport) RoundTrip(request *http.Request) (*http.Res
}
func (srt *ServerResolvingTransport) resolve(ctx context.Context, serverName string) (*ResolvedServerName, error) {
- res, lock := srt.getResolveCache(serverName)
- if res != nil {
- return res, nil
+ srt.resolveLocksLock.Lock()
+ lock, ok := srt.resolveLocks[serverName]
+ if !ok {
+ lock = &sync.Mutex{}
+ srt.resolveLocks[serverName] = lock
}
+ srt.resolveLocksLock.Unlock()
+
lock.Lock()
defer lock.Unlock()
- res, _ = srt.getResolveCache(serverName)
- if res != nil {
+ res, err := srt.cache.LoadResolution(serverName)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read cache: %w", err)
+ } else if res != nil {
+ return res, nil
+ } else if res, err = ResolveServerName(ctx, serverName, srt.ResolveOpts); err != nil {
+ return nil, err
+ } else {
+ srt.cache.StoreResolution(res)
return res, nil
}
- var err error
- res, err = ResolveServerName(ctx, serverName, srt.ResolveOpts)
- if err != nil {
- return nil, err
- }
- srt.cacheLock.Lock()
- srt.cache[serverName] = res
- srt.cacheLock.Unlock()
- return res, nil
-}
-
-func (srt *ServerResolvingTransport) getResolveCache(serverName string) (*ResolvedServerName, *sync.Mutex) {
- srt.cacheLock.Lock()
- defer srt.cacheLock.Unlock()
- if val, ok := srt.cache[serverName]; ok && time.Until(val.Expires) > 0 {
- return val, nil
- }
- rl, ok := srt.resolveLocks[serverName]
- if !ok {
- rl = &sync.Mutex{}
- srt.resolveLocks[serverName] = rl
- }
- return nil, rl
}
diff --git a/federation/keyserver.go b/federation/keyserver.go
index 3e74bfdf..d32ba5cf 100644
--- a/federation/keyserver.go
+++ b/federation/keyserver.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2024 Tulir Asokan
+// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -8,13 +8,17 @@ package federation
import (
"encoding/json"
- "fmt"
"net/http"
"strconv"
"time"
- "github.com/gorilla/mux"
+ "github.com/rs/zerolog"
+ "github.com/rs/zerolog/hlog"
+ "go.mau.fi/util/exerrors"
+ "go.mau.fi/util/exhttp"
"go.mau.fi/util/jsontime"
+ "go.mau.fi/util/ptr"
+ "go.mau.fi/util/requestlog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/id"
@@ -47,34 +51,29 @@ type KeyServer struct {
KeyProvider ServerKeyProvider
Version ServerVersion
WellKnownTarget string
+ OtherKeys KeyCache
}
// Register registers the key server endpoints to the given router.
-func (ks *KeyServer) Register(r *mux.Router) {
- r.HandleFunc("/.well-known/matrix/server", ks.GetWellKnown).Methods(http.MethodGet)
- r.HandleFunc("/_matrix/federation/v1/version", ks.GetServerVersion).Methods(http.MethodGet)
- keyRouter := r.PathPrefix("/_matrix/key").Subrouter()
- keyRouter.HandleFunc("/v2/server", ks.GetServerKey).Methods(http.MethodGet)
- keyRouter.HandleFunc("/v2/query/{serverName}", ks.GetQueryKeys).Methods(http.MethodGet)
- keyRouter.HandleFunc("/v2/query", ks.PostQueryKeys).Methods(http.MethodPost)
- keyRouter.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- jsonResponse(w, http.StatusNotFound, &mautrix.RespError{
- ErrCode: mautrix.MUnrecognized.ErrCode,
- Err: "Unrecognized endpoint",
- })
- })
- keyRouter.MethodNotAllowedHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- jsonResponse(w, http.StatusMethodNotAllowed, &mautrix.RespError{
- ErrCode: mautrix.MUnrecognized.ErrCode,
- Err: "Invalid method for endpoint",
- })
- })
-}
-
-func jsonResponse(w http.ResponseWriter, code int, data any) {
- w.Header().Add("Content-Type", "application/json")
- w.WriteHeader(code)
- _ = json.NewEncoder(w).Encode(data)
+func (ks *KeyServer) Register(r *http.ServeMux, log zerolog.Logger) {
+ r.HandleFunc("GET /.well-known/matrix/server", ks.GetWellKnown)
+ r.HandleFunc("GET /_matrix/federation/v1/version", ks.GetServerVersion)
+ keyRouter := http.NewServeMux()
+ keyRouter.HandleFunc("GET /v2/server", ks.GetServerKey)
+ keyRouter.HandleFunc("GET /v2/query/{serverName}", ks.GetQueryKeys)
+ keyRouter.HandleFunc("POST /v2/query", ks.PostQueryKeys)
+ errorBodies := exhttp.ErrorBodies{
+ NotFound: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint")).MarshalJSON()),
+ MethodNotAllowed: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint")).MarshalJSON()),
+ }
+ r.Handle("/_matrix/key/", exhttp.ApplyMiddleware(
+ keyRouter,
+ exhttp.StripPrefix("/_matrix/key"),
+ hlog.NewHandler(log),
+ hlog.RequestIDHandler("request_id", "Request-Id"),
+ requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}),
+ exhttp.HandleErrors(errorBodies),
+ ))
}
// RespWellKnown is the response body for the `GET /.well-known/matrix/server` endpoint.
@@ -87,12 +86,9 @@ type RespWellKnown struct {
// https://spec.matrix.org/v1.9/server-server-api/#get_well-knownmatrixserver
func (ks *KeyServer) GetWellKnown(w http.ResponseWriter, r *http.Request) {
if ks.WellKnownTarget == "" {
- jsonResponse(w, http.StatusNotFound, &mautrix.RespError{
- ErrCode: mautrix.MNotFound.ErrCode,
- Err: "No well-known target set",
- })
+ mautrix.MNotFound.WithMessage("No well-known target set").Write(w)
} else {
- jsonResponse(w, http.StatusOK, &RespWellKnown{Server: ks.WellKnownTarget})
+ exhttp.WriteJSONResponse(w, http.StatusOK, &RespWellKnown{Server: ks.WellKnownTarget})
}
}
@@ -105,7 +101,7 @@ type RespServerVersion struct {
//
// https://spec.matrix.org/v1.9/server-server-api/#get_matrixfederationv1version
func (ks *KeyServer) GetServerVersion(w http.ResponseWriter, r *http.Request) {
- jsonResponse(w, http.StatusOK, &RespServerVersion{Server: ks.Version})
+ exhttp.WriteJSONResponse(w, http.StatusOK, &RespServerVersion{Server: ks.Version})
}
// GetServerKey implements the `GET /_matrix/key/v2/server` endpoint.
@@ -114,12 +110,9 @@ func (ks *KeyServer) GetServerVersion(w http.ResponseWriter, r *http.Request) {
func (ks *KeyServer) GetServerKey(w http.ResponseWriter, r *http.Request) {
domain, key := ks.KeyProvider.Get(r)
if key == nil {
- jsonResponse(w, http.StatusNotFound, &mautrix.RespError{
- ErrCode: mautrix.MNotFound.ErrCode,
- Err: fmt.Sprintf("No signing key found for %q", r.Host),
- })
+ mautrix.MNotFound.WithMessage("No signing key found for %q", r.Host).Write(w)
} else {
- jsonResponse(w, http.StatusOK, key.GenerateKeyResponse(domain, nil))
+ exhttp.WriteJSONResponse(w, http.StatusOK, key.GenerateKeyResponse(domain, nil))
}
}
@@ -144,10 +137,7 @@ func (ks *KeyServer) PostQueryKeys(w http.ResponseWriter, r *http.Request) {
var req ReqQueryKeys
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
- jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{
- ErrCode: mautrix.MBadJSON.ErrCode,
- Err: fmt.Sprintf("failed to parse request: %v", err),
- })
+ mautrix.MBadJSON.WithMessage("failed to parse request: %v", err).Write(w)
return
}
@@ -165,7 +155,7 @@ func (ks *KeyServer) PostQueryKeys(w http.ResponseWriter, r *http.Request) {
}
}
}
- jsonResponse(w, http.StatusOK, resp)
+ exhttp.WriteJSONResponse(w, http.StatusOK, resp)
}
// GetQueryKeysResponse is the response body for the `GET /_matrix/key/v2/query/{serverName}` endpoint
@@ -177,27 +167,39 @@ type GetQueryKeysResponse struct {
//
// https://spec.matrix.org/v1.9/server-server-api/#get_matrixkeyv2queryservername
func (ks *KeyServer) GetQueryKeys(w http.ResponseWriter, r *http.Request) {
- serverName := mux.Vars(r)["serverName"]
+ serverName := r.PathValue("serverName")
minimumValidUntilTSString := r.URL.Query().Get("minimum_valid_until_ts")
minimumValidUntilTS, err := strconv.ParseInt(minimumValidUntilTSString, 10, 64)
if err != nil && minimumValidUntilTSString != "" {
- jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{
- ErrCode: mautrix.MInvalidParam.ErrCode,
- Err: fmt.Sprintf("failed to parse ?minimum_valid_until_ts: %v", err),
- })
+ mautrix.MInvalidParam.WithMessage("failed to parse ?minimum_valid_until_ts: %v", err).Write(w)
return
} else if time.UnixMilli(minimumValidUntilTS).After(time.Now().Add(24 * time.Hour)) {
- jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{
- ErrCode: mautrix.MInvalidParam.ErrCode,
- Err: "minimum_valid_until_ts may not be more than 24 hours in the future",
- })
+ mautrix.MInvalidParam.WithMessage("minimum_valid_until_ts may not be more than 24 hours in the future").Write(w)
return
}
resp := &GetQueryKeysResponse{
ServerKeys: []*ServerKeyResponse{},
}
- if domain, key := ks.KeyProvider.Get(r); key != nil && domain == serverName {
- resp.ServerKeys = append(resp.ServerKeys, key.GenerateKeyResponse(serverName, nil))
+ domain, key := ks.KeyProvider.Get(r)
+ if domain == serverName {
+ if key != nil {
+ resp.ServerKeys = append(resp.ServerKeys, key.GenerateKeyResponse(serverName, nil))
+ }
+ } else if ks.OtherKeys != nil {
+ otherKey, err := ks.OtherKeys.LoadKeys(serverName)
+ if err != nil {
+ mautrix.MUnknown.WithMessage("Failed to load keys from cache").Write(w)
+ return
+ }
+ if key != nil && domain != "" {
+ signature, err := key.SignJSON(otherKey)
+ if err == nil {
+ otherKey.Signatures[domain] = map[id.KeyID]string{
+ key.ID: signature,
+ }
+ }
+ }
+ resp.ServerKeys = append(resp.ServerKeys, otherKey)
}
- jsonResponse(w, http.StatusOK, resp)
+ exhttp.WriteJSONResponse(w, http.StatusOK, resp)
}
diff --git a/federation/pdu/auth.go b/federation/pdu/auth.go
new file mode 100644
index 00000000..16706fe5
--- /dev/null
+++ b/federation/pdu/auth.go
@@ -0,0 +1,71 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu
+
+import (
+ "slices"
+
+ "github.com/tidwall/gjson"
+ "go.mau.fi/util/exgjson"
+
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+)
+
+type StateKey struct {
+ Type string
+ StateKey string
+}
+
+var thirdPartyInviteTokenPath = exgjson.Path("third_party_invite", "signed", "token")
+
+type AuthEventSelection []StateKey
+
+func (aes *AuthEventSelection) Add(evtType, stateKey string) {
+ key := StateKey{Type: evtType, StateKey: stateKey}
+ if !aes.Has(key) {
+ *aes = append(*aes, key)
+ }
+}
+
+func (aes *AuthEventSelection) Has(key StateKey) bool {
+ return slices.Contains(*aes, key)
+}
+
+func (pdu *PDU) AuthEventSelection(roomVersion id.RoomVersion) (keys AuthEventSelection) {
+ if pdu.Type == event.StateCreate.Type && pdu.StateKey != nil {
+ return AuthEventSelection{}
+ }
+ keys = make(AuthEventSelection, 0, 3)
+ if !roomVersion.RoomIDIsCreateEventID() {
+ keys.Add(event.StateCreate.Type, "")
+ }
+ keys.Add(event.StatePowerLevels.Type, "")
+ keys.Add(event.StateMember.Type, pdu.Sender.String())
+ if pdu.Type == event.StateMember.Type && pdu.StateKey != nil {
+ keys.Add(event.StateMember.Type, *pdu.StateKey)
+ membership := event.Membership(gjson.GetBytes(pdu.Content, "membership").Str)
+ if membership == event.MembershipJoin || membership == event.MembershipInvite || membership == event.MembershipKnock {
+ keys.Add(event.StateJoinRules.Type, "")
+ }
+ if membership == event.MembershipInvite {
+ thirdPartyInviteToken := gjson.GetBytes(pdu.Content, thirdPartyInviteTokenPath).Str
+ if thirdPartyInviteToken != "" {
+ keys.Add(event.StateThirdPartyInvite.Type, thirdPartyInviteToken)
+ }
+ }
+ if membership == event.MembershipJoin && roomVersion.RestrictedJoins() {
+ authorizedVia := gjson.GetBytes(pdu.Content, "authorised_via_users_server").Str
+ if authorizedVia != "" {
+ keys.Add(event.StateMember.Type, authorizedVia)
+ }
+ }
+ }
+ return
+}
diff --git a/federation/pdu/hash.go b/federation/pdu/hash.go
new file mode 100644
index 00000000..38ef83e9
--- /dev/null
+++ b/federation/pdu/hash.go
@@ -0,0 +1,118 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu
+
+import (
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/base64"
+ "fmt"
+
+ "github.com/tidwall/gjson"
+
+ "maunium.net/go/mautrix/id"
+)
+
+func (pdu *PDU) CalculateContentHash() ([32]byte, error) {
+ if pdu == nil {
+ return [32]byte{}, ErrPDUIsNil
+ }
+ pduClone := pdu.Clone()
+ pduClone.Signatures = nil
+ pduClone.Unsigned = nil
+ pduClone.Hashes = nil
+ rawJSON, err := marshalCanonical(pduClone)
+ if err != nil {
+ return [32]byte{}, fmt.Errorf("failed to marshal PDU to calculate content hash: %w", err)
+ }
+ return sha256.Sum256(rawJSON), nil
+}
+
+func (pdu *PDU) FillContentHash() error {
+ if pdu == nil {
+ return ErrPDUIsNil
+ } else if pdu.Hashes != nil {
+ return nil
+ } else if hash, err := pdu.CalculateContentHash(); err != nil {
+ return err
+ } else {
+ pdu.Hashes = &Hashes{SHA256: hash[:]}
+ return nil
+ }
+}
+
+func (pdu *PDU) VerifyContentHash() bool {
+ if pdu == nil || pdu.Hashes == nil {
+ return false
+ }
+ calculatedHash, err := pdu.CalculateContentHash()
+ if err != nil {
+ return false
+ }
+ return hmac.Equal(calculatedHash[:], pdu.Hashes.SHA256)
+}
+
+func (pdu *PDU) GetRoomID() (id.RoomID, error) {
+ if pdu == nil {
+ return "", ErrPDUIsNil
+ } else if pdu.Type != "m.room.create" {
+ return "", fmt.Errorf("room ID can only be calculated for m.room.create events")
+ } else if roomVersion := id.RoomVersion(gjson.GetBytes(pdu.Content, "room_version").Str); !roomVersion.RoomIDIsCreateEventID() {
+ return "", fmt.Errorf("room version %s does not use m.room.create event ID as room ID", roomVersion)
+ } else if evtID, err := pdu.calculateEventID(roomVersion, '!'); err != nil {
+ return "", fmt.Errorf("failed to calculate event ID: %w", err)
+ } else {
+ return id.RoomID(evtID), nil
+ }
+}
+
+var UseInternalMetaForGetEventID = false
+
+func (pdu *PDU) GetEventID(roomVersion id.RoomVersion) (id.EventID, error) {
+ if UseInternalMetaForGetEventID && pdu.InternalMeta.EventID != "" {
+ return pdu.InternalMeta.EventID, nil
+ }
+ return pdu.calculateEventID(roomVersion, '$')
+}
+
+func (pdu *PDU) GetReferenceHash(roomVersion id.RoomVersion) ([32]byte, error) {
+ if pdu == nil {
+ return [32]byte{}, ErrPDUIsNil
+ }
+ if pdu.Hashes == nil || pdu.Hashes.SHA256 == nil {
+ if err := pdu.FillContentHash(); err != nil {
+ return [32]byte{}, err
+ }
+ }
+ rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion))
+ if err != nil {
+ return [32]byte{}, fmt.Errorf("failed to marshal redacted PDU to calculate event ID: %w", err)
+ }
+ return sha256.Sum256(rawJSON), nil
+}
+
+func (pdu *PDU) calculateEventID(roomVersion id.RoomVersion, prefix byte) (id.EventID, error) {
+ referenceHash, err := pdu.GetReferenceHash(roomVersion)
+ if err != nil {
+ return "", err
+ }
+ eventID := make([]byte, 44)
+ eventID[0] = prefix
+ switch roomVersion.EventIDFormat() {
+ case id.EventIDFormatCustom:
+ return "", fmt.Errorf("*pdu.PDU can only be used for room v3+")
+ case id.EventIDFormatBase64:
+ base64.RawStdEncoding.Encode(eventID[1:], referenceHash[:])
+ case id.EventIDFormatURLSafeBase64:
+ base64.RawURLEncoding.Encode(eventID[1:], referenceHash[:])
+ default:
+ return "", fmt.Errorf("unknown event ID format %v", roomVersion.EventIDFormat())
+ }
+ return id.EventID(eventID), nil
+}
diff --git a/federation/pdu/hash_test.go b/federation/pdu/hash_test.go
new file mode 100644
index 00000000..17417e12
--- /dev/null
+++ b/federation/pdu/hash_test.go
@@ -0,0 +1,55 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu_test
+
+import (
+ "encoding/base64"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "go.mau.fi/util/exerrors"
+)
+
+func TestPDU_CalculateContentHash(t *testing.T) {
+ for _, test := range testPDUs {
+ if test.redacted {
+ continue
+ }
+ t.Run(test.name, func(t *testing.T) {
+ parsed := parsePDU(test.pdu)
+ contentHash := exerrors.Must(parsed.CalculateContentHash())
+ assert.Equal(
+ t,
+ base64.RawStdEncoding.EncodeToString(parsed.Hashes.SHA256),
+ base64.RawStdEncoding.EncodeToString(contentHash[:]),
+ )
+ })
+ }
+}
+
+func TestPDU_VerifyContentHash(t *testing.T) {
+ for _, test := range testPDUs {
+ if test.redacted {
+ continue
+ }
+ t.Run(test.name, func(t *testing.T) {
+ parsed := parsePDU(test.pdu)
+ assert.True(t, parsed.VerifyContentHash())
+ })
+ }
+}
+
+func TestPDU_GetEventID(t *testing.T) {
+ for _, test := range testPDUs {
+ t.Run(test.name, func(t *testing.T) {
+ gotEventID := exerrors.Must(parsePDU(test.pdu).GetEventID(test.roomVersion))
+ assert.Equal(t, test.eventID, gotEventID)
+ })
+ }
+}
diff --git a/federation/pdu/pdu.go b/federation/pdu/pdu.go
new file mode 100644
index 00000000..17db6995
--- /dev/null
+++ b/federation/pdu/pdu.go
@@ -0,0 +1,156 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu
+
+import (
+ "bytes"
+ "crypto/ed25519"
+ "encoding/json/jsontext"
+ "encoding/json/v2"
+ "errors"
+ "fmt"
+ "strings"
+ "time"
+
+ "github.com/tidwall/gjson"
+ "go.mau.fi/util/jsonbytes"
+ "go.mau.fi/util/ptr"
+
+ "maunium.net/go/mautrix/crypto/canonicaljson"
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+)
+
+// GetKeyFunc is a callback for retrieving the key corresponding to a given key ID when verifying the signature of a PDU.
+//
+// The input time is the timestamp of the event. The function should attempt to fetch a key that is
+// valid at or after this time, but if that is not possible, the latest available key should be
+// returned without an error. The verify function will do its own validity checking based on the
+// returned valid until timestamp.
+type GetKeyFunc = func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error)
+
+type AnyPDU interface {
+ GetRoomID() (id.RoomID, error)
+ GetEventID(roomVersion id.RoomVersion) (id.EventID, error)
+ GetReferenceHash(roomVersion id.RoomVersion) ([32]byte, error)
+ CalculateContentHash() ([32]byte, error)
+ FillContentHash() error
+ VerifyContentHash() bool
+ Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error
+ VerifySignature(roomVersion id.RoomVersion, serverName string, getKey GetKeyFunc) error
+ ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error)
+ AuthEventSelection(roomVersion id.RoomVersion) (keys AuthEventSelection)
+}
+
+var (
+ _ AnyPDU = (*PDU)(nil)
+ _ AnyPDU = (*RoomV1PDU)(nil)
+)
+
+type InternalMeta struct {
+ EventID id.EventID `json:"event_id,omitempty"`
+ Rejected bool `json:"rejected,omitempty"`
+ Extra map[string]any `json:",unknown"`
+}
+
+type PDU struct {
+ AuthEvents []id.EventID `json:"auth_events"`
+ Content jsontext.Value `json:"content"`
+ Depth int64 `json:"depth"`
+ Hashes *Hashes `json:"hashes,omitzero"`
+ OriginServerTS int64 `json:"origin_server_ts"`
+ PrevEvents []id.EventID `json:"prev_events"`
+ Redacts *id.EventID `json:"redacts,omitzero"`
+ RoomID id.RoomID `json:"room_id,omitzero"` // not present for room v12+ create events
+ Sender id.UserID `json:"sender"`
+ Signatures map[string]map[id.KeyID]string `json:"signatures,omitzero"`
+ StateKey *string `json:"state_key,omitzero"`
+ Type string `json:"type"`
+ Unsigned jsontext.Value `json:"unsigned,omitzero"`
+ InternalMeta InternalMeta `json:"-"`
+
+ Unknown jsontext.Value `json:",unknown"`
+
+ // Deprecated legacy fields
+ DeprecatedPrevState jsontext.Value `json:"prev_state,omitzero"`
+ DeprecatedOrigin jsontext.Value `json:"origin,omitzero"`
+ DeprecatedMembership jsontext.Value `json:"membership,omitzero"`
+}
+
+var ErrPDUIsNil = errors.New("PDU is nil")
+
+type Hashes struct {
+ SHA256 jsonbytes.UnpaddedBytes `json:"sha256"`
+
+ Unknown jsontext.Value `json:",unknown"`
+}
+
+func (pdu *PDU) ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error) {
+ if pdu.Type == "m.room.create" && roomVersion == "" {
+ roomVersion = id.RoomVersion(gjson.GetBytes(pdu.Content, "room_version").Str)
+ }
+ evtType := event.Type{Type: pdu.Type, Class: event.MessageEventType}
+ if pdu.StateKey != nil {
+ evtType.Class = event.StateEventType
+ }
+ eventID, err := pdu.GetEventID(roomVersion)
+ if err != nil {
+ return nil, err
+ }
+ roomID := pdu.RoomID
+ if pdu.Type == "m.room.create" && roomVersion.RoomIDIsCreateEventID() {
+ roomID = id.RoomID(strings.Replace(string(eventID), "$", "!", 1))
+ }
+ evt := &event.Event{
+ StateKey: pdu.StateKey,
+ Sender: pdu.Sender,
+ Type: evtType,
+ Timestamp: pdu.OriginServerTS,
+ ID: eventID,
+ RoomID: roomID,
+ Redacts: ptr.Val(pdu.Redacts),
+ }
+ err = json.Unmarshal(pdu.Content, &evt.Content)
+ if err != nil {
+ return nil, fmt.Errorf("failed to unmarshal content: %w", err)
+ }
+ return evt, nil
+}
+
+func (pdu *PDU) AddSignature(serverName string, keyID id.KeyID, signature string) {
+ if signature == "" {
+ return
+ }
+ if pdu.Signatures == nil {
+ pdu.Signatures = make(map[string]map[id.KeyID]string)
+ }
+ if _, ok := pdu.Signatures[serverName]; !ok {
+ pdu.Signatures[serverName] = make(map[id.KeyID]string)
+ }
+ pdu.Signatures[serverName][keyID] = signature
+}
+
+func marshalCanonical(data any) (jsontext.Value, error) {
+ marshaledBytes, err := json.Marshal(data)
+ if err != nil {
+ return nil, err
+ }
+ marshaled := jsontext.Value(marshaledBytes)
+ err = marshaled.Canonicalize()
+ if err != nil {
+ return nil, err
+ }
+ check := canonicaljson.CanonicalJSONAssumeValid(marshaled)
+ if !bytes.Equal(marshaled, check) {
+ fmt.Println(string(marshaled))
+ fmt.Println(string(check))
+ return nil, fmt.Errorf("canonical JSON mismatch for %s", string(marshaled))
+ }
+ return marshaled, nil
+}
diff --git a/federation/pdu/pdu_test.go b/federation/pdu/pdu_test.go
new file mode 100644
index 00000000..59d7c3a6
--- /dev/null
+++ b/federation/pdu/pdu_test.go
@@ -0,0 +1,193 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu_test
+
+import (
+ "encoding/json/v2"
+ "time"
+
+ "go.mau.fi/util/exerrors"
+
+ "maunium.net/go/mautrix/federation/pdu"
+ "maunium.net/go/mautrix/id"
+)
+
+type serverKey struct {
+ key id.SigningKey
+ validUntilTS time.Time
+}
+
+type serverDetails struct {
+ serverName string
+ keys map[id.KeyID]serverKey
+}
+
+func (sd serverDetails) getKey(serverName string, keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) {
+ if serverName != sd.serverName {
+ return "", time.Time{}, nil
+ }
+ key, ok := sd.keys[keyID]
+ if ok {
+ return key.key, key.validUntilTS, nil
+ }
+ return "", time.Time{}, nil
+}
+
+var mauniumNet = serverDetails{
+ serverName: "maunium.net",
+ keys: map[id.KeyID]serverKey{
+ "ed25519:a_xxeS": {
+ key: "lVt/CC3tv74OH6xTph2JrUmeRj/j+1q0HVa0Xf4QlCg",
+ validUntilTS: time.Now(),
+ },
+ },
+}
+var envsNet = serverDetails{
+ serverName: "envs.net",
+ keys: map[id.KeyID]serverKey{
+ "ed25519:a_zIqy": {
+ key: "vCUcZpt9hUn0aabfh/9GP/6sZvXcydww8DUstPHdJm0",
+ validUntilTS: time.UnixMilli(1722360538068),
+ },
+ "ed25519:wuJyKT": {
+ key: "xbE1QssgomL4wCSlyMYF5/7KxVyM4HPwAbNa+nFFnx0",
+ validUntilTS: time.Now(),
+ },
+ },
+}
+var matrixOrg = serverDetails{
+ serverName: "matrix.org",
+ keys: map[id.KeyID]serverKey{
+ "ed25519:auto": {
+ key: "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw",
+ validUntilTS: time.UnixMilli(1576767829750),
+ },
+ "ed25519:a_RXGa": {
+ key: "l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ",
+ validUntilTS: time.Now(),
+ },
+ },
+}
+var continuwuityOrg = serverDetails{
+ serverName: "continuwuity.org",
+ keys: map[id.KeyID]serverKey{
+ "ed25519:PwHlNsFu": {
+ key: "8eNx2s0zWW+heKAmOH5zKv/nCPkEpraDJfGHxDu6hFI",
+ validUntilTS: time.Now(),
+ },
+ },
+}
+var novaAstraltechOrg = serverDetails{
+ serverName: "nova.astraltech.org",
+ keys: map[id.KeyID]serverKey{
+ "ed25519:a_afpo": {
+ key: "O1Y9GWuKo9xkuzuQef6gROxtTgxxAbS3WPNghPYXF3o",
+ validUntilTS: time.Now(),
+ },
+ },
+}
+
+type testPDU struct {
+ name string
+ pdu string
+ eventID id.EventID
+ roomVersion id.RoomVersion
+ redacted bool
+ serverDetails
+}
+
+var roomV4MessageTestPDU = testPDU{
+ name: "m.room.message in v4 room",
+ pdu: `{"auth_events":["$OB87jNemaIVDHAfu0-pa_cP7OPFXUXCbFpjYVi8gll4","$RaWbTF9wQfGQgUpe1S13wzICtGTB2PNKRHUNHu9IO1c","$ZmEWOXw6cC4Rd1wTdY5OzeLJVzjhrkxFPwwKE4gguGk"],"content":{"body":"the last one is saying it shouldn't have effects","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":13103,"hashes":{"sha256":"c2wb8qMlvzIPCP1Wd+eYZ4BRgnGYxS97dR1UlJjVMeg"},"origin_server_ts":1752875275263,"prev_events":["$-7_BMI3BXwj3ayoxiJvraJxYWTKwjiQ6sh7CW_Brvj0"],"room_id":"!JiiOHXrIUCtcOJsZCa:matrix.org","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"99TAqHpBkUEtgCraXsVXogmf/hnijPbgbG9eACtA+mbix3Y6gURI4QGQgcX/NhcE3pJQZ/YDjmbuvCnKvEccAA"}},"unsigned":{"age_ts":1752875275281}}`,
+ eventID: "$Jo_lmFR-e6lzrimzCA7DevIn2OwhuQYmd9xkcJBoqAA",
+ roomVersion: id.RoomV4,
+ serverDetails: mauniumNet,
+}
+
+var roomV12MessageTestPDU = testPDU{
+ name: "m.room.message in v12 room",
+ pdu: `{"auth_events":["$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA","$hyeL_nU_L3tsZ2dtZZpAHk0Skv-PqFQIipuII_By584"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":122,"hashes":{"sha256":"IQ0zlc+PXeEs6R3JvRkW3xTPV3zlGKSSd3x07KXGjzs"},"origin_server_ts":1755384351627,"prev_events":["$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA"],"room_id":"!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ","sender":"@tulir_test:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"0GDMddL2k7gF4V1VU8sL3wTfhAIzAu5iVH5jeavZ2VEg3J9/tHLWXAOn2tzkLaMRWl0/XpINT2YlH/rd2U21Ag"}},"unsigned":{"age_ts":1755384351627}}`,
+ eventID: "$xmP-wZfpannuHG-Akogi6c4YvqxChMtdyYbUMGOrMWc",
+ roomVersion: id.RoomV12,
+ serverDetails: mauniumNet,
+}
+
+var testPDUs = []testPDU{roomV4MessageTestPDU, {
+ name: "m.room.message in v5 room",
+ pdu: `{"auth_events":["$hp0ImHqYgHTRbLeWKPeTeFmxdb5SdMJN9cfmTrTk7d0","$KAj7X7tnJbR9qYYMWJSw-1g414_KlPptbbkZm7_kUtg","$V-2ShOwZYhA_nxMijaf3lqFgIJgzE2UMeFPtOLnoBYM"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":2248,"hashes":{"sha256":"kV+JuLbWXJ2r6PjHT3wt8bFc/TfI1nTaSN3Lamg/xHs"},"origin_server_ts":1755422945654,"prev_events":["$49lFLem2Nk4dxHk9RDXxTdaq9InIJpmkHpzVnjKcYwg"],"room_id":"!vzBgJsjNzgHSdWsmki:mozilla.org","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"JIl60uVgfCLBZLPoSiE7wVkJ9U5cNEPVPuv1sCCYUOq5yOW56WD1adgpBUdX2UFpYkCHvkRnyQGxU0+6HBp5BA"}},"unsigned":{"age_ts":1755422945673}}`,
+ eventID: "$Qn4tHfuAe6PlnKXPZnygAU9wd6RXqMKtt_ZzstHTSgA",
+ roomVersion: id.RoomV5,
+ serverDetails: mauniumNet,
+}, {
+ name: "m.room.message in v10 room",
+ pdu: `{"auth_events":["$--ilpwnsHaEdHrwiMrZNu5xHP6TthWG0FIXMHnlHCcs","$tn1FZUI_YUpfTr_a3Y_r8kC3inliIZZratzg0UsNdCQ","$Z-qMWmiMvm-aIEffcfSO6lN7TyjyTOsIcHIymfzoo20"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":100885,"hashes":{"sha256":"jc9272JPpPIVreJC3UEAm3BNVnLX8sm3U/TZs23wsHo"},"origin_server_ts":1755422792518,"prev_events":["$HDtbzpSys36Hk-F2NsiXfp9slsGXBH0b58qyddj_q5E"],"room_id":"!UzZHbJYcgggctGnlzr:envs.net","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"sAMLo9jPtNB0Jq67IQm06siEBx82qZa2edu56IDQ4tDylEV4Mq7iFO23gCghqXA7B/MqBsjXotGBxv6AvlJ2Dw"}},"unsigned":{"age_ts":1755422792540}}`,
+ eventID: "$4ZFr_ypfp4DyZQP4zyxM_cvuOMFkl07doJmwi106YFY",
+ roomVersion: id.RoomV10,
+ serverDetails: mauniumNet,
+}, {
+ name: "m.room.message in v11 room",
+ pdu: `{"auth_events":["$L8Ak6A939llTRIsZrytMlLDXQhI4uLEjx-wb1zSg-Bw","$QJmr7mmGeXGD4Tof0ZYSPW2oRGklseyHTKtZXnF-YNM","$7bkKK_Z-cGQ6Ae4HXWGBwXyZi3YjC6rIcQzGfVyl3Eo"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":3212,"hashes":{"sha256":"K549YdTnv62Jn84Y7sS5ZN3+AdmhleZHbenbhUpR2R8"},"origin_server_ts":1754242687127,"prev_events":["$DAhJg4jVsqk5FRatE2hbT1dSA8D2ASy5DbjEHIMSHwY"],"room_id":"!offtopic-2:continuwuity.org","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"SkzZdZ+rH22kzCBBIAErTdB0Vg6vkFmzvwjlOarGul72EnufgtE/tJcd3a8szAdK7f1ZovRyQxDgVm/Ib2u0Aw"}},"unsigned":{"age_ts":1754242687146}}`,
+ eventID: `$qkWfTL7_l3oRZO2CItW8-Q0yAmi_l_1ua629ZDqponE`,
+ roomVersion: id.RoomV11,
+ serverDetails: mauniumNet,
+}, roomV12MessageTestPDU, {
+ name: "m.room.create in v4 room",
+ pdu: `{"auth_events": [], "prev_events": [], "type": "m.room.create", "room_id": "!jxlRxnrZCsjpjDubDX:matrix.org", "sender": "@neilj:matrix.org", "content": {"room_version": "4", "predecessor": {"room_id": "!DYgXKezaHgMbiPMzjX:matrix.org", "event_id": "$156171636353XwPJT:matrix.org"}, "creator": "@neilj:matrix.org"}, "depth": 1, "prev_state": [], "state_key": "", "origin": "matrix.org", "origin_server_ts": 1561716363993, "hashes": {"sha256": "9tj8GpXjTAJvdNAbnuKLemZZk+Tjv2LAbGodSX6nJAo"}, "signatures": {"matrix.org": {"ed25519:auto": "2+sNt8uJUhzU4GPxnFVYtU2ZRgFdtVLT1vEZGUdJYN40zBpwYEGJy+kyb5matA+8/yLeYD9gu1O98lhleH0aCA"}}, "unsigned": {"age": 104769}}`,
+ eventID: "$ay_9_nPilrTpb3UxIwHHBBfFjTJb6hBAE_JzQwSjqeY",
+ roomVersion: id.RoomV4,
+ serverDetails: matrixOrg,
+}, {
+ name: "m.room.create in v10 room",
+ pdu: `{"auth_events":[],"content":{"creator":"@creme:envs.net","predecessor":{"event_id":"$BxYNisKcyBDhPLiVC06t18qhv7wsT72MzMCqn5vRhfY","room_id":"!tEyFYiMHhwJlDXTxwf:envs.net"},"room_version":"10"},"depth":1,"hashes":{"sha256":"us3TrsIjBWpwbm+k3F9fUVnz9GIuhnb+LcaY47fWwUI"},"origin":"envs.net","origin_server_ts":1664394769527,"prev_events":[],"room_id":"!UzZHbJYcgggctGnlzr:envs.net","sender":"@creme:envs.net","state_key":"","type":"m.room.create","signatures":{"envs.net":{"ed25519:a_zIqy":"0g3FDaD1e5BekJYW2sR7dgxuKoZshrf8P067c9+jmH6frsWr2Ua86Ax08CFa/n46L8uvV2SGofP8iiVYgXCRBg"}},"unsigned":{"age":2060}}`,
+ eventID: "$tn1FZUI_YUpfTr_a3Y_r8kC3inliIZZratzg0UsNdCQ",
+ roomVersion: id.RoomV10,
+ serverDetails: envsNet,
+}, {
+ name: "m.room.create in v12 room",
+ pdu: `{"auth_events":[],"content":{"fi.mau.randomness":"AAXZ6aIc","predecessor":{"room_id":"!#test/room\nversion 11, with @\ud83d\udc08\ufe0f:maunium.net"},"room_version":"12"},"depth":1,"hashes":{"sha256":"d3L1M3KUdyIKWcShyW6grUoJ8GOjCdSIEvQrDVHSpE8"},"origin_server_ts":1754940000000,"prev_events":[],"sender":"@tulir:maunium.net","state_key":"","type":"m.room.create","signatures":{"maunium.net":{"ed25519:a_xxeS":"ebjIRpzToc82cjb/RGY+VUzZic0yeRZrjctgx0SUTJxkprXn3/i1KdiYULfl/aD0cUJ5eL8gLakOSk2glm+sBw"}},"unsigned":{"age_ts":1754939139045}}`,
+ eventID: "$mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ",
+ roomVersion: id.RoomV12,
+ serverDetails: mauniumNet,
+}, {
+ name: "m.room.member in v4 room",
+ pdu: `{"auth_events":["$ay_9_nPilrTpb3UxIwHHBBfFjTJb6hBAE_JzQwSjqeY","$jg2AgCfnwnjR-osoyM0lVYS21QrtfmZxhGO90PRkmO4","$wMGMP4Ucij2_d4h_fVDgIT2xooLZAgMcBruT9oo3Jio","$yyDgV8w0_e8qslmn0nh9OeSq_fO0zjpjTjSEdKFxDso"],"prev_events":["$zSjNuTXhUe3Rq6NpKD3sNyl8a_asMnBhGC5IbacHlJ4"],"type":"m.room.member","room_id":"!jxlRxnrZCsjpjDubDX:matrix.org","sender":"@tulir:maunium.net","content":{"membership":"join","displayname":"tulir","avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","clicked \"send membership event with no changes\"":true},"depth":14370,"prev_state":[],"state_key":"@tulir:maunium.net","origin":"maunium.net","origin_server_ts":1600871136259,"hashes":{"sha256":"Ga6bG9Mk0887ruzM9TAAfa1O3DbNssb+qSFtE9oeRL4"},"signatures":{"maunium.net":{"ed25519:a_xxeS":"fzOyDG3G3pEzixtWPttkRA1DfnHETiKbiG8SEBQe2qycQbZWPky7xX8WujSrUJH/+bxTABpQwEH49d+RakxtBw"}},"unsigned":{"age_ts":1600871136259,"replaces_state":"$jg2AgCfnwnjR-osoyM0lVYS21QrtfmZxhGO90PRkmO4"}}`,
+ eventID: "$VtuCNOfAWGow-cxy0ajeK3fvONcC8QzF2yWa43g0Gwo",
+ roomVersion: id.RoomV4,
+ serverDetails: mauniumNet,
+}, {
+ name: "m.room.member in v10 room",
+ pdu: `{"auth_events":["$HQC4hWaioLKVbMH94qKbfb3UnL4ocql2vi-VdUYI48I","$R9FUDgNAp9ms7b6ASunZOIkpqmsIRq_ROrNEznu62fs","$kEPF8Aj87EzRmFPriu2zdyEY0rY15XSqywTYVLUUlCA","$tn1FZUI_YUpfTr_a3Y_r8kC3inliIZZratzg0UsNdCQ"],"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"depth":182,"hashes":{"sha256":"0HscBc921QV2dxK2qY7qrnyoAgfxBM7kKvqAXlEk+GE"},"origin":"maunium.net","origin_server_ts":1665402609039,"prev_events":["$R9FUDgNAp9ms7b6ASunZOIkpqmsIRq_ROrNEznu62fs"],"room_id":"!UzZHbJYcgggctGnlzr:envs.net","sender":"@tulir:maunium.net","state_key":"@tulir:maunium.net","type":"m.room.member","signatures":{"maunium.net":{"ed25519:a_xxeS":"lkOW0FSJ8MJ0wZpdwLH1Uf6FSl2q9/u6KthRIlM0CwHDJG4sIZ9DrMA8BdU8L/PWoDS/CoDUlLanDh99SplgBw"}},"unsigned":{"age_ts":1665402609039,"replaces_state":"$R9FUDgNAp9ms7b6ASunZOIkpqmsIRq_ROrNEznu62fs"}}`,
+ eventID: "$--ilpwnsHaEdHrwiMrZNu5xHP6TthWG0FIXMHnlHCcs",
+ roomVersion: id.RoomV10,
+ serverDetails: mauniumNet,
+}, {
+ name: "m.room.member of creator in v12 room",
+ pdu: `{"auth_events":[],"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"depth":2,"hashes":{"sha256":"IebdOBYaaWYIx2zq/lkVCnjWIXTLk1g+vgFpJMgd2/E"},"origin_server_ts":1754939139117,"prev_events":["$mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ"],"room_id":"!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ","sender":"@tulir:maunium.net","state_key":"@tulir:maunium.net","type":"m.room.member","signatures":{"maunium.net":{"ed25519:a_xxeS":"rFCgF2hmavdm6+P6/f7rmuOdoSOmELFaH3JdWjgBLZXS2z51Ma7fa2v2+BkAH1FvBo9FLhvEoFVM4WbNQLXtAA"}},"unsigned":{"age_ts":1754939139117}}`,
+ eventID: "$accqGxfvhBvMP4Sf6P7t3WgnaJK6UbonO2ZmwqSE5Sg",
+ roomVersion: id.RoomV12,
+ serverDetails: mauniumNet,
+}, {
+ name: "custom message event in v4 room",
+ pdu: `{"auth_events":["$VtuCNOfAWGow-cxy0ajeK3fvONcC8QzF2yWa43g0Gwo","$ay_9_nPilrTpb3UxIwHHBBfFjTJb6hBAE_JzQwSjqeY","$Gau_XwziYsr-rt3SouhbKN14twgmbKjcZZc_hz-nOgU"],"content":{"\ud83d\udc08\ufe0f":true,"\ud83d\udc15\ufe0f":false},"depth":69645,"hashes":{"sha256":"VHtWyCt+15ZesNnStU3FOkxrjzHJYZfd3JUgO9JWe0s"},"origin_server_ts":1755423939146,"prev_events":["$exmp4cj0OKOFSxuqBYiOYwQi5j_0XRc78d6EavAkhy0"],"room_id":"!jxlRxnrZCsjpjDubDX:matrix.org","sender":"@tulir:maunium.net","type":"\ud83d\udc08\ufe0f","signatures":{"maunium.net":{"ed25519:a_xxeS":"wfmP1XN4JBkKVkqrQnwysyEUslXt8hQRFwN9NC9vJaIeDMd0OJ6uqCas75808DuG71p23fzqbzhRnHckst6FCQ"}},"unsigned":{"age_ts":1755423939164}}`,
+ eventID: "$kAagtZAIEeZaLVCUSl74tAxQbdKbE22GU7FM-iAJBc0",
+ roomVersion: id.RoomV4,
+ serverDetails: mauniumNet,
+}, {
+ name: "redacted m.room.member event in v11 room with 2 signatures",
+ pdu: `{"auth_events":["$9f12-_stoY07BOTmyguE1QlqvghLBh9Rk6PWRLoZn_M","$IP8hyjBkIDREVadyv0fPCGAW9IXGNllaZyxqQwiY_tA","$7dN5J8EveliaPkX6_QSejl4GQtem4oieavgALMeWZyE"],"content":{"membership":"join"},"depth":96978,"hashes":{"sha256":"APYA/aj3u+P0EwNaEofuSIlfqY3cK3lBz6RkwHX+Zak"},"origin_server_ts":1755664164485,"prev_events":["$XBN9W5Ll8VEH3eYqJaemxCBTDdy0hZB0sWpmyoUp93c"],"room_id":"!main-1:continuwuity.org","sender":"@6a19abdd4766:nova.astraltech.org","state_key":"@6a19abdd4766:nova.astraltech.org","type":"m.room.member","signatures":{"continuwuity.org":{"ed25519:PwHlNsFu":"+b/Fp2vWnC+Z2lI3GnCu7ZHdo3iWNDZ2AJqMoU9owMtLBPMxs4dVIsJXvaFq0ryawsgwDwKZ7f4xaFUNARJSDg"},"nova.astraltech.org":{"ed25519:a_afpo":"pXIngyxKukCPR7WOIIy8FTZxQ5L2dLiou5Oc8XS4WyY4YzJuckQzOaToigLLZxamfbN/jXbO+XUizpRpYccDAA"}},"unsigned":{}}`,
+ eventID: "$r6d9m125YWG28-Tln47bWtm6Jlv4mcSUWJTHijBlXLQ",
+ roomVersion: id.RoomV11,
+ serverDetails: novaAstraltechOrg,
+ redacted: true,
+}}
+
+func parsePDU(pdu string) (out *pdu.PDU) {
+ exerrors.PanicIfNotNil(json.Unmarshal([]byte(pdu), &out))
+ return
+}
diff --git a/federation/pdu/redact.go b/federation/pdu/redact.go
new file mode 100644
index 00000000..d7ee0c15
--- /dev/null
+++ b/federation/pdu/redact.go
@@ -0,0 +1,111 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu
+
+import (
+ "encoding/json/jsontext"
+
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
+ "go.mau.fi/util/exgjson"
+ "go.mau.fi/util/ptr"
+
+ "maunium.net/go/mautrix/id"
+)
+
+func filteredObject(object jsontext.Value, allowedPaths ...string) jsontext.Value {
+ filtered := jsontext.Value("{}")
+ var err error
+ for _, path := range allowedPaths {
+ res := gjson.GetBytes(object, path)
+ if res.Exists() {
+ var raw jsontext.Value
+ if res.Index > 0 {
+ raw = object[res.Index : res.Index+len(res.Raw)]
+ } else {
+ raw = jsontext.Value(res.Raw)
+ }
+ filtered, err = sjson.SetRawBytes(filtered, path, raw)
+ if err != nil {
+ panic(err)
+ }
+ }
+ }
+ return filtered
+}
+
+func (pdu *PDU) Clone() *PDU {
+ return ptr.Clone(pdu)
+}
+
+func (pdu *PDU) RedactForSignature(roomVersion id.RoomVersion) *PDU {
+ pdu.Signatures = nil
+ return pdu.Redact(roomVersion)
+}
+
+var emptyObject = jsontext.Value("{}")
+
+func RedactContent(eventType string, content jsontext.Value, roomVersion id.RoomVersion) jsontext.Value {
+ switch eventType {
+ case "m.room.member":
+ allowedPaths := []string{"membership"}
+ if roomVersion.RestrictedJoinsFix() {
+ allowedPaths = append(allowedPaths, "join_authorised_via_users_server")
+ }
+ if roomVersion.UpdatedRedactionRules() {
+ allowedPaths = append(allowedPaths, exgjson.Path("third_party_invite", "signed"))
+ }
+ return filteredObject(content, allowedPaths...)
+ case "m.room.create":
+ if !roomVersion.UpdatedRedactionRules() {
+ return filteredObject(content, "creator")
+ }
+ return content
+ case "m.room.join_rules":
+ if roomVersion.RestrictedJoins() {
+ return filteredObject(content, "join_rule", "allow")
+ }
+ return filteredObject(content, "join_rule")
+ case "m.room.power_levels":
+ allowedKeys := []string{"ban", "events", "events_default", "kick", "redact", "state_default", "users", "users_default"}
+ if roomVersion.UpdatedRedactionRules() {
+ allowedKeys = append(allowedKeys, "invite")
+ }
+ return filteredObject(content, allowedKeys...)
+ case "m.room.history_visibility":
+ return filteredObject(content, "history_visibility")
+ case "m.room.redaction":
+ if roomVersion.RedactsInContent() {
+ return filteredObject(content, "redacts")
+ }
+ return emptyObject
+ case "m.room.aliases":
+ if roomVersion.SpecialCasedAliasesAuth() {
+ return filteredObject(content, "aliases")
+ }
+ return emptyObject
+ default:
+ return emptyObject
+ }
+}
+
+func (pdu *PDU) Redact(roomVersion id.RoomVersion) *PDU {
+ pdu.Unknown = nil
+ pdu.Unsigned = nil
+ if roomVersion.UpdatedRedactionRules() {
+ pdu.DeprecatedPrevState = nil
+ pdu.DeprecatedOrigin = nil
+ pdu.DeprecatedMembership = nil
+ }
+ if pdu.Type != "m.room.redaction" || roomVersion.RedactsInContent() {
+ pdu.Redacts = nil
+ }
+ pdu.Content = RedactContent(pdu.Type, pdu.Content, roomVersion)
+ return pdu
+}
diff --git a/federation/pdu/signature.go b/federation/pdu/signature.go
new file mode 100644
index 00000000..04e7c5ef
--- /dev/null
+++ b/federation/pdu/signature.go
@@ -0,0 +1,60 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu
+
+import (
+ "crypto/ed25519"
+ "encoding/base64"
+ "fmt"
+ "time"
+
+ "maunium.net/go/mautrix/federation/signutil"
+ "maunium.net/go/mautrix/id"
+)
+
+func (pdu *PDU) Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error {
+ err := pdu.FillContentHash()
+ if err != nil {
+ return err
+ }
+ rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion))
+ if err != nil {
+ return fmt.Errorf("failed to marshal redacted PDU to sign: %w", err)
+ }
+ signature := ed25519.Sign(privateKey, rawJSON)
+ pdu.AddSignature(serverName, keyID, base64.RawStdEncoding.EncodeToString(signature))
+ return nil
+}
+
+func (pdu *PDU) VerifySignature(roomVersion id.RoomVersion, serverName string, getKey GetKeyFunc) error {
+ rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion))
+ if err != nil {
+ return fmt.Errorf("failed to marshal redacted PDU to verify signature: %w", err)
+ }
+ verified := false
+ for keyID, sig := range pdu.Signatures[serverName] {
+ originServerTS := time.UnixMilli(pdu.OriginServerTS)
+ key, validUntil, err := getKey(serverName, keyID, originServerTS)
+ if err != nil {
+ return fmt.Errorf("failed to get key %s for %s: %w", keyID, serverName, err)
+ } else if key == "" {
+ return fmt.Errorf("key %s not found for %s", keyID, serverName)
+ } else if validUntil.Before(originServerTS) && roomVersion.EnforceSigningKeyValidity() {
+ return fmt.Errorf("key %s for %s is only valid until %s, but event is from %s", keyID, serverName, validUntil, originServerTS)
+ } else if err = signutil.VerifyJSONRaw(key, sig, rawJSON); err != nil {
+ return fmt.Errorf("failed to verify signature from key %s: %w", keyID, err)
+ } else {
+ verified = true
+ }
+ }
+ if !verified {
+ return fmt.Errorf("no verifiable signatures found for server %s", serverName)
+ }
+ return nil
+}
diff --git a/federation/pdu/signature_test.go b/federation/pdu/signature_test.go
new file mode 100644
index 00000000..01df5076
--- /dev/null
+++ b/federation/pdu/signature_test.go
@@ -0,0 +1,102 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu_test
+
+import (
+ "crypto/ed25519"
+ "encoding/base64"
+ "encoding/json/jsontext"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.mau.fi/util/exerrors"
+
+ "maunium.net/go/mautrix/federation/pdu"
+ "maunium.net/go/mautrix/id"
+)
+
+func TestPDU_VerifySignature(t *testing.T) {
+ for _, test := range testPDUs {
+ t.Run(test.name, func(t *testing.T) {
+ parsed := parsePDU(test.pdu)
+ err := parsed.VerifySignature(test.roomVersion, test.serverName, test.getKey)
+ assert.NoError(t, err)
+ })
+ }
+}
+
+func TestPDU_VerifySignature_Fail_NoKey(t *testing.T) {
+ test := roomV12MessageTestPDU
+ parsed := parsePDU(test.pdu)
+ err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
+ return
+ })
+ assert.Error(t, err)
+}
+
+func TestPDU_VerifySignature_V4ExpiredKey(t *testing.T) {
+ test := roomV4MessageTestPDU
+ parsed := parsePDU(test.pdu)
+ err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
+ key = test.keys[keyID].key
+ validUntil = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
+ return
+ })
+ assert.NoError(t, err)
+}
+
+func TestPDU_VerifySignature_V12ExpiredKey(t *testing.T) {
+ test := roomV12MessageTestPDU
+ parsed := parsePDU(test.pdu)
+ err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
+ key = test.keys[keyID].key
+ validUntil = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
+ return
+ })
+ assert.Error(t, err)
+}
+
+func TestPDU_VerifySignature_V12InvalidSignature(t *testing.T) {
+ test := roomV12MessageTestPDU
+ parsed := parsePDU(test.pdu)
+ for _, sigs := range parsed.Signatures {
+ for key := range sigs {
+ sigs[key] = sigs[key][:len(sigs[key])-3] + "ABC"
+ }
+ }
+ err := parsed.VerifySignature(test.roomVersion, test.serverName, test.getKey)
+ assert.Error(t, err)
+}
+
+func TestPDU_Sign(t *testing.T) {
+ pubKey, privKey := exerrors.Must2(ed25519.GenerateKey(nil))
+ evt := &pdu.PDU{
+ AuthEvents: []id.EventID{"$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA", "$hyeL_nU_L3tsZ2dtZZpAHk0Skv-PqFQIipuII_By584"},
+ Content: jsontext.Value(`{"msgtype":"m.text","body":"Hello, world!"}`),
+ Depth: 123,
+ OriginServerTS: 1755384351627,
+ PrevEvents: []id.EventID{"$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA"},
+ RoomID: "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ",
+ Sender: "@tulir:example.com",
+ Type: "m.room.message",
+ }
+ err := evt.Sign(id.RoomV12, "example.com", "ed25519:rand", privKey)
+ require.NoError(t, err)
+ err = evt.VerifySignature(id.RoomV11, "example.com", func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
+ if serverName == "example.com" && keyID == "ed25519:rand" {
+ key = id.SigningKey(base64.RawStdEncoding.EncodeToString(pubKey))
+ validUntil = time.Now()
+ }
+ return
+ })
+ require.NoError(t, err)
+
+}
diff --git a/federation/pdu/v1.go b/federation/pdu/v1.go
new file mode 100644
index 00000000..9557f8ab
--- /dev/null
+++ b/federation/pdu/v1.go
@@ -0,0 +1,277 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu
+
+import (
+ "crypto/ed25519"
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/json/jsontext"
+ "encoding/json/v2"
+ "fmt"
+ "time"
+
+ "github.com/tidwall/gjson"
+ "go.mau.fi/util/ptr"
+
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/federation/signutil"
+ "maunium.net/go/mautrix/id"
+)
+
+type V1EventReference struct {
+ ID id.EventID
+ Hashes Hashes
+}
+
+var (
+ _ json.UnmarshalerFrom = (*V1EventReference)(nil)
+ _ json.MarshalerTo = (*V1EventReference)(nil)
+)
+
+func (er *V1EventReference) MarshalJSONTo(enc *jsontext.Encoder) error {
+ return json.MarshalEncode(enc, []any{er.ID, er.Hashes})
+}
+
+func (er *V1EventReference) UnmarshalJSONFrom(dec *jsontext.Decoder) error {
+ var ref V1EventReference
+ var data []jsontext.Value
+ if err := json.UnmarshalDecode(dec, &data); err != nil {
+ return err
+ } else if len(data) != 2 {
+ return fmt.Errorf("V1EventReference.UnmarshalJSONFrom: expected array with 2 elements, got %d", len(data))
+ } else if err = json.Unmarshal(data[0], &ref.ID); err != nil {
+ return fmt.Errorf("V1EventReference.UnmarshalJSONFrom: failed to unmarshal event ID: %w", err)
+ } else if err = json.Unmarshal(data[1], &ref.Hashes); err != nil {
+ return fmt.Errorf("V1EventReference.UnmarshalJSONFrom: failed to unmarshal hashes: %w", err)
+ }
+ *er = ref
+ return nil
+}
+
+type RoomV1PDU struct {
+ AuthEvents []V1EventReference `json:"auth_events"`
+ Content jsontext.Value `json:"content"`
+ Depth int64 `json:"depth"`
+ EventID id.EventID `json:"event_id"`
+ Hashes *Hashes `json:"hashes,omitzero"`
+ OriginServerTS int64 `json:"origin_server_ts"`
+ PrevEvents []V1EventReference `json:"prev_events"`
+ Redacts *id.EventID `json:"redacts,omitzero"`
+ RoomID id.RoomID `json:"room_id"`
+ Sender id.UserID `json:"sender"`
+ Signatures map[string]map[id.KeyID]string `json:"signatures,omitzero"`
+ StateKey *string `json:"state_key,omitzero"`
+ Type string `json:"type"`
+ Unsigned jsontext.Value `json:"unsigned,omitzero"`
+
+ Unknown jsontext.Value `json:",unknown"`
+
+ // Deprecated legacy fields
+ DeprecatedPrevState jsontext.Value `json:"prev_state,omitzero"`
+ DeprecatedOrigin jsontext.Value `json:"origin,omitzero"`
+ DeprecatedMembership jsontext.Value `json:"membership,omitzero"`
+}
+
+func (pdu *RoomV1PDU) GetRoomID() (id.RoomID, error) {
+ return pdu.RoomID, nil
+}
+
+func (pdu *RoomV1PDU) GetEventID(roomVersion id.RoomVersion) (id.EventID, error) {
+ if !pdu.SupportsRoomVersion(roomVersion) {
+ return "", fmt.Errorf("RoomV1PDU.GetEventID: unsupported room version %s", roomVersion)
+ }
+ return pdu.EventID, nil
+}
+
+func (pdu *RoomV1PDU) RedactForSignature(roomVersion id.RoomVersion) *RoomV1PDU {
+ pdu.Signatures = nil
+ return pdu.Redact(roomVersion)
+}
+
+func (pdu *RoomV1PDU) Redact(roomVersion id.RoomVersion) *RoomV1PDU {
+ pdu.Unknown = nil
+ pdu.Unsigned = nil
+ if pdu.Type != "m.room.redaction" {
+ pdu.Redacts = nil
+ }
+ pdu.Content = RedactContent(pdu.Type, pdu.Content, roomVersion)
+ return pdu
+}
+
+func (pdu *RoomV1PDU) GetReferenceHash(roomVersion id.RoomVersion) ([32]byte, error) {
+ if !pdu.SupportsRoomVersion(roomVersion) {
+ return [32]byte{}, fmt.Errorf("RoomV1PDU.GetReferenceHash: unsupported room version %s", roomVersion)
+ }
+ if pdu == nil {
+ return [32]byte{}, ErrPDUIsNil
+ }
+ if pdu.Hashes == nil || pdu.Hashes.SHA256 == nil {
+ if err := pdu.FillContentHash(); err != nil {
+ return [32]byte{}, err
+ }
+ }
+ rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion))
+ if err != nil {
+ return [32]byte{}, fmt.Errorf("failed to marshal redacted PDU to calculate event ID: %w", err)
+ }
+ return sha256.Sum256(rawJSON), nil
+}
+
+func (pdu *RoomV1PDU) CalculateContentHash() ([32]byte, error) {
+ if pdu == nil {
+ return [32]byte{}, ErrPDUIsNil
+ }
+ pduClone := pdu.Clone()
+ pduClone.Signatures = nil
+ pduClone.Unsigned = nil
+ pduClone.Hashes = nil
+ rawJSON, err := marshalCanonical(pduClone)
+ if err != nil {
+ return [32]byte{}, fmt.Errorf("failed to marshal PDU to calculate content hash: %w", err)
+ }
+ return sha256.Sum256(rawJSON), nil
+}
+
+func (pdu *RoomV1PDU) FillContentHash() error {
+ if pdu == nil {
+ return ErrPDUIsNil
+ } else if pdu.Hashes != nil {
+ return nil
+ } else if hash, err := pdu.CalculateContentHash(); err != nil {
+ return err
+ } else {
+ pdu.Hashes = &Hashes{SHA256: hash[:]}
+ return nil
+ }
+}
+
+func (pdu *RoomV1PDU) VerifyContentHash() bool {
+ if pdu == nil || pdu.Hashes == nil {
+ return false
+ }
+ calculatedHash, err := pdu.CalculateContentHash()
+ if err != nil {
+ return false
+ }
+ return hmac.Equal(calculatedHash[:], pdu.Hashes.SHA256)
+}
+
+func (pdu *RoomV1PDU) Clone() *RoomV1PDU {
+ return ptr.Clone(pdu)
+}
+
+func (pdu *RoomV1PDU) Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error {
+ if !pdu.SupportsRoomVersion(roomVersion) {
+ return fmt.Errorf("RoomV1PDU.Sign: unsupported room version %s", roomVersion)
+ }
+ err := pdu.FillContentHash()
+ if err != nil {
+ return err
+ }
+ rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion))
+ if err != nil {
+ return fmt.Errorf("failed to marshal redacted PDU to sign: %w", err)
+ }
+ signature := ed25519.Sign(privateKey, rawJSON)
+ if pdu.Signatures == nil {
+ pdu.Signatures = make(map[string]map[id.KeyID]string)
+ }
+ if _, ok := pdu.Signatures[serverName]; !ok {
+ pdu.Signatures[serverName] = make(map[id.KeyID]string)
+ }
+ pdu.Signatures[serverName][keyID] = base64.RawStdEncoding.EncodeToString(signature)
+ return nil
+}
+
+func (pdu *RoomV1PDU) VerifySignature(roomVersion id.RoomVersion, serverName string, getKey GetKeyFunc) error {
+ if !pdu.SupportsRoomVersion(roomVersion) {
+ return fmt.Errorf("RoomV1PDU.VerifySignature: unsupported room version %s", roomVersion)
+ }
+ rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion))
+ if err != nil {
+ return fmt.Errorf("failed to marshal redacted PDU to verify signature: %w", err)
+ }
+ verified := false
+ for keyID, sig := range pdu.Signatures[serverName] {
+ originServerTS := time.UnixMilli(pdu.OriginServerTS)
+ key, _, err := getKey(serverName, keyID, originServerTS)
+ if err != nil {
+ return fmt.Errorf("failed to get key %s for %s: %w", keyID, serverName, err)
+ } else if key == "" {
+ return fmt.Errorf("key %s not found for %s", keyID, serverName)
+ } else if err = signutil.VerifyJSONRaw(key, sig, rawJSON); err != nil {
+ return fmt.Errorf("failed to verify signature from key %s: %w", keyID, err)
+ } else {
+ verified = true
+ }
+ }
+ if !verified {
+ return fmt.Errorf("no verifiable signatures found for server %s", serverName)
+ }
+ return nil
+}
+
+func (pdu *RoomV1PDU) SupportsRoomVersion(roomVersion id.RoomVersion) bool {
+ switch roomVersion {
+ case id.RoomV0, id.RoomV1, id.RoomV2:
+ return true
+ default:
+ return false
+ }
+}
+
+func (pdu *RoomV1PDU) ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error) {
+ if !pdu.SupportsRoomVersion(roomVersion) {
+ return nil, fmt.Errorf("RoomV1PDU.ToClientEvent: unsupported room version %s", roomVersion)
+ }
+ evtType := event.Type{Type: pdu.Type, Class: event.MessageEventType}
+ if pdu.StateKey != nil {
+ evtType.Class = event.StateEventType
+ }
+ evt := &event.Event{
+ StateKey: pdu.StateKey,
+ Sender: pdu.Sender,
+ Type: evtType,
+ Timestamp: pdu.OriginServerTS,
+ ID: pdu.EventID,
+ RoomID: pdu.RoomID,
+ Redacts: ptr.Val(pdu.Redacts),
+ }
+ err := json.Unmarshal(pdu.Content, &evt.Content)
+ if err != nil {
+ return nil, fmt.Errorf("failed to unmarshal content: %w", err)
+ }
+ return evt, nil
+}
+
+func (pdu *RoomV1PDU) AuthEventSelection(_ id.RoomVersion) (keys AuthEventSelection) {
+ if pdu.Type == event.StateCreate.Type && pdu.StateKey != nil {
+ return AuthEventSelection{}
+ }
+ keys = make(AuthEventSelection, 0, 3)
+ keys.Add(event.StateCreate.Type, "")
+ keys.Add(event.StatePowerLevels.Type, "")
+ keys.Add(event.StateMember.Type, pdu.Sender.String())
+ if pdu.Type == event.StateMember.Type && pdu.StateKey != nil {
+ keys.Add(event.StateMember.Type, *pdu.StateKey)
+ membership := event.Membership(gjson.GetBytes(pdu.Content, "membership").Str)
+ if membership == event.MembershipJoin || membership == event.MembershipInvite || membership == event.MembershipKnock {
+ keys.Add(event.StateJoinRules.Type, "")
+ }
+ if membership == event.MembershipInvite {
+ thirdPartyInviteToken := gjson.GetBytes(pdu.Content, thirdPartyInviteTokenPath).Str
+ if thirdPartyInviteToken != "" {
+ keys.Add(event.StateThirdPartyInvite.Type, thirdPartyInviteToken)
+ }
+ }
+ }
+ return
+}
diff --git a/federation/pdu/v1_test.go b/federation/pdu/v1_test.go
new file mode 100644
index 00000000..ecf2dbd2
--- /dev/null
+++ b/federation/pdu/v1_test.go
@@ -0,0 +1,86 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu_test
+
+import (
+ "encoding/base64"
+ "encoding/json/v2"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "go.mau.fi/util/exerrors"
+
+ "maunium.net/go/mautrix/federation/pdu"
+ "maunium.net/go/mautrix/id"
+)
+
+var testV1PDUs = []testPDU{{
+ name: "m.room.message in v1 room",
+ pdu: `{"auth_events":[["$159234730483190eXavq:matrix.org",{"sha256":"VprZrhMqOQyKbfF3UE26JXE8D27ih4R/FGGc8GZ0Whs"}],["$143454825711DhCxH:matrix.org",{"sha256":"3sJh/5GOB094OKuhbjL634Gt69YIcge9GD55ciJa9ok"}],["$156837651426789wiPdh:maunium.net",{"sha256":"FGyR3sxJ/VxYabDkO/5qtwrPR3hLwGknJ0KX0w3GUHE"}]],"content":{"body":"photo-1526336024174-e58f5cdd8e13.jpg","info":{"h":1620,"mimetype":"image/jpeg","size":208053,"w":1080},"msgtype":"m.image","url":"mxc://maunium.net/aEqEghIjFPAerIhCxJCYpQeC"},"depth":16669,"event_id":"$16738169022163bokdi:maunium.net","hashes":{"sha256":"XYB47Gf2vAci3BTguIJaC75ZYGMuVY65jcvoUVgpcLA"},"origin":"maunium.net","origin_server_ts":1673816902100,"prev_events":[["$1673816901121325UMCjA:matrix.org",{"sha256":"t7e0IYHLI3ydIPoIU8a8E/pIWXH9cNLlQBEtGyGtHwc"}]],"room_id":"!jhpZBTbckszblMYjMK:matrix.org","sender":"@cat:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"uRZbEm+P+Y1ZVgwBn5I6SlaUZdzlH1bB4nv81yt5EIQ0b1fZ8YgM4UWMijrrXp3+NmqRFl0cakSM3MneJOtFCw"}},"unsigned":{"age_ts":1673816902100}}`,
+ eventID: "$16738169022163bokdi:maunium.net",
+ roomVersion: id.RoomV1,
+ serverDetails: mauniumNet,
+}, {
+ name: "m.room.create in v1 room",
+ pdu: `{"origin": "matrix.org", "signatures": {"matrix.org": {"ed25519:auto": "XTejpXn5REoHrZWgCpJglGX7MfOWS2zUjYwJRLrwW2PQPbFdqtL+JnprBXwIP2C1NmgWSKG+am1QdApu0KoHCQ"}}, "origin_server_ts": 1434548257426, "sender": "@appservice-irc:matrix.org", "event_id": "$143454825711DhCxH:matrix.org", "prev_events": [], "unsigned": {"age": 12872287834}, "state_key": "", "content": {"creator": "@appservice-irc:matrix.org"}, "depth": 1, "prev_state": [], "room_id": "!jhpZBTbckszblMYjMK:matrix.org", "auth_events": [], "hashes": {"sha256": "+SSdmeeoKI/6yK6sY4XAFljWFiugSlCiXQf0QMCZjTs"}, "type": "m.room.create"}`,
+ eventID: "$143454825711DhCxH:matrix.org",
+ roomVersion: id.RoomV1,
+ serverDetails: matrixOrg,
+}, {
+ name: "m.room.member in v1 room",
+ pdu: `{"auth_events": [["$1536447669931522zlyWe:matrix.org", {"sha256": "UkzPGd7cPAGvC0FVx3Yy2/Q0GZhA2kcgj8MGp5pjYV8"}], ["$143454825711DhCxH:matrix.org", {"sha256": "3sJh/5GOB094OKuhbjL634Gt69YIcge9GD55ciJa9ok"}], ["$143454825714nUEqZ:matrix.org", {"sha256": "NjuZXu8EDMfIfejPcNlC/IdnKQAGpPIcQjHaf0BZaHk"}]], "prev_events": [["$15660585503271JRRMm:maunium.net", {"sha256": "/Sm7uSLkYMHapp6I3NuEVJlk2JucW2HqjsQy9vzhciA"}]], "type": "m.room.member", "room_id": "!jhpZBTbckszblMYjMK:matrix.org", "sender": "@tulir:maunium.net", "content": {"membership": "join", "avatar_url": "mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO", "displayname": "tulir"}, "depth": 10485, "prev_state": [], "state_key": "@tulir:maunium.net", "event_id": "$15660585693272iEryv:maunium.net", "origin": "maunium.net", "origin_server_ts": 1566058569201, "hashes": {"sha256": "1D6fdDzKsMGCxSqlXPA7I9wGQNTutVuJke1enGHoWK8"}, "signatures": {"maunium.net": {"ed25519:a_xxeS": "Lj/zDK6ozr4vgsxyL8jY56wTGWoA4jnlvkTs5paCX1w3nNKHnQnSMi+wuaqI6yv5vYh9usGWco2LLMuMzYXcBg"}}, "unsigned": {"age_ts": 1566058569201, "replaces_state": "$15660585383268liyBc:maunium.net"}}`,
+ eventID: "$15660585693272iEryv:maunium.net",
+ roomVersion: id.RoomV1,
+ serverDetails: mauniumNet,
+}}
+
+func parseV1PDU(pdu string) (out *pdu.RoomV1PDU) {
+ exerrors.PanicIfNotNil(json.Unmarshal([]byte(pdu), &out))
+ return
+}
+
+func TestRoomV1PDU_CalculateContentHash(t *testing.T) {
+ for _, test := range testV1PDUs {
+ t.Run(test.name, func(t *testing.T) {
+ parsed := parseV1PDU(test.pdu)
+ contentHash := exerrors.Must(parsed.CalculateContentHash())
+ assert.Equal(
+ t,
+ base64.RawStdEncoding.EncodeToString(parsed.Hashes.SHA256),
+ base64.RawStdEncoding.EncodeToString(contentHash[:]),
+ )
+ })
+ }
+}
+
+func TestRoomV1PDU_VerifyContentHash(t *testing.T) {
+ for _, test := range testV1PDUs {
+ t.Run(test.name, func(t *testing.T) {
+ parsed := parseV1PDU(test.pdu)
+ assert.True(t, parsed.VerifyContentHash())
+ })
+ }
+}
+
+func TestRoomV1PDU_VerifySignature(t *testing.T) {
+ for _, test := range testV1PDUs {
+ t.Run(test.name, func(t *testing.T) {
+ parsed := parseV1PDU(test.pdu)
+ err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) {
+ key, ok := test.keys[keyID]
+ if ok {
+ return key.key, key.validUntilTS, nil
+ }
+ return "", time.Time{}, nil
+ })
+ assert.NoError(t, err)
+ })
+ }
+}
diff --git a/federation/resolution.go b/federation/resolution.go
index 24085282..a3188266 100644
--- a/federation/resolution.go
+++ b/federation/resolution.go
@@ -20,6 +20,8 @@ import (
"time"
"github.com/rs/zerolog"
+
+ "maunium.net/go/mautrix"
)
type ResolvedServerName struct {
@@ -78,7 +80,10 @@ func ResolveServerName(ctx context.Context, serverName string, opts ...*ResolveS
} else if wellKnown != nil {
output.Expires = expiry
output.HostHeader = wellKnown.Server
- hostname, port, ok = ParseServerName(wellKnown.Server)
+ wkHost, wkPort, ok := ParseServerName(wellKnown.Server)
+ if ok {
+ hostname, port = wkHost, wkPort
+ }
// Step 3.1 and 3.2: IP literals and hostnames with port inside .well-known
if net.ParseIP(hostname) != nil || port != 0 {
if port == 0 {
@@ -120,6 +125,38 @@ func RequestSRV(ctx context.Context, cli *net.Resolver, hostname string) ([]*net
return target, err
}
+func parseCacheControl(resp *http.Response) time.Duration {
+ cc := resp.Header.Get("Cache-Control")
+ if cc == "" {
+ return 0
+ }
+ parts := strings.Split(cc, ",")
+ for _, part := range parts {
+ kv := strings.SplitN(strings.TrimSpace(part), "=", 1)
+ switch kv[0] {
+ case "no-cache", "no-store":
+ return 0
+ case "max-age":
+ if len(kv) < 2 {
+ continue
+ }
+ maxAge, err := strconv.Atoi(kv[1])
+ if err != nil || maxAge < 0 {
+ continue
+ }
+ age, _ := strconv.Atoi(resp.Header.Get("Age"))
+ return time.Duration(maxAge-age) * time.Second
+ }
+ }
+ return 0
+}
+
+const (
+ MinCacheDuration = 1 * time.Hour
+ MaxCacheDuration = 72 * time.Hour
+ DefaultCacheDuration = 24 * time.Hour
+)
+
// RequestWellKnown sends a request to the well-known endpoint of a server and returns the response,
// plus the time when the cache should expire.
func RequestWellKnown(ctx context.Context, cli *http.Client, hostname string) (*RespWellKnown, time.Time, error) {
@@ -139,14 +176,23 @@ func RequestWellKnown(ctx context.Context, cli *http.Client, hostname string) (*
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, time.Time{}, fmt.Errorf("unexpected status code %d", resp.StatusCode)
+ } else if resp.ContentLength > mautrix.WellKnownMaxSize {
+ return nil, time.Time{}, fmt.Errorf("response too large: %d bytes", resp.ContentLength)
}
var respData RespWellKnown
- err = json.NewDecoder(io.LimitReader(resp.Body, 50*1024)).Decode(&respData)
+ err = json.NewDecoder(io.LimitReader(resp.Body, mautrix.WellKnownMaxSize)).Decode(&respData)
if err != nil {
return nil, time.Time{}, fmt.Errorf("failed to decode response: %w", err)
} else if respData.Server == "" {
return nil, time.Time{}, errors.New("server name not found in response")
}
- // TODO parse cache-control header
+ cacheDuration := parseCacheControl(resp)
+ if cacheDuration <= 0 {
+ cacheDuration = DefaultCacheDuration
+ } else if cacheDuration < MinCacheDuration {
+ cacheDuration = MinCacheDuration
+ } else if cacheDuration > MaxCacheDuration {
+ cacheDuration = MaxCacheDuration
+ }
return &respData, time.Now().Add(24 * time.Hour), nil
}
diff --git a/federation/serverauth.go b/federation/serverauth.go
new file mode 100644
index 00000000..cd300341
--- /dev/null
+++ b/federation/serverauth.go
@@ -0,0 +1,264 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package federation
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "maps"
+ "net/http"
+ "slices"
+ "strings"
+ "sync"
+
+ "github.com/rs/zerolog"
+ "go.mau.fi/util/ptr"
+
+ "maunium.net/go/mautrix"
+ "maunium.net/go/mautrix/id"
+)
+
+type ServerAuth struct {
+ Keys KeyCache
+ Client *Client
+ GetDestination func(XMatrixAuth) string
+ MaxBodySize int64
+
+ keyFetchLocks map[string]*sync.Mutex
+ keyFetchLocksLock sync.Mutex
+}
+
+func NewServerAuth(client *Client, keyCache KeyCache, getDestination func(auth XMatrixAuth) string) *ServerAuth {
+ return &ServerAuth{
+ Keys: keyCache,
+ Client: client,
+ GetDestination: getDestination,
+ MaxBodySize: 50 * 1024 * 1024,
+ keyFetchLocks: make(map[string]*sync.Mutex),
+ }
+}
+
+var MUnauthorized = mautrix.RespError{ErrCode: "M_UNAUTHORIZED", StatusCode: http.StatusUnauthorized}
+
+var (
+ errMissingAuthHeader = MUnauthorized.WithMessage("Missing Authorization header")
+ errInvalidAuthHeader = MUnauthorized.WithMessage("Authorization header does not start with X-Matrix")
+ errMalformedAuthHeader = MUnauthorized.WithMessage("X-Matrix value is missing required components")
+ errInvalidDestination = MUnauthorized.WithMessage("Invalid destination in X-Matrix header")
+ errFailedToQueryKeys = MUnauthorized.WithMessage("Failed to query server keys")
+ errInvalidSelfSignatures = MUnauthorized.WithMessage("Server keys don't have valid self-signatures")
+ errRequestBodyTooLarge = mautrix.MTooLarge.WithMessage("Request body too large")
+ errInvalidJSONBody = mautrix.MBadJSON.WithMessage("Request body is not valid JSON")
+ errBodyReadFailed = mautrix.MUnknown.WithMessage("Failed to read request body")
+ errInvalidRequestSignature = MUnauthorized.WithMessage("Failed to verify request signature")
+)
+
+type XMatrixAuth struct {
+ Origin string
+ Destination string
+ KeyID id.KeyID
+ Signature string
+}
+
+func (xma XMatrixAuth) String() string {
+ return fmt.Sprintf(
+ `X-Matrix origin="%s",destination="%s",key="%s",sig="%s"`,
+ xma.Origin,
+ xma.Destination,
+ xma.KeyID,
+ xma.Signature,
+ )
+}
+
+func ParseXMatrixAuth(auth string) (xma XMatrixAuth) {
+ auth = strings.TrimPrefix(auth, "X-Matrix ")
+ // TODO upgrade to strings.SplitSeq after Go 1.24 is the minimum
+ for _, part := range strings.Split(auth, ",") {
+ part = strings.TrimSpace(part)
+ eqIdx := strings.Index(part, "=")
+ if eqIdx == -1 || strings.Count(part, "=") > 1 {
+ continue
+ }
+ val := strings.Trim(part[eqIdx+1:], "\"")
+ switch strings.ToLower(part[:eqIdx]) {
+ case "origin":
+ xma.Origin = val
+ case "destination":
+ xma.Destination = val
+ case "key":
+ xma.KeyID = id.KeyID(val)
+ case "sig":
+ xma.Signature = val
+ }
+ }
+ return
+}
+
+func (sa *ServerAuth) GetKeysWithCache(ctx context.Context, serverName string, keyID id.KeyID) (*ServerKeyResponse, error) {
+ res, err := sa.Keys.LoadKeys(serverName)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read cache: %w", err)
+ } else if res.HasKey(keyID) {
+ return res, nil
+ }
+
+ sa.keyFetchLocksLock.Lock()
+ lock, ok := sa.keyFetchLocks[serverName]
+ if !ok {
+ lock = &sync.Mutex{}
+ sa.keyFetchLocks[serverName] = lock
+ }
+ sa.keyFetchLocksLock.Unlock()
+
+ lock.Lock()
+ defer lock.Unlock()
+ res, err = sa.Keys.LoadKeys(serverName)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read cache: %w", err)
+ } else if res != nil {
+ if res.HasKey(keyID) {
+ return res, nil
+ } else if !sa.Keys.ShouldReQuery(serverName) {
+ zerolog.Ctx(ctx).Trace().
+ Str("server_name", serverName).
+ Stringer("key_id", keyID).
+ Msg("Not sending key request for missing key ID, last query was too recent")
+ return res, nil
+ }
+ }
+ res, err = sa.Client.ServerKeys(ctx, serverName)
+ if err != nil {
+ sa.Keys.StoreFetchError(serverName, err)
+ return nil, err
+ }
+ sa.Keys.StoreKeys(res)
+ return res, nil
+}
+
+type fixedLimitedReader struct {
+ R io.Reader
+ N int64
+ Err error
+}
+
+func (l *fixedLimitedReader) Read(p []byte) (n int, err error) {
+ if l.N <= 0 {
+ return 0, l.Err
+ }
+ if int64(len(p)) > l.N {
+ p = p[0:l.N]
+ }
+ n, err = l.R.Read(p)
+ l.N -= int64(n)
+ return
+}
+
+func (sa *ServerAuth) Authenticate(r *http.Request) (*http.Request, *mautrix.RespError) {
+ defer func() {
+ _ = r.Body.Close()
+ }()
+ log := zerolog.Ctx(r.Context())
+ if r.ContentLength > sa.MaxBodySize {
+ return nil, &errRequestBodyTooLarge
+ }
+ auth := r.Header.Get("Authorization")
+ if auth == "" {
+ return nil, &errMissingAuthHeader
+ } else if !strings.HasPrefix(auth, "X-Matrix ") {
+ return nil, &errInvalidAuthHeader
+ }
+ parsed := ParseXMatrixAuth(auth)
+ if parsed.Origin == "" || parsed.KeyID == "" || parsed.Signature == "" {
+ log.Trace().Str("auth_header", auth).Msg("Malformed X-Matrix header")
+ return nil, &errMalformedAuthHeader
+ }
+ destination := sa.GetDestination(parsed)
+ if destination == "" || (parsed.Destination != "" && parsed.Destination != destination) {
+ log.Trace().
+ Str("got_destination", parsed.Destination).
+ Str("expected_destination", destination).
+ Msg("Invalid destination in X-Matrix header")
+ return nil, &errInvalidDestination
+ }
+ resp, err := sa.GetKeysWithCache(r.Context(), parsed.Origin, parsed.KeyID)
+ if err != nil {
+ if !errors.Is(err, ErrRecentKeyQueryFailed) {
+ log.Err(err).
+ Str("server_name", parsed.Origin).
+ Msg("Failed to query keys to authenticate request")
+ } else {
+ log.Trace().Err(err).
+ Str("server_name", parsed.Origin).
+ Msg("Failed to query keys to authenticate request (cached error)")
+ }
+ return nil, &errFailedToQueryKeys
+ } else if err := resp.VerifySelfSignature(); err != nil {
+ log.Trace().Err(err).
+ Str("server_name", parsed.Origin).
+ Msg("Failed to validate self-signatures of server keys")
+ return nil, &errInvalidSelfSignatures
+ }
+ key, ok := resp.VerifyKeys[parsed.KeyID]
+ if !ok {
+ keys := slices.Collect(maps.Keys(resp.VerifyKeys))
+ log.Trace().
+ Stringer("expected_key_id", parsed.KeyID).
+ Any("found_key_ids", keys).
+ Msg("Didn't find expected key ID to verify request")
+ return nil, ptr.Ptr(MUnauthorized.WithMessage("Key ID %q not found (got %v)", parsed.KeyID, keys))
+ }
+ var reqBody []byte
+ if r.ContentLength != 0 && r.Method != http.MethodGet && r.Method != http.MethodHead {
+ reqBody, err = io.ReadAll(&fixedLimitedReader{R: r.Body, N: sa.MaxBodySize, Err: errRequestBodyTooLarge})
+ if errors.Is(err, errRequestBodyTooLarge) {
+ return nil, &errRequestBodyTooLarge
+ } else if err != nil {
+ log.Err(err).
+ Str("server_name", parsed.Origin).
+ Msg("Failed to read request body to authenticate")
+ return nil, &errBodyReadFailed
+ } else if !json.Valid(reqBody) {
+ return nil, &errInvalidJSONBody
+ }
+ }
+ err = (&signableRequest{
+ Method: r.Method,
+ URI: r.URL.RequestURI(),
+ Origin: parsed.Origin,
+ Destination: destination,
+ Content: reqBody,
+ }).Verify(key.Key, parsed.Signature)
+ if err != nil {
+ log.Trace().Err(err).Msg("Request has invalid signature")
+ return nil, &errInvalidRequestSignature
+ }
+ ctx := context.WithValue(r.Context(), contextKeyDestinationServer, destination)
+ ctx = context.WithValue(ctx, contextKeyOriginServer, parsed.Origin)
+ ctx = log.With().
+ Str("origin_server_name", parsed.Origin).
+ Str("destination_server_name", destination).
+ Logger().WithContext(ctx)
+ modifiedReq := r.WithContext(ctx)
+ if reqBody != nil {
+ modifiedReq.Body = io.NopCloser(bytes.NewReader(reqBody))
+ }
+ return modifiedReq, nil
+}
+
+func (sa *ServerAuth) AuthenticateMiddleware(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if modifiedReq, err := sa.Authenticate(r); err != nil {
+ err.Write(w)
+ } else {
+ next.ServeHTTP(w, modifiedReq)
+ }
+ })
+}
diff --git a/federation/serverauth_test.go b/federation/serverauth_test.go
new file mode 100644
index 00000000..f99fc6cf
--- /dev/null
+++ b/federation/serverauth_test.go
@@ -0,0 +1,29 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package federation_test
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "maunium.net/go/mautrix/federation"
+)
+
+func TestServerKeyResponse_VerifySelfSignature(t *testing.T) {
+ cli := federation.NewClient("", nil, nil)
+ ctx := context.Background()
+ for _, name := range []string{"matrix.org", "maunium.net", "cd.mau.dev", "uwu.mau.dev"} {
+ t.Run(name, func(t *testing.T) {
+ resp, err := cli.ServerKeys(ctx, name)
+ require.NoError(t, err)
+ assert.NoError(t, resp.VerifySelfSignature())
+ })
+ }
+}
diff --git a/federation/signingkey.go b/federation/signingkey.go
index 67751b48..a4ad9679 100644
--- a/federation/signingkey.go
+++ b/federation/signingkey.go
@@ -14,9 +14,11 @@ import (
"strings"
"time"
+ "github.com/tidwall/sjson"
"go.mau.fi/util/jsontime"
"maunium.net/go/mautrix/crypto/canonicaljson"
+ "maunium.net/go/mautrix/federation/signutil"
"maunium.net/go/mautrix/id"
)
@@ -31,8 +33,8 @@ type SigningKey struct {
//
// The output of this function can be parsed back into a [SigningKey] using the [ParseSynapseKey] function.
func (sk *SigningKey) SynapseString() string {
- alg, id := sk.ID.Parse()
- return fmt.Sprintf("%s %s %s", alg, id, base64.RawStdEncoding.EncodeToString(sk.Priv.Seed()))
+ alg, keyID := sk.ID.Parse()
+ return fmt.Sprintf("%s %s %s", alg, keyID, base64.RawStdEncoding.EncodeToString(sk.Priv.Seed()))
}
// ParseSynapseKey parses a Synapse-compatible private key string into a SigningKey.
@@ -77,6 +79,37 @@ type ServerKeyResponse struct {
OldVerifyKeys map[id.KeyID]OldVerifyKey `json:"old_verify_keys,omitempty"`
Signatures map[string]map[id.KeyID]string `json:"signatures,omitempty"`
ValidUntilTS jsontime.UnixMilli `json:"valid_until_ts"`
+
+ Raw json.RawMessage `json:"-"`
+}
+
+type QueryKeysResponse struct {
+ ServerKeys []*ServerKeyResponse `json:"server_keys"`
+}
+
+func (skr *ServerKeyResponse) HasKey(keyID id.KeyID) bool {
+ if skr == nil {
+ return false
+ } else if _, ok := skr.VerifyKeys[keyID]; ok {
+ return true
+ }
+ return false
+}
+
+func (skr *ServerKeyResponse) VerifySelfSignature() error {
+ for keyID, key := range skr.VerifyKeys {
+ if err := signutil.VerifyJSON(skr.ServerName, keyID, key.Key, skr.Raw); err != nil {
+ return fmt.Errorf("failed to verify self signature for key %s: %w", keyID, err)
+ }
+ }
+ return nil
+}
+
+type marshalableSKR ServerKeyResponse
+
+func (skr *ServerKeyResponse) UnmarshalJSON(data []byte) error {
+ skr.Raw = data
+ return json.Unmarshal(data, (*marshalableSKR)(skr))
}
type ServerVerifyKey struct {
@@ -92,12 +125,16 @@ type OldVerifyKey struct {
ExpiredTS jsontime.UnixMilli `json:"expired_ts"`
}
-func (sk *SigningKey) SignJSON(data any) ([]byte, error) {
+func (sk *SigningKey) SignJSON(data any) (string, error) {
marshaled, err := json.Marshal(data)
if err != nil {
- return nil, err
+ return "", err
}
- return sk.SignRawJSON(marshaled), nil
+ marshaled, err = sjson.DeleteBytes(marshaled, "signatures")
+ if err != nil {
+ return "", err
+ }
+ return base64.RawStdEncoding.EncodeToString(sk.SignRawJSON(marshaled)), nil
}
func (sk *SigningKey) SignRawJSON(data json.RawMessage) []byte {
@@ -120,7 +157,7 @@ func (sk *SigningKey) GenerateKeyResponse(serverName string, oldVerifyKeys map[i
}
skr.Signatures = map[string]map[id.KeyID]string{
serverName: {
- sk.ID: base64.RawURLEncoding.EncodeToString(signature),
+ sk.ID: signature,
},
}
return skr
diff --git a/federation/signutil/verify.go b/federation/signutil/verify.go
new file mode 100644
index 00000000..ea0e7886
--- /dev/null
+++ b/federation/signutil/verify.go
@@ -0,0 +1,106 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package signutil
+
+import (
+ "crypto/ed25519"
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
+ "go.mau.fi/util/exgjson"
+
+ "maunium.net/go/mautrix/crypto/canonicaljson"
+ "maunium.net/go/mautrix/id"
+)
+
+var ErrSignatureNotFound = errors.New("signature not found")
+var ErrInvalidSignature = errors.New("invalid signature")
+
+func VerifyJSON(serverName string, keyID id.KeyID, key id.SigningKey, data any) error {
+ var err error
+ message, ok := data.(json.RawMessage)
+ if !ok {
+ message, err = json.Marshal(data)
+ if err != nil {
+ return fmt.Errorf("failed to marshal data: %w", err)
+ }
+ }
+ sigVal := gjson.GetBytes(message, exgjson.Path("signatures", serverName, string(keyID)))
+ if sigVal.Type != gjson.String {
+ return ErrSignatureNotFound
+ }
+ message, err = sjson.DeleteBytes(message, "signatures")
+ if err != nil {
+ return fmt.Errorf("failed to delete signatures: %w", err)
+ }
+ message, err = sjson.DeleteBytes(message, "unsigned")
+ if err != nil {
+ return fmt.Errorf("failed to delete unsigned: %w", err)
+ }
+ return VerifyJSONRaw(key, sigVal.Str, message)
+}
+
+func VerifyJSONAny(key id.SigningKey, data any) error {
+ var err error
+ message, ok := data.(json.RawMessage)
+ if !ok {
+ message, err = json.Marshal(data)
+ if err != nil {
+ return fmt.Errorf("failed to marshal data: %w", err)
+ }
+ }
+ sigs := gjson.GetBytes(message, "signatures")
+ if !sigs.IsObject() {
+ return ErrSignatureNotFound
+ }
+ message, err = sjson.DeleteBytes(message, "signatures")
+ if err != nil {
+ return fmt.Errorf("failed to delete signatures: %w", err)
+ }
+ message, err = sjson.DeleteBytes(message, "unsigned")
+ if err != nil {
+ return fmt.Errorf("failed to delete unsigned: %w", err)
+ }
+ var validated bool
+ sigs.ForEach(func(_, value gjson.Result) bool {
+ if !value.IsObject() {
+ return true
+ }
+ value.ForEach(func(_, value gjson.Result) bool {
+ if value.Type != gjson.String {
+ return true
+ }
+ validated = VerifyJSONRaw(key, value.Str, message) == nil
+ return !validated
+ })
+ return !validated
+ })
+ if !validated {
+ return ErrInvalidSignature
+ }
+ return nil
+}
+
+func VerifyJSONRaw(key id.SigningKey, sig string, message json.RawMessage) error {
+ sigBytes, err := base64.RawStdEncoding.DecodeString(sig)
+ if err != nil {
+ return fmt.Errorf("failed to decode signature: %w", err)
+ }
+ keyBytes, err := base64.RawStdEncoding.DecodeString(string(key))
+ if err != nil {
+ return fmt.Errorf("failed to decode key: %w", err)
+ }
+ message = canonicaljson.CanonicalJSONAssumeValid(message)
+ if !ed25519.Verify(keyBytes, message, sigBytes) {
+ return ErrInvalidSignature
+ }
+ return nil
+}
diff --git a/filter.go b/filter.go
index 2603bfb9..54973dab 100644
--- a/filter.go
+++ b/filter.go
@@ -19,45 +19,45 @@ const (
// Filter is used by clients to specify how the server should filter responses to e.g. sync requests
// Specified by: https://spec.matrix.org/v1.2/client-server-api/#filtering
type Filter struct {
- AccountData FilterPart `json:"account_data,omitempty"`
+ AccountData *FilterPart `json:"account_data,omitempty"`
EventFields []string `json:"event_fields,omitempty"`
EventFormat EventFormat `json:"event_format,omitempty"`
- Presence FilterPart `json:"presence,omitempty"`
- Room RoomFilter `json:"room,omitempty"`
+ Presence *FilterPart `json:"presence,omitempty"`
+ Room *RoomFilter `json:"room,omitempty"`
BeeperToDevice *FilterPart `json:"com.beeper.to_device,omitempty"`
}
// RoomFilter is used to define filtering rules for room events
type RoomFilter struct {
- AccountData FilterPart `json:"account_data,omitempty"`
- Ephemeral FilterPart `json:"ephemeral,omitempty"`
+ AccountData *FilterPart `json:"account_data,omitempty"`
+ Ephemeral *FilterPart `json:"ephemeral,omitempty"`
IncludeLeave bool `json:"include_leave,omitempty"`
NotRooms []id.RoomID `json:"not_rooms,omitempty"`
Rooms []id.RoomID `json:"rooms,omitempty"`
- State FilterPart `json:"state,omitempty"`
- Timeline FilterPart `json:"timeline,omitempty"`
+ State *FilterPart `json:"state,omitempty"`
+ Timeline *FilterPart `json:"timeline,omitempty"`
}
// FilterPart is used to define filtering rules for specific categories of events
type FilterPart struct {
- NotRooms []id.RoomID `json:"not_rooms,omitempty"`
- Rooms []id.RoomID `json:"rooms,omitempty"`
- Limit int `json:"limit,omitempty"`
- NotSenders []id.UserID `json:"not_senders,omitempty"`
- NotTypes []event.Type `json:"not_types,omitempty"`
- Senders []id.UserID `json:"senders,omitempty"`
- Types []event.Type `json:"types,omitempty"`
- ContainsURL *bool `json:"contains_url,omitempty"`
-
- LazyLoadMembers bool `json:"lazy_load_members,omitempty"`
- IncludeRedundantMembers bool `json:"include_redundant_members,omitempty"`
+ NotRooms []id.RoomID `json:"not_rooms,omitempty"`
+ Rooms []id.RoomID `json:"rooms,omitempty"`
+ Limit int `json:"limit,omitempty"`
+ NotSenders []id.UserID `json:"not_senders,omitempty"`
+ NotTypes []event.Type `json:"not_types,omitempty"`
+ Senders []id.UserID `json:"senders,omitempty"`
+ Types []event.Type `json:"types,omitempty"`
+ ContainsURL *bool `json:"contains_url,omitempty"`
+ LazyLoadMembers bool `json:"lazy_load_members,omitempty"`
+ IncludeRedundantMembers bool `json:"include_redundant_members,omitempty"`
+ UnreadThreadNotifications bool `json:"unread_thread_notifications,omitempty"`
}
// Validate checks if the filter contains valid property values
func (filter *Filter) Validate() error {
if filter.EventFormat != EventFormatClient && filter.EventFormat != EventFormatFederation {
- return errors.New("Bad event_format value. Must be one of [\"client\", \"federation\"]")
+ return errors.New("bad event_format value")
}
return nil
}
@@ -69,7 +69,7 @@ func DefaultFilter() Filter {
EventFields: nil,
EventFormat: "client",
Presence: DefaultFilterPart(),
- Room: RoomFilter{
+ Room: &RoomFilter{
AccountData: DefaultFilterPart(),
Ephemeral: DefaultFilterPart(),
IncludeLeave: false,
@@ -82,8 +82,8 @@ func DefaultFilter() Filter {
}
// DefaultFilterPart returns the default filter part used by the Matrix server if no filter is provided in the request
-func DefaultFilterPart() FilterPart {
- return FilterPart{
+func DefaultFilterPart() *FilterPart {
+ return &FilterPart{
NotRooms: nil,
Rooms: nil,
Limit: 20,
diff --git a/format/htmlparser.go b/format/htmlparser.go
index 7c3b3c88..e0507d93 100644
--- a/format/htmlparser.go
+++ b/format/htmlparser.go
@@ -13,6 +13,7 @@ import (
"strconv"
"strings"
+ "go.mau.fi/util/exstrings"
"golang.org/x/net/html"
"maunium.net/go/mautrix/event"
@@ -92,6 +93,30 @@ func DefaultPillConverter(displayname, mxid, eventID string, ctx Context) string
}
}
+func onlyBacktickCount(line string) (count int) {
+ for i := 0; i < len(line); i++ {
+ if line[i] != '`' {
+ return -1
+ }
+ count++
+ }
+ return
+}
+
+func DefaultMonospaceBlockConverter(code, language string, ctx Context) string {
+ if len(code) == 0 || code[len(code)-1] != '\n' {
+ code += "\n"
+ }
+ fence := "```"
+ for line := range strings.SplitSeq(code, "\n") {
+ count := onlyBacktickCount(strings.TrimSpace(line))
+ if count >= len(fence) {
+ fence = strings.Repeat("`", count+1)
+ }
+ }
+ return fmt.Sprintf("%s%s\n%s%s", fence, language, code, fence)
+}
+
// HTMLParser is a somewhat customizable Matrix HTML parser.
type HTMLParser struct {
PillConverter PillConverter
@@ -187,25 +212,6 @@ func (parser *HTMLParser) listToString(node *html.Node, ctx Context) string {
return strings.Join(children, "\n")
}
-func LongestSequence(in string, of rune) int {
- currentSeq := 0
- maxSeq := 0
- for _, chr := range in {
- if chr == of {
- currentSeq++
- } else {
- if currentSeq > maxSeq {
- maxSeq = currentSeq
- }
- currentSeq = 0
- }
- }
- if currentSeq > maxSeq {
- maxSeq = currentSeq
- }
- return maxSeq
-}
-
func (parser *HTMLParser) basicFormatToString(node *html.Node, ctx Context) string {
str := parser.nodeToTagAwareString(node.FirstChild, ctx)
switch node.Data {
@@ -232,8 +238,7 @@ func (parser *HTMLParser) basicFormatToString(node *html.Node, ctx Context) stri
if parser.MonospaceConverter != nil {
return parser.MonospaceConverter(str, ctx)
}
- surround := strings.Repeat("`", LongestSequence(str, '`')+1)
- return fmt.Sprintf("%s%s%s", surround, str, surround)
+ return SafeMarkdownCode(str)
}
return str
}
@@ -306,7 +311,10 @@ func (parser *HTMLParser) linkToString(node *html.Node, ctx Context) string {
}
if parser.LinkConverter != nil {
return parser.LinkConverter(str, href, ctx)
- } else if str == href {
+ } else if str == href ||
+ str == strings.TrimPrefix(href, "mailto:") ||
+ str == strings.TrimPrefix(href, "http://") ||
+ str == strings.TrimPrefix(href, "https://") {
return str
}
return fmt.Sprintf("%s (%s)", str, href)
@@ -348,6 +356,8 @@ func (parser *HTMLParser) tagToString(node *html.Node, ctx Context) string {
return parser.imgToString(node, ctx)
case "hr":
return parser.HorizontalLine
+ case "input":
+ return parser.inputToString(node, ctx)
case "pre":
var preStr, language string
if node.FirstChild != nil && node.FirstChild.Type == html.ElementNode && node.FirstChild.Data == "code" {
@@ -362,20 +372,28 @@ func (parser *HTMLParser) tagToString(node *html.Node, ctx Context) string {
if parser.MonospaceBlockConverter != nil {
return parser.MonospaceBlockConverter(preStr, language, ctx)
}
- if len(preStr) == 0 || preStr[len(preStr)-1] != '\n' {
- preStr += "\n"
- }
- return fmt.Sprintf("```%s\n%s```", language, preStr)
+ return DefaultMonospaceBlockConverter(preStr, language, ctx)
default:
return parser.nodeToTagAwareString(node.FirstChild, ctx)
}
}
+func (parser *HTMLParser) inputToString(node *html.Node, ctx Context) string {
+ if len(ctx.TagStack) > 1 && ctx.TagStack[len(ctx.TagStack)-2] == "li" {
+ _, checked := parser.maybeGetAttribute(node, "checked")
+ if checked {
+ return "[x]"
+ }
+ return "[ ]"
+ }
+ return parser.nodeToTagAwareString(node.FirstChild, ctx)
+}
+
func (parser *HTMLParser) singleNodeToString(node *html.Node, ctx Context) TaggedString {
switch node.Type {
case html.TextNode:
if !ctx.PreserveWhitespace {
- node.Data = strings.Replace(node.Data, "\n", "", -1)
+ node.Data = exstrings.CollapseSpaces(strings.ReplaceAll(node.Data, "\n", ""))
}
if parser.TextConverter != nil {
node.Data = parser.TextConverter(node.Data, ctx)
@@ -455,7 +473,7 @@ var MarkdownHTMLParser = &HTMLParser{
PillConverter: DefaultPillConverter,
LinkConverter: func(text, href string, ctx Context) string {
if text == href {
- return text
+ return fmt.Sprintf("<%s>", href)
}
return fmt.Sprintf("[%s](%s)", text, href)
},
diff --git a/format/markdown.go b/format/markdown.go
index d099ba00..77ced0dc 100644
--- a/format/markdown.go
+++ b/format/markdown.go
@@ -8,14 +8,17 @@ package format
import (
"fmt"
+ "regexp"
"strings"
"github.com/yuin/goldmark"
"github.com/yuin/goldmark/extension"
"github.com/yuin/goldmark/renderer/html"
+ "go.mau.fi/util/exstrings"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format/mdext"
+ "maunium.net/go/mautrix/id"
)
const paragraphStart = ""
@@ -39,6 +42,55 @@ func UnwrapSingleParagraph(html string) string {
return html
}
+var mdEscapeRegex = regexp.MustCompile("([\\\\`*_[\\]()])")
+
+func EscapeMarkdown(text string) string {
+ text = mdEscapeRegex.ReplaceAllString(text, "\\$1")
+ text = strings.ReplaceAll(text, ">", ">")
+ text = strings.ReplaceAll(text, "<", "<")
+ return text
+}
+
+type uriAble interface {
+ String() string
+ URI() *id.MatrixURI
+}
+
+func MarkdownMention(id uriAble) string {
+ return MarkdownMentionWithName(id.String(), id)
+}
+
+func MarkdownMentionWithName(name string, id uriAble) string {
+ return MarkdownLink(name, id.URI().MatrixToURL())
+}
+
+func MarkdownMentionRoomID(name string, id id.RoomID, via ...string) string {
+ if name == "" {
+ name = id.String()
+ }
+ return MarkdownLink(name, id.URI(via...).MatrixToURL())
+}
+
+func MarkdownLink(name string, url string) string {
+ return fmt.Sprintf("[%s](%s)", EscapeMarkdown(name), EscapeMarkdown(url))
+}
+
+func SafeMarkdownCode[T ~string](textInput T) string {
+ if textInput == "" {
+ return "` `"
+ }
+ text := strings.ReplaceAll(string(textInput), "\n", " ")
+ backtickCount := exstrings.LongestSequenceOf(text, '`')
+ if backtickCount == 0 {
+ return fmt.Sprintf("`%s`", text)
+ }
+ quotes := strings.Repeat("`", backtickCount+1)
+ if text[0] == '`' || text[len(text)-1] == '`' {
+ return fmt.Sprintf("%s %s %s", quotes, text, quotes)
+ }
+ return fmt.Sprintf("%s%s%s", quotes, text, quotes)
+}
+
func RenderMarkdownCustom(text string, renderer goldmark.Markdown) event.MessageEventContent {
var buf strings.Builder
err := renderer.Convert([]byte(text), &buf)
diff --git a/format/markdown_test.go b/format/markdown_test.go
index d4e7d716..46ea4886 100644
--- a/format/markdown_test.go
+++ b/format/markdown_test.go
@@ -196,3 +196,18 @@ func TestRenderMarkdown_CustomEmoji(t *testing.T) {
assert.Equal(t, html, rendered, "with input %q", markdown)
}
}
+
+var codeTests = map[string]string{
+ "meow": "`meow`",
+ "me`ow": "``me`ow``",
+ "`me`ow": "`` `me`ow ``",
+ "me`ow`": "`` me`ow` ``",
+ "`meow`": "`` `meow` ``",
+ "`````````": "`````````` ````````` ``````````",
+}
+
+func TestSafeMarkdownCode(t *testing.T) {
+ for input, expected := range codeTests {
+ assert.Equal(t, expected, format.SafeMarkdownCode(input), "with input %q", input)
+ }
+}
diff --git a/go.mod b/go.mod
index 8bf9baac..49a1d4e4 100644
--- a/go.mod
+++ b/go.mod
@@ -1,43 +1,42 @@
module maunium.net/go/mautrix
-go 1.22.0
+go 1.25.0
-toolchain go1.23.3
+toolchain go1.26.0
require (
- filippo.io/edwards25519 v1.1.0
+ filippo.io/edwards25519 v1.2.0
github.com/chzyer/readline v1.5.1
- github.com/gorilla/mux v1.8.0
- github.com/gorilla/websocket v1.5.0
- github.com/lib/pq v1.10.9
- github.com/mattn/go-sqlite3 v1.14.24
+ github.com/coder/websocket v1.8.14
+ github.com/lib/pq v1.11.2
+ github.com/mattn/go-sqlite3 v1.14.34
github.com/rs/xid v1.6.0
- github.com/rs/zerolog v1.33.0
+ github.com/rs/zerolog v1.34.0
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
- github.com/stretchr/testify v1.9.0
+ github.com/stretchr/testify v1.11.1
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
- github.com/yuin/goldmark v1.7.8
- go.mau.fi/util v0.8.2
- go.mau.fi/zeroconfig v0.1.3
- golang.org/x/crypto v0.29.0
- golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f
- golang.org/x/net v0.31.0
- golang.org/x/sync v0.9.0
+ github.com/yuin/goldmark v1.7.16
+ go.mau.fi/util v0.9.6
+ go.mau.fi/zeroconfig v0.2.0
+ golang.org/x/crypto v0.48.0
+ golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa
+ golang.org/x/net v0.50.0
+ golang.org/x/sync v0.19.0
gopkg.in/yaml.v3 v3.0.1
maunium.net/go/mauflag v1.0.0
)
require (
- github.com/coreos/go-systemd/v22 v22.5.0 // indirect
+ github.com/coreos/go-systemd/v22 v22.6.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
- github.com/mattn/go-colorable v0.1.13 // indirect
- github.com/mattn/go-isatty v0.0.19 // indirect
- github.com/petermattis/goid v0.0.0-20241025130422-66cb2e6d7274 // indirect
+ github.com/mattn/go-colorable v0.1.14 // indirect
+ github.com/mattn/go-isatty v0.0.20 // indirect
+ github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
- github.com/tidwall/pretty v1.2.0 // indirect
- golang.org/x/sys v0.27.0 // indirect
- golang.org/x/text v0.20.0 // indirect
+ github.com/tidwall/pretty v1.2.1 // indirect
+ golang.org/x/sys v0.41.0 // indirect
+ golang.org/x/text v0.34.0 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
)
diff --git a/go.sum b/go.sum
index 205cbfaf..871a5156 100644
--- a/go.sum
+++ b/go.sum
@@ -1,5 +1,5 @@
-filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
-filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
+filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo=
+filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM=
@@ -8,69 +8,70 @@ github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI
github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk=
github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04=
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
-github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs=
+github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
+github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
+github.com/coreos/go-systemd/v22 v22.6.0 h1:aGVa/v8B7hpb0TKl0MWoAavPDmHvobFe5R5zn0bCJWo=
+github.com/coreos/go-systemd/v22 v22.6.0/go.mod h1:iG+pp635Fo7ZmV/j14KUcmEyWF+0X7Lua8rrTWzYgWU=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
-github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
-github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
-github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
-github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
-github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
-github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
-github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
+github.com/lib/pq v1.11.2 h1:x6gxUeu39V0BHZiugWe8LXZYZ+Utk7hSJGThs8sdzfs=
+github.com/lib/pq v1.11.2/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
+github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
+github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
-github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
-github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
-github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
-github.com/petermattis/goid v0.0.0-20241025130422-66cb2e6d7274 h1:qli3BGQK0tYDkSEvZ/FzZTi9ZrOX86Q6CIhKLGc489A=
-github.com/petermattis/goid v0.0.0-20241025130422-66cb2e6d7274/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4=
+github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
+github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
+github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp6Zk=
+github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
+github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 h1:KPpdlQLZcHfTMQRi6bFQ7ogNO0ltFT4PmtwTLW4W+14=
+github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
-github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU=
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
-github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
-github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
+github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
+github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
-github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
-github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
+github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
+github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
-github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
+github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
+github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
-github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic=
-github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
-go.mau.fi/util v0.8.2 h1:zWbVHwdRKwI6U9AusmZ8bwgcLosikwbb4GGqLrNr1YE=
-go.mau.fi/util v0.8.2/go.mod h1:BHHC9R2WLMJd1bwTZfTcFxUgRFmUgUmiWcT4RbzUgiA=
-go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM=
-go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70=
-golang.org/x/crypto v0.29.0 h1:L5SG1JTTXupVV3n6sUqMTeWbjAyfPwoda2DLX8J8FrQ=
-golang.org/x/crypto v0.29.0/go.mod h1:+F4F4N5hv6v38hfeYwTdx20oUvLLc+QfrE9Ax9HtgRg=
-golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f h1:XdNn9LlyWAhLVp6P/i8QYBW+hlyhrhei9uErw2B5GJo=
-golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f/go.mod h1:D5SMRVC3C2/4+F/DB1wZsLRnSNimn2Sp/NPsCrsv8ak=
-golang.org/x/net v0.31.0 h1:68CPQngjLL0r2AlUKiSxtQFKvzRVbnzLwMUn5SzcLHo=
-golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM=
-golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ=
-golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
+github.com/yuin/goldmark v1.7.16 h1:n+CJdUxaFMiDUNnWC3dMWCIQJSkxH4uz3ZwQBkAlVNE=
+github.com/yuin/goldmark v1.7.16/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
+go.mau.fi/util v0.9.6 h1:2nsvxm49KhI3wrFltr0+wSUBlnQ4CMtykuELjpIU+ts=
+go.mau.fi/util v0.9.6/go.mod h1:sIJpRH7Iy5Ad1SBuxQoatxtIeErgzxCtjd/2hCMkYMI=
+go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU=
+go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w=
+golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
+golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
+golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa h1:Zt3DZoOFFYkKhDT3v7Lm9FDMEV06GpzjG2jrqW+QTE0=
+golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa/go.mod h1:K79w1Vqn7PoiZn+TkNpx3BUWUQksGO3JcVX6qIjytmA=
+golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60=
+golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM=
+golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
+golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s=
-golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
-golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug=
-golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=
+golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
+golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
+golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
+golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
diff --git a/id/contenturi.go b/id/contenturi.go
index e6a313f5..67127b6c 100644
--- a/id/contenturi.go
+++ b/id/contenturi.go
@@ -17,8 +17,14 @@ import (
)
var (
- InvalidContentURI = errors.New("invalid Matrix content URI")
- InputNotJSONString = errors.New("input doesn't look like a JSON string")
+ ErrInvalidContentURI = errors.New("invalid Matrix content URI")
+ ErrInputNotJSONString = errors.New("input doesn't look like a JSON string")
+)
+
+// Deprecated: use variables prefixed with Err
+var (
+ InvalidContentURI = ErrInvalidContentURI
+ InputNotJSONString = ErrInputNotJSONString
)
// ContentURIString is a string that's expected to be a Matrix content URI.
@@ -55,9 +61,9 @@ func ParseContentURI(uri string) (parsed ContentURI, err error) {
if len(uri) == 0 {
return
} else if !strings.HasPrefix(uri, "mxc://") {
- err = InvalidContentURI
+ err = ErrInvalidContentURI
} else if index := strings.IndexRune(uri[6:], '/'); index == -1 || index == len(uri)-7 {
- err = InvalidContentURI
+ err = ErrInvalidContentURI
} else {
parsed.Homeserver = uri[6 : 6+index]
parsed.FileID = uri[6+index+1:]
@@ -71,9 +77,9 @@ func ParseContentURIBytes(uri []byte) (parsed ContentURI, err error) {
if len(uri) == 0 {
return
} else if !bytes.HasPrefix(uri, mxcBytes) {
- err = InvalidContentURI
+ err = ErrInvalidContentURI
} else if index := bytes.IndexRune(uri[6:], '/'); index == -1 || index == len(uri)-7 {
- err = InvalidContentURI
+ err = ErrInvalidContentURI
} else {
parsed.Homeserver = string(uri[6 : 6+index])
parsed.FileID = string(uri[6+index+1:])
@@ -86,7 +92,7 @@ func (uri *ContentURI) UnmarshalJSON(raw []byte) (err error) {
*uri = ContentURI{}
return nil
} else if len(raw) < 2 || raw[0] != '"' || raw[len(raw)-1] != '"' {
- return InputNotJSONString
+ return fmt.Errorf("ContentURI: %w", ErrInputNotJSONString)
}
parsed, err := ParseContentURIBytes(raw[1 : len(raw)-1])
if err != nil {
diff --git a/id/crypto.go b/id/crypto.go
index 355a84a8..ee857f78 100644
--- a/id/crypto.go
+++ b/id/crypto.go
@@ -53,6 +53,34 @@ const (
KeyBackupAlgorithmMegolmBackupV1 KeyBackupAlgorithm = "m.megolm_backup.v1.curve25519-aes-sha2"
)
+type KeySource string
+
+func (source KeySource) String() string {
+ return string(source)
+}
+
+func (source KeySource) Int() int {
+ switch source {
+ case KeySourceDirect:
+ return 100
+ case KeySourceBackup:
+ return 90
+ case KeySourceImport:
+ return 80
+ case KeySourceForward:
+ return 50
+ default:
+ return 0
+ }
+}
+
+const (
+ KeySourceDirect KeySource = "direct"
+ KeySourceBackup KeySource = "backup"
+ KeySourceImport KeySource = "import"
+ KeySourceForward KeySource = "forward"
+)
+
// BackupVersion is an arbitrary string that identifies a server side key backup.
type KeyBackupVersion string
diff --git a/id/matrixuri.go b/id/matrixuri.go
index 2637d876..d5c78bc7 100644
--- a/id/matrixuri.go
+++ b/id/matrixuri.go
@@ -54,7 +54,7 @@ var SigilToPathSegment = map[rune]string{
func (uri *MatrixURI) getQuery() url.Values {
q := make(url.Values)
- if uri.Via != nil && len(uri.Via) > 0 {
+ if len(uri.Via) > 0 {
q["via"] = uri.Via
}
if len(uri.Action) > 0 {
@@ -210,7 +210,11 @@ func ProcessMatrixURI(uri *url.URL) (*MatrixURI, error) {
if len(parts[1]) == 0 {
return nil, ErrEmptySecondSegment
}
- parsed.MXID1 = parts[1]
+ var err error
+ parsed.MXID1, err = url.PathUnescape(parts[1])
+ if err != nil {
+ return nil, fmt.Errorf("failed to url decode second segment %q: %w", parts[1], err)
+ }
// Step 6: if the first part is a room and the URI has 4 segments, construct a second level identifier
if parsed.Sigil1 == '!' && len(parts) == 4 {
@@ -226,7 +230,10 @@ func ProcessMatrixURI(uri *url.URL) (*MatrixURI, error) {
if len(parts[3]) == 0 {
return nil, ErrEmptyFourthSegment
}
- parsed.MXID2 = parts[3]
+ parsed.MXID2, err = url.PathUnescape(parts[3])
+ if err != nil {
+ return nil, fmt.Errorf("failed to url decode fourth segment %q: %w", parts[3], err)
+ }
}
// Step 7: parse the query and extract via and action items
diff --git a/id/matrixuri_test.go b/id/matrixuri_test.go
index 8b1096cb..90a0754d 100644
--- a/id/matrixuri_test.go
+++ b/id/matrixuri_test.go
@@ -77,8 +77,12 @@ func TestParseMatrixURI_RoomID(t *testing.T) {
parsedVia, err := id.ParseMatrixURI("matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org?via=maunium.net&via=matrix.org")
require.NoError(t, err)
require.NotNil(t, parsedVia)
+ parsedEncoded, err := id.ParseMatrixURI("matrix:roomid/7NdBVvkd4aLSbgKt9RXl%3Aexample.org")
+ require.NoError(t, err)
+ require.NotNil(t, parsedEncoded)
assert.Equal(t, roomIDLink, *parsed)
+ assert.Equal(t, roomIDLink, *parsedEncoded)
assert.Equal(t, roomIDViaLink, *parsedVia)
}
diff --git a/id/opaque.go b/id/opaque.go
index 1d9f0dcf..c1ad4988 100644
--- a/id/opaque.go
+++ b/id/opaque.go
@@ -32,6 +32,9 @@ type EventID string
// https://github.com/matrix-org/matrix-doc/pull/2716
type BatchID string
+// A DelayID is a string identifying a delayed event.
+type DelayID string
+
func (roomID RoomID) String() string {
return string(roomID)
}
diff --git a/id/roomversion.go b/id/roomversion.go
new file mode 100644
index 00000000..578c10bd
--- /dev/null
+++ b/id/roomversion.go
@@ -0,0 +1,265 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package id
+
+import (
+ "errors"
+ "fmt"
+ "slices"
+)
+
+type RoomVersion string
+
+const (
+ RoomV0 RoomVersion = "" // No room version, used for rooms created before room versions were introduced, equivalent to v1
+ RoomV1 RoomVersion = "1"
+ RoomV2 RoomVersion = "2"
+ RoomV3 RoomVersion = "3"
+ RoomV4 RoomVersion = "4"
+ RoomV5 RoomVersion = "5"
+ RoomV6 RoomVersion = "6"
+ RoomV7 RoomVersion = "7"
+ RoomV8 RoomVersion = "8"
+ RoomV9 RoomVersion = "9"
+ RoomV10 RoomVersion = "10"
+ RoomV11 RoomVersion = "11"
+ RoomV12 RoomVersion = "12"
+)
+
+func (rv RoomVersion) Equals(versions ...RoomVersion) bool {
+ return slices.Contains(versions, rv)
+}
+
+func (rv RoomVersion) NotEquals(versions ...RoomVersion) bool {
+ return !rv.Equals(versions...)
+}
+
+var ErrUnknownRoomVersion = errors.New("unknown room version")
+
+func (rv RoomVersion) unknownVersionError() error {
+ return fmt.Errorf("%w %s", ErrUnknownRoomVersion, rv)
+}
+
+func (rv RoomVersion) IsKnown() bool {
+ switch rv {
+ case RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11, RoomV12:
+ return true
+ default:
+ return false
+ }
+}
+
+type StateResVersion int
+
+const (
+ // StateResV1 is the original state resolution algorithm.
+ StateResV1 StateResVersion = 0
+ // StateResV2 is state resolution v2 introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/1759
+ StateResV2 StateResVersion = 1
+ // StateResV2_1 is state resolution v2.1 introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/4297
+ StateResV2_1 StateResVersion = 2
+)
+
+// StateResVersion returns the version of the state resolution algorithm used by this room version.
+func (rv RoomVersion) StateResVersion() StateResVersion {
+ switch rv {
+ case RoomV0, RoomV1:
+ return StateResV1
+ case RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11:
+ return StateResV2
+ case RoomV12:
+ return StateResV2_1
+ default:
+ panic(rv.unknownVersionError())
+ }
+}
+
+type EventIDFormat int
+
+const (
+ // EventIDFormatCustom is the original format used by room v1 and v2.
+ // Event IDs in this format are an arbitrary string followed by a colon and the server name.
+ EventIDFormatCustom EventIDFormat = 0
+ // EventIDFormatBase64 is the format used by room v3 introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/1659.
+ // Event IDs in this format are the standard unpadded base64-encoded SHA256 reference hash of the event.
+ EventIDFormatBase64 EventIDFormat = 1
+ // EventIDFormatURLSafeBase64 is the format used by room v4 and later introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/2002.
+ // Event IDs in this format are the url-safe unpadded base64-encoded SHA256 reference hash of the event.
+ EventIDFormatURLSafeBase64 EventIDFormat = 2
+)
+
+// EventIDFormat returns the format of event IDs used by this room version.
+func (rv RoomVersion) EventIDFormat() EventIDFormat {
+ switch rv {
+ case RoomV0, RoomV1, RoomV2:
+ return EventIDFormatCustom
+ case RoomV3:
+ return EventIDFormatBase64
+ default:
+ return EventIDFormatURLSafeBase64
+ }
+}
+
+/////////////////////
+// Room v5 changes //
+/////////////////////
+// https://github.com/matrix-org/matrix-spec-proposals/pull/2077
+
+// EnforceSigningKeyValidity returns true if the `valid_until_ts` field of federation signing keys
+// must be enforced on received events.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/2076
+func (rv RoomVersion) EnforceSigningKeyValidity() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4)
+}
+
+/////////////////////
+// Room v6 changes //
+/////////////////////
+// https://github.com/matrix-org/matrix-spec-proposals/pull/2240
+
+// SpecialCasedAliasesAuth returns true if the `m.room.aliases` event authorization is special cased
+// to only always allow servers to modify the state event with their own server name as state key.
+// This also implies that the `aliases` field is protected from redactions.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/2432
+func (rv RoomVersion) SpecialCasedAliasesAuth() bool {
+ return rv.Equals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5)
+}
+
+// ForbidFloatsAndBigInts returns true if floats and integers greater than 2^53-1 or lower than -2^53+1 are forbidden everywhere.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/2540
+func (rv RoomVersion) ForbidFloatsAndBigInts() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5)
+}
+
+// NotificationsPowerLevels returns true if the `notifications` field in `m.room.power_levels` is validated in event auth.
+// However, the field is not protected from redactions.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/2209
+func (rv RoomVersion) NotificationsPowerLevels() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5)
+}
+
+/////////////////////
+// Room v7 changes //
+/////////////////////
+// https://github.com/matrix-org/matrix-spec-proposals/pull/2998
+
+// Knocks returns true if the `knock` join rule is supported.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/2403
+func (rv RoomVersion) Knocks() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6)
+}
+
+/////////////////////
+// Room v8 changes //
+/////////////////////
+// https://github.com/matrix-org/matrix-spec-proposals/pull/3289
+
+// RestrictedJoins returns true if the `restricted` join rule is supported.
+// This also implies that the `allow` field in the `m.room.join_rules` event is supported and protected from redactions.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/3083
+func (rv RoomVersion) RestrictedJoins() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7)
+}
+
+/////////////////////
+// Room v9 changes //
+/////////////////////
+// https://github.com/matrix-org/matrix-spec-proposals/pull/3375
+
+// RestrictedJoinsFix returns true if the `join_authorised_via_users_server` field in `m.room.member` events is protected from redactions.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/3375
+func (rv RoomVersion) RestrictedJoinsFix() bool {
+ return rv.RestrictedJoins() && rv != RoomV8
+}
+
+//////////////////////
+// Room v10 changes //
+//////////////////////
+// https://github.com/matrix-org/matrix-spec-proposals/pull/3604
+
+// ValidatePowerLevelInts returns true if the known values in `m.room.power_levels` must be integers (and not strings).
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/3667
+func (rv RoomVersion) ValidatePowerLevelInts() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9)
+}
+
+// KnockRestricted returns true if the `knock_restricted` join rule is supported.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/3787
+func (rv RoomVersion) KnockRestricted() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9)
+}
+
+//////////////////////
+// Room v11 changes //
+//////////////////////
+// https://github.com/matrix-org/matrix-spec-proposals/pull/3820
+
+// CreatorInContent returns true if the `m.room.create` event has a `creator` field in content.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/2175
+func (rv RoomVersion) CreatorInContent() bool {
+ return rv.Equals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10)
+}
+
+// RedactsInContent returns true if the `m.room.redaction` event has the `redacts` field in content instead of at the top level.
+// The redaction protection is also moved from the top level to the content field.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/2174
+// (and https://github.com/matrix-org/matrix-spec-proposals/pull/2176 for the redaction protection).
+func (rv RoomVersion) RedactsInContent() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10)
+}
+
+// UpdatedRedactionRules returns true if various updates to the redaction algorithm are applied.
+//
+// Specifically:
+//
+// * the `membership`, `origin`, and `prev_state` fields at the top level of all events are no longer protected.
+// * the entire content of `m.room.create` is protected.
+// * the `redacts` field in `m.room.redaction` content is protected instead of the top-level field.
+// * the `m.room.power_levels` event protects the `invite` field in content.
+// * the `signed` field inside the `third_party_invite` field in content of `m.room.member` events is protected.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/2176,
+// https://github.com/matrix-org/matrix-spec-proposals/pull/3821, and
+// https://github.com/matrix-org/matrix-spec-proposals/pull/3989
+func (rv RoomVersion) UpdatedRedactionRules() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10)
+}
+
+//////////////////////
+// Room v12 changes //
+//////////////////////
+// https://github.com/matrix-org/matrix-spec-proposals/pull/4304
+
+// Return value of StateResVersion was changed to StateResV2_1
+
+// PrivilegedRoomCreators returns true if the creator(s) of a room always have infinite power level.
+// This also implies that the `m.room.create` event has an `additional_creators` field,
+// and that the creators can't be present in the `m.room.power_levels` event.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/4289
+func (rv RoomVersion) PrivilegedRoomCreators() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11)
+}
+
+// RoomIDIsCreateEventID returns true if the ID of rooms is the same as the ID of the `m.room.create` event.
+// This also implies that `m.room.create` events do not have a `room_id` field.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/4291
+func (rv RoomVersion) RoomIDIsCreateEventID() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11)
+}
diff --git a/id/servername.go b/id/servername.go
new file mode 100644
index 00000000..923705b6
--- /dev/null
+++ b/id/servername.go
@@ -0,0 +1,58 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package id
+
+import (
+ "regexp"
+ "strconv"
+)
+
+type ParsedServerNameType int
+
+const (
+ ServerNameDNS ParsedServerNameType = iota
+ ServerNameIPv4
+ ServerNameIPv6
+)
+
+type ParsedServerName struct {
+ Type ParsedServerNameType
+ Host string
+ Port int
+}
+
+var ServerNameRegex = regexp.MustCompile(`^(?:\[([0-9A-Fa-f:.]{2,45})]|(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})|([0-9A-Za-z.-]{1,255}))(?::(\d{1,5}))?$`)
+
+func ValidateServerName(serverName string) bool {
+ return len(serverName) <= 255 && len(serverName) > 0 && ServerNameRegex.MatchString(serverName)
+}
+
+func ParseServerName(serverName string) *ParsedServerName {
+ if len(serverName) > 255 || len(serverName) < 1 {
+ return nil
+ }
+ match := ServerNameRegex.FindStringSubmatch(serverName)
+ if len(match) != 5 {
+ return nil
+ }
+ port, _ := strconv.Atoi(match[4])
+ parsed := &ParsedServerName{
+ Port: port,
+ }
+ switch {
+ case match[1] != "":
+ parsed.Type = ServerNameIPv6
+ parsed.Host = match[1]
+ case match[2] != "":
+ parsed.Type = ServerNameIPv4
+ parsed.Host = match[2]
+ case match[3] != "":
+ parsed.Type = ServerNameDNS
+ parsed.Host = match[3]
+ }
+ return parsed
+}
diff --git a/id/trust.go b/id/trust.go
index 04f6e36b..6255093e 100644
--- a/id/trust.go
+++ b/id/trust.go
@@ -16,6 +16,7 @@ type TrustState int
const (
TrustStateBlacklisted TrustState = -100
+ TrustStateDeviceKeyMismatch TrustState = -5
TrustStateUnset TrustState = 0
TrustStateUnknownDevice TrustState = 10
TrustStateForwarded TrustState = 20
@@ -23,7 +24,7 @@ const (
TrustStateCrossSignedTOFU TrustState = 100
TrustStateCrossSignedVerified TrustState = 200
TrustStateVerified TrustState = 300
- TrustStateInvalid TrustState = (1 << 31) - 1
+ TrustStateInvalid TrustState = -2147483647
)
func (ts *TrustState) UnmarshalText(data []byte) error {
@@ -44,6 +45,8 @@ func ParseTrustState(val string) TrustState {
switch strings.ToLower(val) {
case "blacklisted":
return TrustStateBlacklisted
+ case "device-key-mismatch":
+ return TrustStateDeviceKeyMismatch
case "unverified":
return TrustStateUnset
case "cross-signed-untrusted":
@@ -67,6 +70,8 @@ func (ts TrustState) String() string {
switch ts {
case TrustStateBlacklisted:
return "blacklisted"
+ case TrustStateDeviceKeyMismatch:
+ return "device-key-mismatch"
case TrustStateUnset:
return "unverified"
case TrustStateCrossSignedUntrusted:
diff --git a/id/userid.go b/id/userid.go
index 1e1f3b29..726a0d58 100644
--- a/id/userid.go
+++ b/id/userid.go
@@ -30,10 +30,11 @@ func NewEncodedUserID(localpart, homeserver string) UserID {
}
var (
- ErrInvalidUserID = errors.New("is not a valid user ID")
- ErrNoncompliantLocalpart = errors.New("contains characters that are not allowed")
- ErrUserIDTooLong = errors.New("the given user ID is longer than 255 characters")
- ErrEmptyLocalpart = errors.New("empty localparts are not allowed")
+ ErrInvalidUserID = errors.New("is not a valid user ID")
+ ErrNoncompliantLocalpart = errors.New("contains characters that are not allowed")
+ ErrUserIDTooLong = errors.New("the given user ID is longer than 255 characters")
+ ErrEmptyLocalpart = errors.New("empty localparts are not allowed")
+ ErrNoncompliantServerPart = errors.New("is not a valid server name")
)
// ParseCommonIdentifier parses a common identifier according to https://spec.matrix.org/v1.9/appendices/#common-identifier-format
@@ -43,10 +44,10 @@ func ParseCommonIdentifier[Stringish ~string](identifier Stringish) (sigil byte,
}
sigil = identifier[0]
strIdentifier := string(identifier)
- if strings.ContainsRune(strIdentifier, ':') {
- parts := strings.SplitN(strIdentifier, ":", 2)
- localpart = parts[0][1:]
- homeserver = parts[1]
+ colonIdx := strings.IndexByte(strIdentifier, ':')
+ if colonIdx > 0 {
+ localpart = strIdentifier[1:colonIdx]
+ homeserver = strIdentifier[colonIdx+1:]
} else {
localpart = strIdentifier[1:]
}
@@ -103,21 +104,32 @@ func ValidateUserLocalpart(localpart string) error {
return nil
}
-// ParseAndValidate parses the user ID into the localpart and server name like Parse,
-// and also validates that the localpart is allowed according to the user identifiers spec.
-func (userID UserID) ParseAndValidate() (localpart, homeserver string, err error) {
- localpart, homeserver, err = userID.Parse()
+// ParseAndValidateStrict is a stricter version of ParseAndValidateRelaxed that checks the localpart to only allow non-historical localparts.
+// This should be used with care: there are real users still using historical localparts.
+func (userID UserID) ParseAndValidateStrict() (localpart, homeserver string, err error) {
+ localpart, homeserver, err = userID.ParseAndValidateRelaxed()
if err == nil {
err = ValidateUserLocalpart(localpart)
}
- if err == nil && len(userID) > UserIDMaxLength {
+ return
+}
+
+// ParseAndValidateRelaxed parses the user ID into the localpart and server name like Parse,
+// and also validates that the user ID is not too long and that the server name is valid.
+func (userID UserID) ParseAndValidateRelaxed() (localpart, homeserver string, err error) {
+ if len(userID) > UserIDMaxLength {
err = ErrUserIDTooLong
+ return
+ }
+ localpart, homeserver, err = userID.Parse()
+ if err == nil && !ValidateServerName(homeserver) {
+ err = fmt.Errorf("%q %q", homeserver, ErrNoncompliantServerPart)
}
return
}
func (userID UserID) ParseAndDecode() (localpart, homeserver string, err error) {
- localpart, homeserver, err = userID.ParseAndValidate()
+ localpart, homeserver, err = userID.ParseAndValidateStrict()
if err == nil {
localpart, err = DecodeUserLocalpart(localpart)
}
@@ -207,15 +219,15 @@ func DecodeUserLocalpart(str string) (string, error) {
for i := 0; i < len(strBytes); i++ {
b := strBytes[i]
if !isValidByte(b) {
- return "", fmt.Errorf("Byte pos %d: Invalid byte", i)
+ return "", fmt.Errorf("invalid encoded byte at position %d: %c", i, b)
}
if b == '_' { // next byte is a-z and should be upper-case or is another _ and should be a literal _
if i+1 >= len(strBytes) {
- return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding but ran out of string", i)
+ return "", fmt.Errorf("unexpected end of string after underscore at %d", i)
}
if !isValidEscapedChar(strBytes[i+1]) { // invalid escaping
- return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding", i)
+ return "", fmt.Errorf("unexpected byte %c after underscore at %d", strBytes[i+1], i)
}
if strBytes[i+1] == '_' {
outputBuffer.WriteByte('_')
@@ -225,7 +237,7 @@ func DecodeUserLocalpart(str string) (string, error) {
i++ // skip next byte since we just handled it
} else if b == '=' { // next 2 bytes are hex and should be buffered ready to be read as utf8
if i+2 >= len(strBytes) {
- return "", fmt.Errorf("Byte pos: %d: expected quote-printable encoding but ran out of string", i)
+ return "", fmt.Errorf("unexpected end of string after equals sign at %d", i)
}
dst := make([]byte, 1)
_, err := hex.Decode(dst, strBytes[i+1:i+3])
diff --git a/id/userid_test.go b/id/userid_test.go
index 359bc687..57a88066 100644
--- a/id/userid_test.go
+++ b/id/userid_test.go
@@ -38,30 +38,30 @@ func TestUserID_Parse_Invalid(t *testing.T) {
assert.True(t, errors.Is(err, id.ErrInvalidUserID))
}
-func TestUserID_ParseAndValidate_Invalid(t *testing.T) {
+func TestUserID_ParseAndValidateStrict_Invalid(t *testing.T) {
const inputUserID = "@s p a c e:maunium.net"
- _, _, err := id.UserID(inputUserID).ParseAndValidate()
+ _, _, err := id.UserID(inputUserID).ParseAndValidateStrict()
assert.Error(t, err)
assert.True(t, errors.Is(err, id.ErrNoncompliantLocalpart))
}
-func TestUserID_ParseAndValidate_Empty(t *testing.T) {
+func TestUserID_ParseAndValidateStrict_Empty(t *testing.T) {
const inputUserID = "@:ponies.im"
- _, _, err := id.UserID(inputUserID).ParseAndValidate()
+ _, _, err := id.UserID(inputUserID).ParseAndValidateStrict()
assert.Error(t, err)
assert.True(t, errors.Is(err, id.ErrEmptyLocalpart))
}
-func TestUserID_ParseAndValidate_Long(t *testing.T) {
+func TestUserID_ParseAndValidateStrict_Long(t *testing.T) {
const inputUserID = "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa:example.com"
- _, _, err := id.UserID(inputUserID).ParseAndValidate()
+ _, _, err := id.UserID(inputUserID).ParseAndValidateStrict()
assert.Error(t, err)
assert.True(t, errors.Is(err, id.ErrUserIDTooLong))
}
-func TestUserID_ParseAndValidate_NotLong(t *testing.T) {
+func TestUserID_ParseAndValidateStrict_NotLong(t *testing.T) {
const inputUserID = "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa:example.com"
- _, _, err := id.UserID(inputUserID).ParseAndValidate()
+ _, _, err := id.UserID(inputUserID).ParseAndValidateStrict()
assert.NoError(t, err)
}
@@ -70,7 +70,7 @@ func TestUserIDEncoding(t *testing.T) {
const encodedLocalpart = "_this=20local+part=20contains=20_il_le_ga_l=20ch=c3=a4racters=20=f0=9f=9a=a8"
const inputServerName = "example.com"
userID := id.NewEncodedUserID(inputLocalpart, inputServerName)
- parsedLocalpart, parsedServerName, err := userID.ParseAndValidate()
+ parsedLocalpart, parsedServerName, err := userID.ParseAndValidateStrict()
assert.NoError(t, err)
assert.Equal(t, encodedLocalpart, parsedLocalpart)
assert.Equal(t, inputServerName, parsedServerName)
diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go
index ff8b2157..4d2bc7cf 100644
--- a/mediaproxy/mediaproxy.go
+++ b/mediaproxy/mediaproxy.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2024 Tulir Asokan
+// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -8,7 +8,6 @@ package mediaproxy
import (
"context"
- "encoding/json"
"errors"
"fmt"
"io"
@@ -22,11 +21,16 @@ import (
"strings"
"time"
- "github.com/gorilla/mux"
"github.com/rs/zerolog"
+ "github.com/rs/zerolog/hlog"
+ "go.mau.fi/util/exerrors"
+ "go.mau.fi/util/exhttp"
+ "go.mau.fi/util/ptr"
+ "go.mau.fi/util/requestlog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/federation"
+ "maunium.net/go/mautrix/id"
)
type GetMediaResponse interface {
@@ -91,17 +95,20 @@ func (d *GetMediaResponseCallback) GetContentType() string {
return d.ContentType
}
+type FileMeta struct {
+ ContentType string
+ ReplacementFile string
+}
+
type GetMediaResponseFile struct {
- Callback func(w *os.File) error
- ContentType string
+ Callback func(w *os.File) (*FileMeta, error)
}
type GetMediaFunc = func(ctx context.Context, mediaID string, params map[string]string) (response GetMediaResponse, err error)
type MediaProxy struct {
- KeyServer *federation.KeyServer
-
- ForceProxyLegacyFederation bool
+ KeyServer *federation.KeyServer
+ ServerAuth *federation.ServerAuth
GetMedia GetMediaFunc
PrepareProxyRequest func(*http.Request)
@@ -109,8 +116,8 @@ type MediaProxy struct {
serverName string
serverKey *federation.SigningKey
- FederationRouter *mux.Router
- ClientMediaRouter *mux.Router
+ FederationRouter *http.ServeMux
+ ClientMediaRouter *http.ServeMux
}
func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProxy, error) {
@@ -118,7 +125,7 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx
if err != nil {
return nil, err
}
- return &MediaProxy{
+ mp := &MediaProxy{
serverName: serverName,
serverKey: parsed,
GetMedia: getMedia,
@@ -133,12 +140,27 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx
Version: strings.TrimPrefix(mautrix.VersionWithCommit, "v"),
},
},
- }, nil
+ }
+ mp.FederationRouter = http.NewServeMux()
+ mp.FederationRouter.HandleFunc("GET /v1/media/download/{mediaID}", mp.DownloadMediaFederation)
+ mp.FederationRouter.HandleFunc("GET /v1/media/thumbnail/{mediaID}", mp.DownloadMediaFederation)
+ mp.FederationRouter.HandleFunc("GET /v1/version", mp.KeyServer.GetServerVersion)
+ mp.ClientMediaRouter = http.NewServeMux()
+ mp.ClientMediaRouter.HandleFunc("GET /download/{serverName}/{mediaID}", mp.DownloadMedia)
+ mp.ClientMediaRouter.HandleFunc("GET /download/{serverName}/{mediaID}/{fileName}", mp.DownloadMedia)
+ mp.ClientMediaRouter.HandleFunc("GET /thumbnail/{serverName}/{mediaID}", mp.DownloadMedia)
+ mp.ClientMediaRouter.HandleFunc("PUT /upload/{serverName}/{mediaID}", mp.UploadNotSupported)
+ mp.ClientMediaRouter.HandleFunc("POST /upload", mp.UploadNotSupported)
+ mp.ClientMediaRouter.HandleFunc("POST /create", mp.UploadNotSupported)
+ mp.ClientMediaRouter.HandleFunc("GET /config", mp.UploadNotSupported)
+ mp.ClientMediaRouter.HandleFunc("GET /preview_url", mp.PreviewURLNotSupported)
+ return mp, nil
}
type BasicConfig struct {
ServerName string `yaml:"server_name" json:"server_name"`
ServerKey string `yaml:"server_key" json:"server_key"`
+ FederationAuth bool `yaml:"federation_auth" json:"federation_auth"`
WellKnownResponse string `yaml:"well_known_response" json:"well_known_response"`
}
@@ -150,6 +172,9 @@ func NewFromConfig(cfg BasicConfig, getMedia GetMediaFunc) (*MediaProxy, error)
if cfg.WellKnownResponse != "" {
mp.KeyServer.WellKnownTarget = cfg.WellKnownResponse
}
+ if cfg.FederationAuth {
+ mp.EnableServerAuth(nil, nil)
+ }
return mp, nil
}
@@ -159,8 +184,8 @@ type ServerConfig struct {
}
func (mp *MediaProxy) Listen(cfg ServerConfig) error {
- router := mux.NewRouter()
- mp.RegisterRoutes(router)
+ router := http.NewServeMux()
+ mp.RegisterRoutes(router, zerolog.Nop())
return http.ListenAndServe(fmt.Sprintf("%s:%d", cfg.Hostname, cfg.Port), router)
}
@@ -172,49 +197,42 @@ func (mp *MediaProxy) GetServerKey() *federation.SigningKey {
return mp.serverKey
}
-func (mp *MediaProxy) RegisterRoutes(router *mux.Router) {
- if mp.FederationRouter == nil {
- mp.FederationRouter = router.PathPrefix("/_matrix/federation").Subrouter()
+func (mp *MediaProxy) EnableServerAuth(client *federation.Client, keyCache federation.KeyCache) {
+ if keyCache == nil {
+ keyCache = federation.NewInMemoryCache()
}
- if mp.ClientMediaRouter == nil {
- mp.ClientMediaRouter = router.PathPrefix("/_matrix/client/v1/media").Subrouter()
+ if client == nil {
+ resCache, _ := keyCache.(federation.ResolutionCache)
+ client = federation.NewClient(mp.serverName, mp.serverKey, resCache)
}
-
- mp.FederationRouter.HandleFunc("/v1/media/download/{mediaID}", mp.DownloadMediaFederation).Methods(http.MethodGet)
- mp.FederationRouter.HandleFunc("/v1/version", mp.KeyServer.GetServerVersion).Methods(http.MethodGet)
- mp.ClientMediaRouter.HandleFunc("/download/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet)
- mp.ClientMediaRouter.HandleFunc("/download/{serverName}/{mediaID}/{fileName}", mp.DownloadMedia).Methods(http.MethodGet)
- mp.ClientMediaRouter.HandleFunc("/thumbnail/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet)
- mp.ClientMediaRouter.HandleFunc("/upload/{serverName}/{mediaID}", mp.UploadNotSupported).Methods(http.MethodPut)
- mp.ClientMediaRouter.HandleFunc("/upload", mp.UploadNotSupported).Methods(http.MethodPost)
- mp.ClientMediaRouter.HandleFunc("/create", mp.UploadNotSupported).Methods(http.MethodPost)
- mp.ClientMediaRouter.HandleFunc("/config", mp.UploadNotSupported).Methods(http.MethodGet)
- mp.ClientMediaRouter.HandleFunc("/preview_url", mp.PreviewURLNotSupported).Methods(http.MethodGet)
- mp.FederationRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint)
- mp.FederationRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod)
- mp.ClientMediaRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint)
- mp.ClientMediaRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod)
- corsMiddleware := func(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Header().Set("Access-Control-Allow-Origin", "*")
- w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
- w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With, Content-Type, Authorization")
- w.Header().Set("Content-Security-Policy", "sandbox; default-src 'none'; script-src 'none'; plugin-types application/pdf; style-src 'unsafe-inline'; object-src 'self';")
- next.ServeHTTP(w, r)
- })
- }
- mp.ClientMediaRouter.Use(corsMiddleware)
- mp.KeyServer.Register(router)
+ mp.ServerAuth = federation.NewServerAuth(client, keyCache, func(auth federation.XMatrixAuth) string {
+ return mp.GetServerName()
+ })
}
-// Deprecated: use mautrix.RespError instead
-type ResponseError struct {
- Status int
- Data any
-}
-
-func (err *ResponseError) Error() string {
- return fmt.Sprintf("HTTP %d: %v", err.Status, err.Data)
+func (mp *MediaProxy) RegisterRoutes(router *http.ServeMux, log zerolog.Logger) {
+ errorBodies := exhttp.ErrorBodies{
+ NotFound: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint")).MarshalJSON()),
+ MethodNotAllowed: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint")).MarshalJSON()),
+ }
+ router.Handle("/_matrix/federation/", exhttp.ApplyMiddleware(
+ mp.FederationRouter,
+ exhttp.StripPrefix("/_matrix/federation"),
+ hlog.NewHandler(log),
+ hlog.RequestIDHandler("request_id", "Request-Id"),
+ requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}),
+ exhttp.HandleErrors(errorBodies),
+ ))
+ router.Handle("/_matrix/client/v1/media/", exhttp.ApplyMiddleware(
+ mp.ClientMediaRouter,
+ exhttp.StripPrefix("/_matrix/client/v1/media"),
+ hlog.NewHandler(log),
+ hlog.RequestIDHandler("request_id", "Request-Id"),
+ exhttp.CORSMiddleware,
+ requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}),
+ exhttp.HandleErrors(errorBodies),
+ ))
+ mp.KeyServer.Register(router, log)
}
var ErrInvalidMediaIDSyntax = errors.New("invalid media ID syntax")
@@ -228,20 +246,18 @@ func queryToMap(vals url.Values) map[string]string {
}
func (mp *MediaProxy) getMedia(w http.ResponseWriter, r *http.Request) GetMediaResponse {
- mediaID := mux.Vars(r)["mediaID"]
+ mediaID := r.PathValue("mediaID")
+ if !id.IsValidMediaID(mediaID) {
+ mautrix.MNotFound.WithMessage("Media ID %q is not valid", mediaID).Write(w)
+ return nil
+ }
resp, err := mp.GetMedia(r.Context(), mediaID, queryToMap(r.URL.Query()))
if err != nil {
- //lint:ignore SA1019 deprecated types need to be supported until they're removed
- var respError *ResponseError
var mautrixRespError mautrix.RespError
if errors.Is(err, ErrInvalidMediaIDSyntax) {
mautrix.MNotFound.WithMessage("This is a media proxy at %q, other media downloads are not available here", mp.serverName).Write(w)
} else if errors.As(err, &mautrixRespError) {
mautrixRespError.Write(w)
- } else if errors.As(err, &respError) {
- w.Header().Add("Content-Type", "application/json")
- w.WriteHeader(respError.Status)
- _ = json.NewEncoder(w).Encode(respError.Data)
} else {
zerolog.Ctx(r.Context()).Err(err).Str("media_id", mediaID).Msg("Failed to get media URL")
mautrix.MNotFound.WithMessage("Media not found").Write(w)
@@ -271,9 +287,16 @@ func startMultipart(ctx context.Context, w http.ResponseWriter) *multipart.Write
}
func (mp *MediaProxy) DownloadMediaFederation(w http.ResponseWriter, r *http.Request) {
+ if mp.ServerAuth != nil {
+ var err *mautrix.RespError
+ r, err = mp.ServerAuth.Authenticate(r)
+ if err != nil {
+ err.Write(w)
+ return
+ }
+ }
ctx := r.Context()
log := zerolog.Ctx(ctx)
- // TODO check destination header in X-Matrix auth
resp := mp.getMedia(w, r)
if resp == nil {
@@ -369,8 +392,7 @@ func (mp *MediaProxy) addHeaders(w http.ResponseWriter, mimeType, fileName strin
func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
log := zerolog.Ctx(ctx)
- vars := mux.Vars(r)
- if vars["serverName"] != mp.serverName {
+ if r.PathValue("serverName") != mp.serverName {
mautrix.MNotFound.WithMessage("This is a media proxy at %q, other media downloads are not available here", mp.serverName).Write(w)
return
}
@@ -393,7 +415,7 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTemporaryRedirect)
} else if fileResp, ok := resp.(*GetMediaResponseFile); ok {
responseStarted, err := doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error {
- mp.addHeaders(w, mimeType, vars["fileName"])
+ mp.addHeaders(w, mimeType, r.PathValue("fileName"))
w.Header().Set("Content-Length", strconv.FormatInt(size, 10))
w.WriteHeader(http.StatusOK)
_, err := wt.WriteTo(w)
@@ -410,13 +432,16 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) {
}
}
}
- } else if dataResp, ok := resp.(GetMediaResponseWriter); ok {
- mp.addHeaders(w, dataResp.GetContentType(), vars["fileName"])
- if dataResp.GetContentLength() != 0 {
- w.Header().Set("Content-Length", strconv.FormatInt(dataResp.GetContentLength(), 10))
+ } else if writerResp, ok := resp.(GetMediaResponseWriter); ok {
+ if dataResp, ok := writerResp.(*GetMediaResponseData); ok {
+ defer dataResp.Reader.Close()
+ }
+ mp.addHeaders(w, writerResp.GetContentType(), r.PathValue("fileName"))
+ if writerResp.GetContentLength() != 0 {
+ w.Header().Set("Content-Length", strconv.FormatInt(writerResp.GetContentLength(), 10))
}
w.WriteHeader(http.StatusOK)
- _, err := dataResp.WriteTo(w)
+ _, err := writerResp.WriteTo(w)
if err != nil {
log.Err(err).Msg("Failed to write media data")
}
@@ -433,23 +458,35 @@ func doTempFileDownload(
if err != nil {
return false, fmt.Errorf("failed to create temp file: %w", err)
}
+ origTempFile := tempFile
defer func() {
- _ = tempFile.Close()
- _ = os.Remove(tempFile.Name())
+ _ = origTempFile.Close()
+ _ = os.Remove(origTempFile.Name())
}()
- err = data.Callback(tempFile)
+ meta, err := data.Callback(tempFile)
if err != nil {
return false, err
}
- _, err = tempFile.Seek(0, io.SeekStart)
- if err != nil {
- return false, fmt.Errorf("failed to seek to start of temp file: %w", err)
+ if meta.ReplacementFile != "" {
+ tempFile, err = os.Open(meta.ReplacementFile)
+ if err != nil {
+ return false, fmt.Errorf("failed to open replacement file: %w", err)
+ }
+ defer func() {
+ _ = tempFile.Close()
+ _ = os.Remove(origTempFile.Name())
+ }()
+ } else {
+ _, err = tempFile.Seek(0, io.SeekStart)
+ if err != nil {
+ return false, fmt.Errorf("failed to seek to start of temp file: %w", err)
+ }
}
fileInfo, err := tempFile.Stat()
if err != nil {
return false, fmt.Errorf("failed to stat temp file: %w", err)
}
- mimeType := data.ContentType
+ mimeType := meta.ContentType
if mimeType == "" {
buf := make([]byte, 512)
n, err := tempFile.Read(buf)
@@ -477,11 +514,6 @@ var (
ErrPreviewURLNotSupported = mautrix.MUnrecognized.
WithMessage("This is a media proxy and does not support URL previews.").
WithStatus(http.StatusNotImplemented)
- ErrUnknownEndpoint = mautrix.MUnrecognized.
- WithMessage("Unrecognized endpoint")
- ErrUnsupportedMethod = mautrix.MUnrecognized.
- WithMessage("Invalid method for endpoint").
- WithStatus(http.StatusMethodNotAllowed)
)
func (mp *MediaProxy) UploadNotSupported(w http.ResponseWriter, r *http.Request) {
@@ -491,11 +523,3 @@ func (mp *MediaProxy) UploadNotSupported(w http.ResponseWriter, r *http.Request)
func (mp *MediaProxy) PreviewURLNotSupported(w http.ResponseWriter, r *http.Request) {
ErrPreviewURLNotSupported.Write(w)
}
-
-func (mp *MediaProxy) UnknownEndpoint(w http.ResponseWriter, r *http.Request) {
- ErrUnknownEndpoint.Write(w)
-}
-
-func (mp *MediaProxy) UnsupportedMethod(w http.ResponseWriter, r *http.Request) {
- ErrUnsupportedMethod.Write(w)
-}
diff --git a/mockserver/mockserver.go b/mockserver/mockserver.go
new file mode 100644
index 00000000..507c24a5
--- /dev/null
+++ b/mockserver/mockserver.go
@@ -0,0 +1,307 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mockserver
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "maps"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ globallog "github.com/rs/zerolog/log" // zerolog-allow-global-log
+ "github.com/stretchr/testify/require"
+ "go.mau.fi/util/dbutil"
+ "go.mau.fi/util/exerrors"
+ "go.mau.fi/util/exhttp"
+ "go.mau.fi/util/random"
+
+ "maunium.net/go/mautrix"
+ "maunium.net/go/mautrix/crypto"
+ "maunium.net/go/mautrix/crypto/cryptohelper"
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+)
+
+func mustDecode(r *http.Request, data any) {
+ exerrors.PanicIfNotNil(json.NewDecoder(r.Body).Decode(data))
+}
+
+type userAndDeviceID struct {
+ UserID id.UserID
+ DeviceID id.DeviceID
+}
+
+type MockServer struct {
+ Router *http.ServeMux
+ Server *httptest.Server
+
+ AccessTokenToUserID map[string]userAndDeviceID
+ DeviceInbox map[id.UserID]map[id.DeviceID][]event.Event
+ AccountData map[id.UserID]map[event.Type]json.RawMessage
+ DeviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys
+ OneTimeKeys map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey
+ MasterKeys map[id.UserID]mautrix.CrossSigningKeys
+ SelfSigningKeys map[id.UserID]mautrix.CrossSigningKeys
+ UserSigningKeys map[id.UserID]mautrix.CrossSigningKeys
+
+ PopOTKs bool
+ MemoryStore bool
+}
+
+func Create(t testing.TB) *MockServer {
+ t.Helper()
+
+ server := MockServer{
+ AccessTokenToUserID: map[string]userAndDeviceID{},
+ DeviceInbox: map[id.UserID]map[id.DeviceID][]event.Event{},
+ AccountData: map[id.UserID]map[event.Type]json.RawMessage{},
+ DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{},
+ OneTimeKeys: map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{},
+ MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{},
+ SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
+ UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
+ PopOTKs: true,
+ MemoryStore: true,
+ }
+
+ router := http.NewServeMux()
+ router.HandleFunc("POST /_matrix/client/v3/login", server.postLogin)
+ router.HandleFunc("POST /_matrix/client/v3/keys/query", server.postKeysQuery)
+ router.HandleFunc("POST /_matrix/client/v3/keys/claim", server.postKeysClaim)
+ router.HandleFunc("PUT /_matrix/client/v3/sendToDevice/{type}/{txn}", server.putSendToDevice)
+ router.HandleFunc("PUT /_matrix/client/v3/user/{userID}/account_data/{type}", server.putAccountData)
+ router.HandleFunc("POST /_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload)
+ router.HandleFunc("POST /_matrix/client/v3/keys/signatures/upload", server.emptyResp)
+ router.HandleFunc("POST /_matrix/client/v3/keys/upload", server.postKeysUpload)
+ server.Router = router
+ server.Server = httptest.NewServer(router)
+ t.Cleanup(server.Server.Close)
+ return &server
+}
+
+func (ms *MockServer) getUserID(r *http.Request) userAndDeviceID {
+ authHeader := r.Header.Get("Authorization")
+ authHeader = strings.TrimPrefix(authHeader, "Bearer ")
+ userID, ok := ms.AccessTokenToUserID[authHeader]
+ if !ok {
+ panic("no user ID found for access token " + authHeader)
+ }
+ return userID
+}
+
+func (ms *MockServer) emptyResp(w http.ResponseWriter, _ *http.Request) {
+ exhttp.WriteEmptyJSONResponse(w, http.StatusOK)
+}
+
+func (ms *MockServer) postLogin(w http.ResponseWriter, r *http.Request) {
+ var loginReq mautrix.ReqLogin
+ mustDecode(r, &loginReq)
+
+ deviceID := loginReq.DeviceID
+ if deviceID == "" {
+ deviceID = id.DeviceID(random.String(10))
+ }
+
+ accessToken := random.String(30)
+ userID := id.UserID(loginReq.Identifier.User)
+ ms.AccessTokenToUserID[accessToken] = userAndDeviceID{
+ UserID: userID,
+ DeviceID: deviceID,
+ }
+
+ exhttp.WriteJSONResponse(w, http.StatusOK, &mautrix.RespLogin{
+ AccessToken: accessToken,
+ DeviceID: deviceID,
+ UserID: userID,
+ })
+}
+
+func (ms *MockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) {
+ var req mautrix.ReqSendToDevice
+ mustDecode(r, &req)
+ evtType := event.Type{Type: r.PathValue("type"), Class: event.ToDeviceEventType}
+
+ for user, devices := range req.Messages {
+ for device, content := range devices {
+ if _, ok := ms.DeviceInbox[user]; !ok {
+ ms.DeviceInbox[user] = map[id.DeviceID][]event.Event{}
+ }
+ content.ParseRaw(evtType)
+ ms.DeviceInbox[user][device] = append(ms.DeviceInbox[user][device], event.Event{
+ Sender: ms.getUserID(r).UserID,
+ Type: evtType,
+ Content: *content,
+ })
+ }
+ }
+ ms.emptyResp(w, r)
+}
+
+func (ms *MockServer) putAccountData(w http.ResponseWriter, r *http.Request) {
+ userID := id.UserID(r.PathValue("userID"))
+ eventType := event.Type{Type: r.PathValue("type"), Class: event.AccountDataEventType}
+
+ jsonData, _ := io.ReadAll(r.Body)
+ if _, ok := ms.AccountData[userID]; !ok {
+ ms.AccountData[userID] = map[event.Type]json.RawMessage{}
+ }
+ ms.AccountData[userID][eventType] = json.RawMessage(jsonData)
+ ms.emptyResp(w, r)
+}
+
+func (ms *MockServer) postKeysQuery(w http.ResponseWriter, r *http.Request) {
+ var req mautrix.ReqQueryKeys
+ mustDecode(r, &req)
+ resp := mautrix.RespQueryKeys{
+ MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{},
+ UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
+ SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
+ DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{},
+ }
+ for user := range req.DeviceKeys {
+ resp.MasterKeys[user] = ms.MasterKeys[user]
+ resp.UserSigningKeys[user] = ms.UserSigningKeys[user]
+ resp.SelfSigningKeys[user] = ms.SelfSigningKeys[user]
+ resp.DeviceKeys[user] = ms.DeviceKeys[user]
+ }
+ exhttp.WriteJSONResponse(w, http.StatusOK, &resp)
+}
+
+func (ms *MockServer) postKeysClaim(w http.ResponseWriter, r *http.Request) {
+ var req mautrix.ReqClaimKeys
+ mustDecode(r, &req)
+ resp := mautrix.RespClaimKeys{
+ OneTimeKeys: map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{},
+ }
+ for user, devices := range req.OneTimeKeys {
+ resp.OneTimeKeys[user] = map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{}
+ for device := range devices {
+ keys := ms.OneTimeKeys[user][device]
+ for keyID, key := range keys {
+ if ms.PopOTKs {
+ delete(keys, keyID)
+ }
+ resp.OneTimeKeys[user][device] = map[id.KeyID]mautrix.OneTimeKey{
+ keyID: key,
+ }
+ break
+ }
+ }
+ }
+ exhttp.WriteJSONResponse(w, http.StatusOK, &resp)
+}
+
+func (ms *MockServer) postKeysUpload(w http.ResponseWriter, r *http.Request) {
+ var req mautrix.ReqUploadKeys
+ mustDecode(r, &req)
+
+ uid := ms.getUserID(r)
+ userID := uid.UserID
+ if _, ok := ms.DeviceKeys[userID]; !ok {
+ ms.DeviceKeys[userID] = map[id.DeviceID]mautrix.DeviceKeys{}
+ }
+ if _, ok := ms.OneTimeKeys[userID]; !ok {
+ ms.OneTimeKeys[userID] = map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{}
+ }
+
+ if req.DeviceKeys != nil {
+ ms.DeviceKeys[userID][uid.DeviceID] = *req.DeviceKeys
+ }
+ otks, ok := ms.OneTimeKeys[userID][uid.DeviceID]
+ if !ok {
+ otks = map[id.KeyID]mautrix.OneTimeKey{}
+ ms.OneTimeKeys[userID][uid.DeviceID] = otks
+ }
+ if req.OneTimeKeys != nil {
+ maps.Copy(otks, req.OneTimeKeys)
+ }
+
+ exhttp.WriteJSONResponse(w, http.StatusOK, &mautrix.RespUploadKeys{
+ OneTimeKeyCounts: mautrix.OTKCount{SignedCurve25519: len(otks)},
+ })
+}
+
+func (ms *MockServer) postDeviceSigningUpload(w http.ResponseWriter, r *http.Request) {
+ var req mautrix.UploadCrossSigningKeysReq[any]
+ mustDecode(r, &req)
+
+ userID := ms.getUserID(r).UserID
+ ms.MasterKeys[userID] = req.Master
+ ms.SelfSigningKeys[userID] = req.SelfSigning
+ ms.UserSigningKeys[userID] = req.UserSigning
+
+ ms.emptyResp(w, r)
+}
+
+func (ms *MockServer) Login(t testing.TB, ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*mautrix.Client, crypto.Store) {
+ t.Helper()
+ if ctx == nil {
+ ctx = context.TODO()
+ }
+ client, err := mautrix.NewClient(ms.Server.URL, "", "")
+ require.NoError(t, err)
+ client.Client = ms.Server.Client()
+
+ _, err = client.Login(ctx, &mautrix.ReqLogin{
+ Type: mautrix.AuthTypePassword,
+ Identifier: mautrix.UserIdentifier{
+ Type: mautrix.IdentifierTypeUser,
+ User: userID.String(),
+ },
+ DeviceID: deviceID,
+ Password: "password",
+ StoreCredentials: true,
+ })
+ require.NoError(t, err)
+
+ var store any
+ if ms.MemoryStore {
+ store = crypto.NewMemoryStore(nil)
+ client.StateStore = mautrix.NewMemoryStateStore()
+ } else {
+ store, err = dbutil.NewFromConfig("", dbutil.Config{
+ PoolConfig: dbutil.PoolConfig{
+ Type: "sqlite3-fk-wal",
+ URI: fmt.Sprintf("file:%s?mode=memory&cache=shared&_txlock=immediate", random.String(10)),
+ MaxOpenConns: 5,
+ MaxIdleConns: 1,
+ },
+ }, nil)
+ require.NoError(t, err)
+ }
+ cryptoHelper, err := cryptohelper.NewCryptoHelper(client, []byte("test"), store)
+ require.NoError(t, err)
+ client.Crypto = cryptoHelper
+
+ err = cryptoHelper.Init(ctx)
+ require.NoError(t, err)
+
+ machineLog := globallog.Logger.With().
+ Stringer("my_user_id", userID).
+ Stringer("my_device_id", deviceID).
+ Logger()
+ cryptoHelper.Machine().Log = &machineLog
+
+ err = cryptoHelper.Machine().ShareKeys(ctx, 50)
+ require.NoError(t, err)
+
+ return client, cryptoHelper.Machine().CryptoStore
+}
+
+func (ms *MockServer) DispatchToDevice(t testing.TB, ctx context.Context, client *mautrix.Client) {
+ t.Helper()
+
+ for _, evt := range ms.DeviceInbox[client.UserID][client.DeviceID] {
+ client.Syncer.(*mautrix.DefaultSyncer).Dispatch(ctx, &evt)
+ ms.DeviceInbox[client.UserID][client.DeviceID] = ms.DeviceInbox[client.UserID][client.DeviceID][1:]
+ }
+}
diff --git a/pushrules/action.go b/pushrules/action.go
index 9838e88b..b5a884b2 100644
--- a/pushrules/action.go
+++ b/pushrules/action.go
@@ -105,7 +105,7 @@ func (action *PushAction) UnmarshalJSON(raw []byte) error {
if ok {
action.Action = ActionSetTweak
action.Tweak = PushActionTweak(tweak)
- action.Value, _ = val["value"]
+ action.Value = val["value"]
}
}
return nil
diff --git a/pushrules/action_test.go b/pushrules/action_test.go
index a8f68415..3c0aa168 100644
--- a/pushrules/action_test.go
+++ b/pushrules/action_test.go
@@ -139,9 +139,9 @@ func TestPushAction_UnmarshalJSON_InvalidTypeDoesNothing(t *testing.T) {
}
err := pa.UnmarshalJSON([]byte(`{"foo": "bar"}`))
- assert.Nil(t, err)
+ assert.NoError(t, err)
err = pa.UnmarshalJSON([]byte(`9001`))
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, pushrules.PushActionType("unchanged"), pa.Action)
assert.Equal(t, pushrules.PushActionTweak("unchanged"), pa.Tweak)
@@ -156,7 +156,7 @@ func TestPushAction_UnmarshalJSON_StringChangesActionType(t *testing.T) {
}
err := pa.UnmarshalJSON([]byte(`"foo"`))
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, pushrules.PushActionType("foo"), pa.Action)
assert.Equal(t, pushrules.PushActionTweak("unchanged"), pa.Tweak)
@@ -171,7 +171,7 @@ func TestPushAction_UnmarshalJSON_SetTweakChangesTweak(t *testing.T) {
}
err := pa.UnmarshalJSON([]byte(`{"set_tweak": "foo", "value": 123.0}`))
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, pushrules.ActionSetTweak, pa.Action)
assert.Equal(t, pushrules.PushActionTweak("foo"), pa.Tweak)
@@ -185,7 +185,7 @@ func TestPushAction_MarshalJSON_TweakOutputWorks(t *testing.T) {
Value: "bar",
}
data, err := pa.MarshalJSON()
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, []byte(`{"set_tweak":"foo","value":"bar"}`), data)
}
@@ -196,6 +196,6 @@ func TestPushAction_MarshalJSON_OtherOutputWorks(t *testing.T) {
Value: "bar",
}
data, err := pa.MarshalJSON()
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, []byte(`"something else"`), data)
}
diff --git a/pushrules/condition_test.go b/pushrules/condition_test.go
index 0d3eaf7a..37af3e34 100644
--- a/pushrules/condition_test.go
+++ b/pushrules/condition_test.go
@@ -102,14 +102,6 @@ func newEventPropertyIsPushCondition(key string, value any) *pushrules.PushCondi
}
}
-func newEventPropertyContainsPushCondition(key string, value any) *pushrules.PushCondition {
- return &pushrules.PushCondition{
- Kind: pushrules.KindEventPropertyContains,
- Key: key,
- Value: value,
- }
-}
-
func TestPushCondition_Match_InvalidKind(t *testing.T) {
condition := &pushrules.PushCondition{
Kind: pushrules.PushCondKind("invalid"),
diff --git a/pushrules/pushrules_test.go b/pushrules/pushrules_test.go
index a531ca28..a5a0f5e7 100644
--- a/pushrules/pushrules_test.go
+++ b/pushrules/pushrules_test.go
@@ -25,7 +25,7 @@ func TestEventToPushRules(t *testing.T) {
},
}
pushRuleset, err := pushrules.EventToPushRules(evt)
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.NotNil(t, pushRuleset)
assert.IsType(t, pushRuleset.Override, pushrules.PushRuleArray{})
diff --git a/pushrules/rule.go b/pushrules/rule.go
index ee6d33c4..cf659695 100644
--- a/pushrules/rule.go
+++ b/pushrules/rule.go
@@ -8,7 +8,10 @@ package pushrules
import (
"encoding/gob"
+ "regexp"
+ "strings"
+ "go.mau.fi/util/exerrors"
"go.mau.fi/util/glob"
"maunium.net/go/mautrix/event"
@@ -165,13 +168,20 @@ func (rule *PushRule) matchConditions(room Room, evt *event.Event) bool {
}
func (rule *PushRule) matchPattern(room Room, evt *event.Event) bool {
- pattern := glob.CompileWithImplicitContains(rule.Pattern)
- if pattern == nil {
- return false
- }
msg, ok := evt.Content.Raw["body"].(string)
if !ok {
return false
}
- return pattern.Match(msg)
+ var buf strings.Builder
+ // As per https://spec.matrix.org/unstable/client-server-api/#push-rules, content rules are case-insensitive
+ // and must match whole words, so wrap the converted glob in (?i) and \b.
+ buf.WriteString(`(?i)\b`)
+ // strings.Builder will never return errors
+ exerrors.PanicIfNotNil(glob.ToRegexPattern(rule.Pattern, &buf))
+ buf.WriteString(`\b`)
+ pattern, err := regexp.Compile(buf.String())
+ if err != nil {
+ return false
+ }
+ return pattern.MatchString(msg)
}
diff --git a/pushrules/rule_test.go b/pushrules/rule_test.go
index 803c721e..7ff839a7 100644
--- a/pushrules/rule_test.go
+++ b/pushrules/rule_test.go
@@ -186,6 +186,34 @@ func TestPushRule_Match_Content(t *testing.T) {
assert.True(t, rule.Match(blankTestRoom, evt))
}
+func TestPushRule_Match_WordBoundary(t *testing.T) {
+ rule := &pushrules.PushRule{
+ Type: pushrules.ContentRule,
+ Enabled: true,
+ Pattern: "test",
+ }
+
+ evt := newFakeEvent(event.EventMessage, &event.MessageEventContent{
+ MsgType: event.MsgEmote,
+ Body: "is testing pushrules",
+ })
+ assert.False(t, rule.Match(blankTestRoom, evt))
+}
+
+func TestPushRule_Match_CaseInsensitive(t *testing.T) {
+ rule := &pushrules.PushRule{
+ Type: pushrules.ContentRule,
+ Enabled: true,
+ Pattern: "test",
+ }
+
+ evt := newFakeEvent(event.EventMessage, &event.MessageEventContent{
+ MsgType: event.MsgEmote,
+ Body: "is TeSt-InG pushrules",
+ })
+ assert.True(t, rule.Match(blankTestRoom, evt))
+}
+
func TestPushRule_Match_Content_Fail(t *testing.T) {
rule := &pushrules.PushRule{
Type: pushrules.ContentRule,
diff --git a/requests.go b/requests.go
index 595f1212..cc8b7266 100644
--- a/requests.go
+++ b/requests.go
@@ -2,7 +2,9 @@ package mautrix
import (
"encoding/json"
+ "fmt"
"strconv"
+ "time"
"maunium.net/go/mautrix/crypto/signatures"
"maunium.net/go/mautrix/event"
@@ -38,20 +40,40 @@ const (
type Direction rune
+func (d Direction) MarshalJSON() ([]byte, error) {
+ return json.Marshal(string(d))
+}
+
+func (d *Direction) UnmarshalJSON(data []byte) error {
+ var str string
+ if err := json.Unmarshal(data, &str); err != nil {
+ return err
+ }
+ switch str {
+ case "f":
+ *d = DirectionForward
+ case "b":
+ *d = DirectionBackward
+ default:
+ return fmt.Errorf("invalid direction %q, must be 'f' or 'b'", str)
+ }
+ return nil
+}
+
const (
DirectionForward Direction = 'f'
DirectionBackward Direction = 'b'
)
// ReqRegister is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3register
-type ReqRegister struct {
+type ReqRegister[UIAType any] struct {
Username string `json:"username,omitempty"`
Password string `json:"password,omitempty"`
DeviceID id.DeviceID `json:"device_id,omitempty"`
InitialDeviceDisplayName string `json:"initial_device_display_name,omitempty"`
InhibitLogin bool `json:"inhibit_login,omitempty"`
RefreshToken bool `json:"refresh_token,omitempty"`
- Auth interface{} `json:"auth,omitempty"`
+ Auth UIAType `json:"auth,omitempty"`
// Type for registration, only used for appservice user registrations
// https://spec.matrix.org/v1.2/application-service-api/#server-admin-style-permissions
@@ -91,6 +113,10 @@ type ReqLogin struct {
StoreHomeserverURL bool `json:"-"`
}
+type ReqPutDevice struct {
+ DisplayName string `json:"display_name,omitempty"`
+}
+
type ReqUIAuthFallback struct {
Session string `json:"session"`
User string `json:"user"`
@@ -115,11 +141,12 @@ type ReqCreateRoom struct {
InitialState []*event.Event `json:"initial_state,omitempty"`
Preset string `json:"preset,omitempty"`
IsDirect bool `json:"is_direct,omitempty"`
- RoomVersion string `json:"room_version,omitempty"`
+ RoomVersion id.RoomVersion `json:"room_version,omitempty"`
PowerLevelOverride *event.PowerLevelsEventContent `json:"power_level_content_override,omitempty"`
MeowRoomID id.RoomID `json:"fi.mau.room_id,omitempty"`
+ MeowCreateTS int64 `json:"fi.mau.origin_server_ts,omitempty"`
BeeperInitialMembers []id.UserID `json:"com.beeper.initial_members,omitempty"`
BeeperAutoJoinInvites bool `json:"com.beeper.auto_join_invites,omitempty"`
BeeperLocalRoomID id.RoomID `json:"com.beeper.local_room_id,omitempty"`
@@ -134,12 +161,37 @@ type ReqRedact struct {
Extra map[string]interface{}
}
+type ReqRedactUser struct {
+ Reason string `json:"reason"`
+ Limit int `json:"-"`
+}
+
type ReqMembers struct {
At string `json:"at"`
Membership event.Membership `json:"membership,omitempty"`
NotMembership event.Membership `json:"not_membership,omitempty"`
}
+type ReqJoinRoom struct {
+ Via []string `json:"-"`
+ Reason string `json:"reason,omitempty"`
+ ThirdPartySigned any `json:"third_party_signed,omitempty"`
+}
+
+type ReqKnockRoom struct {
+ Via []string `json:"-"`
+ Reason string `json:"reason,omitempty"`
+}
+
+type ReqSearchUserDirectory struct {
+ SearchTerm string `json:"search_term"`
+ Limit int `json:"limit,omitempty"`
+}
+
+type ReqMutualRooms struct {
+ From string `json:"-"`
+}
+
// ReqInvite3PID is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidinvite-1
// It is also a JSON object used in https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3createroom
type ReqInvite3PID struct {
@@ -168,6 +220,8 @@ type ReqKickUser struct {
type ReqBanUser struct {
Reason string `json:"reason,omitempty"`
UserID id.UserID `json:"user_id"`
+
+ MSC4293RedactEvents bool `json:"org.matrix.msc4293.redact_events,omitempty"`
}
// ReqUnbanUser is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidunban
@@ -183,7 +237,8 @@ type ReqTyping struct {
}
type ReqPresence struct {
- Presence event.Presence `json:"presence"`
+ Presence event.Presence `json:"presence"`
+ StatusMsg string `json:"status_msg,omitempty"`
}
type ReqAliasCreate struct {
@@ -265,11 +320,11 @@ func (csk *CrossSigningKeys) FirstKey() id.Ed25519 {
return ""
}
-type UploadCrossSigningKeysReq struct {
+type UploadCrossSigningKeysReq[UIAType any] struct {
Master CrossSigningKeys `json:"master_key"`
SelfSigning CrossSigningKeys `json:"self_signing_key"`
UserSigning CrossSigningKeys `json:"user_signing_key"`
- Auth interface{} `json:"auth,omitempty"`
+ Auth UIAType `json:"auth,omitempty"`
}
type KeyMap map[id.DeviceKeyID]string
@@ -311,20 +366,40 @@ type ReqSendToDevice struct {
Messages map[id.UserID]map[id.DeviceID]*event.Content `json:"messages"`
}
+type ReqSendEvent struct {
+ Timestamp int64
+ TransactionID string
+ UnstableDelay time.Duration
+ UnstableStickyDuration time.Duration
+ DontEncrypt bool
+ MeowEventID id.EventID
+}
+
+type ReqDelayedEvents struct {
+ DelayID id.DelayID `json:"-"`
+ Status event.DelayStatus `json:"-"`
+ NextBatch string `json:"-"`
+}
+
+type ReqUpdateDelayedEvent struct {
+ DelayID id.DelayID `json:"-"`
+ Action event.DelayAction `json:"action"`
+}
+
// ReqDeviceInfo is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3devicesdeviceid
type ReqDeviceInfo struct {
DisplayName string `json:"display_name,omitempty"`
}
// ReqDeleteDevice is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#delete_matrixclientv3devicesdeviceid
-type ReqDeleteDevice struct {
- Auth interface{} `json:"auth,omitempty"`
+type ReqDeleteDevice[UIAType any] struct {
+ Auth UIAType `json:"auth,omitempty"`
}
// ReqDeleteDevices is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3delete_devices
-type ReqDeleteDevices struct {
+type ReqDeleteDevices[UIAType any] struct {
Devices []id.DeviceID `json:"devices"`
- Auth interface{} `json:"auth,omitempty"`
+ Auth UIAType `json:"auth,omitempty"`
}
type ReqPutPushRule struct {
@@ -336,18 +411,6 @@ type ReqPutPushRule struct {
Pattern string `json:"pattern"`
}
-// Deprecated: MSC2716 was abandoned
-type ReqBatchSend struct {
- PrevEventID id.EventID `json:"-"`
- BatchID id.BatchID `json:"-"`
-
- BeeperNewMessages bool `json:"-"`
- BeeperMarkReadBy id.UserID `json:"-"`
-
- StateEventsAtStart []*event.Event `json:"state_events_at_start"`
- Events []*event.Event `json:"events"`
-}
-
type ReqBeeperBatchSend struct {
// ForwardIfNoMessages should be set to true if the batch should be forward
// backfilled if there are no messages currently in the room.
@@ -383,6 +446,33 @@ type ReqSendReceipt struct {
ThreadID string `json:"thread_id,omitempty"`
}
+type ReqPublicRooms struct {
+ IncludeAllNetworks bool
+ Limit int
+ Since string
+ ThirdPartyInstanceID string
+}
+
+func (req *ReqPublicRooms) Query() map[string]string {
+ query := map[string]string{}
+ if req == nil {
+ return query
+ }
+ if req.IncludeAllNetworks {
+ query["include_all_networks"] = "true"
+ }
+ if req.Limit > 0 {
+ query["limit"] = strconv.Itoa(req.Limit)
+ }
+ if req.Since != "" {
+ query["since"] = req.Since
+ }
+ if req.ThirdPartyInstanceID != "" {
+ query["third_party_instance_id"] = req.ThirdPartyInstanceID
+ }
+ return query
+}
+
// ReqHierarchy contains the parameters for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv1roomsroomidhierarchy
//
// As it's a GET method, there is no JSON body, so this is only query parameters.
@@ -475,3 +565,54 @@ type ReqReport struct {
Reason string `json:"reason,omitempty"`
Score int `json:"score,omitempty"`
}
+
+type ReqGetRelations struct {
+ RelationType event.RelationType
+ EventType event.Type
+
+ Dir Direction
+ From string
+ To string
+ Limit int
+ Recurse bool
+}
+
+func (rgr *ReqGetRelations) PathSuffix() ClientURLPath {
+ if rgr.RelationType != "" {
+ if rgr.EventType.Type != "" {
+ return ClientURLPath{rgr.RelationType, rgr.EventType.Type}
+ }
+ return ClientURLPath{rgr.RelationType}
+ }
+ return ClientURLPath{}
+}
+
+func (rgr *ReqGetRelations) Query() map[string]string {
+ query := map[string]string{}
+ if rgr.Dir != 0 {
+ query["dir"] = string(rgr.Dir)
+ }
+ if rgr.From != "" {
+ query["from"] = rgr.From
+ }
+ if rgr.To != "" {
+ query["to"] = rgr.To
+ }
+ if rgr.Limit > 0 {
+ query["limit"] = strconv.Itoa(rgr.Limit)
+ }
+ if rgr.Recurse {
+ query["recurse"] = "true"
+ }
+ return query
+}
+
+// ReqSuspend is the request body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323
+type ReqSuspend struct {
+ Suspended bool `json:"suspended"`
+}
+
+// ReqLocked is the request body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323
+type ReqLocked struct {
+ Locked bool `json:"locked"`
+}
diff --git a/responses.go b/responses.go
index 26aaac77..4fbe1fbc 100644
--- a/responses.go
+++ b/responses.go
@@ -4,13 +4,16 @@ import (
"bytes"
"encoding/json"
"fmt"
+ "maps"
"reflect"
+ "slices"
"strconv"
"strings"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.mau.fi/util/jsontime"
+ "go.mau.fi/util/ptr"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
@@ -32,6 +35,11 @@ type RespJoinRoom struct {
RoomID id.RoomID `json:"room_id"`
}
+// RespKnockRoom is the JSON response for https://spec.matrix.org/v1.13/client-server-api/#post_matrixclientv3knockroomidoralias
+type RespKnockRoom struct {
+ RoomID id.RoomID `json:"room_id"`
+}
+
// RespLeaveRoom is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidleave
type RespLeaveRoom struct{}
@@ -97,6 +105,29 @@ type RespContext struct {
// RespSendEvent is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidsendeventtypetxnid
type RespSendEvent struct {
EventID id.EventID `json:"event_id"`
+
+ UnstableDelayID id.DelayID `json:"delay_id,omitempty"`
+}
+
+type RespUpdateDelayedEvent struct{}
+
+type RespDelayedEvents struct {
+ Scheduled []*event.ScheduledDelayedEvent `json:"scheduled,omitempty"`
+ Finalised []*event.FinalisedDelayedEvent `json:"finalised,omitempty"`
+ NextBatch string `json:"next_batch,omitempty"`
+
+ // Deprecated: Synapse implementation still returns this
+ DelayedEvents []*event.ScheduledDelayedEvent `json:"delayed_events,omitempty"`
+ // Deprecated: Synapse implementation still returns this
+ FinalisedEvents []*event.FinalisedDelayedEvent `json:"finalised_events,omitempty"`
+}
+
+type RespRedactUserEvents struct {
+ IsMoreEvents bool `json:"is_more_events"`
+ RedactedEvents struct {
+ Total int `json:"total"`
+ SoftFailed int `json:"soft_failed"`
+ } `json:"redacted_events"`
}
// RespMediaConfig is the JSON response for https://spec.matrix.org/v1.4/client-server-api/#get_matrixmediav3config
@@ -155,8 +186,89 @@ type RespUserDisplayName struct {
}
type RespUserProfile struct {
- DisplayName string `json:"displayname"`
- AvatarURL id.ContentURI `json:"avatar_url"`
+ DisplayName string `json:"displayname,omitempty"`
+ AvatarURL id.ContentURI `json:"avatar_url,omitempty"`
+ Extra map[string]any `json:"-"`
+}
+
+type marshalableUserProfile RespUserProfile
+
+func (r *RespUserProfile) UnmarshalJSON(data []byte) error {
+ err := json.Unmarshal(data, &r.Extra)
+ if err != nil {
+ return err
+ }
+ r.DisplayName, _ = r.Extra["displayname"].(string)
+ avatarURL, _ := r.Extra["avatar_url"].(string)
+ if avatarURL != "" {
+ r.AvatarURL, _ = id.ParseContentURI(avatarURL)
+ }
+ delete(r.Extra, "displayname")
+ delete(r.Extra, "avatar_url")
+ return nil
+}
+
+func (r *RespUserProfile) MarshalJSON() ([]byte, error) {
+ if len(r.Extra) == 0 {
+ return json.Marshal((*marshalableUserProfile)(r))
+ }
+ marshalMap := maps.Clone(r.Extra)
+ if r.DisplayName != "" {
+ marshalMap["displayname"] = r.DisplayName
+ } else {
+ delete(marshalMap, "displayname")
+ }
+ if !r.AvatarURL.IsEmpty() {
+ marshalMap["avatar_url"] = r.AvatarURL.String()
+ } else {
+ delete(marshalMap, "avatar_url")
+ }
+ return json.Marshal(marshalMap)
+}
+
+type RespSearchUserDirectory struct {
+ Limited bool `json:"limited"`
+ Results []*UserDirectoryEntry `json:"results"`
+}
+
+type UserDirectoryEntry struct {
+ RespUserProfile
+ UserID id.UserID `json:"user_id"`
+}
+
+func (r *UserDirectoryEntry) UnmarshalJSON(data []byte) error {
+ err := r.RespUserProfile.UnmarshalJSON(data)
+ if err != nil {
+ return err
+ }
+ userIDStr, _ := r.Extra["user_id"].(string)
+ r.UserID = id.UserID(userIDStr)
+ delete(r.Extra, "user_id")
+ return nil
+}
+
+func (r *UserDirectoryEntry) MarshalJSON() ([]byte, error) {
+ if r.Extra == nil {
+ r.Extra = make(map[string]any)
+ }
+ r.Extra["user_id"] = r.UserID.String()
+ return r.RespUserProfile.MarshalJSON()
+}
+
+type RespMutualRooms struct {
+ Joined []id.RoomID `json:"joined"`
+ NextBatch string `json:"next_batch,omitempty"`
+ Count int `json:"count,omitempty"`
+}
+
+type RespRoomSummary struct {
+ PublicRoomInfo
+
+ Membership event.Membership `json:"membership,omitempty"`
+
+ UnstableRoomVersion id.RoomVersion `json:"im.nheko.summary.room_version,omitempty"`
+ UnstableRoomVersionOld id.RoomVersion `json:"im.nheko.summary.version,omitempty"`
+ UnstableEncryption id.Algorithm `json:"im.nheko.summary.encryption,omitempty"`
}
// RespRegisterAvailable is the JSON response for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv3registeravailable
@@ -207,6 +319,9 @@ type RespLogin struct {
DeviceID id.DeviceID `json:"device_id"`
UserID id.UserID `json:"user_id"`
WellKnown *ClientWellKnown `json:"well_known,omitempty"`
+
+ RefreshToken string `json:"refresh_token,omitempty"`
+ ExpiresInMS int64 `json:"expires_in_ms,omitempty"`
}
// RespLogout is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3logout
@@ -227,6 +342,24 @@ type LazyLoadSummary struct {
InvitedMemberCount *int `json:"m.invited_member_count,omitempty"`
}
+func (lls *LazyLoadSummary) MemberCount() int {
+ if lls == nil {
+ return 0
+ }
+ return ptr.Val(lls.JoinedMemberCount) + ptr.Val(lls.InvitedMemberCount)
+}
+
+func (lls *LazyLoadSummary) Equal(other *LazyLoadSummary) bool {
+ if lls == other {
+ return true
+ } else if lls == nil || other == nil {
+ return false
+ }
+ return ptr.Val(lls.JoinedMemberCount) == ptr.Val(other.JoinedMemberCount) &&
+ ptr.Val(lls.InvitedMemberCount) == ptr.Val(other.InvitedMemberCount) &&
+ slices.Equal(lls.Heroes, other.Heroes)
+}
+
type SyncEventsList struct {
Events []*event.Event `json:"events,omitempty"`
}
@@ -322,6 +455,7 @@ type BeeperInboxPreviewEvent struct {
type SyncJoinedRoom struct {
Summary LazyLoadSummary `json:"summary"`
State SyncEventsList `json:"state"`
+ StateAfter *SyncEventsList `json:"state_after,omitempty"`
Timeline SyncTimeline `json:"timeline"`
Ephemeral SyncEventsList `json:"ephemeral"`
AccountData SyncEventsList `json:"account_data"`
@@ -347,16 +481,7 @@ func (sjr SyncJoinedRoom) MarshalJSON() ([]byte, error) {
}
type SyncInvitedRoom struct {
- Summary LazyLoadSummary `json:"summary"`
- State SyncEventsList `json:"invite_state"`
-}
-
-type marshalableSyncInvitedRoom SyncInvitedRoom
-
-var syncInvitedRoomPathsToDelete = []string{"summary"}
-
-func (sir SyncInvitedRoom) MarshalJSON() ([]byte, error) {
- return marshalAndDeleteEmpty((marshalableSyncInvitedRoom)(sir), syncInvitedRoomPathsToDelete)
+ State SyncEventsList `json:"invite_state"`
}
type SyncKnockedRoom struct {
@@ -421,29 +546,19 @@ type RespDeviceInfo struct {
LastSeenTS int64 `json:"last_seen_ts"`
}
-// Deprecated: MSC2716 was abandoned
-type RespBatchSend struct {
- StateEventIDs []id.EventID `json:"state_event_ids"`
- EventIDs []id.EventID `json:"event_ids"`
-
- InsertionEventID id.EventID `json:"insertion_event_id"`
- BatchEventID id.EventID `json:"batch_event_id"`
- BaseInsertionEventID id.EventID `json:"base_insertion_event_id"`
-
- NextBatchID id.BatchID `json:"next_batch_id"`
-}
-
type RespBeeperBatchSend struct {
EventIDs []id.EventID `json:"event_ids"`
}
// RespCapabilities is the JSON response for https://spec.matrix.org/v1.3/client-server-api/#get_matrixclientv3capabilities
type RespCapabilities struct {
- RoomVersions *CapRoomVersions `json:"m.room_versions,omitempty"`
- ChangePassword *CapBooleanTrue `json:"m.change_password,omitempty"`
- SetDisplayname *CapBooleanTrue `json:"m.set_displayname,omitempty"`
- SetAvatarURL *CapBooleanTrue `json:"m.set_avatar_url,omitempty"`
- ThreePIDChanges *CapBooleanTrue `json:"m.3pid_changes,omitempty"`
+ RoomVersions *CapRoomVersions `json:"m.room_versions,omitempty"`
+ ChangePassword *CapBooleanTrue `json:"m.change_password,omitempty"`
+ SetDisplayname *CapBooleanTrue `json:"m.set_displayname,omitempty"`
+ SetAvatarURL *CapBooleanTrue `json:"m.set_avatar_url,omitempty"`
+ ThreePIDChanges *CapBooleanTrue `json:"m.3pid_changes,omitempty"`
+ GetLoginToken *CapBooleanTrue `json:"m.get_login_token,omitempty"`
+ UnstableAccountModeration *CapUnstableAccountModeration `json:"uk.timedout.msc4323,omitempty"`
Custom map[string]interface{} `json:"-"`
}
@@ -552,29 +667,44 @@ func (vers *CapRoomVersions) IsAvailable(version string) bool {
return available
}
+type CapUnstableAccountModeration struct {
+ Suspend bool `json:"suspend"`
+ Lock bool `json:"lock"`
+}
+
+type RespPublicRooms struct {
+ Chunk []*PublicRoomInfo `json:"chunk"`
+ NextBatch string `json:"next_batch,omitempty"`
+ PrevBatch string `json:"prev_batch,omitempty"`
+ TotalRoomCountEstimate int `json:"total_room_count_estimate"`
+}
+
+type PublicRoomInfo struct {
+ RoomID id.RoomID `json:"room_id"`
+ AvatarURL id.ContentURIString `json:"avatar_url,omitempty"`
+ CanonicalAlias id.RoomAlias `json:"canonical_alias,omitempty"`
+ GuestCanJoin bool `json:"guest_can_join"`
+ JoinRule event.JoinRule `json:"join_rule,omitempty"`
+ Name string `json:"name,omitempty"`
+ NumJoinedMembers int `json:"num_joined_members"`
+ RoomType event.RoomType `json:"room_type"`
+ Topic string `json:"topic,omitempty"`
+ WorldReadable bool `json:"world_readable"`
+
+ RoomVersion id.RoomVersion `json:"room_version,omitempty"`
+ Encryption id.Algorithm `json:"encryption,omitempty"`
+ AllowedRoomIDs []id.RoomID `json:"allowed_room_ids,omitempty"`
+}
+
// RespHierarchy is the JSON response for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv1roomsroomidhierarchy
type RespHierarchy struct {
- NextBatch string `json:"next_batch,omitempty"`
- Rooms []ChildRoomsChunk `json:"rooms"`
+ NextBatch string `json:"next_batch,omitempty"`
+ Rooms []*ChildRoomsChunk `json:"rooms"`
}
type ChildRoomsChunk struct {
- AvatarURL id.ContentURI `json:"avatar_url,omitempty"`
- CanonicalAlias id.RoomAlias `json:"canonical_alias,omitempty"`
- ChildrenState []StrippedStateWithTime `json:"children_state"`
- GuestCanJoin bool `json:"guest_can_join"`
- JoinRule event.JoinRule `json:"join_rule,omitempty"`
- Name string `json:"name,omitempty"`
- NumJoinedMembers int `json:"num_joined_members"`
- RoomID id.RoomID `json:"room_id"`
- RoomType event.RoomType `json:"room_type"`
- Topic string `json:"topic,omitempty"`
- WorldReadble bool `json:"world_readable"`
-}
-
-type StrippedStateWithTime struct {
- event.StrippedState
- Timestamp jsontime.UnixMilli `json:"origin_server_ts"`
+ PublicRoomInfo
+ ChildrenState []*event.Event `json:"children_state"`
}
type RespAppservicePing struct {
@@ -623,3 +753,47 @@ type RespRoomKeysUpdate struct {
Count int `json:"count"`
ETag string `json:"etag"`
}
+
+type RespOpenIDToken struct {
+ AccessToken string `json:"access_token"`
+ ExpiresIn int `json:"expires_in"`
+ MatrixServerName string `json:"matrix_server_name"`
+ TokenType string `json:"token_type"` // Always "Bearer"
+}
+
+type RespGetRelations struct {
+ Chunk []*event.Event `json:"chunk"`
+ NextBatch string `json:"next_batch,omitempty"`
+ PrevBatch string `json:"prev_batch,omitempty"`
+ RecursionDepth int `json:"recursion_depth,omitempty"`
+}
+
+// RespSuspended is the response body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323
+type RespSuspended struct {
+ Suspended bool `json:"suspended"`
+}
+
+// RespLocked is the response body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323
+type RespLocked struct {
+ Locked bool `json:"locked"`
+}
+
+type ConnectionInfo struct {
+ IP string `json:"ip,omitempty"`
+ LastSeen jsontime.UnixMilli `json:"last_seen,omitempty"`
+ UserAgent string `json:"user_agent,omitempty"`
+}
+
+type SessionInfo struct {
+ Connections []ConnectionInfo `json:"connections,omitempty"`
+}
+
+type DeviceInfo struct {
+ Sessions []SessionInfo `json:"sessions,omitempty"`
+}
+
+// RespWhoIs is the response body for https://spec.matrix.org/v1.15/client-server-api/#get_matrixclientv3adminwhoisuserid
+type RespWhoIs struct {
+ UserID id.UserID `json:"user_id,omitempty"`
+ Devices map[id.DeviceID]DeviceInfo `json:"devices,omitempty"`
+}
diff --git a/responses_test.go b/responses_test.go
index b23d85ad..73d82635 100644
--- a/responses_test.go
+++ b/responses_test.go
@@ -8,7 +8,6 @@ package mautrix_test
import (
"encoding/json"
- "fmt"
"testing"
"github.com/stretchr/testify/assert"
@@ -86,7 +85,6 @@ func TestRespCapabilities_UnmarshalJSON(t *testing.T) {
var caps mautrix.RespCapabilities
err := json.Unmarshal([]byte(sampleData), &caps)
require.NoError(t, err)
- fmt.Println(caps)
require.NotNil(t, caps.RoomVersions)
assert.Equal(t, "9", caps.RoomVersions.Default)
diff --git a/room.go b/room.go
index c3ddb7e6..4292bff5 100644
--- a/room.go
+++ b/room.go
@@ -5,8 +5,6 @@ import (
"maunium.net/go/mautrix/id"
)
-type RoomStateMap = map[event.Type]map[string]*event.Event
-
// Room represents a single Matrix room.
type Room struct {
ID id.RoomID
@@ -25,8 +23,8 @@ func (room Room) UpdateState(evt *event.Event) {
// GetStateEvent returns the state event for the given type/state_key combo, or nil.
func (room Room) GetStateEvent(eventType event.Type, stateKey string) *event.Event {
- stateEventMap, _ := room.State[eventType]
- evt, _ := stateEventMap[stateKey]
+ stateEventMap := room.State[eventType]
+ evt := stateEventMap[stateKey]
return evt
}
diff --git a/sqlstatestore/statestore.go b/sqlstatestore/statestore.go
index 33c10c4c..11957dfa 100644
--- a/sqlstatestore/statestore.go
+++ b/sqlstatestore/statestore.go
@@ -62,6 +62,9 @@ func (store *SQLStateStore) IsRegistered(ctx context.Context, userID id.UserID)
}
func (store *SQLStateStore) MarkRegistered(ctx context.Context, userID id.UserID) error {
+ if userID == "" {
+ return fmt.Errorf("user ID is empty")
+ }
_, err := store.Exec(ctx, "INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
return err
}
@@ -85,14 +88,11 @@ func (store *SQLStateStore) GetRoomMembers(ctx context.Context, roomID id.RoomID
query = fmt.Sprintf("%s AND membership IN (%s)", query, strings.Join(placeholders, ","))
}
rows, err := store.Query(ctx, query, args...)
- if err != nil {
- return nil, err
- }
members := make(map[id.UserID]*event.MemberEventContent)
- return members, dbutil.NewRowIter(rows, func(row dbutil.Scannable) (ret Member, err error) {
+ return members, dbutil.NewRowIterWithError(rows, func(row dbutil.Scannable) (ret Member, err error) {
err = row.Scan(&ret.UserID, &ret.Membership, &ret.Displayname, &ret.AvatarURL)
return
- }).Iter(func(m Member) (bool, error) {
+ }, err).Iter(func(m Member) (bool, error) {
members[m.UserID] = &m.MemberEventContent
return true, nil
})
@@ -159,10 +159,7 @@ func (store *SQLStateStore) FindSharedRooms(ctx context.Context, userID id.UserI
`
}
rows, err := store.Query(ctx, query, userID)
- if err != nil {
- return nil, err
- }
- return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.RoomID]).AsList()
+ return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.RoomID], err).AsList()
}
func (store *SQLStateStore) IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool {
@@ -188,6 +185,11 @@ func (store *SQLStateStore) IsMembership(ctx context.Context, roomID id.RoomID,
}
func (store *SQLStateStore) SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error {
+ if roomID == "" {
+ return fmt.Errorf("room ID is empty")
+ } else if userID == "" {
+ return fmt.Errorf("user ID is empty")
+ }
_, err := store.Exec(ctx, `
INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, '', '')
ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership
@@ -220,6 +222,11 @@ func (u *userProfileRow) GetMassInsertValues() [5]any {
var userProfileMassInserter = dbutil.NewMassInsertBuilder[*userProfileRow, [1]any](insertUserProfileQuery, "($1, $%d, $%d, $%d, $%d, $%d)")
func (store *SQLStateStore) SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error {
+ if roomID == "" {
+ return fmt.Errorf("room ID is empty")
+ } else if userID == "" {
+ return fmt.Errorf("user ID is empty")
+ }
var nameSkeleton []byte
if !store.DisableNameDisambiguation && len(member.Displayname) > 0 {
nameSkeletonArr := confusable.SkeletonHash(member.Displayname)
@@ -241,6 +248,9 @@ func (store *SQLStateStore) IsConfusableName(ctx context.Context, roomID id.Room
const userProfileMassInsertBatchSize = 500
func (store *SQLStateStore) ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error {
+ if roomID == "" {
+ return fmt.Errorf("room ID is empty")
+ }
return store.DoTxn(ctx, nil, func(ctx context.Context) error {
err := store.ClearCachedMembers(ctx, roomID, onlyMemberships...)
if err != nil {
@@ -311,6 +321,9 @@ func (store *SQLStateStore) HasFetchedMembers(ctx context.Context, roomID id.Roo
}
func (store *SQLStateStore) MarkMembersFetched(ctx context.Context, roomID id.RoomID) error {
+ if roomID == "" {
+ return fmt.Errorf("room ID is empty")
+ }
_, err := store.Exec(ctx, `
INSERT INTO mx_room_state (room_id, members_fetched) VALUES ($1, true)
ON CONFLICT (room_id) DO UPDATE SET members_fetched=true
@@ -340,6 +353,9 @@ func (store *SQLStateStore) GetAllMembers(ctx context.Context, roomID id.RoomID)
}
func (store *SQLStateStore) SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error {
+ if roomID == "" {
+ return fmt.Errorf("room ID is empty")
+ }
contentBytes, err := json.Marshal(content)
if err != nil {
return fmt.Errorf("failed to marshal content JSON: %w", err)
@@ -354,7 +370,7 @@ func (store *SQLStateStore) SetEncryptionEvent(ctx context.Context, roomID id.Ro
func (store *SQLStateStore) GetEncryptionEvent(ctx context.Context, roomID id.RoomID) (*event.EncryptionEventContent, error) {
var data []byte
err := store.
- QueryRow(ctx, "SELECT encryption FROM mx_room_state WHERE room_id=$1", roomID).
+ QueryRow(ctx, "SELECT encryption FROM mx_room_state WHERE room_id=$1 AND encryption IS NOT NULL", roomID).
Scan(&data)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
@@ -377,6 +393,9 @@ func (store *SQLStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (
}
func (store *SQLStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error {
+ if roomID == "" {
+ return fmt.Errorf("room ID is empty")
+ }
_, err := store.Exec(ctx, `
INSERT INTO mx_room_state (room_id, power_levels) VALUES ($1, $2)
ON CONFLICT (room_id) DO UPDATE SET power_levels=excluded.power_levels
@@ -385,89 +404,92 @@ func (store *SQLStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID
}
func (store *SQLStateStore) GetPowerLevels(ctx context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) {
+ levels = &event.PowerLevelsEventContent{}
err = store.
- QueryRow(ctx, "SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID).
- Scan(&dbutil.JSON{Data: &levels})
+ QueryRow(ctx, "SELECT power_levels, create_event FROM mx_room_state WHERE room_id=$1 AND power_levels IS NOT NULL", roomID).
+ Scan(&dbutil.JSON{Data: &levels}, &dbutil.JSON{Data: &levels.CreateEvent})
if errors.Is(err, sql.ErrNoRows) {
- err = nil
+ return nil, nil
+ } else if err != nil {
+ return nil, err
+ }
+ if levels.CreateEvent != nil {
+ err = levels.CreateEvent.Content.ParseRaw(event.StateCreate)
}
return
}
func (store *SQLStateStore) GetPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID) (int, error) {
- if store.Dialect == dbutil.Postgres {
- var powerLevel int
- err := store.
- QueryRow(ctx, `
- SELECT COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
- FROM mx_room_state WHERE room_id=$1
- `, roomID, userID).
- Scan(&powerLevel)
- return powerLevel, err
- } else {
- levels, err := store.GetPowerLevels(ctx, roomID)
- if err != nil {
- return 0, err
- }
- return levels.GetUserLevel(userID), nil
+ levels, err := store.GetPowerLevels(ctx, roomID)
+ if err != nil {
+ return 0, err
}
+ return levels.GetUserLevel(userID), nil
}
func (store *SQLStateStore) GetPowerLevelRequirement(ctx context.Context, roomID id.RoomID, eventType event.Type) (int, error) {
- if store.Dialect == dbutil.Postgres {
- defaultType := "events_default"
- defaultValue := 0
- if eventType.IsState() {
- defaultType = "state_default"
- defaultValue = 50
- }
- var powerLevel int
- err := store.
- QueryRow(ctx, `
- SELECT COALESCE((power_levels->'events'->$2)::int, (power_levels->'$3')::int, $4)
- FROM mx_room_state WHERE room_id=$1
- `, roomID, eventType.Type, defaultType, defaultValue).
- Scan(&powerLevel)
- if errors.Is(err, sql.ErrNoRows) {
- err = nil
- powerLevel = defaultValue
- }
- return powerLevel, err
- } else {
- levels, err := store.GetPowerLevels(ctx, roomID)
- if err != nil {
- return 0, err
- }
- return levels.GetEventLevel(eventType), nil
+ levels, err := store.GetPowerLevels(ctx, roomID)
+ if err != nil {
+ return 0, err
}
+ return levels.GetEventLevel(eventType), nil
}
func (store *SQLStateStore) HasPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID, eventType event.Type) (bool, error) {
- if store.Dialect == dbutil.Postgres {
- defaultType := "events_default"
- defaultValue := 0
- if eventType.IsState() {
- defaultType = "state_default"
- defaultValue = 50
- }
- var hasPower bool
- err := store.
- QueryRow(ctx, `SELECT
- COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
- >=
- COALESCE((power_levels->'events'->$3)::int, (power_levels->'$4')::int, $5)
- FROM mx_room_state WHERE room_id=$1`, roomID, userID, eventType.Type, defaultType, defaultValue).
- Scan(&hasPower)
- if errors.Is(err, sql.ErrNoRows) {
- err = nil
- hasPower = defaultValue == 0
- }
- return hasPower, err
- } else {
- levels, err := store.GetPowerLevels(ctx, roomID)
- if err != nil {
- return false, err
- }
- return levels.GetUserLevel(userID) >= levels.GetEventLevel(eventType), nil
+ levels, err := store.GetPowerLevels(ctx, roomID)
+ if err != nil {
+ return false, err
}
+ return levels.GetUserLevel(userID) >= levels.GetEventLevel(eventType), nil
+}
+
+func (store *SQLStateStore) SetCreate(ctx context.Context, evt *event.Event) error {
+ if evt.Type != event.StateCreate {
+ return fmt.Errorf("invalid event type for create event: %s", evt.Type)
+ } else if evt.RoomID == "" {
+ return fmt.Errorf("room ID is empty")
+ }
+ _, err := store.Exec(ctx, `
+ INSERT INTO mx_room_state (room_id, create_event) VALUES ($1, $2)
+ ON CONFLICT (room_id) DO UPDATE SET create_event=excluded.create_event
+ `, evt.RoomID, dbutil.JSON{Data: evt})
+ return err
+}
+
+func (store *SQLStateStore) GetCreate(ctx context.Context, roomID id.RoomID) (evt *event.Event, err error) {
+ err = store.
+ QueryRow(ctx, "SELECT create_event FROM mx_room_state WHERE room_id=$1 AND create_event IS NOT NULL", roomID).
+ Scan(&dbutil.JSON{Data: &evt})
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, nil
+ } else if err != nil {
+ return nil, err
+ }
+ if evt != nil {
+ err = evt.Content.ParseRaw(event.StateCreate)
+ }
+ return
+}
+
+func (store *SQLStateStore) SetJoinRules(ctx context.Context, roomID id.RoomID, rules *event.JoinRulesEventContent) error {
+ if roomID == "" {
+ return fmt.Errorf("room ID is empty")
+ }
+ _, err := store.Exec(ctx, `
+ INSERT INTO mx_room_state (room_id, join_rules) VALUES ($1, $2)
+ ON CONFLICT (room_id) DO UPDATE SET join_rules=excluded.join_rules
+ `, roomID, dbutil.JSON{Data: rules})
+ return err
+}
+
+func (store *SQLStateStore) GetJoinRules(ctx context.Context, roomID id.RoomID) (levels *event.JoinRulesEventContent, err error) {
+ levels = &event.JoinRulesEventContent{}
+ err = store.
+ QueryRow(ctx, "SELECT join_rules FROM mx_room_state WHERE room_id=$1 AND join_rules IS NOT NULL", roomID).
+ Scan(&dbutil.JSON{Data: &levels})
+ if errors.Is(err, sql.ErrNoRows) {
+ levels = nil
+ err = nil
+ }
+ return
}
diff --git a/sqlstatestore/v00-latest-revision.sql b/sqlstatestore/v00-latest-revision.sql
index a58cc56a..4679f1c6 100644
--- a/sqlstatestore/v00-latest-revision.sql
+++ b/sqlstatestore/v00-latest-revision.sql
@@ -1,4 +1,4 @@
--- v0 -> v7 (compatible with v3+): Latest revision
+-- v0 -> v10 (compatible with v3+): Latest revision
CREATE TABLE mx_registrations (
user_id TEXT PRIMARY KEY
@@ -26,5 +26,7 @@ CREATE TABLE mx_room_state (
room_id TEXT PRIMARY KEY,
power_levels jsonb,
encryption jsonb,
+ create_event jsonb,
+ join_rules jsonb,
members_fetched BOOLEAN NOT NULL DEFAULT false
);
diff --git a/sqlstatestore/v08-create-event.sql b/sqlstatestore/v08-create-event.sql
new file mode 100644
index 00000000..9f1b55c9
--- /dev/null
+++ b/sqlstatestore/v08-create-event.sql
@@ -0,0 +1,2 @@
+-- v8 (compatible with v3+): Add create event to room state table
+ALTER TABLE mx_room_state ADD COLUMN create_event jsonb;
diff --git a/sqlstatestore/v09-clear-empty-room-ids.sql b/sqlstatestore/v09-clear-empty-room-ids.sql
new file mode 100644
index 00000000..ca951068
--- /dev/null
+++ b/sqlstatestore/v09-clear-empty-room-ids.sql
@@ -0,0 +1,3 @@
+-- v9 (compatible with v3+): Clear invalid rows
+DELETE FROM mx_room_state WHERE room_id='';
+DELETE FROM mx_user_profile WHERE room_id='' OR user_id='';
diff --git a/sqlstatestore/v10-join-rules.sql b/sqlstatestore/v10-join-rules.sql
new file mode 100644
index 00000000..3074c46a
--- /dev/null
+++ b/sqlstatestore/v10-join-rules.sql
@@ -0,0 +1,2 @@
+-- v10 (compatible with v3+): Add join rules to room state table
+ALTER TABLE mx_room_state ADD COLUMN join_rules jsonb;
diff --git a/statestore.go b/statestore.go
index e728b885..2bd498dd 100644
--- a/statestore.go
+++ b/statestore.go
@@ -34,6 +34,12 @@ type StateStore interface {
SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error
GetPowerLevels(ctx context.Context, roomID id.RoomID) (*event.PowerLevelsEventContent, error)
+ SetCreate(ctx context.Context, evt *event.Event) error
+ GetCreate(ctx context.Context, roomID id.RoomID) (*event.Event, error)
+
+ GetJoinRules(ctx context.Context, roomID id.RoomID) (*event.JoinRulesEventContent, error)
+ SetJoinRules(ctx context.Context, roomID id.RoomID, content *event.JoinRulesEventContent) error
+
HasFetchedMembers(ctx context.Context, roomID id.RoomID) (bool, error)
MarkMembersFetched(ctx context.Context, roomID id.RoomID) error
GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error)
@@ -68,9 +74,13 @@ func UpdateStateStore(ctx context.Context, store StateStore, evt *event.Event) {
err = store.SetPowerLevels(ctx, evt.RoomID, content)
case *event.EncryptionEventContent:
err = store.SetEncryptionEvent(ctx, evt.RoomID, content)
+ case *event.CreateEventContent:
+ err = store.SetCreate(ctx, evt)
+ case *event.JoinRulesEventContent:
+ err = store.SetJoinRules(ctx, evt.RoomID, content)
default:
switch evt.Type {
- case event.StateMember, event.StatePowerLevels, event.StateEncryption:
+ case event.StateMember, event.StatePowerLevels, event.StateEncryption, event.StateCreate:
zerolog.Ctx(ctx).Warn().
Stringer("event_id", evt.ID).
Str("event_type", evt.Type.Type).
@@ -101,11 +111,14 @@ type MemoryStateStore struct {
MembersFetched map[id.RoomID]bool `json:"members_fetched"`
PowerLevels map[id.RoomID]*event.PowerLevelsEventContent `json:"power_levels"`
Encryption map[id.RoomID]*event.EncryptionEventContent `json:"encryption"`
+ Create map[id.RoomID]*event.Event `json:"create"`
+ JoinRules map[id.RoomID]*event.JoinRulesEventContent `json:"join_rules"`
registrationsLock sync.RWMutex
membersLock sync.RWMutex
powerLevelsLock sync.RWMutex
encryptionLock sync.RWMutex
+ joinRulesLock sync.RWMutex
}
func NewMemoryStateStore() StateStore {
@@ -115,6 +128,8 @@ func NewMemoryStateStore() StateStore {
MembersFetched: make(map[id.RoomID]bool),
PowerLevels: make(map[id.RoomID]*event.PowerLevelsEventContent),
Encryption: make(map[id.RoomID]*event.EncryptionEventContent),
+ Create: make(map[id.RoomID]*event.Event),
+ JoinRules: make(map[id.RoomID]*event.JoinRulesEventContent),
}
}
@@ -298,6 +313,9 @@ func (store *MemoryStateStore) SetPowerLevels(_ context.Context, roomID id.RoomI
func (store *MemoryStateStore) GetPowerLevels(_ context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) {
store.powerLevelsLock.RLock()
levels = store.PowerLevels[roomID]
+ if levels != nil && levels.CreateEvent == nil {
+ levels.CreateEvent = store.Create[roomID]
+ }
store.powerLevelsLock.RUnlock()
return
}
@@ -314,6 +332,23 @@ func (store *MemoryStateStore) HasPowerLevel(ctx context.Context, roomID id.Room
return exerrors.Must(store.GetPowerLevel(ctx, roomID, userID)) >= exerrors.Must(store.GetPowerLevelRequirement(ctx, roomID, eventType)), nil
}
+func (store *MemoryStateStore) SetCreate(ctx context.Context, evt *event.Event) error {
+ store.powerLevelsLock.Lock()
+ store.Create[evt.RoomID] = evt
+ if pls, ok := store.PowerLevels[evt.RoomID]; ok && pls.CreateEvent == nil {
+ pls.CreateEvent = evt
+ }
+ store.powerLevelsLock.Unlock()
+ return nil
+}
+
+func (store *MemoryStateStore) GetCreate(ctx context.Context, roomID id.RoomID) (*event.Event, error) {
+ store.powerLevelsLock.RLock()
+ evt := store.Create[roomID]
+ store.powerLevelsLock.RUnlock()
+ return evt, nil
+}
+
func (store *MemoryStateStore) SetEncryptionEvent(_ context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error {
store.encryptionLock.Lock()
store.Encryption[roomID] = content
@@ -327,6 +362,19 @@ func (store *MemoryStateStore) GetEncryptionEvent(_ context.Context, roomID id.R
return store.Encryption[roomID], nil
}
+func (store *MemoryStateStore) SetJoinRules(ctx context.Context, roomID id.RoomID, content *event.JoinRulesEventContent) error {
+ store.joinRulesLock.Lock()
+ store.JoinRules[roomID] = content
+ store.joinRulesLock.Unlock()
+ return nil
+}
+
+func (store *MemoryStateStore) GetJoinRules(ctx context.Context, roomID id.RoomID) (*event.JoinRulesEventContent, error) {
+ store.joinRulesLock.RLock()
+ defer store.joinRulesLock.RUnlock()
+ return store.JoinRules[roomID], nil
+}
+
func (store *MemoryStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error) {
cfg, err := store.GetEncryptionEvent(ctx, roomID)
return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1, err
diff --git a/synapseadmin/client.go b/synapseadmin/client.go
index 775b4b13..6925ca7d 100644
--- a/synapseadmin/client.go
+++ b/synapseadmin/client.go
@@ -14,9 +14,9 @@ import (
//
// https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/index.html
type Client struct {
- *mautrix.Client
+ Client *mautrix.Client
}
func (cli *Client) BuildAdminURL(path ...any) string {
- return cli.BuildURL(mautrix.SynapseAdminURLPath(path))
+ return cli.Client.BuildURL(mautrix.SynapseAdminURLPath(path))
}
diff --git a/synapseadmin/register.go b/synapseadmin/register.go
index 641f9b56..05e0729a 100644
--- a/synapseadmin/register.go
+++ b/synapseadmin/register.go
@@ -73,7 +73,7 @@ func (req *ReqSharedSecretRegister) Sign(secret string) string {
// This does not need to be called manually as SharedSecretRegister will automatically call this if no nonce is provided.
func (cli *Client) GetRegisterNonce(ctx context.Context) (string, error) {
var resp respGetRegisterNonce
- _, err := cli.MakeRequest(ctx, http.MethodGet, cli.BuildURL(mautrix.SynapseAdminURLPath{"v1", "register"}), nil, &resp)
+ _, err := cli.Client.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v1", "register"), nil, &resp)
if err != nil {
return "", err
}
@@ -93,7 +93,7 @@ func (cli *Client) SharedSecretRegister(ctx context.Context, sharedSecret string
}
req.SHA1Checksum = req.Sign(sharedSecret)
var resp mautrix.RespRegister
- _, err = cli.MakeRequest(ctx, http.MethodPost, cli.BuildURL(mautrix.SynapseAdminURLPath{"v1", "register"}), &req, &resp)
+ _, err = cli.Client.MakeRequest(ctx, http.MethodPost, cli.BuildAdminURL("v1", "register"), &req, &resp)
if err != nil {
return nil, err
}
diff --git a/synapseadmin/roomapi.go b/synapseadmin/roomapi.go
index 6c072e23..0925b748 100644
--- a/synapseadmin/roomapi.go
+++ b/synapseadmin/roomapi.go
@@ -75,12 +75,17 @@ type RespListRooms struct {
// https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#list-room-api
func (cli *Client) ListRooms(ctx context.Context, req ReqListRoom) (RespListRooms, error) {
var resp RespListRooms
- var reqURL string
- reqURL = cli.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery())
- _, err := cli.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp)
+ reqURL := cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery())
+ _, err := cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp)
return resp, err
}
+func (cli *Client) RoomInfo(ctx context.Context, roomID id.RoomID) (resp *RoomInfo, err error) {
+ reqURL := cli.BuildAdminURL("v1", "rooms", roomID)
+ _, err = cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp)
+ return
+}
+
type RespRoomMessages = mautrix.RespMessages
// RoomMessages returns a list of messages in a room.
@@ -104,13 +109,14 @@ func (cli *Client) RoomMessages(ctx context.Context, roomID id.RoomID, from, to
if limit != 0 {
query["limit"] = strconv.Itoa(limit)
}
- urlPath := cli.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms", roomID, "messages"}, query)
- _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
+ urlPath := cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms", roomID, "messages"}, query)
+ _, err = cli.Client.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return resp, err
}
type ReqDeleteRoom struct {
Purge bool `json:"purge,omitempty"`
+ ForcePurge bool `json:"force_purge,omitempty"`
Block bool `json:"block,omitempty"`
Message string `json:"message,omitempty"`
RoomName string `json:"room_name,omitempty"`
@@ -121,6 +127,19 @@ type RespDeleteRoom struct {
DeleteID string `json:"delete_id"`
}
+type RespDeleteRoomResult struct {
+ KickedUsers []id.UserID `json:"kicked_users,omitempty"`
+ FailedToKickUsers []id.UserID `json:"failed_to_kick_users,omitempty"`
+ LocalAliases []id.RoomAlias `json:"local_aliases,omitempty"`
+ NewRoomID id.RoomID `json:"new_room_id,omitempty"`
+}
+
+type RespDeleteRoomStatus struct {
+ Status string `json:"status,omitempty"`
+ Error string `json:"error,omitempty"`
+ ShutdownRoom RespDeleteRoomResult `json:"shutdown_room,omitempty"`
+}
+
// DeleteRoom deletes a room from the server, optionally blocking it and/or purging all data from the database.
//
// This calls the async version of the endpoint, which will return immediately and delete the room in the background.
@@ -129,10 +148,37 @@ type RespDeleteRoom struct {
func (cli *Client) DeleteRoom(ctx context.Context, roomID id.RoomID, req ReqDeleteRoom) (RespDeleteRoom, error) {
reqURL := cli.BuildAdminURL("v2", "rooms", roomID)
var resp RespDeleteRoom
- _, err := cli.MakeRequest(ctx, http.MethodDelete, reqURL, &req, &resp)
+ _, err := cli.Client.MakeRequest(ctx, http.MethodDelete, reqURL, &req, &resp)
return resp, err
}
+func (cli *Client) DeleteRoomStatus(ctx context.Context, deleteID string) (resp RespDeleteRoomStatus, err error) {
+ reqURL := cli.BuildAdminURL("v2", "rooms", "delete_status", deleteID)
+ _, err = cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp)
+ return
+}
+
+// DeleteRoomSync deletes a room from the server, optionally blocking it and/or purging all data from the database.
+//
+// This calls the synchronous version of the endpoint, which will block until the room is deleted.
+//
+// https://element-hq.github.io/synapse/latest/admin_api/rooms.html#version-1-old-version
+func (cli *Client) DeleteRoomSync(ctx context.Context, roomID id.RoomID, req ReqDeleteRoom) (resp RespDeleteRoomResult, err error) {
+ reqURL := cli.BuildAdminURL("v1", "rooms", roomID)
+ httpClient := &http.Client{}
+ _, err = cli.Client.MakeFullRequest(ctx, mautrix.FullRequest{
+ Method: http.MethodDelete,
+ URL: reqURL,
+ RequestJSON: &req,
+ ResponseJSON: &resp,
+ MaxAttempts: 1,
+ // Use a fresh HTTP client without timeouts
+ Client: httpClient,
+ })
+ httpClient.CloseIdleConnections()
+ return
+}
+
type RespRoomsMembers struct {
Members []id.UserID `json:"members"`
Total int `json:"total"`
@@ -144,7 +190,7 @@ type RespRoomsMembers struct {
func (cli *Client) RoomMembers(ctx context.Context, roomID id.RoomID) (RespRoomsMembers, error) {
reqURL := cli.BuildAdminURL("v1", "rooms", roomID, "members")
var resp RespRoomsMembers
- _, err := cli.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp)
+ _, err := cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp)
return resp, err
}
@@ -157,7 +203,7 @@ type ReqMakeRoomAdmin struct {
// https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#make-room-admin-api
func (cli *Client) MakeRoomAdmin(ctx context.Context, roomIDOrAlias string, req ReqMakeRoomAdmin) error {
reqURL := cli.BuildAdminURL("v1", "rooms", roomIDOrAlias, "make_room_admin")
- _, err := cli.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil)
+ _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil)
return err
}
@@ -170,7 +216,7 @@ type ReqJoinUserToRoom struct {
// https://matrix-org.github.io/synapse/latest/admin_api/room_membership.html
func (cli *Client) JoinUserToRoom(ctx context.Context, roomID id.RoomID, req ReqJoinUserToRoom) error {
reqURL := cli.BuildAdminURL("v1", "join", roomID)
- _, err := cli.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil)
+ _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil)
return err
}
@@ -183,7 +229,7 @@ type ReqBlockRoom struct {
// https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#block-room-api
func (cli *Client) BlockRoom(ctx context.Context, roomID id.RoomID, req ReqBlockRoom) error {
reqURL := cli.BuildAdminURL("v1", "rooms", roomID, "block")
- _, err := cli.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil)
+ _, err := cli.Client.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil)
return err
}
@@ -199,6 +245,6 @@ type RoomsBlockResponse struct {
func (cli *Client) GetRoomBlockStatus(ctx context.Context, roomID id.RoomID) (RoomsBlockResponse, error) {
var resp RoomsBlockResponse
reqURL := cli.BuildAdminURL("v1", "rooms", roomID, "block")
- _, err := cli.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp)
+ _, err := cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp)
return resp, err
}
diff --git a/synapseadmin/userapi.go b/synapseadmin/userapi.go
index 9cbb17e4..b1de55b6 100644
--- a/synapseadmin/userapi.go
+++ b/synapseadmin/userapi.go
@@ -32,7 +32,7 @@ type ReqResetPassword struct {
// https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#reset-password
func (cli *Client) ResetPassword(ctx context.Context, req ReqResetPassword) error {
reqURL := cli.BuildAdminURL("v1", "reset_password", req.UserID)
- _, err := cli.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil)
+ _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil)
return err
}
@@ -43,8 +43,8 @@ func (cli *Client) ResetPassword(ctx context.Context, req ReqResetPassword) erro
//
// https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#check-username-availability
func (cli *Client) UsernameAvailable(ctx context.Context, username string) (resp *mautrix.RespRegisterAvailable, err error) {
- u := cli.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "username_available"}, map[string]string{"username": username})
- _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &resp)
+ u := cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "username_available"}, map[string]string{"username": username})
+ _, err = cli.Client.MakeRequest(ctx, http.MethodGet, u, nil, &resp)
if err == nil && !resp.Available {
err = fmt.Errorf(`request returned OK status without "available": true`)
}
@@ -65,7 +65,7 @@ type RespListDevices struct {
//
// https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#list-all-devices
func (cli *Client) ListDevices(ctx context.Context, userID id.UserID) (resp *RespListDevices, err error) {
- _, err = cli.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v2", "users", userID, "devices"), nil, &resp)
+ _, err = cli.Client.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v2", "users", userID, "devices"), nil, &resp)
return
}
@@ -89,7 +89,7 @@ type RespUserInfo struct {
//
// https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#query-user-account
func (cli *Client) GetUserInfo(ctx context.Context, userID id.UserID) (resp *RespUserInfo, err error) {
- _, err = cli.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v2", "users", userID), nil, &resp)
+ _, err = cli.Client.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v2", "users", userID), nil, &resp)
return
}
@@ -102,7 +102,20 @@ type ReqDeleteUser struct {
// https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#deactivate-account
func (cli *Client) DeactivateAccount(ctx context.Context, userID id.UserID, req ReqDeleteUser) error {
reqURL := cli.BuildAdminURL("v1", "deactivate", userID)
- _, err := cli.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil)
+ _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil)
+ return err
+}
+
+type ReqSuspendUser struct {
+ Suspend bool `json:"suspend"`
+}
+
+// SuspendAccount suspends or unsuspends a specific local user account.
+//
+// https://element-hq.github.io/synapse/latest/admin_api/user_admin_api.html#suspendunsuspend-account
+func (cli *Client) SuspendAccount(ctx context.Context, userID id.UserID, req ReqSuspendUser) error {
+ reqURL := cli.BuildAdminURL("v1", "suspend", userID)
+ _, err := cli.Client.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil)
return err
}
@@ -124,7 +137,7 @@ type ReqCreateOrModifyAccount struct {
// https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#create-or-modify-account
func (cli *Client) CreateOrModifyAccount(ctx context.Context, userID id.UserID, req ReqCreateOrModifyAccount) error {
reqURL := cli.BuildAdminURL("v2", "users", userID)
- _, err := cli.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil)
+ _, err := cli.Client.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil)
return err
}
@@ -140,7 +153,7 @@ type ReqSetRatelimit = RatelimitOverride
// https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#set-ratelimit
func (cli *Client) SetUserRatelimit(ctx context.Context, userID id.UserID, req ReqSetRatelimit) error {
reqURL := cli.BuildAdminURL("v1", "users", userID, "override_ratelimit")
- _, err := cli.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil)
+ _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil)
return err
}
@@ -150,7 +163,7 @@ type RespUserRatelimit = RatelimitOverride
//
// https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#get-status-of-ratelimit
func (cli *Client) GetUserRatelimit(ctx context.Context, userID id.UserID) (resp RespUserRatelimit, err error) {
- _, err = cli.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v1", "users", userID, "override_ratelimit"), nil, &resp)
+ _, err = cli.Client.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v1", "users", userID, "override_ratelimit"), nil, &resp)
return
}
@@ -158,6 +171,6 @@ func (cli *Client) GetUserRatelimit(ctx context.Context, userID id.UserID) (resp
//
// https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#delete-ratelimit
func (cli *Client) DeleteUserRatelimit(ctx context.Context, userID id.UserID) (err error) {
- _, err = cli.MakeRequest(ctx, http.MethodDelete, cli.BuildAdminURL("v1", "users", userID, "override_ratelimit"), nil, nil)
+ _, err = cli.Client.MakeRequest(ctx, http.MethodDelete, cli.BuildAdminURL("v1", "users", userID, "override_ratelimit"), nil, nil)
return
}
diff --git a/sync.go b/sync.go
index d4208404..598df8e0 100644
--- a/sync.go
+++ b/sync.go
@@ -90,6 +90,7 @@ func (s *DefaultSyncer) ProcessResponse(ctx context.Context, res *RespSync, sinc
err = fmt.Errorf("ProcessResponse panicked! since=%s panic=%s\n%s", since, r, debug.Stack())
}
}()
+ ctx = context.WithValue(ctx, SyncTokenContextKey, since)
for _, listener := range s.syncListeners {
if !listener(ctx, res, since) {
@@ -97,33 +98,38 @@ func (s *DefaultSyncer) ProcessResponse(ctx context.Context, res *RespSync, sinc
}
}
- s.processSyncEvents(ctx, "", res.ToDevice.Events, event.SourceToDevice)
- s.processSyncEvents(ctx, "", res.Presence.Events, event.SourcePresence)
- s.processSyncEvents(ctx, "", res.AccountData.Events, event.SourceAccountData)
+ s.processSyncEvents(ctx, "", res.ToDevice.Events, event.SourceToDevice, false)
+ s.processSyncEvents(ctx, "", res.Presence.Events, event.SourcePresence, false)
+ s.processSyncEvents(ctx, "", res.AccountData.Events, event.SourceAccountData, false)
for roomID, roomData := range res.Rooms.Join {
- s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceJoin|event.SourceState)
- s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceJoin|event.SourceTimeline)
- s.processSyncEvents(ctx, roomID, roomData.Ephemeral.Events, event.SourceJoin|event.SourceEphemeral)
- s.processSyncEvents(ctx, roomID, roomData.AccountData.Events, event.SourceJoin|event.SourceAccountData)
+ if roomData.StateAfter == nil {
+ s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceJoin|event.SourceState, false)
+ s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceJoin|event.SourceTimeline, false)
+ } else {
+ s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceJoin|event.SourceTimeline, true)
+ s.processSyncEvents(ctx, roomID, roomData.StateAfter.Events, event.SourceJoin|event.SourceState, false)
+ }
+ s.processSyncEvents(ctx, roomID, roomData.Ephemeral.Events, event.SourceJoin|event.SourceEphemeral, false)
+ s.processSyncEvents(ctx, roomID, roomData.AccountData.Events, event.SourceJoin|event.SourceAccountData, false)
}
for roomID, roomData := range res.Rooms.Invite {
- s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceInvite|event.SourceState)
+ s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceInvite|event.SourceState, false)
}
for roomID, roomData := range res.Rooms.Leave {
- s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceLeave|event.SourceState)
- s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceLeave|event.SourceTimeline)
+ s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceLeave|event.SourceState, false)
+ s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceLeave|event.SourceTimeline, false)
}
return
}
-func (s *DefaultSyncer) processSyncEvents(ctx context.Context, roomID id.RoomID, events []*event.Event, source event.Source) {
+func (s *DefaultSyncer) processSyncEvents(ctx context.Context, roomID id.RoomID, events []*event.Event, source event.Source, ignoreState bool) {
for _, evt := range events {
- s.processSyncEvent(ctx, roomID, evt, source)
+ s.processSyncEvent(ctx, roomID, evt, source, ignoreState)
}
}
-func (s *DefaultSyncer) processSyncEvent(ctx context.Context, roomID id.RoomID, evt *event.Event, source event.Source) {
+func (s *DefaultSyncer) processSyncEvent(ctx context.Context, roomID id.RoomID, evt *event.Event, source event.Source, ignoreState bool) {
evt.RoomID = roomID
// Ensure the type class is correct. It's safe to mutate the class since the event type is not a pointer.
@@ -149,6 +155,7 @@ func (s *DefaultSyncer) processSyncEvent(ctx context.Context, roomID id.RoomID,
}
evt.Mautrix.EventSource = source
+ evt.Mautrix.IgnoreState = ignoreState
s.Dispatch(ctx, evt)
}
@@ -191,8 +198,8 @@ func (s *DefaultSyncer) OnFailedSync(res *RespSync, err error) (time.Duration, e
}
var defaultFilter = Filter{
- Room: RoomFilter{
- Timeline: FilterPart{
+ Room: &RoomFilter{
+ Timeline: &FilterPart{
Limit: 50,
},
},
@@ -257,7 +264,7 @@ func dontProcessOldEvents(userID id.UserID, resp *RespSync, since string) bool {
// cli.Syncer.(mautrix.ExtensibleSyncer).OnSync(cli.MoveInviteState)
func (cli *Client) MoveInviteState(ctx context.Context, resp *RespSync, _ string) bool {
for _, meta := range resp.Rooms.Invite {
- var inviteState []event.StrippedState
+ var inviteState []*event.Event
var inviteEvt *event.Event
for _, evt := range meta.State.Events {
if evt.Type == event.StateMember && evt.GetStateKey() == cli.UserID.String() {
@@ -265,12 +272,7 @@ func (cli *Client) MoveInviteState(ctx context.Context, resp *RespSync, _ string
} else {
evt.Type.Class = event.StateEventType
_ = evt.Content.ParseRaw(evt.Type)
- inviteState = append(inviteState, event.StrippedState{
- Content: evt.Content,
- Type: evt.Type,
- StateKey: evt.GetStateKey(),
- Sender: evt.Sender,
- })
+ inviteState = append(inviteState, evt)
}
}
if inviteEvt != nil {
diff --git a/url.go b/url.go
index f35ae5e2..91b3d49d 100644
--- a/url.go
+++ b/url.go
@@ -57,13 +57,13 @@ func BuildURL(baseURL *url.URL, path ...any) *url.URL {
// BuildURL builds a URL with the Client's homeserver and appservice user ID set already.
func (cli *Client) BuildURL(urlPath PrefixableURLPath) string {
- return cli.BuildURLWithQuery(urlPath, nil)
+ return cli.BuildURLWithFullQuery(urlPath, nil)
}
// BuildClientURL builds a URL with the Client's homeserver and appservice user ID set already.
// This method also automatically prepends the client API prefix (/_matrix/client).
func (cli *Client) BuildClientURL(urlPath ...any) string {
- return cli.BuildURLWithQuery(ClientURLPath(urlPath), nil)
+ return cli.BuildURLWithFullQuery(ClientURLPath(urlPath), nil)
}
type PrefixableURLPath interface {
@@ -97,6 +97,19 @@ func (saup SynapseAdminURLPath) FullPath() []any {
// BuildURLWithQuery builds a URL with query parameters in addition to the Client's homeserver
// and appservice user ID set already.
func (cli *Client) BuildURLWithQuery(urlPath PrefixableURLPath, urlQuery map[string]string) string {
+ return cli.BuildURLWithFullQuery(urlPath, func(q url.Values) {
+ for k, v := range urlQuery {
+ q.Set(k, v)
+ }
+ })
+}
+
+// BuildURLWithQuery builds a URL with query parameters in addition to the Client's homeserver
+// and appservice user ID set already.
+func (cli *Client) BuildURLWithFullQuery(urlPath PrefixableURLPath, fn func(q url.Values)) string {
+ if cli == nil {
+ return "client is nil"
+ }
hsURL := *BuildURL(cli.HomeserverURL, urlPath.FullPath()...)
query := hsURL.Query()
if cli.SetAppServiceUserID {
@@ -106,10 +119,8 @@ func (cli *Client) BuildURLWithQuery(urlPath PrefixableURLPath, urlQuery map[str
query.Set("device_id", string(cli.DeviceID))
query.Set("org.matrix.msc3202.device_id", string(cli.DeviceID))
}
- if urlQuery != nil {
- for k, v := range urlQuery {
- query.Set(k, v)
- }
+ if fn != nil {
+ fn(query)
}
hsURL.RawQuery = query.Encode()
return hsURL.String()
diff --git a/version.go b/version.go
index dd70d55b..f00bbf39 100644
--- a/version.go
+++ b/version.go
@@ -4,10 +4,11 @@ import (
"fmt"
"regexp"
"runtime"
+ "runtime/debug"
"strings"
)
-const Version = "v0.22.0"
+const Version = "v0.26.3"
var GoModVersion = ""
var Commit = ""
@@ -15,11 +16,20 @@ var VersionWithCommit = Version
var DefaultUserAgent = "mautrix-go/" + Version + " go/" + strings.TrimPrefix(runtime.Version(), "go")
-var goModVersionRegex = regexp.MustCompile(`v.+\d{14}-([0-9a-f]{12})`)
-
func init() {
+ if GoModVersion == "" {
+ info, _ := debug.ReadBuildInfo()
+ if info != nil {
+ for _, mod := range info.Deps {
+ if mod.Path == "maunium.net/go/mautrix" {
+ GoModVersion = mod.Version
+ break
+ }
+ }
+ }
+ }
if GoModVersion != "" {
- match := goModVersionRegex.FindStringSubmatch(GoModVersion)
+ match := regexp.MustCompile(`v.+\d{14}-([0-9a-f]{12})`).FindStringSubmatch(GoModVersion)
if match != nil {
Commit = match[1]
}
diff --git a/versions.go b/versions.go
index 672018ff..61b2e4ea 100644
--- a/versions.go
+++ b/versions.go
@@ -60,17 +60,28 @@ type UnstableFeature struct {
}
var (
- FeatureAsyncUploads = UnstableFeature{UnstableFlag: "fi.mau.msc2246.stable", SpecVersion: SpecV17}
- FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17}
- FeatureAuthenticatedMedia = UnstableFeature{UnstableFlag: "org.matrix.msc3916.stable", SpecVersion: SpecV111}
+ FeatureAsyncUploads = UnstableFeature{UnstableFlag: "fi.mau.msc2246.stable", SpecVersion: SpecV17}
+ FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17}
+ FeatureAuthenticatedMedia = UnstableFeature{UnstableFlag: "org.matrix.msc3916.stable", SpecVersion: SpecV111}
+ FeatureUnstableMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms"}
+ FeatureStableMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms.stable" /*, SpecVersion: SpecV118*/}
+ FeatureUserRedaction = UnstableFeature{UnstableFlag: "org.matrix.msc4194"}
+ FeatureViewRedactedContent = UnstableFeature{UnstableFlag: "fi.mau.msc2815"}
+ FeatureUnstableAccountModeration = UnstableFeature{UnstableFlag: "uk.timedout.msc4323"}
+ FeatureStableAccountModeration = UnstableFeature{UnstableFlag: "uk.timedout.msc4323.stable" /*, SpecVersion: SpecV118*/}
+ FeatureUnstableProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133"}
+ FeatureArbitraryProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133.stable", SpecVersion: SpecV116}
+ FeatureRedactSendAsEvent = UnstableFeature{UnstableFlag: "com.beeper.msc4169"}
- BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"}
- BeeperFeatureBatchSending = UnstableFeature{UnstableFlag: "com.beeper.batch_sending"}
- BeeperFeatureRoomYeeting = UnstableFeature{UnstableFlag: "com.beeper.room_yeeting"}
- BeeperFeatureAutojoinInvites = UnstableFeature{UnstableFlag: "com.beeper.room_create_autojoin_invites"}
- BeeperFeatureArbitraryProfileMeta = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_profile_meta"}
- BeeperFeatureAccountDataMute = UnstableFeature{UnstableFlag: "com.beeper.account_data_mute"}
- BeeperFeatureInboxState = UnstableFeature{UnstableFlag: "com.beeper.inbox_state"}
+ BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"}
+ BeeperFeatureBatchSending = UnstableFeature{UnstableFlag: "com.beeper.batch_sending"}
+ BeeperFeatureRoomYeeting = UnstableFeature{UnstableFlag: "com.beeper.room_yeeting"}
+ BeeperFeatureAutojoinInvites = UnstableFeature{UnstableFlag: "com.beeper.room_create_autojoin_invites"}
+ BeeperFeatureArbitraryProfileMeta = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_profile_meta"}
+ BeeperFeatureAccountDataMute = UnstableFeature{UnstableFlag: "com.beeper.account_data_mute"}
+ BeeperFeatureInboxState = UnstableFeature{UnstableFlag: "com.beeper.inbox_state"}
+ BeeperFeatureArbitraryMemberChange = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_member_change"}
+ BeeperFeatureEphemeralEvents = UnstableFeature{UnstableFlag: "com.beeper.ephemeral"}
)
func (versions *RespVersions) Supports(feature UnstableFeature) bool {
@@ -110,6 +121,12 @@ var (
SpecV19 = MustParseSpecVersion("v1.9")
SpecV110 = MustParseSpecVersion("v1.10")
SpecV111 = MustParseSpecVersion("v1.11")
+ SpecV112 = MustParseSpecVersion("v1.12")
+ SpecV113 = MustParseSpecVersion("v1.13")
+ SpecV114 = MustParseSpecVersion("v1.14")
+ SpecV115 = MustParseSpecVersion("v1.15")
+ SpecV116 = MustParseSpecVersion("v1.16")
+ SpecV117 = MustParseSpecVersion("v1.17")
)
func (svf SpecVersionFormat) String() string {