%s", html.EscapeString(qr)),
+ Info: &event.FileInfo{
+ MimeType: "image/png",
+ Width: qrSizePx,
+ Height: qrSizePx,
+ Size: len(qrData),
+ },
}
if *prevEventID != "" {
content.SetEdit(*prevEventID)
@@ -217,18 +279,55 @@ func sendQR(ce *Event, qr string, prevEventID *id.EventID) error {
return nil
}
+func sendUserInputAttachments(ce *Event, atts []*bridgev2.LoginUserInputAttachment) error {
+ for _, att := range atts {
+ if att.FileName == "" {
+ return fmt.Errorf("missing attachment filename")
+ }
+ mxc, file, err := ce.Bot.UploadMedia(ce.Ctx, ce.RoomID, att.Content, att.FileName, att.Info.MimeType)
+ if err != nil {
+ return fmt.Errorf("failed to upload attachment %q: %w", att.FileName, err)
+ }
+ content := &event.MessageEventContent{
+ MsgType: att.Type,
+ FileName: att.FileName,
+ URL: mxc,
+ File: file,
+ Info: &event.FileInfo{
+ MimeType: att.Info.MimeType,
+ Width: att.Info.Width,
+ Height: att.Info.Height,
+ Size: att.Info.Size,
+ },
+ Body: att.FileName,
+ }
+ _, err = ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventMessage, &event.Content{Parsed: content}, nil)
+ if err != nil {
+ return nil
+ }
+ }
+ return nil
+}
+
type contextKey int
const (
contextKeyPrevEventID contextKey = iota
)
-func doLoginDisplayAndWait(ce *Event, login bridgev2.LoginProcessDisplayAndWait, step *bridgev2.LoginStep) {
+func doLoginDisplayAndWait(ce *Event, login bridgev2.LoginProcessDisplayAndWait, step *bridgev2.LoginStep, override *bridgev2.UserLogin) {
prevEvent, ok := ce.Ctx.Value(contextKeyPrevEventID).(*id.EventID)
if !ok {
prevEvent = new(id.EventID)
ce.Ctx = context.WithValue(ce.Ctx, contextKeyPrevEventID, prevEvent)
}
+ cancelCtx, cancelFunc := context.WithCancel(ce.Ctx)
+ defer cancelFunc()
+ StoreCommandState(ce.User, &CommandState{
+ Action: "Login",
+ Cancel: cancelFunc,
+ })
+ defer StoreCommandState(ce.User, nil)
switch step.DisplayAndWaitParams.Type {
case bridgev2.LoginDisplayTypeQR:
err := sendQR(ce, step.DisplayAndWaitParams.Data, prevEvent)
@@ -248,7 +347,7 @@ func doLoginDisplayAndWait(ce *Event, login bridgev2.LoginProcessDisplayAndWait,
login.Cancel()
return
}
- nextStep, err := login.Wait(ce.Ctx)
+ nextStep, err := login.Wait(cancelCtx)
// Redact the QR code, unless the next step is refreshing the code (in which case the event is just edited)
if *prevEvent != "" && (nextStep == nil || nextStep.StepID != step.StepID) {
_, _ = ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventRedaction, &event.Content{
@@ -262,12 +361,13 @@ func doLoginDisplayAndWait(ce *Event, login bridgev2.LoginProcessDisplayAndWait,
ce.Reply("Login failed: %v", err)
return
}
- doLoginStep(ce, login, nextStep)
+ doLoginStep(ce, login, nextStep, override)
}
type cookieLoginCommandState struct {
- Login bridgev2.LoginProcessCookies
- Data *bridgev2.LoginCookiesParams
+ Login bridgev2.LoginProcessCookies
+ Data *bridgev2.LoginCookiesParams
+ Override *bridgev2.UserLogin
}
func (clcs *cookieLoginCommandState) prompt(ce *Event) {
@@ -292,7 +392,7 @@ func (clcs *cookieLoginCommandState) submit(ce *Event) {
}
reqCookies := make(map[string]string)
for _, cookie := range parsed.Cookies() {
- reqCookies[cookie.Name], err = url.QueryUnescape(cookie.Value)
+ reqCookies[cookie.Name], err = url.PathUnescape(cookie.Value)
if err != nil {
ce.Reply("Failed to parse cookie %s: %v", cookie.Name, err)
return
@@ -365,7 +465,7 @@ func (clcs *cookieLoginCommandState) submit(ce *Event) {
missingKeys = append(missingKeys, field.ID)
}
if match, _ := regexp.MatchString(field.Pattern, val); !match {
- ce.Reply("Invalid value for %s: doesn't match regex `%s`", field.ID, field.Pattern)
+ ce.Reply("Invalid value for %s: `%s` doesn't match regex `%s`", field.ID, val, field.Pattern)
return
}
}
@@ -379,7 +479,7 @@ func (clcs *cookieLoginCommandState) submit(ce *Event) {
ce.Reply("Login failed: %v", err)
return
}
- doLoginStep(ce, clcs.Login, nextStep)
+ doLoginStep(ce, clcs.Login, nextStep, clcs.Override)
}
func maybeURLDecodeCookie(val string, field *bridgev2.LoginCookieField) string {
@@ -392,38 +492,50 @@ func maybeURLDecodeCookie(val string, field *bridgev2.LoginCookieField) string {
if !isCookie {
return val
}
- match, _ := regexp.MatchString(field.Pattern, val)
- if !match {
- return val
- }
- decoded, err := url.QueryUnescape(val)
+ decoded, err := url.PathUnescape(val)
if err != nil {
return val
}
return decoded
}
-func doLoginStep(ce *Event, login bridgev2.LoginProcess, step *bridgev2.LoginStep) {
+func doLoginStep(ce *Event, login bridgev2.LoginProcess, step *bridgev2.LoginStep, override *bridgev2.UserLogin) {
+ ce.Log.Debug().Any("next_step", step).Msg("Got next login step")
if step.Instructions != "" {
ce.Reply(step.Instructions)
}
switch step.Type {
case bridgev2.LoginStepTypeDisplayAndWait:
- doLoginDisplayAndWait(ce, login.(bridgev2.LoginProcessDisplayAndWait), step)
+ doLoginDisplayAndWait(ce, login.(bridgev2.LoginProcessDisplayAndWait), step, override)
case bridgev2.LoginStepTypeCookies:
(&cookieLoginCommandState{
- Login: login.(bridgev2.LoginProcessCookies),
- Data: step.CookiesParams,
+ Login: login.(bridgev2.LoginProcessCookies),
+ Data: step.CookiesParams,
+ Override: override,
}).prompt(ce)
case bridgev2.LoginStepTypeUserInput:
+ err := sendUserInputAttachments(ce, step.UserInputParams.Attachments)
+ if err != nil {
+ ce.Reply("Failed to send attachments: %v", err)
+ }
(&userInputLoginCommandState{
Login: login.(bridgev2.LoginProcessUserInput),
RemainingFields: step.UserInputParams.Fields,
Data: make(map[string]string),
+ Override: override,
}).promptNext(ce)
case bridgev2.LoginStepTypeComplete:
- // Nothing to do other than instructions
+ if override != nil && override.ID != step.CompleteParams.UserLoginID {
+ ce.Log.Info().
+ Str("old_login_id", string(override.ID)).
+ Str("new_login_id", string(step.CompleteParams.UserLoginID)).
+ Msg("Login resulted in different remote ID than what was being overridden. Deleting previous login")
+ override.Delete(ce.Ctx, status.BridgeState{
+ StateEvent: status.StateLoggedOut,
+ Reason: "LOGIN_OVERRIDDEN",
+ }, bridgev2.DeleteOpts{LogoutRemote: true})
+ }
default:
panic(fmt.Errorf("unknown login step type %q", step.Type))
}
diff --git a/bridgev2/commands/processor.go b/bridgev2/commands/processor.go
index 1aca596c..391c3685 100644
--- a/bridgev2/commands/processor.go
+++ b/bridgev2/commands/processor.go
@@ -17,8 +17,7 @@ import (
"github.com/rs/zerolog"
"maunium.net/go/mautrix/bridgev2"
-
- "maunium.net/go/mautrix/bridge/status"
+ "maunium.net/go/mautrix/bridgev2/status"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
@@ -42,10 +41,11 @@ func NewProcessor(bridge *bridgev2.Bridge) bridgev2.CommandProcessor {
}
proc.AddHandlers(
CommandHelp, CommandCancel,
- CommandRegisterPush, CommandDeletePortal, CommandDeleteAllPortals,
- CommandLogin, CommandListLogins, CommandLogout, CommandSetPreferredLogin,
+ CommandRegisterPush, CommandSendAccountData, CommandResetNetwork,
+ CommandDeletePortal, CommandDeleteAllPortals, CommandSetManagementRoom,
+ CommandLogin, CommandRelogin, CommandListLogins, CommandLogout, CommandSetPreferredLogin,
CommandSetRelay, CommandUnsetRelay,
- CommandResolveIdentifier, CommandStartChat, CommandSearch,
+ CommandResolveIdentifier, CommandStartChat, CommandCreateGroup, CommandSearch, CommandSyncChat, CommandMute,
CommandSudo, CommandDoIn,
)
return proc
diff --git a/bridgev2/commands/relay.go b/bridgev2/commands/relay.go
index af756c87..94c19739 100644
--- a/bridgev2/commands/relay.go
+++ b/bridgev2/commands/relay.go
@@ -37,7 +37,7 @@ func fnSetRelay(ce *Event) {
}
onlySetDefaultRelays := !ce.User.Permissions.Admin && ce.Bridge.Config.Relay.AdminOnly
var relay *bridgev2.UserLogin
- if len(ce.Args) == 0 {
+ if len(ce.Args) == 0 && ce.Portal.Receiver == "" {
relay = ce.User.GetDefaultLogin()
isLoggedIn := relay != nil
if onlySetDefaultRelays {
@@ -73,9 +73,19 @@ func fnSetRelay(ce *Event) {
}
}
} else {
- relay = ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0]))
+ var targetID networkid.UserLoginID
+ if ce.Portal.Receiver != "" {
+ targetID = ce.Portal.Receiver
+ if len(ce.Args) > 0 && ce.Args[0] != string(targetID) {
+ ce.Reply("In split portals, only the receiver (%s) can be set as relay", targetID)
+ return
+ }
+ } else {
+ targetID = networkid.UserLoginID(ce.Args[0])
+ }
+ relay = ce.Bridge.GetCachedUserLoginByID(targetID)
if relay == nil {
- ce.Reply("User login with ID `%s` not found", ce.Args[0])
+ ce.Reply("User login with ID `%s` not found", targetID)
return
} else if slices.Contains(ce.Bridge.Config.Relay.DefaultRelays, relay.ID) {
// All good
diff --git a/bridgev2/commands/startchat.go b/bridgev2/commands/startchat.go
index 53c07530..c7b05a6e 100644
--- a/bridgev2/commands/startchat.go
+++ b/bridgev2/commands/startchat.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2024 Tulir Asokan
+// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -8,14 +8,21 @@ package commands
import (
"context"
+ "errors"
"fmt"
+ "html"
+ "maps"
+ "slices"
"strings"
"time"
- "golang.org/x/net/html"
+ "github.com/rs/zerolog"
"maunium.net/go/mautrix/bridgev2"
"maunium.net/go/mautrix/bridgev2/networkid"
+ "maunium.net/go/mautrix/bridgev2/provisionutil"
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/format"
"maunium.net/go/mautrix/id"
)
@@ -28,6 +35,36 @@ var CommandResolveIdentifier = &FullHandler{
Args: "[_login ID_] <_identifier_>",
},
RequiresLogin: true,
+ NetworkAPI: NetworkAPIImplements[bridgev2.IdentifierResolvingNetworkAPI],
+}
+
+var CommandSyncChat = &FullHandler{
+ Func: func(ce *Event) {
+ login, _, err := ce.Portal.FindPreferredLogin(ce.Ctx, ce.User, false)
+ if err != nil {
+ ce.Log.Err(err).Msg("Failed to find login for sync")
+ ce.Reply("Failed to find login: %v", err)
+ return
+ } else if login == nil {
+ ce.Reply("No login found for sync")
+ return
+ }
+ info, err := login.Client.GetChatInfo(ce.Ctx, ce.Portal)
+ if err != nil {
+ ce.Log.Err(err).Msg("Failed to get chat info for sync")
+ ce.Reply("Failed to get chat info: %v", err)
+ return
+ }
+ ce.Portal.UpdateInfo(ce.Ctx, info, login, nil, time.Time{})
+ ce.React("✅️")
+ },
+ Name: "sync-portal",
+ Help: HelpMeta{
+ Section: HelpSectionChats,
+ Description: "Sync the current portal room",
+ },
+ RequiresPortal: true,
+ RequiresLogin: true,
}
var CommandStartChat = &FullHandler{
@@ -40,11 +77,18 @@ var CommandStartChat = &FullHandler{
Args: "[_login ID_] <_identifier_>",
},
RequiresLogin: true,
+ NetworkAPI: NetworkAPIImplements[bridgev2.IdentifierResolvingNetworkAPI],
}
-func getClientForStartingChat[T bridgev2.IdentifierResolvingNetworkAPI](ce *Event, thing string) (*bridgev2.UserLogin, T, []string) {
- remainingArgs := ce.Args[1:]
- login := ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0]))
+func getClientForStartingChat[T bridgev2.NetworkAPI](ce *Event, thing string) (*bridgev2.UserLogin, T, []string) {
+ var remainingArgs []string
+ if len(ce.Args) > 1 {
+ remainingArgs = ce.Args[1:]
+ }
+ var login *bridgev2.UserLogin
+ if len(ce.Args) > 0 {
+ login = ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0]))
+ }
if login == nil || login.UserMXID != ce.User.MXID {
remainingArgs = ce.Args
login = ce.User.GetDefaultLogin()
@@ -56,24 +100,13 @@ func getClientForStartingChat[T bridgev2.IdentifierResolvingNetworkAPI](ce *Even
return login, api, remainingArgs
}
-func formatResolveIdentifierResult(ctx context.Context, resp *bridgev2.ResolveIdentifierResponse) string {
- var targetName string
- var targetMXID id.UserID
- if resp.Ghost != nil {
- if resp.UserInfo != nil {
- resp.Ghost.UpdateInfo(ctx, resp.UserInfo)
- }
- targetName = resp.Ghost.Name
- targetMXID = resp.Ghost.Intent.GetMXID()
- } else if resp.UserInfo != nil && resp.UserInfo.Name != nil {
- targetName = *resp.UserInfo.Name
- }
- if targetMXID != "" {
- return fmt.Sprintf("`%s` / [%s](%s)", resp.UserID, targetName, targetMXID.URI().MatrixToURL())
- } else if targetName != "" {
- return fmt.Sprintf("`%s` / %s", resp.UserID, targetName)
+func formatResolveIdentifierResult(resp *provisionutil.RespResolveIdentifier) string {
+ if resp.MXID != "" {
+ return fmt.Sprintf("`%s` / [%s](%s)", resp.ID, resp.Name, resp.MXID.URI().MatrixToURL())
+ } else if resp.Name != "" {
+ return fmt.Sprintf("`%s` / %s", resp.ID, resp.Name)
} else {
- return fmt.Sprintf("`%s`", resp.UserID)
+ return fmt.Sprintf("`%s`", resp.ID)
}
}
@@ -86,65 +119,137 @@ func fnResolveIdentifier(ce *Event) {
if api == nil {
return
}
- createChat := ce.Command == "start-chat"
+ allLogins := ce.User.GetUserLogins()
+ createChat := ce.Command == "start-chat" || ce.Command == "pm"
identifier := strings.Join(identifierParts, " ")
- resp, err := api.ResolveIdentifier(ce.Ctx, identifier, createChat)
+ resp, err := provisionutil.ResolveIdentifier(ce.Ctx, login, identifier, createChat)
+ for i := 0; i < len(allLogins) && errors.Is(err, bridgev2.ErrResolveIdentifierTryNext); i++ {
+ resp, err = provisionutil.ResolveIdentifier(ce.Ctx, allLogins[i], identifier, createChat)
+ }
if err != nil {
- ce.Log.Err(err).Msg("Failed to resolve identifier")
ce.Reply("Failed to resolve identifier: %v", err)
return
} else if resp == nil {
ce.ReplyAdvanced(fmt.Sprintf("Identifier %s not found", html.EscapeString(identifier)), false, true)
return
}
- formattedName := formatResolveIdentifierResult(ce.Ctx, resp)
+ formattedName := formatResolveIdentifierResult(resp)
if createChat {
- if resp.Chat == nil {
- ce.Reply("Interface error: network connector did not return chat for create chat request")
- return
+ name := resp.Portal.Name
+ if name == "" {
+ name = resp.Portal.MXID.String()
}
- portal := resp.Chat.Portal
- if portal == nil {
- portal, err = ce.Bridge.GetPortalByKey(ce.Ctx, resp.Chat.PortalKey)
- if err != nil {
- ce.Log.Err(err).Msg("Failed to get portal")
- ce.Reply("Failed to get portal: %v", err)
- return
- }
- }
- if resp.Chat.PortalInfo == nil {
- resp.Chat.PortalInfo, err = api.GetChatInfo(ce.Ctx, portal)
- if err != nil {
- ce.Log.Err(err).Msg("Failed to get portal info")
- ce.Reply("Failed to get portal info: %v", err)
- return
- }
- }
- if portal.MXID != "" {
- name := portal.Name
- if name == "" {
- name = portal.MXID.String()
- }
- portal.UpdateInfo(ce.Ctx, resp.Chat.PortalInfo, login, nil, time.Time{})
- ce.Reply("You already have a direct chat with %s at [%s](%s)", formattedName, name, portal.MXID.URI().MatrixToURL())
+ if !resp.JustCreated {
+ ce.Reply("You already have a direct chat with %s at [%s](%s)", formattedName, name, resp.Portal.MXID.URI().MatrixToURL())
} else {
- err = portal.CreateMatrixRoom(ce.Ctx, login, resp.Chat.PortalInfo)
- if err != nil {
- ce.Log.Err(err).Msg("Failed to create room")
- ce.Reply("Failed to create room: %v", err)
- return
- }
- name := portal.Name
- if name == "" {
- name = portal.MXID.String()
- }
- ce.Reply("Created chat with %s: [%s](%s)", formattedName, name, portal.MXID.URI().MatrixToURL())
+ ce.Reply("Created chat with %s: [%s](%s)", formattedName, name, resp.Portal.MXID.URI().MatrixToURL())
}
} else {
ce.Reply("Found %s", formattedName)
}
}
+var CommandCreateGroup = &FullHandler{
+ Func: fnCreateGroup,
+ Name: "create-group",
+ Aliases: []string{"create"},
+ Help: HelpMeta{
+ Section: HelpSectionChats,
+ Description: "Create a new group chat for the current Matrix room",
+ Args: "[_group type_]",
+ },
+ RequiresLogin: true,
+ NetworkAPI: NetworkAPIImplements[bridgev2.GroupCreatingNetworkAPI],
+}
+
+func getState[T any](ctx context.Context, roomID id.RoomID, evtType event.Type, provider bridgev2.MatrixConnectorWithArbitraryRoomState) (content T) {
+ evt, err := provider.GetStateEvent(ctx, roomID, evtType, "")
+ if err != nil {
+ zerolog.Ctx(ctx).Err(err).Stringer("event_type", evtType).Msg("Failed to get state event for group creation")
+ } else if evt != nil {
+ content, _ = evt.Content.Parsed.(T)
+ }
+ return
+}
+
+func fnCreateGroup(ce *Event) {
+ ce.Bridge.Matrix.GetCapabilities()
+ login, api, remainingArgs := getClientForStartingChat[bridgev2.GroupCreatingNetworkAPI](ce, "creating group")
+ if api == nil {
+ return
+ }
+ stateProvider, ok := ce.Bridge.Matrix.(bridgev2.MatrixConnectorWithArbitraryRoomState)
+ if !ok {
+ ce.Reply("Matrix connector doesn't support fetching room state")
+ return
+ }
+ members, err := ce.Bridge.Matrix.GetMembers(ce.Ctx, ce.RoomID)
+ if err != nil {
+ ce.Log.Err(err).Msg("Failed to get room members for group creation")
+ ce.Reply("Failed to get room members: %v", err)
+ return
+ }
+ caps := ce.Bridge.Network.GetCapabilities()
+ params := &bridgev2.GroupCreateParams{
+ Username: "",
+ Participants: make([]networkid.UserID, 0, len(members)-2),
+ Parent: nil, // TODO check space parent event
+ Name: getState[*event.RoomNameEventContent](ce.Ctx, ce.RoomID, event.StateRoomName, stateProvider),
+ Avatar: getState[*event.RoomAvatarEventContent](ce.Ctx, ce.RoomID, event.StateRoomAvatar, stateProvider),
+ Topic: getState[*event.TopicEventContent](ce.Ctx, ce.RoomID, event.StateTopic, stateProvider),
+ Disappear: getState[*event.BeeperDisappearingTimer](ce.Ctx, ce.RoomID, event.StateBeeperDisappearingTimer, stateProvider),
+ RoomID: ce.RoomID,
+ }
+ for userID, member := range members {
+ if userID == ce.User.MXID || userID == ce.Bot.GetMXID() || !member.Membership.IsInviteOrJoin() {
+ continue
+ }
+ if parsedUserID, ok := ce.Bridge.Matrix.ParseGhostMXID(userID); ok {
+ params.Participants = append(params.Participants, parsedUserID)
+ } else if !ce.Bridge.Config.SplitPortals {
+ if user, err := ce.Bridge.GetExistingUserByMXID(ce.Ctx, userID); err != nil {
+ ce.Log.Err(err).Stringer("user_id", userID).Msg("Failed to get user for room member")
+ } else if user != nil {
+ // TODO add user logins to participants
+ //for _, login := range user.GetUserLogins() {
+ // params.Participants = append(params.Participants, login.GetUserID())
+ //}
+ }
+ }
+ }
+
+ if len(caps.Provisioning.GroupCreation) == 0 {
+ ce.Reply("No group creation types defined in network capabilities")
+ return
+ } else if len(remainingArgs) > 0 {
+ params.Type = remainingArgs[0]
+ } else if len(caps.Provisioning.GroupCreation) == 1 {
+ for params.Type = range caps.Provisioning.GroupCreation {
+ // The loop assigns the variable we want
+ }
+ } else {
+ types := strings.Join(slices.Collect(maps.Keys(caps.Provisioning.GroupCreation)), "`, `")
+ ce.Reply("Please specify type of group to create: `%s`", types)
+ return
+ }
+ resp, err := provisionutil.CreateGroup(ce.Ctx, login, params)
+ if err != nil {
+ ce.Reply("Failed to create group: %v", err)
+ return
+ }
+ var postfix string
+ if len(resp.FailedParticipants) > 0 {
+ failedParticipantsStrings := make([]string, len(resp.FailedParticipants))
+ i := 0
+ for participantID, meta := range resp.FailedParticipants {
+ failedParticipantsStrings[i] = fmt.Sprintf("* %s: %s", format.SafeMarkdownCode(participantID), meta.Reason)
+ i++
+ }
+ postfix += "\n\nFailed to add some participants:\n" + strings.Join(failedParticipantsStrings, "\n")
+ }
+ ce.Reply("Successfully created group `%s`%s", resp.ID, postfix)
+}
+
var CommandSearch = &FullHandler{
Func: fnSearch,
Name: "search",
@@ -154,6 +259,7 @@ var CommandSearch = &FullHandler{
Args: "<_query_>",
},
RequiresLogin: true,
+ NetworkAPI: NetworkAPIImplements[bridgev2.UserSearchingNetworkAPI],
}
func fnSearch(ce *Event) {
@@ -161,35 +267,67 @@ func fnSearch(ce *Event) {
ce.Reply("Usage: `$cmdprefix search
+ Blockquote = "blockquote", // blockquote
+ InlineLink = "inline_link", // a
+ UserLink = "user_link", //
+ RoomLink = "room_link", //
+ EventLink = "event_link", //
+ AtRoomMention = "at_room_mention", // @room (no html tag)
+ UnorderedList = "unordered_list", // ul + li
+ OrderedList = "ordered_list", // ol + li
+ ListStart = "ordered_list.start", //
+ ListJumpValue = "ordered_list.jump_value", // -
+ CustomEmoji = "custom_emoji", //
+ Spoiler = "spoiler", //
+ SpoilerReason = "spoiler.reason", //
+ TextForegroundColor = "color.foreground", //
+ TextBackgroundColor = "color.background", //
+ HorizontalLine = "horizontal_line", // hr
+ Headers = "headers", // h1, h2, h3, h4, h5, h6
+ Superscript = "superscript", // sup
+ Subscript = "subscript", // sub
+ Math = "math", //
+ DetailsSummary = "details_summary", // ...
...
+ Table = "table", // table, thead, tbody, tr, th, td
+}
diff --git a/event/capabilities.go b/event/capabilities.go
new file mode 100644
index 00000000..a86c726b
--- /dev/null
+++ b/event/capabilities.go
@@ -0,0 +1,414 @@
+// Copyright (c) 2024 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package event
+
+import (
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/binary"
+ "fmt"
+ "io"
+ "mime"
+ "slices"
+ "strings"
+
+ "go.mau.fi/util/exerrors"
+ "go.mau.fi/util/jsontime"
+ "go.mau.fi/util/ptr"
+ "golang.org/x/exp/constraints"
+ "golang.org/x/exp/maps"
+)
+
+type RoomFeatures struct {
+ ID string `json:"id,omitempty"`
+
+ // N.B. New fields need to be added to the Hash function to be included in the deduplication hash.
+
+ Formatting FormattingFeatureMap `json:"formatting,omitempty"`
+ File FileFeatureMap `json:"file,omitempty"`
+ State StateFeatureMap `json:"state,omitempty"`
+ MemberActions MemberFeatureMap `json:"member_actions,omitempty"`
+
+ MaxTextLength int `json:"max_text_length,omitempty"`
+
+ LocationMessage CapabilitySupportLevel `json:"location_message,omitempty"`
+ Poll CapabilitySupportLevel `json:"poll,omitempty"`
+ Thread CapabilitySupportLevel `json:"thread,omitempty"`
+ Reply CapabilitySupportLevel `json:"reply,omitempty"`
+
+ Edit CapabilitySupportLevel `json:"edit,omitempty"`
+ EditMaxCount int `json:"edit_max_count,omitempty"`
+ EditMaxAge *jsontime.Seconds `json:"edit_max_age,omitempty"`
+ Delete CapabilitySupportLevel `json:"delete,omitempty"`
+ DeleteForMe bool `json:"delete_for_me,omitempty"`
+ DeleteMaxAge *jsontime.Seconds `json:"delete_max_age,omitempty"`
+
+ DisappearingTimer *DisappearingTimerCapability `json:"disappearing_timer,omitempty"`
+
+ Reaction CapabilitySupportLevel `json:"reaction,omitempty"`
+ ReactionCount int `json:"reaction_count,omitempty"`
+ AllowedReactions []string `json:"allowed_reactions,omitempty"`
+ CustomEmojiReactions bool `json:"custom_emoji_reactions,omitempty"`
+
+ ReadReceipts bool `json:"read_receipts,omitempty"`
+ TypingNotifications bool `json:"typing_notifications,omitempty"`
+ Archive bool `json:"archive,omitempty"`
+ MarkAsUnread bool `json:"mark_as_unread,omitempty"`
+ DeleteChat bool `json:"delete_chat,omitempty"`
+ DeleteChatForEveryone bool `json:"delete_chat_for_everyone,omitempty"`
+
+ MessageRequest *MessageRequestFeatures `json:"message_request,omitempty"`
+
+ PerMessageProfileRelay bool `json:"-"`
+}
+
+func (rf *RoomFeatures) GetID() string {
+ if rf.ID != "" {
+ return rf.ID
+ }
+ return base64.RawURLEncoding.EncodeToString(rf.Hash())
+}
+
+func (rf *RoomFeatures) Clone() *RoomFeatures {
+ if rf == nil {
+ return nil
+ }
+ clone := *rf
+ clone.File = clone.File.Clone()
+ clone.Formatting = maps.Clone(clone.Formatting)
+ clone.State = clone.State.Clone()
+ clone.MemberActions = clone.MemberActions.Clone()
+ clone.EditMaxAge = ptr.Clone(clone.EditMaxAge)
+ clone.DeleteMaxAge = ptr.Clone(clone.DeleteMaxAge)
+ clone.DisappearingTimer = clone.DisappearingTimer.Clone()
+ clone.AllowedReactions = slices.Clone(clone.AllowedReactions)
+ clone.MessageRequest = clone.MessageRequest.Clone()
+ return &clone
+}
+
+type MemberFeatureMap map[MemberAction]CapabilitySupportLevel
+
+func (mfm MemberFeatureMap) Clone() MemberFeatureMap {
+ return maps.Clone(mfm)
+}
+
+type MemberAction string
+
+const (
+ MemberActionBan MemberAction = "ban"
+ MemberActionKick MemberAction = "kick"
+ MemberActionLeave MemberAction = "leave"
+ MemberActionRevokeInvite MemberAction = "revoke_invite"
+ MemberActionInvite MemberAction = "invite"
+)
+
+type StateFeatureMap map[string]*StateFeatures
+
+func (sfm StateFeatureMap) Clone() StateFeatureMap {
+ dup := maps.Clone(sfm)
+ for key, value := range dup {
+ dup[key] = value.Clone()
+ }
+ return dup
+}
+
+type StateFeatures struct {
+ Level CapabilitySupportLevel `json:"level"`
+}
+
+func (sf *StateFeatures) Clone() *StateFeatures {
+ if sf == nil {
+ return nil
+ }
+ clone := *sf
+ return &clone
+}
+
+func (sf *StateFeatures) Hash() []byte {
+ return sf.Level.Hash()
+}
+
+type FormattingFeatureMap map[FormattingFeature]CapabilitySupportLevel
+
+type FileFeatureMap map[CapabilityMsgType]*FileFeatures
+
+func (ffm FileFeatureMap) Clone() FileFeatureMap {
+ dup := maps.Clone(ffm)
+ for key, value := range dup {
+ dup[key] = value.Clone()
+ }
+ return dup
+}
+
+type DisappearingTimerCapability struct {
+ Types []DisappearingType `json:"types"`
+ Timers []jsontime.Milliseconds `json:"timers,omitempty"`
+
+ OmitEmptyTimer bool `json:"omit_empty_timer,omitempty"`
+}
+
+func (dtc *DisappearingTimerCapability) Clone() *DisappearingTimerCapability {
+ if dtc == nil {
+ return nil
+ }
+ clone := *dtc
+ clone.Types = slices.Clone(clone.Types)
+ clone.Timers = slices.Clone(clone.Timers)
+ return &clone
+}
+
+func (dtc *DisappearingTimerCapability) Supports(content *BeeperDisappearingTimer) bool {
+ if dtc == nil || content == nil || content.Type == DisappearingTypeNone {
+ return true
+ }
+ return slices.Contains(dtc.Types, content.Type) && (dtc.Timers == nil || slices.Contains(dtc.Timers, content.Timer))
+}
+
+type MessageRequestFeatures struct {
+ AcceptWithMessage CapabilitySupportLevel `json:"accept_with_message,omitempty"`
+ AcceptWithButton CapabilitySupportLevel `json:"accept_with_button,omitempty"`
+}
+
+func (mrf *MessageRequestFeatures) Clone() *MessageRequestFeatures {
+ return ptr.Clone(mrf)
+}
+
+func (mrf *MessageRequestFeatures) Hash() []byte {
+ if mrf == nil {
+ return nil
+ }
+ hasher := sha256.New()
+ hashValue(hasher, "accept_with_message", mrf.AcceptWithMessage)
+ hashValue(hasher, "accept_with_button", mrf.AcceptWithButton)
+ return hasher.Sum(nil)
+}
+
+type CapabilityMsgType = MessageType
+
+// Message types which are used for event capability signaling, but aren't real values for the msgtype field.
+const (
+ CapMsgVoice CapabilityMsgType = "org.matrix.msc3245.voice"
+ CapMsgGIF CapabilityMsgType = "fi.mau.gif"
+ CapMsgSticker CapabilityMsgType = "m.sticker"
+)
+
+type CapabilitySupportLevel int
+
+func (csl CapabilitySupportLevel) Partial() bool {
+ return csl >= CapLevelPartialSupport
+}
+
+func (csl CapabilitySupportLevel) Full() bool {
+ return csl >= CapLevelFullySupported
+}
+
+func (csl CapabilitySupportLevel) Reject() bool {
+ return csl <= CapLevelRejected
+}
+
+const (
+ CapLevelRejected CapabilitySupportLevel = -2 // The feature is unsupported and messages using it will be rejected.
+ CapLevelDropped CapabilitySupportLevel = -1 // The feature is unsupported and has no fallback. The message will go through, but data may be lost.
+ CapLevelUnsupported CapabilitySupportLevel = 0 // The feature is unsupported, but may have a fallback.
+ CapLevelPartialSupport CapabilitySupportLevel = 1 // The feature is partially supported (e.g. it may be converted to a different format).
+ CapLevelFullySupported CapabilitySupportLevel = 2 // The feature is fully supported and can be safely used.
+)
+
+type FormattingFeature string
+
+const (
+ FmtBold FormattingFeature = "bold" // strong, b
+ FmtItalic FormattingFeature = "italic" // em, i
+ FmtUnderline FormattingFeature = "underline" // u
+ FmtStrikethrough FormattingFeature = "strikethrough" // del, s
+ FmtInlineCode FormattingFeature = "inline_code" // code
+ FmtCodeBlock FormattingFeature = "code_block" // pre + code
+ FmtSyntaxHighlighting FormattingFeature = "code_block.syntax_highlighting" //
+ FmtBlockquote FormattingFeature = "blockquote" // blockquote
+ FmtInlineLink FormattingFeature = "inline_link" // a
+ FmtUserLink FormattingFeature = "user_link" //
+ FmtRoomLink FormattingFeature = "room_link" //
+ FmtEventLink FormattingFeature = "event_link" //
+ FmtAtRoomMention FormattingFeature = "at_room_mention" // @room (no html tag)
+ FmtUnorderedList FormattingFeature = "unordered_list" // ul + li
+ FmtOrderedList FormattingFeature = "ordered_list" // ol + li
+ FmtListStart FormattingFeature = "ordered_list.start" //
+ FmtListJumpValue FormattingFeature = "ordered_list.jump_value" // -
+ FmtCustomEmoji FormattingFeature = "custom_emoji" //
+ FmtSpoiler FormattingFeature = "spoiler" //
+ FmtSpoilerReason FormattingFeature = "spoiler.reason" //
+ FmtTextForegroundColor FormattingFeature = "color.foreground" //
+ FmtTextBackgroundColor FormattingFeature = "color.background" //
+ FmtHorizontalLine FormattingFeature = "horizontal_line" // hr
+ FmtHeaders FormattingFeature = "headers" // h1, h2, h3, h4, h5, h6
+ FmtSuperscript FormattingFeature = "superscript" // sup
+ FmtSubscript FormattingFeature = "subscript" // sub
+ FmtMath FormattingFeature = "math" //
+ FmtDetailsSummary FormattingFeature = "details_summary" // ...
...
+ FmtTable FormattingFeature = "table" // table, thead, tbody, tr, th, td
+)
+
+type FileFeatures struct {
+ // N.B. New fields need to be added to the Hash function to be included in the deduplication hash.
+
+ MimeTypes map[string]CapabilitySupportLevel `json:"mime_types"`
+
+ Caption CapabilitySupportLevel `json:"caption,omitempty"`
+ MaxCaptionLength int `json:"max_caption_length,omitempty"`
+
+ MaxSize int64 `json:"max_size,omitempty"`
+ MaxWidth int `json:"max_width,omitempty"`
+ MaxHeight int `json:"max_height,omitempty"`
+ MaxDuration *jsontime.Seconds `json:"max_duration,omitempty"`
+
+ ViewOnce bool `json:"view_once,omitempty"`
+}
+
+func (ff *FileFeatures) GetMimeSupport(inputType string) CapabilitySupportLevel {
+ match, ok := ff.MimeTypes[inputType]
+ if ok {
+ return match
+ }
+ if strings.IndexByte(inputType, ';') != -1 {
+ plainMime, _, _ := mime.ParseMediaType(inputType)
+ if plainMime != "" {
+ if match, ok = ff.MimeTypes[plainMime]; ok {
+ return match
+ }
+ }
+ }
+ if slash := strings.IndexByte(inputType, '/'); slash > 0 {
+ generalType := fmt.Sprintf("%s/*", inputType[:slash])
+ if match, ok = ff.MimeTypes[generalType]; ok {
+ return match
+ }
+ }
+ match, ok = ff.MimeTypes["*/*"]
+ if ok {
+ return match
+ }
+ return CapLevelRejected
+}
+
+type hashable interface {
+ Hash() []byte
+}
+
+func hashMap[Key ~string, Value hashable](w io.Writer, name string, data map[Key]Value) {
+ keys := maps.Keys(data)
+ slices.Sort(keys)
+ exerrors.Must(w.Write([]byte(name)))
+ for _, key := range keys {
+ exerrors.Must(w.Write([]byte(key)))
+ exerrors.Must(w.Write(data[key].Hash()))
+ exerrors.Must(w.Write([]byte{0}))
+ }
+}
+
+func hashValue(w io.Writer, name string, data hashable) {
+ exerrors.Must(w.Write([]byte(name)))
+ exerrors.Must(w.Write(data.Hash()))
+}
+
+func hashInt[T constraints.Integer](w io.Writer, name string, data T) {
+ exerrors.Must(w.Write(binary.BigEndian.AppendUint64([]byte(name), uint64(data))))
+}
+
+func hashBool[T ~bool](w io.Writer, name string, data T) {
+ exerrors.Must(w.Write([]byte(name)))
+ if data {
+ exerrors.Must(w.Write([]byte{1}))
+ } else {
+ exerrors.Must(w.Write([]byte{0}))
+ }
+}
+
+func (csl CapabilitySupportLevel) Hash() []byte {
+ return []byte{byte(csl + 128)}
+}
+
+func (rf *RoomFeatures) Hash() []byte {
+ hasher := sha256.New()
+
+ hashMap(hasher, "formatting", rf.Formatting)
+ hashMap(hasher, "file", rf.File)
+ hashMap(hasher, "state", rf.State)
+ hashMap(hasher, "member_actions", rf.MemberActions)
+
+ hashInt(hasher, "max_text_length", rf.MaxTextLength)
+
+ hashValue(hasher, "location_message", rf.LocationMessage)
+ hashValue(hasher, "poll", rf.Poll)
+ hashValue(hasher, "thread", rf.Thread)
+ hashValue(hasher, "reply", rf.Reply)
+
+ hashValue(hasher, "edit", rf.Edit)
+ hashInt(hasher, "edit_max_count", rf.EditMaxCount)
+ hashInt(hasher, "edit_max_age", rf.EditMaxAge.Get())
+
+ hashValue(hasher, "delete", rf.Delete)
+ hashBool(hasher, "delete_for_me", rf.DeleteForMe)
+ hashInt(hasher, "delete_max_age", rf.DeleteMaxAge.Get())
+ hashValue(hasher, "disappearing_timer", rf.DisappearingTimer)
+
+ hashValue(hasher, "reaction", rf.Reaction)
+ hashInt(hasher, "reaction_count", rf.ReactionCount)
+ hasher.Write([]byte("allowed_reactions"))
+ for _, reaction := range rf.AllowedReactions {
+ hasher.Write([]byte(reaction))
+ }
+ hashBool(hasher, "custom_emoji_reactions", rf.CustomEmojiReactions)
+
+ hashBool(hasher, "read_receipts", rf.ReadReceipts)
+ hashBool(hasher, "typing_notifications", rf.TypingNotifications)
+ hashBool(hasher, "archive", rf.Archive)
+ hashBool(hasher, "mark_as_unread", rf.MarkAsUnread)
+ hashBool(hasher, "delete_chat", rf.DeleteChat)
+ hashBool(hasher, "delete_chat_for_everyone", rf.DeleteChatForEveryone)
+ hashValue(hasher, "message_request", rf.MessageRequest)
+
+ return hasher.Sum(nil)
+}
+
+func (dtc *DisappearingTimerCapability) Hash() []byte {
+ if dtc == nil {
+ return nil
+ }
+ hasher := sha256.New()
+ hasher.Write([]byte("types"))
+ for _, t := range dtc.Types {
+ hasher.Write([]byte(t))
+ }
+ hasher.Write([]byte("timers"))
+ for _, timer := range dtc.Timers {
+ hashInt(hasher, "", timer.Milliseconds())
+ }
+ return hasher.Sum(nil)
+}
+
+func (ff *FileFeatures) Hash() []byte {
+ hasher := sha256.New()
+ hashMap(hasher, "mime_types", ff.MimeTypes)
+ hashValue(hasher, "caption", ff.Caption)
+ hashInt(hasher, "max_caption_length", ff.MaxCaptionLength)
+ hashInt(hasher, "max_size", ff.MaxSize)
+ hashInt(hasher, "max_width", ff.MaxWidth)
+ hashInt(hasher, "max_height", ff.MaxHeight)
+ hashInt(hasher, "max_duration", ff.MaxDuration.Get())
+ hashBool(hasher, "view_once", ff.ViewOnce)
+ return hasher.Sum(nil)
+}
+
+func (ff *FileFeatures) Clone() *FileFeatures {
+ if ff == nil {
+ return nil
+ }
+ clone := *ff
+ clone.MimeTypes = maps.Clone(clone.MimeTypes)
+ clone.MaxDuration = ptr.Clone(clone.MaxDuration)
+ return &clone
+}
diff --git a/event/cmdschema/content.go b/event/cmdschema/content.go
new file mode 100644
index 00000000..ce07c4c0
--- /dev/null
+++ b/event/cmdschema/content.go
@@ -0,0 +1,78 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package cmdschema
+
+import (
+ "crypto/sha256"
+ "encoding/base64"
+ "fmt"
+ "reflect"
+ "slices"
+
+ "go.mau.fi/util/exsync"
+ "go.mau.fi/util/ptr"
+
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+)
+
+type EventContent struct {
+ Command string `json:"command"`
+ Aliases []string `json:"aliases,omitempty"`
+ Parameters []*Parameter `json:"parameters,omitempty"`
+ Description *event.ExtensibleTextContainer `json:"description,omitempty"`
+ TailParam string `json:"fi.mau.tail_parameter,omitempty"`
+}
+
+func (ec *EventContent) Validate() error {
+ if ec == nil {
+ return fmt.Errorf("event content is nil")
+ } else if ec.Command == "" {
+ return fmt.Errorf("command is empty")
+ }
+ var tailFound bool
+ dupMap := exsync.NewSet[string]()
+ for i, p := range ec.Parameters {
+ if err := p.Validate(); err != nil {
+ return fmt.Errorf("parameter %q (#%d) is invalid: %w", ptr.Val(p).Key, i+1, err)
+ } else if !dupMap.Add(p.Key) {
+ return fmt.Errorf("duplicate parameter key %q at #%d", p.Key, i+1)
+ } else if p.Key == ec.TailParam {
+ tailFound = true
+ } else if tailFound && !p.Optional {
+ return fmt.Errorf("required parameter %q (#%d) is after tail parameter %q", p.Key, i+1, ec.TailParam)
+ }
+ }
+ if ec.TailParam != "" && !tailFound {
+ return fmt.Errorf("tail parameter %q not found in parameters", ec.TailParam)
+ }
+ return nil
+}
+
+func (ec *EventContent) IsValid() bool {
+ return ec.Validate() == nil
+}
+
+func (ec *EventContent) StateKey(owner id.UserID) string {
+ hash := sha256.Sum256([]byte(ec.Command + owner.String()))
+ return base64.StdEncoding.EncodeToString(hash[:])
+}
+
+func (ec *EventContent) Equals(other *EventContent) bool {
+ if ec == nil || other == nil {
+ return ec == other
+ }
+ return ec.Command == other.Command &&
+ slices.Equal(ec.Aliases, other.Aliases) &&
+ slices.EqualFunc(ec.Parameters, other.Parameters, (*Parameter).Equals) &&
+ ec.Description.Equals(other.Description) &&
+ ec.TailParam == other.TailParam
+}
+
+func init() {
+ event.TypeMap[event.StateMSC4391BotCommand] = reflect.TypeOf(EventContent{})
+}
diff --git a/event/cmdschema/parameter.go b/event/cmdschema/parameter.go
new file mode 100644
index 00000000..4193b297
--- /dev/null
+++ b/event/cmdschema/parameter.go
@@ -0,0 +1,286 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package cmdschema
+
+import (
+ "encoding/json"
+ "fmt"
+ "slices"
+
+ "go.mau.fi/util/exslices"
+
+ "maunium.net/go/mautrix/event"
+)
+
+type Parameter struct {
+ Key string `json:"key"`
+ Schema *ParameterSchema `json:"schema"`
+ Optional bool `json:"optional,omitempty"`
+ Description *event.ExtensibleTextContainer `json:"description,omitempty"`
+ DefaultValue any `json:"fi.mau.default_value,omitempty"`
+}
+
+func (p *Parameter) Equals(other *Parameter) bool {
+ if p == nil || other == nil {
+ return p == other
+ }
+ return p.Key == other.Key &&
+ p.Schema.Equals(other.Schema) &&
+ p.Optional == other.Optional &&
+ p.Description.Equals(other.Description) &&
+ p.DefaultValue == other.DefaultValue // TODO this won't work for room/event ID values
+}
+
+func (p *Parameter) Validate() error {
+ if p == nil {
+ return fmt.Errorf("parameter is nil")
+ } else if p.Key == "" {
+ return fmt.Errorf("key is empty")
+ }
+ return p.Schema.Validate()
+}
+
+func (p *Parameter) IsValid() bool {
+ return p.Validate() == nil
+}
+
+func (p *Parameter) GetDefaultValue() any {
+ if p != nil && p.DefaultValue != nil {
+ return p.DefaultValue
+ } else if p == nil || p.Optional {
+ return nil
+ }
+ return p.Schema.GetDefaultValue()
+}
+
+type PrimitiveType string
+
+const (
+ PrimitiveTypeString PrimitiveType = "string"
+ PrimitiveTypeInteger PrimitiveType = "integer"
+ PrimitiveTypeBoolean PrimitiveType = "boolean"
+ PrimitiveTypeServerName PrimitiveType = "server_name"
+ PrimitiveTypeUserID PrimitiveType = "user_id"
+ PrimitiveTypeRoomID PrimitiveType = "room_id"
+ PrimitiveTypeRoomAlias PrimitiveType = "room_alias"
+ PrimitiveTypeEventID PrimitiveType = "event_id"
+)
+
+func (pt PrimitiveType) Schema() *ParameterSchema {
+ return &ParameterSchema{
+ SchemaType: SchemaTypePrimitive,
+ Type: pt,
+ }
+}
+
+func (pt PrimitiveType) IsValid() bool {
+ switch pt {
+ case PrimitiveTypeString,
+ PrimitiveTypeInteger,
+ PrimitiveTypeBoolean,
+ PrimitiveTypeServerName,
+ PrimitiveTypeUserID,
+ PrimitiveTypeRoomID,
+ PrimitiveTypeRoomAlias,
+ PrimitiveTypeEventID:
+ return true
+ default:
+ return false
+ }
+}
+
+type SchemaType string
+
+const (
+ SchemaTypePrimitive SchemaType = "primitive"
+ SchemaTypeArray SchemaType = "array"
+ SchemaTypeUnion SchemaType = "union"
+ SchemaTypeLiteral SchemaType = "literal"
+)
+
+type ParameterSchema struct {
+ SchemaType SchemaType `json:"schema_type"`
+ Type PrimitiveType `json:"type,omitempty"` // Only for primitive
+ Items *ParameterSchema `json:"items,omitempty"` // Only for array
+ Variants []*ParameterSchema `json:"variants,omitempty"` // Only for union
+ Value any `json:"value,omitempty"` // Only for literal
+}
+
+func Literal(value any) *ParameterSchema {
+ return &ParameterSchema{
+ SchemaType: SchemaTypeLiteral,
+ Value: value,
+ }
+}
+
+func Enum(values ...any) *ParameterSchema {
+ return Union(exslices.CastFunc(values, Literal)...)
+}
+
+func flattenUnion(variants []*ParameterSchema) []*ParameterSchema {
+ var flattened []*ParameterSchema
+ for _, variant := range variants {
+ switch variant.SchemaType {
+ case SchemaTypeArray:
+ panic(fmt.Errorf("illegal array schema in union"))
+ case SchemaTypeUnion:
+ flattened = append(flattened, flattenUnion(variant.Variants)...)
+ default:
+ flattened = append(flattened, variant)
+ }
+ }
+ return flattened
+}
+
+func Union(variants ...*ParameterSchema) *ParameterSchema {
+ needsFlattening := false
+ for _, variant := range variants {
+ if variant.SchemaType == SchemaTypeArray {
+ panic(fmt.Errorf("illegal array schema in union"))
+ } else if variant.SchemaType == SchemaTypeUnion {
+ needsFlattening = true
+ }
+ }
+ if needsFlattening {
+ variants = flattenUnion(variants)
+ }
+ return &ParameterSchema{
+ SchemaType: SchemaTypeUnion,
+ Variants: variants,
+ }
+}
+
+func Array(items *ParameterSchema) *ParameterSchema {
+ if items.SchemaType == SchemaTypeArray {
+ panic(fmt.Errorf("illegal array schema in array"))
+ }
+ return &ParameterSchema{
+ SchemaType: SchemaTypeArray,
+ Items: items,
+ }
+}
+
+func (ps *ParameterSchema) GetDefaultValue() any {
+ if ps == nil {
+ return nil
+ }
+ switch ps.SchemaType {
+ case SchemaTypePrimitive:
+ switch ps.Type {
+ case PrimitiveTypeInteger:
+ return 0
+ case PrimitiveTypeBoolean:
+ return false
+ default:
+ return ""
+ }
+ case SchemaTypeArray:
+ return []any{}
+ case SchemaTypeUnion:
+ if len(ps.Variants) > 0 {
+ return ps.Variants[0].GetDefaultValue()
+ }
+ return nil
+ case SchemaTypeLiteral:
+ return ps.Value
+ default:
+ return nil
+ }
+}
+
+func (ps *ParameterSchema) IsValid() bool {
+ return ps.validate("") == nil
+}
+
+func (ps *ParameterSchema) Validate() error {
+ return ps.validate("")
+}
+
+func (ps *ParameterSchema) validate(parent SchemaType) error {
+ if ps == nil {
+ return fmt.Errorf("schema is nil")
+ }
+ switch ps.SchemaType {
+ case SchemaTypePrimitive:
+ if !ps.Type.IsValid() {
+ return fmt.Errorf("invalid primitive type %s", ps.Type)
+ } else if ps.Items != nil || ps.Variants != nil || ps.Value != nil {
+ return fmt.Errorf("primitive schema has extra fields")
+ }
+ return nil
+ case SchemaTypeArray:
+ if parent != "" {
+ return fmt.Errorf("arrays can't be nested in other types")
+ } else if err := ps.Items.validate(ps.SchemaType); err != nil {
+ return fmt.Errorf("item schema is invalid: %w", err)
+ } else if ps.Type != "" || ps.Variants != nil || ps.Value != nil {
+ return fmt.Errorf("array schema has extra fields")
+ }
+ return nil
+ case SchemaTypeUnion:
+ if len(ps.Variants) == 0 {
+ return fmt.Errorf("no variants specified for union")
+ } else if parent != "" && parent != SchemaTypeArray {
+ return fmt.Errorf("unions can't be nested in anything other than arrays")
+ }
+ for i, v := range ps.Variants {
+ if err := v.validate(ps.SchemaType); err != nil {
+ return fmt.Errorf("variant #%d is invalid: %w", i+1, err)
+ }
+ }
+ if ps.Type != "" || ps.Items != nil || ps.Value != nil {
+ return fmt.Errorf("union schema has extra fields")
+ }
+ return nil
+ case SchemaTypeLiteral:
+ switch typedVal := ps.Value.(type) {
+ case string, float64, int, int64, json.Number, bool, RoomIDValue, *RoomIDValue:
+ // ok
+ case map[string]any:
+ if typedVal["type"] != "event_id" && typedVal["type"] != "room_id" {
+ return fmt.Errorf("literal value has invalid map data")
+ }
+ default:
+ return fmt.Errorf("literal value has unsupported type %T", ps.Value)
+ }
+ if ps.Type != "" || ps.Items != nil || ps.Variants != nil {
+ return fmt.Errorf("literal schema has extra fields")
+ }
+ return nil
+ default:
+ return fmt.Errorf("invalid schema type %s", ps.SchemaType)
+ }
+}
+
+func (ps *ParameterSchema) Equals(other *ParameterSchema) bool {
+ if ps == nil || other == nil {
+ return ps == other
+ }
+ return ps.SchemaType == other.SchemaType &&
+ ps.Type == other.Type &&
+ ps.Items.Equals(other.Items) &&
+ slices.EqualFunc(ps.Variants, other.Variants, (*ParameterSchema).Equals) &&
+ ps.Value == other.Value // TODO this won't work for room/event ID values
+}
+
+func (ps *ParameterSchema) AllowsPrimitive(prim PrimitiveType) bool {
+ switch ps.SchemaType {
+ case SchemaTypePrimitive:
+ return ps.Type == prim
+ case SchemaTypeUnion:
+ for _, variant := range ps.Variants {
+ if variant.AllowsPrimitive(prim) {
+ return true
+ }
+ }
+ return false
+ case SchemaTypeArray:
+ return ps.Items.AllowsPrimitive(prim)
+ default:
+ return false
+ }
+}
diff --git a/event/cmdschema/parse.go b/event/cmdschema/parse.go
new file mode 100644
index 00000000..92e69b60
--- /dev/null
+++ b/event/cmdschema/parse.go
@@ -0,0 +1,478 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package cmdschema
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "regexp"
+ "strconv"
+ "strings"
+
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+)
+
+const botArrayOpener = "<"
+const botArrayCloser = ">"
+
+func parseQuoted(val string) (parsed, remaining string, quoted bool) {
+ if len(val) == 0 {
+ return
+ }
+ if !strings.HasPrefix(val, `"`) {
+ spaceIdx := strings.IndexByte(val, ' ')
+ if spaceIdx == -1 {
+ parsed = val
+ } else {
+ parsed = val[:spaceIdx]
+ remaining = strings.TrimLeft(val[spaceIdx+1:], " ")
+ }
+ return
+ }
+ val = val[1:]
+ var buf strings.Builder
+ for {
+ quoteIdx := strings.IndexByte(val, '"')
+ var valUntilQuote string
+ if quoteIdx == -1 {
+ valUntilQuote = val
+ } else {
+ valUntilQuote = val[:quoteIdx]
+ }
+ escapeIdx := strings.IndexByte(valUntilQuote, '\\')
+ if escapeIdx >= 0 {
+ buf.WriteString(val[:escapeIdx])
+ if len(val) > escapeIdx+1 {
+ buf.WriteByte(val[escapeIdx+1])
+ }
+ val = val[min(escapeIdx+2, len(val)):]
+ } else if quoteIdx >= 0 {
+ buf.WriteString(val[:quoteIdx])
+ val = val[quoteIdx+1:]
+ break
+ } else if buf.Len() == 0 {
+ // Unterminated quote, no escape characters, val is the whole input
+ return val, "", true
+ } else {
+ // Unterminated quote, but there were escape characters previously
+ buf.WriteString(val)
+ val = ""
+ break
+ }
+ }
+ return buf.String(), strings.TrimLeft(val, " "), true
+}
+
+// ParseInput tries to parse the given text into a bot command event matching this command definition.
+//
+// If the prefix doesn't match, this will return a nil content and nil error.
+// If the prefix does match, some content is always returned, but there may still be an error if parsing failed.
+func (ec *EventContent) ParseInput(owner id.UserID, sigils []string, input string) (content *event.MessageEventContent, err error) {
+ prefix := ec.parsePrefix(input, sigils, owner.String())
+ if prefix == "" {
+ return nil, nil
+ }
+ content = &event.MessageEventContent{
+ MsgType: event.MsgText,
+ Body: input,
+ Mentions: &event.Mentions{UserIDs: []id.UserID{owner}},
+ MSC4391BotCommand: &event.MSC4391BotCommandInput{
+ Command: ec.Command,
+ },
+ }
+ content.MSC4391BotCommand.Arguments, err = ec.ParseArguments(input[len(prefix):])
+ return content, err
+}
+
+func (ec *EventContent) ParseArguments(input string) (json.RawMessage, error) {
+ args := make(map[string]any)
+ var retErr error
+ setError := func(err error) {
+ if err != nil && retErr == nil {
+ retErr = err
+ }
+ }
+ processParameter := func(param *Parameter, isLast, isTail, isNamed bool) {
+ origInput := input
+ var nextVal string
+ var wasQuoted bool
+ if param.Schema.SchemaType == SchemaTypeArray {
+ hasOpener := strings.HasPrefix(input, botArrayOpener)
+ arrayClosed := false
+ if hasOpener {
+ input = input[len(botArrayOpener):]
+ if strings.HasPrefix(input, botArrayCloser) {
+ input = strings.TrimLeft(input[len(botArrayCloser):], " ")
+ arrayClosed = true
+ }
+ }
+ var collector []any
+ for len(input) > 0 && !arrayClosed {
+ //origInput = input
+ nextVal, input, wasQuoted = parseQuoted(input)
+ if !wasQuoted && hasOpener && strings.HasSuffix(nextVal, botArrayCloser) {
+ // The value wasn't quoted and has the array delimiter at the end, close the array
+ nextVal = strings.TrimRight(nextVal, botArrayCloser)
+ arrayClosed = true
+ } else if hasOpener && strings.HasPrefix(input, botArrayCloser) {
+ // The value was quoted or there was a space, and the next character is the
+ // array delimiter, close the array
+ input = strings.TrimLeft(input[len(botArrayCloser):], " ")
+ arrayClosed = true
+ } else if !hasOpener && !isLast {
+ // For array arguments in the middle without the <> delimiters, stop after the first item
+ arrayClosed = true
+ }
+ parsedVal, err := param.Schema.Items.ParseString(nextVal)
+ if err == nil {
+ collector = append(collector, parsedVal)
+ } else if hasOpener || isLast {
+ setError(fmt.Errorf("failed to parse item #%d of array %s: %w", len(collector)+1, param.Key, err))
+ } else {
+ //input = origInput
+ }
+ }
+ args[param.Key] = collector
+ } else {
+ nextVal, input, wasQuoted = parseQuoted(input)
+ if (isLast || isTail) && !wasQuoted && len(input) > 0 {
+ // If the last argument is not quoted, just treat the rest of the string
+ // as the argument without escapes (arguments with escapes should be quoted).
+ nextVal += " " + input
+ input = ""
+ }
+ // Special case for named boolean parameters: if no value is given, treat it as true
+ if nextVal == "" && !wasQuoted && isNamed && param.Schema.AllowsPrimitive(PrimitiveTypeBoolean) {
+ args[param.Key] = true
+ return
+ }
+ if nextVal == "" && !wasQuoted && !isNamed && !param.Optional {
+ setError(fmt.Errorf("missing value for required parameter %s", param.Key))
+ }
+ parsedVal, err := param.Schema.ParseString(nextVal)
+ if err != nil {
+ args[param.Key] = param.GetDefaultValue()
+ // For optional parameters that fail to parse, restore the input and try passing it as the next parameter
+ if param.Optional && !isLast && !isNamed {
+ input = strings.TrimLeft(origInput, " ")
+ } else if !param.Optional || isNamed {
+ setError(fmt.Errorf("failed to parse %s: %w", param.Key, err))
+ }
+ } else {
+ args[param.Key] = parsedVal
+ }
+ }
+ }
+ skipParams := make([]bool, len(ec.Parameters))
+ for i, param := range ec.Parameters {
+ for strings.HasPrefix(input, "--") {
+ nameEndIdx := strings.IndexAny(input, " =")
+ if nameEndIdx == -1 {
+ nameEndIdx = len(input)
+ }
+ overrideParam, paramIdx := ec.parameterByName(input[2:nameEndIdx])
+ if overrideParam != nil {
+ // Trim the equals sign, but leave spaces alone to let parseQuoted treat it as empty input
+ input = strings.TrimPrefix(input[nameEndIdx:], "=")
+ skipParams[paramIdx] = true
+ processParameter(overrideParam, false, false, true)
+ } else {
+ break
+ }
+ }
+ isTail := param.Key == ec.TailParam
+ if skipParams[i] || (param.Optional && !isTail) {
+ continue
+ }
+ processParameter(param, i == len(ec.Parameters)-1, isTail, false)
+ }
+ jsonArgs, marshalErr := json.Marshal(args)
+ if marshalErr != nil {
+ return nil, fmt.Errorf("failed to marshal arguments: %w", marshalErr)
+ }
+ return jsonArgs, retErr
+}
+
+func (ec *EventContent) parameterByName(name string) (*Parameter, int) {
+ for i, param := range ec.Parameters {
+ if strings.EqualFold(param.Key, name) {
+ return param, i
+ }
+ }
+ return nil, -1
+}
+
+func (ec *EventContent) parsePrefix(origInput string, sigils []string, owner string) (prefix string) {
+ input := origInput
+ var chosenSigil string
+ for _, sigil := range sigils {
+ if strings.HasPrefix(input, sigil) {
+ chosenSigil = sigil
+ break
+ }
+ }
+ if chosenSigil == "" {
+ return ""
+ }
+ input = input[len(chosenSigil):]
+ var chosenAlias string
+ if !strings.HasPrefix(input, ec.Command) {
+ for _, alias := range ec.Aliases {
+ if strings.HasPrefix(input, alias) {
+ chosenAlias = alias
+ break
+ }
+ }
+ if chosenAlias == "" {
+ return ""
+ }
+ } else {
+ chosenAlias = ec.Command
+ }
+ input = strings.TrimPrefix(input[len(chosenAlias):], owner)
+ if input == "" || input[0] == ' ' {
+ input = strings.TrimLeft(input, " ")
+ return origInput[:len(origInput)-len(input)]
+ }
+ return ""
+}
+
+func (pt PrimitiveType) ValidateValue(value any) bool {
+ _, err := pt.NormalizeValue(value)
+ return err == nil
+}
+
+func normalizeNumber(value any) (int, error) {
+ switch typedValue := value.(type) {
+ case int:
+ return typedValue, nil
+ case int64:
+ return int(typedValue), nil
+ case float64:
+ return int(typedValue), nil
+ case json.Number:
+ if i, err := typedValue.Int64(); err != nil {
+ return 0, fmt.Errorf("failed to parse json.Number: %w", err)
+ } else {
+ return int(i), nil
+ }
+ default:
+ return 0, fmt.Errorf("unsupported type %T for integer", value)
+ }
+}
+
+func (pt PrimitiveType) NormalizeValue(value any) (any, error) {
+ switch pt {
+ case PrimitiveTypeInteger:
+ return normalizeNumber(value)
+ case PrimitiveTypeBoolean:
+ bv, ok := value.(bool)
+ if !ok {
+ return nil, fmt.Errorf("unsupported type %T for boolean", value)
+ }
+ return bv, nil
+ case PrimitiveTypeString, PrimitiveTypeServerName:
+ str, ok := value.(string)
+ if !ok {
+ return nil, fmt.Errorf("unsupported type %T for string", value)
+ }
+ return str, pt.validateStringValue(str)
+ case PrimitiveTypeUserID, PrimitiveTypeRoomAlias:
+ str, ok := value.(string)
+ if !ok {
+ return nil, fmt.Errorf("unsupported type %T for user ID or room alias", value)
+ } else if plainErr := pt.validateStringValue(str); plainErr == nil {
+ return str, nil
+ } else if parsed, err := id.ParseMatrixURIOrMatrixToURL(str); err != nil {
+ return nil, fmt.Errorf("couldn't parse %q as plain ID nor matrix URI: %w / %w", value, plainErr, err)
+ } else if parsed.Sigil1 == '@' && pt == PrimitiveTypeUserID {
+ return parsed.UserID(), nil
+ } else if parsed.Sigil1 == '#' && pt == PrimitiveTypeRoomAlias {
+ return parsed.RoomAlias(), nil
+ } else {
+ return nil, fmt.Errorf("unexpected sigil %c for user ID or room alias", parsed.Sigil1)
+ }
+ case PrimitiveTypeRoomID, PrimitiveTypeEventID:
+ riv, err := NormalizeRoomIDValue(value)
+ if err != nil {
+ return nil, err
+ }
+ return riv, riv.Validate()
+ default:
+ return nil, fmt.Errorf("cannot normalize value for argument type %s", pt)
+ }
+}
+
+func (pt PrimitiveType) validateStringValue(value string) error {
+ switch pt {
+ case PrimitiveTypeString:
+ return nil
+ case PrimitiveTypeServerName:
+ if !id.ValidateServerName(value) {
+ return fmt.Errorf("invalid server name: %q", value)
+ }
+ return nil
+ case PrimitiveTypeUserID:
+ _, _, err := id.UserID(value).ParseAndValidateRelaxed()
+ return err
+ case PrimitiveTypeRoomAlias:
+ sigil, localpart, serverName := id.ParseCommonIdentifier(value)
+ if sigil != '#' || localpart == "" || serverName == "" {
+ return fmt.Errorf("invalid room alias: %q", value)
+ } else if !id.ValidateServerName(serverName) {
+ return fmt.Errorf("invalid server name in room alias: %q", serverName)
+ }
+ return nil
+ default:
+ panic(fmt.Errorf("validateStringValue called with invalid type %s", pt))
+ }
+}
+
+func parseBoolean(val string) (bool, error) {
+ if len(val) == 0 {
+ return false, fmt.Errorf("cannot parse empty string as boolean")
+ }
+ switch strings.ToLower(val) {
+ case "t", "true", "y", "yes", "1":
+ return true, nil
+ case "f", "false", "n", "no", "0":
+ return false, nil
+ default:
+ return false, fmt.Errorf("invalid boolean string: %q", val)
+ }
+}
+
+var markdownLinkRegex = regexp.MustCompile(`^\[.+]\(([^)]+)\)$`)
+
+func parseRoomOrEventID(value string) (*RoomIDValue, error) {
+ if strings.HasPrefix(value, "[") && strings.Contains(value, "](") && strings.HasSuffix(value, ")") {
+ matches := markdownLinkRegex.FindStringSubmatch(value)
+ if len(matches) == 2 {
+ value = matches[1]
+ }
+ }
+ parsed, err := id.ParseMatrixURIOrMatrixToURL(value)
+ if err != nil && strings.HasPrefix(value, "!") {
+ return &RoomIDValue{
+ Type: PrimitiveTypeRoomID,
+ RoomID: id.RoomID(value),
+ }, nil
+ }
+ if err != nil {
+ return nil, err
+ } else if parsed.Sigil1 != '!' {
+ return nil, fmt.Errorf("unexpected sigil %c for room ID", parsed.Sigil1)
+ } else if parsed.MXID2 != "" && parsed.Sigil2 != '$' {
+ return nil, fmt.Errorf("unexpected sigil %c for event ID", parsed.Sigil2)
+ }
+ valType := PrimitiveTypeRoomID
+ if parsed.MXID2 != "" {
+ valType = PrimitiveTypeEventID
+ }
+ return &RoomIDValue{
+ Type: valType,
+ RoomID: parsed.RoomID(),
+ Via: parsed.Via,
+ EventID: parsed.EventID(),
+ }, nil
+}
+
+func (pt PrimitiveType) ParseString(value string) (any, error) {
+ switch pt {
+ case PrimitiveTypeInteger:
+ return strconv.Atoi(value)
+ case PrimitiveTypeBoolean:
+ return parseBoolean(value)
+ case PrimitiveTypeString, PrimitiveTypeServerName, PrimitiveTypeUserID:
+ return value, pt.validateStringValue(value)
+ case PrimitiveTypeRoomAlias:
+ plainErr := pt.validateStringValue(value)
+ if plainErr == nil {
+ return value, nil
+ }
+ parsed, err := id.ParseMatrixURIOrMatrixToURL(value)
+ if err != nil {
+ return nil, fmt.Errorf("couldn't parse %q as plain room alias nor matrix URI: %w / %w", value, plainErr, err)
+ } else if parsed.Sigil1 != '#' {
+ return nil, fmt.Errorf("unexpected sigil %c for room alias", parsed.Sigil1)
+ }
+ return parsed.RoomAlias(), nil
+ case PrimitiveTypeRoomID, PrimitiveTypeEventID:
+ parsed, err := parseRoomOrEventID(value)
+ if err != nil {
+ return nil, err
+ } else if pt != parsed.Type {
+ return nil, fmt.Errorf("mismatching argument type: expected %s but got %s", pt, parsed.Type)
+ }
+ return parsed, nil
+ default:
+ return nil, fmt.Errorf("cannot parse string for argument type %s", pt)
+ }
+}
+
+func (ps *ParameterSchema) ParseString(value string) (any, error) {
+ if ps == nil {
+ return nil, fmt.Errorf("parameter schema is nil")
+ }
+ switch ps.SchemaType {
+ case SchemaTypePrimitive:
+ return ps.Type.ParseString(value)
+ case SchemaTypeLiteral:
+ switch typedValue := ps.Value.(type) {
+ case string:
+ if value == typedValue {
+ return typedValue, nil
+ } else {
+ return nil, fmt.Errorf("literal value %q does not match %q", typedValue, value)
+ }
+ case int, int64, float64, json.Number:
+ expectedVal, _ := normalizeNumber(typedValue)
+ intVal, err := strconv.Atoi(value)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse integer literal: %w", err)
+ } else if intVal != expectedVal {
+ return nil, fmt.Errorf("literal value %d does not match %d", expectedVal, intVal)
+ }
+ return intVal, nil
+ case bool:
+ boolVal, err := parseBoolean(value)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse boolean literal: %w", err)
+ } else if boolVal != typedValue {
+ return nil, fmt.Errorf("literal value %t does not match %t", typedValue, boolVal)
+ }
+ return boolVal, nil
+ case RoomIDValue, *RoomIDValue, map[string]any, json.RawMessage:
+ expectedVal, _ := NormalizeRoomIDValue(typedValue)
+ parsed, err := parseRoomOrEventID(value)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse room or event ID literal: %w", err)
+ } else if !parsed.Equals(expectedVal) {
+ return nil, fmt.Errorf("literal value %s does not match %s", expectedVal, parsed)
+ }
+ return parsed, nil
+ default:
+ return nil, fmt.Errorf("unsupported literal type %T", ps.Value)
+ }
+ case SchemaTypeUnion:
+ var errs []error
+ for _, variant := range ps.Variants {
+ if parsed, err := variant.ParseString(value); err == nil {
+ return parsed, nil
+ } else {
+ errs = append(errs, err)
+ }
+ }
+ return nil, fmt.Errorf("no union variant matched: %w", errors.Join(errs...))
+ case SchemaTypeArray:
+ return nil, fmt.Errorf("cannot parse string for array schema type")
+ default:
+ return nil, fmt.Errorf("unknown schema type %s", ps.SchemaType)
+ }
+}
diff --git a/event/cmdschema/parse_test.go b/event/cmdschema/parse_test.go
new file mode 100644
index 00000000..1e0d1817
--- /dev/null
+++ b/event/cmdschema/parse_test.go
@@ -0,0 +1,118 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package cmdschema
+
+import (
+ "bytes"
+ "encoding/json"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "go.mau.fi/util/exbytes"
+ "go.mau.fi/util/exerrors"
+
+ "maunium.net/go/mautrix/event/cmdschema/testdata"
+)
+
+type QuoteParseOutput struct {
+ Parsed string
+ Remaining string
+ Quoted bool
+}
+
+func (qpo *QuoteParseOutput) UnmarshalJSON(data []byte) error {
+ var arr []any
+ if err := json.Unmarshal(data, &arr); err != nil {
+ return err
+ }
+ qpo.Parsed = arr[0].(string)
+ qpo.Remaining = arr[1].(string)
+ qpo.Quoted = arr[2].(bool)
+ return nil
+}
+
+type QuoteParseTestData struct {
+ Name string `json:"name"`
+ Input string `json:"input"`
+ Output QuoteParseOutput `json:"output"`
+}
+
+func loadFile[T any](name string) (into T) {
+ quoteData := exerrors.Must(testdata.FS.ReadFile(name))
+ exerrors.PanicIfNotNil(json.Unmarshal(quoteData, &into))
+ return
+}
+
+func TestParseQuoted(t *testing.T) {
+ qptd := loadFile[[]QuoteParseTestData]("parse_quote.json")
+ for _, test := range qptd {
+ t.Run(test.Name, func(t *testing.T) {
+ parsed, remaining, quoted := parseQuoted(test.Input)
+ assert.Equalf(t, test.Output, QuoteParseOutput{
+ Parsed: parsed,
+ Remaining: remaining,
+ Quoted: quoted,
+ }, "Failed with input `%s`", test.Input)
+ // Note: can't just test that requoted == input, because some inputs
+ // have unnecessary escapes which won't survive roundtripping
+ t.Run("roundtrip", func(t *testing.T) {
+ requoted := quoteString(parsed) + " " + remaining
+ reparsed, newRemaining, _ := parseQuoted(requoted)
+ assert.Equal(t, parsed, reparsed)
+ assert.Equal(t, remaining, newRemaining)
+ })
+ })
+ }
+}
+
+type CommandTestData struct {
+ Spec *EventContent
+ Tests []*CommandTestUnit
+}
+
+type CommandTestUnit struct {
+ Name string `json:"name"`
+ Input string `json:"input"`
+ Broken string `json:"broken,omitempty"`
+ Error bool `json:"error"`
+ Output json.RawMessage `json:"output"`
+}
+
+func compactJSON(input json.RawMessage) json.RawMessage {
+ var buf bytes.Buffer
+ exerrors.PanicIfNotNil(json.Compact(&buf, input))
+ return buf.Bytes()
+}
+
+func TestMSC4391BotCommandEventContent_ParseInput(t *testing.T) {
+ for _, cmd := range exerrors.Must(testdata.FS.ReadDir("commands")) {
+ t.Run(strings.TrimSuffix(cmd.Name(), ".json"), func(t *testing.T) {
+ ctd := loadFile[CommandTestData]("commands/" + cmd.Name())
+ for _, test := range ctd.Tests {
+ outputStr := exbytes.UnsafeString(compactJSON(test.Output))
+ t.Run(test.Name, func(t *testing.T) {
+ if test.Broken != "" {
+ t.Skip(test.Broken)
+ }
+ output, err := ctd.Spec.ParseInput("@testbot", []string{"/"}, test.Input)
+ if test.Error {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ }
+ if outputStr == "null" {
+ assert.Nil(t, output)
+ } else {
+ assert.Equal(t, ctd.Spec.Command, output.MSC4391BotCommand.Command)
+ assert.Equalf(t, outputStr, exbytes.UnsafeString(output.MSC4391BotCommand.Arguments), "Input: %s", test.Input)
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/event/cmdschema/roomid.go b/event/cmdschema/roomid.go
new file mode 100644
index 00000000..98c421fc
--- /dev/null
+++ b/event/cmdschema/roomid.go
@@ -0,0 +1,135 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package cmdschema
+
+import (
+ "encoding/json"
+ "fmt"
+ "slices"
+ "strings"
+
+ "maunium.net/go/mautrix/id"
+)
+
+var ParameterSchemaJoinableRoom = Union(
+ PrimitiveTypeRoomID.Schema(),
+ PrimitiveTypeRoomAlias.Schema(),
+)
+
+type RoomIDValue struct {
+ Type PrimitiveType `json:"type"`
+ RoomID id.RoomID `json:"id"`
+ Via []string `json:"via,omitempty"`
+ EventID id.EventID `json:"event_id,omitempty"`
+}
+
+func NormalizeRoomIDValue(input any) (riv *RoomIDValue, err error) {
+ switch typedValue := input.(type) {
+ case map[string]any, json.RawMessage:
+ var raw json.RawMessage
+ if raw, err = json.Marshal(input); err != nil {
+ err = fmt.Errorf("failed to roundtrip room ID value: %w", err)
+ } else if err = json.Unmarshal(raw, &riv); err != nil {
+ err = fmt.Errorf("failed to roundtrip room ID value: %w", err)
+ }
+ case *RoomIDValue:
+ riv = typedValue
+ case RoomIDValue:
+ riv = &typedValue
+ default:
+ err = fmt.Errorf("unsupported type %T for room or event ID", input)
+ }
+ return
+}
+
+func (riv *RoomIDValue) String() string {
+ return riv.URI().String()
+}
+
+func (riv *RoomIDValue) URI() *id.MatrixURI {
+ if riv == nil {
+ return nil
+ }
+ switch riv.Type {
+ case PrimitiveTypeRoomID:
+ return riv.RoomID.URI(riv.Via...)
+ case PrimitiveTypeEventID:
+ return riv.RoomID.EventURI(riv.EventID, riv.Via...)
+ default:
+ return nil
+ }
+}
+
+func (riv *RoomIDValue) Equals(other *RoomIDValue) bool {
+ if riv == nil || other == nil {
+ return riv == other
+ }
+ return riv.Type == other.Type &&
+ riv.RoomID == other.RoomID &&
+ riv.EventID == other.EventID &&
+ slices.Equal(riv.Via, other.Via)
+}
+
+func (riv *RoomIDValue) Validate() error {
+ if riv == nil {
+ return fmt.Errorf("value is nil")
+ }
+ switch riv.Type {
+ case PrimitiveTypeRoomID:
+ if riv.EventID != "" {
+ return fmt.Errorf("event ID must be empty for room ID type")
+ }
+ case PrimitiveTypeEventID:
+ if !strings.HasPrefix(riv.EventID.String(), "$") {
+ return fmt.Errorf("event ID not valid: %q", riv.EventID)
+ }
+ default:
+ return fmt.Errorf("unexpected type %s for room/event ID value", riv.Type)
+ }
+ for _, via := range riv.Via {
+ if !id.ValidateServerName(via) {
+ return fmt.Errorf("invalid server name %q in vias", via)
+ }
+ }
+ sigil, localpart, serverName := id.ParseCommonIdentifier(riv.RoomID)
+ if sigil != '!' {
+ return fmt.Errorf("room ID does not start with !: %q", riv.RoomID)
+ } else if localpart == "" && serverName == "" {
+ return fmt.Errorf("room ID has empty localpart and server name: %q", riv.RoomID)
+ } else if serverName != "" && !id.ValidateServerName(serverName) {
+ return fmt.Errorf("invalid server name %q in room ID", serverName)
+ }
+ return nil
+}
+
+func (riv *RoomIDValue) IsValid() bool {
+ return riv.Validate() == nil
+}
+
+type RoomIDOrString string
+
+func (ros *RoomIDOrString) UnmarshalJSON(data []byte) error {
+ if len(data) == 0 {
+ return fmt.Errorf("empty data for room ID or string")
+ }
+ if data[0] == '"' {
+ var str string
+ if err := json.Unmarshal(data, &str); err != nil {
+ return err
+ }
+ *ros = RoomIDOrString(str)
+ return nil
+ }
+ var riv RoomIDValue
+ if err := json.Unmarshal(data, &riv); err != nil {
+ return err
+ } else if err = riv.Validate(); err != nil {
+ return err
+ }
+ *ros = RoomIDOrString(riv.String())
+ return nil
+}
diff --git a/event/cmdschema/stringify.go b/event/cmdschema/stringify.go
new file mode 100644
index 00000000..c5c57c53
--- /dev/null
+++ b/event/cmdschema/stringify.go
@@ -0,0 +1,122 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package cmdschema
+
+import (
+ "encoding/json"
+ "strconv"
+ "strings"
+)
+
+var quoteEscaper = strings.NewReplacer(
+ `"`, `\"`,
+ `\`, `\\`,
+)
+
+const charsToQuote = ` \` + botArrayOpener + botArrayCloser
+
+func quoteString(val string) string {
+ if val == "" {
+ return `""`
+ }
+ val = quoteEscaper.Replace(val)
+ if strings.ContainsAny(val, charsToQuote) {
+ return `"` + val + `"`
+ }
+ return val
+}
+
+func (ec *EventContent) StringifyArgs(args any) string {
+ var argMap map[string]any
+ switch typedArgs := args.(type) {
+ case json.RawMessage:
+ err := json.Unmarshal(typedArgs, &argMap)
+ if err != nil {
+ return ""
+ }
+ case map[string]any:
+ argMap = typedArgs
+ default:
+ if b, err := json.Marshal(args); err != nil {
+ return ""
+ } else if err = json.Unmarshal(b, &argMap); err != nil {
+ return ""
+ }
+ }
+ parts := make([]string, 0, len(ec.Parameters))
+ for i, param := range ec.Parameters {
+ isLast := i == len(ec.Parameters)-1
+ val := argMap[param.Key]
+ if val == nil {
+ val = param.DefaultValue
+ if val == nil && !param.Optional {
+ val = param.Schema.GetDefaultValue()
+ }
+ }
+ if val == nil {
+ continue
+ }
+ var stringified string
+ if param.Schema.SchemaType == SchemaTypeArray {
+ stringified = arrayArgumentToString(val, isLast)
+ } else {
+ stringified = singleArgumentToString(val)
+ }
+ if stringified != "" {
+ parts = append(parts, stringified)
+ }
+ }
+ return strings.Join(parts, " ")
+}
+
+func arrayArgumentToString(val any, isLast bool) string {
+ valArr, ok := val.([]any)
+ if !ok {
+ return ""
+ }
+ parts := make([]string, 0, len(valArr))
+ for _, elem := range valArr {
+ stringified := singleArgumentToString(elem)
+ if stringified != "" {
+ parts = append(parts, stringified)
+ }
+ }
+ joinedParts := strings.Join(parts, " ")
+ if isLast && len(parts) > 0 {
+ return joinedParts
+ }
+ return botArrayOpener + joinedParts + botArrayCloser
+}
+
+func singleArgumentToString(val any) string {
+ switch typedVal := val.(type) {
+ case string:
+ return quoteString(typedVal)
+ case json.Number:
+ return typedVal.String()
+ case bool:
+ return strconv.FormatBool(typedVal)
+ case int:
+ return strconv.Itoa(typedVal)
+ case int64:
+ return strconv.FormatInt(typedVal, 10)
+ case float64:
+ return strconv.FormatInt(int64(typedVal), 10)
+ case map[string]any, json.RawMessage, RoomIDValue, *RoomIDValue:
+ normalized, err := NormalizeRoomIDValue(typedVal)
+ if err != nil {
+ return ""
+ }
+ uri := normalized.URI()
+ if uri == nil {
+ return ""
+ }
+ return quoteString(uri.String())
+ default:
+ return ""
+ }
+}
diff --git a/event/cmdschema/testdata/commands.schema.json b/event/cmdschema/testdata/commands.schema.json
new file mode 100644
index 00000000..e53382db
--- /dev/null
+++ b/event/cmdschema/testdata/commands.schema.json
@@ -0,0 +1,281 @@
+{
+ "$schema": "https://json-schema.org/draft/2020-12/schema#",
+ "$id": "commands.schema.json",
+ "title": "ParseInput test cases",
+ "description": "JSON schema for test case files containing command specifications and test cases",
+ "type": "object",
+ "required": [
+ "spec",
+ "tests"
+ ],
+ "additionalProperties": false,
+ "properties": {
+ "spec": {
+ "title": "MSC4391 Command Description",
+ "description": "JSON schema defining the structure of a bot command event content",
+ "type": "object",
+ "required": [
+ "command"
+ ],
+ "additionalProperties": false,
+ "properties": {
+ "command": {
+ "type": "string",
+ "description": "The command name that triggers this bot command"
+ },
+ "aliases": {
+ "type": "array",
+ "description": "Alternative names/aliases for this command",
+ "items": {
+ "type": "string"
+ }
+ },
+ "parameters": {
+ "type": "array",
+ "description": "List of parameters accepted by this command",
+ "items": {
+ "$ref": "#/$defs/Parameter"
+ }
+ },
+ "description": {
+ "$ref": "#/$defs/ExtensibleTextContainer",
+ "description": "Human-readable description of the command"
+ },
+ "fi.mau.tail_parameter": {
+ "type": "string",
+ "description": "The key of the parameter that accepts remaining arguments as tail text"
+ },
+ "source": {
+ "type": "string",
+ "description": "The user ID of the bot that responds to this command"
+ }
+ }
+ },
+ "tests": {
+ "type": "array",
+ "description": "Array of test cases for the command",
+ "items": {
+ "type": "object",
+ "description": "A single test case for command parsing",
+ "required": [
+ "name",
+ "input"
+ ],
+ "additionalProperties": false,
+ "properties": {
+ "name": {
+ "type": "string",
+ "description": "The name of the test case"
+ },
+ "input": {
+ "type": "string",
+ "description": "The command input string to parse"
+ },
+ "output": {
+ "description": "The expected parsed parameter values, or null if the parsing is expected to fail",
+ "oneOf": [
+ {
+ "type": "object",
+ "additionalProperties": true
+ },
+ {
+ "type": "null"
+ }
+ ]
+ },
+ "error": {
+ "type": "boolean",
+ "description": "Whether parsing should result in an error. May still produce output.",
+ "default": false
+ }
+ }
+ }
+ }
+ },
+ "$defs": {
+ "ExtensibleTextContainer": {
+ "type": "object",
+ "description": "Container for text that can have multiple representations",
+ "required": [
+ "m.text"
+ ],
+ "properties": {
+ "m.text": {
+ "type": "array",
+ "description": "Array of text representations in different formats",
+ "items": {
+ "$ref": "#/$defs/ExtensibleText"
+ }
+ }
+ }
+ },
+ "ExtensibleText": {
+ "type": "object",
+ "description": "A text representation with a specific MIME type",
+ "required": [
+ "body"
+ ],
+ "properties": {
+ "body": {
+ "type": "string",
+ "description": "The text content"
+ },
+ "mimetype": {
+ "type": "string",
+ "description": "The MIME type of the text (e.g., text/plain, text/html)",
+ "default": "text/plain",
+ "examples": [
+ "text/plain",
+ "text/html"
+ ]
+ }
+ }
+ },
+ "Parameter": {
+ "type": "object",
+ "description": "A parameter definition for a command",
+ "required": [
+ "key",
+ "schema"
+ ],
+ "additionalProperties": false,
+ "properties": {
+ "key": {
+ "type": "string",
+ "description": "The identifier for this parameter"
+ },
+ "schema": {
+ "$ref": "#/$defs/ParameterSchema",
+ "description": "The schema defining the type and structure of this parameter"
+ },
+ "optional": {
+ "type": "boolean",
+ "description": "Whether this parameter is optional",
+ "default": false
+ },
+ "description": {
+ "$ref": "#/$defs/ExtensibleTextContainer",
+ "description": "Human-readable description of this parameter"
+ },
+ "fi.mau.default_value": {
+ "description": "Default value for this parameter if not provided"
+ }
+ }
+ },
+ "ParameterSchema": {
+ "type": "object",
+ "description": "Schema definition for a parameter value",
+ "required": [
+ "schema_type"
+ ],
+ "additionalProperties": false,
+ "properties": {
+ "schema_type": {
+ "type": "string",
+ "enum": [
+ "primitive",
+ "array",
+ "union",
+ "literal"
+ ],
+ "description": "The type of schema"
+ }
+ },
+ "allOf": [
+ {
+ "if": {
+ "properties": {
+ "schema_type": {
+ "const": "primitive"
+ }
+ }
+ },
+ "then": {
+ "required": [
+ "type"
+ ],
+ "properties": {
+ "type": {
+ "type": "string",
+ "enum": [
+ "string",
+ "integer",
+ "boolean",
+ "server_name",
+ "user_id",
+ "room_id",
+ "room_alias",
+ "event_id"
+ ],
+ "description": "The primitive type (only for schema_type: primitive)"
+ }
+ }
+ }
+ },
+ {
+ "if": {
+ "properties": {
+ "schema_type": {
+ "const": "array"
+ }
+ }
+ },
+ "then": {
+ "required": [
+ "items"
+ ],
+ "properties": {
+ "items": {
+ "$ref": "#/$defs/ParameterSchema",
+ "description": "The schema for array items (only for schema_type: array)"
+ }
+ }
+ }
+ },
+ {
+ "if": {
+ "properties": {
+ "schema_type": {
+ "const": "union"
+ }
+ }
+ },
+ "then": {
+ "required": [
+ "variants"
+ ],
+ "properties": {
+ "variants": {
+ "type": "array",
+ "description": "The possible variants (only for schema_type: union)",
+ "items": {
+ "$ref": "#/$defs/ParameterSchema"
+ },
+ "minItems": 1
+ }
+ }
+ }
+ },
+ {
+ "if": {
+ "properties": {
+ "schema_type": {
+ "const": "literal"
+ }
+ }
+ },
+ "then": {
+ "required": [
+ "value"
+ ],
+ "properties": {
+ "value": {
+ "description": "The literal value (only for schema_type: literal)"
+ }
+ }
+ }
+ }
+ ]
+ }
+ }
+}
diff --git a/event/cmdschema/testdata/commands/flags.json b/event/cmdschema/testdata/commands/flags.json
new file mode 100644
index 00000000..6ce1f4da
--- /dev/null
+++ b/event/cmdschema/testdata/commands/flags.json
@@ -0,0 +1,126 @@
+{
+ "$schema": "../commands.schema.json#",
+ "spec": {
+ "command": "flag",
+ "source": "@testbot",
+ "parameters": [
+ {
+ "key": "meow",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "string"
+ }
+ },
+ {
+ "key": "user",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "user_id"
+ },
+ "optional": true
+ },
+ {
+ "key": "woof",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "boolean"
+ },
+ "optional": true,
+ "fi.mau.default_value": false
+ }
+ ],
+ "fi.mau.tail_parameter": "user"
+ },
+ "tests": [
+ {
+ "name": "no flags",
+ "input": "/flag mrrp",
+ "output": {
+ "meow": "mrrp",
+ "user": null
+ }
+ },
+ {
+ "name": "no flags, has tail",
+ "input": "/flag mrrp @user:example.com",
+ "output": {
+ "meow": "mrrp",
+ "user": "@user:example.com"
+ }
+ },
+ {
+ "name": "named flag at start",
+ "input": "/flag --woof=yes mrrp @user:example.com",
+ "output": {
+ "meow": "mrrp",
+ "user": "@user:example.com",
+ "woof": true
+ }
+ },
+ {
+ "name": "boolean flag without value",
+ "input": "/flag --woof mrrp @user:example.com",
+ "output": {
+ "meow": "mrrp",
+ "user": "@user:example.com",
+ "woof": true
+ }
+ },
+ {
+ "name": "user id flag without value",
+ "input": "/flag --user --woof mrrp",
+ "error": true,
+ "output": {
+ "meow": "mrrp",
+ "user": null,
+ "woof": true
+ }
+ },
+ {
+ "name": "named flag in the middle",
+ "input": "/flag mrrp --woof=yes @user:example.com",
+ "output": {
+ "meow": "mrrp",
+ "user": "@user:example.com",
+ "woof": true
+ }
+ },
+ {
+ "name": "named flag in the middle with different value",
+ "input": "/flag mrrp --woof=no @user:example.com",
+ "output": {
+ "meow": "mrrp",
+ "user": "@user:example.com",
+ "woof": false
+ }
+ },
+ {
+ "name": "all variables named",
+ "input": "/flag --woof=no --meow=mrrp --user=@user:example.com",
+ "output": {
+ "meow": "mrrp",
+ "user": "@user:example.com",
+ "woof": false
+ }
+ },
+ {
+ "name": "all variables named with quotes",
+ "input": "/flag --woof --meow=\"meow meow mrrp\" --user=\"@user:example.com\"",
+ "output": {
+ "meow": "meow meow mrrp",
+ "user": "@user:example.com",
+ "woof": true
+ }
+ },
+ {
+ "name": "invalid value for named parameter",
+ "input": "/flag --user=meowings mrrp --woof",
+ "error": true,
+ "output": {
+ "meow": "mrrp",
+ "user": null,
+ "woof": true
+ }
+ }
+ ]
+}
diff --git a/event/cmdschema/testdata/commands/room_id_or_alias.json b/event/cmdschema/testdata/commands/room_id_or_alias.json
new file mode 100644
index 00000000..1351c292
--- /dev/null
+++ b/event/cmdschema/testdata/commands/room_id_or_alias.json
@@ -0,0 +1,85 @@
+{
+ "$schema": "../commands.schema.json#",
+ "spec": {
+ "command": "test room reference",
+ "source": "@testbot",
+ "parameters": [
+ {
+ "key": "room",
+ "schema": {
+ "schema_type": "union",
+ "variants": [
+ {
+ "schema_type": "primitive",
+ "type": "room_id"
+ },
+ {
+ "schema_type": "primitive",
+ "type": "room_alias"
+ }
+ ]
+ }
+ }
+ ]
+ },
+ "tests": [
+ {
+ "name": "room alias",
+ "input": "/test room reference #test:matrix.org",
+ "output": {
+ "room": "#test:matrix.org"
+ }
+ },
+ {
+ "name": "room id",
+ "input": "/test room reference !aiwVrNhPwbGBNjqlNu:matrix.org",
+ "output": {
+ "room": {
+ "type": "room_id",
+ "id": "!aiwVrNhPwbGBNjqlNu:matrix.org"
+ }
+ }
+ },
+ {
+ "name": "room id matrix.to link",
+ "input": "/test room reference https://matrix.to/#/!aiwVrNhPwbGBNjqlNu:matrix.org?via=example.com",
+ "output": {
+ "room": {
+ "type": "room_id",
+ "id": "!aiwVrNhPwbGBNjqlNu:matrix.org",
+ "via": [
+ "example.com"
+ ]
+ }
+ }
+ },
+ {
+ "name": "room id matrix.to link with url encoding",
+ "input": "/test room reference https://matrix.to/#/!%23test%2Froom%0Aversion%20%3Cu%3E11%3C%2Fu%3E%2C%20with%20%40%F0%9F%90%88%EF%B8%8F%3Amaunium.net?via=maunium.net",
+ "broken": "Go's url.URL does url decoding on the fragment, which breaks splitting the path segments properly",
+ "output": {
+ "room": {
+ "type": "room_id",
+ "id": "!#test/room\nversion 11, with @🐈️:maunium.net",
+ "via": [
+ "maunium.net"
+ ]
+ }
+ }
+ },
+ {
+ "name": "room id matrix: URI",
+ "input": "/test room reference matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org",
+ "output": {
+ "room": {
+ "type": "room_id",
+ "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ",
+ "via": [
+ "maunium.net",
+ "matrix.org"
+ ]
+ }
+ }
+ }
+ ]
+}
diff --git a/event/cmdschema/testdata/commands/room_reference_list.json b/event/cmdschema/testdata/commands/room_reference_list.json
new file mode 100644
index 00000000..aa266054
--- /dev/null
+++ b/event/cmdschema/testdata/commands/room_reference_list.json
@@ -0,0 +1,106 @@
+{
+ "$schema": "../commands.schema.json#",
+ "spec": {
+ "command": "test room reference",
+ "source": "@testbot",
+ "parameters": [
+ {
+ "key": "rooms",
+ "schema": {
+ "schema_type": "array",
+ "items": {
+ "schema_type": "union",
+ "variants": [
+ {
+ "schema_type": "primitive",
+ "type": "room_id"
+ },
+ {
+ "schema_type": "primitive",
+ "type": "room_alias"
+ }
+ ]
+ }
+ }
+ }
+ ]
+ },
+ "tests": [
+ {
+ "name": "room alias",
+ "input": "/test room reference #test:matrix.org",
+ "output": {
+ "rooms": [
+ "#test:matrix.org"
+ ]
+ }
+ },
+ {
+ "name": "room id",
+ "input": "/test room reference !aiwVrNhPwbGBNjqlNu:matrix.org",
+ "output": {
+ "rooms": [
+ {
+ "type": "room_id",
+ "id": "!aiwVrNhPwbGBNjqlNu:matrix.org"
+ }
+ ]
+ }
+ },
+ {
+ "name": "two room ids",
+ "input": "/test room reference !mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ !aiwVrNhPwbGBNjqlNu:matrix.org",
+ "output": {
+ "rooms": [
+ {
+ "type": "room_id",
+ "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ"
+ },
+ {
+ "type": "room_id",
+ "id": "!aiwVrNhPwbGBNjqlNu:matrix.org"
+ }
+ ]
+ }
+ },
+ {
+ "name": "room id matrix: URI",
+ "input": "/test room reference matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org",
+ "output": {
+ "rooms": [
+ {
+ "type": "room_id",
+ "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ",
+ "via": [
+ "maunium.net",
+ "matrix.org"
+ ]
+ }
+ ]
+ }
+ },
+ {
+ "name": "room id matrix: URI and matrix.to URL",
+ "input": "/test room reference https://matrix.to/#/!aiwVrNhPwbGBNjqlNu:matrix.org?via=example.com matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org",
+ "output": {
+ "rooms": [
+ {
+ "type": "room_id",
+ "id": "!aiwVrNhPwbGBNjqlNu:matrix.org",
+ "via": [
+ "example.com"
+ ]
+ },
+ {
+ "type": "room_id",
+ "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ",
+ "via": [
+ "maunium.net",
+ "matrix.org"
+ ]
+ }
+ ]
+ }
+ }
+ ]
+}
diff --git a/event/cmdschema/testdata/commands/simple.json b/event/cmdschema/testdata/commands/simple.json
new file mode 100644
index 00000000..94667323
--- /dev/null
+++ b/event/cmdschema/testdata/commands/simple.json
@@ -0,0 +1,46 @@
+{
+ "$schema": "../commands.schema.json#",
+ "spec": {
+ "command": "test simple",
+ "source": "@testbot",
+ "parameters": [
+ {
+ "key": "meow",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "string"
+ }
+ }
+ ]
+ },
+ "tests": [
+ {
+ "name": "success",
+ "input": "/test simple mrrp",
+ "output": {
+ "meow": "mrrp"
+ }
+ },
+ {
+ "name": "directed success",
+ "input": "/test simple@testbot mrrp",
+ "output": {
+ "meow": "mrrp"
+ }
+ },
+ {
+ "name": "missing parameter",
+ "input": "/test simple",
+ "error": true,
+ "output": {
+ "meow": ""
+ }
+ },
+ {
+ "name": "directed at another bot",
+ "input": "/test simple@anotherbot mrrp",
+ "error": false,
+ "output": null
+ }
+ ]
+}
diff --git a/event/cmdschema/testdata/commands/tail.json b/event/cmdschema/testdata/commands/tail.json
new file mode 100644
index 00000000..9782f8ec
--- /dev/null
+++ b/event/cmdschema/testdata/commands/tail.json
@@ -0,0 +1,60 @@
+{
+ "$schema": "../commands.schema.json#",
+ "spec": {
+ "command": "tail",
+ "source": "@testbot",
+ "parameters": [
+ {
+ "key": "meow",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "string"
+ }
+ },
+ {
+ "key": "reason",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "string"
+ },
+ "optional": true
+ },
+ {
+ "key": "woof",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "boolean"
+ },
+ "optional": true
+ }
+ ],
+ "fi.mau.tail_parameter": "reason"
+ },
+ "tests": [
+ {
+ "name": "no tail or flag",
+ "input": "/tail mrrp",
+ "output": {
+ "meow": "mrrp",
+ "reason": ""
+ }
+ },
+ {
+ "name": "tail, no flag",
+ "input": "/tail mrrp meow meow",
+ "output": {
+ "meow": "mrrp",
+ "reason": "meow meow"
+ }
+ },
+ {
+ "name": "flag before tail",
+ "input": "/tail mrrp --woof meow meow",
+ "output": {
+ "meow": "mrrp",
+ "reason": "meow meow",
+ "woof": true
+ }
+ }
+ ]
+}
diff --git a/hicli/database/upgrades/upgrades.go b/event/cmdschema/testdata/data.go
similarity index 54%
rename from hicli/database/upgrades/upgrades.go
rename to event/cmdschema/testdata/data.go
index 9d0bd1a0..eceea3d2 100644
--- a/hicli/database/upgrades/upgrades.go
+++ b/event/cmdschema/testdata/data.go
@@ -1,22 +1,14 @@
-// Copyright (c) 2024 Tulir Asokan
+// Copyright (c) 2026 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-package upgrades
+package testdata
import (
"embed"
-
- "go.mau.fi/util/dbutil"
)
-var Table dbutil.UpgradeTable
-
-//go:embed *.sql
-var upgrades embed.FS
-
-func init() {
- Table.RegisterFS(upgrades)
-}
+//go:embed *
+var FS embed.FS
diff --git a/event/cmdschema/testdata/parse_quote.json b/event/cmdschema/testdata/parse_quote.json
new file mode 100644
index 00000000..8f52b7f5
--- /dev/null
+++ b/event/cmdschema/testdata/parse_quote.json
@@ -0,0 +1,30 @@
+[
+ {"name": "empty string", "input": "", "output": ["", "", false]},
+ {"name": "single word", "input": "meow", "output": ["meow", "", false]},
+ {"name": "two words", "input": "meow woof", "output": ["meow", "woof", false]},
+ {"name": "many words", "input": "meow meow mrrp", "output": ["meow", "meow mrrp", false]},
+ {"name": "extra spaces", "input": "meow meow mrrp", "output": ["meow", "meow mrrp", false]},
+ {"name": "trailing space", "input": "meow ", "output": ["meow", "", false]},
+ {"name": "only spaces", "input": " ", "output": ["", "", false]},
+ {"name": "leading spaces", "input": " meow woof", "output": ["", "meow woof", false]},
+ {"name": "backslash at end unquoted", "input": "meow\\ woof", "output": ["meow\\", "woof", false]},
+ {"name": "quoted word", "input": "\"meow\" meow mrrp", "output": ["meow", "meow mrrp", true]},
+ {"name": "quoted words", "input": "\"meow meow\" mrrp", "output": ["meow meow", "mrrp", true]},
+ {"name": "spaces in quotes", "input": "\" meow meow \" mrrp", "output": [" meow meow ", "mrrp", true]},
+ {"name": "empty quoted string", "input": "\"\"", "output": ["", "", true]},
+ {"name": "empty quoted with trailing", "input": "\"\" meow", "output": ["", "meow", true]},
+ {"name": "quote no space before next", "input": "\"meow\"woof", "output": ["meow", "woof", true]},
+ {"name": "just opening quote", "input": "\"", "output": ["", "", true]},
+ {"name": "quote then space then text", "input": "\" meow", "output": [" meow", "", true]},
+ {"name": "quotes after word", "input": "meow \" meow mrrp \"", "output": ["meow", "\" meow mrrp \"", false]},
+ {"name": "escaped quote", "input": "\"meow\\\" meow\" mrrp", "output": ["meow\" meow", "mrrp", true]},
+ {"name": "missing end quote", "input": "\"meow meow mrrp", "output": ["meow meow mrrp", "", true]},
+ {"name": "missing end quote with escaped quote", "input": "\"meow\\\" meow mrrp", "output": ["meow\" meow mrrp", "", true]},
+ {"name": "quote in the middle", "input": "me\"ow meow mrrp", "output": ["me\"ow", "meow mrrp", false]},
+ {"name": "backslash in the middle", "input": "me\\ow meow mrrp", "output": ["me\\ow", "meow mrrp", false]},
+ {"name": "other escaped character", "input": "\"m\\eow\" meow mrrp", "output": ["meow", "meow mrrp", true]},
+ {"name": "escaped backslashes", "input": "\"m\\\\e\\\"ow\\\\\" meow mrrp", "output": ["m\\e\"ow\\", "meow mrrp", true]},
+ {"name": "just quotes", "input": "\"\\\"\\\"\\\\\\\"\" meow", "output": ["\"\"\\\"", "meow", true]},
+ {"name": "escape at eof", "input": "\"meow\\", "output": ["meow", "", true]},
+ {"name": "escaped backslash at eof", "input": "\"meow\\\\", "output": ["meow\\", "", true]}
+]
diff --git a/event/cmdschema/testdata/parse_quote.schema.json b/event/cmdschema/testdata/parse_quote.schema.json
new file mode 100644
index 00000000..9f249116
--- /dev/null
+++ b/event/cmdschema/testdata/parse_quote.schema.json
@@ -0,0 +1,46 @@
+{
+ "$schema": "https://json-schema.org/draft/2020-12/schema#",
+ "$id": "parse_quote.schema.json",
+ "title": "parseQuote test cases",
+ "description": "Test cases for the parseQuoted function",
+ "type": "array",
+ "items": {
+ "type": "object",
+ "required": [
+ "name",
+ "input",
+ "output"
+ ],
+ "properties": {
+ "name": {
+ "type": "string",
+ "description": "Name of the test case"
+ },
+ "input": {
+ "type": "string",
+ "description": "Input string to be parsed"
+ },
+ "output": {
+ "type": "array",
+ "description": "Expected output of parsing: [first word, remaining text, was quoted]",
+ "minItems": 3,
+ "maxItems": 3,
+ "prefixItems": [
+ {
+ "type": "string",
+ "description": "First parsed word"
+ },
+ {
+ "type": "string",
+ "description": "Remaining text after the first word"
+ },
+ {
+ "type": "boolean",
+ "description": "Whether the first word was quoted"
+ }
+ ]
+ }
+ },
+ "additionalProperties": false
+ }
+}
diff --git a/event/content.go b/event/content.go
index 882d3368..814aeec4 100644
--- a/event/content.go
+++ b/event/content.go
@@ -18,6 +18,7 @@ import (
// This is used by Content.ParseRaw() for creating the correct type of struct.
var TypeMap = map[Type]reflect.Type{
StateMember: reflect.TypeOf(MemberEventContent{}),
+ StateThirdPartyInvite: reflect.TypeOf(ThirdPartyInviteEventContent{}),
StatePowerLevels: reflect.TypeOf(PowerLevelsEventContent{}),
StateCanonicalAlias: reflect.TypeOf(CanonicalAliasEventContent{}),
StateRoomName: reflect.TypeOf(RoomNameEventContent{}),
@@ -38,7 +39,9 @@ var TypeMap = map[Type]reflect.Type{
StateHalfShotBridge: reflect.TypeOf(BridgeEventContent{}),
StateSpaceParent: reflect.TypeOf(SpaceParentEventContent{}),
StateSpaceChild: reflect.TypeOf(SpaceChildEventContent{}),
- StateInsertionMarker: reflect.TypeOf(InsertionMarkerContent{}),
+
+ StateRoomPolicy: reflect.TypeOf(RoomPolicyEventContent{}),
+ StateUnstableRoomPolicy: reflect.TypeOf(RoomPolicyEventContent{}),
StateLegacyPolicyRoom: reflect.TypeOf(ModPolicyContent{}),
StateLegacyPolicyServer: reflect.TypeOf(ModPolicyContent{}),
@@ -48,6 +51,8 @@ var TypeMap = map[Type]reflect.Type{
StateUnstablePolicyUser: reflect.TypeOf(ModPolicyContent{}),
StateElementFunctionalMembers: reflect.TypeOf(ElementFunctionalMembersContent{}),
+ StateBeeperRoomFeatures: reflect.TypeOf(RoomFeatures{}),
+ StateBeeperDisappearingTimer: reflect.TypeOf(BeeperDisappearingTimer{}),
EventMessage: reflect.TypeOf(MessageEventContent{}),
EventSticker: reflect.TypeOf(MessageEventContent{}),
@@ -58,7 +63,11 @@ var TypeMap = map[Type]reflect.Type{
EventUnstablePollStart: reflect.TypeOf(PollStartEventContent{}),
EventUnstablePollResponse: reflect.TypeOf(PollResponseEventContent{}),
- BeeperMessageStatus: reflect.TypeOf(BeeperMessageStatusEventContent{}),
+ BeeperMessageStatus: reflect.TypeOf(BeeperMessageStatusEventContent{}),
+ BeeperTranscription: reflect.TypeOf(BeeperTranscriptionEventContent{}),
+ BeeperDeleteChat: reflect.TypeOf(BeeperChatDeleteEventContent{}),
+ BeeperAcceptMessageRequest: reflect.TypeOf(BeeperAcceptMessageRequestEventContent{}),
+ BeeperSendState: reflect.TypeOf(BeeperSendStateEventContent{}),
AccountDataRoomTags: reflect.TypeOf(TagEventContent{}),
AccountDataDirectChats: reflect.TypeOf(DirectChatsEventContent{}),
@@ -67,9 +76,11 @@ var TypeMap = map[Type]reflect.Type{
AccountDataMarkedUnread: reflect.TypeOf(MarkedUnreadEventContent{}),
AccountDataBeeperMute: reflect.TypeOf(BeeperMuteEventContent{}),
- EphemeralEventTyping: reflect.TypeOf(TypingEventContent{}),
- EphemeralEventReceipt: reflect.TypeOf(ReceiptEventContent{}),
- EphemeralEventPresence: reflect.TypeOf(PresenceEventContent{}),
+ EphemeralEventTyping: reflect.TypeOf(TypingEventContent{}),
+ EphemeralEventReceipt: reflect.TypeOf(ReceiptEventContent{}),
+ EphemeralEventPresence: reflect.TypeOf(PresenceEventContent{}),
+ EphemeralEventEncrypted: reflect.TypeOf(EncryptedEventContent{}),
+ BeeperEphemeralEventAIStream: reflect.TypeOf(BeeperAIStreamEventContent{}),
InRoomVerificationReady: reflect.TypeOf(VerificationReadyEventContent{}),
InRoomVerificationStart: reflect.TypeOf(VerificationStartEventContent{}),
@@ -121,7 +132,7 @@ var TypeMap = map[Type]reflect.Type{
// When being marshaled into JSON, the data in Parsed will be marshaled first and then recursively merged
// with the data in Raw. Values in Raw are preferred, but nested objects will be recursed into before merging,
// rather than overriding the whole object with the one in Raw).
-// If one of them is nil, the only the other is used. If both (Parsed and Raw) are nil, VeryRaw is used instead.
+// If one of them is nil, then only the other is used. If both (Parsed and Raw) are nil, VeryRaw is used instead.
type Content struct {
VeryRaw json.RawMessage
Raw map[string]interface{}
@@ -188,6 +199,13 @@ func IsUnsupportedContentType(err error) bool {
var ErrContentAlreadyParsed = errors.New("content is already parsed")
var ErrUnsupportedContentType = errors.New("unsupported event type")
+func (content *Content) GetRaw() map[string]interface{} {
+ if content.Raw == nil {
+ content.Raw = make(map[string]interface{})
+ }
+ return content.Raw
+}
+
func (content *Content) ParseRaw(evtType Type) error {
if content.Parsed != nil {
return ErrContentAlreadyParsed
diff --git a/event/delayed.go b/event/delayed.go
new file mode 100644
index 00000000..fefb62af
--- /dev/null
+++ b/event/delayed.go
@@ -0,0 +1,70 @@
+package event
+
+import (
+ "encoding/json"
+
+ "go.mau.fi/util/jsontime"
+
+ "maunium.net/go/mautrix/id"
+)
+
+type ScheduledDelayedEvent struct {
+ DelayID id.DelayID `json:"delay_id"`
+ RoomID id.RoomID `json:"room_id"`
+ Type Type `json:"type"`
+ StateKey *string `json:"state_key,omitempty"`
+ Delay int64 `json:"delay"`
+ RunningSince jsontime.UnixMilli `json:"running_since"`
+ Content Content `json:"content"`
+}
+
+func (e ScheduledDelayedEvent) AsEvent(eventID id.EventID, ts jsontime.UnixMilli) (*Event, error) {
+ evt := &Event{
+ ID: eventID,
+ RoomID: e.RoomID,
+ Type: e.Type,
+ StateKey: e.StateKey,
+ Content: e.Content,
+ Timestamp: ts.UnixMilli(),
+ }
+ return evt, evt.Content.ParseRaw(evt.Type)
+}
+
+type FinalisedDelayedEvent struct {
+ DelayedEvent *ScheduledDelayedEvent `json:"scheduled_event"`
+ Outcome DelayOutcome `json:"outcome"`
+ Reason DelayReason `json:"reason"`
+ Error json.RawMessage `json:"error,omitempty"`
+ EventID id.EventID `json:"event_id,omitempty"`
+ Timestamp jsontime.UnixMilli `json:"origin_server_ts"`
+}
+
+type DelayStatus string
+
+var (
+ DelayStatusScheduled DelayStatus = "scheduled"
+ DelayStatusFinalised DelayStatus = "finalised"
+)
+
+type DelayAction string
+
+var (
+ DelayActionSend DelayAction = "send"
+ DelayActionCancel DelayAction = "cancel"
+ DelayActionRestart DelayAction = "restart"
+)
+
+type DelayOutcome string
+
+var (
+ DelayOutcomeSend DelayOutcome = "send"
+ DelayOutcomeCancel DelayOutcome = "cancel"
+)
+
+type DelayReason string
+
+var (
+ DelayReasonAction DelayReason = "action"
+ DelayReasonError DelayReason = "error"
+ DelayReasonDelay DelayReason = "delay"
+)
diff --git a/event/encryption.go b/event/encryption.go
index cf9c2814..c60cb91a 100644
--- a/event/encryption.go
+++ b/event/encryption.go
@@ -63,7 +63,7 @@ func (content *EncryptedEventContent) UnmarshalJSON(data []byte) error {
return json.Unmarshal(content.Ciphertext, &content.OlmCiphertext)
case id.AlgorithmMegolmV1:
if len(content.Ciphertext) == 0 || content.Ciphertext[0] != '"' || content.Ciphertext[len(content.Ciphertext)-1] != '"' {
- return id.InputNotJSONString
+ return fmt.Errorf("ciphertext %w", id.ErrInputNotJSONString)
}
content.MegolmCiphertext = content.Ciphertext[1 : len(content.Ciphertext)-1]
}
@@ -132,8 +132,9 @@ type RoomKeyRequestEventContent struct {
type RequestedKeyInfo struct {
Algorithm id.Algorithm `json:"algorithm"`
RoomID id.RoomID `json:"room_id"`
- SenderKey id.SenderKey `json:"sender_key"`
SessionID id.SessionID `json:"session_id"`
+ // Deprecated: Matrix v1.3
+ SenderKey id.SenderKey `json:"sender_key"`
}
type RoomKeyWithheldCode string
diff --git a/event/events.go b/event/events.go
index 23769ae8..72c1e161 100644
--- a/event/events.go
+++ b/event/events.go
@@ -118,6 +118,9 @@ type MautrixInfo struct {
DecryptionDuration time.Duration
CheckpointSent bool
+ // When using MSC4222 and the state_after field, this field is set
+ // for timeline events to indicate they shouldn't update room state.
+ IgnoreState bool
}
func (evt *Event) GetStateKey() string {
@@ -127,30 +130,29 @@ func (evt *Event) GetStateKey() string {
return ""
}
-type StrippedState struct {
- Content Content `json:"content"`
- Type Type `json:"type"`
- StateKey string `json:"state_key"`
- Sender id.UserID `json:"sender"`
-}
-
type Unsigned struct {
- PrevContent *Content `json:"prev_content,omitempty"`
- PrevSender id.UserID `json:"prev_sender,omitempty"`
- ReplacesState id.EventID `json:"replaces_state,omitempty"`
- Age int64 `json:"age,omitempty"`
- TransactionID string `json:"transaction_id,omitempty"`
- Relations *Relations `json:"m.relations,omitempty"`
- RedactedBecause *Event `json:"redacted_because,omitempty"`
- InviteRoomState []StrippedState `json:"invite_room_state,omitempty"`
+ PrevContent *Content `json:"prev_content,omitempty"`
+ PrevSender id.UserID `json:"prev_sender,omitempty"`
+ Membership Membership `json:"membership,omitempty"`
+ ReplacesState id.EventID `json:"replaces_state,omitempty"`
+ Age int64 `json:"age,omitempty"`
+ TransactionID string `json:"transaction_id,omitempty"`
+ Relations *Relations `json:"m.relations,omitempty"`
+ RedactedBecause *Event `json:"redacted_because,omitempty"`
+ InviteRoomState []*Event `json:"invite_room_state,omitempty"`
- BeeperHSOrder int64 `json:"com.beeper.hs.order,omitempty"`
- BeeperHSSuborder int64 `json:"com.beeper.hs.suborder,omitempty"`
- BeeperFromBackup bool `json:"com.beeper.from_backup,omitempty"`
+ BeeperHSOrder int64 `json:"com.beeper.hs.order,omitempty"`
+ BeeperHSSuborder int16 `json:"com.beeper.hs.suborder,omitempty"`
+ BeeperHSOrderString *BeeperEncodedOrder `json:"com.beeper.hs.order_string,omitempty"`
+ BeeperFromBackup bool `json:"com.beeper.from_backup,omitempty"`
+
+ ElementSoftFailed bool `json:"io.element.synapse.soft_failed,omitempty"`
+ ElementPolicyServerSpammy bool `json:"io.element.synapse.policy_server_spammy,omitempty"`
}
func (us *Unsigned) IsEmpty() bool {
- return us.PrevContent == nil && us.PrevSender == "" && us.ReplacesState == "" && us.Age == 0 &&
+ return us.PrevContent == nil && us.PrevSender == "" && us.ReplacesState == "" && us.Age == 0 && us.Membership == "" &&
us.TransactionID == "" && us.RedactedBecause == nil && us.InviteRoomState == nil && us.Relations == nil &&
- us.BeeperHSOrder == 0 && us.BeeperHSSuborder == 0
+ us.BeeperHSOrder == 0 && us.BeeperHSSuborder == 0 && us.BeeperHSOrderString.IsZero() &&
+ !us.ElementSoftFailed
}
diff --git a/event/member.go b/event/member.go
index ebafdcb7..9956a36b 100644
--- a/event/member.go
+++ b/event/member.go
@@ -7,8 +7,6 @@
package event
import (
- "encoding/json"
-
"maunium.net/go/mautrix/id"
)
@@ -35,19 +33,37 @@ const (
// MemberEventContent represents the content of a m.room.member state event.
// https://spec.matrix.org/v1.2/client-server-api/#mroommember
type MemberEventContent struct {
- Membership Membership `json:"membership"`
- AvatarURL id.ContentURIString `json:"avatar_url,omitempty"`
- Displayname string `json:"displayname,omitempty"`
- IsDirect bool `json:"is_direct,omitempty"`
- ThirdPartyInvite *ThirdPartyInvite `json:"third_party_invite,omitempty"`
- Reason string `json:"reason,omitempty"`
+ Membership Membership `json:"membership"`
+ AvatarURL id.ContentURIString `json:"avatar_url,omitempty"`
+ Displayname string `json:"displayname,omitempty"`
+ IsDirect bool `json:"is_direct,omitempty"`
+ ThirdPartyInvite *ThirdPartyInvite `json:"third_party_invite,omitempty"`
+ Reason string `json:"reason,omitempty"`
+ JoinAuthorisedViaUsersServer id.UserID `json:"join_authorised_via_users_server,omitempty"`
+ MSC3414File *EncryptedFileInfo `json:"org.matrix.msc3414.file,omitempty"`
+
+ MSC4293RedactEvents bool `json:"org.matrix.msc4293.redact_events,omitempty"`
+}
+
+type SignedThirdPartyInvite struct {
+ Token string `json:"token"`
+ Signatures map[string]map[id.KeyID]string `json:"signatures,omitempty"`
+ MXID string `json:"mxid"`
}
type ThirdPartyInvite struct {
- DisplayName string `json:"display_name"`
- Signed struct {
- Token string `json:"token"`
- Signatures json.RawMessage `json:"signatures"`
- MXID string `json:"mxid"`
- }
+ DisplayName string `json:"display_name"`
+ Signed SignedThirdPartyInvite `json:"signed"`
+}
+
+type ThirdPartyInviteEventContent struct {
+ DisplayName string `json:"display_name"`
+ KeyValidityURL string `json:"key_validity_url"`
+ PublicKey id.Ed25519 `json:"public_key"`
+ PublicKeys []ThirdPartyInviteKey `json:"public_keys,omitempty"`
+}
+
+type ThirdPartyInviteKey struct {
+ KeyValidityURL string `json:"key_validity_url,omitempty"`
+ PublicKey id.Ed25519 `json:"public_key"`
}
diff --git a/event/message.go b/event/message.go
index 9badd9a2..3fb3dc82 100644
--- a/event/message.go
+++ b/event/message.go
@@ -8,12 +8,11 @@ package event
import (
"encoding/json"
+ "html"
"slices"
"strconv"
"strings"
- "golang.org/x/net/html"
-
"maunium.net/go/mautrix/crypto/attachment"
"maunium.net/go/mautrix/id"
)
@@ -33,7 +32,7 @@ func (mt MessageType) IsText() bool {
func (mt MessageType) IsMedia() bool {
switch mt {
- case MsgImage, MsgVideo, MsgAudio, MsgFile, MessageType(EventSticker.Type):
+ case MsgImage, MsgVideo, MsgAudio, MsgFile, CapMsgSticker:
return true
default:
return false
@@ -136,11 +135,42 @@ type MessageEventContent struct {
BeeperGalleryCaption string `json:"com.beeper.gallery.caption,omitempty"`
BeeperGalleryCaptionHTML string `json:"com.beeper.gallery.caption_html,omitempty"`
BeeperPerMessageProfile *BeeperPerMessageProfile `json:"com.beeper.per_message_profile,omitempty"`
+ BeeperActionMessage *BeeperActionMessage `json:"com.beeper.action_message,omitempty"`
BeeperLinkPreviews []*BeeperLinkPreview `json:"com.beeper.linkpreviews,omitempty"`
+ BeeperDisappearingTimer *BeeperDisappearingTimer `json:"com.beeper.disappearing_timer,omitempty"`
+
MSC1767Audio *MSC1767Audio `json:"org.matrix.msc1767.audio,omitempty"`
MSC3245Voice *MSC3245Voice `json:"org.matrix.msc3245.voice,omitempty"`
+
+ MSC4391BotCommand *MSC4391BotCommandInput `json:"org.matrix.msc4391.command,omitempty"`
+}
+
+func (content *MessageEventContent) GetCapMsgType() CapabilityMsgType {
+ switch content.MsgType {
+ case CapMsgSticker:
+ return CapMsgSticker
+ case "":
+ if content.URL != "" || content.File != nil {
+ return CapMsgSticker
+ }
+ case MsgImage:
+ return MsgImage
+ case MsgAudio:
+ if content.MSC3245Voice != nil {
+ return CapMsgVoice
+ }
+ return MsgAudio
+ case MsgVideo:
+ if content.Info != nil && content.Info.MauGIF {
+ return CapMsgGIF
+ }
+ return MsgVideo
+ case MsgFile:
+ return MsgFile
+ }
+ return ""
}
func (content *MessageEventContent) GetFileName() string {
@@ -185,6 +215,7 @@ func (content *MessageEventContent) SetEdit(original id.EventID) {
content.RelatesTo = (&RelatesTo{}).SetReplace(original)
if content.MsgType == MsgText || content.MsgType == MsgNotice {
content.Body = "* " + content.Body
+ content.Mentions = &Mentions{}
if content.Format == FormatHTML && len(content.FormattedBody) > 0 {
content.FormattedBody = "* " + content.FormattedBody
}
@@ -245,24 +276,46 @@ func (m *Mentions) Has(userID id.UserID) bool {
return m != nil && slices.Contains(m.UserIDs, userID)
}
+func (m *Mentions) Merge(other *Mentions) *Mentions {
+ if m == nil {
+ return other
+ } else if other == nil {
+ return m
+ }
+ return &Mentions{
+ UserIDs: slices.Concat(m.UserIDs, other.UserIDs),
+ Room: m.Room || other.Room,
+ }
+}
+
+type MSC4391BotCommandInputCustom[T any] struct {
+ Command string `json:"command"`
+ Arguments T `json:"arguments,omitempty"`
+}
+
+type MSC4391BotCommandInput = MSC4391BotCommandInputCustom[json.RawMessage]
+
type EncryptedFileInfo struct {
attachment.EncryptedFile
URL id.ContentURIString `json:"url"`
}
type FileInfo struct {
- MimeType string `json:"mimetype,omitempty"`
- ThumbnailInfo *FileInfo `json:"thumbnail_info,omitempty"`
- ThumbnailURL id.ContentURIString `json:"thumbnail_url,omitempty"`
- ThumbnailFile *EncryptedFileInfo `json:"thumbnail_file,omitempty"`
+ MimeType string
+ ThumbnailInfo *FileInfo
+ ThumbnailURL id.ContentURIString
+ ThumbnailFile *EncryptedFileInfo
- Blurhash string `json:"blurhash,omitempty"`
- AnoaBlurhash string `json:"xyz.amorgan.blurhash,omitempty"`
+ Blurhash string
+ AnoaBlurhash string
- Width int `json:"-"`
- Height int `json:"-"`
- Duration int `json:"-"`
- Size int `json:"-"`
+ MauGIF bool
+ IsAnimated bool
+
+ Width int
+ Height int
+ Duration int
+ Size int
}
type serializableFileInfo struct {
@@ -274,6 +327,9 @@ type serializableFileInfo struct {
Blurhash string `json:"blurhash,omitempty"`
AnoaBlurhash string `json:"xyz.amorgan.blurhash,omitempty"`
+ MauGIF bool `json:"fi.mau.gif,omitempty"`
+ IsAnimated bool `json:"is_animated,omitempty"`
+
Width json.Number `json:"w,omitempty"`
Height json.Number `json:"h,omitempty"`
Duration json.Number `json:"duration,omitempty"`
@@ -290,6 +346,9 @@ func (sfi *serializableFileInfo) CopyFrom(fileInfo *FileInfo) *serializableFileI
ThumbnailInfo: (&serializableFileInfo{}).CopyFrom(fileInfo.ThumbnailInfo),
ThumbnailFile: fileInfo.ThumbnailFile,
+ MauGIF: fileInfo.MauGIF,
+ IsAnimated: fileInfo.IsAnimated,
+
Blurhash: fileInfo.Blurhash,
AnoaBlurhash: fileInfo.AnoaBlurhash,
}
@@ -318,6 +377,8 @@ func (sfi *serializableFileInfo) CopyTo(fileInfo *FileInfo) {
MimeType: sfi.MimeType,
ThumbnailURL: sfi.ThumbnailURL,
ThumbnailFile: sfi.ThumbnailFile,
+ MauGIF: sfi.MauGIF,
+ IsAnimated: sfi.IsAnimated,
Blurhash: sfi.Blurhash,
AnoaBlurhash: sfi.AnoaBlurhash,
}
diff --git a/event/message_test.go b/event/message_test.go
index 562a6622..c721df35 100644
--- a/event/message_test.go
+++ b/event/message_test.go
@@ -33,7 +33,7 @@ const invalidMessageEvent = `{
func TestMessageEventContent__ParseInvalid(t *testing.T) {
var evt *event.Event
err := json.Unmarshal([]byte(invalidMessageEvent), &evt)
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender)
assert.Equal(t, event.EventMessage, evt.Type)
@@ -42,7 +42,7 @@ func TestMessageEventContent__ParseInvalid(t *testing.T) {
assert.Equal(t, id.RoomID("!bar"), evt.RoomID)
err = evt.Content.ParseRaw(evt.Type)
- assert.NotNil(t, err)
+ assert.Error(t, err)
}
const messageEvent = `{
@@ -68,7 +68,7 @@ const messageEvent = `{
func TestMessageEventContent__ParseEdit(t *testing.T) {
var evt *event.Event
err := json.Unmarshal([]byte(messageEvent), &evt)
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender)
assert.Equal(t, event.EventMessage, evt.Type)
@@ -110,7 +110,7 @@ const imageMessageEvent = `{
func TestMessageEventContent__ParseMedia(t *testing.T) {
var evt *event.Event
err := json.Unmarshal([]byte(imageMessageEvent), &evt)
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender)
assert.Equal(t, event.EventMessage, evt.Type)
@@ -125,7 +125,7 @@ func TestMessageEventContent__ParseMedia(t *testing.T) {
content := evt.Content.Parsed.(*event.MessageEventContent)
assert.Equal(t, event.MsgImage, content.MsgType)
parsedURL, err := content.URL.Parse()
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, id.ContentURI{Homeserver: "example.com", FileID: "image"}, parsedURL)
assert.Nil(t, content.NewContent)
assert.Equal(t, "image/png", content.GetInfo().MimeType)
@@ -145,7 +145,7 @@ const expectedMarshalResult = `{"msgtype":"m.text","body":"test"}`
func TestMessageEventContent__Marshal(t *testing.T) {
data, err := json.Marshal(parsedMessage)
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, expectedMarshalResult, string(data))
}
@@ -163,6 +163,6 @@ const expectedCustomMarshalResult = `{"body":"test","msgtype":"m.text","net.maun
func TestMessageEventContent__Marshal_Custom(t *testing.T) {
data, err := json.Marshal(customParsedMessage)
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, expectedCustomMarshalResult, string(data))
}
diff --git a/event/poll.go b/event/poll.go
index 37333015..9082f65e 100644
--- a/event/poll.go
+++ b/event/poll.go
@@ -29,16 +29,13 @@ func (content *PollResponseEventContent) SetRelatesTo(rel *RelatesTo) {
}
type MSC1767Message struct {
- Text string `json:"org.matrix.msc1767.text,omitempty"`
- HTML string `json:"org.matrix.msc1767.html,omitempty"`
- Message []struct {
- MimeType string `json:"mimetype"`
- Body string `json:"body"`
- } `json:"org.matrix.msc1767.message,omitempty"`
+ Text string `json:"org.matrix.msc1767.text,omitempty"`
+ HTML string `json:"org.matrix.msc1767.html,omitempty"`
+ Message []ExtensibleText `json:"org.matrix.msc1767.message,omitempty"`
}
type PollStartEventContent struct {
- RelatesTo *RelatesTo `json:"m.relates_to"`
+ RelatesTo *RelatesTo `json:"m.relates_to,omitempty"`
Mentions *Mentions `json:"m.mentions,omitempty"`
PollStart struct {
Kind string `json:"kind"`
diff --git a/event/powerlevels.go b/event/powerlevels.go
index 2f4d4573..668eb6d3 100644
--- a/event/powerlevels.go
+++ b/event/powerlevels.go
@@ -7,6 +7,8 @@
package event
import (
+ "math"
+ "slices"
"sync"
"go.mau.fi/util/ptr"
@@ -26,6 +28,9 @@ type PowerLevelsEventContent struct {
Events map[string]int `json:"events,omitempty"`
EventsDefault int `json:"events_default,omitempty"`
+ beeperEphemeralLock sync.RWMutex
+ BeeperEphemeral map[string]int `json:"com.beeper.ephemeral,omitempty"`
+
Notifications *NotificationPowerLevels `json:"notifications,omitempty"`
StateDefaultPtr *int `json:"state_default,omitempty"`
@@ -34,6 +39,12 @@ type PowerLevelsEventContent struct {
KickPtr *int `json:"kick,omitempty"`
BanPtr *int `json:"ban,omitempty"`
RedactPtr *int `json:"redact,omitempty"`
+
+ BeeperEphemeralDefaultPtr *int `json:"com.beeper.ephemeral_default,omitempty"`
+
+ // This is not a part of power levels, it's added by mautrix-go internally in certain places
+ // in order to detect creator power accurately.
+ CreateEvent *Event `json:"-"`
}
func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent {
@@ -45,6 +56,7 @@ func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent {
UsersDefault: pl.UsersDefault,
Events: maps.Clone(pl.Events),
EventsDefault: pl.EventsDefault,
+ BeeperEphemeral: maps.Clone(pl.BeeperEphemeral),
StateDefaultPtr: ptr.Clone(pl.StateDefaultPtr),
Notifications: pl.Notifications.Clone(),
@@ -53,6 +65,10 @@ func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent {
KickPtr: ptr.Clone(pl.KickPtr),
BanPtr: ptr.Clone(pl.BanPtr),
RedactPtr: ptr.Clone(pl.RedactPtr),
+
+ BeeperEphemeralDefaultPtr: ptr.Clone(pl.BeeperEphemeralDefaultPtr),
+
+ CreateEvent: pl.CreateEvent,
}
}
@@ -111,7 +127,17 @@ func (pl *PowerLevelsEventContent) StateDefault() int {
return 50
}
+func (pl *PowerLevelsEventContent) BeeperEphemeralDefault() int {
+ if pl.BeeperEphemeralDefaultPtr != nil {
+ return *pl.BeeperEphemeralDefaultPtr
+ }
+ return pl.EventsDefault
+}
+
func (pl *PowerLevelsEventContent) GetUserLevel(userID id.UserID) int {
+ if pl.isCreator(userID) {
+ return math.MaxInt
+ }
pl.usersLock.RLock()
defer pl.usersLock.RUnlock()
level, ok := pl.Users[userID]
@@ -121,9 +147,19 @@ func (pl *PowerLevelsEventContent) GetUserLevel(userID id.UserID) int {
return level
}
+const maxPL = 1<<53 - 1
+
func (pl *PowerLevelsEventContent) SetUserLevel(userID id.UserID, level int) {
pl.usersLock.Lock()
defer pl.usersLock.Unlock()
+ if pl.isCreator(userID) {
+ return
+ }
+ if level == math.MaxInt && maxPL < math.MaxInt {
+ // Hack to avoid breaking on 32-bit systems (they're only slightly supported)
+ x := int64(maxPL)
+ level = int(x)
+ }
if level == pl.UsersDefault {
delete(pl.Users, userID)
} else {
@@ -138,9 +174,24 @@ func (pl *PowerLevelsEventContent) EnsureUserLevel(target id.UserID, level int)
return pl.EnsureUserLevelAs("", target, level)
}
+func (pl *PowerLevelsEventContent) createContent() *CreateEventContent {
+ if pl.CreateEvent == nil {
+ return &CreateEventContent{}
+ }
+ return pl.CreateEvent.Content.AsCreate()
+}
+
+func (pl *PowerLevelsEventContent) isCreator(userID id.UserID) bool {
+ cc := pl.createContent()
+ return cc.SupportsCreatorPower() && (userID == pl.CreateEvent.Sender || slices.Contains(cc.AdditionalCreators, userID))
+}
+
func (pl *PowerLevelsEventContent) EnsureUserLevelAs(actor, target id.UserID, level int) bool {
+ if pl.isCreator(target) {
+ return false
+ }
existingLevel := pl.GetUserLevel(target)
- if actor != "" {
+ if actor != "" && !pl.isCreator(actor) {
actorLevel := pl.GetUserLevel(actor)
if actorLevel <= existingLevel || actorLevel < level {
return false
@@ -166,6 +217,29 @@ func (pl *PowerLevelsEventContent) GetEventLevel(eventType Type) int {
return level
}
+func (pl *PowerLevelsEventContent) GetBeeperEphemeralLevel(eventType Type) int {
+ pl.beeperEphemeralLock.RLock()
+ defer pl.beeperEphemeralLock.RUnlock()
+ level, ok := pl.BeeperEphemeral[eventType.String()]
+ if !ok {
+ return pl.BeeperEphemeralDefault()
+ }
+ return level
+}
+
+func (pl *PowerLevelsEventContent) SetBeeperEphemeralLevel(eventType Type, level int) {
+ pl.beeperEphemeralLock.Lock()
+ defer pl.beeperEphemeralLock.Unlock()
+ if level == pl.BeeperEphemeralDefault() {
+ delete(pl.BeeperEphemeral, eventType.String())
+ } else {
+ if pl.BeeperEphemeral == nil {
+ pl.BeeperEphemeral = make(map[string]int)
+ }
+ pl.BeeperEphemeral[eventType.String()] = level
+ }
+}
+
func (pl *PowerLevelsEventContent) SetEventLevel(eventType Type, level int) {
pl.eventsLock.Lock()
defer pl.eventsLock.Unlock()
@@ -185,7 +259,7 @@ func (pl *PowerLevelsEventContent) EnsureEventLevel(eventType Type, level int) b
func (pl *PowerLevelsEventContent) EnsureEventLevelAs(actor id.UserID, eventType Type, level int) bool {
existingLevel := pl.GetEventLevel(eventType)
- if actor != "" {
+ if actor != "" && !pl.isCreator(actor) {
actorLevel := pl.GetUserLevel(actor)
if existingLevel > actorLevel || level > actorLevel {
return false
diff --git a/event/powerlevels_ephemeral_test.go b/event/powerlevels_ephemeral_test.go
new file mode 100644
index 00000000..f5861583
--- /dev/null
+++ b/event/powerlevels_ephemeral_test.go
@@ -0,0 +1,67 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package event_test
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "maunium.net/go/mautrix/event"
+)
+
+func TestPowerLevelsEventContent_BeeperEphemeralDefaultFallsBackToEventsDefault(t *testing.T) {
+ pl := &event.PowerLevelsEventContent{
+ EventsDefault: 45,
+ }
+
+ assert.Equal(t, 45, pl.BeeperEphemeralDefault())
+
+ override := 60
+ pl.BeeperEphemeralDefaultPtr = &override
+ assert.Equal(t, 60, pl.BeeperEphemeralDefault())
+}
+
+func TestPowerLevelsEventContent_GetSetBeeperEphemeralLevel(t *testing.T) {
+ pl := &event.PowerLevelsEventContent{
+ EventsDefault: 25,
+ }
+ evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType}
+
+ assert.Equal(t, 25, pl.GetBeeperEphemeralLevel(evtType))
+
+ pl.SetBeeperEphemeralLevel(evtType, 50)
+ assert.Equal(t, 50, pl.GetBeeperEphemeralLevel(evtType))
+ require.NotNil(t, pl.BeeperEphemeral)
+ assert.Equal(t, 50, pl.BeeperEphemeral[evtType.String()])
+
+ pl.SetBeeperEphemeralLevel(evtType, 25)
+ _, exists := pl.BeeperEphemeral[evtType.String()]
+ assert.False(t, exists)
+}
+
+func TestPowerLevelsEventContent_CloneCopiesBeeperEphemeralFields(t *testing.T) {
+ override := 70
+ pl := &event.PowerLevelsEventContent{
+ EventsDefault: 35,
+ BeeperEphemeral: map[string]int{"com.example.ephemeral": 90},
+ BeeperEphemeralDefaultPtr: &override,
+ }
+
+ cloned := pl.Clone()
+ require.NotNil(t, cloned)
+ require.NotNil(t, cloned.BeeperEphemeralDefaultPtr)
+ assert.Equal(t, 70, *cloned.BeeperEphemeralDefaultPtr)
+ assert.Equal(t, 90, cloned.BeeperEphemeral["com.example.ephemeral"])
+
+ cloned.BeeperEphemeral["com.example.ephemeral"] = 99
+ *cloned.BeeperEphemeralDefaultPtr = 71
+
+ assert.Equal(t, 90, pl.BeeperEphemeral["com.example.ephemeral"])
+ assert.Equal(t, 70, *pl.BeeperEphemeralDefaultPtr)
+}
diff --git a/event/relations.go b/event/relations.go
index ea40cc06..2316cbc7 100644
--- a/event/relations.go
+++ b/event/relations.go
@@ -15,10 +15,11 @@ import (
type RelationType string
const (
- RelReplace RelationType = "m.replace"
- RelReference RelationType = "m.reference"
- RelAnnotation RelationType = "m.annotation"
- RelThread RelationType = "m.thread"
+ RelReplace RelationType = "m.replace"
+ RelReference RelationType = "m.reference"
+ RelAnnotation RelationType = "m.annotation"
+ RelThread RelationType = "m.thread"
+ RelBeeperTranscription RelationType = "com.beeper.transcription"
)
type RelatesTo struct {
@@ -33,7 +34,7 @@ type RelatesTo struct {
type InReplyTo struct {
EventID id.EventID `json:"event_id,omitempty"`
- UnstableRoomID id.RoomID `json:"room_id,omitempty"`
+ UnstableRoomID id.RoomID `json:"com.beeper.cross_room_id,omitempty"`
}
func (rel *RelatesTo) Copy() *RelatesTo {
@@ -100,6 +101,10 @@ func (rel *RelatesTo) SetReplace(mxid id.EventID) *RelatesTo {
}
func (rel *RelatesTo) SetReplyTo(mxid id.EventID) *RelatesTo {
+ if rel.Type != RelThread {
+ rel.Type = ""
+ rel.EventID = ""
+ }
rel.InReplyTo = &InReplyTo{EventID: mxid}
rel.IsFallingBack = false
return rel
diff --git a/event/reply.go b/event/reply.go
index 73f8cfc7..5f55bb80 100644
--- a/event/reply.go
+++ b/event/reply.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2020 Tulir Asokan
+// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -7,7 +7,6 @@
package event
import (
- "fmt"
"regexp"
"strings"
@@ -33,12 +32,13 @@ func TrimReplyFallbackText(text string) string {
}
func (content *MessageEventContent) RemoveReplyFallback() {
- if len(content.RelatesTo.GetReplyTo()) > 0 && !content.replyFallbackRemoved {
- if content.Format == FormatHTML {
- content.FormattedBody = TrimReplyFallbackHTML(content.FormattedBody)
+ if len(content.RelatesTo.GetReplyTo()) > 0 && !content.replyFallbackRemoved && content.Format == FormatHTML {
+ origHTML := content.FormattedBody
+ content.FormattedBody = TrimReplyFallbackHTML(content.FormattedBody)
+ if content.FormattedBody != origHTML {
+ content.Body = TrimReplyFallbackText(content.Body)
+ content.replyFallbackRemoved = true
}
- content.Body = TrimReplyFallbackText(content.Body)
- content.replyFallbackRemoved = true
}
}
@@ -47,52 +47,28 @@ func (content *MessageEventContent) GetReplyTo() id.EventID {
return content.RelatesTo.GetReplyTo()
}
-const ReplyFormat = `In reply to %s
%s
`
-
-func (evt *Event) GenerateReplyFallbackHTML() string {
- parsedContent, ok := evt.Content.Parsed.(*MessageEventContent)
- if !ok {
- return ""
- }
- parsedContent.RemoveReplyFallback()
- body := parsedContent.FormattedBody
- if len(body) == 0 {
- body = TextToHTML(parsedContent.Body)
- }
-
- senderDisplayName := evt.Sender
-
- return fmt.Sprintf(ReplyFormat, evt.RoomID, evt.ID, evt.Sender, senderDisplayName, body)
-}
-
-func (evt *Event) GenerateReplyFallbackText() string {
- parsedContent, ok := evt.Content.Parsed.(*MessageEventContent)
- if !ok {
- return ""
- }
- parsedContent.RemoveReplyFallback()
- body := parsedContent.Body
- lines := strings.Split(strings.TrimSpace(body), "\n")
- firstLine, lines := lines[0], lines[1:]
-
- senderDisplayName := evt.Sender
-
- var fallbackText strings.Builder
- _, _ = fmt.Fprintf(&fallbackText, "> <%s> %s", senderDisplayName, firstLine)
- for _, line := range lines {
- _, _ = fmt.Fprintf(&fallbackText, "\n> %s", line)
- }
- fallbackText.WriteString("\n\n")
- return fallbackText.String()
-}
-
func (content *MessageEventContent) SetReply(inReplyTo *Event) {
- content.RelatesTo = (&RelatesTo{}).SetReplyTo(inReplyTo.ID)
-
- if content.MsgType == MsgText || content.MsgType == MsgNotice {
- content.EnsureHasHTML()
- content.FormattedBody = inReplyTo.GenerateReplyFallbackHTML() + content.FormattedBody
- content.Body = inReplyTo.GenerateReplyFallbackText() + content.Body
- content.replyFallbackRemoved = false
+ if content.RelatesTo == nil {
+ content.RelatesTo = &RelatesTo{}
}
+ content.RelatesTo.SetReplyTo(inReplyTo.ID)
+ if content.Mentions == nil {
+ content.Mentions = &Mentions{}
+ }
+ content.Mentions.Add(inReplyTo.Sender)
+}
+
+func (content *MessageEventContent) SetThread(inReplyTo *Event) {
+ root := inReplyTo.ID
+ relatable, ok := inReplyTo.Content.Parsed.(Relatable)
+ if ok {
+ targetRoot := relatable.OptionalGetRelatesTo().GetThreadParent()
+ if targetRoot != "" {
+ root = targetRoot
+ }
+ }
+ if content.RelatesTo == nil {
+ content.RelatesTo = &RelatesTo{}
+ }
+ content.RelatesTo.SetThread(root, inReplyTo.ID)
}
diff --git a/event/state.go b/event/state.go
index 15972892..ace170a5 100644
--- a/event/state.go
+++ b/event/state.go
@@ -7,6 +7,12 @@
package event
import (
+ "encoding/base64"
+ "encoding/json"
+ "slices"
+
+ "go.mau.fi/util/jsontime"
+
"maunium.net/go/mautrix/id"
)
@@ -42,7 +48,52 @@ type ServerACLEventContent struct {
// TopicEventContent represents the content of a m.room.topic state event.
// https://spec.matrix.org/v1.2/client-server-api/#mroomtopic
type TopicEventContent struct {
- Topic string `json:"topic"`
+ Topic string `json:"topic"`
+ ExtensibleTopic *ExtensibleTopic `json:"m.topic,omitempty"`
+}
+
+// ExtensibleTopic represents the contents of the m.topic field within the
+// m.room.topic state event as described in [MSC3765].
+//
+// [MSC3765]: https://github.com/matrix-org/matrix-spec-proposals/pull/3765
+type ExtensibleTopic = ExtensibleTextContainer
+
+type ExtensibleTextContainer struct {
+ Text []ExtensibleText `json:"m.text"`
+}
+
+func (c *ExtensibleTextContainer) Equals(description *ExtensibleTextContainer) bool {
+ if c == nil || description == nil {
+ return c == description
+ }
+ return slices.Equal(c.Text, description.Text)
+}
+
+func MakeExtensibleText(text string) *ExtensibleTextContainer {
+ return &ExtensibleTextContainer{
+ Text: []ExtensibleText{{
+ Body: text,
+ MimeType: "text/plain",
+ }},
+ }
+}
+
+func MakeExtensibleFormattedText(plaintext, html string) *ExtensibleTextContainer {
+ return &ExtensibleTextContainer{
+ Text: []ExtensibleText{{
+ Body: plaintext,
+ MimeType: "text/plain",
+ }, {
+ Body: html,
+ MimeType: "text/html",
+ }},
+ }
+}
+
+// ExtensibleText represents the contents of an m.text field.
+type ExtensibleText struct {
+ MimeType string `json:"mimetype,omitempty"`
+ Body string `json:"body"`
}
// TombstoneEventContent represents the content of a m.room.tombstone state event.
@@ -52,35 +103,64 @@ type TombstoneEventContent struct {
ReplacementRoom id.RoomID `json:"replacement_room"`
}
+func (tec *TombstoneEventContent) GetReplacementRoom() id.RoomID {
+ if tec == nil {
+ return ""
+ }
+ return tec.ReplacementRoom
+}
+
type Predecessor struct {
RoomID id.RoomID `json:"room_id"`
EventID id.EventID `json:"event_id"`
}
-type RoomVersion string
+// Deprecated: use id.RoomVersion instead
+type RoomVersion = id.RoomVersion
+// Deprecated: use id.RoomVX constants instead
const (
- RoomV1 RoomVersion = "1"
- RoomV2 RoomVersion = "2"
- RoomV3 RoomVersion = "3"
- RoomV4 RoomVersion = "4"
- RoomV5 RoomVersion = "5"
- RoomV6 RoomVersion = "6"
- RoomV7 RoomVersion = "7"
- RoomV8 RoomVersion = "8"
- RoomV9 RoomVersion = "9"
- RoomV10 RoomVersion = "10"
- RoomV11 RoomVersion = "11"
+ RoomV1 = id.RoomV1
+ RoomV2 = id.RoomV2
+ RoomV3 = id.RoomV3
+ RoomV4 = id.RoomV4
+ RoomV5 = id.RoomV5
+ RoomV6 = id.RoomV6
+ RoomV7 = id.RoomV7
+ RoomV8 = id.RoomV8
+ RoomV9 = id.RoomV9
+ RoomV10 = id.RoomV10
+ RoomV11 = id.RoomV11
+ RoomV12 = id.RoomV12
)
// CreateEventContent represents the content of a m.room.create state event.
// https://spec.matrix.org/v1.2/client-server-api/#mroomcreate
type CreateEventContent struct {
- Type RoomType `json:"type,omitempty"`
- Creator id.UserID `json:"creator,omitempty"`
- Federate bool `json:"m.federate,omitempty"`
- RoomVersion RoomVersion `json:"room_version,omitempty"`
- Predecessor *Predecessor `json:"predecessor,omitempty"`
+ Type RoomType `json:"type,omitempty"`
+ Federate *bool `json:"m.federate,omitempty"`
+ RoomVersion id.RoomVersion `json:"room_version,omitempty"`
+ Predecessor *Predecessor `json:"predecessor,omitempty"`
+
+ // Room v12+ only
+ AdditionalCreators []id.UserID `json:"additional_creators,omitempty"`
+
+ // Deprecated: use the event sender instead
+ Creator id.UserID `json:"creator,omitempty"`
+}
+
+func (cec *CreateEventContent) GetPredecessor() (p Predecessor) {
+ if cec != nil && cec.Predecessor != nil {
+ p = *cec.Predecessor
+ }
+ return
+}
+
+func (cec *CreateEventContent) SupportsCreatorPower() bool {
+ if cec == nil {
+ return false
+ }
+ return cec.RoomVersion.PrivilegedRoomCreators()
}
// JoinRule specifies how open a room is to new members.
@@ -158,7 +238,8 @@ type BridgeInfoSection struct {
AvatarURL id.ContentURIString `json:"avatar_url,omitempty"`
ExternalURL string `json:"external_url,omitempty"`
- Receiver string `json:"fi.mau.receiver,omitempty"`
+ Receiver string `json:"fi.mau.receiver,omitempty"`
+ MessageRequest bool `json:"com.beeper.message_request,omitempty"`
}
// BridgeEventContent represents the content of a m.bridge state event.
@@ -172,6 +253,32 @@ type BridgeEventContent struct {
BeeperRoomType string `json:"com.beeper.room_type,omitempty"`
BeeperRoomTypeV2 string `json:"com.beeper.room_type.v2,omitempty"`
+
+ TempSlackRemoteIDMigratedFlag bool `json:"com.beeper.slack_remote_id_migrated,omitempty"`
+ TempSlackRemoteIDMigratedFlag2 bool `json:"com.beeper.slack_remote_id_really_migrated,omitempty"`
+}
+
+// DisappearingType represents the type of a disappearing message timer.
+type DisappearingType string
+
+const (
+ DisappearingTypeNone DisappearingType = ""
+ DisappearingTypeAfterRead DisappearingType = "after_read"
+ DisappearingTypeAfterSend DisappearingType = "after_send"
+)
+
+type BeeperDisappearingTimer struct {
+ Type DisappearingType `json:"type"`
+ Timer jsontime.Milliseconds `json:"timer"`
+}
+
+type marshalableBeeperDisappearingTimer BeeperDisappearingTimer
+
+func (bdt *BeeperDisappearingTimer) MarshalJSON() ([]byte, error) {
+ if bdt == nil || bdt.Type == DisappearingTypeNone {
+ return []byte("{}"), nil
+ }
+ return json.Marshal((*marshalableBeeperDisappearingTimer)(bdt))
}
type SpaceChildEventContent struct {
@@ -188,25 +295,63 @@ type SpaceParentEventContent struct {
type PolicyRecommendation string
const (
- PolicyRecommendationBan PolicyRecommendation = "m.ban"
- PolicyRecommendationUnstableBan PolicyRecommendation = "org.matrix.mjolnir.ban"
- PolicyRecommendationUnban PolicyRecommendation = "fi.mau.meowlnir.unban"
+ PolicyRecommendationBan PolicyRecommendation = "m.ban"
+ PolicyRecommendationUnstableTakedown PolicyRecommendation = "org.matrix.msc4204.takedown"
+ PolicyRecommendationUnstableBan PolicyRecommendation = "org.matrix.mjolnir.ban"
+ PolicyRecommendationUnban PolicyRecommendation = "fi.mau.meowlnir.unban"
)
+type PolicyHashes struct {
+ SHA256 string `json:"sha256"`
+}
+
+func (ph *PolicyHashes) DecodeSHA256() *[32]byte {
+ if ph == nil || ph.SHA256 == "" {
+ return nil
+ }
+ decoded, _ := base64.StdEncoding.DecodeString(ph.SHA256)
+ if len(decoded) == 32 {
+ return (*[32]byte)(decoded)
+ }
+ return nil
+}
+
// ModPolicyContent represents the content of a m.room.rule.user, m.room.rule.room, and m.room.rule.server state event.
// https://spec.matrix.org/v1.2/client-server-api/#moderation-policy-lists
type ModPolicyContent struct {
- Entity string `json:"entity"`
+ Entity string `json:"entity,omitempty"`
Reason string `json:"reason"`
Recommendation PolicyRecommendation `json:"recommendation"`
+ UnstableHashes *PolicyHashes `json:"org.matrix.msc4205.hashes,omitempty"`
}
-// Deprecated: MSC2716 has been abandoned
-type InsertionMarkerContent struct {
- InsertionID id.EventID `json:"org.matrix.msc2716.marker.insertion"`
- Timestamp int64 `json:"com.beeper.timestamp,omitempty"`
+func (mpc *ModPolicyContent) EntityOrHash() string {
+ if mpc.UnstableHashes != nil && mpc.UnstableHashes.SHA256 != "" {
+ return mpc.UnstableHashes.SHA256
+ }
+ return mpc.Entity
}
type ElementFunctionalMembersContent struct {
ServiceMembers []id.UserID `json:"service_members"`
}
+
+func (efmc *ElementFunctionalMembersContent) Add(mxid id.UserID) bool {
+ if slices.Contains(efmc.ServiceMembers, mxid) {
+ return false
+ }
+ efmc.ServiceMembers = append(efmc.ServiceMembers, mxid)
+ return true
+}
+
+type PolicyServerPublicKeys struct {
+ Ed25519 id.Ed25519 `json:"ed25519,omitempty"`
+}
+
+type RoomPolicyEventContent struct {
+ Via string `json:"via,omitempty"`
+ PublicKeys *PolicyServerPublicKeys `json:"public_keys,omitempty"`
+
+ // Deprecated, only for legacy use
+ PublicKey id.Ed25519 `json:"public_key,omitempty"`
+}
diff --git a/event/type.go b/event/type.go
index 4396c9cc..80b86728 100644
--- a/event/type.go
+++ b/event/type.go
@@ -108,13 +108,14 @@ func (et *Type) IsCustom() bool {
func (et *Type) GuessClass() TypeClass {
switch et.Type {
- case StateAliases.Type, StateCanonicalAlias.Type, StateCreate.Type, StateJoinRules.Type, StateMember.Type,
+ case StateAliases.Type, StateCanonicalAlias.Type, StateCreate.Type, StateJoinRules.Type, StateMember.Type, StateThirdPartyInvite.Type,
StatePowerLevels.Type, StateRoomName.Type, StateRoomAvatar.Type, StateServerACL.Type, StateTopic.Type,
StatePinnedEvents.Type, StateTombstone.Type, StateEncryption.Type, StateBridge.Type, StateHalfShotBridge.Type,
StateSpaceParent.Type, StateSpaceChild.Type, StatePolicyRoom.Type, StatePolicyServer.Type, StatePolicyUser.Type,
- StateInsertionMarker.Type, StateElementFunctionalMembers.Type:
+ StateElementFunctionalMembers.Type, StateBeeperRoomFeatures.Type, StateBeeperDisappearingTimer.Type,
+ StateMSC4391BotCommand.Type, StateRoomPolicy.Type, StateUnstableRoomPolicy.Type:
return StateEventType
- case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type:
+ case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type, BeeperEphemeralEventAIStream.Type:
return EphemeralEventType
case AccountDataDirectChats.Type, AccountDataPushRules.Type, AccountDataRoomTags.Type,
AccountDataFullyRead.Type, AccountDataIgnoredUserList.Type, AccountDataMarkedUnread.Type,
@@ -126,7 +127,8 @@ func (et *Type) GuessClass() TypeClass {
InRoomVerificationStart.Type, InRoomVerificationReady.Type, InRoomVerificationAccept.Type,
InRoomVerificationKey.Type, InRoomVerificationMAC.Type, InRoomVerificationCancel.Type,
CallInvite.Type, CallCandidates.Type, CallAnswer.Type, CallReject.Type, CallSelectAnswer.Type,
- CallNegotiate.Type, CallHangup.Type, BeeperMessageStatus.Type, EventUnstablePollStart.Type, EventUnstablePollResponse.Type:
+ CallNegotiate.Type, CallHangup.Type, BeeperMessageStatus.Type, EventUnstablePollStart.Type, EventUnstablePollResponse.Type,
+ EventUnstablePollEnd.Type, BeeperTranscription.Type, BeeperDeleteChat.Type, BeeperAcceptMessageRequest.Type:
return MessageEventType
case ToDeviceRoomKey.Type, ToDeviceRoomKeyRequest.Type, ToDeviceForwardedRoomKey.Type, ToDeviceRoomKeyWithheld.Type,
ToDeviceBeeperRoomKeyAck.Type:
@@ -149,7 +151,7 @@ func (et *Type) MarshalJSON() ([]byte, error) {
return json.Marshal(&et.Type)
}
-func (et Type) UnmarshalText(data []byte) error {
+func (et *Type) UnmarshalText(data []byte) error {
et.Type = string(data)
et.Class = et.GuessClass()
return nil
@@ -159,11 +161,11 @@ func (et Type) MarshalText() ([]byte, error) {
return []byte(et.Type), nil
}
-func (et *Type) String() string {
+func (et Type) String() string {
return et.Type
}
-func (et *Type) Repr() string {
+func (et Type) Repr() string {
return fmt.Sprintf("%s (%s)", et.Type, et.Class.Name())
}
@@ -176,6 +178,7 @@ var (
StateHistoryVisibility = Type{"m.room.history_visibility", StateEventType}
StateGuestAccess = Type{"m.room.guest_access", StateEventType}
StateMember = Type{"m.room.member", StateEventType}
+ StateThirdPartyInvite = Type{"m.room.third_party_invite", StateEventType}
StatePowerLevels = Type{"m.room.power_levels", StateEventType}
StateRoomName = Type{"m.room.name", StateEventType}
StateTopic = Type{"m.room.topic", StateEventType}
@@ -192,6 +195,9 @@ var (
StateSpaceChild = Type{"m.space.child", StateEventType}
StateSpaceParent = Type{"m.space.parent", StateEventType}
+ StateRoomPolicy = Type{"m.room.policy", StateEventType}
+ StateUnstableRoomPolicy = Type{"org.matrix.msc4284.policy", StateEventType}
+
StateLegacyPolicyRoom = Type{"m.room.rule.room", StateEventType}
StateLegacyPolicyServer = Type{"m.room.rule.server", StateEventType}
StateLegacyPolicyUser = Type{"m.room.rule.user", StateEventType}
@@ -199,10 +205,10 @@ var (
StateUnstablePolicyServer = Type{"org.matrix.mjolnir.rule.server", StateEventType}
StateUnstablePolicyUser = Type{"org.matrix.mjolnir.rule.user", StateEventType}
- // Deprecated: MSC2716 has been abandoned
- StateInsertionMarker = Type{"org.matrix.msc2716.marker", StateEventType}
-
StateElementFunctionalMembers = Type{"io.element.functional_members", StateEventType}
+ StateBeeperRoomFeatures = Type{"com.beeper.room_features", StateEventType}
+ StateBeeperDisappearingTimer = Type{"com.beeper.disappearing_timer", StateEventType}
+ StateMSC4391BotCommand = Type{"org.matrix.msc4391.command_description", StateEventType}
)
// Message events
@@ -231,17 +237,24 @@ var (
CallNegotiate = Type{"m.call.negotiate", MessageEventType}
CallHangup = Type{"m.call.hangup", MessageEventType}
- BeeperMessageStatus = Type{"com.beeper.message_send_status", MessageEventType}
+ BeeperMessageStatus = Type{"com.beeper.message_send_status", MessageEventType}
+ BeeperTranscription = Type{"com.beeper.transcription", MessageEventType}
+ BeeperDeleteChat = Type{"com.beeper.delete_chat", MessageEventType}
+ BeeperAcceptMessageRequest = Type{"com.beeper.accept_message_request", MessageEventType}
+ BeeperSendState = Type{"com.beeper.send_state", MessageEventType}
EventUnstablePollStart = Type{Type: "org.matrix.msc3381.poll.start", Class: MessageEventType}
EventUnstablePollResponse = Type{Type: "org.matrix.msc3381.poll.response", Class: MessageEventType}
+ EventUnstablePollEnd = Type{Type: "org.matrix.msc3381.poll.end", Class: MessageEventType}
)
// Ephemeral events
var (
- EphemeralEventReceipt = Type{"m.receipt", EphemeralEventType}
- EphemeralEventTyping = Type{"m.typing", EphemeralEventType}
- EphemeralEventPresence = Type{"m.presence", EphemeralEventType}
+ EphemeralEventReceipt = Type{"m.receipt", EphemeralEventType}
+ EphemeralEventTyping = Type{"m.typing", EphemeralEventType}
+ EphemeralEventPresence = Type{"m.presence", EphemeralEventType}
+ EphemeralEventEncrypted = Type{"m.room.encrypted", EphemeralEventType}
+ BeeperEphemeralEventAIStream = Type{"com.beeper.ai.stream_event", EphemeralEventType}
)
// Account data events
diff --git a/event/voip.go b/event/voip.go
index 28f56c95..cd8364a1 100644
--- a/event/voip.go
+++ b/event/voip.go
@@ -76,7 +76,7 @@ func (cv *CallVersion) Int() (int, error) {
type BaseCallEventContent struct {
CallID string `json:"call_id"`
PartyID string `json:"party_id"`
- Version CallVersion `json:"version"`
+ Version CallVersion `json:"version,omitempty"`
}
type CallInviteEventContent struct {
diff --git a/example/main.go b/example/main.go
index d8006d46..2bf4bef3 100644
--- a/example/main.go
+++ b/example/main.go
@@ -143,7 +143,7 @@ func main() {
if err != nil {
log.Error().Err(err).Msg("Failed to send event")
} else {
- log.Info().Str("event_id", resp.EventID.String()).Msg("Event sent")
+ log.Info().Stringer("event_id", resp.EventID).Msg("Event sent")
}
}
cancelSync()
diff --git a/federation/cache.go b/federation/cache.go
new file mode 100644
index 00000000..24154974
--- /dev/null
+++ b/federation/cache.go
@@ -0,0 +1,153 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package federation
+
+import (
+ "errors"
+ "fmt"
+ "math"
+ "sync"
+ "time"
+)
+
+// ResolutionCache is an interface for caching resolved server names.
+type ResolutionCache interface {
+ StoreResolution(*ResolvedServerName)
+ // LoadResolution loads a resolved server name from the cache.
+ // Expired entries MUST NOT be returned.
+ LoadResolution(serverName string) (*ResolvedServerName, error)
+}
+
+type KeyCache interface {
+ StoreKeys(*ServerKeyResponse)
+ StoreFetchError(serverName string, err error)
+ ShouldReQuery(serverName string) bool
+ LoadKeys(serverName string) (*ServerKeyResponse, error)
+}
+
+type InMemoryCache struct {
+ MinKeyRefetchDelay time.Duration
+
+ resolutions map[string]*ResolvedServerName
+ resolutionsLock sync.RWMutex
+ keys map[string]*ServerKeyResponse
+ lastReQueryAt map[string]time.Time
+ lastError map[string]*resolutionErrorCache
+ keysLock sync.RWMutex
+}
+
+var (
+ _ ResolutionCache = (*InMemoryCache)(nil)
+ _ KeyCache = (*InMemoryCache)(nil)
+)
+
+func NewInMemoryCache() *InMemoryCache {
+ return &InMemoryCache{
+ resolutions: make(map[string]*ResolvedServerName),
+ keys: make(map[string]*ServerKeyResponse),
+ lastReQueryAt: make(map[string]time.Time),
+ lastError: make(map[string]*resolutionErrorCache),
+ MinKeyRefetchDelay: 1 * time.Hour,
+ }
+}
+
+func (c *InMemoryCache) StoreResolution(resolution *ResolvedServerName) {
+ c.resolutionsLock.Lock()
+ defer c.resolutionsLock.Unlock()
+ c.resolutions[resolution.ServerName] = resolution
+}
+
+func (c *InMemoryCache) LoadResolution(serverName string) (*ResolvedServerName, error) {
+ c.resolutionsLock.RLock()
+ defer c.resolutionsLock.RUnlock()
+ resolution, ok := c.resolutions[serverName]
+ if !ok || time.Until(resolution.Expires) < 0 {
+ return nil, nil
+ }
+ return resolution, nil
+}
+
+func (c *InMemoryCache) StoreKeys(keys *ServerKeyResponse) {
+ c.keysLock.Lock()
+ defer c.keysLock.Unlock()
+ c.keys[keys.ServerName] = keys
+ delete(c.lastError, keys.ServerName)
+}
+
+type resolutionErrorCache struct {
+ Error error
+ Time time.Time
+ Count int
+}
+
+const MaxBackoff = 7 * 24 * time.Hour
+
+func (rec *resolutionErrorCache) ShouldRetry() bool {
+ backoff := time.Duration(math.Exp(float64(rec.Count))) * time.Second
+ return time.Since(rec.Time) > backoff
+}
+
+var ErrRecentKeyQueryFailed = errors.New("last retry was too recent")
+
+func (c *InMemoryCache) LoadKeys(serverName string) (*ServerKeyResponse, error) {
+ c.keysLock.RLock()
+ defer c.keysLock.RUnlock()
+ keys, ok := c.keys[serverName]
+ if !ok || time.Until(keys.ValidUntilTS.Time) < 0 {
+ err, ok := c.lastError[serverName]
+ if ok && !err.ShouldRetry() {
+ return nil, fmt.Errorf(
+ "%w (%s ago) and failed with %w",
+ ErrRecentKeyQueryFailed,
+ time.Since(err.Time).String(),
+ err.Error,
+ )
+ }
+ return nil, nil
+ }
+ return keys, nil
+}
+
+func (c *InMemoryCache) StoreFetchError(serverName string, err error) {
+ c.keysLock.Lock()
+ defer c.keysLock.Unlock()
+ errorCache, ok := c.lastError[serverName]
+ if ok {
+ errorCache.Time = time.Now()
+ errorCache.Error = err
+ errorCache.Count++
+ } else {
+ c.lastError[serverName] = &resolutionErrorCache{Error: err, Time: time.Now(), Count: 1}
+ }
+}
+
+func (c *InMemoryCache) ShouldReQuery(serverName string) bool {
+ c.keysLock.Lock()
+ defer c.keysLock.Unlock()
+ lastQuery, ok := c.lastReQueryAt[serverName]
+ if ok && time.Since(lastQuery) < c.MinKeyRefetchDelay {
+ return false
+ }
+ c.lastReQueryAt[serverName] = time.Now()
+ return true
+}
+
+type noopCache struct{}
+
+func (*noopCache) StoreKeys(_ *ServerKeyResponse) {}
+func (*noopCache) LoadKeys(_ string) (*ServerKeyResponse, error) { return nil, nil }
+func (*noopCache) StoreFetchError(_ string, _ error) {}
+func (*noopCache) ShouldReQuery(_ string) bool { return true }
+func (*noopCache) StoreResolution(_ *ResolvedServerName) {}
+func (*noopCache) LoadResolution(_ string) (*ResolvedServerName, error) { return nil, nil }
+
+var (
+ _ ResolutionCache = (*noopCache)(nil)
+ _ KeyCache = (*noopCache)(nil)
+)
+
+var NoopCache *noopCache
diff --git a/federation/client.go b/federation/client.go
index 098df095..183fb5d1 100644
--- a/federation/client.go
+++ b/federation/client.go
@@ -9,7 +9,6 @@ package federation
import (
"bytes"
"context"
- "encoding/base64"
"encoding/json"
"fmt"
"io"
@@ -22,6 +21,7 @@ import (
"go.mau.fi/util/jsontime"
"maunium.net/go/mautrix"
+ "maunium.net/go/mautrix/federation/signutil"
"maunium.net/go/mautrix/id"
)
@@ -30,17 +30,25 @@ type Client struct {
ServerName string
UserAgent string
Key *SigningKey
+
+ ResponseSizeLimit int64
}
-func NewClient(serverName string, key *SigningKey) *Client {
+func NewClient(serverName string, key *SigningKey, cache ResolutionCache) *Client {
return &Client{
HTTP: &http.Client{
- Transport: NewServerResolvingTransport(),
+ Transport: NewServerResolvingTransport(cache),
Timeout: 120 * time.Second,
+ CheckRedirect: func(req *http.Request, via []*http.Request) error {
+ // Federation requests do not allow redirects.
+ return http.ErrUseLastResponse
+ },
},
UserAgent: mautrix.DefaultUserAgent,
ServerName: serverName,
Key: key,
+
+ ResponseSizeLimit: mautrix.DefaultResponseSizeLimit,
}
}
@@ -54,7 +62,7 @@ func (c *Client) ServerKeys(ctx context.Context, serverName string) (resp *Serve
return
}
-func (c *Client) QueryKeys(ctx context.Context, serverName string, req *ReqQueryKeys) (resp *ServerKeyResponse, err error) {
+func (c *Client) QueryKeys(ctx context.Context, serverName string, req *ReqQueryKeys) (resp *QueryKeysResponse, err error) {
err = c.MakeRequest(ctx, serverName, false, http.MethodPost, KeyURLPath{"v2", "query"}, req, &resp)
return
}
@@ -81,7 +89,7 @@ type RespSendTransaction struct {
}
func (c *Client) SendTransaction(ctx context.Context, req *ReqSendTransaction) (resp *RespSendTransaction, err error) {
- err = c.MakeRequest(ctx, req.Destination, true, http.MethodPost, URLPath{"v1", "send", req.TxnID}, req, &resp)
+ err = c.MakeRequest(ctx, req.Destination, true, http.MethodPut, URLPath{"v1", "send", req.TxnID}, req, &resp)
return
}
@@ -220,6 +228,26 @@ func (c *Client) Query(ctx context.Context, serverName, queryType string, queryP
return
}
+func queryToValues(query map[string]string) url.Values {
+ values := make(url.Values, len(query))
+ for k, v := range query {
+ values[k] = []string{v}
+ }
+ return values
+}
+
+func (c *Client) PublicRooms(ctx context.Context, serverName string, req *mautrix.ReqPublicRooms) (resp *mautrix.RespPublicRooms, err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: serverName,
+ Method: http.MethodGet,
+ Path: URLPath{"v1", "publicRooms"},
+ Query: queryToValues(req.Query()),
+ Authenticate: true,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
type RespOpenIDUserInfo struct {
Sub id.UserID `json:"sub"`
}
@@ -235,6 +263,169 @@ func (c *Client) GetOpenIDUserInfo(ctx context.Context, serverName, accessToken
return
}
+type ReqMakeJoin struct {
+ RoomID id.RoomID
+ UserID id.UserID
+ Via string
+ SupportedVersions []id.RoomVersion
+}
+
+type RespMakeJoin struct {
+ RoomVersion id.RoomVersion `json:"room_version"`
+ Event PDU `json:"event"`
+}
+
+type ReqSendJoin struct {
+ RoomID id.RoomID
+ EventID id.EventID
+ OmitMembers bool
+ Event PDU
+ Via string
+}
+
+type ReqSendKnock struct {
+ RoomID id.RoomID
+ EventID id.EventID
+ Event PDU
+ Via string
+}
+
+type RespSendJoin struct {
+ AuthChain []PDU `json:"auth_chain"`
+ Event PDU `json:"event"`
+ MembersOmitted bool `json:"members_omitted"`
+ ServersInRoom []string `json:"servers_in_room"`
+ State []PDU `json:"state"`
+}
+
+type RespSendKnock struct {
+ KnockRoomState []PDU `json:"knock_room_state"`
+}
+
+type ReqSendInvite struct {
+ RoomID id.RoomID `json:"-"`
+ UserID id.UserID `json:"-"`
+ Event PDU `json:"event"`
+ InviteRoomState []PDU `json:"invite_room_state"`
+ RoomVersion id.RoomVersion `json:"room_version"`
+}
+
+type RespSendInvite struct {
+ Event PDU `json:"event"`
+}
+
+type ReqMakeLeave struct {
+ RoomID id.RoomID
+ UserID id.UserID
+ Via string
+}
+
+type ReqSendLeave struct {
+ RoomID id.RoomID
+ EventID id.EventID
+ Event PDU
+ Via string
+}
+
+type (
+ ReqMakeKnock = ReqMakeJoin
+ RespMakeKnock = RespMakeJoin
+ RespMakeLeave = RespMakeJoin
+)
+
+func (c *Client) MakeJoin(ctx context.Context, req *ReqMakeJoin) (resp *RespMakeJoin, err error) {
+ versions := make([]string, len(req.SupportedVersions))
+ for i, v := range req.SupportedVersions {
+ versions[i] = string(v)
+ }
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.Via,
+ Method: http.MethodGet,
+ Path: URLPath{"v1", "make_join", req.RoomID, req.UserID},
+ Query: url.Values{"ver": versions},
+ Authenticate: true,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+func (c *Client) MakeKnock(ctx context.Context, req *ReqMakeKnock) (resp *RespMakeKnock, err error) {
+ versions := make([]string, len(req.SupportedVersions))
+ for i, v := range req.SupportedVersions {
+ versions[i] = string(v)
+ }
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.Via,
+ Method: http.MethodGet,
+ Path: URLPath{"v1", "make_knock", req.RoomID, req.UserID},
+ Query: url.Values{"ver": versions},
+ Authenticate: true,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+func (c *Client) SendJoin(ctx context.Context, req *ReqSendJoin) (resp *RespSendJoin, err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.Via,
+ Method: http.MethodPut,
+ Path: URLPath{"v2", "send_join", req.RoomID, req.EventID},
+ Query: url.Values{
+ "omit_members": {strconv.FormatBool(req.OmitMembers)},
+ },
+ Authenticate: true,
+ RequestJSON: req.Event,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+func (c *Client) SendKnock(ctx context.Context, req *ReqSendKnock) (resp *RespSendKnock, err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.Via,
+ Method: http.MethodPut,
+ Path: URLPath{"v1", "send_knock", req.RoomID, req.EventID},
+ Authenticate: true,
+ RequestJSON: req.Event,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+func (c *Client) SendInvite(ctx context.Context, req *ReqSendInvite) (resp *RespSendInvite, err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.UserID.Homeserver(),
+ Method: http.MethodPut,
+ Path: URLPath{"v2", "invite", req.RoomID, req.UserID},
+ Authenticate: true,
+ RequestJSON: req,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+func (c *Client) MakeLeave(ctx context.Context, req *ReqMakeLeave) (resp *RespMakeLeave, err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.Via,
+ Method: http.MethodGet,
+ Path: URLPath{"v1", "make_leave", req.RoomID, req.UserID},
+ Authenticate: true,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+func (c *Client) SendLeave(ctx context.Context, req *ReqSendLeave) (err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.Via,
+ Method: http.MethodPut,
+ Path: URLPath{"v2", "send_leave", req.RoomID, req.EventID},
+ Authenticate: true,
+ RequestJSON: req.Event,
+ })
+ return
+}
+
type URLPath []any
func (fup URLPath) FullPath() []any {
@@ -286,15 +477,27 @@ func (c *Client) MakeFullRequest(ctx context.Context, params RequestParams) ([]b
WrappedError: err,
}
}
- defer func() {
- _ = resp.Body.Close()
- }()
+ if !params.DontReadBody {
+ defer resp.Body.Close()
+ }
var body []byte
- if resp.StatusCode >= 400 {
+ if resp.StatusCode >= 300 {
body, err = mautrix.ParseErrorResponse(req, resp)
return body, resp, err
} else if params.ResponseJSON != nil || !params.DontReadBody {
- body, err = io.ReadAll(resp.Body)
+ if resp.ContentLength > c.ResponseSizeLimit {
+ return body, resp, mautrix.HTTPError{
+ Request: req,
+ Response: resp,
+
+ Message: "not reading response",
+ WrappedError: fmt.Errorf("%w (%.2f MiB)", mautrix.ErrResponseTooLong, float64(resp.ContentLength)/1024/1024),
+ }
+ }
+ body, err = io.ReadAll(io.LimitReader(resp.Body, c.ResponseSizeLimit+1))
+ if err == nil && len(body) > int(c.ResponseSizeLimit) {
+ err = mautrix.ErrBodyReadReachedLimit
+ }
if err != nil {
return body, resp, mautrix.HTTPError{
Request: req,
@@ -354,16 +557,12 @@ func (c *Client) compileRequest(ctx context.Context, params RequestParams) (*htt
Message: "client not configured for authentication",
}
}
- var contentAny any
- if reqJSON != nil {
- contentAny = reqJSON
- }
auth, err := (&signableRequest{
Method: req.Method,
URI: reqURL.RequestURI(),
Origin: c.ServerName,
Destination: params.ServerName,
- Content: contentAny,
+ Content: reqJSON,
}).Sign(c.Key)
if err != nil {
return nil, mautrix.HTTPError{
@@ -377,11 +576,19 @@ func (c *Client) compileRequest(ctx context.Context, params RequestParams) (*htt
}
type signableRequest struct {
- Method string `json:"method"`
- URI string `json:"uri"`
- Origin string `json:"origin"`
- Destination string `json:"destination"`
- Content any `json:"content,omitempty"`
+ Method string `json:"method"`
+ URI string `json:"uri"`
+ Origin string `json:"origin"`
+ Destination string `json:"destination"`
+ Content json.RawMessage `json:"content,omitempty"`
+}
+
+func (r *signableRequest) Verify(key id.SigningKey, sig string) error {
+ message, err := json.Marshal(r)
+ if err != nil {
+ return fmt.Errorf("failed to marshal data: %w", err)
+ }
+ return signutil.VerifyJSONRaw(key, sig, message)
}
func (r *signableRequest) Sign(key *SigningKey) (string, error) {
@@ -389,11 +596,10 @@ func (r *signableRequest) Sign(key *SigningKey) (string, error) {
if err != nil {
return "", err
}
- return fmt.Sprintf(
- `X-Matrix origin="%s",destination="%s",key="%s",sig="%s"`,
- r.Origin,
- r.Destination,
- key.ID,
- base64.RawURLEncoding.EncodeToString(sig),
- ), nil
+ return XMatrixAuth{
+ Origin: r.Origin,
+ Destination: r.Destination,
+ KeyID: key.ID,
+ Signature: sig,
+ }.String(), nil
}
diff --git a/federation/client_test.go b/federation/client_test.go
index ba3c3ed4..ece399ea 100644
--- a/federation/client_test.go
+++ b/federation/client_test.go
@@ -16,7 +16,7 @@ import (
)
func TestClient_Version(t *testing.T) {
- cli := federation.NewClient("", nil)
+ cli := federation.NewClient("", nil, nil)
resp, err := cli.Version(context.TODO(), "maunium.net")
require.NoError(t, err)
require.Equal(t, "Synapse", resp.Server.Name)
diff --git a/federation/context.go b/federation/context.go
new file mode 100644
index 00000000..eedb2dc1
--- /dev/null
+++ b/federation/context.go
@@ -0,0 +1,42 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package federation
+
+import (
+ "context"
+ "net/http"
+)
+
+type contextKey int
+
+const (
+ contextKeyIPPort contextKey = iota
+ contextKeyDestinationServer
+ contextKeyOriginServer
+)
+
+func DestinationServerNameFromRequest(r *http.Request) string {
+ return DestinationServerName(r.Context())
+}
+
+func DestinationServerName(ctx context.Context) string {
+ if dest, ok := ctx.Value(contextKeyDestinationServer).(string); ok {
+ return dest
+ }
+ return ""
+}
+
+func OriginServerNameFromRequest(r *http.Request) string {
+ return OriginServerName(r.Context())
+}
+
+func OriginServerName(ctx context.Context) string {
+ if origin, ok := ctx.Value(contextKeyOriginServer).(string); ok {
+ return origin
+ }
+ return ""
+}
diff --git a/federation/eventauth/eventauth.go b/federation/eventauth/eventauth.go
new file mode 100644
index 00000000..c72933c2
--- /dev/null
+++ b/federation/eventauth/eventauth.go
@@ -0,0 +1,851 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package eventauth
+
+import (
+ "encoding/json"
+ "encoding/json/jsontext"
+ "errors"
+ "fmt"
+ "slices"
+ "strconv"
+ "strings"
+
+ "github.com/tidwall/gjson"
+ "go.mau.fi/util/exgjson"
+ "go.mau.fi/util/exstrings"
+ "go.mau.fi/util/ptr"
+
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/federation/pdu"
+ "maunium.net/go/mautrix/federation/signutil"
+ "maunium.net/go/mautrix/id"
+)
+
+type AuthFailError struct {
+ Index string
+ Message string
+ Wrapped error
+}
+
+func (afe AuthFailError) Error() string {
+ if afe.Message != "" {
+ return fmt.Sprintf("fail %s: %s", afe.Index, afe.Message)
+ } else if afe.Wrapped != nil {
+ return fmt.Sprintf("fail %s: %s", afe.Index, afe.Wrapped.Error())
+ }
+ return fmt.Sprintf("fail %s", afe.Index)
+}
+
+func (afe AuthFailError) Unwrap() error {
+ return afe.Wrapped
+}
+
+var mFederatePath = exgjson.Path("m.federate")
+
+var (
+ ErrCreateHasPrevEvents = AuthFailError{Index: "1.1", Message: "m.room.create event has prev_events"}
+ ErrCreateHasRoomID = AuthFailError{Index: "1.2", Message: "m.room.create event has room_id set"}
+ ErrRoomIDDoesntMatchSender = AuthFailError{Index: "1.2", Message: "room ID server doesn't match sender server"}
+ ErrUnknownRoomVersion = AuthFailError{Index: "1.3", Wrapped: id.ErrUnknownRoomVersion}
+ ErrInvalidAdditionalCreators = AuthFailError{Index: "1.4", Message: "m.room.create event has invalid additional_creators"}
+ ErrMissingCreator = AuthFailError{Index: "1.4", Message: "m.room.create event is missing creator field"}
+
+ ErrInvalidRoomIDLength = AuthFailError{Index: "2", Message: "room ID length is invalid"}
+ ErrFailedToGetCreateEvent = AuthFailError{Index: "2", Message: "failed to get m.room.create event"}
+ ErrCreateEventNotFound = AuthFailError{Index: "2", Message: "m.room.create event not found using room ID as event ID"}
+ ErrRejectedCreateEvent = AuthFailError{Index: "2", Message: "m.room.create event was rejected"}
+
+ ErrFailedToGetAuthEvents = AuthFailError{Index: "3", Message: "failed to get auth events"}
+ ErrFailedToParsePowerLevels = AuthFailError{Index: "?", Message: "failed to parse power levels"}
+ ErrDuplicateAuthEvent = AuthFailError{Index: "3.1", Message: "duplicate type/state key pair in auth events"}
+ ErrNonStateAuthEvent = AuthFailError{Index: "3.2", Message: "non-state event in auth events"}
+ ErrMissingAuthEvent = AuthFailError{Index: "3.2", Message: "missing auth event"}
+ ErrUnexpectedAuthEvent = AuthFailError{Index: "3.2", Message: "unexpected type/state key pair in auth events"}
+ ErrNoCreateEvent = AuthFailError{Index: "3.2", Message: "no m.room.create event found in auth events"}
+ ErrRejectedAuthEvent = AuthFailError{Index: "3.3", Message: "auth event was rejected"}
+ ErrMismatchingRoomIDInAuthEvent = AuthFailError{Index: "3.4", Message: "auth event room ID does not match event room ID"}
+
+ ErrFederationDisabled = AuthFailError{Index: "4", Message: "federation is disabled for this room"}
+
+ ErrMemberNotState = AuthFailError{Index: "5.1", Message: "m.room.member event is not a state event"}
+ ErrNotSignedByAuthoriser = AuthFailError{Index: "5.2", Message: "m.room.member event is not signed by server of join_authorised_via_users_server"}
+ ErrCantJoinOtherUser = AuthFailError{Index: "5.3.2", Message: "can't send join event with different state key"}
+ ErrCantJoinBanned = AuthFailError{Index: "5.3.3", Message: "user is banned from the room"}
+ ErrAuthoriserCantInvite = AuthFailError{Index: "5.3.5.2", Message: "authoriser doesn't have sufficient power level to invite"}
+ ErrAuthoriserNotInRoom = AuthFailError{Index: "5.3.5.2", Message: "authoriser isn't a member of the room"}
+ ErrCantJoinWithoutInvite = AuthFailError{Index: "5.3.7", Message: "can't join invite-only room without invite"}
+ ErrInvalidJoinRule = AuthFailError{Index: "5.3.7", Message: "invalid join rule in room"}
+ ErrThirdPartyInviteBanned = AuthFailError{Index: "5.4.1.1", Message: "third party invite target user is banned"}
+ ErrThirdPartyInviteMissingFields = AuthFailError{Index: "5.4.1.3", Message: "third party invite is missing mxid or token fields"}
+ ErrThirdPartyInviteMXIDMismatch = AuthFailError{Index: "5.4.1.4", Message: "mxid in signed third party invite doesn't match event state key"}
+ ErrThirdPartyInviteNotFound = AuthFailError{Index: "5.4.1.5", Message: "matching m.room.third_party_invite event not found in auth events"}
+ ErrThirdPartyInviteSenderMismatch = AuthFailError{Index: "5.4.1.6", Message: "sender of third party invite doesn't match sender of member event"}
+ ErrThirdPartyInviteNotSigned = AuthFailError{Index: "5.4.1.8", Message: "no valid signatures found for third party invite"}
+ ErrInviterNotInRoom = AuthFailError{Index: "5.4.2", Message: "inviter's membership is not join"}
+ ErrInviteTargetAlreadyInRoom = AuthFailError{Index: "5.4.3", Message: "invite target user is already in the room"}
+ ErrInviteTargetBanned = AuthFailError{Index: "5.4.3", Message: "invite target user is banned"}
+ ErrInsufficientPermissionForInvite = AuthFailError{Index: "5.4.5", Message: "inviter does not have sufficient permission to send invites"}
+ ErrCantLeaveWithoutBeingInRoom = AuthFailError{Index: "5.5.1", Message: "can't leave room without being in it"}
+ ErrCantKickWithoutBeingInRoom = AuthFailError{Index: "5.5.2", Message: "can't kick another user without being in the room"}
+ ErrInsufficientPermissionForUnban = AuthFailError{Index: "5.5.3", Message: "sender does not have sufficient permission to unban users"}
+ ErrInsufficientPermissionForKick = AuthFailError{Index: "5.5.5", Message: "sender does not have sufficient permission to kick the user"}
+ ErrCantBanWithoutBeingInRoom = AuthFailError{Index: "5.6.1", Message: "can't ban another user without being in the room"}
+ ErrInsufficientPermissionForBan = AuthFailError{Index: "5.6.3", Message: "sender does not have sufficient permission to ban the user"}
+ ErrNotKnockableRoom = AuthFailError{Index: "5.7.1", Message: "join rule doesn't allow knocking"}
+ ErrCantKnockOtherUser = AuthFailError{Index: "5.7.1", Message: "can't send knock event with different state key"}
+ ErrCantKnockWhileInRoom = AuthFailError{Index: "5.7.2", Message: "can't knock while joined, invited or banned"}
+ ErrUnknownMembership = AuthFailError{Index: "5.8", Message: "unknown membership in m.room.member event"}
+
+ ErrNotInRoom = AuthFailError{Index: "6", Message: "sender is not a member of the room"}
+
+ ErrInsufficientPowerForThirdPartyInvite = AuthFailError{Index: "7.1", Message: "sender does not have sufficient power level to send third party invite"}
+
+ ErrInsufficientPowerLevel = AuthFailError{Index: "8", Message: "sender does not have sufficient power level to send event"}
+
+ ErrMismatchingPrivateStateKey = AuthFailError{Index: "9", Message: "state keys starting with @ must match sender user ID"}
+
+ ErrTopLevelPLNotInteger = AuthFailError{Index: "10.1", Message: "invalid type for top-level power level field"}
+ ErrPLNotInteger = AuthFailError{Index: "10.2", Message: "invalid type for power level"}
+ ErrInvalidUserIDInPL = AuthFailError{Index: "10.3", Message: "invalid user ID in power levels"}
+ ErrUserPLNotInteger = AuthFailError{Index: "10.3", Message: "invalid type for user power level"}
+ ErrCreatorInPowerLevels = AuthFailError{Index: "10.4", Message: "room creators must not be specified in power levels"}
+ ErrInvalidPowerChange = AuthFailError{Index: "10.x", Message: "illegal power level change"}
+ ErrInvalidUserPowerChange = AuthFailError{Index: "10.9", Message: "illegal power level change"}
+)
+
+func isRejected(evt *pdu.PDU) bool {
+ return evt.InternalMeta.Rejected
+}
+
+type GetEventsFunc = func(ids []id.EventID) ([]*pdu.PDU, error)
+
+func Authorize(roomVersion id.RoomVersion, evt *pdu.PDU, getEvents GetEventsFunc, getKey pdu.GetKeyFunc) error {
+ if evt.Type == event.StateCreate.Type {
+ // 1. If type is m.room.create:
+ return authorizeCreate(roomVersion, evt)
+ }
+ var createEvt *pdu.PDU
+ if roomVersion.RoomIDIsCreateEventID() {
+ // 2. If the event’s room_id is not an event ID for an accepted (not rejected) m.room.create event,
+ // with the sigil ! instead of $, reject.
+ if len(evt.RoomID) != 44 {
+ return fmt.Errorf("%w (%d)", ErrInvalidRoomIDLength, len(evt.RoomID))
+ } else if createEvts, err := getEvents([]id.EventID{id.EventID("$" + evt.RoomID[1:])}); err != nil {
+ return fmt.Errorf("%w: %w", ErrFailedToGetCreateEvent, err)
+ } else if len(createEvts) != 1 {
+ return fmt.Errorf("%w (%s)", ErrCreateEventNotFound, evt.RoomID)
+ } else if isRejected(createEvts[0]) {
+ return ErrRejectedCreateEvent
+ } else {
+ createEvt = createEvts[0]
+ }
+ }
+ authEvents, err := getEvents(evt.AuthEvents)
+ if err != nil {
+ return fmt.Errorf("%w: %w", ErrFailedToGetAuthEvents, err)
+ }
+ expectedAuthEvents := evt.AuthEventSelection(roomVersion)
+ deduplicator := make(map[pdu.StateKey]id.EventID, len(expectedAuthEvents))
+ // 3. Considering the event’s auth_events:
+ for i, ae := range authEvents {
+ authEvtID := evt.AuthEvents[i]
+ if ae == nil {
+ return fmt.Errorf("%w (%s)", ErrMissingAuthEvent, authEvtID)
+ } else if ae.StateKey == nil {
+ // This approximately falls under rule 3.2.
+ return fmt.Errorf("%w (%s)", ErrNonStateAuthEvent, authEvtID)
+ }
+ key := pdu.StateKey{Type: ae.Type, StateKey: *ae.StateKey}
+ if prevEvtID, alreadyFound := deduplicator[key]; alreadyFound {
+ // 3.1. If there are duplicate entries for a given type and state_key pair, reject.
+ return fmt.Errorf("%w for %s/%s: found %s and %s", ErrDuplicateAuthEvent, ae.Type, *ae.StateKey, prevEvtID, authEvtID)
+ } else if !expectedAuthEvents.Has(key) {
+ // 3.2. If there are entries whose type and state_key don’t match those specified by
+ // the auth events selection algorithm described in the server specification, reject.
+ return fmt.Errorf("%w: found %s with key %s/%s", ErrUnexpectedAuthEvent, authEvtID, ae.Type, *ae.StateKey)
+ } else if isRejected(ae) {
+ // 3.3. If there are entries which were themselves rejected under the checks performed on receipt of a PDU, reject.
+ return fmt.Errorf("%w (%s)", ErrRejectedAuthEvent, authEvtID)
+ } else if ae.RoomID != evt.RoomID {
+ // 3.4. If any event in auth_events has a room_id which does not match that of the event being authorised, reject.
+ return fmt.Errorf("%w (%s)", ErrMismatchingRoomIDInAuthEvent, authEvtID)
+ } else {
+ deduplicator[key] = authEvtID
+ }
+ if ae.Type == event.StateCreate.Type {
+ if createEvt == nil {
+ createEvt = ae
+ } else {
+ // Duplicates are prevented by deduplicator, AuthEventSelection also won't allow a create event at all for v12+
+ panic(fmt.Errorf("impossible case: multiple create events found in auth events"))
+ }
+ }
+ }
+ if createEvt == nil {
+ // This comes either from auth_events or room_id depending on the room version.
+ // The checks above make sure it's from the right source.
+ return ErrNoCreateEvent
+ }
+ if federateVal := gjson.GetBytes(createEvt.Content, mFederatePath); federateVal.Type == gjson.False && createEvt.Sender.Homeserver() != evt.Sender.Homeserver() {
+ // 4. If the content of the m.room.create event in the room state has the property m.federate set to false,
+ // and the sender domain of the event does not match the sender domain of the create event, reject.
+ return ErrFederationDisabled
+ }
+ if evt.Type == event.StateMember.Type {
+ // 5. If type is m.room.member:
+ return authorizeMember(roomVersion, evt, createEvt, authEvents, getKey)
+ }
+ senderMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, evt.Sender.String(), "membership", "leave"))
+ if senderMembership != event.MembershipJoin {
+ // 6. If the sender’s current membership state is not join, reject.
+ return ErrNotInRoom
+ }
+ powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt)
+ if err != nil {
+ return err
+ }
+ senderPL := powerLevels.GetUserLevel(evt.Sender)
+ if evt.Type == event.StateThirdPartyInvite.Type {
+ // 7.1. Allow if and only if sender’s current power level is greater than or equal to the invite level.
+ if senderPL >= powerLevels.Invite() {
+ return nil
+ }
+ return ErrInsufficientPowerForThirdPartyInvite
+ }
+ typeClass := event.MessageEventType
+ if evt.StateKey != nil {
+ typeClass = event.StateEventType
+ }
+ evtLevel := powerLevels.GetEventLevel(event.Type{Type: evt.Type, Class: typeClass})
+ if evtLevel > senderPL {
+ // 8. If the event type’s required power level is greater than the sender’s power level, reject.
+ return fmt.Errorf("%w (%d > %d)", ErrInsufficientPowerLevel, evtLevel, senderPL)
+ }
+
+ if evt.StateKey != nil && strings.HasPrefix(*evt.StateKey, "@") && *evt.StateKey != evt.Sender.String() {
+ // 9. If the event has a state_key that starts with an @ and does not match the sender, reject.
+ return ErrMismatchingPrivateStateKey
+ }
+
+ if evt.Type == event.StatePowerLevels.Type {
+ // 10. If type is m.room.power_levels:
+ return authorizePowerLevels(roomVersion, evt, createEvt, authEvents)
+ }
+
+ // 11. Otherwise, allow.
+ return nil
+}
+
+var ErrUserIDNotAString = errors.New("not a string")
+var ErrUserIDNotValid = errors.New("not a valid user ID")
+
+func isValidUserID(roomVersion id.RoomVersion, userID gjson.Result) error {
+ if userID.Type != gjson.String {
+ return ErrUserIDNotAString
+ }
+ // In a future room version, user IDs will have stricter validation
+ _, _, err := id.UserID(userID.Str).Parse()
+ if err != nil {
+ return ErrUserIDNotValid
+ }
+ return nil
+}
+
+func authorizeCreate(roomVersion id.RoomVersion, evt *pdu.PDU) error {
+ if len(evt.PrevEvents) > 0 {
+ // 1.1. If it has any prev_events, reject.
+ return ErrCreateHasPrevEvents
+ }
+ if roomVersion.RoomIDIsCreateEventID() {
+ if evt.RoomID != "" {
+ // 1.2. If the event has a room_id, reject.
+ return ErrCreateHasRoomID
+ }
+ } else {
+ _, _, server := id.ParseCommonIdentifier(evt.RoomID)
+ if server == "" || server != evt.Sender.Homeserver() {
+ // 1.2. (v11 and below) If the domain of the room_id does not match the domain of the sender, reject.
+ return ErrRoomIDDoesntMatchSender
+ }
+ }
+ if !roomVersion.IsKnown() {
+ // 1.3. If content.room_version is present and is not a recognised version, reject.
+ return fmt.Errorf("%w %s", ErrUnknownRoomVersion, roomVersion)
+ }
+ if roomVersion.PrivilegedRoomCreators() {
+ additionalCreators := gjson.GetBytes(evt.Content, "additional_creators")
+ if additionalCreators.Exists() {
+ if !additionalCreators.IsArray() {
+ return fmt.Errorf("%w: not an array", ErrInvalidAdditionalCreators)
+ }
+ for i, item := range additionalCreators.Array() {
+ // 1.4. If additional_creators is present in content and is not an array of strings
+ // where each string passes the same user ID validation applied to sender, reject.
+ if err := isValidUserID(roomVersion, item); err != nil {
+ return fmt.Errorf("%w: item #%d %w", ErrInvalidAdditionalCreators, i+1, err)
+ }
+ }
+ }
+ }
+ if roomVersion.CreatorInContent() {
+ // 1.4. (v10 and below) If content has no creator property, reject.
+ if !gjson.GetBytes(evt.Content, "creator").Exists() {
+ return ErrMissingCreator
+ }
+ }
+ // 1.5. Otherwise, allow.
+ return nil
+}
+
+func authorizeMember(roomVersion id.RoomVersion, evt, createEvt *pdu.PDU, authEvents []*pdu.PDU, getKey pdu.GetKeyFunc) error {
+ membership := event.Membership(gjson.GetBytes(evt.Content, "membership").Str)
+ if evt.StateKey == nil {
+ // 5.1. If there is no state_key property, or no membership property in content, reject.
+ return ErrMemberNotState
+ }
+ authorizedVia := id.UserID(gjson.GetBytes(evt.Content, "authorised_via_users_server").Str)
+ if authorizedVia != "" {
+ homeserver := authorizedVia.Homeserver()
+ err := evt.VerifySignature(roomVersion, homeserver, getKey)
+ if err != nil {
+ // 5.2. If content has a join_authorised_via_users_server key:
+ // 5.2.1. If the event is not validly signed by the homeserver of the user ID denoted by the key, reject.
+ return fmt.Errorf("%w: %w", ErrNotSignedByAuthoriser, err)
+ }
+ }
+ targetPrevMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, *evt.StateKey, "membership", "leave"))
+ senderMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, evt.Sender.String(), "membership", "leave"))
+ switch membership {
+ case event.MembershipJoin:
+ createEvtID, err := createEvt.GetEventID(roomVersion)
+ if err != nil {
+ return fmt.Errorf("failed to get create event ID: %w", err)
+ }
+ creator := createEvt.Sender.String()
+ if roomVersion.CreatorInContent() {
+ creator = gjson.GetBytes(evt.Content, "creator").Str
+ }
+ if len(evt.PrevEvents) == 1 &&
+ len(evt.AuthEvents) <= 1 &&
+ evt.PrevEvents[0] == createEvtID &&
+ *evt.StateKey == creator {
+ // 5.3.1. If the only previous event is an m.room.create and the state_key is the sender of the m.room.create, allow.
+ return nil
+ }
+ // Spec wart: this would make more sense before the check above.
+ // Now you can set anyone as the sender of the first join.
+ if evt.Sender.String() != *evt.StateKey {
+ // 5.3.2. If the sender does not match state_key, reject.
+ return ErrCantJoinOtherUser
+ }
+
+ if senderMembership == event.MembershipBan {
+ // 5.3.3. If the sender is banned, reject.
+ return ErrCantJoinBanned
+ }
+
+ joinRule := event.JoinRule(findEventAndReadString(authEvents, event.StateJoinRules.Type, "", "join_rule", "invite"))
+ switch joinRule {
+ case event.JoinRuleKnock:
+ if !roomVersion.Knocks() {
+ return ErrInvalidJoinRule
+ }
+ fallthrough
+ case event.JoinRuleInvite:
+ // 5.3.4. If the join_rule is invite or knock then allow if membership state is invite or join.
+ if targetPrevMembership == event.MembershipJoin || targetPrevMembership == event.MembershipInvite {
+ return nil
+ }
+ return ErrCantJoinWithoutInvite
+ case event.JoinRuleKnockRestricted:
+ if !roomVersion.KnockRestricted() {
+ return ErrInvalidJoinRule
+ }
+ fallthrough
+ case event.JoinRuleRestricted:
+ if joinRule == event.JoinRuleRestricted && !roomVersion.RestrictedJoins() {
+ return ErrInvalidJoinRule
+ }
+ if targetPrevMembership == event.MembershipJoin || targetPrevMembership == event.MembershipInvite {
+ // 5.3.5.1. If membership state is join or invite, allow.
+ return nil
+ }
+ powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt)
+ if err != nil {
+ return err
+ }
+ if powerLevels.GetUserLevel(authorizedVia) < powerLevels.Invite() {
+ // 5.3.5.2. If the join_authorised_via_users_server key in content is not a user with sufficient permission to invite other users, reject.
+ return ErrAuthoriserCantInvite
+ }
+ authorizerMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, authorizedVia.String(), "membership", string(event.MembershipLeave)))
+ if authorizerMembership != event.MembershipJoin {
+ return ErrAuthoriserNotInRoom
+ }
+ // 5.3.5.3. Otherwise, allow.
+ return nil
+ case event.JoinRulePublic:
+ // 5.3.6. If the join_rule is public, allow.
+ return nil
+ default:
+ // 5.3.7. Otherwise, reject.
+ return ErrInvalidJoinRule
+ }
+ case event.MembershipInvite:
+ tpiVal := gjson.GetBytes(evt.Content, "third_party_invite")
+ if tpiVal.Exists() {
+ if targetPrevMembership == event.MembershipBan {
+ return ErrThirdPartyInviteBanned
+ }
+ signed := tpiVal.Get("signed")
+ mxid := signed.Get("mxid").Str
+ token := signed.Get("token").Str
+ if mxid == "" || token == "" {
+ // 5.4.1.2. If content.third_party_invite does not have a signed property, reject.
+ // 5.4.1.3. If signed does not have mxid and token properties, reject.
+ return ErrThirdPartyInviteMissingFields
+ }
+ if mxid != *evt.StateKey {
+ // 5.4.1.4. If mxid does not match state_key, reject.
+ return ErrThirdPartyInviteMXIDMismatch
+ }
+ tpiEvt := findEvent(authEvents, event.StateThirdPartyInvite.Type, token)
+ if tpiEvt == nil {
+ // 5.4.1.5. If there is no m.room.third_party_invite event in the current room state with state_key matching token, reject.
+ return ErrThirdPartyInviteNotFound
+ }
+ if tpiEvt.Sender != evt.Sender {
+ // 5.4.1.6. If sender does not match sender of the m.room.third_party_invite, reject.
+ return ErrThirdPartyInviteSenderMismatch
+ }
+ var keys []id.Ed25519
+ const ed25519Base64Len = 43
+ oldPubKey := gjson.GetBytes(evt.Content, "public_key.token")
+ if oldPubKey.Type == gjson.String && len(oldPubKey.Str) == ed25519Base64Len {
+ keys = append(keys, id.Ed25519(oldPubKey.Str))
+ }
+ gjson.GetBytes(evt.Content, "public_keys").ForEach(func(key, value gjson.Result) bool {
+ if key.Type != gjson.Number {
+ return false
+ }
+ if value.Type == gjson.String && len(value.Str) == ed25519Base64Len {
+ keys = append(keys, id.Ed25519(value.Str))
+ }
+ return true
+ })
+ rawSigned := jsontext.Value(exstrings.UnsafeBytes(signed.Str))
+ var validated bool
+ for _, key := range keys {
+ if signutil.VerifyJSONAny(key, rawSigned) == nil {
+ validated = true
+ }
+ }
+ if validated {
+ // 4.4.1.7. If any signature in signed matches any public key in the m.room.third_party_invite event, allow.
+ return nil
+ }
+ // 4.4.1.8. Otherwise, reject.
+ return ErrThirdPartyInviteNotSigned
+ }
+ if senderMembership != event.MembershipJoin {
+ // 5.4.2. If the sender’s current membership state is not join, reject.
+ return ErrInviterNotInRoom
+ }
+ // 5.4.3. If target user’s current membership state is join or ban, reject.
+ if targetPrevMembership == event.MembershipJoin {
+ return ErrInviteTargetAlreadyInRoom
+ } else if targetPrevMembership == event.MembershipBan {
+ return ErrInviteTargetBanned
+ }
+ powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt)
+ if err != nil {
+ return err
+ }
+ if powerLevels.GetUserLevel(evt.Sender) >= powerLevels.Invite() {
+ // 5.4.4. If the sender’s power level is greater than or equal to the invite level, allow.
+ return nil
+ }
+ // 5.4.5. Otherwise, reject.
+ return ErrInsufficientPermissionForInvite
+ case event.MembershipLeave:
+ if evt.Sender.String() == *evt.StateKey {
+ // 5.5.1. If the sender matches state_key, allow if and only if that user’s current membership state is invite, join, or knock.
+ if senderMembership == event.MembershipInvite ||
+ senderMembership == event.MembershipJoin ||
+ (senderMembership == event.MembershipKnock && roomVersion.Knocks()) {
+ return nil
+ }
+ return ErrCantLeaveWithoutBeingInRoom
+ }
+ if senderMembership != event.MembershipJoin {
+ // 5.5.2. If the sender’s current membership state is not join, reject.
+ return ErrCantKickWithoutBeingInRoom
+ }
+ powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt)
+ if err != nil {
+ return err
+ }
+ senderLevel := powerLevels.GetUserLevel(evt.Sender)
+ if targetPrevMembership == event.MembershipBan && senderLevel < powerLevels.Ban() {
+ // 5.5.3. If the target user’s current membership state is ban, and the sender’s power level is less than the ban level, reject.
+ return ErrInsufficientPermissionForUnban
+ }
+ if senderLevel >= powerLevels.Kick() && powerLevels.GetUserLevel(id.UserID(*evt.StateKey)) < senderLevel {
+ // 5.5.4. If the sender’s power level is greater than or equal to the kick level, and the target user’s power level is less than the sender’s power level, allow.
+ return nil
+ }
+ // TODO separate errors for < kick and < target user level?
+ // 5.5.5. Otherwise, reject.
+ return ErrInsufficientPermissionForKick
+ case event.MembershipBan:
+ if senderMembership != event.MembershipJoin {
+ // 5.6.1. If the sender’s current membership state is not join, reject.
+ return ErrCantBanWithoutBeingInRoom
+ }
+ powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt)
+ if err != nil {
+ return err
+ }
+ senderLevel := powerLevels.GetUserLevel(evt.Sender)
+ if senderLevel >= powerLevels.Ban() && powerLevels.GetUserLevel(id.UserID(*evt.StateKey)) < senderLevel {
+ // 5.6.2. If the sender’s power level is greater than or equal to the ban level, and the target user’s power level is less than the sender’s power level, allow.
+ return nil
+ }
+ // 5.6.3. Otherwise, reject.
+ return ErrInsufficientPermissionForBan
+ case event.MembershipKnock:
+ joinRule := event.JoinRule(findEventAndReadString(authEvents, event.StateJoinRules.Type, "", "join_rule", "invite"))
+ validKnockRule := roomVersion.Knocks() && joinRule == event.JoinRuleKnock
+ validKnockRestrictedRule := roomVersion.KnockRestricted() && joinRule == event.JoinRuleKnockRestricted
+ if !validKnockRule && !validKnockRestrictedRule {
+ // 5.7.1. If the join_rule is anything other than knock or knock_restricted, reject.
+ return ErrNotKnockableRoom
+ }
+ if evt.Sender.String() != *evt.StateKey {
+ // 5.7.2. If the sender does not match state_key, reject.
+ return ErrCantKnockOtherUser
+ }
+ if senderMembership != event.MembershipBan && senderMembership != event.MembershipInvite && senderMembership != event.MembershipJoin {
+ // 5.7.3. If the sender’s current membership is not ban, invite, or join, allow.
+ return nil
+ }
+ // 5.7.4. Otherwise, reject.
+ return ErrCantKnockWhileInRoom
+ default:
+ // 5.8. Otherwise, the membership is unknown. Reject.
+ return ErrUnknownMembership
+ }
+}
+
+func authorizePowerLevels(roomVersion id.RoomVersion, evt, createEvt *pdu.PDU, authEvents []*pdu.PDU) error {
+ if roomVersion.ValidatePowerLevelInts() {
+ for _, key := range []string{"users_default", "events_default", "state_default", "ban", "redact", "kick", "invite"} {
+ res := gjson.GetBytes(evt.Content, key)
+ if !res.Exists() {
+ continue
+ }
+ if parseIntWithVersion(roomVersion, res) == nil {
+ // 10.1. If any of the properties users_default, events_default, state_default, ban, redact, kick, or invite in content are present and not an integer, reject.
+ return fmt.Errorf("%w %s", ErrTopLevelPLNotInteger, key)
+ }
+ }
+ for _, key := range []string{"events", "notifications"} {
+ obj := gjson.GetBytes(evt.Content, key)
+ if !obj.Exists() {
+ continue
+ }
+ // 10.2. If either of the properties events or notifications in content are present and not an object [...], reject.
+ if !obj.IsObject() {
+ return fmt.Errorf("%w %s", ErrTopLevelPLNotInteger, key)
+ }
+ var err error
+ // 10.2. [...] are not an object with values that are integers, reject.
+ obj.ForEach(func(innerKey, value gjson.Result) bool {
+ if parseIntWithVersion(roomVersion, value) == nil {
+ err = fmt.Errorf("%w %s.%s", ErrPLNotInteger, key, innerKey.Str)
+ return false
+ }
+ return true
+ })
+ if err != nil {
+ return err
+ }
+ }
+ }
+ var creators []id.UserID
+ if roomVersion.PrivilegedRoomCreators() {
+ creators = append(creators, createEvt.Sender)
+ gjson.GetBytes(createEvt.Content, "additional_creators").ForEach(func(key, value gjson.Result) bool {
+ creators = append(creators, id.UserID(value.Str))
+ return true
+ })
+ }
+ users := gjson.GetBytes(evt.Content, "users")
+ if users.Exists() {
+ if !users.IsObject() {
+ // 10.3. If the users property in content is not an object [...], reject.
+ return fmt.Errorf("%w users", ErrTopLevelPLNotInteger)
+ }
+ var err error
+ users.ForEach(func(key, value gjson.Result) bool {
+ if validatorErr := isValidUserID(roomVersion, key); validatorErr != nil {
+ // 10.3. [...] is not an object with keys that are valid user IDs [...], reject.
+ err = fmt.Errorf("%w: %q %w", ErrInvalidUserIDInPL, key.Str, validatorErr)
+ return false
+ }
+ if parseIntWithVersion(roomVersion, value) == nil {
+ // 10.3. [...] is not an object [...] with values that are integers, reject.
+ err = fmt.Errorf("%w %q", ErrUserPLNotInteger, key.Str)
+ return false
+ }
+ // creators is only filled if the room version has privileged room creators
+ if slices.Contains(creators, id.UserID(key.Str)) {
+ // 10.4. If the users property in content contains the sender of the m.room.create event or any of
+ // the additional_creators array (if present) from the content of the m.room.create event, reject.
+ err = fmt.Errorf("%w: %q", ErrCreatorInPowerLevels, key.Str)
+ return false
+ }
+ return true
+ })
+ if err != nil {
+ return err
+ }
+ }
+ oldPL := findEvent(authEvents, event.StatePowerLevels.Type, "")
+ if oldPL == nil {
+ // 10.5. If there is no previous m.room.power_levels event in the room, allow.
+ return nil
+ }
+ if slices.Contains(creators, evt.Sender) {
+ // Skip remaining checks for creators
+ return nil
+ }
+ senderPLPtr := parsePythonInt(gjson.GetBytes(oldPL.Content, exgjson.Path("users", evt.Sender.String())))
+ if senderPLPtr == nil {
+ senderPLPtr = parsePythonInt(gjson.GetBytes(oldPL.Content, "users_default"))
+ if senderPLPtr == nil {
+ senderPLPtr = ptr.Ptr(0)
+ }
+ }
+ for _, key := range []string{"users_default", "events_default", "state_default", "ban", "redact", "kick", "invite"} {
+ oldVal := gjson.GetBytes(oldPL.Content, key)
+ newVal := gjson.GetBytes(evt.Content, key)
+ if err := allowPowerChange(roomVersion, *senderPLPtr, key, oldVal, newVal); err != nil {
+ return err
+ }
+ }
+ if err := allowPowerChangeMap(
+ roomVersion, *senderPLPtr, "events", "",
+ gjson.GetBytes(oldPL.Content, "events"),
+ gjson.GetBytes(evt.Content, "events"),
+ ); err != nil {
+ return err
+ }
+ if err := allowPowerChangeMap(
+ roomVersion, *senderPLPtr, "notifications", "",
+ gjson.GetBytes(oldPL.Content, "notifications"),
+ gjson.GetBytes(evt.Content, "notifications"),
+ ); err != nil {
+ return err
+ }
+ if err := allowPowerChangeMap(
+ roomVersion, *senderPLPtr, "users", evt.Sender.String(),
+ gjson.GetBytes(oldPL.Content, "users"),
+ gjson.GetBytes(evt.Content, "users"),
+ ); err != nil {
+ return err
+ }
+ return nil
+}
+
+func allowPowerChangeMap(roomVersion id.RoomVersion, maxVal int, path, ownID string, old, new gjson.Result) (err error) {
+ old.ForEach(func(key, value gjson.Result) bool {
+ newVal := new.Get(exgjson.Path(key.Str))
+ err = allowPowerChange(roomVersion, maxVal, path+"."+key.Str, value, newVal)
+ if err == nil && ownID != "" && key.Str != ownID {
+ parsedOldVal := parseIntWithVersion(roomVersion, value)
+ parsedNewVal := parseIntWithVersion(roomVersion, newVal)
+ if *parsedOldVal >= maxVal && *parsedOldVal != *parsedNewVal {
+ err = fmt.Errorf("%w: can't change users.%s from %s to %s with sender level %d", ErrInvalidUserPowerChange, key.Str, stringifyForError(value), stringifyForError(newVal), maxVal)
+ }
+ }
+ return err == nil
+ })
+ if err != nil {
+ return
+ }
+ new.ForEach(func(key, value gjson.Result) bool {
+ err = allowPowerChange(roomVersion, maxVal, path+"."+key.Str, old.Get(exgjson.Path(key.Str)), value)
+ return err == nil
+ })
+ return
+}
+
+func allowPowerChange(roomVersion id.RoomVersion, maxVal int, path string, old, new gjson.Result) error {
+ oldVal := parseIntWithVersion(roomVersion, old)
+ newVal := parseIntWithVersion(roomVersion, new)
+ if oldVal == nil {
+ if newVal == nil || *newVal <= maxVal {
+ return nil
+ }
+ } else if newVal == nil {
+ if *oldVal <= maxVal {
+ return nil
+ }
+ } else if *oldVal == *newVal || (*oldVal <= maxVal && *newVal <= maxVal) {
+ return nil
+ }
+ return fmt.Errorf("%w can't change %s from %s to %s with sender level %d", ErrInvalidPowerChange, path, stringifyForError(old), stringifyForError(new), maxVal)
+}
+
+func stringifyForError(val gjson.Result) string {
+ if !val.Exists() {
+ return "null"
+ }
+ return val.Raw
+}
+
+func findEvent(events []*pdu.PDU, evtType, stateKey string) *pdu.PDU {
+ for _, evt := range events {
+ if evt.Type == evtType && *evt.StateKey == stateKey {
+ return evt
+ }
+ }
+ return nil
+}
+
+func findEventAndReadData[T any](events []*pdu.PDU, evtType, stateKey string, reader func(evt *pdu.PDU) T) T {
+ return reader(findEvent(events, evtType, stateKey))
+}
+
+func findEventAndReadString(events []*pdu.PDU, evtType, stateKey, fieldPath, defVal string) string {
+ return findEventAndReadData(events, evtType, stateKey, func(evt *pdu.PDU) string {
+ if evt == nil {
+ return defVal
+ }
+ res := gjson.GetBytes(evt.Content, fieldPath)
+ if res.Type != gjson.String {
+ return defVal
+ }
+ return res.Str
+ })
+}
+
+func getPowerLevels(roomVersion id.RoomVersion, authEvents []*pdu.PDU, createEvt *pdu.PDU) (*event.PowerLevelsEventContent, error) {
+ var err error
+ powerLevels := findEventAndReadData(authEvents, event.StatePowerLevels.Type, "", func(evt *pdu.PDU) *event.PowerLevelsEventContent {
+ if evt == nil {
+ return nil
+ }
+ content := evt.Content
+ out := &event.PowerLevelsEventContent{}
+ if !roomVersion.ValidatePowerLevelInts() {
+ safeParsePowerLevels(content, out)
+ } else {
+ err = json.Unmarshal(content, out)
+ }
+ return out
+ })
+ if err != nil {
+ // This should never happen thanks to safeParsePowerLevels for v1-9 and strict validation in v10+
+ return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err)
+ }
+ if roomVersion.PrivilegedRoomCreators() {
+ if powerLevels == nil {
+ powerLevels = &event.PowerLevelsEventContent{}
+ }
+ powerLevels.CreateEvent, err = createEvt.ToClientEvent(roomVersion)
+ if err != nil {
+ return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err)
+ }
+ err = powerLevels.CreateEvent.Content.ParseRaw(powerLevels.CreateEvent.Type)
+ if err != nil {
+ return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err)
+ }
+ } else if powerLevels == nil {
+ powerLevels = &event.PowerLevelsEventContent{
+ Users: map[id.UserID]int{
+ createEvt.Sender: 100,
+ },
+ }
+ }
+ return powerLevels, nil
+}
+
+func parseIntWithVersion(roomVersion id.RoomVersion, val gjson.Result) *int {
+ if roomVersion.ValidatePowerLevelInts() {
+ if val.Type != gjson.Number {
+ return nil
+ }
+ return ptr.Ptr(int(val.Int()))
+ }
+ return parsePythonInt(val)
+}
+
+func parsePythonInt(val gjson.Result) *int {
+ switch val.Type {
+ case gjson.True:
+ return ptr.Ptr(1)
+ case gjson.False:
+ return ptr.Ptr(0)
+ case gjson.Number:
+ return ptr.Ptr(int(val.Int()))
+ case gjson.String:
+ // strconv.Atoi accepts signs as well as leading zeroes, so we just need to trim spaces beforehand
+ num, err := strconv.Atoi(strings.TrimSpace(val.Str))
+ if err != nil {
+ return nil
+ }
+ return &num
+ default:
+ // Python int() doesn't accept nulls, arrays or dicts
+ return nil
+ }
+}
+
+func safeParsePowerLevels(content jsontext.Value, into *event.PowerLevelsEventContent) {
+ *into = event.PowerLevelsEventContent{
+ Users: make(map[id.UserID]int),
+ UsersDefault: ptr.Val(parsePythonInt(gjson.GetBytes(content, "users_default"))),
+ Events: make(map[string]int),
+ EventsDefault: ptr.Val(parsePythonInt(gjson.GetBytes(content, "events_default"))),
+ Notifications: nil, // irrelevant for event auth
+ StateDefaultPtr: parsePythonInt(gjson.GetBytes(content, "state_default")),
+ InvitePtr: parsePythonInt(gjson.GetBytes(content, "invite")),
+ KickPtr: parsePythonInt(gjson.GetBytes(content, "kick")),
+ BanPtr: parsePythonInt(gjson.GetBytes(content, "ban")),
+ RedactPtr: parsePythonInt(gjson.GetBytes(content, "redact")),
+ }
+ gjson.GetBytes(content, "events").ForEach(func(key, value gjson.Result) bool {
+ if key.Type != gjson.String {
+ return false
+ }
+ val := parsePythonInt(value)
+ if val != nil {
+ into.Events[key.Str] = *val
+ }
+ return true
+ })
+ gjson.GetBytes(content, "users").ForEach(func(key, value gjson.Result) bool {
+ if key.Type != gjson.String {
+ return false
+ }
+ val := parsePythonInt(value)
+ if val == nil {
+ return false
+ }
+ userID := id.UserID(key.Str)
+ if _, _, err := userID.Parse(); err != nil {
+ return false
+ }
+ into.Users[userID] = *val
+ return true
+ })
+}
diff --git a/federation/eventauth/eventauth_internal_test.go b/federation/eventauth/eventauth_internal_test.go
new file mode 100644
index 00000000..d316f3c8
--- /dev/null
+++ b/federation/eventauth/eventauth_internal_test.go
@@ -0,0 +1,66 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package eventauth
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+type pythonIntTest struct {
+ Name string
+ Input string
+ Expected int64
+}
+
+var pythonIntTests = []pythonIntTest{
+ {"True", `true`, 1},
+ {"False", `false`, 0},
+ {"SmallFloat", `3.1415`, 3},
+ {"SmallFloatRoundDown", `10.999999999999999`, 10},
+ {"SmallFloatRoundUp", `10.9999999999999999`, 11},
+ {"BigFloatRoundDown", `1000000.9999999999`, 1000000},
+ {"BigFloatRoundUp", `1000000.99999999999`, 1000001},
+ {"BigFloatPrecisionError", `9007199254740993.0`, 9007199254740992},
+ {"BigFloatPrecisionError2", `9007199254740993.123`, 9007199254740994},
+ {"Int64", `9223372036854775807`, 9223372036854775807},
+ {"Int64String", `"9223372036854775807"`, 9223372036854775807},
+ {"String", `"123"`, 123},
+ {"InvalidFloatInString", `"123.456"`, 0},
+ {"StringWithPlusSign", `"+123"`, 123},
+ {"StringWithMinusSign", `"-123"`, -123},
+ {"StringWithSpaces", `" 123 "`, 123},
+ {"StringWithSpacesAndSign", `" -123 "`, -123},
+ //{"StringWithUnderscores", `"123_456"`, 123456},
+ //{"StringWithUnderscores", `"123_456"`, 123456},
+ {"InvalidStringWithTrailingUnderscore", `"123_456_"`, 0},
+ {"InvalidStringWithMultipleUnderscores", `"123__456"`, 0},
+ {"InvalidStringWithLeadingUnderscore", `"_123_456"`, 0},
+ {"InvalidStringWithUnderscoreAfterSign", `"+_123_456"`, 0},
+ {"InvalidStringWithUnderscoreAfterSpace", `" _123_456"`, 0},
+ //{"StringWithUnderscoresAndSpaces", `" +1_2_3_4_5_6 "`, 123456},
+}
+
+func TestParsePythonInt(t *testing.T) {
+ for _, test := range pythonIntTests {
+ t.Run(test.Name, func(t *testing.T) {
+ output := parsePythonInt(gjson.Parse(test.Input))
+ if strings.HasPrefix(test.Name, "Invalid") {
+ assert.Nil(t, output)
+ } else {
+ require.NotNil(t, output)
+ assert.Equal(t, int(test.Expected), *output)
+ }
+ })
+ }
+}
diff --git a/federation/eventauth/eventauth_test.go b/federation/eventauth/eventauth_test.go
new file mode 100644
index 00000000..e3c5cd76
--- /dev/null
+++ b/federation/eventauth/eventauth_test.go
@@ -0,0 +1,85 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package eventauth_test
+
+import (
+ "embed"
+ "encoding/json/jsontext"
+ "encoding/json/v2"
+ "errors"
+ "io"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+ "go.mau.fi/util/exerrors"
+ "go.mau.fi/util/ptr"
+
+ "maunium.net/go/mautrix/federation/eventauth"
+ "maunium.net/go/mautrix/federation/pdu"
+ "maunium.net/go/mautrix/id"
+)
+
+//go:embed *.jsonl
+var data embed.FS
+
+type eventMap map[id.EventID]*pdu.PDU
+
+func (em eventMap) Get(ids []id.EventID) ([]*pdu.PDU, error) {
+ output := make([]*pdu.PDU, len(ids))
+ for i, evtID := range ids {
+ output[i] = em[evtID]
+ }
+ return output, nil
+}
+
+func GetKey(serverName string, keyID id.KeyID, validUntilTS time.Time) (id.SigningKey, time.Time, error) {
+ return "", time.Time{}, nil
+}
+
+func TestAuthorize(t *testing.T) {
+ files := exerrors.Must(data.ReadDir("."))
+ for _, file := range files {
+ t.Run(file.Name(), func(t *testing.T) {
+ decoder := jsontext.NewDecoder(exerrors.Must(data.Open(file.Name())))
+ events := make(eventMap)
+ var roomVersion *id.RoomVersion
+ for i := 1; ; i++ {
+ var evt *pdu.PDU
+ err := json.UnmarshalDecode(decoder, &evt)
+ if errors.Is(err, io.EOF) {
+ break
+ }
+ require.NoError(t, err)
+ if roomVersion == nil {
+ require.Equal(t, evt.Type, "m.room.create")
+ roomVersion = ptr.Ptr(id.RoomVersion(gjson.GetBytes(evt.Content, "room_version").Str))
+ }
+ expectedEventID := gjson.GetBytes(evt.Unsigned, "event_id").Str
+ evtID, err := evt.GetEventID(*roomVersion)
+ require.NoError(t, err)
+ require.Equalf(t, id.EventID(expectedEventID), evtID, "Event ID mismatch for event #%d", i)
+
+ // TODO allow redacted events
+ assert.True(t, evt.VerifyContentHash(), i)
+
+ events[evtID] = evt
+ err = eventauth.Authorize(*roomVersion, evt, events.Get, GetKey)
+ if err != nil {
+ evt.InternalMeta.Rejected = true
+ }
+ // TODO allow testing intentionally rejected events
+ assert.NoErrorf(t, err, "Failed to authorize event #%d / %s of type %s", i, evtID, evt.Type)
+ }
+ })
+ }
+
+}
diff --git a/federation/eventauth/testroom-v12-success.jsonl b/federation/eventauth/testroom-v12-success.jsonl
new file mode 100644
index 00000000..2b751de3
--- /dev/null
+++ b/federation/eventauth/testroom-v12-success.jsonl
@@ -0,0 +1,21 @@
+{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age_ts":1756071567186,"event_id":"$lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54"}}
+{"auth_events":[],"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"depth":2,"hashes":{"sha256":"MXmgq0e4J9CdIP0IVKVvueFhOb+ndlsXpeyI+6l/2FI"},"origin_server_ts":1756071567259,"prev_events":["$lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"xMgRzyRg9VM9XCKpfFJA+MrYoI68b8PIddKpMTcxz/fDzmGSHEy6Ta2b59VxiX3NoJe2CigkDZ3+jVsQoZYIBA"}},"state_key":"@tulir:maunium.net","type":"m.room.member","unsigned":{"age_ts":1756071567259,"event_id":"$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"}}
+{"auth_events":["$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":150},"events_default":0,"historical":100,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001},"users_default":0},"depth":3,"hashes":{"sha256":"/JzQNBNqJ/i8vwj6xESDaD5EDdOqB4l/LmKlvAVl5jY"},"origin_server_ts":1756071567319,"prev_events":["$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"W3N3X/enja+lumXw3uz66/wT9oczoxrmHbAD5/RF069cX4wkCtqtDd61VWPkSGmKxdV1jurgbCqSX6+Q9/t3AA"}},"state_key":"","type":"m.room.power_levels","unsigned":{"age_ts":1756071567319,"event_id":"$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"}}
+{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"join_rule":"invite"},"depth":4,"hashes":{"sha256":"GBu5AySj75ZXlOLd65mB03KueFKOHNgvtg2o/LUnLyI"},"origin_server_ts":1756071567320,"prev_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"XqWEnFREo2PhRnaebGjNzdHdtD691BtCQKkLnpKd8P3lVDewDt8OkCbDSk/Uzh9rDtzwWEsbsIoKSYuOm+G6CA"}},"state_key":"","type":"m.room.join_rules","unsigned":{"age_ts":1756071567320,"event_id":"$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"}}
+{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"history_visibility":"shared"},"depth":5,"hashes":{"sha256":"niDi5vG2akQm0f5pm0aoCYXqmWjXRfmP1ulr/ZEPm/k"},"origin_server_ts":1756071567320,"prev_events":["$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"PTIrNke/fc9+ObKAl/K0PGZfmpe8dwREyoA5rXffOXWdRHSaBifn9UIiJUqd68Bzvrv4RcADTR/ci7lUquFBBw"}},"state_key":"","type":"m.room.history_visibility","unsigned":{"age_ts":1756071567320,"event_id":"$Wmy3G9yxl9ArVg5ZsdeIDPxBsNAdgseuvHoqHTZ2vug"}}
+{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"guest_access":"can_join"},"depth":6,"hashes":{"sha256":"sZ9QqsId4oarFF724esTohXuRxDNnaXPl+QmTDG60dw"},"origin_server_ts":1756071567321,"prev_events":["$Wmy3G9yxl9ArVg5ZsdeIDPxBsNAdgseuvHoqHTZ2vug"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"Eh2P9/hl38wfZx2AQbeS5VCD4wldXPfeP2sQsJsLtfmdwFV74jrlGVBaKIkaYcXY4eA08iDp8HW5jqttZqKKDg"}},"state_key":"","type":"m.room.guest_access","unsigned":{"age_ts":1756071567321,"event_id":"$hYVRH7F4P5mB5IqvBDDU5aXY7pYGG0ApstrryiVPKmQ"}}
+{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"name":"event auth test v12"},"depth":7,"hashes":{"sha256":"tjwPo38yR+23Was6SbxLvPMhNx44DaXLhF3rKgngepU"},"origin_server_ts":1756071567321,"prev_events":["$hYVRH7F4P5mB5IqvBDDU5aXY7pYGG0ApstrryiVPKmQ"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"q1rk0c5m8TJYE9tePsMaLeaigatNNbvaLRom0X8KiZY0EH+itujfA+/UnksvmPmMmThfAXWlFLx5u8tcuSVyCQ"}},"state_key":"","type":"m.room.name","unsigned":{"age_ts":1756071567321,"event_id":"$fFDwIavLTEIfcnggWuryB6JwfS-L2KT6vP1ap3P6ctE"}}
+{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"],"content":{"avatar_url":"mxc://envs.net/000cf1510b7c61018f9c72ca4cc63668370782c81725865933316030464","displayname":"tulir[e]","membership":"invite"},"depth":8,"hashes":{"sha256":"r5EBUZN/4LbVcMYwuffDcVV9G4OMHzAQuNbnjigL+OE"},"origin_server_ts":1756071567548,"prev_events":["$fFDwIavLTEIfcnggWuryB6JwfS-L2KT6vP1ap3P6ctE"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"envs.net":{"ed25519:wuJyKT":"svB+uW4Tsj8/I+SYbLl+LPPjBlqxGNXE4wGyAxlP7vfyJtFf7Kn/19jx65wT9ebeCq5sTGlEDV4Fabwma9LhDA"},"maunium.net":{"ed25519:a_xxeS":"LBYMcdJVSNsLd6SmOgx5oOU/0xOeCl03o4g83VwJfHWlRuTT5l9+qlpNED28wY07uxoU9MgLgXXICJ0EezMBCg"}},"state_key":"@tulir:envs.net","type":"m.room.member","unsigned":{"age_ts":1756071567548,"event_id":"$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok","invite_room_state":[{"content":{"join_rule":"invite"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.join_rules"},{"content":{"name":"event auth test v12"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.name"},{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age_ts":1756071567186}},{"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"sender":"@tulir:maunium.net","state_key":"@tulir:maunium.net","type":"m.room.member"}]}}
+{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":9,"hashes":{"sha256":"23rgMf7EGJcYt3Aj0qAFnmBWCxuU9Uk+ReidqtIJDKQ"},"origin_server_ts":1756071575986,"prev_events":["$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"p+Fm/uWO8VXJdCYvN/dVb8HF8W3t1sssNCBiOWbzAeuS3QqYjoMKHyixLuN1mOdnCyATv7SsHHmA4+cELRGdAA"}},"type":"m.room.message","unsigned":{"age_ts":1756071576002,"event_id":"$eZDCydRWSRnR5od0c7ahz2qSZQDHbl5g5PITT0OMC3E"}}
+{"auth_events":["$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"],"content":{"avatar_url":"mxc://envs.net/000cf1510b7c61018f9c72ca4cc63668370782c81725865933316030464","displayname":"tulir[e]","membership":"join"},"depth":10,"hashes":{"sha256":"2kJPx2UsysNzTH8QGYHUKTO/05yetxKRlI0nKFeGbts"},"origin_server_ts":1756071578631,"prev_events":["$eZDCydRWSRnR5od0c7ahz2qSZQDHbl5g5PITT0OMC3E"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"Wuzxkh8nEEX6mdJzph6Bt5ku+odFkEg2RIpFAAirOqxgcrwRaz42PsJni3YbfzH1qneF+iWQ/neA+up6jLXFBw"}},"state_key":"@tulir:envs.net","type":"m.room.member","unsigned":{"age":6,"event_id":"$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","replaces_state":"$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok"}}
+{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"content":{"avatar_url":"mxc://matrix.org/BDYVQFSLvZHMaKHDGiRkvhVg","displayname":"tulir[m]","membership":"invite"},"depth":11,"hashes":{"sha256":"dRE11R2hBfFalQ5tIJdyaElUIiSE5aCKMddjek4wR3c"},"origin_server_ts":1756071591449,"prev_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"/Mi4kX40fbR+V3DCJJGI/9L3Uuf8y5Un8LHlCQv1T0O5gnFZGQ3qN6rRNaZ1Kdh3QJBU6H4NTfnd+SVj3wt3CQ"},"matrix.org":{"ed25519:a_RXGa":"ZeLm/oxP3/Cds/uCL2FaZpgjUp0vTDBlGG6YVFNl76yIVlyIKKQKR6BSVw2u5KC5Mu9M1f+0lDmLGQujR5NkBg"}},"state_key":"@tulir:matrix.org","type":"m.room.member","unsigned":{"event_id":"$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4","invite_room_state":[{"content":{"join_rule":"invite"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.join_rules"},{"content":{"name":"event auth test v12"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.name"},{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age":11553}},{"content":{"avatar_url":"mxc://envs.net/000cf1510b7c61018f9c72ca4cc63668370782c81725865933316030464","displayname":"tulir[e]","membership":"join"},"sender":"@tulir:envs.net","state_key":"@tulir:envs.net","type":"m.room.member"}]}}
+{"auth_events":["$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"content":{"avatar_url":"mxc://matrix.org/BDYVQFSLvZHMaKHDGiRkvhVg","displayname":"tulir[m]","membership":"join"},"depth":12,"hashes":{"sha256":"hR/fRIyFkxKnA1XNxIB+NKC0VR0vHs82EDgydhmmZXU"},"origin_server_ts":1756071609205,"prev_events":["$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"matrix.org":{"ed25519:a_RXGa":"keWbZHm+LPW22XWxb14Att4Ae4GVc6XAKAnxFRr3hxhrgEhsnMcxUx7fjqlA1dk3As6kjLKdekcyCef+AQCXCA"}},"state_key":"@tulir:matrix.org","type":"m.room.member","unsigned":{"age":19,"event_id":"$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw","replaces_state":"$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4"}}
+{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":150},"events_default":0,"historical":100,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":13,"hashes":{"sha256":"30Wuw3xIbA8+eXQBa4nFDKcyHtMbKPBYhLW1zft9/fE"},"origin_server_ts":1756071643928,"prev_events":["$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"x6Y4uViq4nK8LVPqtMLdCuvNET2bnjxYTgiKuEe1JYfwB4jPBnPuqvrt1O9oaanMpcRWbnuiZjckq4bUlRZ7Cw"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg","replaces_state":"$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"}}
+{"auth_events":["$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg","$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw"],"content":{"name":"event auth test v12!"},"depth":14,"hashes":{"sha256":"WT0gz7KYXvbdNruRavqIi9Hhul3rxCdZ+YY9yMGN+Fw"},"origin_server_ts":1756071656988,"prev_events":["$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"matrix.org":{"ed25519:a_RXGa":"bSplmqtXVhO2Z3hJ8JMQ/u7G2Wmg6yt7SwhYXObRQJfthekddJN152ME4YJIwy7YD8WFq7EkyB/NMyQoliYyCg"}},"state_key":"","type":"m.room.name","unsigned":{"event_id":"$p4xvOczrhzQMtRW3-Tf86LYUb5aqpGFIgjwHBuxWIcI","replaces_state":"$fFDwIavLTEIfcnggWuryB6JwfS-L2KT6vP1ap3P6ctE"}}
+{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":9001},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":15,"hashes":{"sha256":"FnGzbcXc8YOiB1TY33QunGA17Axoyuu3sdVOj5Z408o"},"origin_server_ts":1756071804931,"prev_events":["$p4xvOczrhzQMtRW3-Tf86LYUb5aqpGFIgjwHBuxWIcI"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"uyTUsPR+CzCtlevzB5+sNXvmfbPSp6u7RZC4E4TLVsj45+pjmMRswAvuHP9PT2+Tkl6Hu8ZPigsXgbKZtR35Aw"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw","replaces_state":"$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg"}}
+{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":100},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":16,"hashes":{"sha256":"KcivsiLesdnUnKX23Akk3OJEJFGRSY0g4H+p7XIThnw"},"origin_server_ts":1756071812688,"prev_events":["$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"cAK8dO2AVZklY9te5aVKbF1jR/eB5rzeNOXfYPjBLf+aSAS4Z6R2aMKW6hJB9PqRS4S+UZc24DTrjUjnvMzeBA"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU","replaces_state":"$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"}}
+{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"],"content":{"body":"meow #2","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":17,"hashes":{"sha256":"SgH9fOXGdbdqpRfYmoz1t29+gX8Ze4ThSoj6klZs3Og"},"origin_server_ts":1756247476706,"prev_events":["$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"SMYK7zP3SaQOKhzZUKUBVCKwffYqi3PFAlPM34kRJtmfGU3KZXNBT0zi+veXDMmxkMunqhF2RTHBD6joa0kBAQ"}},"type":"m.room.message","unsigned":{"event_id":"$KFHLO0-ENYOGQXogp84C-ISSu1xtKUzIMaZ6LiBcR_w"}}
+{"auth_events":["$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw","$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":100},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:beeper.com":8999,"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":18,"hashes":{"sha256":"l8Mw3VKn/Bvntg7bZ8uh5J8M2IBZM93Xg7hsdaSci8s"},"origin_server_ts":1758918656341,"prev_events":["$KFHLO0-ENYOGQXogp84C-ISSu1xtKUzIMaZ6LiBcR_w"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"matrix.org":{"ed25519:a_RXGa":"cg5LP0WuTnVB5jFhNERLLU5b+EhmyACiOq6cp3gKJnZsTAb1yajcgJybLWKrc8QQqxPa7hPnskRBgt4OBTFNAA"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0","replaces_state":"$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"}}
+{"auth_events":["$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw","$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"],"content":{"avatar_url":"mxc://beeper.com/eBdwbHbllONoAySQkXLjbfFM","displayname":"tulir[b]","membership":"invite"},"depth":19,"hashes":{"sha256":"KpmaRUQnJju8TIDMPzakitUIKOWJxTvULpFB3a1CGgc"},"origin_server_ts":1758918665952,"prev_events":["$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"beeper.com":{"ed25519:a_zgvp":"mzI9rPkQ1xHl2/G5Yrn0qmIRt5OyjPNqRwilPfH4jmr1tP+vv3vC0m4mph/MCOq8S1c/DQaCWSpdOX1uWfchBQ"},"matrix.org":{"ed25519:a_RXGa":"kEdfr8DjxC/bdvGYxnniFI/pxDWeyG73OjG/Gu1uoHLhjdtAT/vEQ6lotJJs214/KX5eAaQWobE9qtMvtPwMDw"}},"state_key":"@tulir:beeper.com","type":"m.room.member","unsigned":{"event_id":"$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro","invite_room_state":[{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age":11553}},{"content":{"avatar_url":"mxc://matrix.org/BDYVQFSLvZHMaKHDGiRkvhVg","displayname":"tulir[m]","membership":"join"},"sender":"@tulir:matrix.org","state_key":"@tulir:matrix.org","type":"m.room.member"},{"content":{"name":"event auth test v12!"},"sender":"@tulir:matrix.org","state_key":"","type":"m.room.name"},{"content":{"join_rule":"invite"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.join_rules"}]}}
+{"auth_events":["$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro","$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"],"content":{"avatar_url":"mxc://beeper.com/eBdwbHbllONoAySQkXLjbfFM","displayname":"tulir[b]","membership":"join"},"depth":20,"hashes":{"sha256":"bmaHSm4mYPNBNlUfFsauSTxLrUH4CUSAKYvr1v76qkk"},"origin_server_ts":1758918670276,"prev_events":["$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:beeper.com","signatures":{"beeper.com":{"ed25519:a_zgvp":"D3cz3m15m89a3G4c5yWOBCjhtSeI5IxBfQKt5XOr9a44QHyc3nwjjvIJaRrKNcS5tLUJwZ2IpVzjlrpbPHpxDA"}},"state_key":"@tulir:beeper.com","type":"m.room.member","unsigned":{"age":6,"event_id":"$_hayW1Y0HRWp3VEGZZbsMf0Ncg9x6n0ikveD0lbCwMw","replaces_state":"$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro"}}
+{"auth_events":["$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0","$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":100},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:beeper.com":9000,"@tulir:envs.net":9001,"@tulir:matrix.org":8999},"users_default":0},"depth":21,"hashes":{"sha256":"xCj9vszChHiXba9DaPzhtF79Tphek3pRViMp36DOurU"},"origin_server_ts":1758918689485,"prev_events":["$_hayW1Y0HRWp3VEGZZbsMf0Ncg9x6n0ikveD0lbCwMw"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"odkrWD30+ObeYtagULtECB/QmGae7qNy66nmJMWYXiQMYUJw/GMzSmgAiLAWfVYlfD3aEvMb/CBdrhL07tfSBw"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$di6cI89-GxX8-Wbx-0T69l4wg6TUWITRkjWXzG7EBqo","replaces_state":"$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"}}
diff --git a/federation/httpclient.go b/federation/httpclient.go
index d6d97280..2f8dbb4f 100644
--- a/federation/httpclient.go
+++ b/federation/httpclient.go
@@ -12,7 +12,6 @@ import (
"net"
"net/http"
"sync"
- "time"
)
// ServerResolvingTransport is an http.RoundTripper that resolves Matrix server names before sending requests.
@@ -22,17 +21,20 @@ type ServerResolvingTransport struct {
Transport *http.Transport
Dialer *net.Dialer
- cache map[string]*ResolvedServerName
- resolveLocks map[string]*sync.Mutex
- cacheLock sync.Mutex
+ cache ResolutionCache
+
+ resolveLocks map[string]*sync.Mutex
+ resolveLocksLock sync.Mutex
}
-func NewServerResolvingTransport() *ServerResolvingTransport {
+func NewServerResolvingTransport(cache ResolutionCache) *ServerResolvingTransport {
+ if cache == nil {
+ cache = NewInMemoryCache()
+ }
srt := &ServerResolvingTransport{
- cache: make(map[string]*ResolvedServerName),
resolveLocks: make(map[string]*sync.Mutex),
-
- Dialer: &net.Dialer{},
+ cache: cache,
+ Dialer: &net.Dialer{},
}
srt.Transport = &http.Transport{
DialContext: srt.DialContext,
@@ -50,12 +52,6 @@ func (srt *ServerResolvingTransport) DialContext(ctx context.Context, network, a
return srt.Dialer.DialContext(ctx, network, addrs[0])
}
-type contextKey int
-
-const (
- contextKeyIPPort contextKey = iota
-)
-
func (srt *ServerResolvingTransport) RoundTrip(request *http.Request) (*http.Response, error) {
if request.URL.Scheme != "matrix-federation" {
return nil, fmt.Errorf("unsupported scheme: %s", request.URL.Scheme)
@@ -72,37 +68,25 @@ func (srt *ServerResolvingTransport) RoundTrip(request *http.Request) (*http.Res
}
func (srt *ServerResolvingTransport) resolve(ctx context.Context, serverName string) (*ResolvedServerName, error) {
- res, lock := srt.getResolveCache(serverName)
- if res != nil {
- return res, nil
+ srt.resolveLocksLock.Lock()
+ lock, ok := srt.resolveLocks[serverName]
+ if !ok {
+ lock = &sync.Mutex{}
+ srt.resolveLocks[serverName] = lock
}
+ srt.resolveLocksLock.Unlock()
+
lock.Lock()
defer lock.Unlock()
- res, _ = srt.getResolveCache(serverName)
- if res != nil {
+ res, err := srt.cache.LoadResolution(serverName)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read cache: %w", err)
+ } else if res != nil {
+ return res, nil
+ } else if res, err = ResolveServerName(ctx, serverName, srt.ResolveOpts); err != nil {
+ return nil, err
+ } else {
+ srt.cache.StoreResolution(res)
return res, nil
}
- var err error
- res, err = ResolveServerName(ctx, serverName, srt.ResolveOpts)
- if err != nil {
- return nil, err
- }
- srt.cacheLock.Lock()
- srt.cache[serverName] = res
- srt.cacheLock.Unlock()
- return res, nil
-}
-
-func (srt *ServerResolvingTransport) getResolveCache(serverName string) (*ResolvedServerName, *sync.Mutex) {
- srt.cacheLock.Lock()
- defer srt.cacheLock.Unlock()
- if val, ok := srt.cache[serverName]; ok && time.Until(val.Expires) > 0 {
- return val, nil
- }
- rl, ok := srt.resolveLocks[serverName]
- if !ok {
- rl = &sync.Mutex{}
- srt.resolveLocks[serverName] = rl
- }
- return nil, rl
}
diff --git a/federation/keyserver.go b/federation/keyserver.go
index 3e74bfdf..d32ba5cf 100644
--- a/federation/keyserver.go
+++ b/federation/keyserver.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2024 Tulir Asokan
+// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -8,13 +8,17 @@ package federation
import (
"encoding/json"
- "fmt"
"net/http"
"strconv"
"time"
- "github.com/gorilla/mux"
+ "github.com/rs/zerolog"
+ "github.com/rs/zerolog/hlog"
+ "go.mau.fi/util/exerrors"
+ "go.mau.fi/util/exhttp"
"go.mau.fi/util/jsontime"
+ "go.mau.fi/util/ptr"
+ "go.mau.fi/util/requestlog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/id"
@@ -47,34 +51,29 @@ type KeyServer struct {
KeyProvider ServerKeyProvider
Version ServerVersion
WellKnownTarget string
+ OtherKeys KeyCache
}
// Register registers the key server endpoints to the given router.
-func (ks *KeyServer) Register(r *mux.Router) {
- r.HandleFunc("/.well-known/matrix/server", ks.GetWellKnown).Methods(http.MethodGet)
- r.HandleFunc("/_matrix/federation/v1/version", ks.GetServerVersion).Methods(http.MethodGet)
- keyRouter := r.PathPrefix("/_matrix/key").Subrouter()
- keyRouter.HandleFunc("/v2/server", ks.GetServerKey).Methods(http.MethodGet)
- keyRouter.HandleFunc("/v2/query/{serverName}", ks.GetQueryKeys).Methods(http.MethodGet)
- keyRouter.HandleFunc("/v2/query", ks.PostQueryKeys).Methods(http.MethodPost)
- keyRouter.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- jsonResponse(w, http.StatusNotFound, &mautrix.RespError{
- ErrCode: mautrix.MUnrecognized.ErrCode,
- Err: "Unrecognized endpoint",
- })
- })
- keyRouter.MethodNotAllowedHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- jsonResponse(w, http.StatusMethodNotAllowed, &mautrix.RespError{
- ErrCode: mautrix.MUnrecognized.ErrCode,
- Err: "Invalid method for endpoint",
- })
- })
-}
-
-func jsonResponse(w http.ResponseWriter, code int, data any) {
- w.Header().Add("Content-Type", "application/json")
- w.WriteHeader(code)
- _ = json.NewEncoder(w).Encode(data)
+func (ks *KeyServer) Register(r *http.ServeMux, log zerolog.Logger) {
+ r.HandleFunc("GET /.well-known/matrix/server", ks.GetWellKnown)
+ r.HandleFunc("GET /_matrix/federation/v1/version", ks.GetServerVersion)
+ keyRouter := http.NewServeMux()
+ keyRouter.HandleFunc("GET /v2/server", ks.GetServerKey)
+ keyRouter.HandleFunc("GET /v2/query/{serverName}", ks.GetQueryKeys)
+ keyRouter.HandleFunc("POST /v2/query", ks.PostQueryKeys)
+ errorBodies := exhttp.ErrorBodies{
+ NotFound: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint")).MarshalJSON()),
+ MethodNotAllowed: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint")).MarshalJSON()),
+ }
+ r.Handle("/_matrix/key/", exhttp.ApplyMiddleware(
+ keyRouter,
+ exhttp.StripPrefix("/_matrix/key"),
+ hlog.NewHandler(log),
+ hlog.RequestIDHandler("request_id", "Request-Id"),
+ requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}),
+ exhttp.HandleErrors(errorBodies),
+ ))
}
// RespWellKnown is the response body for the `GET /.well-known/matrix/server` endpoint.
@@ -87,12 +86,9 @@ type RespWellKnown struct {
// https://spec.matrix.org/v1.9/server-server-api/#get_well-knownmatrixserver
func (ks *KeyServer) GetWellKnown(w http.ResponseWriter, r *http.Request) {
if ks.WellKnownTarget == "" {
- jsonResponse(w, http.StatusNotFound, &mautrix.RespError{
- ErrCode: mautrix.MNotFound.ErrCode,
- Err: "No well-known target set",
- })
+ mautrix.MNotFound.WithMessage("No well-known target set").Write(w)
} else {
- jsonResponse(w, http.StatusOK, &RespWellKnown{Server: ks.WellKnownTarget})
+ exhttp.WriteJSONResponse(w, http.StatusOK, &RespWellKnown{Server: ks.WellKnownTarget})
}
}
@@ -105,7 +101,7 @@ type RespServerVersion struct {
//
// https://spec.matrix.org/v1.9/server-server-api/#get_matrixfederationv1version
func (ks *KeyServer) GetServerVersion(w http.ResponseWriter, r *http.Request) {
- jsonResponse(w, http.StatusOK, &RespServerVersion{Server: ks.Version})
+ exhttp.WriteJSONResponse(w, http.StatusOK, &RespServerVersion{Server: ks.Version})
}
// GetServerKey implements the `GET /_matrix/key/v2/server` endpoint.
@@ -114,12 +110,9 @@ func (ks *KeyServer) GetServerVersion(w http.ResponseWriter, r *http.Request) {
func (ks *KeyServer) GetServerKey(w http.ResponseWriter, r *http.Request) {
domain, key := ks.KeyProvider.Get(r)
if key == nil {
- jsonResponse(w, http.StatusNotFound, &mautrix.RespError{
- ErrCode: mautrix.MNotFound.ErrCode,
- Err: fmt.Sprintf("No signing key found for %q", r.Host),
- })
+ mautrix.MNotFound.WithMessage("No signing key found for %q", r.Host).Write(w)
} else {
- jsonResponse(w, http.StatusOK, key.GenerateKeyResponse(domain, nil))
+ exhttp.WriteJSONResponse(w, http.StatusOK, key.GenerateKeyResponse(domain, nil))
}
}
@@ -144,10 +137,7 @@ func (ks *KeyServer) PostQueryKeys(w http.ResponseWriter, r *http.Request) {
var req ReqQueryKeys
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
- jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{
- ErrCode: mautrix.MBadJSON.ErrCode,
- Err: fmt.Sprintf("failed to parse request: %v", err),
- })
+ mautrix.MBadJSON.WithMessage("failed to parse request: %v", err).Write(w)
return
}
@@ -165,7 +155,7 @@ func (ks *KeyServer) PostQueryKeys(w http.ResponseWriter, r *http.Request) {
}
}
}
- jsonResponse(w, http.StatusOK, resp)
+ exhttp.WriteJSONResponse(w, http.StatusOK, resp)
}
// GetQueryKeysResponse is the response body for the `GET /_matrix/key/v2/query/{serverName}` endpoint
@@ -177,27 +167,39 @@ type GetQueryKeysResponse struct {
//
// https://spec.matrix.org/v1.9/server-server-api/#get_matrixkeyv2queryservername
func (ks *KeyServer) GetQueryKeys(w http.ResponseWriter, r *http.Request) {
- serverName := mux.Vars(r)["serverName"]
+ serverName := r.PathValue("serverName")
minimumValidUntilTSString := r.URL.Query().Get("minimum_valid_until_ts")
minimumValidUntilTS, err := strconv.ParseInt(minimumValidUntilTSString, 10, 64)
if err != nil && minimumValidUntilTSString != "" {
- jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{
- ErrCode: mautrix.MInvalidParam.ErrCode,
- Err: fmt.Sprintf("failed to parse ?minimum_valid_until_ts: %v", err),
- })
+ mautrix.MInvalidParam.WithMessage("failed to parse ?minimum_valid_until_ts: %v", err).Write(w)
return
} else if time.UnixMilli(minimumValidUntilTS).After(time.Now().Add(24 * time.Hour)) {
- jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{
- ErrCode: mautrix.MInvalidParam.ErrCode,
- Err: "minimum_valid_until_ts may not be more than 24 hours in the future",
- })
+ mautrix.MInvalidParam.WithMessage("minimum_valid_until_ts may not be more than 24 hours in the future").Write(w)
return
}
resp := &GetQueryKeysResponse{
ServerKeys: []*ServerKeyResponse{},
}
- if domain, key := ks.KeyProvider.Get(r); key != nil && domain == serverName {
- resp.ServerKeys = append(resp.ServerKeys, key.GenerateKeyResponse(serverName, nil))
+ domain, key := ks.KeyProvider.Get(r)
+ if domain == serverName {
+ if key != nil {
+ resp.ServerKeys = append(resp.ServerKeys, key.GenerateKeyResponse(serverName, nil))
+ }
+ } else if ks.OtherKeys != nil {
+ otherKey, err := ks.OtherKeys.LoadKeys(serverName)
+ if err != nil {
+ mautrix.MUnknown.WithMessage("Failed to load keys from cache").Write(w)
+ return
+ }
+ if key != nil && domain != "" {
+ signature, err := key.SignJSON(otherKey)
+ if err == nil {
+ otherKey.Signatures[domain] = map[id.KeyID]string{
+ key.ID: signature,
+ }
+ }
+ }
+ resp.ServerKeys = append(resp.ServerKeys, otherKey)
}
- jsonResponse(w, http.StatusOK, resp)
+ exhttp.WriteJSONResponse(w, http.StatusOK, resp)
}
diff --git a/federation/pdu/auth.go b/federation/pdu/auth.go
new file mode 100644
index 00000000..16706fe5
--- /dev/null
+++ b/federation/pdu/auth.go
@@ -0,0 +1,71 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu
+
+import (
+ "slices"
+
+ "github.com/tidwall/gjson"
+ "go.mau.fi/util/exgjson"
+
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+)
+
+type StateKey struct {
+ Type string
+ StateKey string
+}
+
+var thirdPartyInviteTokenPath = exgjson.Path("third_party_invite", "signed", "token")
+
+type AuthEventSelection []StateKey
+
+func (aes *AuthEventSelection) Add(evtType, stateKey string) {
+ key := StateKey{Type: evtType, StateKey: stateKey}
+ if !aes.Has(key) {
+ *aes = append(*aes, key)
+ }
+}
+
+func (aes *AuthEventSelection) Has(key StateKey) bool {
+ return slices.Contains(*aes, key)
+}
+
+func (pdu *PDU) AuthEventSelection(roomVersion id.RoomVersion) (keys AuthEventSelection) {
+ if pdu.Type == event.StateCreate.Type && pdu.StateKey != nil {
+ return AuthEventSelection{}
+ }
+ keys = make(AuthEventSelection, 0, 3)
+ if !roomVersion.RoomIDIsCreateEventID() {
+ keys.Add(event.StateCreate.Type, "")
+ }
+ keys.Add(event.StatePowerLevels.Type, "")
+ keys.Add(event.StateMember.Type, pdu.Sender.String())
+ if pdu.Type == event.StateMember.Type && pdu.StateKey != nil {
+ keys.Add(event.StateMember.Type, *pdu.StateKey)
+ membership := event.Membership(gjson.GetBytes(pdu.Content, "membership").Str)
+ if membership == event.MembershipJoin || membership == event.MembershipInvite || membership == event.MembershipKnock {
+ keys.Add(event.StateJoinRules.Type, "")
+ }
+ if membership == event.MembershipInvite {
+ thirdPartyInviteToken := gjson.GetBytes(pdu.Content, thirdPartyInviteTokenPath).Str
+ if thirdPartyInviteToken != "" {
+ keys.Add(event.StateThirdPartyInvite.Type, thirdPartyInviteToken)
+ }
+ }
+ if membership == event.MembershipJoin && roomVersion.RestrictedJoins() {
+ authorizedVia := gjson.GetBytes(pdu.Content, "authorised_via_users_server").Str
+ if authorizedVia != "" {
+ keys.Add(event.StateMember.Type, authorizedVia)
+ }
+ }
+ }
+ return
+}
diff --git a/federation/pdu/hash.go b/federation/pdu/hash.go
new file mode 100644
index 00000000..38ef83e9
--- /dev/null
+++ b/federation/pdu/hash.go
@@ -0,0 +1,118 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu
+
+import (
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/base64"
+ "fmt"
+
+ "github.com/tidwall/gjson"
+
+ "maunium.net/go/mautrix/id"
+)
+
+func (pdu *PDU) CalculateContentHash() ([32]byte, error) {
+ if pdu == nil {
+ return [32]byte{}, ErrPDUIsNil
+ }
+ pduClone := pdu.Clone()
+ pduClone.Signatures = nil
+ pduClone.Unsigned = nil
+ pduClone.Hashes = nil
+ rawJSON, err := marshalCanonical(pduClone)
+ if err != nil {
+ return [32]byte{}, fmt.Errorf("failed to marshal PDU to calculate content hash: %w", err)
+ }
+ return sha256.Sum256(rawJSON), nil
+}
+
+func (pdu *PDU) FillContentHash() error {
+ if pdu == nil {
+ return ErrPDUIsNil
+ } else if pdu.Hashes != nil {
+ return nil
+ } else if hash, err := pdu.CalculateContentHash(); err != nil {
+ return err
+ } else {
+ pdu.Hashes = &Hashes{SHA256: hash[:]}
+ return nil
+ }
+}
+
+func (pdu *PDU) VerifyContentHash() bool {
+ if pdu == nil || pdu.Hashes == nil {
+ return false
+ }
+ calculatedHash, err := pdu.CalculateContentHash()
+ if err != nil {
+ return false
+ }
+ return hmac.Equal(calculatedHash[:], pdu.Hashes.SHA256)
+}
+
+func (pdu *PDU) GetRoomID() (id.RoomID, error) {
+ if pdu == nil {
+ return "", ErrPDUIsNil
+ } else if pdu.Type != "m.room.create" {
+ return "", fmt.Errorf("room ID can only be calculated for m.room.create events")
+ } else if roomVersion := id.RoomVersion(gjson.GetBytes(pdu.Content, "room_version").Str); !roomVersion.RoomIDIsCreateEventID() {
+ return "", fmt.Errorf("room version %s does not use m.room.create event ID as room ID", roomVersion)
+ } else if evtID, err := pdu.calculateEventID(roomVersion, '!'); err != nil {
+ return "", fmt.Errorf("failed to calculate event ID: %w", err)
+ } else {
+ return id.RoomID(evtID), nil
+ }
+}
+
+var UseInternalMetaForGetEventID = false
+
+func (pdu *PDU) GetEventID(roomVersion id.RoomVersion) (id.EventID, error) {
+ if UseInternalMetaForGetEventID && pdu.InternalMeta.EventID != "" {
+ return pdu.InternalMeta.EventID, nil
+ }
+ return pdu.calculateEventID(roomVersion, '$')
+}
+
+func (pdu *PDU) GetReferenceHash(roomVersion id.RoomVersion) ([32]byte, error) {
+ if pdu == nil {
+ return [32]byte{}, ErrPDUIsNil
+ }
+ if pdu.Hashes == nil || pdu.Hashes.SHA256 == nil {
+ if err := pdu.FillContentHash(); err != nil {
+ return [32]byte{}, err
+ }
+ }
+ rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion))
+ if err != nil {
+ return [32]byte{}, fmt.Errorf("failed to marshal redacted PDU to calculate event ID: %w", err)
+ }
+ return sha256.Sum256(rawJSON), nil
+}
+
+func (pdu *PDU) calculateEventID(roomVersion id.RoomVersion, prefix byte) (id.EventID, error) {
+ referenceHash, err := pdu.GetReferenceHash(roomVersion)
+ if err != nil {
+ return "", err
+ }
+ eventID := make([]byte, 44)
+ eventID[0] = prefix
+ switch roomVersion.EventIDFormat() {
+ case id.EventIDFormatCustom:
+ return "", fmt.Errorf("*pdu.PDU can only be used for room v3+")
+ case id.EventIDFormatBase64:
+ base64.RawStdEncoding.Encode(eventID[1:], referenceHash[:])
+ case id.EventIDFormatURLSafeBase64:
+ base64.RawURLEncoding.Encode(eventID[1:], referenceHash[:])
+ default:
+ return "", fmt.Errorf("unknown event ID format %v", roomVersion.EventIDFormat())
+ }
+ return id.EventID(eventID), nil
+}
diff --git a/federation/pdu/hash_test.go b/federation/pdu/hash_test.go
new file mode 100644
index 00000000..17417e12
--- /dev/null
+++ b/federation/pdu/hash_test.go
@@ -0,0 +1,55 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu_test
+
+import (
+ "encoding/base64"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "go.mau.fi/util/exerrors"
+)
+
+func TestPDU_CalculateContentHash(t *testing.T) {
+ for _, test := range testPDUs {
+ if test.redacted {
+ continue
+ }
+ t.Run(test.name, func(t *testing.T) {
+ parsed := parsePDU(test.pdu)
+ contentHash := exerrors.Must(parsed.CalculateContentHash())
+ assert.Equal(
+ t,
+ base64.RawStdEncoding.EncodeToString(parsed.Hashes.SHA256),
+ base64.RawStdEncoding.EncodeToString(contentHash[:]),
+ )
+ })
+ }
+}
+
+func TestPDU_VerifyContentHash(t *testing.T) {
+ for _, test := range testPDUs {
+ if test.redacted {
+ continue
+ }
+ t.Run(test.name, func(t *testing.T) {
+ parsed := parsePDU(test.pdu)
+ assert.True(t, parsed.VerifyContentHash())
+ })
+ }
+}
+
+func TestPDU_GetEventID(t *testing.T) {
+ for _, test := range testPDUs {
+ t.Run(test.name, func(t *testing.T) {
+ gotEventID := exerrors.Must(parsePDU(test.pdu).GetEventID(test.roomVersion))
+ assert.Equal(t, test.eventID, gotEventID)
+ })
+ }
+}
diff --git a/federation/pdu/pdu.go b/federation/pdu/pdu.go
new file mode 100644
index 00000000..17db6995
--- /dev/null
+++ b/federation/pdu/pdu.go
@@ -0,0 +1,156 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu
+
+import (
+ "bytes"
+ "crypto/ed25519"
+ "encoding/json/jsontext"
+ "encoding/json/v2"
+ "errors"
+ "fmt"
+ "strings"
+ "time"
+
+ "github.com/tidwall/gjson"
+ "go.mau.fi/util/jsonbytes"
+ "go.mau.fi/util/ptr"
+
+ "maunium.net/go/mautrix/crypto/canonicaljson"
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+)
+
+// GetKeyFunc is a callback for retrieving the key corresponding to a given key ID when verifying the signature of a PDU.
+//
+// The input time is the timestamp of the event. The function should attempt to fetch a key that is
+// valid at or after this time, but if that is not possible, the latest available key should be
+// returned without an error. The verify function will do its own validity checking based on the
+// returned valid until timestamp.
+type GetKeyFunc = func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error)
+
+type AnyPDU interface {
+ GetRoomID() (id.RoomID, error)
+ GetEventID(roomVersion id.RoomVersion) (id.EventID, error)
+ GetReferenceHash(roomVersion id.RoomVersion) ([32]byte, error)
+ CalculateContentHash() ([32]byte, error)
+ FillContentHash() error
+ VerifyContentHash() bool
+ Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error
+ VerifySignature(roomVersion id.RoomVersion, serverName string, getKey GetKeyFunc) error
+ ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error)
+ AuthEventSelection(roomVersion id.RoomVersion) (keys AuthEventSelection)
+}
+
+var (
+ _ AnyPDU = (*PDU)(nil)
+ _ AnyPDU = (*RoomV1PDU)(nil)
+)
+
+type InternalMeta struct {
+ EventID id.EventID `json:"event_id,omitempty"`
+ Rejected bool `json:"rejected,omitempty"`
+ Extra map[string]any `json:",unknown"`
+}
+
+type PDU struct {
+ AuthEvents []id.EventID `json:"auth_events"`
+ Content jsontext.Value `json:"content"`
+ Depth int64 `json:"depth"`
+ Hashes *Hashes `json:"hashes,omitzero"`
+ OriginServerTS int64 `json:"origin_server_ts"`
+ PrevEvents []id.EventID `json:"prev_events"`
+ Redacts *id.EventID `json:"redacts,omitzero"`
+ RoomID id.RoomID `json:"room_id,omitzero"` // not present for room v12+ create events
+ Sender id.UserID `json:"sender"`
+ Signatures map[string]map[id.KeyID]string `json:"signatures,omitzero"`
+ StateKey *string `json:"state_key,omitzero"`
+ Type string `json:"type"`
+ Unsigned jsontext.Value `json:"unsigned,omitzero"`
+ InternalMeta InternalMeta `json:"-"`
+
+ Unknown jsontext.Value `json:",unknown"`
+
+ // Deprecated legacy fields
+ DeprecatedPrevState jsontext.Value `json:"prev_state,omitzero"`
+ DeprecatedOrigin jsontext.Value `json:"origin,omitzero"`
+ DeprecatedMembership jsontext.Value `json:"membership,omitzero"`
+}
+
+var ErrPDUIsNil = errors.New("PDU is nil")
+
+type Hashes struct {
+ SHA256 jsonbytes.UnpaddedBytes `json:"sha256"`
+
+ Unknown jsontext.Value `json:",unknown"`
+}
+
+func (pdu *PDU) ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error) {
+ if pdu.Type == "m.room.create" && roomVersion == "" {
+ roomVersion = id.RoomVersion(gjson.GetBytes(pdu.Content, "room_version").Str)
+ }
+ evtType := event.Type{Type: pdu.Type, Class: event.MessageEventType}
+ if pdu.StateKey != nil {
+ evtType.Class = event.StateEventType
+ }
+ eventID, err := pdu.GetEventID(roomVersion)
+ if err != nil {
+ return nil, err
+ }
+ roomID := pdu.RoomID
+ if pdu.Type == "m.room.create" && roomVersion.RoomIDIsCreateEventID() {
+ roomID = id.RoomID(strings.Replace(string(eventID), "$", "!", 1))
+ }
+ evt := &event.Event{
+ StateKey: pdu.StateKey,
+ Sender: pdu.Sender,
+ Type: evtType,
+ Timestamp: pdu.OriginServerTS,
+ ID: eventID,
+ RoomID: roomID,
+ Redacts: ptr.Val(pdu.Redacts),
+ }
+ err = json.Unmarshal(pdu.Content, &evt.Content)
+ if err != nil {
+ return nil, fmt.Errorf("failed to unmarshal content: %w", err)
+ }
+ return evt, nil
+}
+
+func (pdu *PDU) AddSignature(serverName string, keyID id.KeyID, signature string) {
+ if signature == "" {
+ return
+ }
+ if pdu.Signatures == nil {
+ pdu.Signatures = make(map[string]map[id.KeyID]string)
+ }
+ if _, ok := pdu.Signatures[serverName]; !ok {
+ pdu.Signatures[serverName] = make(map[id.KeyID]string)
+ }
+ pdu.Signatures[serverName][keyID] = signature
+}
+
+func marshalCanonical(data any) (jsontext.Value, error) {
+ marshaledBytes, err := json.Marshal(data)
+ if err != nil {
+ return nil, err
+ }
+ marshaled := jsontext.Value(marshaledBytes)
+ err = marshaled.Canonicalize()
+ if err != nil {
+ return nil, err
+ }
+ check := canonicaljson.CanonicalJSONAssumeValid(marshaled)
+ if !bytes.Equal(marshaled, check) {
+ fmt.Println(string(marshaled))
+ fmt.Println(string(check))
+ return nil, fmt.Errorf("canonical JSON mismatch for %s", string(marshaled))
+ }
+ return marshaled, nil
+}
diff --git a/federation/pdu/pdu_test.go b/federation/pdu/pdu_test.go
new file mode 100644
index 00000000..59d7c3a6
--- /dev/null
+++ b/federation/pdu/pdu_test.go
@@ -0,0 +1,193 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu_test
+
+import (
+ "encoding/json/v2"
+ "time"
+
+ "go.mau.fi/util/exerrors"
+
+ "maunium.net/go/mautrix/federation/pdu"
+ "maunium.net/go/mautrix/id"
+)
+
+type serverKey struct {
+ key id.SigningKey
+ validUntilTS time.Time
+}
+
+type serverDetails struct {
+ serverName string
+ keys map[id.KeyID]serverKey
+}
+
+func (sd serverDetails) getKey(serverName string, keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) {
+ if serverName != sd.serverName {
+ return "", time.Time{}, nil
+ }
+ key, ok := sd.keys[keyID]
+ if ok {
+ return key.key, key.validUntilTS, nil
+ }
+ return "", time.Time{}, nil
+}
+
+var mauniumNet = serverDetails{
+ serverName: "maunium.net",
+ keys: map[id.KeyID]serverKey{
+ "ed25519:a_xxeS": {
+ key: "lVt/CC3tv74OH6xTph2JrUmeRj/j+1q0HVa0Xf4QlCg",
+ validUntilTS: time.Now(),
+ },
+ },
+}
+var envsNet = serverDetails{
+ serverName: "envs.net",
+ keys: map[id.KeyID]serverKey{
+ "ed25519:a_zIqy": {
+ key: "vCUcZpt9hUn0aabfh/9GP/6sZvXcydww8DUstPHdJm0",
+ validUntilTS: time.UnixMilli(1722360538068),
+ },
+ "ed25519:wuJyKT": {
+ key: "xbE1QssgomL4wCSlyMYF5/7KxVyM4HPwAbNa+nFFnx0",
+ validUntilTS: time.Now(),
+ },
+ },
+}
+var matrixOrg = serverDetails{
+ serverName: "matrix.org",
+ keys: map[id.KeyID]serverKey{
+ "ed25519:auto": {
+ key: "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw",
+ validUntilTS: time.UnixMilli(1576767829750),
+ },
+ "ed25519:a_RXGa": {
+ key: "l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ",
+ validUntilTS: time.Now(),
+ },
+ },
+}
+var continuwuityOrg = serverDetails{
+ serverName: "continuwuity.org",
+ keys: map[id.KeyID]serverKey{
+ "ed25519:PwHlNsFu": {
+ key: "8eNx2s0zWW+heKAmOH5zKv/nCPkEpraDJfGHxDu6hFI",
+ validUntilTS: time.Now(),
+ },
+ },
+}
+var novaAstraltechOrg = serverDetails{
+ serverName: "nova.astraltech.org",
+ keys: map[id.KeyID]serverKey{
+ "ed25519:a_afpo": {
+ key: "O1Y9GWuKo9xkuzuQef6gROxtTgxxAbS3WPNghPYXF3o",
+ validUntilTS: time.Now(),
+ },
+ },
+}
+
+type testPDU struct {
+ name string
+ pdu string
+ eventID id.EventID
+ roomVersion id.RoomVersion
+ redacted bool
+ serverDetails
+}
+
+var roomV4MessageTestPDU = testPDU{
+ name: "m.room.message in v4 room",
+ pdu: `{"auth_events":["$OB87jNemaIVDHAfu0-pa_cP7OPFXUXCbFpjYVi8gll4","$RaWbTF9wQfGQgUpe1S13wzICtGTB2PNKRHUNHu9IO1c","$ZmEWOXw6cC4Rd1wTdY5OzeLJVzjhrkxFPwwKE4gguGk"],"content":{"body":"the last one is saying it shouldn't have effects","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":13103,"hashes":{"sha256":"c2wb8qMlvzIPCP1Wd+eYZ4BRgnGYxS97dR1UlJjVMeg"},"origin_server_ts":1752875275263,"prev_events":["$-7_BMI3BXwj3ayoxiJvraJxYWTKwjiQ6sh7CW_Brvj0"],"room_id":"!JiiOHXrIUCtcOJsZCa:matrix.org","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"99TAqHpBkUEtgCraXsVXogmf/hnijPbgbG9eACtA+mbix3Y6gURI4QGQgcX/NhcE3pJQZ/YDjmbuvCnKvEccAA"}},"unsigned":{"age_ts":1752875275281}}`,
+ eventID: "$Jo_lmFR-e6lzrimzCA7DevIn2OwhuQYmd9xkcJBoqAA",
+ roomVersion: id.RoomV4,
+ serverDetails: mauniumNet,
+}
+
+var roomV12MessageTestPDU = testPDU{
+ name: "m.room.message in v12 room",
+ pdu: `{"auth_events":["$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA","$hyeL_nU_L3tsZ2dtZZpAHk0Skv-PqFQIipuII_By584"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":122,"hashes":{"sha256":"IQ0zlc+PXeEs6R3JvRkW3xTPV3zlGKSSd3x07KXGjzs"},"origin_server_ts":1755384351627,"prev_events":["$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA"],"room_id":"!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ","sender":"@tulir_test:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"0GDMddL2k7gF4V1VU8sL3wTfhAIzAu5iVH5jeavZ2VEg3J9/tHLWXAOn2tzkLaMRWl0/XpINT2YlH/rd2U21Ag"}},"unsigned":{"age_ts":1755384351627}}`,
+ eventID: "$xmP-wZfpannuHG-Akogi6c4YvqxChMtdyYbUMGOrMWc",
+ roomVersion: id.RoomV12,
+ serverDetails: mauniumNet,
+}
+
+var testPDUs = []testPDU{roomV4MessageTestPDU, {
+ name: "m.room.message in v5 room",
+ pdu: `{"auth_events":["$hp0ImHqYgHTRbLeWKPeTeFmxdb5SdMJN9cfmTrTk7d0","$KAj7X7tnJbR9qYYMWJSw-1g414_KlPptbbkZm7_kUtg","$V-2ShOwZYhA_nxMijaf3lqFgIJgzE2UMeFPtOLnoBYM"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":2248,"hashes":{"sha256":"kV+JuLbWXJ2r6PjHT3wt8bFc/TfI1nTaSN3Lamg/xHs"},"origin_server_ts":1755422945654,"prev_events":["$49lFLem2Nk4dxHk9RDXxTdaq9InIJpmkHpzVnjKcYwg"],"room_id":"!vzBgJsjNzgHSdWsmki:mozilla.org","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"JIl60uVgfCLBZLPoSiE7wVkJ9U5cNEPVPuv1sCCYUOq5yOW56WD1adgpBUdX2UFpYkCHvkRnyQGxU0+6HBp5BA"}},"unsigned":{"age_ts":1755422945673}}`,
+ eventID: "$Qn4tHfuAe6PlnKXPZnygAU9wd6RXqMKtt_ZzstHTSgA",
+ roomVersion: id.RoomV5,
+ serverDetails: mauniumNet,
+}, {
+ name: "m.room.message in v10 room",
+ pdu: `{"auth_events":["$--ilpwnsHaEdHrwiMrZNu5xHP6TthWG0FIXMHnlHCcs","$tn1FZUI_YUpfTr_a3Y_r8kC3inliIZZratzg0UsNdCQ","$Z-qMWmiMvm-aIEffcfSO6lN7TyjyTOsIcHIymfzoo20"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":100885,"hashes":{"sha256":"jc9272JPpPIVreJC3UEAm3BNVnLX8sm3U/TZs23wsHo"},"origin_server_ts":1755422792518,"prev_events":["$HDtbzpSys36Hk-F2NsiXfp9slsGXBH0b58qyddj_q5E"],"room_id":"!UzZHbJYcgggctGnlzr:envs.net","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"sAMLo9jPtNB0Jq67IQm06siEBx82qZa2edu56IDQ4tDylEV4Mq7iFO23gCghqXA7B/MqBsjXotGBxv6AvlJ2Dw"}},"unsigned":{"age_ts":1755422792540}}`,
+ eventID: "$4ZFr_ypfp4DyZQP4zyxM_cvuOMFkl07doJmwi106YFY",
+ roomVersion: id.RoomV10,
+ serverDetails: mauniumNet,
+}, {
+ name: "m.room.message in v11 room",
+ pdu: `{"auth_events":["$L8Ak6A939llTRIsZrytMlLDXQhI4uLEjx-wb1zSg-Bw","$QJmr7mmGeXGD4Tof0ZYSPW2oRGklseyHTKtZXnF-YNM","$7bkKK_Z-cGQ6Ae4HXWGBwXyZi3YjC6rIcQzGfVyl3Eo"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":3212,"hashes":{"sha256":"K549YdTnv62Jn84Y7sS5ZN3+AdmhleZHbenbhUpR2R8"},"origin_server_ts":1754242687127,"prev_events":["$DAhJg4jVsqk5FRatE2hbT1dSA8D2ASy5DbjEHIMSHwY"],"room_id":"!offtopic-2:continuwuity.org","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"SkzZdZ+rH22kzCBBIAErTdB0Vg6vkFmzvwjlOarGul72EnufgtE/tJcd3a8szAdK7f1ZovRyQxDgVm/Ib2u0Aw"}},"unsigned":{"age_ts":1754242687146}}`,
+ eventID: `$qkWfTL7_l3oRZO2CItW8-Q0yAmi_l_1ua629ZDqponE`,
+ roomVersion: id.RoomV11,
+ serverDetails: mauniumNet,
+}, roomV12MessageTestPDU, {
+ name: "m.room.create in v4 room",
+ pdu: `{"auth_events": [], "prev_events": [], "type": "m.room.create", "room_id": "!jxlRxnrZCsjpjDubDX:matrix.org", "sender": "@neilj:matrix.org", "content": {"room_version": "4", "predecessor": {"room_id": "!DYgXKezaHgMbiPMzjX:matrix.org", "event_id": "$156171636353XwPJT:matrix.org"}, "creator": "@neilj:matrix.org"}, "depth": 1, "prev_state": [], "state_key": "", "origin": "matrix.org", "origin_server_ts": 1561716363993, "hashes": {"sha256": "9tj8GpXjTAJvdNAbnuKLemZZk+Tjv2LAbGodSX6nJAo"}, "signatures": {"matrix.org": {"ed25519:auto": "2+sNt8uJUhzU4GPxnFVYtU2ZRgFdtVLT1vEZGUdJYN40zBpwYEGJy+kyb5matA+8/yLeYD9gu1O98lhleH0aCA"}}, "unsigned": {"age": 104769}}`,
+ eventID: "$ay_9_nPilrTpb3UxIwHHBBfFjTJb6hBAE_JzQwSjqeY",
+ roomVersion: id.RoomV4,
+ serverDetails: matrixOrg,
+}, {
+ name: "m.room.create in v10 room",
+ pdu: `{"auth_events":[],"content":{"creator":"@creme:envs.net","predecessor":{"event_id":"$BxYNisKcyBDhPLiVC06t18qhv7wsT72MzMCqn5vRhfY","room_id":"!tEyFYiMHhwJlDXTxwf:envs.net"},"room_version":"10"},"depth":1,"hashes":{"sha256":"us3TrsIjBWpwbm+k3F9fUVnz9GIuhnb+LcaY47fWwUI"},"origin":"envs.net","origin_server_ts":1664394769527,"prev_events":[],"room_id":"!UzZHbJYcgggctGnlzr:envs.net","sender":"@creme:envs.net","state_key":"","type":"m.room.create","signatures":{"envs.net":{"ed25519:a_zIqy":"0g3FDaD1e5BekJYW2sR7dgxuKoZshrf8P067c9+jmH6frsWr2Ua86Ax08CFa/n46L8uvV2SGofP8iiVYgXCRBg"}},"unsigned":{"age":2060}}`,
+ eventID: "$tn1FZUI_YUpfTr_a3Y_r8kC3inliIZZratzg0UsNdCQ",
+ roomVersion: id.RoomV10,
+ serverDetails: envsNet,
+}, {
+ name: "m.room.create in v12 room",
+ pdu: `{"auth_events":[],"content":{"fi.mau.randomness":"AAXZ6aIc","predecessor":{"room_id":"!#test/room\nversion 11, with @\ud83d\udc08\ufe0f:maunium.net"},"room_version":"12"},"depth":1,"hashes":{"sha256":"d3L1M3KUdyIKWcShyW6grUoJ8GOjCdSIEvQrDVHSpE8"},"origin_server_ts":1754940000000,"prev_events":[],"sender":"@tulir:maunium.net","state_key":"","type":"m.room.create","signatures":{"maunium.net":{"ed25519:a_xxeS":"ebjIRpzToc82cjb/RGY+VUzZic0yeRZrjctgx0SUTJxkprXn3/i1KdiYULfl/aD0cUJ5eL8gLakOSk2glm+sBw"}},"unsigned":{"age_ts":1754939139045}}`,
+ eventID: "$mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ",
+ roomVersion: id.RoomV12,
+ serverDetails: mauniumNet,
+}, {
+ name: "m.room.member in v4 room",
+ pdu: `{"auth_events":["$ay_9_nPilrTpb3UxIwHHBBfFjTJb6hBAE_JzQwSjqeY","$jg2AgCfnwnjR-osoyM0lVYS21QrtfmZxhGO90PRkmO4","$wMGMP4Ucij2_d4h_fVDgIT2xooLZAgMcBruT9oo3Jio","$yyDgV8w0_e8qslmn0nh9OeSq_fO0zjpjTjSEdKFxDso"],"prev_events":["$zSjNuTXhUe3Rq6NpKD3sNyl8a_asMnBhGC5IbacHlJ4"],"type":"m.room.member","room_id":"!jxlRxnrZCsjpjDubDX:matrix.org","sender":"@tulir:maunium.net","content":{"membership":"join","displayname":"tulir","avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","clicked \"send membership event with no changes\"":true},"depth":14370,"prev_state":[],"state_key":"@tulir:maunium.net","origin":"maunium.net","origin_server_ts":1600871136259,"hashes":{"sha256":"Ga6bG9Mk0887ruzM9TAAfa1O3DbNssb+qSFtE9oeRL4"},"signatures":{"maunium.net":{"ed25519:a_xxeS":"fzOyDG3G3pEzixtWPttkRA1DfnHETiKbiG8SEBQe2qycQbZWPky7xX8WujSrUJH/+bxTABpQwEH49d+RakxtBw"}},"unsigned":{"age_ts":1600871136259,"replaces_state":"$jg2AgCfnwnjR-osoyM0lVYS21QrtfmZxhGO90PRkmO4"}}`,
+ eventID: "$VtuCNOfAWGow-cxy0ajeK3fvONcC8QzF2yWa43g0Gwo",
+ roomVersion: id.RoomV4,
+ serverDetails: mauniumNet,
+}, {
+ name: "m.room.member in v10 room",
+ pdu: `{"auth_events":["$HQC4hWaioLKVbMH94qKbfb3UnL4ocql2vi-VdUYI48I","$R9FUDgNAp9ms7b6ASunZOIkpqmsIRq_ROrNEznu62fs","$kEPF8Aj87EzRmFPriu2zdyEY0rY15XSqywTYVLUUlCA","$tn1FZUI_YUpfTr_a3Y_r8kC3inliIZZratzg0UsNdCQ"],"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"depth":182,"hashes":{"sha256":"0HscBc921QV2dxK2qY7qrnyoAgfxBM7kKvqAXlEk+GE"},"origin":"maunium.net","origin_server_ts":1665402609039,"prev_events":["$R9FUDgNAp9ms7b6ASunZOIkpqmsIRq_ROrNEznu62fs"],"room_id":"!UzZHbJYcgggctGnlzr:envs.net","sender":"@tulir:maunium.net","state_key":"@tulir:maunium.net","type":"m.room.member","signatures":{"maunium.net":{"ed25519:a_xxeS":"lkOW0FSJ8MJ0wZpdwLH1Uf6FSl2q9/u6KthRIlM0CwHDJG4sIZ9DrMA8BdU8L/PWoDS/CoDUlLanDh99SplgBw"}},"unsigned":{"age_ts":1665402609039,"replaces_state":"$R9FUDgNAp9ms7b6ASunZOIkpqmsIRq_ROrNEznu62fs"}}`,
+ eventID: "$--ilpwnsHaEdHrwiMrZNu5xHP6TthWG0FIXMHnlHCcs",
+ roomVersion: id.RoomV10,
+ serverDetails: mauniumNet,
+}, {
+ name: "m.room.member of creator in v12 room",
+ pdu: `{"auth_events":[],"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"depth":2,"hashes":{"sha256":"IebdOBYaaWYIx2zq/lkVCnjWIXTLk1g+vgFpJMgd2/E"},"origin_server_ts":1754939139117,"prev_events":["$mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ"],"room_id":"!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ","sender":"@tulir:maunium.net","state_key":"@tulir:maunium.net","type":"m.room.member","signatures":{"maunium.net":{"ed25519:a_xxeS":"rFCgF2hmavdm6+P6/f7rmuOdoSOmELFaH3JdWjgBLZXS2z51Ma7fa2v2+BkAH1FvBo9FLhvEoFVM4WbNQLXtAA"}},"unsigned":{"age_ts":1754939139117}}`,
+ eventID: "$accqGxfvhBvMP4Sf6P7t3WgnaJK6UbonO2ZmwqSE5Sg",
+ roomVersion: id.RoomV12,
+ serverDetails: mauniumNet,
+}, {
+ name: "custom message event in v4 room",
+ pdu: `{"auth_events":["$VtuCNOfAWGow-cxy0ajeK3fvONcC8QzF2yWa43g0Gwo","$ay_9_nPilrTpb3UxIwHHBBfFjTJb6hBAE_JzQwSjqeY","$Gau_XwziYsr-rt3SouhbKN14twgmbKjcZZc_hz-nOgU"],"content":{"\ud83d\udc08\ufe0f":true,"\ud83d\udc15\ufe0f":false},"depth":69645,"hashes":{"sha256":"VHtWyCt+15ZesNnStU3FOkxrjzHJYZfd3JUgO9JWe0s"},"origin_server_ts":1755423939146,"prev_events":["$exmp4cj0OKOFSxuqBYiOYwQi5j_0XRc78d6EavAkhy0"],"room_id":"!jxlRxnrZCsjpjDubDX:matrix.org","sender":"@tulir:maunium.net","type":"\ud83d\udc08\ufe0f","signatures":{"maunium.net":{"ed25519:a_xxeS":"wfmP1XN4JBkKVkqrQnwysyEUslXt8hQRFwN9NC9vJaIeDMd0OJ6uqCas75808DuG71p23fzqbzhRnHckst6FCQ"}},"unsigned":{"age_ts":1755423939164}}`,
+ eventID: "$kAagtZAIEeZaLVCUSl74tAxQbdKbE22GU7FM-iAJBc0",
+ roomVersion: id.RoomV4,
+ serverDetails: mauniumNet,
+}, {
+ name: "redacted m.room.member event in v11 room with 2 signatures",
+ pdu: `{"auth_events":["$9f12-_stoY07BOTmyguE1QlqvghLBh9Rk6PWRLoZn_M","$IP8hyjBkIDREVadyv0fPCGAW9IXGNllaZyxqQwiY_tA","$7dN5J8EveliaPkX6_QSejl4GQtem4oieavgALMeWZyE"],"content":{"membership":"join"},"depth":96978,"hashes":{"sha256":"APYA/aj3u+P0EwNaEofuSIlfqY3cK3lBz6RkwHX+Zak"},"origin_server_ts":1755664164485,"prev_events":["$XBN9W5Ll8VEH3eYqJaemxCBTDdy0hZB0sWpmyoUp93c"],"room_id":"!main-1:continuwuity.org","sender":"@6a19abdd4766:nova.astraltech.org","state_key":"@6a19abdd4766:nova.astraltech.org","type":"m.room.member","signatures":{"continuwuity.org":{"ed25519:PwHlNsFu":"+b/Fp2vWnC+Z2lI3GnCu7ZHdo3iWNDZ2AJqMoU9owMtLBPMxs4dVIsJXvaFq0ryawsgwDwKZ7f4xaFUNARJSDg"},"nova.astraltech.org":{"ed25519:a_afpo":"pXIngyxKukCPR7WOIIy8FTZxQ5L2dLiou5Oc8XS4WyY4YzJuckQzOaToigLLZxamfbN/jXbO+XUizpRpYccDAA"}},"unsigned":{}}`,
+ eventID: "$r6d9m125YWG28-Tln47bWtm6Jlv4mcSUWJTHijBlXLQ",
+ roomVersion: id.RoomV11,
+ serverDetails: novaAstraltechOrg,
+ redacted: true,
+}}
+
+func parsePDU(pdu string) (out *pdu.PDU) {
+ exerrors.PanicIfNotNil(json.Unmarshal([]byte(pdu), &out))
+ return
+}
diff --git a/federation/pdu/redact.go b/federation/pdu/redact.go
new file mode 100644
index 00000000..d7ee0c15
--- /dev/null
+++ b/federation/pdu/redact.go
@@ -0,0 +1,111 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu
+
+import (
+ "encoding/json/jsontext"
+
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
+ "go.mau.fi/util/exgjson"
+ "go.mau.fi/util/ptr"
+
+ "maunium.net/go/mautrix/id"
+)
+
+func filteredObject(object jsontext.Value, allowedPaths ...string) jsontext.Value {
+ filtered := jsontext.Value("{}")
+ var err error
+ for _, path := range allowedPaths {
+ res := gjson.GetBytes(object, path)
+ if res.Exists() {
+ var raw jsontext.Value
+ if res.Index > 0 {
+ raw = object[res.Index : res.Index+len(res.Raw)]
+ } else {
+ raw = jsontext.Value(res.Raw)
+ }
+ filtered, err = sjson.SetRawBytes(filtered, path, raw)
+ if err != nil {
+ panic(err)
+ }
+ }
+ }
+ return filtered
+}
+
+func (pdu *PDU) Clone() *PDU {
+ return ptr.Clone(pdu)
+}
+
+func (pdu *PDU) RedactForSignature(roomVersion id.RoomVersion) *PDU {
+ pdu.Signatures = nil
+ return pdu.Redact(roomVersion)
+}
+
+var emptyObject = jsontext.Value("{}")
+
+func RedactContent(eventType string, content jsontext.Value, roomVersion id.RoomVersion) jsontext.Value {
+ switch eventType {
+ case "m.room.member":
+ allowedPaths := []string{"membership"}
+ if roomVersion.RestrictedJoinsFix() {
+ allowedPaths = append(allowedPaths, "join_authorised_via_users_server")
+ }
+ if roomVersion.UpdatedRedactionRules() {
+ allowedPaths = append(allowedPaths, exgjson.Path("third_party_invite", "signed"))
+ }
+ return filteredObject(content, allowedPaths...)
+ case "m.room.create":
+ if !roomVersion.UpdatedRedactionRules() {
+ return filteredObject(content, "creator")
+ }
+ return content
+ case "m.room.join_rules":
+ if roomVersion.RestrictedJoins() {
+ return filteredObject(content, "join_rule", "allow")
+ }
+ return filteredObject(content, "join_rule")
+ case "m.room.power_levels":
+ allowedKeys := []string{"ban", "events", "events_default", "kick", "redact", "state_default", "users", "users_default"}
+ if roomVersion.UpdatedRedactionRules() {
+ allowedKeys = append(allowedKeys, "invite")
+ }
+ return filteredObject(content, allowedKeys...)
+ case "m.room.history_visibility":
+ return filteredObject(content, "history_visibility")
+ case "m.room.redaction":
+ if roomVersion.RedactsInContent() {
+ return filteredObject(content, "redacts")
+ }
+ return emptyObject
+ case "m.room.aliases":
+ if roomVersion.SpecialCasedAliasesAuth() {
+ return filteredObject(content, "aliases")
+ }
+ return emptyObject
+ default:
+ return emptyObject
+ }
+}
+
+func (pdu *PDU) Redact(roomVersion id.RoomVersion) *PDU {
+ pdu.Unknown = nil
+ pdu.Unsigned = nil
+ if roomVersion.UpdatedRedactionRules() {
+ pdu.DeprecatedPrevState = nil
+ pdu.DeprecatedOrigin = nil
+ pdu.DeprecatedMembership = nil
+ }
+ if pdu.Type != "m.room.redaction" || roomVersion.RedactsInContent() {
+ pdu.Redacts = nil
+ }
+ pdu.Content = RedactContent(pdu.Type, pdu.Content, roomVersion)
+ return pdu
+}
diff --git a/federation/pdu/signature.go b/federation/pdu/signature.go
new file mode 100644
index 00000000..04e7c5ef
--- /dev/null
+++ b/federation/pdu/signature.go
@@ -0,0 +1,60 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu
+
+import (
+ "crypto/ed25519"
+ "encoding/base64"
+ "fmt"
+ "time"
+
+ "maunium.net/go/mautrix/federation/signutil"
+ "maunium.net/go/mautrix/id"
+)
+
+func (pdu *PDU) Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error {
+ err := pdu.FillContentHash()
+ if err != nil {
+ return err
+ }
+ rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion))
+ if err != nil {
+ return fmt.Errorf("failed to marshal redacted PDU to sign: %w", err)
+ }
+ signature := ed25519.Sign(privateKey, rawJSON)
+ pdu.AddSignature(serverName, keyID, base64.RawStdEncoding.EncodeToString(signature))
+ return nil
+}
+
+func (pdu *PDU) VerifySignature(roomVersion id.RoomVersion, serverName string, getKey GetKeyFunc) error {
+ rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion))
+ if err != nil {
+ return fmt.Errorf("failed to marshal redacted PDU to verify signature: %w", err)
+ }
+ verified := false
+ for keyID, sig := range pdu.Signatures[serverName] {
+ originServerTS := time.UnixMilli(pdu.OriginServerTS)
+ key, validUntil, err := getKey(serverName, keyID, originServerTS)
+ if err != nil {
+ return fmt.Errorf("failed to get key %s for %s: %w", keyID, serverName, err)
+ } else if key == "" {
+ return fmt.Errorf("key %s not found for %s", keyID, serverName)
+ } else if validUntil.Before(originServerTS) && roomVersion.EnforceSigningKeyValidity() {
+ return fmt.Errorf("key %s for %s is only valid until %s, but event is from %s", keyID, serverName, validUntil, originServerTS)
+ } else if err = signutil.VerifyJSONRaw(key, sig, rawJSON); err != nil {
+ return fmt.Errorf("failed to verify signature from key %s: %w", keyID, err)
+ } else {
+ verified = true
+ }
+ }
+ if !verified {
+ return fmt.Errorf("no verifiable signatures found for server %s", serverName)
+ }
+ return nil
+}
diff --git a/federation/pdu/signature_test.go b/federation/pdu/signature_test.go
new file mode 100644
index 00000000..01df5076
--- /dev/null
+++ b/federation/pdu/signature_test.go
@@ -0,0 +1,102 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu_test
+
+import (
+ "crypto/ed25519"
+ "encoding/base64"
+ "encoding/json/jsontext"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.mau.fi/util/exerrors"
+
+ "maunium.net/go/mautrix/federation/pdu"
+ "maunium.net/go/mautrix/id"
+)
+
+func TestPDU_VerifySignature(t *testing.T) {
+ for _, test := range testPDUs {
+ t.Run(test.name, func(t *testing.T) {
+ parsed := parsePDU(test.pdu)
+ err := parsed.VerifySignature(test.roomVersion, test.serverName, test.getKey)
+ assert.NoError(t, err)
+ })
+ }
+}
+
+func TestPDU_VerifySignature_Fail_NoKey(t *testing.T) {
+ test := roomV12MessageTestPDU
+ parsed := parsePDU(test.pdu)
+ err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
+ return
+ })
+ assert.Error(t, err)
+}
+
+func TestPDU_VerifySignature_V4ExpiredKey(t *testing.T) {
+ test := roomV4MessageTestPDU
+ parsed := parsePDU(test.pdu)
+ err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
+ key = test.keys[keyID].key
+ validUntil = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
+ return
+ })
+ assert.NoError(t, err)
+}
+
+func TestPDU_VerifySignature_V12ExpiredKey(t *testing.T) {
+ test := roomV12MessageTestPDU
+ parsed := parsePDU(test.pdu)
+ err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
+ key = test.keys[keyID].key
+ validUntil = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
+ return
+ })
+ assert.Error(t, err)
+}
+
+func TestPDU_VerifySignature_V12InvalidSignature(t *testing.T) {
+ test := roomV12MessageTestPDU
+ parsed := parsePDU(test.pdu)
+ for _, sigs := range parsed.Signatures {
+ for key := range sigs {
+ sigs[key] = sigs[key][:len(sigs[key])-3] + "ABC"
+ }
+ }
+ err := parsed.VerifySignature(test.roomVersion, test.serverName, test.getKey)
+ assert.Error(t, err)
+}
+
+func TestPDU_Sign(t *testing.T) {
+ pubKey, privKey := exerrors.Must2(ed25519.GenerateKey(nil))
+ evt := &pdu.PDU{
+ AuthEvents: []id.EventID{"$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA", "$hyeL_nU_L3tsZ2dtZZpAHk0Skv-PqFQIipuII_By584"},
+ Content: jsontext.Value(`{"msgtype":"m.text","body":"Hello, world!"}`),
+ Depth: 123,
+ OriginServerTS: 1755384351627,
+ PrevEvents: []id.EventID{"$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA"},
+ RoomID: "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ",
+ Sender: "@tulir:example.com",
+ Type: "m.room.message",
+ }
+ err := evt.Sign(id.RoomV12, "example.com", "ed25519:rand", privKey)
+ require.NoError(t, err)
+ err = evt.VerifySignature(id.RoomV11, "example.com", func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
+ if serverName == "example.com" && keyID == "ed25519:rand" {
+ key = id.SigningKey(base64.RawStdEncoding.EncodeToString(pubKey))
+ validUntil = time.Now()
+ }
+ return
+ })
+ require.NoError(t, err)
+
+}
diff --git a/federation/pdu/v1.go b/federation/pdu/v1.go
new file mode 100644
index 00000000..9557f8ab
--- /dev/null
+++ b/federation/pdu/v1.go
@@ -0,0 +1,277 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu
+
+import (
+ "crypto/ed25519"
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/json/jsontext"
+ "encoding/json/v2"
+ "fmt"
+ "time"
+
+ "github.com/tidwall/gjson"
+ "go.mau.fi/util/ptr"
+
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/federation/signutil"
+ "maunium.net/go/mautrix/id"
+)
+
+type V1EventReference struct {
+ ID id.EventID
+ Hashes Hashes
+}
+
+var (
+ _ json.UnmarshalerFrom = (*V1EventReference)(nil)
+ _ json.MarshalerTo = (*V1EventReference)(nil)
+)
+
+func (er *V1EventReference) MarshalJSONTo(enc *jsontext.Encoder) error {
+ return json.MarshalEncode(enc, []any{er.ID, er.Hashes})
+}
+
+func (er *V1EventReference) UnmarshalJSONFrom(dec *jsontext.Decoder) error {
+ var ref V1EventReference
+ var data []jsontext.Value
+ if err := json.UnmarshalDecode(dec, &data); err != nil {
+ return err
+ } else if len(data) != 2 {
+ return fmt.Errorf("V1EventReference.UnmarshalJSONFrom: expected array with 2 elements, got %d", len(data))
+ } else if err = json.Unmarshal(data[0], &ref.ID); err != nil {
+ return fmt.Errorf("V1EventReference.UnmarshalJSONFrom: failed to unmarshal event ID: %w", err)
+ } else if err = json.Unmarshal(data[1], &ref.Hashes); err != nil {
+ return fmt.Errorf("V1EventReference.UnmarshalJSONFrom: failed to unmarshal hashes: %w", err)
+ }
+ *er = ref
+ return nil
+}
+
+type RoomV1PDU struct {
+ AuthEvents []V1EventReference `json:"auth_events"`
+ Content jsontext.Value `json:"content"`
+ Depth int64 `json:"depth"`
+ EventID id.EventID `json:"event_id"`
+ Hashes *Hashes `json:"hashes,omitzero"`
+ OriginServerTS int64 `json:"origin_server_ts"`
+ PrevEvents []V1EventReference `json:"prev_events"`
+ Redacts *id.EventID `json:"redacts,omitzero"`
+ RoomID id.RoomID `json:"room_id"`
+ Sender id.UserID `json:"sender"`
+ Signatures map[string]map[id.KeyID]string `json:"signatures,omitzero"`
+ StateKey *string `json:"state_key,omitzero"`
+ Type string `json:"type"`
+ Unsigned jsontext.Value `json:"unsigned,omitzero"`
+
+ Unknown jsontext.Value `json:",unknown"`
+
+ // Deprecated legacy fields
+ DeprecatedPrevState jsontext.Value `json:"prev_state,omitzero"`
+ DeprecatedOrigin jsontext.Value `json:"origin,omitzero"`
+ DeprecatedMembership jsontext.Value `json:"membership,omitzero"`
+}
+
+func (pdu *RoomV1PDU) GetRoomID() (id.RoomID, error) {
+ return pdu.RoomID, nil
+}
+
+func (pdu *RoomV1PDU) GetEventID(roomVersion id.RoomVersion) (id.EventID, error) {
+ if !pdu.SupportsRoomVersion(roomVersion) {
+ return "", fmt.Errorf("RoomV1PDU.GetEventID: unsupported room version %s", roomVersion)
+ }
+ return pdu.EventID, nil
+}
+
+func (pdu *RoomV1PDU) RedactForSignature(roomVersion id.RoomVersion) *RoomV1PDU {
+ pdu.Signatures = nil
+ return pdu.Redact(roomVersion)
+}
+
+func (pdu *RoomV1PDU) Redact(roomVersion id.RoomVersion) *RoomV1PDU {
+ pdu.Unknown = nil
+ pdu.Unsigned = nil
+ if pdu.Type != "m.room.redaction" {
+ pdu.Redacts = nil
+ }
+ pdu.Content = RedactContent(pdu.Type, pdu.Content, roomVersion)
+ return pdu
+}
+
+func (pdu *RoomV1PDU) GetReferenceHash(roomVersion id.RoomVersion) ([32]byte, error) {
+ if !pdu.SupportsRoomVersion(roomVersion) {
+ return [32]byte{}, fmt.Errorf("RoomV1PDU.GetReferenceHash: unsupported room version %s", roomVersion)
+ }
+ if pdu == nil {
+ return [32]byte{}, ErrPDUIsNil
+ }
+ if pdu.Hashes == nil || pdu.Hashes.SHA256 == nil {
+ if err := pdu.FillContentHash(); err != nil {
+ return [32]byte{}, err
+ }
+ }
+ rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion))
+ if err != nil {
+ return [32]byte{}, fmt.Errorf("failed to marshal redacted PDU to calculate event ID: %w", err)
+ }
+ return sha256.Sum256(rawJSON), nil
+}
+
+func (pdu *RoomV1PDU) CalculateContentHash() ([32]byte, error) {
+ if pdu == nil {
+ return [32]byte{}, ErrPDUIsNil
+ }
+ pduClone := pdu.Clone()
+ pduClone.Signatures = nil
+ pduClone.Unsigned = nil
+ pduClone.Hashes = nil
+ rawJSON, err := marshalCanonical(pduClone)
+ if err != nil {
+ return [32]byte{}, fmt.Errorf("failed to marshal PDU to calculate content hash: %w", err)
+ }
+ return sha256.Sum256(rawJSON), nil
+}
+
+func (pdu *RoomV1PDU) FillContentHash() error {
+ if pdu == nil {
+ return ErrPDUIsNil
+ } else if pdu.Hashes != nil {
+ return nil
+ } else if hash, err := pdu.CalculateContentHash(); err != nil {
+ return err
+ } else {
+ pdu.Hashes = &Hashes{SHA256: hash[:]}
+ return nil
+ }
+}
+
+func (pdu *RoomV1PDU) VerifyContentHash() bool {
+ if pdu == nil || pdu.Hashes == nil {
+ return false
+ }
+ calculatedHash, err := pdu.CalculateContentHash()
+ if err != nil {
+ return false
+ }
+ return hmac.Equal(calculatedHash[:], pdu.Hashes.SHA256)
+}
+
+func (pdu *RoomV1PDU) Clone() *RoomV1PDU {
+ return ptr.Clone(pdu)
+}
+
+func (pdu *RoomV1PDU) Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error {
+ if !pdu.SupportsRoomVersion(roomVersion) {
+ return fmt.Errorf("RoomV1PDU.Sign: unsupported room version %s", roomVersion)
+ }
+ err := pdu.FillContentHash()
+ if err != nil {
+ return err
+ }
+ rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion))
+ if err != nil {
+ return fmt.Errorf("failed to marshal redacted PDU to sign: %w", err)
+ }
+ signature := ed25519.Sign(privateKey, rawJSON)
+ if pdu.Signatures == nil {
+ pdu.Signatures = make(map[string]map[id.KeyID]string)
+ }
+ if _, ok := pdu.Signatures[serverName]; !ok {
+ pdu.Signatures[serverName] = make(map[id.KeyID]string)
+ }
+ pdu.Signatures[serverName][keyID] = base64.RawStdEncoding.EncodeToString(signature)
+ return nil
+}
+
+func (pdu *RoomV1PDU) VerifySignature(roomVersion id.RoomVersion, serverName string, getKey GetKeyFunc) error {
+ if !pdu.SupportsRoomVersion(roomVersion) {
+ return fmt.Errorf("RoomV1PDU.VerifySignature: unsupported room version %s", roomVersion)
+ }
+ rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion))
+ if err != nil {
+ return fmt.Errorf("failed to marshal redacted PDU to verify signature: %w", err)
+ }
+ verified := false
+ for keyID, sig := range pdu.Signatures[serverName] {
+ originServerTS := time.UnixMilli(pdu.OriginServerTS)
+ key, _, err := getKey(serverName, keyID, originServerTS)
+ if err != nil {
+ return fmt.Errorf("failed to get key %s for %s: %w", keyID, serverName, err)
+ } else if key == "" {
+ return fmt.Errorf("key %s not found for %s", keyID, serverName)
+ } else if err = signutil.VerifyJSONRaw(key, sig, rawJSON); err != nil {
+ return fmt.Errorf("failed to verify signature from key %s: %w", keyID, err)
+ } else {
+ verified = true
+ }
+ }
+ if !verified {
+ return fmt.Errorf("no verifiable signatures found for server %s", serverName)
+ }
+ return nil
+}
+
+func (pdu *RoomV1PDU) SupportsRoomVersion(roomVersion id.RoomVersion) bool {
+ switch roomVersion {
+ case id.RoomV0, id.RoomV1, id.RoomV2:
+ return true
+ default:
+ return false
+ }
+}
+
+func (pdu *RoomV1PDU) ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error) {
+ if !pdu.SupportsRoomVersion(roomVersion) {
+ return nil, fmt.Errorf("RoomV1PDU.ToClientEvent: unsupported room version %s", roomVersion)
+ }
+ evtType := event.Type{Type: pdu.Type, Class: event.MessageEventType}
+ if pdu.StateKey != nil {
+ evtType.Class = event.StateEventType
+ }
+ evt := &event.Event{
+ StateKey: pdu.StateKey,
+ Sender: pdu.Sender,
+ Type: evtType,
+ Timestamp: pdu.OriginServerTS,
+ ID: pdu.EventID,
+ RoomID: pdu.RoomID,
+ Redacts: ptr.Val(pdu.Redacts),
+ }
+ err := json.Unmarshal(pdu.Content, &evt.Content)
+ if err != nil {
+ return nil, fmt.Errorf("failed to unmarshal content: %w", err)
+ }
+ return evt, nil
+}
+
+func (pdu *RoomV1PDU) AuthEventSelection(_ id.RoomVersion) (keys AuthEventSelection) {
+ if pdu.Type == event.StateCreate.Type && pdu.StateKey != nil {
+ return AuthEventSelection{}
+ }
+ keys = make(AuthEventSelection, 0, 3)
+ keys.Add(event.StateCreate.Type, "")
+ keys.Add(event.StatePowerLevels.Type, "")
+ keys.Add(event.StateMember.Type, pdu.Sender.String())
+ if pdu.Type == event.StateMember.Type && pdu.StateKey != nil {
+ keys.Add(event.StateMember.Type, *pdu.StateKey)
+ membership := event.Membership(gjson.GetBytes(pdu.Content, "membership").Str)
+ if membership == event.MembershipJoin || membership == event.MembershipInvite || membership == event.MembershipKnock {
+ keys.Add(event.StateJoinRules.Type, "")
+ }
+ if membership == event.MembershipInvite {
+ thirdPartyInviteToken := gjson.GetBytes(pdu.Content, thirdPartyInviteTokenPath).Str
+ if thirdPartyInviteToken != "" {
+ keys.Add(event.StateThirdPartyInvite.Type, thirdPartyInviteToken)
+ }
+ }
+ }
+ return
+}
diff --git a/federation/pdu/v1_test.go b/federation/pdu/v1_test.go
new file mode 100644
index 00000000..ecf2dbd2
--- /dev/null
+++ b/federation/pdu/v1_test.go
@@ -0,0 +1,86 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu_test
+
+import (
+ "encoding/base64"
+ "encoding/json/v2"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "go.mau.fi/util/exerrors"
+
+ "maunium.net/go/mautrix/federation/pdu"
+ "maunium.net/go/mautrix/id"
+)
+
+var testV1PDUs = []testPDU{{
+ name: "m.room.message in v1 room",
+ pdu: `{"auth_events":[["$159234730483190eXavq:matrix.org",{"sha256":"VprZrhMqOQyKbfF3UE26JXE8D27ih4R/FGGc8GZ0Whs"}],["$143454825711DhCxH:matrix.org",{"sha256":"3sJh/5GOB094OKuhbjL634Gt69YIcge9GD55ciJa9ok"}],["$156837651426789wiPdh:maunium.net",{"sha256":"FGyR3sxJ/VxYabDkO/5qtwrPR3hLwGknJ0KX0w3GUHE"}]],"content":{"body":"photo-1526336024174-e58f5cdd8e13.jpg","info":{"h":1620,"mimetype":"image/jpeg","size":208053,"w":1080},"msgtype":"m.image","url":"mxc://maunium.net/aEqEghIjFPAerIhCxJCYpQeC"},"depth":16669,"event_id":"$16738169022163bokdi:maunium.net","hashes":{"sha256":"XYB47Gf2vAci3BTguIJaC75ZYGMuVY65jcvoUVgpcLA"},"origin":"maunium.net","origin_server_ts":1673816902100,"prev_events":[["$1673816901121325UMCjA:matrix.org",{"sha256":"t7e0IYHLI3ydIPoIU8a8E/pIWXH9cNLlQBEtGyGtHwc"}]],"room_id":"!jhpZBTbckszblMYjMK:matrix.org","sender":"@cat:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"uRZbEm+P+Y1ZVgwBn5I6SlaUZdzlH1bB4nv81yt5EIQ0b1fZ8YgM4UWMijrrXp3+NmqRFl0cakSM3MneJOtFCw"}},"unsigned":{"age_ts":1673816902100}}`,
+ eventID: "$16738169022163bokdi:maunium.net",
+ roomVersion: id.RoomV1,
+ serverDetails: mauniumNet,
+}, {
+ name: "m.room.create in v1 room",
+ pdu: `{"origin": "matrix.org", "signatures": {"matrix.org": {"ed25519:auto": "XTejpXn5REoHrZWgCpJglGX7MfOWS2zUjYwJRLrwW2PQPbFdqtL+JnprBXwIP2C1NmgWSKG+am1QdApu0KoHCQ"}}, "origin_server_ts": 1434548257426, "sender": "@appservice-irc:matrix.org", "event_id": "$143454825711DhCxH:matrix.org", "prev_events": [], "unsigned": {"age": 12872287834}, "state_key": "", "content": {"creator": "@appservice-irc:matrix.org"}, "depth": 1, "prev_state": [], "room_id": "!jhpZBTbckszblMYjMK:matrix.org", "auth_events": [], "hashes": {"sha256": "+SSdmeeoKI/6yK6sY4XAFljWFiugSlCiXQf0QMCZjTs"}, "type": "m.room.create"}`,
+ eventID: "$143454825711DhCxH:matrix.org",
+ roomVersion: id.RoomV1,
+ serverDetails: matrixOrg,
+}, {
+ name: "m.room.member in v1 room",
+ pdu: `{"auth_events": [["$1536447669931522zlyWe:matrix.org", {"sha256": "UkzPGd7cPAGvC0FVx3Yy2/Q0GZhA2kcgj8MGp5pjYV8"}], ["$143454825711DhCxH:matrix.org", {"sha256": "3sJh/5GOB094OKuhbjL634Gt69YIcge9GD55ciJa9ok"}], ["$143454825714nUEqZ:matrix.org", {"sha256": "NjuZXu8EDMfIfejPcNlC/IdnKQAGpPIcQjHaf0BZaHk"}]], "prev_events": [["$15660585503271JRRMm:maunium.net", {"sha256": "/Sm7uSLkYMHapp6I3NuEVJlk2JucW2HqjsQy9vzhciA"}]], "type": "m.room.member", "room_id": "!jhpZBTbckszblMYjMK:matrix.org", "sender": "@tulir:maunium.net", "content": {"membership": "join", "avatar_url": "mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO", "displayname": "tulir"}, "depth": 10485, "prev_state": [], "state_key": "@tulir:maunium.net", "event_id": "$15660585693272iEryv:maunium.net", "origin": "maunium.net", "origin_server_ts": 1566058569201, "hashes": {"sha256": "1D6fdDzKsMGCxSqlXPA7I9wGQNTutVuJke1enGHoWK8"}, "signatures": {"maunium.net": {"ed25519:a_xxeS": "Lj/zDK6ozr4vgsxyL8jY56wTGWoA4jnlvkTs5paCX1w3nNKHnQnSMi+wuaqI6yv5vYh9usGWco2LLMuMzYXcBg"}}, "unsigned": {"age_ts": 1566058569201, "replaces_state": "$15660585383268liyBc:maunium.net"}}`,
+ eventID: "$15660585693272iEryv:maunium.net",
+ roomVersion: id.RoomV1,
+ serverDetails: mauniumNet,
+}}
+
+func parseV1PDU(pdu string) (out *pdu.RoomV1PDU) {
+ exerrors.PanicIfNotNil(json.Unmarshal([]byte(pdu), &out))
+ return
+}
+
+func TestRoomV1PDU_CalculateContentHash(t *testing.T) {
+ for _, test := range testV1PDUs {
+ t.Run(test.name, func(t *testing.T) {
+ parsed := parseV1PDU(test.pdu)
+ contentHash := exerrors.Must(parsed.CalculateContentHash())
+ assert.Equal(
+ t,
+ base64.RawStdEncoding.EncodeToString(parsed.Hashes.SHA256),
+ base64.RawStdEncoding.EncodeToString(contentHash[:]),
+ )
+ })
+ }
+}
+
+func TestRoomV1PDU_VerifyContentHash(t *testing.T) {
+ for _, test := range testV1PDUs {
+ t.Run(test.name, func(t *testing.T) {
+ parsed := parseV1PDU(test.pdu)
+ assert.True(t, parsed.VerifyContentHash())
+ })
+ }
+}
+
+func TestRoomV1PDU_VerifySignature(t *testing.T) {
+ for _, test := range testV1PDUs {
+ t.Run(test.name, func(t *testing.T) {
+ parsed := parseV1PDU(test.pdu)
+ err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) {
+ key, ok := test.keys[keyID]
+ if ok {
+ return key.key, key.validUntilTS, nil
+ }
+ return "", time.Time{}, nil
+ })
+ assert.NoError(t, err)
+ })
+ }
+}
diff --git a/federation/resolution.go b/federation/resolution.go
index 24085282..a3188266 100644
--- a/federation/resolution.go
+++ b/federation/resolution.go
@@ -20,6 +20,8 @@ import (
"time"
"github.com/rs/zerolog"
+
+ "maunium.net/go/mautrix"
)
type ResolvedServerName struct {
@@ -78,7 +80,10 @@ func ResolveServerName(ctx context.Context, serverName string, opts ...*ResolveS
} else if wellKnown != nil {
output.Expires = expiry
output.HostHeader = wellKnown.Server
- hostname, port, ok = ParseServerName(wellKnown.Server)
+ wkHost, wkPort, ok := ParseServerName(wellKnown.Server)
+ if ok {
+ hostname, port = wkHost, wkPort
+ }
// Step 3.1 and 3.2: IP literals and hostnames with port inside .well-known
if net.ParseIP(hostname) != nil || port != 0 {
if port == 0 {
@@ -120,6 +125,38 @@ func RequestSRV(ctx context.Context, cli *net.Resolver, hostname string) ([]*net
return target, err
}
+func parseCacheControl(resp *http.Response) time.Duration {
+ cc := resp.Header.Get("Cache-Control")
+ if cc == "" {
+ return 0
+ }
+ parts := strings.Split(cc, ",")
+ for _, part := range parts {
+ kv := strings.SplitN(strings.TrimSpace(part), "=", 1)
+ switch kv[0] {
+ case "no-cache", "no-store":
+ return 0
+ case "max-age":
+ if len(kv) < 2 {
+ continue
+ }
+ maxAge, err := strconv.Atoi(kv[1])
+ if err != nil || maxAge < 0 {
+ continue
+ }
+ age, _ := strconv.Atoi(resp.Header.Get("Age"))
+ return time.Duration(maxAge-age) * time.Second
+ }
+ }
+ return 0
+}
+
+const (
+ MinCacheDuration = 1 * time.Hour
+ MaxCacheDuration = 72 * time.Hour
+ DefaultCacheDuration = 24 * time.Hour
+)
+
// RequestWellKnown sends a request to the well-known endpoint of a server and returns the response,
// plus the time when the cache should expire.
func RequestWellKnown(ctx context.Context, cli *http.Client, hostname string) (*RespWellKnown, time.Time, error) {
@@ -139,14 +176,23 @@ func RequestWellKnown(ctx context.Context, cli *http.Client, hostname string) (*
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, time.Time{}, fmt.Errorf("unexpected status code %d", resp.StatusCode)
+ } else if resp.ContentLength > mautrix.WellKnownMaxSize {
+ return nil, time.Time{}, fmt.Errorf("response too large: %d bytes", resp.ContentLength)
}
var respData RespWellKnown
- err = json.NewDecoder(io.LimitReader(resp.Body, 50*1024)).Decode(&respData)
+ err = json.NewDecoder(io.LimitReader(resp.Body, mautrix.WellKnownMaxSize)).Decode(&respData)
if err != nil {
return nil, time.Time{}, fmt.Errorf("failed to decode response: %w", err)
} else if respData.Server == "" {
return nil, time.Time{}, errors.New("server name not found in response")
}
- // TODO parse cache-control header
+ cacheDuration := parseCacheControl(resp)
+ if cacheDuration <= 0 {
+ cacheDuration = DefaultCacheDuration
+ } else if cacheDuration < MinCacheDuration {
+ cacheDuration = MinCacheDuration
+ } else if cacheDuration > MaxCacheDuration {
+ cacheDuration = MaxCacheDuration
+ }
return &respData, time.Now().Add(24 * time.Hour), nil
}
diff --git a/federation/serverauth.go b/federation/serverauth.go
new file mode 100644
index 00000000..cd300341
--- /dev/null
+++ b/federation/serverauth.go
@@ -0,0 +1,264 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package federation
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "maps"
+ "net/http"
+ "slices"
+ "strings"
+ "sync"
+
+ "github.com/rs/zerolog"
+ "go.mau.fi/util/ptr"
+
+ "maunium.net/go/mautrix"
+ "maunium.net/go/mautrix/id"
+)
+
+type ServerAuth struct {
+ Keys KeyCache
+ Client *Client
+ GetDestination func(XMatrixAuth) string
+ MaxBodySize int64
+
+ keyFetchLocks map[string]*sync.Mutex
+ keyFetchLocksLock sync.Mutex
+}
+
+func NewServerAuth(client *Client, keyCache KeyCache, getDestination func(auth XMatrixAuth) string) *ServerAuth {
+ return &ServerAuth{
+ Keys: keyCache,
+ Client: client,
+ GetDestination: getDestination,
+ MaxBodySize: 50 * 1024 * 1024,
+ keyFetchLocks: make(map[string]*sync.Mutex),
+ }
+}
+
+var MUnauthorized = mautrix.RespError{ErrCode: "M_UNAUTHORIZED", StatusCode: http.StatusUnauthorized}
+
+var (
+ errMissingAuthHeader = MUnauthorized.WithMessage("Missing Authorization header")
+ errInvalidAuthHeader = MUnauthorized.WithMessage("Authorization header does not start with X-Matrix")
+ errMalformedAuthHeader = MUnauthorized.WithMessage("X-Matrix value is missing required components")
+ errInvalidDestination = MUnauthorized.WithMessage("Invalid destination in X-Matrix header")
+ errFailedToQueryKeys = MUnauthorized.WithMessage("Failed to query server keys")
+ errInvalidSelfSignatures = MUnauthorized.WithMessage("Server keys don't have valid self-signatures")
+ errRequestBodyTooLarge = mautrix.MTooLarge.WithMessage("Request body too large")
+ errInvalidJSONBody = mautrix.MBadJSON.WithMessage("Request body is not valid JSON")
+ errBodyReadFailed = mautrix.MUnknown.WithMessage("Failed to read request body")
+ errInvalidRequestSignature = MUnauthorized.WithMessage("Failed to verify request signature")
+)
+
+type XMatrixAuth struct {
+ Origin string
+ Destination string
+ KeyID id.KeyID
+ Signature string
+}
+
+func (xma XMatrixAuth) String() string {
+ return fmt.Sprintf(
+ `X-Matrix origin="%s",destination="%s",key="%s",sig="%s"`,
+ xma.Origin,
+ xma.Destination,
+ xma.KeyID,
+ xma.Signature,
+ )
+}
+
+func ParseXMatrixAuth(auth string) (xma XMatrixAuth) {
+ auth = strings.TrimPrefix(auth, "X-Matrix ")
+ // TODO upgrade to strings.SplitSeq after Go 1.24 is the minimum
+ for _, part := range strings.Split(auth, ",") {
+ part = strings.TrimSpace(part)
+ eqIdx := strings.Index(part, "=")
+ if eqIdx == -1 || strings.Count(part, "=") > 1 {
+ continue
+ }
+ val := strings.Trim(part[eqIdx+1:], "\"")
+ switch strings.ToLower(part[:eqIdx]) {
+ case "origin":
+ xma.Origin = val
+ case "destination":
+ xma.Destination = val
+ case "key":
+ xma.KeyID = id.KeyID(val)
+ case "sig":
+ xma.Signature = val
+ }
+ }
+ return
+}
+
+func (sa *ServerAuth) GetKeysWithCache(ctx context.Context, serverName string, keyID id.KeyID) (*ServerKeyResponse, error) {
+ res, err := sa.Keys.LoadKeys(serverName)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read cache: %w", err)
+ } else if res.HasKey(keyID) {
+ return res, nil
+ }
+
+ sa.keyFetchLocksLock.Lock()
+ lock, ok := sa.keyFetchLocks[serverName]
+ if !ok {
+ lock = &sync.Mutex{}
+ sa.keyFetchLocks[serverName] = lock
+ }
+ sa.keyFetchLocksLock.Unlock()
+
+ lock.Lock()
+ defer lock.Unlock()
+ res, err = sa.Keys.LoadKeys(serverName)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read cache: %w", err)
+ } else if res != nil {
+ if res.HasKey(keyID) {
+ return res, nil
+ } else if !sa.Keys.ShouldReQuery(serverName) {
+ zerolog.Ctx(ctx).Trace().
+ Str("server_name", serverName).
+ Stringer("key_id", keyID).
+ Msg("Not sending key request for missing key ID, last query was too recent")
+ return res, nil
+ }
+ }
+ res, err = sa.Client.ServerKeys(ctx, serverName)
+ if err != nil {
+ sa.Keys.StoreFetchError(serverName, err)
+ return nil, err
+ }
+ sa.Keys.StoreKeys(res)
+ return res, nil
+}
+
+type fixedLimitedReader struct {
+ R io.Reader
+ N int64
+ Err error
+}
+
+func (l *fixedLimitedReader) Read(p []byte) (n int, err error) {
+ if l.N <= 0 {
+ return 0, l.Err
+ }
+ if int64(len(p)) > l.N {
+ p = p[0:l.N]
+ }
+ n, err = l.R.Read(p)
+ l.N -= int64(n)
+ return
+}
+
+func (sa *ServerAuth) Authenticate(r *http.Request) (*http.Request, *mautrix.RespError) {
+ defer func() {
+ _ = r.Body.Close()
+ }()
+ log := zerolog.Ctx(r.Context())
+ if r.ContentLength > sa.MaxBodySize {
+ return nil, &errRequestBodyTooLarge
+ }
+ auth := r.Header.Get("Authorization")
+ if auth == "" {
+ return nil, &errMissingAuthHeader
+ } else if !strings.HasPrefix(auth, "X-Matrix ") {
+ return nil, &errInvalidAuthHeader
+ }
+ parsed := ParseXMatrixAuth(auth)
+ if parsed.Origin == "" || parsed.KeyID == "" || parsed.Signature == "" {
+ log.Trace().Str("auth_header", auth).Msg("Malformed X-Matrix header")
+ return nil, &errMalformedAuthHeader
+ }
+ destination := sa.GetDestination(parsed)
+ if destination == "" || (parsed.Destination != "" && parsed.Destination != destination) {
+ log.Trace().
+ Str("got_destination", parsed.Destination).
+ Str("expected_destination", destination).
+ Msg("Invalid destination in X-Matrix header")
+ return nil, &errInvalidDestination
+ }
+ resp, err := sa.GetKeysWithCache(r.Context(), parsed.Origin, parsed.KeyID)
+ if err != nil {
+ if !errors.Is(err, ErrRecentKeyQueryFailed) {
+ log.Err(err).
+ Str("server_name", parsed.Origin).
+ Msg("Failed to query keys to authenticate request")
+ } else {
+ log.Trace().Err(err).
+ Str("server_name", parsed.Origin).
+ Msg("Failed to query keys to authenticate request (cached error)")
+ }
+ return nil, &errFailedToQueryKeys
+ } else if err := resp.VerifySelfSignature(); err != nil {
+ log.Trace().Err(err).
+ Str("server_name", parsed.Origin).
+ Msg("Failed to validate self-signatures of server keys")
+ return nil, &errInvalidSelfSignatures
+ }
+ key, ok := resp.VerifyKeys[parsed.KeyID]
+ if !ok {
+ keys := slices.Collect(maps.Keys(resp.VerifyKeys))
+ log.Trace().
+ Stringer("expected_key_id", parsed.KeyID).
+ Any("found_key_ids", keys).
+ Msg("Didn't find expected key ID to verify request")
+ return nil, ptr.Ptr(MUnauthorized.WithMessage("Key ID %q not found (got %v)", parsed.KeyID, keys))
+ }
+ var reqBody []byte
+ if r.ContentLength != 0 && r.Method != http.MethodGet && r.Method != http.MethodHead {
+ reqBody, err = io.ReadAll(&fixedLimitedReader{R: r.Body, N: sa.MaxBodySize, Err: errRequestBodyTooLarge})
+ if errors.Is(err, errRequestBodyTooLarge) {
+ return nil, &errRequestBodyTooLarge
+ } else if err != nil {
+ log.Err(err).
+ Str("server_name", parsed.Origin).
+ Msg("Failed to read request body to authenticate")
+ return nil, &errBodyReadFailed
+ } else if !json.Valid(reqBody) {
+ return nil, &errInvalidJSONBody
+ }
+ }
+ err = (&signableRequest{
+ Method: r.Method,
+ URI: r.URL.RequestURI(),
+ Origin: parsed.Origin,
+ Destination: destination,
+ Content: reqBody,
+ }).Verify(key.Key, parsed.Signature)
+ if err != nil {
+ log.Trace().Err(err).Msg("Request has invalid signature")
+ return nil, &errInvalidRequestSignature
+ }
+ ctx := context.WithValue(r.Context(), contextKeyDestinationServer, destination)
+ ctx = context.WithValue(ctx, contextKeyOriginServer, parsed.Origin)
+ ctx = log.With().
+ Str("origin_server_name", parsed.Origin).
+ Str("destination_server_name", destination).
+ Logger().WithContext(ctx)
+ modifiedReq := r.WithContext(ctx)
+ if reqBody != nil {
+ modifiedReq.Body = io.NopCloser(bytes.NewReader(reqBody))
+ }
+ return modifiedReq, nil
+}
+
+func (sa *ServerAuth) AuthenticateMiddleware(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if modifiedReq, err := sa.Authenticate(r); err != nil {
+ err.Write(w)
+ } else {
+ next.ServeHTTP(w, modifiedReq)
+ }
+ })
+}
diff --git a/federation/serverauth_test.go b/federation/serverauth_test.go
new file mode 100644
index 00000000..f99fc6cf
--- /dev/null
+++ b/federation/serverauth_test.go
@@ -0,0 +1,29 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package federation_test
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "maunium.net/go/mautrix/federation"
+)
+
+func TestServerKeyResponse_VerifySelfSignature(t *testing.T) {
+ cli := federation.NewClient("", nil, nil)
+ ctx := context.Background()
+ for _, name := range []string{"matrix.org", "maunium.net", "cd.mau.dev", "uwu.mau.dev"} {
+ t.Run(name, func(t *testing.T) {
+ resp, err := cli.ServerKeys(ctx, name)
+ require.NoError(t, err)
+ assert.NoError(t, resp.VerifySelfSignature())
+ })
+ }
+}
diff --git a/federation/signingkey.go b/federation/signingkey.go
index 67751b48..a4ad9679 100644
--- a/federation/signingkey.go
+++ b/federation/signingkey.go
@@ -14,9 +14,11 @@ import (
"strings"
"time"
+ "github.com/tidwall/sjson"
"go.mau.fi/util/jsontime"
"maunium.net/go/mautrix/crypto/canonicaljson"
+ "maunium.net/go/mautrix/federation/signutil"
"maunium.net/go/mautrix/id"
)
@@ -31,8 +33,8 @@ type SigningKey struct {
//
// The output of this function can be parsed back into a [SigningKey] using the [ParseSynapseKey] function.
func (sk *SigningKey) SynapseString() string {
- alg, id := sk.ID.Parse()
- return fmt.Sprintf("%s %s %s", alg, id, base64.RawStdEncoding.EncodeToString(sk.Priv.Seed()))
+ alg, keyID := sk.ID.Parse()
+ return fmt.Sprintf("%s %s %s", alg, keyID, base64.RawStdEncoding.EncodeToString(sk.Priv.Seed()))
}
// ParseSynapseKey parses a Synapse-compatible private key string into a SigningKey.
@@ -77,6 +79,37 @@ type ServerKeyResponse struct {
OldVerifyKeys map[id.KeyID]OldVerifyKey `json:"old_verify_keys,omitempty"`
Signatures map[string]map[id.KeyID]string `json:"signatures,omitempty"`
ValidUntilTS jsontime.UnixMilli `json:"valid_until_ts"`
+
+ Raw json.RawMessage `json:"-"`
+}
+
+type QueryKeysResponse struct {
+ ServerKeys []*ServerKeyResponse `json:"server_keys"`
+}
+
+func (skr *ServerKeyResponse) HasKey(keyID id.KeyID) bool {
+ if skr == nil {
+ return false
+ } else if _, ok := skr.VerifyKeys[keyID]; ok {
+ return true
+ }
+ return false
+}
+
+func (skr *ServerKeyResponse) VerifySelfSignature() error {
+ for keyID, key := range skr.VerifyKeys {
+ if err := signutil.VerifyJSON(skr.ServerName, keyID, key.Key, skr.Raw); err != nil {
+ return fmt.Errorf("failed to verify self signature for key %s: %w", keyID, err)
+ }
+ }
+ return nil
+}
+
+type marshalableSKR ServerKeyResponse
+
+func (skr *ServerKeyResponse) UnmarshalJSON(data []byte) error {
+ skr.Raw = data
+ return json.Unmarshal(data, (*marshalableSKR)(skr))
}
type ServerVerifyKey struct {
@@ -92,12 +125,16 @@ type OldVerifyKey struct {
ExpiredTS jsontime.UnixMilli `json:"expired_ts"`
}
-func (sk *SigningKey) SignJSON(data any) ([]byte, error) {
+func (sk *SigningKey) SignJSON(data any) (string, error) {
marshaled, err := json.Marshal(data)
if err != nil {
- return nil, err
+ return "", err
}
- return sk.SignRawJSON(marshaled), nil
+ marshaled, err = sjson.DeleteBytes(marshaled, "signatures")
+ if err != nil {
+ return "", err
+ }
+ return base64.RawStdEncoding.EncodeToString(sk.SignRawJSON(marshaled)), nil
}
func (sk *SigningKey) SignRawJSON(data json.RawMessage) []byte {
@@ -120,7 +157,7 @@ func (sk *SigningKey) GenerateKeyResponse(serverName string, oldVerifyKeys map[i
}
skr.Signatures = map[string]map[id.KeyID]string{
serverName: {
- sk.ID: base64.RawURLEncoding.EncodeToString(signature),
+ sk.ID: signature,
},
}
return skr
diff --git a/federation/signutil/verify.go b/federation/signutil/verify.go
new file mode 100644
index 00000000..ea0e7886
--- /dev/null
+++ b/federation/signutil/verify.go
@@ -0,0 +1,106 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package signutil
+
+import (
+ "crypto/ed25519"
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
+ "go.mau.fi/util/exgjson"
+
+ "maunium.net/go/mautrix/crypto/canonicaljson"
+ "maunium.net/go/mautrix/id"
+)
+
+var ErrSignatureNotFound = errors.New("signature not found")
+var ErrInvalidSignature = errors.New("invalid signature")
+
+func VerifyJSON(serverName string, keyID id.KeyID, key id.SigningKey, data any) error {
+ var err error
+ message, ok := data.(json.RawMessage)
+ if !ok {
+ message, err = json.Marshal(data)
+ if err != nil {
+ return fmt.Errorf("failed to marshal data: %w", err)
+ }
+ }
+ sigVal := gjson.GetBytes(message, exgjson.Path("signatures", serverName, string(keyID)))
+ if sigVal.Type != gjson.String {
+ return ErrSignatureNotFound
+ }
+ message, err = sjson.DeleteBytes(message, "signatures")
+ if err != nil {
+ return fmt.Errorf("failed to delete signatures: %w", err)
+ }
+ message, err = sjson.DeleteBytes(message, "unsigned")
+ if err != nil {
+ return fmt.Errorf("failed to delete unsigned: %w", err)
+ }
+ return VerifyJSONRaw(key, sigVal.Str, message)
+}
+
+func VerifyJSONAny(key id.SigningKey, data any) error {
+ var err error
+ message, ok := data.(json.RawMessage)
+ if !ok {
+ message, err = json.Marshal(data)
+ if err != nil {
+ return fmt.Errorf("failed to marshal data: %w", err)
+ }
+ }
+ sigs := gjson.GetBytes(message, "signatures")
+ if !sigs.IsObject() {
+ return ErrSignatureNotFound
+ }
+ message, err = sjson.DeleteBytes(message, "signatures")
+ if err != nil {
+ return fmt.Errorf("failed to delete signatures: %w", err)
+ }
+ message, err = sjson.DeleteBytes(message, "unsigned")
+ if err != nil {
+ return fmt.Errorf("failed to delete unsigned: %w", err)
+ }
+ var validated bool
+ sigs.ForEach(func(_, value gjson.Result) bool {
+ if !value.IsObject() {
+ return true
+ }
+ value.ForEach(func(_, value gjson.Result) bool {
+ if value.Type != gjson.String {
+ return true
+ }
+ validated = VerifyJSONRaw(key, value.Str, message) == nil
+ return !validated
+ })
+ return !validated
+ })
+ if !validated {
+ return ErrInvalidSignature
+ }
+ return nil
+}
+
+func VerifyJSONRaw(key id.SigningKey, sig string, message json.RawMessage) error {
+ sigBytes, err := base64.RawStdEncoding.DecodeString(sig)
+ if err != nil {
+ return fmt.Errorf("failed to decode signature: %w", err)
+ }
+ keyBytes, err := base64.RawStdEncoding.DecodeString(string(key))
+ if err != nil {
+ return fmt.Errorf("failed to decode key: %w", err)
+ }
+ message = canonicaljson.CanonicalJSONAssumeValid(message)
+ if !ed25519.Verify(keyBytes, message, sigBytes) {
+ return ErrInvalidSignature
+ }
+ return nil
+}
diff --git a/filter.go b/filter.go
index 2603bfb9..54973dab 100644
--- a/filter.go
+++ b/filter.go
@@ -19,45 +19,45 @@ const (
// Filter is used by clients to specify how the server should filter responses to e.g. sync requests
// Specified by: https://spec.matrix.org/v1.2/client-server-api/#filtering
type Filter struct {
- AccountData FilterPart `json:"account_data,omitempty"`
+ AccountData *FilterPart `json:"account_data,omitempty"`
EventFields []string `json:"event_fields,omitempty"`
EventFormat EventFormat `json:"event_format,omitempty"`
- Presence FilterPart `json:"presence,omitempty"`
- Room RoomFilter `json:"room,omitempty"`
+ Presence *FilterPart `json:"presence,omitempty"`
+ Room *RoomFilter `json:"room,omitempty"`
BeeperToDevice *FilterPart `json:"com.beeper.to_device,omitempty"`
}
// RoomFilter is used to define filtering rules for room events
type RoomFilter struct {
- AccountData FilterPart `json:"account_data,omitempty"`
- Ephemeral FilterPart `json:"ephemeral,omitempty"`
+ AccountData *FilterPart `json:"account_data,omitempty"`
+ Ephemeral *FilterPart `json:"ephemeral,omitempty"`
IncludeLeave bool `json:"include_leave,omitempty"`
NotRooms []id.RoomID `json:"not_rooms,omitempty"`
Rooms []id.RoomID `json:"rooms,omitempty"`
- State FilterPart `json:"state,omitempty"`
- Timeline FilterPart `json:"timeline,omitempty"`
+ State *FilterPart `json:"state,omitempty"`
+ Timeline *FilterPart `json:"timeline,omitempty"`
}
// FilterPart is used to define filtering rules for specific categories of events
type FilterPart struct {
- NotRooms []id.RoomID `json:"not_rooms,omitempty"`
- Rooms []id.RoomID `json:"rooms,omitempty"`
- Limit int `json:"limit,omitempty"`
- NotSenders []id.UserID `json:"not_senders,omitempty"`
- NotTypes []event.Type `json:"not_types,omitempty"`
- Senders []id.UserID `json:"senders,omitempty"`
- Types []event.Type `json:"types,omitempty"`
- ContainsURL *bool `json:"contains_url,omitempty"`
-
- LazyLoadMembers bool `json:"lazy_load_members,omitempty"`
- IncludeRedundantMembers bool `json:"include_redundant_members,omitempty"`
+ NotRooms []id.RoomID `json:"not_rooms,omitempty"`
+ Rooms []id.RoomID `json:"rooms,omitempty"`
+ Limit int `json:"limit,omitempty"`
+ NotSenders []id.UserID `json:"not_senders,omitempty"`
+ NotTypes []event.Type `json:"not_types,omitempty"`
+ Senders []id.UserID `json:"senders,omitempty"`
+ Types []event.Type `json:"types,omitempty"`
+ ContainsURL *bool `json:"contains_url,omitempty"`
+ LazyLoadMembers bool `json:"lazy_load_members,omitempty"`
+ IncludeRedundantMembers bool `json:"include_redundant_members,omitempty"`
+ UnreadThreadNotifications bool `json:"unread_thread_notifications,omitempty"`
}
// Validate checks if the filter contains valid property values
func (filter *Filter) Validate() error {
if filter.EventFormat != EventFormatClient && filter.EventFormat != EventFormatFederation {
- return errors.New("Bad event_format value. Must be one of [\"client\", \"federation\"]")
+ return errors.New("bad event_format value")
}
return nil
}
@@ -69,7 +69,7 @@ func DefaultFilter() Filter {
EventFields: nil,
EventFormat: "client",
Presence: DefaultFilterPart(),
- Room: RoomFilter{
+ Room: &RoomFilter{
AccountData: DefaultFilterPart(),
Ephemeral: DefaultFilterPart(),
IncludeLeave: false,
@@ -82,8 +82,8 @@ func DefaultFilter() Filter {
}
// DefaultFilterPart returns the default filter part used by the Matrix server if no filter is provided in the request
-func DefaultFilterPart() FilterPart {
- return FilterPart{
+func DefaultFilterPart() *FilterPart {
+ return &FilterPart{
NotRooms: nil,
Rooms: nil,
Limit: 20,
diff --git a/format/htmlparser.go b/format/htmlparser.go
index d099e8a7..e0507d93 100644
--- a/format/htmlparser.go
+++ b/format/htmlparser.go
@@ -13,6 +13,7 @@ import (
"strconv"
"strings"
+ "go.mau.fi/util/exstrings"
"golang.org/x/net/html"
"maunium.net/go/mautrix/event"
@@ -66,6 +67,7 @@ type LinkConverter func(text, href string, ctx Context) string
type ColorConverter func(text, fg, bg string, ctx Context) string
type CodeBlockConverter func(code, language string, ctx Context) string
type PillConverter func(displayname, mxid, eventID string, ctx Context) string
+type ImageConverter func(src, alt, title, width, height string, isEmoji bool) string
const ContextKeyMentions = "_mentions"
@@ -91,6 +93,30 @@ func DefaultPillConverter(displayname, mxid, eventID string, ctx Context) string
}
}
+func onlyBacktickCount(line string) (count int) {
+ for i := 0; i < len(line); i++ {
+ if line[i] != '`' {
+ return -1
+ }
+ count++
+ }
+ return
+}
+
+func DefaultMonospaceBlockConverter(code, language string, ctx Context) string {
+ if len(code) == 0 || code[len(code)-1] != '\n' {
+ code += "\n"
+ }
+ fence := "```"
+ for line := range strings.SplitSeq(code, "\n") {
+ count := onlyBacktickCount(strings.TrimSpace(line))
+ if count >= len(fence) {
+ fence = strings.Repeat("`", count+1)
+ }
+ }
+ return fmt.Sprintf("%s%s\n%s%s", fence, language, code, fence)
+}
+
// HTMLParser is a somewhat customizable Matrix HTML parser.
type HTMLParser struct {
PillConverter PillConverter
@@ -101,12 +127,15 @@ type HTMLParser struct {
ItalicConverter TextConverter
StrikethroughConverter TextConverter
UnderlineConverter TextConverter
+ MathConverter TextConverter
+ MathBlockConverter TextConverter
LinkConverter LinkConverter
SpoilerConverter SpoilerConverter
ColorConverter ColorConverter
MonospaceBlockConverter CodeBlockConverter
MonospaceConverter TextConverter
TextConverter TextConverter
+ ImageConverter ImageConverter
}
// TaggedString is a string that also contains a HTML tag.
@@ -183,25 +212,6 @@ func (parser *HTMLParser) listToString(node *html.Node, ctx Context) string {
return strings.Join(children, "\n")
}
-func LongestSequence(in string, of rune) int {
- currentSeq := 0
- maxSeq := 0
- for _, chr := range in {
- if chr == of {
- currentSeq++
- } else {
- if currentSeq > maxSeq {
- maxSeq = currentSeq
- }
- currentSeq = 0
- }
- }
- if currentSeq > maxSeq {
- maxSeq = currentSeq
- }
- return maxSeq
-}
-
func (parser *HTMLParser) basicFormatToString(node *html.Node, ctx Context) string {
str := parser.nodeToTagAwareString(node.FirstChild, ctx)
switch node.Data {
@@ -228,14 +238,23 @@ func (parser *HTMLParser) basicFormatToString(node *html.Node, ctx Context) stri
if parser.MonospaceConverter != nil {
return parser.MonospaceConverter(str, ctx)
}
- surround := strings.Repeat("`", LongestSequence(str, '`')+1)
- return fmt.Sprintf("%s%s%s", surround, str, surround)
+ return SafeMarkdownCode(str)
}
return str
}
func (parser *HTMLParser) spanToString(node *html.Node, ctx Context) string {
str := parser.nodeToTagAwareString(node.FirstChild, ctx)
+ if node.Data == "span" || node.Data == "div" {
+ math, _ := parser.maybeGetAttribute(node, "data-mx-maths")
+ if math != "" && parser.MathConverter != nil {
+ if node.Data == "div" && parser.MathBlockConverter != nil {
+ str = parser.MathBlockConverter(math, ctx)
+ } else {
+ str = parser.MathConverter(math, ctx)
+ }
+ }
+ }
if node.Data == "span" {
reason, isSpoiler := parser.maybeGetAttribute(node, "data-mx-spoiler")
if isSpoiler {
@@ -292,12 +311,28 @@ func (parser *HTMLParser) linkToString(node *html.Node, ctx Context) string {
}
if parser.LinkConverter != nil {
return parser.LinkConverter(str, href, ctx)
- } else if str == href {
+ } else if str == href ||
+ str == strings.TrimPrefix(href, "mailto:") ||
+ str == strings.TrimPrefix(href, "http://") ||
+ str == strings.TrimPrefix(href, "https://") {
return str
}
return fmt.Sprintf("%s (%s)", str, href)
}
+func (parser *HTMLParser) imgToString(node *html.Node, ctx Context) string {
+ src := parser.getAttribute(node, "src")
+ alt := parser.getAttribute(node, "alt")
+ title := parser.getAttribute(node, "title")
+ width := parser.getAttribute(node, "width")
+ height := parser.getAttribute(node, "height")
+ _, isEmoji := parser.maybeGetAttribute(node, "data-mx-emoticon")
+ if parser.ImageConverter != nil {
+ return parser.ImageConverter(src, alt, title, width, height, isEmoji)
+ }
+ return alt
+}
+
func (parser *HTMLParser) tagToString(node *html.Node, ctx Context) string {
ctx = ctx.WithTag(node.Data)
switch node.Data {
@@ -317,8 +352,12 @@ func (parser *HTMLParser) tagToString(node *html.Node, ctx Context) string {
return parser.linkToString(node, ctx)
case "p":
return parser.nodeToTagAwareString(node.FirstChild, ctx)
+ case "img":
+ return parser.imgToString(node, ctx)
case "hr":
return parser.HorizontalLine
+ case "input":
+ return parser.inputToString(node, ctx)
case "pre":
var preStr, language string
if node.FirstChild != nil && node.FirstChild.Type == html.ElementNode && node.FirstChild.Data == "code" {
@@ -333,20 +372,28 @@ func (parser *HTMLParser) tagToString(node *html.Node, ctx Context) string {
if parser.MonospaceBlockConverter != nil {
return parser.MonospaceBlockConverter(preStr, language, ctx)
}
- if len(preStr) == 0 || preStr[len(preStr)-1] != '\n' {
- preStr += "\n"
- }
- return fmt.Sprintf("```%s\n%s```", language, preStr)
+ return DefaultMonospaceBlockConverter(preStr, language, ctx)
default:
return parser.nodeToTagAwareString(node.FirstChild, ctx)
}
}
+func (parser *HTMLParser) inputToString(node *html.Node, ctx Context) string {
+ if len(ctx.TagStack) > 1 && ctx.TagStack[len(ctx.TagStack)-2] == "li" {
+ _, checked := parser.maybeGetAttribute(node, "checked")
+ if checked {
+ return "[x]"
+ }
+ return "[ ]"
+ }
+ return parser.nodeToTagAwareString(node.FirstChild, ctx)
+}
+
func (parser *HTMLParser) singleNodeToString(node *html.Node, ctx Context) TaggedString {
switch node.Type {
case html.TextNode:
if !ctx.PreserveWhitespace {
- node.Data = strings.Replace(node.Data, "\n", "", -1)
+ node.Data = exstrings.CollapseSpaces(strings.ReplaceAll(node.Data, "\n", ""))
}
if parser.TextConverter != nil {
node.Data = parser.TextConverter(node.Data, ctx)
@@ -412,6 +459,35 @@ func (parser *HTMLParser) Parse(htmlData string, ctx Context) string {
return parser.nodeToTagAwareString(node, ctx)
}
+var TextHTMLParser = &HTMLParser{
+ TabsToSpaces: 4,
+ Newline: "\n",
+ HorizontalLine: "\n---\n",
+ PillConverter: DefaultPillConverter,
+}
+
+var MarkdownHTMLParser = &HTMLParser{
+ TabsToSpaces: 4,
+ Newline: "\n",
+ HorizontalLine: "\n---\n",
+ PillConverter: DefaultPillConverter,
+ LinkConverter: func(text, href string, ctx Context) string {
+ if text == href {
+ return fmt.Sprintf("<%s>", href)
+ }
+ return fmt.Sprintf("[%s](%s)", text, href)
+ },
+ MathConverter: func(s string, c Context) string {
+ return fmt.Sprintf("$%s$", s)
+ },
+ MathBlockConverter: func(s string, c Context) string {
+ return fmt.Sprintf("$$\n%s\n$$", s)
+ },
+ UnderlineConverter: func(s string, c Context) string {
+ return fmt.Sprintf("%s", s)
+ },
+}
+
// HTMLToText converts Matrix HTML into text with the default settings.
func HTMLToText(html string) string {
return (&HTMLParser{
@@ -422,20 +498,12 @@ func HTMLToText(html string) string {
}).Parse(html, NewContext(context.TODO()))
}
-func HTMLToMarkdownAndMentions(html string) (parsed string, mentions *event.Mentions) {
+func HTMLToMarkdownFull(parser *HTMLParser, html string) (parsed string, mentions *event.Mentions) {
+ if parser == nil {
+ parser = MarkdownHTMLParser
+ }
ctx := NewContext(context.TODO())
- parsed = (&HTMLParser{
- TabsToSpaces: 4,
- Newline: "\n",
- HorizontalLine: "\n---\n",
- PillConverter: DefaultPillConverter,
- LinkConverter: func(text, href string, ctx Context) string {
- if text == href {
- return text
- }
- return fmt.Sprintf("[%s](%s)", text, href)
- },
- }).Parse(html, ctx)
+ parsed = parser.Parse(html, ctx)
mentionList, _ := ctx.ReturnData[ContextKeyMentions].([]id.UserID)
mentions = &event.Mentions{
UserIDs: mentionList,
@@ -447,6 +515,6 @@ func HTMLToMarkdownAndMentions(html string) (parsed string, mentions *event.Ment
//
// Currently, the only difference to HTMLToText is how links are formatted.
func HTMLToMarkdown(html string) string {
- parsed, _ := HTMLToMarkdownAndMentions(html)
+ parsed, _ := HTMLToMarkdownFull(nil, html)
return parsed
}
diff --git a/format/markdown.go b/format/markdown.go
index 11f9f684..77ced0dc 100644
--- a/format/markdown.go
+++ b/format/markdown.go
@@ -8,14 +8,17 @@ package format
import (
"fmt"
+ "regexp"
"strings"
"github.com/yuin/goldmark"
"github.com/yuin/goldmark/extension"
"github.com/yuin/goldmark/renderer/html"
+ "go.mau.fi/util/exstrings"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format/mdext"
+ "maunium.net/go/mautrix/id"
)
const paragraphStart = ""
@@ -39,6 +42,55 @@ func UnwrapSingleParagraph(html string) string {
return html
}
+var mdEscapeRegex = regexp.MustCompile("([\\\\`*_[\\]()])")
+
+func EscapeMarkdown(text string) string {
+ text = mdEscapeRegex.ReplaceAllString(text, "\\$1")
+ text = strings.ReplaceAll(text, ">", ">")
+ text = strings.ReplaceAll(text, "<", "<")
+ return text
+}
+
+type uriAble interface {
+ String() string
+ URI() *id.MatrixURI
+}
+
+func MarkdownMention(id uriAble) string {
+ return MarkdownMentionWithName(id.String(), id)
+}
+
+func MarkdownMentionWithName(name string, id uriAble) string {
+ return MarkdownLink(name, id.URI().MatrixToURL())
+}
+
+func MarkdownMentionRoomID(name string, id id.RoomID, via ...string) string {
+ if name == "" {
+ name = id.String()
+ }
+ return MarkdownLink(name, id.URI(via...).MatrixToURL())
+}
+
+func MarkdownLink(name string, url string) string {
+ return fmt.Sprintf("[%s](%s)", EscapeMarkdown(name), EscapeMarkdown(url))
+}
+
+func SafeMarkdownCode[T ~string](textInput T) string {
+ if textInput == "" {
+ return "` `"
+ }
+ text := strings.ReplaceAll(string(textInput), "\n", " ")
+ backtickCount := exstrings.LongestSequenceOf(text, '`')
+ if backtickCount == 0 {
+ return fmt.Sprintf("`%s`", text)
+ }
+ quotes := strings.Repeat("`", backtickCount+1)
+ if text[0] == '`' || text[len(text)-1] == '`' {
+ return fmt.Sprintf("%s %s %s", quotes, text, quotes)
+ }
+ return fmt.Sprintf("%s%s%s", quotes, text, quotes)
+}
+
func RenderMarkdownCustom(text string, renderer goldmark.Markdown) event.MessageEventContent {
var buf strings.Builder
err := renderer.Convert([]byte(text), &buf)
@@ -49,8 +101,16 @@ func RenderMarkdownCustom(text string, renderer goldmark.Markdown) event.Message
return HTMLToContent(htmlBody)
}
-func HTMLToContent(html string) event.MessageEventContent {
- text, mentions := HTMLToMarkdownAndMentions(html)
+func TextToContent(text string) event.MessageEventContent {
+ return event.MessageEventContent{
+ MsgType: event.MsgText,
+ Body: text,
+ Mentions: &event.Mentions{},
+ }
+}
+
+func HTMLToContentFull(renderer *HTMLParser, html string) event.MessageEventContent {
+ text, mentions := HTMLToMarkdownFull(renderer, html)
if html != text {
return event.MessageEventContent{
FormattedBody: html,
@@ -60,11 +120,11 @@ func HTMLToContent(html string) event.MessageEventContent {
Mentions: mentions,
}
}
- return event.MessageEventContent{
- MsgType: event.MsgText,
- Body: text,
- Mentions: &event.Mentions{},
- }
+ return TextToContent(text)
+}
+
+func HTMLToContent(html string) event.MessageEventContent {
+ return HTMLToContentFull(nil, html)
}
func RenderMarkdown(text string, allowMarkdown, allowHTML bool) event.MessageEventContent {
@@ -80,10 +140,6 @@ func RenderMarkdown(text string, allowMarkdown, allowHTML bool) event.MessageEve
htmlBody = strings.Replace(text, "\n", "
", -1)
return HTMLToContent(htmlBody)
} else {
- return event.MessageEventContent{
- MsgType: event.MsgText,
- Body: text,
- Mentions: &event.Mentions{},
- }
+ return TextToContent(text)
}
}
diff --git a/format/markdown_test.go b/format/markdown_test.go
index 10ae270c..46ea4886 100644
--- a/format/markdown_test.go
+++ b/format/markdown_test.go
@@ -158,3 +158,56 @@ func TestRenderMarkdown_DiscordUnderline(t *testing.T) {
assert.Equal(t, html, strings.ReplaceAll(rendered, "\n", ""))
}
}
+
+var mathTests = map[string]string{
+ "$foo$": `foo`,
+ "hello $foo$ world": `hello foo world`,
+ "$$\nfoo\nbar\n$$": `
foo
bar`,
+ "`$foo$`": `$foo$`,
+ "```\n$foo$\n```": `$foo$\n
`,
+ "~~meow $foo$ asd~~": `meow foo asd`,
+ "$5 or $10": `$5 or $10`,
+ "5$ or 10$": `5$ or 10$`,
+ "$5 or 10$": `5 or 10`,
+ "$*500*$": `*500*`,
+ "$$\n*500*\n$$": `*500*`,
+
+ // TODO: This doesn't work :(
+ // Maybe same reason as the spoiler wrapping not working?
+ //"~~$foo$~~": `foo`,
+}
+
+func TestRenderMarkdown_Math(t *testing.T) {
+ renderer := goldmark.New(goldmark.WithExtensions(extension.Strikethrough, mdext.Math, mdext.EscapeHTML), format.HTMLOptions)
+ for markdown, html := range mathTests {
+ rendered := format.UnwrapSingleParagraph(render(renderer, markdown))
+ assert.Equal(t, html, strings.ReplaceAll(rendered, "\n", "\\n"), "with input %q", markdown)
+ }
+}
+
+var customEmojiTests = map[string]string{
+ ``: `
`,
+}
+
+func TestRenderMarkdown_CustomEmoji(t *testing.T) {
+ renderer := goldmark.New(goldmark.WithExtensions(mdext.CustomEmoji), format.HTMLOptions)
+ for markdown, html := range customEmojiTests {
+ rendered := format.UnwrapSingleParagraph(render(renderer, markdown))
+ assert.Equal(t, html, rendered, "with input %q", markdown)
+ }
+}
+
+var codeTests = map[string]string{
+ "meow": "`meow`",
+ "me`ow": "``me`ow``",
+ "`me`ow": "`` `me`ow ``",
+ "me`ow`": "`` me`ow` ``",
+ "`meow`": "`` `meow` ``",
+ "`````````": "`````````` ````````` ``````````",
+}
+
+func TestSafeMarkdownCode(t *testing.T) {
+ for input, expected := range codeTests {
+ assert.Equal(t, expected, format.SafeMarkdownCode(input), "with input %q", input)
+ }
+}
diff --git a/format/mdext/customemoji.go b/format/mdext/customemoji.go
new file mode 100644
index 00000000..2884a5ea
--- /dev/null
+++ b/format/mdext/customemoji.go
@@ -0,0 +1,73 @@
+// Copyright (c) 2024 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mdext
+
+import (
+ "bytes"
+
+ "github.com/yuin/goldmark"
+ "github.com/yuin/goldmark/ast"
+ "github.com/yuin/goldmark/renderer"
+ "github.com/yuin/goldmark/util"
+)
+
+type extCustomEmoji struct{}
+type customEmojiRenderer struct {
+ funcs functionCapturer
+}
+
+// CustomEmoji is an extension that converts certain markdown images into Matrix custom emojis.
+var CustomEmoji = &extCustomEmoji{}
+
+type functionCapturer struct {
+ renderImage renderer.NodeRendererFunc
+ renderText renderer.NodeRendererFunc
+ renderString renderer.NodeRendererFunc
+}
+
+func (fc *functionCapturer) Register(kind ast.NodeKind, rendererFunc renderer.NodeRendererFunc) {
+ switch kind {
+ case ast.KindImage:
+ fc.renderImage = rendererFunc
+ case ast.KindText:
+ fc.renderText = rendererFunc
+ case ast.KindString:
+ fc.renderString = rendererFunc
+ }
+}
+
+var (
+ _ renderer.NodeRendererFuncRegisterer = (*functionCapturer)(nil)
+ _ renderer.Option = (*functionCapturer)(nil)
+)
+
+func (fc *functionCapturer) SetConfig(cfg *renderer.Config) {
+ cfg.NodeRenderers[0].Value.(renderer.NodeRenderer).RegisterFuncs(fc)
+}
+
+func (eeh *extCustomEmoji) Extend(m goldmark.Markdown) {
+ var fc functionCapturer
+ m.Renderer().AddOptions(&fc)
+ m.Renderer().AddOptions(renderer.WithNodeRenderers(util.Prioritized(&customEmojiRenderer{fc}, 0)))
+}
+
+func (cer *customEmojiRenderer) RegisterFuncs(reg renderer.NodeRendererFuncRegisterer) {
+ reg.Register(ast.KindImage, cer.renderImage)
+}
+
+var emojiPrefix = []byte("Emoji: ")
+var mxcPrefix = []byte("mxc://")
+
+func (cer *customEmojiRenderer) renderImage(w util.BufWriter, source []byte, node ast.Node, entering bool) (ast.WalkStatus, error) {
+ n, ok := node.(*ast.Image)
+ if ok && entering && bytes.HasPrefix(n.Title, emojiPrefix) && bytes.HasPrefix(n.Destination, mxcPrefix) {
+ n.Title = bytes.TrimPrefix(n.Title, emojiPrefix)
+ n.SetAttributeString("data-mx-emoticon", nil)
+ n.SetAttributeString("height", "32")
+ }
+ return cer.funcs.renderImage(w, source, node, entering)
+}
diff --git a/format/mdext/math.go b/format/mdext/math.go
new file mode 100644
index 00000000..e6a6ecc5
--- /dev/null
+++ b/format/mdext/math.go
@@ -0,0 +1,240 @@
+// Copyright (c) 2024 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mdext
+
+import (
+ "bytes"
+ "fmt"
+ stdhtml "html"
+ "regexp"
+ "strings"
+ "unicode"
+
+ "github.com/yuin/goldmark"
+ "github.com/yuin/goldmark/ast"
+ "github.com/yuin/goldmark/parser"
+ "github.com/yuin/goldmark/renderer"
+ "github.com/yuin/goldmark/renderer/html"
+ "github.com/yuin/goldmark/text"
+ "github.com/yuin/goldmark/util"
+)
+
+var astKindMath = ast.NewNodeKind("Math")
+
+type astMath struct {
+ ast.BaseInline
+ value []byte
+}
+
+func (n *astMath) Dump(source []byte, level int) {
+ ast.DumpHelper(n, source, level, nil, nil)
+}
+
+func (n *astMath) Kind() ast.NodeKind {
+ return astKindMath
+}
+
+type astMathBlock struct {
+ ast.BaseBlock
+}
+
+func (n *astMathBlock) Dump(source []byte, level int) {
+ ast.DumpHelper(n, source, level, nil, nil)
+}
+
+func (n *astMathBlock) Kind() ast.NodeKind {
+ return astKindMath
+}
+
+type inlineMathParser struct{}
+
+var defaultInlineMathParser = &inlineMathParser{}
+
+func NewInlineMathParser() parser.InlineParser {
+ return defaultInlineMathParser
+}
+
+const mathDelimiter = '$'
+
+func (s *inlineMathParser) Trigger() []byte {
+ return []byte{mathDelimiter}
+}
+
+// This ignores lines where there's no space after the closing $ to avoid false positives
+var latexInlineRegexp = regexp.MustCompile(`^(\$[^$]*\$)(?:$|\s)`)
+
+func (s *inlineMathParser) Parse(parent ast.Node, block text.Reader, pc parser.Context) ast.Node {
+ before := block.PrecendingCharacter()
+ // Ignore lines where the opening $ comes after a letter or number to avoid false positives
+ if unicode.IsLetter(before) || unicode.IsNumber(before) {
+ return nil
+ }
+ line, segment := block.PeekLine()
+ idx := latexInlineRegexp.FindSubmatchIndex(line)
+ if idx == nil {
+ return nil
+ }
+ block.Advance(idx[3])
+ return &astMath{
+ value: block.Value(text.NewSegment(segment.Start+1, segment.Start+idx[3]-1)),
+ }
+}
+
+func (s *inlineMathParser) CloseBlock(parent ast.Node, pc parser.Context) {
+ // nothing to do
+}
+
+type blockMathParser struct{}
+
+var defaultBlockMathParser = &blockMathParser{}
+
+func NewBlockMathParser() parser.BlockParser {
+ return defaultBlockMathParser
+}
+
+var mathBlockInfoKey = parser.NewContextKey()
+
+type mathBlockData struct {
+ indent int
+ length int
+ node ast.Node
+}
+
+func (b *blockMathParser) Trigger() []byte {
+ return []byte{'$'}
+}
+
+func (b *blockMathParser) Open(parent ast.Node, reader text.Reader, pc parser.Context) (ast.Node, parser.State) {
+ line, _ := reader.PeekLine()
+ pos := pc.BlockOffset()
+ if pos < 0 || (line[pos] != mathDelimiter) {
+ return nil, parser.NoChildren
+ }
+ findent := pos
+ i := pos
+ for ; i < len(line) && line[i] == mathDelimiter; i++ {
+ }
+ oFenceLength := i - pos
+ if oFenceLength < 2 {
+ return nil, parser.NoChildren
+ }
+ if i < len(line)-1 {
+ rest := line[i:]
+ left := util.TrimLeftSpaceLength(rest)
+ right := util.TrimRightSpaceLength(rest)
+ if left < len(rest)-right {
+ value := rest[left : len(rest)-right]
+ if bytes.IndexByte(value, mathDelimiter) > -1 {
+ return nil, parser.NoChildren
+ }
+ }
+ }
+ node := &astMathBlock{}
+ pc.Set(mathBlockInfoKey, &mathBlockData{findent, oFenceLength, node})
+ return node, parser.NoChildren
+
+}
+
+func (b *blockMathParser) Continue(node ast.Node, reader text.Reader, pc parser.Context) parser.State {
+ line, segment := reader.PeekLine()
+ fdata := pc.Get(mathBlockInfoKey).(*mathBlockData)
+
+ w, pos := util.IndentWidth(line, reader.LineOffset())
+ if w < 4 {
+ i := pos
+ for ; i < len(line) && line[i] == mathDelimiter; i++ {
+ }
+ length := i - pos
+ if length >= fdata.length && util.IsBlank(line[i:]) {
+ newline := 1
+ if line[len(line)-1] != '\n' {
+ newline = 0
+ }
+ reader.Advance(segment.Stop - segment.Start - newline + segment.Padding)
+ return parser.Close
+ }
+ }
+ pos, padding := util.IndentPositionPadding(line, reader.LineOffset(), segment.Padding, fdata.indent)
+ if pos < 0 {
+ pos = util.FirstNonSpacePosition(line)
+ if pos < 0 {
+ pos = 0
+ }
+ padding = 0
+ }
+ seg := text.NewSegmentPadding(segment.Start+pos, segment.Stop, padding)
+ seg.ForceNewline = true // EOF as newline
+ node.Lines().Append(seg)
+ reader.AdvanceAndSetPadding(segment.Stop-segment.Start-pos-1, padding)
+ return parser.Continue | parser.NoChildren
+}
+
+func (b *blockMathParser) Close(node ast.Node, reader text.Reader, pc parser.Context) {
+ fdata := pc.Get(mathBlockInfoKey).(*mathBlockData)
+ if fdata.node == node {
+ pc.Set(mathBlockInfoKey, nil)
+ }
+}
+
+func (b *blockMathParser) CanInterruptParagraph() bool {
+ return true
+}
+
+func (b *blockMathParser) CanAcceptIndentedLine() bool {
+ return false
+}
+
+type mathHTMLRenderer struct {
+ html.Config
+}
+
+func NewMathHTMLRenderer(opts ...html.Option) renderer.NodeRenderer {
+ r := &mathHTMLRenderer{
+ Config: html.NewConfig(),
+ }
+ for _, opt := range opts {
+ opt.SetHTMLOption(&r.Config)
+ }
+ return r
+}
+
+func (r *mathHTMLRenderer) RegisterFuncs(reg renderer.NodeRendererFuncRegisterer) {
+ reg.Register(astKindMath, r.renderMath)
+}
+
+func (r *mathHTMLRenderer) renderMath(w util.BufWriter, source []byte, n ast.Node, entering bool) (ast.WalkStatus, error) {
+ if entering {
+ tag := "span"
+ var tex string
+ switch typed := n.(type) {
+ case *astMathBlock:
+ tag = "div"
+ tex = string(n.Lines().Value(source))
+ case *astMath:
+ tex = string(typed.value)
+ }
+ tex = stdhtml.EscapeString(strings.TrimSpace(tex))
+ _, _ = fmt.Fprintf(w, `<%s data-mx-maths="%s">%s%s>`, tag, tex, strings.ReplaceAll(tex, "\n", "
"), tag)
+ }
+ return ast.WalkSkipChildren, nil
+}
+
+type math struct{}
+
+// Math is an extension that allow you to use math like '$$text$$'.
+var Math = &math{}
+
+func (e *math) Extend(m goldmark.Markdown) {
+ m.Parser().AddOptions(parser.WithInlineParsers(
+ util.Prioritized(NewInlineMathParser(), 500),
+ ), parser.WithBlockParsers(
+ util.Prioritized(NewBlockMathParser(), 850),
+ ))
+ m.Renderer().AddOptions(renderer.WithNodeRenderers(
+ util.Prioritized(NewMathHTMLRenderer(), 500),
+ ))
+}
diff --git a/format/mdext/rainbow/goldmark.go b/format/mdext/rainbow/goldmark.go
deleted file mode 100644
index 59a36178..00000000
--- a/format/mdext/rainbow/goldmark.go
+++ /dev/null
@@ -1,120 +0,0 @@
-// Copyright (c) 2024 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package rainbow
-
-import (
- "fmt"
- "unicode"
-
- "github.com/rivo/uniseg"
- "github.com/yuin/goldmark"
- "github.com/yuin/goldmark/ast"
- "github.com/yuin/goldmark/renderer"
- "github.com/yuin/goldmark/renderer/html"
- "github.com/yuin/goldmark/util"
- "go.mau.fi/util/random"
-)
-
-// Extension is a goldmark extension that adds rainbow text coloring to the HTML renderer.
-var Extension = &extRainbow{}
-
-type extRainbow struct{}
-type rainbowRenderer struct {
- HardWraps bool
- ColorID string
-}
-
-var defaultRB = &rainbowRenderer{HardWraps: true, ColorID: random.String(16)}
-
-func (er *extRainbow) Extend(m goldmark.Markdown) {
- m.Renderer().AddOptions(renderer.WithNodeRenderers(util.Prioritized(defaultRB, 0)))
-}
-
-func (rb *rainbowRenderer) RegisterFuncs(reg renderer.NodeRendererFuncRegisterer) {
- reg.Register(ast.KindText, rb.renderText)
- reg.Register(ast.KindString, rb.renderString)
-}
-
-type rainbowBufWriter struct {
- util.BufWriter
- ColorID string
-}
-
-func (rbw rainbowBufWriter) WriteString(s string) (int, error) {
- i := 0
- graphemes := uniseg.NewGraphemes(s)
- for graphemes.Next() {
- runes := graphemes.Runes()
- if len(runes) == 1 && unicode.IsSpace(runes[0]) {
- i2, err := rbw.BufWriter.WriteRune(runes[0])
- i += i2
- if err != nil {
- return i, err
- }
- continue
- }
- i2, err := fmt.Fprintf(rbw.BufWriter, "%s", rbw.ColorID, graphemes.Str())
- i += i2
- if err != nil {
- return i, err
- }
- }
- return i, nil
-}
-
-func (rbw rainbowBufWriter) Write(data []byte) (int, error) {
- return rbw.WriteString(string(data))
-}
-
-func (rbw rainbowBufWriter) WriteByte(c byte) error {
- _, err := rbw.WriteRune(rune(c))
- return err
-}
-
-func (rbw rainbowBufWriter) WriteRune(r rune) (int, error) {
- if unicode.IsSpace(r) {
- return rbw.BufWriter.WriteRune(r)
- } else {
- return fmt.Fprintf(rbw.BufWriter, "%c", rbw.ColorID, r)
- }
-}
-
-func (rb *rainbowRenderer) renderText(w util.BufWriter, source []byte, node ast.Node, entering bool) (ast.WalkStatus, error) {
- if !entering {
- return ast.WalkContinue, nil
- }
- n := node.(*ast.Text)
- segment := n.Segment
- if n.IsRaw() {
- html.DefaultWriter.RawWrite(rainbowBufWriter{w, rb.ColorID}, segment.Value(source))
- } else {
- html.DefaultWriter.Write(rainbowBufWriter{w, rb.ColorID}, segment.Value(source))
- if n.HardLineBreak() || (n.SoftLineBreak() && rb.HardWraps) {
- _, _ = w.WriteString("
\n")
- } else if n.SoftLineBreak() {
- _ = w.WriteByte('\n')
- }
- }
- return ast.WalkContinue, nil
-}
-
-func (rb *rainbowRenderer) renderString(w util.BufWriter, source []byte, node ast.Node, entering bool) (ast.WalkStatus, error) {
- if !entering {
- return ast.WalkContinue, nil
- }
- n := node.(*ast.String)
- if n.IsCode() {
- _, _ = w.Write(n.Value)
- } else {
- if n.IsRaw() {
- html.DefaultWriter.RawWrite(rainbowBufWriter{w, rb.ColorID}, n.Value)
- } else {
- html.DefaultWriter.Write(rainbowBufWriter{w, rb.ColorID}, n.Value)
- }
- }
- return ast.WalkContinue, nil
-}
diff --git a/format/mdext/rainbow/gradient.go b/format/mdext/rainbow/gradient.go
deleted file mode 100644
index 34c499e6..00000000
--- a/format/mdext/rainbow/gradient.go
+++ /dev/null
@@ -1,56 +0,0 @@
-// Copyright (c) 2024 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package rainbow
-
-import (
- "regexp"
- "strings"
-
- "github.com/lucasb-eyer/go-colorful"
-)
-
-// GradientTable from https://github.com/lucasb-eyer/go-colorful/blob/master/doc/gradientgen/gradientgen.go
-type GradientTable []struct {
- Col colorful.Color
- Pos float64
-}
-
-func (gt GradientTable) GetInterpolatedColorFor(t float64) colorful.Color {
- for i := 0; i < len(gt)-1; i++ {
- c1 := gt[i]
- c2 := gt[i+1]
- if c1.Pos <= t && t <= c2.Pos {
- t := (t - c1.Pos) / (c2.Pos - c1.Pos)
- return c1.Col.BlendHcl(c2.Col, t).Clamped()
- }
- }
- return gt[len(gt)-1].Col
-}
-
-var Gradient = GradientTable{
- {colorful.LinearRgb(1, 0, 0), 0 / 11.0},
- {colorful.LinearRgb(1, 0.5, 0), 1 / 11.0},
- {colorful.LinearRgb(1, 1, 0), 2 / 11.0},
- {colorful.LinearRgb(0.5, 1, 0), 3 / 11.0},
- {colorful.LinearRgb(0, 1, 0), 4 / 11.0},
- {colorful.LinearRgb(0, 1, 0.5), 5 / 11.0},
- {colorful.LinearRgb(0, 1, 1), 6 / 11.0},
- {colorful.LinearRgb(0, 0.5, 1), 7 / 11.0},
- {colorful.LinearRgb(0, 0, 1), 8 / 11.0},
- {colorful.LinearRgb(0.5, 0, 1), 9 / 11.0},
- {colorful.LinearRgb(1, 0, 1), 10 / 11.0},
- {colorful.LinearRgb(1, 0, 0.5), 11 / 11.0},
-}
-
-func ApplyColor(htmlBody string) string {
- count := strings.Count(htmlBody, defaultRB.ColorID)
- i := -1
- return regexp.MustCompile(defaultRB.ColorID).ReplaceAllStringFunc(htmlBody, func(match string) string {
- i++
- return Gradient.GetInterpolatedColorFor(float64(i) / float64(count)).Hex()
- })
-}
diff --git a/go.mod b/go.mod
index f45b8990..49a1d4e4 100644
--- a/go.mod
+++ b/go.mod
@@ -1,45 +1,42 @@
module maunium.net/go/mautrix
-go 1.22.0
+go 1.25.0
-toolchain go1.23.2
+toolchain go1.26.0
require (
- filippo.io/edwards25519 v1.1.0
+ filippo.io/edwards25519 v1.2.0
github.com/chzyer/readline v1.5.1
- github.com/gorilla/mux v1.8.0
- github.com/gorilla/websocket v1.5.0
- github.com/lib/pq v1.10.9
- github.com/lucasb-eyer/go-colorful v1.2.0
- github.com/mattn/go-sqlite3 v1.14.24
- github.com/rivo/uniseg v0.4.7
+ github.com/coder/websocket v1.8.14
+ github.com/lib/pq v1.11.2
+ github.com/mattn/go-sqlite3 v1.14.34
github.com/rs/xid v1.6.0
- github.com/rs/zerolog v1.33.0
+ github.com/rs/zerolog v1.34.0
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
- github.com/stretchr/testify v1.9.0
+ github.com/stretchr/testify v1.11.1
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
- github.com/yuin/goldmark v1.7.7
- go.mau.fi/util v0.8.1
- go.mau.fi/zeroconfig v0.1.3
- golang.org/x/crypto v0.28.0
- golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c
- golang.org/x/net v0.30.0
- golang.org/x/sync v0.8.0
+ github.com/yuin/goldmark v1.7.16
+ go.mau.fi/util v0.9.6
+ go.mau.fi/zeroconfig v0.2.0
+ golang.org/x/crypto v0.48.0
+ golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa
+ golang.org/x/net v0.50.0
+ golang.org/x/sync v0.19.0
gopkg.in/yaml.v3 v3.0.1
maunium.net/go/mauflag v1.0.0
)
require (
- github.com/coreos/go-systemd/v22 v22.5.0 // indirect
+ github.com/coreos/go-systemd/v22 v22.6.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
- github.com/mattn/go-colorable v0.1.13 // indirect
- github.com/mattn/go-isatty v0.0.19 // indirect
- github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7 // indirect
+ github.com/mattn/go-colorable v0.1.14 // indirect
+ github.com/mattn/go-isatty v0.0.20 // indirect
+ github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
- github.com/tidwall/pretty v1.2.0 // indirect
- golang.org/x/sys v0.26.0 // indirect
- golang.org/x/text v0.19.0 // indirect
+ github.com/tidwall/pretty v1.2.1 // indirect
+ golang.org/x/sys v0.41.0 // indirect
+ golang.org/x/text v0.34.0 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
)
diff --git a/go.sum b/go.sum
index e7a58076..871a5156 100644
--- a/go.sum
+++ b/go.sum
@@ -1,5 +1,5 @@
-filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
-filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
+filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo=
+filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM=
@@ -8,73 +8,70 @@ github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI
github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk=
github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04=
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
-github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs=
+github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
+github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
+github.com/coreos/go-systemd/v22 v22.6.0 h1:aGVa/v8B7hpb0TKl0MWoAavPDmHvobFe5R5zn0bCJWo=
+github.com/coreos/go-systemd/v22 v22.6.0/go.mod h1:iG+pp635Fo7ZmV/j14KUcmEyWF+0X7Lua8rrTWzYgWU=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
-github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
-github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
-github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
-github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
-github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
-github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
-github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
-github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
-github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
+github.com/lib/pq v1.11.2 h1:x6gxUeu39V0BHZiugWe8LXZYZ+Utk7hSJGThs8sdzfs=
+github.com/lib/pq v1.11.2/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
+github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
+github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
-github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
-github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
-github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
-github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7 h1:Dx7Ovyv/SFnMFw3fD4oEoeorXc6saIiQ23LrGLth0Gw=
-github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4=
+github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
+github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
+github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp6Zk=
+github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
+github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 h1:KPpdlQLZcHfTMQRi6bFQ7ogNO0ltFT4PmtwTLW4W+14=
+github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
-github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
-github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
-github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU=
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
-github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
-github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
+github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
+github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
-github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
-github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
+github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
+github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
-github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
+github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
+github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
-github.com/yuin/goldmark v1.7.7 h1:5m9rrB1sW3JUMToKFQfb+FGt1U7r57IHu5GrYrG2nqU=
-github.com/yuin/goldmark v1.7.7/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
-go.mau.fi/util v0.8.1 h1:Ga43cz6esQBYqcjZ/onRoVnYWoUwjWbsxVeJg2jOTSo=
-go.mau.fi/util v0.8.1/go.mod h1:T1u/rD2rzidVrBLyaUdPpZiJdP/rsyi+aTzn0D+Q6wc=
-go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM=
-go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70=
-golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
-golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
-golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY=
-golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8=
-golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
-golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
-golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
-golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
+github.com/yuin/goldmark v1.7.16 h1:n+CJdUxaFMiDUNnWC3dMWCIQJSkxH4uz3ZwQBkAlVNE=
+github.com/yuin/goldmark v1.7.16/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
+go.mau.fi/util v0.9.6 h1:2nsvxm49KhI3wrFltr0+wSUBlnQ4CMtykuELjpIU+ts=
+go.mau.fi/util v0.9.6/go.mod h1:sIJpRH7Iy5Ad1SBuxQoatxtIeErgzxCtjd/2hCMkYMI=
+go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU=
+go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w=
+golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
+golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
+golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa h1:Zt3DZoOFFYkKhDT3v7Lm9FDMEV06GpzjG2jrqW+QTE0=
+golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa/go.mod h1:K79w1Vqn7PoiZn+TkNpx3BUWUQksGO3JcVX6qIjytmA=
+golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60=
+golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM=
+golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
+golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
-golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
-golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
-golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
+golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
+golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
+golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
+golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
diff --git a/hicli/cryptohelper.go b/hicli/cryptohelper.go
deleted file mode 100644
index 2a2e9626..00000000
--- a/hicli/cryptohelper.go
+++ /dev/null
@@ -1,65 +0,0 @@
-// Copyright (c) 2024 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package hicli
-
-import (
- "context"
- "fmt"
- "time"
-
- "github.com/rs/zerolog"
-
- "maunium.net/go/mautrix"
- "maunium.net/go/mautrix/event"
- "maunium.net/go/mautrix/id"
-)
-
-type hiCryptoHelper HiClient
-
-var _ mautrix.CryptoHelper = (*hiCryptoHelper)(nil)
-
-func (h *hiCryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtType event.Type, content any) (*event.EncryptedEventContent, error) {
- roomMeta, err := h.DB.Room.Get(ctx, roomID)
- if err != nil {
- return nil, fmt.Errorf("failed to get room metadata: %w", err)
- } else if roomMeta == nil {
- return nil, fmt.Errorf("unknown room")
- }
- return (*HiClient)(h).Encrypt(ctx, roomMeta, evtType, content)
-}
-
-func (h *hiCryptoHelper) Decrypt(ctx context.Context, evt *event.Event) (*event.Event, error) {
- return h.Crypto.DecryptMegolmEvent(ctx, evt)
-}
-
-func (h *hiCryptoHelper) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
- return h.Crypto.WaitForSession(ctx, roomID, senderKey, sessionID, timeout)
-}
-
-func (h *hiCryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) {
- err := h.Crypto.SendRoomKeyRequest(ctx, roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{
- userID: {deviceID},
- h.Account.UserID: {"*"},
- })
- if err != nil {
- zerolog.Ctx(ctx).Err(err).
- Stringer("room_id", roomID).
- Stringer("session_id", sessionID).
- Stringer("user_id", userID).
- Msg("Failed to send room key request")
- } else {
- zerolog.Ctx(ctx).Debug().
- Stringer("room_id", roomID).
- Stringer("session_id", sessionID).
- Stringer("user_id", userID).
- Msg("Sent room key request")
- }
-}
-
-func (h *hiCryptoHelper) Init(ctx context.Context) error {
- return nil
-}
diff --git a/hicli/database/account.go b/hicli/database/account.go
deleted file mode 100644
index 1dde74fd..00000000
--- a/hicli/database/account.go
+++ /dev/null
@@ -1,74 +0,0 @@
-// Copyright (c) 2024 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package database
-
-import (
- "context"
- "database/sql"
- "errors"
-
- "go.mau.fi/util/dbutil"
-
- "maunium.net/go/mautrix/id"
-)
-
-const (
- getAccountQuery = `SELECT user_id, device_id, access_token, homeserver_url, next_batch FROM account WHERE user_id = $1`
- putNextBatchQuery = `UPDATE account SET next_batch = $1 WHERE user_id = $2`
- upsertAccountQuery = `
- INSERT INTO account (user_id, device_id, access_token, homeserver_url, next_batch)
- VALUES ($1, $2, $3, $4, $5) ON CONFLICT (user_id)
- DO UPDATE SET device_id = excluded.device_id,
- access_token = excluded.access_token,
- homeserver_url = excluded.homeserver_url,
- next_batch = excluded.next_batch
- `
-)
-
-type AccountQuery struct {
- *dbutil.QueryHelper[*Account]
-}
-
-func (aq *AccountQuery) GetFirstUserID(ctx context.Context) (userID id.UserID, err error) {
- var exists bool
- if exists, err = aq.GetDB().TableExists(ctx, "account"); err != nil || !exists {
- return
- }
- err = aq.GetDB().QueryRow(ctx, `SELECT user_id FROM account LIMIT 1`).Scan(&userID)
- if errors.Is(err, sql.ErrNoRows) {
- err = nil
- }
- return
-}
-
-func (aq *AccountQuery) Get(ctx context.Context, userID id.UserID) (*Account, error) {
- return aq.QueryOne(ctx, getAccountQuery, userID)
-}
-
-func (aq *AccountQuery) PutNextBatch(ctx context.Context, userID id.UserID, nextBatch string) error {
- return aq.Exec(ctx, putNextBatchQuery, nextBatch, userID)
-}
-
-func (aq *AccountQuery) Put(ctx context.Context, account *Account) error {
- return aq.Exec(ctx, upsertAccountQuery, account.sqlVariables()...)
-}
-
-type Account struct {
- UserID id.UserID
- DeviceID id.DeviceID
- AccessToken string
- HomeserverURL string
- NextBatch string
-}
-
-func (a *Account) Scan(row dbutil.Scannable) (*Account, error) {
- return dbutil.ValueOrErr(a, row.Scan(&a.UserID, &a.DeviceID, &a.AccessToken, &a.HomeserverURL, &a.NextBatch))
-}
-
-func (a *Account) sqlVariables() []any {
- return []any{a.UserID, a.DeviceID, a.AccessToken, a.HomeserverURL, a.NextBatch}
-}
diff --git a/hicli/database/accountdata.go b/hicli/database/accountdata.go
deleted file mode 100644
index 8723b595..00000000
--- a/hicli/database/accountdata.go
+++ /dev/null
@@ -1,71 +0,0 @@
-// Copyright (c) 2024 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package database
-
-import (
- "context"
- "database/sql"
- "encoding/json"
- "unsafe"
-
- "go.mau.fi/util/dbutil"
-
- "maunium.net/go/mautrix/event"
- "maunium.net/go/mautrix/id"
-)
-
-const (
- upsertAccountDataQuery = `
- INSERT INTO account_data (user_id, type, content) VALUES ($1, $2, $3)
- ON CONFLICT (user_id, type) DO UPDATE SET content = excluded.content
- `
- upsertRoomAccountDataQuery = `
- INSERT INTO room_account_data (user_id, room_id, type, content) VALUES ($1, $2, $3, $4)
- ON CONFLICT (user_id, room_id, type) DO UPDATE SET content = excluded.content
- `
-)
-
-type AccountDataQuery struct {
- *dbutil.QueryHelper[*AccountData]
-}
-
-func unsafeJSONString(content json.RawMessage) *string {
- if content == nil {
- return nil
- }
- str := unsafe.String(unsafe.SliceData(content), len(content))
- return &str
-}
-
-func (adq *AccountDataQuery) Put(ctx context.Context, userID id.UserID, eventType event.Type, content json.RawMessage) error {
- return adq.Exec(ctx, upsertAccountDataQuery, userID, eventType.Type, unsafeJSONString(content))
-}
-
-func (adq *AccountDataQuery) PutRoom(ctx context.Context, userID id.UserID, roomID id.RoomID, eventType event.Type, content json.RawMessage) error {
- return adq.Exec(ctx, upsertRoomAccountDataQuery, userID, roomID, eventType.Type, unsafeJSONString(content))
-}
-
-type AccountData struct {
- UserID id.UserID `json:"user_id"`
- RoomID id.RoomID `json:"room_id,omitempty"`
- Type string `json:"type"`
- Content json.RawMessage `json:"content"`
-}
-
-func (a *AccountData) Scan(row dbutil.Scannable) (*AccountData, error) {
- var roomID sql.NullString
- err := row.Scan(&a.UserID, &roomID, &a.Type, (*[]byte)(&a.Content))
- if err != nil {
- return nil, err
- }
- a.RoomID = id.RoomID(roomID.String)
- return a, nil
-}
-
-func (a *AccountData) sqlVariables() []any {
- return []any{a.UserID, dbutil.StrPtr(a.RoomID), a.Type, unsafeJSONString(a.Content)}
-}
diff --git a/hicli/database/cachedmedia.go b/hicli/database/cachedmedia.go
deleted file mode 100644
index 2ccaca3b..00000000
--- a/hicli/database/cachedmedia.go
+++ /dev/null
@@ -1,150 +0,0 @@
-// Copyright (c) 2024 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package database
-
-import (
- "context"
- "database/sql"
- "net/http"
- "time"
-
- "go.mau.fi/util/dbutil"
- "go.mau.fi/util/jsontime"
- "golang.org/x/exp/slices"
-
- "maunium.net/go/mautrix"
- "maunium.net/go/mautrix/crypto/attachment"
- "maunium.net/go/mautrix/id"
-)
-
-const (
- insertCachedMediaQuery = `
- INSERT INTO cached_media (mxc, event_rowid, enc_file, file_name, mime_type, size, hash, error)
- VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
- ON CONFLICT (mxc) DO NOTHING
- `
- upsertCachedMediaQuery = `
- INSERT INTO cached_media (mxc, event_rowid, enc_file, file_name, mime_type, size, hash, error)
- VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
- ON CONFLICT (mxc) DO UPDATE
- SET enc_file = excluded.enc_file,
- file_name = excluded.file_name,
- mime_type = excluded.mime_type,
- size = excluded.size,
- hash = excluded.hash,
- error = excluded.error
- WHERE excluded.error IS NULL OR cached_media.hash IS NULL
- `
- getCachedMediaQuery = `
- SELECT mxc, event_rowid, enc_file, file_name, mime_type, size, hash, error
- FROM cached_media
- WHERE mxc = $1
- `
-)
-
-type CachedMediaQuery struct {
- *dbutil.QueryHelper[*CachedMedia]
-}
-
-func (cmq *CachedMediaQuery) Add(ctx context.Context, cm *CachedMedia) error {
- return cmq.Exec(ctx, insertCachedMediaQuery, cm.sqlVariables()...)
-}
-
-func (cmq *CachedMediaQuery) Put(ctx context.Context, cm *CachedMedia) error {
- return cmq.Exec(ctx, upsertCachedMediaQuery, cm.sqlVariables()...)
-}
-
-func (cmq *CachedMediaQuery) Get(ctx context.Context, mxc id.ContentURI) (*CachedMedia, error) {
- return cmq.QueryOne(ctx, getCachedMediaQuery, &mxc)
-}
-
-type MediaError struct {
- Matrix *mautrix.RespError `json:"data"`
- StatusCode int `json:"status_code"`
- ReceivedAt jsontime.UnixMilli `json:"received_at"`
- Attempts int `json:"attempts"`
-}
-
-const MaxMediaBackoff = 7 * 24 * time.Hour
-
-func (me *MediaError) backoff() time.Duration {
- return min(time.Duration(2< 0 {
- err = eq.Exec(ctx, updateReactionCountsQuery, evtID, dbutil.JSON{Data: &res.Counts})
- if err != nil {
- return err
- }
- }
- }
- return nil
- })
-}
-
-type EventRowID int64
-
-func (m EventRowID) GetMassInsertValues() [1]any {
- return [1]any{m}
-}
-
-type Event struct {
- RowID EventRowID `json:"rowid"`
- TimelineRowID TimelineRowID `json:"timeline_rowid"`
-
- RoomID id.RoomID `json:"room_id"`
- ID id.EventID `json:"event_id"`
- Sender id.UserID `json:"sender"`
- Type string `json:"type"`
- StateKey *string `json:"state_key,omitempty"`
- Timestamp jsontime.UnixMilli `json:"timestamp"`
-
- Content json.RawMessage `json:"content"`
- Decrypted json.RawMessage `json:"decrypted,omitempty"`
- DecryptedType string `json:"decrypted_type,omitempty"`
- Unsigned json.RawMessage `json:"unsigned,omitempty"`
-
- TransactionID string `json:"transaction_id,omitempty"`
-
- RedactedBy id.EventID `json:"redacted_by,omitempty"`
- RelatesTo id.EventID `json:"relates_to,omitempty"`
- RelationType event.RelationType `json:"relation_type,omitempty"`
-
- MegolmSessionID id.SessionID `json:"-,omitempty"`
- DecryptionError string `json:"decryption_error,omitempty"`
- SendError string `json:"send_error,omitempty"`
-
- Reactions map[string]int `json:"reactions,omitempty"`
- LastEditRowID *EventRowID `json:"last_edit_rowid,omitempty"`
-}
-
-func MautrixToEvent(evt *event.Event) *Event {
- dbEvt := &Event{
- RoomID: evt.RoomID,
- ID: evt.ID,
- Sender: evt.Sender,
- Type: evt.Type.Type,
- StateKey: evt.StateKey,
- Timestamp: jsontime.UM(time.UnixMilli(evt.Timestamp)),
- Content: evt.Content.VeryRaw,
- MegolmSessionID: getMegolmSessionID(evt),
- TransactionID: evt.Unsigned.TransactionID,
- }
- if !strings.HasPrefix(dbEvt.TransactionID, "hicli-mautrix-go_") {
- dbEvt.TransactionID = ""
- }
- dbEvt.RelatesTo, dbEvt.RelationType = getRelatesToFromEvent(evt)
- dbEvt.Unsigned, _ = json.Marshal(&evt.Unsigned)
- if evt.Unsigned.RedactedBecause != nil {
- dbEvt.RedactedBy = evt.Unsigned.RedactedBecause.ID
- }
- return dbEvt
-}
-
-func (e *Event) AsRawMautrix() *event.Event {
- evt := &event.Event{
- RoomID: e.RoomID,
- ID: e.ID,
- Sender: e.Sender,
- Type: event.Type{Type: e.Type, Class: event.MessageEventType},
- StateKey: e.StateKey,
- Timestamp: e.Timestamp.UnixMilli(),
- Content: event.Content{VeryRaw: e.Content},
- }
- if e.Decrypted != nil {
- evt.Content.VeryRaw = e.Decrypted
- evt.Type.Type = e.DecryptedType
- evt.Mautrix.WasEncrypted = true
- }
- if e.StateKey != nil {
- evt.Type.Class = event.StateEventType
- }
- _ = json.Unmarshal(e.Unsigned, &evt.Unsigned)
- return evt
-}
-
-func (e *Event) Scan(row dbutil.Scannable) (*Event, error) {
- var timestamp int64
- var transactionID, redactedBy, relatesTo, relationType, megolmSessionID, decryptionError, sendError, decryptedType sql.NullString
- err := row.Scan(
- &e.RowID,
- &e.TimelineRowID,
- &e.RoomID,
- &e.ID,
- &e.Sender,
- &e.Type,
- &e.StateKey,
- ×tamp,
- (*[]byte)(&e.Content),
- (*[]byte)(&e.Decrypted),
- &decryptedType,
- (*[]byte)(&e.Unsigned),
- &transactionID,
- &redactedBy,
- &relatesTo,
- &relationType,
- &megolmSessionID,
- &decryptionError,
- &sendError,
- dbutil.JSON{Data: &e.Reactions},
- &e.LastEditRowID,
- )
- if err != nil {
- return nil, err
- }
- e.Timestamp = jsontime.UM(time.UnixMilli(timestamp))
- e.TransactionID = transactionID.String
- e.RedactedBy = id.EventID(redactedBy.String)
- e.RelatesTo = id.EventID(relatesTo.String)
- e.RelationType = event.RelationType(relationType.String)
- e.MegolmSessionID = id.SessionID(megolmSessionID.String)
- e.DecryptedType = decryptedType.String
- e.DecryptionError = decryptionError.String
- e.SendError = sendError.String
- return e, nil
-}
-
-var relatesToPath = exgjson.Path("m.relates_to", "event_id")
-var relationTypePath = exgjson.Path("m.relates_to", "rel_type")
-
-func getRelatesToFromEvent(evt *event.Event) (id.EventID, event.RelationType) {
- if evt.StateKey != nil {
- return "", ""
- }
- return GetRelatesToFromBytes(evt.Content.VeryRaw)
-}
-
-func GetRelatesToFromBytes(content []byte) (id.EventID, event.RelationType) {
- results := gjson.GetManyBytes(content, relatesToPath, relationTypePath)
- if len(results) == 2 && results[0].Exists() && results[1].Exists() && results[0].Type == gjson.String && results[1].Type == gjson.String {
- return id.EventID(results[0].Str), event.RelationType(results[1].Str)
- }
- return "", ""
-}
-
-func getMegolmSessionID(evt *event.Event) id.SessionID {
- if evt.Type != event.EventEncrypted {
- return ""
- }
- res := gjson.GetBytes(evt.Content.VeryRaw, "session_id")
- if res.Exists() && res.Type == gjson.String {
- return id.SessionID(res.Str)
- }
- return ""
-}
-
-func (e *Event) sqlVariables() []any {
- var reactions any
- if e.Reactions != nil {
- reactions = e.Reactions
- }
- return []any{
- e.RoomID,
- e.ID,
- e.Sender,
- e.Type,
- e.StateKey,
- e.Timestamp.UnixMilli(),
- unsafeJSONString(e.Content),
- unsafeJSONString(e.Decrypted),
- dbutil.StrPtr(e.DecryptedType),
- unsafeJSONString(e.Unsigned),
- dbutil.StrPtr(e.TransactionID),
- dbutil.StrPtr(e.RedactedBy),
- dbutil.StrPtr(e.RelatesTo),
- dbutil.StrPtr(e.RelationType),
- dbutil.StrPtr(e.MegolmSessionID),
- dbutil.StrPtr(e.DecryptionError),
- dbutil.StrPtr(e.SendError),
- dbutil.JSON{Data: reactions},
- e.LastEditRowID,
- }
-}
-
-func (e *Event) CanUseForPreview() bool {
- return (e.Type == event.EventMessage.Type || e.Type == event.EventSticker.Type ||
- (e.Type == event.EventEncrypted.Type &&
- (e.DecryptedType == event.EventMessage.Type || e.DecryptedType == event.EventSticker.Type))) &&
- e.RelationType != event.RelReplace && e.RedactedBy == ""
-}
-
-func (e *Event) BumpsSortingTimestamp() bool {
- return (e.Type == event.EventMessage.Type || e.Type == event.EventSticker.Type || e.Type == event.EventEncrypted.Type) &&
- e.RelationType != event.RelReplace
-}
diff --git a/hicli/database/receipt.go b/hicli/database/receipt.go
deleted file mode 100644
index 8830efc7..00000000
--- a/hicli/database/receipt.go
+++ /dev/null
@@ -1,82 +0,0 @@
-// Copyright (c) 2024 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package database
-
-import (
- "context"
- "time"
-
- "go.mau.fi/util/dbutil"
- "go.mau.fi/util/exslices"
- "go.mau.fi/util/jsontime"
-
- "maunium.net/go/mautrix/event"
- "maunium.net/go/mautrix/id"
-)
-
-const (
- upsertReceiptQuery = `
- INSERT INTO receipt (room_id, user_id, receipt_type, thread_id, event_id, timestamp)
- VALUES ($1, $2, $3, $4, $5, $6)
- ON CONFLICT (room_id, user_id, receipt_type, thread_id) DO UPDATE
- SET event_id = excluded.event_id,
- timestamp = excluded.timestamp
- `
-)
-
-var receiptMassInserter = dbutil.NewMassInsertBuilder[*Receipt, [1]any](upsertReceiptQuery, "($1, $%d, $%d, $%d, $%d, $%d)")
-
-type ReceiptQuery struct {
- *dbutil.QueryHelper[*Receipt]
-}
-
-func (rq *ReceiptQuery) Put(ctx context.Context, receipt *Receipt) error {
- return rq.Exec(ctx, upsertReceiptQuery, receipt.sqlVariables()...)
-}
-
-func (rq *ReceiptQuery) PutMany(ctx context.Context, roomID id.RoomID, receipts ...*Receipt) error {
- if len(receipts) > 1000 {
- return rq.GetDB().DoTxn(ctx, nil, func(ctx context.Context) error {
- for _, receiptChunk := range exslices.Chunk(receipts, 200) {
- err := rq.PutMany(ctx, roomID, receiptChunk...)
- if err != nil {
- return err
- }
- }
- return nil
- })
- }
- query, params := receiptMassInserter.Build([1]any{roomID}, receipts)
- return rq.Exec(ctx, query, params...)
-}
-
-type Receipt struct {
- RoomID id.RoomID `json:"room_id"`
- UserID id.UserID `json:"user_id"`
- ReceiptType event.ReceiptType `json:"receipt_type"`
- ThreadID event.ThreadID `json:"thread_id"`
- EventID id.EventID `json:"event_id"`
- Timestamp jsontime.UnixMilli `json:"timestamp"`
-}
-
-func (r *Receipt) Scan(row dbutil.Scannable) (*Receipt, error) {
- var ts int64
- err := row.Scan(&r.RoomID, &r.UserID, &r.ReceiptType, &r.ThreadID, &r.EventID, &ts)
- if err != nil {
- return nil, err
- }
- r.Timestamp = jsontime.UM(time.UnixMilli(ts))
- return r, nil
-}
-
-func (r *Receipt) sqlVariables() []any {
- return []any{r.RoomID, r.UserID, r.ReceiptType, r.ThreadID, r.EventID, r.Timestamp.UnixMilli()}
-}
-
-func (r *Receipt) GetMassInsertValues() [5]any {
- return [5]any{r.UserID, r.ReceiptType, r.ThreadID, r.EventID, r.Timestamp.UnixMilli()}
-}
diff --git a/hicli/database/room.go b/hicli/database/room.go
deleted file mode 100644
index d9293cf8..00000000
--- a/hicli/database/room.go
+++ /dev/null
@@ -1,255 +0,0 @@
-// Copyright (c) 2024 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package database
-
-import (
- "context"
- "database/sql"
- "errors"
- "time"
-
- "go.mau.fi/util/dbutil"
- "go.mau.fi/util/jsontime"
-
- "maunium.net/go/mautrix"
- "maunium.net/go/mautrix/event"
- "maunium.net/go/mautrix/id"
-)
-
-const (
- getRoomBaseQuery = `
- SELECT room_id, creation_content, name, name_quality, avatar, explicit_avatar, topic, canonical_alias,
- lazy_load_summary, encryption_event, has_member_list,
- preview_event_rowid, sorting_timestamp, prev_batch
- FROM room
- `
- getRoomsBySortingTimestampQuery = getRoomBaseQuery + `WHERE sorting_timestamp < $1 AND sorting_timestamp > 0 ORDER BY sorting_timestamp DESC LIMIT $2`
- getRoomByIDQuery = getRoomBaseQuery + `WHERE room_id = $1`
- ensureRoomExistsQuery = `
- INSERT INTO room (room_id) VALUES ($1)
- ON CONFLICT (room_id) DO NOTHING
- `
- upsertRoomFromSyncQuery = `
- UPDATE room
- SET creation_content = COALESCE(room.creation_content, $2),
- name = COALESCE($3, room.name),
- name_quality = CASE WHEN $3 IS NOT NULL THEN $4 ELSE room.name_quality END,
- avatar = COALESCE($5, room.avatar),
- explicit_avatar = CASE WHEN $5 IS NOT NULL THEN $6 ELSE room.explicit_avatar END,
- topic = COALESCE($7, room.topic),
- canonical_alias = COALESCE($8, room.canonical_alias),
- lazy_load_summary = COALESCE($9, room.lazy_load_summary),
- encryption_event = COALESCE($10, room.encryption_event),
- has_member_list = room.has_member_list OR $11,
- preview_event_rowid = COALESCE($12, room.preview_event_rowid),
- sorting_timestamp = COALESCE($13, room.sorting_timestamp),
- prev_batch = COALESCE($14, room.prev_batch)
- WHERE room_id = $1
- `
- setRoomPrevBatchQuery = `
- UPDATE room SET prev_batch = $2 WHERE room_id = $1
- `
- updateRoomPreviewIfLaterOnTimelineQuery = `
- UPDATE room
- SET preview_event_rowid = $2
- WHERE room_id = $1
- AND COALESCE((SELECT rowid FROM timeline WHERE event_rowid = $2), -1)
- > COALESCE((SELECT rowid FROM timeline WHERE event_rowid = preview_event_rowid), 0)
- RETURNING preview_event_rowid
- `
- recalculateRoomPreviewEventQuery = `
- SELECT rowid
- FROM event
- WHERE
- room_id = $1
- AND (type IN ('m.room.message', 'm.sticker')
- OR (type = 'm.room.encrypted'
- AND decrypted_type IN ('m.room.message', 'm.sticker')))
- AND relation_type <> 'm.replace'
- AND redacted_by IS NULL
- ORDER BY timestamp DESC
- LIMIT 1
- `
-)
-
-type RoomQuery struct {
- *dbutil.QueryHelper[*Room]
-}
-
-func (rq *RoomQuery) Get(ctx context.Context, roomID id.RoomID) (*Room, error) {
- return rq.QueryOne(ctx, getRoomByIDQuery, roomID)
-}
-
-func (rq *RoomQuery) GetBySortTS(ctx context.Context, maxTS time.Time, limit int) ([]*Room, error) {
- return rq.QueryMany(ctx, getRoomsBySortingTimestampQuery, maxTS.UnixMilli(), limit)
-}
-
-func (rq *RoomQuery) Upsert(ctx context.Context, room *Room) error {
- return rq.Exec(ctx, upsertRoomFromSyncQuery, room.sqlVariables()...)
-}
-
-func (rq *RoomQuery) CreateRow(ctx context.Context, roomID id.RoomID) error {
- return rq.Exec(ctx, ensureRoomExistsQuery, roomID)
-}
-
-func (rq *RoomQuery) SetPrevBatch(ctx context.Context, roomID id.RoomID, prevBatch string) error {
- return rq.Exec(ctx, setRoomPrevBatchQuery, roomID, prevBatch)
-}
-
-func (rq *RoomQuery) UpdatePreviewIfLaterOnTimeline(ctx context.Context, roomID id.RoomID, rowID EventRowID) (previewChanged bool, err error) {
- var newPreviewRowID EventRowID
- err = rq.GetDB().QueryRow(ctx, updateRoomPreviewIfLaterOnTimelineQuery, roomID, rowID).Scan(&newPreviewRowID)
- if errors.Is(err, sql.ErrNoRows) {
- err = nil
- } else if err == nil {
- previewChanged = newPreviewRowID == rowID
- }
- return
-}
-
-func (rq *RoomQuery) RecalculatePreview(ctx context.Context, roomID id.RoomID) (rowID EventRowID, err error) {
- err = rq.GetDB().QueryRow(ctx, recalculateRoomPreviewEventQuery, roomID).Scan(&rowID)
- return
-}
-
-type NameQuality int
-
-const (
- NameQualityNil NameQuality = iota
- NameQualityParticipants
- NameQualityCanonicalAlias
- NameQualityExplicit
-)
-
-const PrevBatchPaginationComplete = "fi.mau.gomuks.pagination_complete"
-
-type Room struct {
- ID id.RoomID `json:"room_id"`
- CreationContent *event.CreateEventContent `json:"creation_content,omitempty"`
-
- Name *string `json:"name,omitempty"`
- NameQuality NameQuality `json:"name_quality"`
- Avatar *id.ContentURI `json:"avatar,omitempty"`
- ExplicitAvatar bool `json:"explicit_avatar"`
- Topic *string `json:"topic,omitempty"`
- CanonicalAlias *id.RoomAlias `json:"canonical_alias,omitempty"`
-
- LazyLoadSummary *mautrix.LazyLoadSummary `json:"lazy_load_summary,omitempty"`
-
- EncryptionEvent *event.EncryptionEventContent `json:"encryption_event,omitempty"`
- HasMemberList bool `json:"has_member_list"`
-
- PreviewEventRowID EventRowID `json:"preview_event_rowid"`
- SortingTimestamp jsontime.UnixMilli `json:"sorting_timestamp"`
-
- PrevBatch string `json:"prev_batch"`
-}
-
-func (r *Room) CheckChangesAndCopyInto(other *Room) (hasChanges bool) {
- if r.Name != nil && r.NameQuality >= other.NameQuality {
- other.Name = r.Name
- other.NameQuality = r.NameQuality
- hasChanges = true
- }
- if r.Avatar != nil {
- other.Avatar = r.Avatar
- other.ExplicitAvatar = r.ExplicitAvatar
- hasChanges = true
- }
- if r.Topic != nil {
- other.Topic = r.Topic
- hasChanges = true
- }
- if r.CanonicalAlias != nil {
- other.CanonicalAlias = r.CanonicalAlias
- hasChanges = true
- }
- if r.LazyLoadSummary != nil {
- other.LazyLoadSummary = r.LazyLoadSummary
- hasChanges = true
- }
- if r.EncryptionEvent != nil && other.EncryptionEvent == nil {
- other.EncryptionEvent = r.EncryptionEvent
- hasChanges = true
- }
- if r.HasMemberList && !other.HasMemberList {
- hasChanges = true
- other.HasMemberList = true
- }
- if r.PreviewEventRowID > other.PreviewEventRowID {
- other.PreviewEventRowID = r.PreviewEventRowID
- hasChanges = true
- }
- if r.SortingTimestamp.After(other.SortingTimestamp.Time) {
- other.SortingTimestamp = r.SortingTimestamp
- hasChanges = true
- }
- if r.PrevBatch != "" && other.PrevBatch == "" {
- other.PrevBatch = r.PrevBatch
- hasChanges = true
- }
- return
-}
-
-func (r *Room) Scan(row dbutil.Scannable) (*Room, error) {
- var prevBatch sql.NullString
- var previewEventRowID, sortingTimestamp sql.NullInt64
- err := row.Scan(
- &r.ID,
- dbutil.JSON{Data: &r.CreationContent},
- &r.Name,
- &r.NameQuality,
- &r.Avatar,
- &r.ExplicitAvatar,
- &r.Topic,
- &r.CanonicalAlias,
- dbutil.JSON{Data: &r.LazyLoadSummary},
- dbutil.JSON{Data: &r.EncryptionEvent},
- &r.HasMemberList,
- &previewEventRowID,
- &sortingTimestamp,
- &prevBatch,
- )
- if err != nil {
- return nil, err
- }
- r.PrevBatch = prevBatch.String
- r.PreviewEventRowID = EventRowID(previewEventRowID.Int64)
- r.SortingTimestamp = jsontime.UM(time.UnixMilli(sortingTimestamp.Int64))
- return r, nil
-}
-
-func (r *Room) sqlVariables() []any {
- return []any{
- r.ID,
- dbutil.JSONPtr(r.CreationContent),
- r.Name,
- r.NameQuality,
- r.Avatar,
- r.ExplicitAvatar,
- r.Topic,
- r.CanonicalAlias,
- dbutil.JSONPtr(r.LazyLoadSummary),
- dbutil.JSONPtr(r.EncryptionEvent),
- r.HasMemberList,
- dbutil.NumPtr(r.PreviewEventRowID),
- dbutil.UnixMilliPtr(r.SortingTimestamp.Time),
- dbutil.StrPtr(r.PrevBatch),
- }
-}
-
-func (r *Room) BumpSortingTimestamp(evt *Event) bool {
- if !evt.BumpsSortingTimestamp() || evt.Timestamp.Before(r.SortingTimestamp.Time) {
- return false
- }
- r.SortingTimestamp = evt.Timestamp
- now := time.Now()
- if r.SortingTimestamp.After(now) {
- r.SortingTimestamp = jsontime.UM(now)
- }
- return true
-}
diff --git a/hicli/database/sessionrequest.go b/hicli/database/sessionrequest.go
deleted file mode 100644
index 6690c13f..00000000
--- a/hicli/database/sessionrequest.go
+++ /dev/null
@@ -1,69 +0,0 @@
-// Copyright (c) 2024 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package database
-
-import (
- "context"
-
- "go.mau.fi/util/dbutil"
-
- "maunium.net/go/mautrix/id"
-)
-
-const (
- putSessionRequestQueueEntry = `
- INSERT INTO session_request (room_id, session_id, sender, min_index, backup_checked, request_sent)
- VALUES ($1, $2, $3, $4, $5, $6)
- ON CONFLICT (session_id) DO UPDATE
- SET min_index = MIN(excluded.min_index, session_request.min_index),
- backup_checked = excluded.backup_checked OR session_request.backup_checked,
- request_sent = excluded.request_sent OR session_request.request_sent
- `
- removeSessionRequestQuery = `
- DELETE FROM session_request WHERE session_id = $1 AND min_index >= $2
- `
- getNextSessionsToRequestQuery = `
- SELECT room_id, session_id, sender, min_index, backup_checked, request_sent
- FROM session_request
- WHERE request_sent = false OR backup_checked = false
- ORDER BY backup_checked, rowid
- LIMIT $1
- `
-)
-
-type SessionRequestQuery struct {
- *dbutil.QueryHelper[*SessionRequest]
-}
-
-func (srq *SessionRequestQuery) Next(ctx context.Context, count int) ([]*SessionRequest, error) {
- return srq.QueryMany(ctx, getNextSessionsToRequestQuery, count)
-}
-
-func (srq *SessionRequestQuery) Remove(ctx context.Context, sessionID id.SessionID, minIndex uint32) error {
- return srq.Exec(ctx, removeSessionRequestQuery, sessionID, minIndex)
-}
-
-func (srq *SessionRequestQuery) Put(ctx context.Context, sr *SessionRequest) error {
- return srq.Exec(ctx, putSessionRequestQueueEntry, sr.sqlVariables()...)
-}
-
-type SessionRequest struct {
- RoomID id.RoomID
- SessionID id.SessionID
- Sender id.UserID
- MinIndex uint32
- BackupChecked bool
- RequestSent bool
-}
-
-func (s *SessionRequest) Scan(row dbutil.Scannable) (*SessionRequest, error) {
- return dbutil.ValueOrErr(s, row.Scan(&s.RoomID, &s.SessionID, &s.Sender, &s.MinIndex, &s.BackupChecked, &s.RequestSent))
-}
-
-func (s *SessionRequest) sqlVariables() []any {
- return []any{s.RoomID, s.SessionID, s.Sender, s.MinIndex, s.BackupChecked, s.RequestSent}
-}
diff --git a/hicli/database/state.go b/hicli/database/state.go
deleted file mode 100644
index c12f9f60..00000000
--- a/hicli/database/state.go
+++ /dev/null
@@ -1,93 +0,0 @@
-// Copyright (c) 2024 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package database
-
-import (
- "context"
- "fmt"
-
- "go.mau.fi/util/dbutil"
- "go.mau.fi/util/exslices"
-
- "maunium.net/go/mautrix/event"
- "maunium.net/go/mautrix/id"
-)
-
-const (
- setCurrentStateQuery = `
- INSERT INTO current_state (room_id, event_type, state_key, event_rowid, membership) VALUES ($1, $2, $3, $4, $5)
- ON CONFLICT (room_id, event_type, state_key) DO UPDATE SET event_rowid = excluded.event_rowid, membership = excluded.membership
- `
- addCurrentStateQuery = `
- INSERT INTO current_state (room_id, event_type, state_key, event_rowid, membership) VALUES ($1, $2, $3, $4, $5)
- ON CONFLICT DO NOTHING
- `
- deleteCurrentStateQuery = `
- DELETE FROM current_state WHERE room_id = $1
- `
- getCurrentRoomStateQuery = `
- SELECT event.rowid, -1, event.room_id, event.event_id, sender, event.type, event.state_key, timestamp, content, decrypted, decrypted_type, unsigned,
- transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid
- FROM current_state cs
- JOIN event ON cs.event_rowid = event.rowid
- WHERE cs.room_id = $1
- `
- getCurrentStateEventQuery = getCurrentRoomStateQuery + `AND cs.event_type = $2 AND cs.state_key = $3`
-)
-
-var massInsertCurrentStateBuilder = dbutil.NewMassInsertBuilder[*CurrentStateEntry, [1]any](addCurrentStateQuery, "($1, $%d, $%d, $%d, $%d)")
-
-const currentStateMassInsertBatchSize = 1000
-
-type CurrentStateEntry struct {
- EventType event.Type
- StateKey string
- EventRowID EventRowID
- Membership event.Membership
-}
-
-func (cse *CurrentStateEntry) GetMassInsertValues() [4]any {
- return [4]any{cse.EventType.Type, cse.StateKey, cse.EventRowID, dbutil.StrPtr(cse.Membership)}
-}
-
-type CurrentStateQuery struct {
- *dbutil.QueryHelper[*Event]
-}
-
-func (csq *CurrentStateQuery) Set(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, eventRowID EventRowID, membership event.Membership) error {
- return csq.Exec(ctx, setCurrentStateQuery, roomID, eventType.Type, stateKey, eventRowID, dbutil.StrPtr(membership))
-}
-
-func (csq *CurrentStateQuery) AddMany(ctx context.Context, roomID id.RoomID, deleteOld bool, entries []*CurrentStateEntry) error {
- var err error
- if deleteOld {
- err = csq.Exec(ctx, deleteCurrentStateQuery, roomID)
- if err != nil {
- return fmt.Errorf("failed to delete old state: %w", err)
- }
- }
- for _, entryChunk := range exslices.Chunk(entries, currentStateMassInsertBatchSize) {
- query, params := massInsertCurrentStateBuilder.Build([1]any{roomID}, entryChunk)
- err = csq.Exec(ctx, query, params...)
- if err != nil {
- return err
- }
- }
- return nil
-}
-
-func (csq *CurrentStateQuery) Add(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, eventRowID EventRowID, membership event.Membership) error {
- return csq.Exec(ctx, addCurrentStateQuery, roomID, eventType.Type, stateKey, eventRowID, dbutil.StrPtr(membership))
-}
-
-func (csq *CurrentStateQuery) Get(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*Event, error) {
- return csq.QueryOne(ctx, getCurrentStateEventQuery, roomID, eventType.Type, stateKey)
-}
-
-func (csq *CurrentStateQuery) GetAll(ctx context.Context, roomID id.RoomID) ([]*Event, error) {
- return csq.QueryMany(ctx, getCurrentRoomStateQuery, roomID)
-}
diff --git a/hicli/database/statestore.go b/hicli/database/statestore.go
deleted file mode 100644
index fcd6aceb..00000000
--- a/hicli/database/statestore.go
+++ /dev/null
@@ -1,188 +0,0 @@
-// Copyright (c) 2024 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package database
-
-import (
- "context"
- "database/sql"
- "errors"
- "fmt"
-
- "go.mau.fi/util/dbutil"
- "golang.org/x/exp/slices"
-
- "maunium.net/go/mautrix"
- "maunium.net/go/mautrix/crypto"
- "maunium.net/go/mautrix/event"
- "maunium.net/go/mautrix/id"
-)
-
-const (
- getMembershipQuery = `
- SELECT membership FROM current_state
- WHERE room_id = $1 AND event_type = 'm.room.member' AND state_key = $2
- `
- getStateEventContentQuery = `
- SELECT event.content FROM current_state cs
- LEFT JOIN event ON event.rowid = cs.event_rowid
- WHERE cs.room_id = $1 AND cs.event_type = $2 AND cs.state_key = $3
- `
- getRoomJoinedMembersQuery = `
- SELECT state_key FROM current_state
- WHERE room_id = $1 AND event_type = 'm.room.member' AND membership = 'join'
- `
- getRoomJoinedOrInvitedMembersQuery = `
- SELECT state_key FROM current_state
- WHERE room_id = $1 AND event_type = 'm.room.member' AND membership IN ('join', 'invite')
- `
- getHasFetchedMembersQuery = `
- SELECT has_member_list FROM room WHERE room_id = $1
- `
- isRoomEncryptedQuery = `
- SELECT room.encryption_event IS NOT NULL FROM room WHERE room_id = $1
- `
- getRoomEncryptionEventQuery = `
- SELECT room.encryption_event FROM room WHERE room_id = $1
- `
- findSharedRoomsQuery = `
- SELECT room_id FROM current_state
- WHERE event_type = 'm.room.member' AND state_key = $1 AND membership = 'join'
- `
-)
-
-type ClientStateStore struct {
- *Database
-}
-
-var _ mautrix.StateStore = (*ClientStateStore)(nil)
-var _ mautrix.StateStoreUpdater = (*ClientStateStore)(nil)
-var _ crypto.StateStore = (*ClientStateStore)(nil)
-
-func (c *ClientStateStore) IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool {
- return c.IsMembership(ctx, roomID, userID, event.MembershipJoin)
-}
-
-func (c *ClientStateStore) IsInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) bool {
- return c.IsMembership(ctx, roomID, userID, event.MembershipInvite, event.MembershipJoin)
-}
-
-func (c *ClientStateStore) IsMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
- var membership event.Membership
- err := c.QueryRow(ctx, getMembershipQuery, roomID, userID).Scan(&membership)
- if errors.Is(err, sql.ErrNoRows) {
- err = nil
- membership = event.MembershipLeave
- }
- return slices.Contains(allowedMemberships, membership)
-}
-
-func (c *ClientStateStore) GetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) {
- content, err := c.TryGetMember(ctx, roomID, userID)
- if content == nil {
- content = &event.MemberEventContent{Membership: event.MembershipLeave}
- }
- return content, err
-}
-
-func (c *ClientStateStore) TryGetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (content *event.MemberEventContent, err error) {
- err = c.QueryRow(ctx, getStateEventContentQuery, roomID, event.StateMember.Type, userID).Scan(&dbutil.JSON{Data: &content})
- if errors.Is(err, sql.ErrNoRows) {
- err = nil
- }
- return
-}
-
-func (c *ClientStateStore) IsConfusableName(ctx context.Context, roomID id.RoomID, currentUser id.UserID, name string) ([]id.UserID, error) {
- //TODO implement me
- panic("implement me")
-}
-
-func (c *ClientStateStore) GetPowerLevels(ctx context.Context, roomID id.RoomID) (content *event.PowerLevelsEventContent, err error) {
- err = c.QueryRow(ctx, getStateEventContentQuery, roomID, event.StatePowerLevels.Type, "").Scan(&dbutil.JSON{Data: &content})
- if errors.Is(err, sql.ErrNoRows) {
- err = nil
- }
- return
-}
-
-func (c *ClientStateStore) GetRoomJoinedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error) {
- rows, err := c.Query(ctx, getRoomJoinedMembersQuery, roomID)
- return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList()
-}
-
-func (c *ClientStateStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error) {
- rows, err := c.Query(ctx, getRoomJoinedOrInvitedMembersQuery, roomID)
- return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList()
-}
-
-func (c *ClientStateStore) HasFetchedMembers(ctx context.Context, roomID id.RoomID) (hasFetched bool, err error) {
- //err = c.QueryRow(ctx, getHasFetchedMembersQuery, roomID).Scan(&hasFetched)
- //if errors.Is(err, sql.ErrNoRows) {
- // err = nil
- //}
- //return
- return false, fmt.Errorf("not implemented")
-}
-
-func (c *ClientStateStore) MarkMembersFetched(ctx context.Context, roomID id.RoomID) error {
- return fmt.Errorf("not implemented")
-}
-
-func (c *ClientStateStore) GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) {
- return nil, fmt.Errorf("not implemented")
-}
-
-func (c *ClientStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (isEncrypted bool, err error) {
- err = c.QueryRow(ctx, isRoomEncryptedQuery, roomID).Scan(&isEncrypted)
- if errors.Is(err, sql.ErrNoRows) {
- err = nil
- }
- return
-}
-
-func (c *ClientStateStore) GetEncryptionEvent(ctx context.Context, roomID id.RoomID) (content *event.EncryptionEventContent, err error) {
- err = c.QueryRow(ctx, getRoomEncryptionEventQuery, roomID).
- Scan(&dbutil.JSON{Data: &content})
- if errors.Is(err, sql.ErrNoRows) {
- err = nil
- }
- return
-}
-
-func (c *ClientStateStore) FindSharedRooms(ctx context.Context, userID id.UserID) ([]id.RoomID, error) {
- // TODO for multiuser support, this might need to filter by the local user's membership
- rows, err := c.Query(ctx, findSharedRoomsQuery, userID)
- return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.RoomID], err).AsList()
-}
-
-// Update methods are all intentionally no-ops as the state store wants to have the full event
-
-func (c *ClientStateStore) SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error {
- return nil
-}
-
-func (c *ClientStateStore) SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error {
- return nil
-}
-
-func (c *ClientStateStore) ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error {
- return nil
-}
-
-func (c *ClientStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error {
- return nil
-}
-
-func (c *ClientStateStore) SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error {
- return nil
-}
-
-func (c *ClientStateStore) UpdateState(ctx context.Context, evt *event.Event) {}
-
-func (c *ClientStateStore) ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error {
- return nil
-}
diff --git a/hicli/database/timeline.go b/hicli/database/timeline.go
deleted file mode 100644
index 0a01c7f5..00000000
--- a/hicli/database/timeline.go
+++ /dev/null
@@ -1,132 +0,0 @@
-// Copyright (c) 2024 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package database
-
-import (
- "context"
- "database/sql"
- "errors"
- "sync"
-
- "go.mau.fi/util/dbutil"
-
- "maunium.net/go/mautrix/id"
-)
-
-const (
- clearTimelineQuery = `
- DELETE FROM timeline WHERE room_id = $1
- `
- appendTimelineQuery = `
- INSERT INTO timeline (room_id, event_rowid) VALUES ($1, $2) RETURNING rowid, event_rowid
- `
- prependTimelineQuery = `
- INSERT INTO timeline (room_id, rowid, event_rowid) VALUES ($1, $2, $3)
- `
- checkTimelineContainsQuery = `
- SELECT EXISTS(SELECT 1 FROM timeline WHERE room_id = $1 AND event_rowid = $2)
- `
- findMinRowIDQuery = `SELECT MIN(rowid) FROM timeline`
- getTimelineQuery = `
- SELECT event.rowid, timeline.rowid, event.room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned,
- transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid
- FROM timeline
- JOIN event ON event.rowid = timeline.event_rowid
- WHERE timeline.room_id = $1 AND ($2 = 0 OR timeline.rowid < $2)
- ORDER BY timeline.rowid DESC
- LIMIT $3
- `
-)
-
-type TimelineRowID int64
-
-type TimelineRowTuple struct {
- Timeline TimelineRowID `json:"timeline_rowid"`
- Event EventRowID `json:"event_rowid"`
-}
-
-var timelineRowTupleScanner = dbutil.ConvertRowFn[TimelineRowTuple](func(row dbutil.Scannable) (trt TimelineRowTuple, err error) {
- err = row.Scan(&trt.Timeline, &trt.Event)
- return
-})
-
-func (trt TimelineRowTuple) GetMassInsertValues() [2]any {
- return [2]any{trt.Timeline, trt.Event}
-}
-
-var appendTimelineQueryBuilder = dbutil.NewMassInsertBuilder[EventRowID, [1]any](appendTimelineQuery, "($1, $%d)")
-var prependTimelineQueryBuilder = dbutil.NewMassInsertBuilder[TimelineRowTuple, [1]any](prependTimelineQuery, "($1, $%d, $%d)")
-
-type TimelineQuery struct {
- *dbutil.QueryHelper[*Event]
-
- minRowID TimelineRowID
- minRowIDFound bool
- prependLock sync.Mutex
-}
-
-// Clear clears the timeline of a given room.
-func (tq *TimelineQuery) Clear(ctx context.Context, roomID id.RoomID) error {
- return tq.Exec(ctx, clearTimelineQuery, roomID)
-}
-
-func (tq *TimelineQuery) reserveRowIDs(ctx context.Context, count int) (startFrom TimelineRowID, err error) {
- tq.prependLock.Lock()
- defer tq.prependLock.Unlock()
- if !tq.minRowIDFound {
- err = tq.GetDB().QueryRow(ctx, findMinRowIDQuery).Scan(&tq.minRowID)
- if err != nil && !errors.Is(err, sql.ErrNoRows) {
- return
- }
- if tq.minRowID >= 0 {
- // No negative row IDs exist, start at -2
- tq.minRowID = -2
- } else {
- // We fetched the lowest row ID, but we want the next available one, so decrement one
- tq.minRowID--
- }
- tq.minRowIDFound = true
- }
- startFrom = tq.minRowID
- tq.minRowID -= TimelineRowID(count)
- return
-}
-
-// Prepend adds the given event row IDs to the beginning of the timeline.
-// The events must be sorted in reverse chronological order (newest event first).
-func (tq *TimelineQuery) Prepend(ctx context.Context, roomID id.RoomID, rowIDs []EventRowID) (prependEntries []TimelineRowTuple, err error) {
- var startFrom TimelineRowID
- startFrom, err = tq.reserveRowIDs(ctx, len(rowIDs))
- if err != nil {
- return
- }
- prependEntries = make([]TimelineRowTuple, len(rowIDs))
- for i, rowID := range rowIDs {
- prependEntries[i] = TimelineRowTuple{
- Timeline: startFrom - TimelineRowID(i),
- Event: rowID,
- }
- }
- query, params := prependTimelineQueryBuilder.Build([1]any{roomID}, prependEntries)
- err = tq.Exec(ctx, query, params...)
- return
-}
-
-// Append adds the given event row IDs to the end of the timeline.
-func (tq *TimelineQuery) Append(ctx context.Context, roomID id.RoomID, rowIDs []EventRowID) ([]TimelineRowTuple, error) {
- query, params := appendTimelineQueryBuilder.Build([1]any{roomID}, rowIDs)
- return timelineRowTupleScanner.NewRowIter(tq.GetDB().Query(ctx, query, params...)).AsList()
-}
-
-func (tq *TimelineQuery) Get(ctx context.Context, roomID id.RoomID, limit int, before TimelineRowID) ([]*Event, error) {
- return tq.QueryMany(ctx, getTimelineQuery, roomID, before, limit)
-}
-
-func (tq *TimelineQuery) Has(ctx context.Context, roomID id.RoomID, eventRowID EventRowID) (exists bool, err error) {
- err = tq.GetDB().QueryRow(ctx, checkTimelineContainsQuery, roomID, eventRowID).Scan(&exists)
- return
-}
diff --git a/hicli/database/upgrades/00-latest-revision.sql b/hicli/database/upgrades/00-latest-revision.sql
deleted file mode 100644
index f8c84a61..00000000
--- a/hicli/database/upgrades/00-latest-revision.sql
+++ /dev/null
@@ -1,248 +0,0 @@
--- v0 -> v2 (compatible with v1+): Latest revision
-CREATE TABLE account (
- user_id TEXT NOT NULL PRIMARY KEY,
- device_id TEXT NOT NULL,
- access_token TEXT NOT NULL,
- homeserver_url TEXT NOT NULL,
-
- next_batch TEXT NOT NULL
-) STRICT;
-
-CREATE TABLE room (
- room_id TEXT NOT NULL PRIMARY KEY,
- creation_content TEXT,
-
- name TEXT,
- name_quality INTEGER NOT NULL DEFAULT 0,
- avatar TEXT,
- explicit_avatar INTEGER NOT NULL DEFAULT 0,
- topic TEXT,
- canonical_alias TEXT,
- lazy_load_summary TEXT,
-
- encryption_event TEXT,
- has_member_list INTEGER NOT NULL DEFAULT false,
-
- preview_event_rowid INTEGER,
- sorting_timestamp INTEGER,
-
- prev_batch TEXT,
-
- CONSTRAINT room_preview_event_fkey FOREIGN KEY (preview_event_rowid) REFERENCES event (rowid) ON DELETE SET NULL
-) STRICT;
-CREATE INDEX room_type_idx ON room (creation_content ->> 'type');
-CREATE INDEX room_sorting_timestamp_idx ON room (sorting_timestamp DESC);
-
-CREATE TABLE account_data (
- user_id TEXT NOT NULL,
- type TEXT NOT NULL,
- content TEXT NOT NULL,
-
- PRIMARY KEY (user_id, type)
-) STRICT;
-
-CREATE TABLE room_account_data (
- user_id TEXT NOT NULL,
- room_id TEXT NOT NULL,
- type TEXT NOT NULL,
- content TEXT NOT NULL,
-
- PRIMARY KEY (user_id, room_id, type),
- CONSTRAINT room_account_data_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE
-) STRICT;
-CREATE INDEX room_account_data_room_id_idx ON room_account_data (room_id);
-
-CREATE TABLE event (
- rowid INTEGER PRIMARY KEY,
-
- room_id TEXT NOT NULL,
- event_id TEXT NOT NULL,
- sender TEXT NOT NULL,
- type TEXT NOT NULL,
- state_key TEXT,
- timestamp INTEGER NOT NULL,
-
- content TEXT NOT NULL,
- decrypted TEXT,
- decrypted_type TEXT,
- unsigned TEXT NOT NULL,
-
- transaction_id TEXT,
-
- redacted_by TEXT,
- relates_to TEXT,
- relation_type TEXT,
-
- megolm_session_id TEXT,
- decryption_error TEXT,
- send_error TEXT,
-
- reactions TEXT,
- last_edit_rowid INTEGER,
-
- CONSTRAINT event_id_unique_key UNIQUE (event_id),
- CONSTRAINT transaction_id_unique_key UNIQUE (transaction_id),
- CONSTRAINT event_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE
-) STRICT;
-CREATE INDEX event_room_id_idx ON event (room_id);
-CREATE INDEX event_redacted_by_idx ON event (room_id, redacted_by);
-CREATE INDEX event_relates_to_idx ON event (room_id, relates_to);
-CREATE INDEX event_megolm_session_id_idx ON event (room_id, megolm_session_id);
-
-CREATE TRIGGER event_update_redacted_by
- AFTER INSERT
- ON event
- WHEN NEW.type = 'm.room.redaction'
-BEGIN
- UPDATE event SET redacted_by = NEW.event_id WHERE room_id = NEW.room_id AND event_id = NEW.content ->> 'redacts';
-END;
-
-CREATE TRIGGER event_update_last_edit_when_redacted
- AFTER UPDATE
- ON event
- WHEN OLD.redacted_by IS NULL
- AND NEW.redacted_by IS NOT NULL
- AND NEW.relation_type = 'm.replace'
- AND NEW.state_key IS NULL
-BEGIN
- UPDATE event
- SET last_edit_rowid = COALESCE(
- (SELECT rowid
- FROM event edit
- WHERE edit.room_id = event.room_id
- AND edit.relates_to = event.event_id
- AND edit.relation_type = 'm.replace'
- AND edit.type = event.type
- AND edit.sender = event.sender
- AND edit.redacted_by IS NULL
- AND edit.state_key IS NULL
- ORDER BY edit.timestamp DESC
- LIMIT 1),
- 0)
- WHERE event_id = NEW.relates_to
- AND last_edit_rowid = NEW.rowid
- AND state_key IS NULL
- AND (relation_type IS NULL OR relation_type NOT IN ('m.replace', 'm.annotation'));
-END;
-
-CREATE TRIGGER event_insert_update_last_edit
- AFTER INSERT
- ON event
- WHEN NEW.relation_type = 'm.replace'
- AND NEW.redacted_by IS NULL
- AND NEW.state_key IS NULL
-BEGIN
- UPDATE event
- SET last_edit_rowid = NEW.rowid
- WHERE event_id = NEW.relates_to
- AND type = NEW.type
- AND sender = NEW.sender
- AND state_key IS NULL
- AND (relation_type IS NULL OR relation_type NOT IN ('m.replace', 'm.annotation'))
- AND NEW.timestamp >
- COALESCE((SELECT prev_edit.timestamp FROM event prev_edit WHERE prev_edit.rowid = event.last_edit_rowid), 0);
-END;
-
-CREATE TRIGGER event_insert_fill_reactions
- AFTER INSERT
- ON event
- WHEN NEW.type = 'm.reaction'
- AND NEW.relation_type = 'm.annotation'
- AND NEW.redacted_by IS NULL
- AND typeof(NEW.content ->> '$."m.relates_to".key') = 'text'
-BEGIN
- UPDATE event
- SET reactions=json_set(
- reactions,
- '$.' || json_quote(NEW.content ->> '$."m.relates_to".key'),
- coalesce(
- reactions ->> ('$.' || json_quote(NEW.content ->> '$."m.relates_to".key')),
- 0
- ) + 1)
- WHERE event_id = NEW.relates_to
- AND reactions IS NOT NULL;
-END;
-
-CREATE TRIGGER event_redact_fill_reactions
- AFTER UPDATE
- ON event
- WHEN NEW.type = 'm.reaction'
- AND NEW.relation_type = 'm.annotation'
- AND NEW.redacted_by IS NOT NULL
- AND OLD.redacted_by IS NULL
- AND typeof(NEW.content ->> '$."m.relates_to".key') = 'text'
-BEGIN
- UPDATE event
- SET reactions=json_set(
- reactions,
- '$.' || json_quote(NEW.content ->> '$."m.relates_to".key'),
- coalesce(
- reactions ->> ('$.' || json_quote(NEW.content ->> '$."m.relates_to".key')),
- 0
- ) - 1)
- WHERE event_id = NEW.relates_to
- AND reactions IS NOT NULL;
-END;
-
-CREATE TABLE cached_media (
- mxc TEXT NOT NULL PRIMARY KEY,
- event_rowid INTEGER,
- enc_file TEXT,
- file_name TEXT,
- mime_type TEXT,
- size INTEGER,
- hash BLOB,
- error TEXT,
-
- CONSTRAINT cached_media_event_fkey FOREIGN KEY (event_rowid) REFERENCES event (rowid) ON DELETE SET NULL
-) STRICT;
-
-CREATE TABLE session_request (
- room_id TEXT NOT NULL,
- session_id TEXT NOT NULL,
- sender TEXT NOT NULL,
- min_index INTEGER NOT NULL,
- backup_checked INTEGER NOT NULL DEFAULT false,
- request_sent INTEGER NOT NULL DEFAULT false,
-
- PRIMARY KEY (session_id),
- CONSTRAINT session_request_queue_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE
-) STRICT;
-CREATE INDEX session_request_room_idx ON session_request (room_id);
-
-CREATE TABLE timeline (
- rowid INTEGER PRIMARY KEY,
- room_id TEXT NOT NULL,
- event_rowid INTEGER NOT NULL,
-
- CONSTRAINT timeline_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE,
- CONSTRAINT timeline_event_fkey FOREIGN KEY (event_rowid) REFERENCES event (rowid) ON DELETE CASCADE,
- CONSTRAINT timeline_event_unique_key UNIQUE (event_rowid)
-) STRICT;
-CREATE INDEX timeline_room_id_idx ON timeline (room_id);
-
-CREATE TABLE current_state (
- room_id TEXT NOT NULL,
- event_type TEXT NOT NULL,
- state_key TEXT NOT NULL,
- event_rowid INTEGER NOT NULL,
-
- membership TEXT,
-
- PRIMARY KEY (room_id, event_type, state_key),
- CONSTRAINT current_state_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE,
- CONSTRAINT current_state_event_fkey FOREIGN KEY (event_rowid) REFERENCES event (rowid)
-) STRICT, WITHOUT ROWID;
-
-CREATE TABLE receipt (
- room_id TEXT NOT NULL,
- user_id TEXT NOT NULL,
- receipt_type TEXT NOT NULL,
- thread_id TEXT NOT NULL,
- event_id TEXT NOT NULL,
- timestamp INTEGER NOT NULL,
-
- PRIMARY KEY (room_id, user_id, receipt_type, thread_id),
- CONSTRAINT receipt_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE
- -- note: there's no foreign key on event ID because receipts could point at events that are too far in history.
-) STRICT;
diff --git a/hicli/database/upgrades/02-explicit-avatar-flag.sql b/hicli/database/upgrades/02-explicit-avatar-flag.sql
deleted file mode 100644
index c11e8801..00000000
--- a/hicli/database/upgrades/02-explicit-avatar-flag.sql
+++ /dev/null
@@ -1,2 +0,0 @@
--- v2 (compatible with v1+): Add explicit avatar flag to rooms
-ALTER TABLE room ADD COLUMN explicit_avatar INTEGER NOT NULL DEFAULT 0;
diff --git a/hicli/decryptionqueue.go b/hicli/decryptionqueue.go
deleted file mode 100644
index 87b6b8b2..00000000
--- a/hicli/decryptionqueue.go
+++ /dev/null
@@ -1,209 +0,0 @@
-// Copyright (c) 2024 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package hicli
-
-import (
- "context"
- "fmt"
- "sync"
-
- "github.com/rs/zerolog"
-
- "maunium.net/go/mautrix/crypto"
- "maunium.net/go/mautrix/event"
- "maunium.net/go/mautrix/hicli/database"
- "maunium.net/go/mautrix/id"
-)
-
-func (h *HiClient) fetchFromKeyBackup(ctx context.Context, roomID id.RoomID, sessionID id.SessionID) (*crypto.InboundGroupSession, error) {
- data, err := h.Client.GetKeyBackupForRoomAndSession(ctx, h.KeyBackupVersion, roomID, sessionID)
- if err != nil {
- return nil, err
- } else if data == nil {
- return nil, nil
- }
- decrypted, err := data.SessionData.Decrypt(h.KeyBackupKey)
- if err != nil {
- return nil, err
- }
- return h.Crypto.ImportRoomKeyFromBackup(ctx, h.KeyBackupVersion, roomID, sessionID, decrypted)
-}
-
-func (h *HiClient) handleReceivedMegolmSession(ctx context.Context, roomID id.RoomID, sessionID id.SessionID, firstKnownIndex uint32) {
- log := zerolog.Ctx(ctx)
- err := h.DB.SessionRequest.Remove(ctx, sessionID, firstKnownIndex)
- if err != nil {
- log.Warn().Err(err).Msg("Failed to remove session request after receiving megolm session")
- }
- events, err := h.DB.Event.GetFailedByMegolmSessionID(ctx, roomID, sessionID)
- if err != nil {
- log.Err(err).Msg("Failed to get events that failed to decrypt to retry decryption")
- return
- } else if len(events) == 0 {
- log.Trace().Msg("No events to retry decryption for")
- return
- }
- decrypted := events[:0]
- for _, evt := range events {
- if evt.Decrypted != nil {
- continue
- }
-
- var mautrixEvt *event.Event
- mautrixEvt, evt.Decrypted, evt.DecryptedType, err = h.decryptEvent(ctx, evt.AsRawMautrix())
- if err != nil {
- log.Warn().Err(err).Stringer("event_id", evt.ID).Msg("Failed to decrypt event even after receiving megolm session")
- } else {
- decrypted = append(decrypted, evt)
- h.cacheMedia(ctx, mautrixEvt, evt.RowID)
- }
- }
- if len(decrypted) > 0 {
- var newPreview database.EventRowID
- err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
- for _, evt := range decrypted {
- err = h.DB.Event.UpdateDecrypted(ctx, evt.RowID, evt.Decrypted, evt.DecryptedType)
- if err != nil {
- return fmt.Errorf("failed to save decrypted content for %s: %w", evt.ID, err)
- }
- if evt.CanUseForPreview() {
- var previewChanged bool
- previewChanged, err = h.DB.Room.UpdatePreviewIfLaterOnTimeline(ctx, evt.RoomID, evt.RowID)
- if err != nil {
- return fmt.Errorf("failed to update room %s preview to %d: %w", evt.RoomID, evt.RowID, err)
- } else if previewChanged {
- newPreview = evt.RowID
- }
- }
- }
- return nil
- })
- if err != nil {
- log.Err(err).Msg("Failed to save decrypted events")
- } else {
- h.EventHandler(&EventsDecrypted{Events: decrypted, PreviewEventRowID: newPreview, RoomID: roomID})
- }
- }
-}
-
-func (h *HiClient) WakeupRequestQueue() {
- select {
- case h.requestQueueWakeup <- struct{}{}:
- default:
- }
-}
-
-func (h *HiClient) RunRequestQueue(ctx context.Context) {
- log := zerolog.Ctx(ctx).With().Str("action", "request queue").Logger()
- ctx = log.WithContext(ctx)
- log.Info().Msg("Starting key request queue")
- defer func() {
- log.Info().Msg("Stopping key request queue")
- }()
- for {
- err := h.FetchKeysForOutdatedUsers(ctx)
- if err != nil {
- log.Err(err).Msg("Failed to fetch outdated device lists for tracked users")
- }
- madeRequests, err := h.RequestQueuedSessions(ctx)
- if err != nil {
- log.Err(err).Msg("Failed to handle session request queue")
- } else if madeRequests {
- continue
- }
- select {
- case <-ctx.Done():
- return
- case <-h.requestQueueWakeup:
- }
- }
-}
-
-func (h *HiClient) requestQueuedSession(ctx context.Context, req *database.SessionRequest, doneFunc func()) {
- defer doneFunc()
- log := zerolog.Ctx(ctx)
- if !req.BackupChecked {
- sess, err := h.fetchFromKeyBackup(ctx, req.RoomID, req.SessionID)
- if err != nil {
- log.Err(err).
- Stringer("session_id", req.SessionID).
- Msg("Failed to fetch session from key backup")
-
- // TODO should this have retries instead of just storing it's checked?
- req.BackupChecked = true
- err = h.DB.SessionRequest.Put(ctx, req)
- if err != nil {
- log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after trying to check backup")
- }
- } else if sess == nil || sess.Internal.FirstKnownIndex() > req.MinIndex {
- req.BackupChecked = true
- err = h.DB.SessionRequest.Put(ctx, req)
- if err != nil {
- log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after checking backup")
- }
- } else {
- log.Debug().Stringer("session_id", req.SessionID).
- Msg("Found session with sufficiently low first known index, removing from queue")
- err = h.DB.SessionRequest.Remove(ctx, req.SessionID, sess.Internal.FirstKnownIndex())
- if err != nil {
- log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to remove session from request queue")
- }
- }
- } else {
- err := h.Crypto.SendRoomKeyRequest(ctx, req.RoomID, "", req.SessionID, "", map[id.UserID][]id.DeviceID{
- h.Account.UserID: {"*"},
- req.Sender: {"*"},
- })
- //var err error
- if err != nil {
- log.Err(err).
- Stringer("session_id", req.SessionID).
- Msg("Failed to send key request")
- } else {
- log.Debug().Stringer("session_id", req.SessionID).Msg("Sent key request")
- req.RequestSent = true
- err = h.DB.SessionRequest.Put(ctx, req)
- if err != nil {
- log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after sending request")
- }
- }
- }
-}
-
-const MaxParallelRequests = 5
-
-func (h *HiClient) RequestQueuedSessions(ctx context.Context) (bool, error) {
- sessions, err := h.DB.SessionRequest.Next(ctx, MaxParallelRequests)
- if err != nil {
- return false, fmt.Errorf("failed to get next events to decrypt: %w", err)
- } else if len(sessions) == 0 {
- return false, nil
- }
- var wg sync.WaitGroup
- wg.Add(len(sessions))
- for _, req := range sessions {
- go h.requestQueuedSession(ctx, req, wg.Done)
- }
- wg.Wait()
-
- return true, err
-}
-
-func (h *HiClient) FetchKeysForOutdatedUsers(ctx context.Context) error {
- outdatedUsers, err := h.Crypto.CryptoStore.GetOutdatedTrackedUsers(ctx)
- if err != nil {
- return err
- } else if len(outdatedUsers) == 0 {
- return nil
- }
- _, err = h.Crypto.FetchKeys(ctx, outdatedUsers, false)
- if err != nil {
- return err
- }
- // TODO backoff for users that fail to be fetched?
- return nil
-}
diff --git a/hicli/events.go b/hicli/events.go
deleted file mode 100644
index b96fd266..00000000
--- a/hicli/events.go
+++ /dev/null
@@ -1,53 +0,0 @@
-// Copyright (c) 2024 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package hicli
-
-import (
- "maunium.net/go/mautrix/event"
- "maunium.net/go/mautrix/hicli/database"
- "maunium.net/go/mautrix/id"
-)
-
-type SyncRoom struct {
- Meta *database.Room `json:"meta"`
- Timeline []database.TimelineRowTuple `json:"timeline"`
- State map[event.Type]map[string]database.EventRowID `json:"state"`
- Events []*database.Event `json:"events"`
- Reset bool `json:"reset"`
-}
-
-type SyncComplete struct {
- Rooms map[id.RoomID]*SyncRoom `json:"rooms"`
-}
-
-func (c *SyncComplete) IsEmpty() bool {
- return len(c.Rooms) == 0
-}
-
-type EventsDecrypted struct {
- RoomID id.RoomID `json:"room_id"`
- PreviewEventRowID database.EventRowID `json:"preview_event_rowid,omitempty"`
- Events []*database.Event `json:"events"`
-}
-
-type Typing struct {
- RoomID id.RoomID `json:"room_id"`
- event.TypingEventContent
-}
-
-type SendComplete struct {
- Event *database.Event `json:"event"`
- Error error `json:"error"`
-}
-
-type ClientState struct {
- IsLoggedIn bool `json:"is_logged_in"`
- IsVerified bool `json:"is_verified"`
- UserID id.UserID `json:"user_id,omitempty"`
- DeviceID id.DeviceID `json:"device_id,omitempty"`
- HomeserverURL string `json:"homeserver_url,omitempty"`
-}
diff --git a/hicli/hicli.go b/hicli/hicli.go
deleted file mode 100644
index 78a1acc0..00000000
--- a/hicli/hicli.go
+++ /dev/null
@@ -1,250 +0,0 @@
-// Copyright (c) 2024 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-// Package hicli contains a highly opinionated high-level framework for developing instant messaging clients on Matrix.
-package hicli
-
-import (
- "context"
- "errors"
- "fmt"
- "net"
- "net/http"
- "net/url"
- "sync"
- "sync/atomic"
- "time"
-
- "github.com/rs/zerolog"
- "go.mau.fi/util/dbutil"
- "go.mau.fi/util/exerrors"
-
- "maunium.net/go/mautrix"
- "maunium.net/go/mautrix/crypto"
- "maunium.net/go/mautrix/crypto/backup"
- "maunium.net/go/mautrix/hicli/database"
- "maunium.net/go/mautrix/id"
- "maunium.net/go/mautrix/pushrules"
-)
-
-type HiClient struct {
- DB *database.Database
- Account *database.Account
- Client *mautrix.Client
- Crypto *crypto.OlmMachine
- CryptoStore *crypto.SQLCryptoStore
- ClientStore *database.ClientStateStore
- Log zerolog.Logger
-
- Verified bool
-
- KeyBackupVersion id.KeyBackupVersion
- KeyBackupKey *backup.MegolmBackupKey
-
- PushRules atomic.Pointer[pushrules.PushRuleset]
-
- EventHandler func(evt any)
-
- firstSyncReceived bool
- syncingID int
- syncLock sync.Mutex
- stopSync atomic.Pointer[context.CancelFunc]
- encryptLock sync.Mutex
-
- requestQueueWakeup chan struct{}
-
- jsonRequestsLock sync.Mutex
- jsonRequests map[int64]context.CancelCauseFunc
-
- paginationInterrupterLock sync.Mutex
- paginationInterrupter map[id.RoomID]context.CancelCauseFunc
-}
-
-var ErrTimelineReset = errors.New("got limited timeline sync response")
-
-func New(rawDB, cryptoDB *dbutil.Database, log zerolog.Logger, pickleKey []byte, evtHandler func(any)) *HiClient {
- if cryptoDB == nil {
- cryptoDB = rawDB
- }
- if rawDB.Owner == "" {
- rawDB.Owner = "hicli"
- rawDB.IgnoreForeignTables = true
- }
- if rawDB.Log == nil {
- rawDB.Log = dbutil.ZeroLogger(log.With().Str("db_section", "hicli").Logger())
- }
- db := database.New(rawDB)
- c := &HiClient{
- DB: db,
- Log: log,
-
- requestQueueWakeup: make(chan struct{}, 1),
- jsonRequests: make(map[int64]context.CancelCauseFunc),
- paginationInterrupter: make(map[id.RoomID]context.CancelCauseFunc),
-
- EventHandler: evtHandler,
- }
- c.ClientStore = &database.ClientStateStore{Database: db}
- c.Client = &mautrix.Client{
- UserAgent: mautrix.DefaultUserAgent,
- Client: &http.Client{
- Transport: &http.Transport{
- DialContext: (&net.Dialer{Timeout: 10 * time.Second}).DialContext,
- TLSHandshakeTimeout: 10 * time.Second,
- // This needs to be relatively high to allow initial syncs
- ResponseHeaderTimeout: 180 * time.Second,
- ForceAttemptHTTP2: true,
- },
- Timeout: 180 * time.Second,
- },
- Syncer: (*hiSyncer)(c),
- Store: (*hiStore)(c),
- StateStore: c.ClientStore,
- Log: log.With().Str("component", "mautrix client").Logger(),
- }
- c.CryptoStore = crypto.NewSQLCryptoStore(cryptoDB, dbutil.ZeroLogger(log.With().Str("db_section", "crypto").Logger()), "", "", pickleKey)
- cryptoLog := log.With().Str("component", "crypto").Logger()
- c.Crypto = crypto.NewOlmMachine(c.Client, &cryptoLog, c.CryptoStore, c.ClientStore)
- c.Crypto.SessionReceived = c.handleReceivedMegolmSession
- c.Crypto.DisableRatchetTracking = true
- c.Crypto.DisableDecryptKeyFetching = true
- c.Client.Crypto = (*hiCryptoHelper)(c)
- return c
-}
-
-func (h *HiClient) IsLoggedIn() bool {
- return h.Account != nil
-}
-
-func (h *HiClient) Start(ctx context.Context, userID id.UserID, expectedAccount *database.Account) error {
- if expectedAccount != nil && userID != expectedAccount.UserID {
- panic(fmt.Errorf("invalid parameters: different user ID in expected account and user ID"))
- }
- err := h.DB.Upgrade(ctx)
- if err != nil {
- return fmt.Errorf("failed to upgrade hicli db: %w", err)
- }
- err = h.CryptoStore.DB.Upgrade(ctx)
- if err != nil {
- return fmt.Errorf("failed to upgrade crypto db: %w", err)
- }
- account, err := h.DB.Account.Get(ctx, userID)
- if err != nil {
- return err
- } else if account == nil && expectedAccount != nil {
- err = h.DB.Account.Put(ctx, expectedAccount)
- if err != nil {
- return err
- }
- account = expectedAccount
- } else if expectedAccount != nil && expectedAccount.DeviceID != account.DeviceID {
- return fmt.Errorf("device ID mismatch: expected %s, got %s", expectedAccount.DeviceID, account.DeviceID)
- }
- if account != nil {
- zerolog.Ctx(ctx).Debug().Stringer("user_id", account.UserID).Msg("Preparing client with existing credentials")
- h.Account = account
- h.CryptoStore.AccountID = account.UserID.String()
- h.CryptoStore.DeviceID = account.DeviceID
- h.Client.UserID = account.UserID
- h.Client.DeviceID = account.DeviceID
- h.Client.AccessToken = account.AccessToken
- h.Client.HomeserverURL, err = url.Parse(account.HomeserverURL)
- if err != nil {
- return err
- }
- err = h.CheckServerVersions(ctx)
- if err != nil {
- return err
- }
- err = h.Crypto.Load(ctx)
- if err != nil {
- return fmt.Errorf("failed to load olm machine: %w", err)
- }
-
- h.Verified, err = h.checkIsCurrentDeviceVerified(ctx)
- if err != nil {
- return err
- }
- zerolog.Ctx(ctx).Debug().Bool("verified", h.Verified).Msg("Checked current device verification status")
- if h.Verified {
- err = h.loadPrivateKeys(ctx)
- if err != nil {
- return err
- }
- go h.Sync()
- }
- }
- return nil
-}
-
-var ErrFailedToCheckServerVersions = errors.New("failed to check server versions")
-var ErrOutdatedServer = errors.New("homeserver is outdated")
-var MinimumSpecVersion = mautrix.SpecV11
-
-func (h *HiClient) CheckServerVersions(ctx context.Context) error {
- versions, err := h.Client.Versions(ctx)
- if err != nil {
- return exerrors.NewDualError(ErrFailedToCheckServerVersions, err)
- } else if !versions.Contains(MinimumSpecVersion) {
- return fmt.Errorf("%w (minimum: %s, highest supported: %s)", ErrOutdatedServer, MinimumSpecVersion, versions.GetLatest())
- }
- return nil
-}
-
-func (h *HiClient) IsSyncing() bool {
- return h.stopSync.Load() != nil
-}
-
-func (h *HiClient) Sync() {
- h.Client.StopSync()
- if fn := h.stopSync.Load(); fn != nil {
- (*fn)()
- }
- h.syncLock.Lock()
- defer h.syncLock.Unlock()
- h.syncingID++
- syncingID := h.syncingID
- log := h.Log.With().
- Str("action", "sync").
- Int("sync_id", syncingID).
- Logger()
- ctx, cancel := context.WithCancel(context.Background())
- defer cancel()
- h.stopSync.Store(&cancel)
- go h.RunRequestQueue(h.Log.WithContext(ctx))
- go h.LoadPushRules(h.Log.WithContext(ctx))
- ctx = log.WithContext(ctx)
- log.Info().Msg("Starting syncing")
- err := h.Client.SyncWithContext(ctx)
- if err != nil && ctx.Err() == nil {
- log.Err(err).Msg("Fatal error in syncer")
- } else {
- log.Info().Msg("Syncing stopped")
- }
-}
-
-func (h *HiClient) LoadPushRules(ctx context.Context) {
- rules, err := h.Client.GetPushRules(ctx)
- if err != nil {
- zerolog.Ctx(ctx).Err(err).Msg("Failed to load push rules")
- return
- }
- h.PushRules.Store(rules)
- zerolog.Ctx(ctx).Debug().Msg("Updated push rules from fetch")
-}
-
-func (h *HiClient) Stop() {
- h.Client.StopSync()
- if fn := h.stopSync.Swap(nil); fn != nil {
- (*fn)()
- }
- h.syncLock.Lock()
- h.syncLock.Unlock()
- err := h.DB.Close()
- if err != nil {
- h.Log.Err(err).Msg("Failed to close database cleanly")
- }
-}
diff --git a/hicli/hitest/hitest.go b/hicli/hitest/hitest.go
deleted file mode 100644
index bdf1598f..00000000
--- a/hicli/hitest/hitest.go
+++ /dev/null
@@ -1,110 +0,0 @@
-// Copyright (c) 2024 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package main
-
-import (
- "context"
- "fmt"
- "io"
- "strings"
-
- "github.com/chzyer/readline"
- _ "github.com/mattn/go-sqlite3"
- "github.com/rs/zerolog"
- "go.mau.fi/util/dbutil"
- _ "go.mau.fi/util/dbutil/litestream"
- "go.mau.fi/util/exerrors"
- "go.mau.fi/util/exzerolog"
- "go.mau.fi/zeroconfig"
-
- "maunium.net/go/mautrix"
- "maunium.net/go/mautrix/event"
- "maunium.net/go/mautrix/hicli"
- "maunium.net/go/mautrix/id"
-)
-
-var writerTypeReadline zeroconfig.WriterType = "hitest_readline"
-
-func main() {
- hicli.InitialDeviceDisplayName = "mautrix hitest"
- rl := exerrors.Must(readline.New("> "))
- defer func() {
- _ = rl.Close()
- }()
- zeroconfig.RegisterWriter(writerTypeReadline, func(config *zeroconfig.WriterConfig) (io.Writer, error) {
- return rl.Stdout(), nil
- })
- debug := zerolog.DebugLevel
- log := exerrors.Must((&zeroconfig.Config{
- MinLevel: &debug,
- Writers: []zeroconfig.WriterConfig{{
- Type: writerTypeReadline,
- Format: zeroconfig.LogFormatPrettyColored,
- }},
- }).Compile())
- exzerolog.SetupDefaults(log)
-
- rawDB := exerrors.Must(dbutil.NewWithDialect("hicli.db", "sqlite3-fk-wal"))
- ctx := log.WithContext(context.Background())
- cli := hicli.New(rawDB, nil, *log, []byte("meow"), func(a any) {
- _, _ = fmt.Fprintf(rl, "Received event of type %T\n", a)
- switch evt := a.(type) {
- case *hicli.SyncComplete:
- for _, room := range evt.Rooms {
- name := "name unset"
- if room.Meta.Name != nil {
- name = *room.Meta.Name
- }
- _, _ = fmt.Fprintf(rl, "Room %s (%s) in sync:\n", name, room.Meta.ID)
- _, _ = fmt.Fprintf(rl, " Preview: %d, sort: %v\n", room.Meta.PreviewEventRowID, room.Meta.SortingTimestamp)
- _, _ = fmt.Fprintf(rl, " Timeline: +%d %v, reset: %t\n", len(room.Timeline), room.Timeline, room.Reset)
- }
- case *hicli.EventsDecrypted:
- for _, decrypted := range evt.Events {
- _, _ = fmt.Fprintf(rl, "Delayed decryption of %s completed: %s / %s\n", decrypted.ID, decrypted.DecryptedType, decrypted.Decrypted)
- }
- if evt.PreviewEventRowID != 0 {
- _, _ = fmt.Fprintf(rl, "Room preview updated: %+v\n", evt.PreviewEventRowID)
- }
- case *hicli.Typing:
- _, _ = fmt.Fprintf(rl, "Typing list in %s: %+v\n", evt.RoomID, evt.UserIDs)
- }
- })
- userID, _ := cli.DB.Account.GetFirstUserID(ctx)
- exerrors.PanicIfNotNil(cli.Start(ctx, userID, nil))
- if !cli.IsLoggedIn() {
- rl.SetPrompt("User ID: ")
- userID := id.UserID(exerrors.Must(rl.Readline()))
- _, serverName := exerrors.Must2(userID.Parse())
- discovery := exerrors.Must(mautrix.DiscoverClientAPI(ctx, serverName))
- password := exerrors.Must(rl.ReadPassword("Password: "))
- recoveryCode := exerrors.Must(rl.ReadPassword("Recovery code: "))
- exerrors.PanicIfNotNil(cli.LoginAndVerify(ctx, discovery.Homeserver.BaseURL, userID.String(), string(password), string(recoveryCode)))
- }
- rl.SetPrompt("> ")
-
- for {
- line, err := rl.Readline()
- if err != nil {
- break
- }
- fields := strings.Fields(line)
- if len(fields) == 0 {
- continue
- }
- switch strings.ToLower(fields[0]) {
- case "send":
- resp, err := cli.Send(ctx, id.RoomID(fields[1]), event.EventMessage, &event.MessageEventContent{
- Body: strings.Join(fields[2:], " "),
- MsgType: event.MsgText,
- })
- _, _ = fmt.Fprintln(rl, err)
- _, _ = fmt.Fprintf(rl, "%+v\n", resp)
- }
- }
- cli.Stop()
-}
diff --git a/hicli/json-commands.go b/hicli/json-commands.go
deleted file mode 100644
index c9dc89d2..00000000
--- a/hicli/json-commands.go
+++ /dev/null
@@ -1,178 +0,0 @@
-// Copyright (c) 2024 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package hicli
-
-import (
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "time"
-
- "maunium.net/go/mautrix"
- "maunium.net/go/mautrix/event"
- "maunium.net/go/mautrix/hicli/database"
- "maunium.net/go/mautrix/id"
-)
-
-func (h *HiClient) handleJSONCommand(ctx context.Context, req *JSONCommand) (any, error) {
- switch req.Command {
- case "get_state":
- return h.State(), nil
- case "cancel":
- return unmarshalAndCall(req.Data, func(params *cancelRequestParams) (bool, error) {
- h.jsonRequestsLock.Lock()
- cancelTarget, ok := h.jsonRequests[params.RequestID]
- h.jsonRequestsLock.Unlock()
- if ok {
- return false, nil
- }
- if params.Reason == "" {
- cancelTarget(nil)
- } else {
- cancelTarget(errors.New(params.Reason))
- }
- return true, nil
- })
- case "send_message":
- return unmarshalAndCall(req.Data, func(params *sendMessageParams) (*database.Event, error) {
- return h.SendMessage(ctx, params.RoomID, params.Text, params.MediaPath, params.ReplyTo, params.Mentions)
- })
- case "send_event":
- return unmarshalAndCall(req.Data, func(params *sendEventParams) (*database.Event, error) {
- return h.Send(ctx, params.RoomID, params.EventType, params.Content)
- })
- case "mark_read":
- return unmarshalAndCall(req.Data, func(params *markReadParams) (bool, error) {
- return true, h.MarkRead(ctx, params.RoomID, params.EventID, params.ReceiptType)
- })
- case "set_typing":
- return unmarshalAndCall(req.Data, func(params *setTypingParams) (bool, error) {
- return true, h.SetTyping(ctx, params.RoomID, time.Duration(params.Timeout)*time.Millisecond)
- })
- case "get_event":
- return unmarshalAndCall(req.Data, func(params *getEventParams) (*database.Event, error) {
- return h.GetEvent(ctx, params.RoomID, params.EventID)
- })
- case "get_events_by_rowids":
- return unmarshalAndCall(req.Data, func(params *getEventsByRowIDsParams) ([]*database.Event, error) {
- return h.GetEventsByRowIDs(ctx, params.RowIDs)
- })
- case "get_room_state":
- return unmarshalAndCall(req.Data, func(params *getRoomStateParams) ([]*database.Event, error) {
- return h.GetRoomState(ctx, params.RoomID, params.FetchMembers, params.Refetch)
- })
- case "paginate":
- return unmarshalAndCall(req.Data, func(params *paginateParams) (*PaginationResponse, error) {
- return h.Paginate(ctx, params.RoomID, params.MaxTimelineID, params.Limit)
- })
- case "paginate_server":
- return unmarshalAndCall(req.Data, func(params *paginateParams) (*PaginationResponse, error) {
- return h.PaginateServer(ctx, params.RoomID, params.Limit)
- })
- case "ensure_group_session_shared":
- return unmarshalAndCall(req.Data, func(params *ensureGroupSessionSharedParams) (bool, error) {
- return true, h.EnsureGroupSessionShared(ctx, params.RoomID)
- })
- case "login":
- return unmarshalAndCall(req.Data, func(params *loginParams) (bool, error) {
- return true, h.LoginPassword(ctx, params.HomeserverURL, params.Username, params.Password)
- })
- case "verify":
- return unmarshalAndCall(req.Data, func(params *verifyParams) (bool, error) {
- return true, h.VerifyWithRecoveryKey(ctx, params.RecoveryKey)
- })
- case "discover_homeserver":
- return unmarshalAndCall(req.Data, func(params *discoverHomeserverParams) (*mautrix.ClientWellKnown, error) {
- _, homeserver, err := params.UserID.Parse()
- if err != nil {
- return nil, err
- }
- return mautrix.DiscoverClientAPI(ctx, homeserver)
- })
- default:
- return nil, fmt.Errorf("unknown command %q", req.Command)
- }
-}
-
-func unmarshalAndCall[T, O any](data json.RawMessage, fn func(*T) (O, error)) (output O, err error) {
- var input T
- err = json.Unmarshal(data, &input)
- if err != nil {
- return
- }
- return fn(&input)
-}
-
-type cancelRequestParams struct {
- RequestID int64 `json:"request_id"`
- Reason string `json:"reason"`
-}
-
-type sendMessageParams struct {
- RoomID id.RoomID `json:"room_id"`
- Text string `json:"text"`
- MediaPath string `json:"media_path"`
- ReplyTo id.EventID `json:"reply_to"`
- Mentions *event.Mentions `json:"mentions"`
-}
-
-type sendEventParams struct {
- RoomID id.RoomID `json:"room_id"`
- EventType event.Type `json:"type"`
- Content json.RawMessage `json:"content"`
-}
-
-type markReadParams struct {
- RoomID id.RoomID `json:"room_id"`
- EventID id.EventID `json:"event_id"`
- ReceiptType event.ReceiptType `json:"receipt_type"`
-}
-
-type setTypingParams struct {
- RoomID id.RoomID `json:"room_id"`
- Timeout int `json:"timeout"`
-}
-
-type getEventParams struct {
- RoomID id.RoomID `json:"room_id"`
- EventID id.EventID `json:"event_id"`
-}
-
-type getEventsByRowIDsParams struct {
- RowIDs []database.EventRowID `json:"row_ids"`
-}
-
-type getRoomStateParams struct {
- RoomID id.RoomID `json:"room_id"`
- Refetch bool `json:"refetch"`
- FetchMembers bool `json:"fetch_members"`
-}
-
-type ensureGroupSessionSharedParams struct {
- RoomID id.RoomID `json:"room_id"`
-}
-
-type loginParams struct {
- HomeserverURL string `json:"homeserver_url"`
- Username string `json:"username"`
- Password string `json:"password"`
-}
-
-type verifyParams struct {
- RecoveryKey string `json:"recovery_key"`
-}
-
-type discoverHomeserverParams struct {
- UserID id.UserID `json:"user_id"`
-}
-
-type paginateParams struct {
- RoomID id.RoomID `json:"room_id"`
- MaxTimelineID database.TimelineRowID `json:"max_timeline_id"`
- Limit int `json:"limit"`
-}
diff --git a/hicli/json.go b/hicli/json.go
deleted file mode 100644
index a27fd007..00000000
--- a/hicli/json.go
+++ /dev/null
@@ -1,119 +0,0 @@
-// Copyright (c) 2024 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package hicli
-
-import (
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "sync/atomic"
-
- "go.mau.fi/util/exerrors"
-)
-
-type JSONCommand struct {
- Command string `json:"command"`
- RequestID int64 `json:"request_id"`
- Data json.RawMessage `json:"data"`
-}
-
-type JSONEventHandler func(*JSONCommand)
-
-var outgoingEventCounter atomic.Int64
-
-func (jeh JSONEventHandler) HandleEvent(evt any) {
- var command string
- switch evt.(type) {
- case *SyncComplete:
- command = "sync_complete"
- case *EventsDecrypted:
- command = "events_decrypted"
- case *Typing:
- command = "typing"
- case *SendComplete:
- command = "send_complete"
- case *ClientState:
- command = "client_state"
- default:
- panic(fmt.Errorf("unknown event type %T", evt))
- }
- data, err := json.Marshal(evt)
- if err != nil {
- panic(fmt.Errorf("failed to marshal event %T: %w", evt, err))
- }
- jeh(&JSONCommand{
- Command: command,
- RequestID: -outgoingEventCounter.Add(1),
- Data: data,
- })
-}
-
-func (h *HiClient) State() *ClientState {
- state := &ClientState{}
- if acc := h.Account; acc != nil {
- state.IsLoggedIn = true
- state.UserID = acc.UserID
- state.DeviceID = acc.DeviceID
- state.HomeserverURL = acc.HomeserverURL
- state.IsVerified = h.Verified
- }
- return state
-}
-
-func (h *HiClient) dispatchCurrentState() {
- h.EventHandler(h.State())
-}
-
-func (h *HiClient) SubmitJSONCommand(ctx context.Context, req *JSONCommand) *JSONCommand {
- if req.Command == "ping" {
- return &JSONCommand{
- Command: "pong",
- RequestID: req.RequestID,
- }
- }
- log := h.Log.With().Int64("request_id", req.RequestID).Str("command", req.Command).Logger()
- ctx, cancel := context.WithCancelCause(ctx)
- defer func() {
- cancel(nil)
- h.jsonRequestsLock.Lock()
- delete(h.jsonRequests, req.RequestID)
- h.jsonRequestsLock.Unlock()
- }()
- ctx = log.WithContext(ctx)
- h.jsonRequestsLock.Lock()
- h.jsonRequests[req.RequestID] = cancel
- h.jsonRequestsLock.Unlock()
- resp, err := h.handleJSONCommand(ctx, req)
- if err != nil {
- if errors.Is(err, context.Canceled) {
- causeErr := context.Cause(ctx)
- if causeErr != ctx.Err() {
- err = fmt.Errorf("%w: %w", err, causeErr)
- }
- }
- return &JSONCommand{
- Command: "error",
- RequestID: req.RequestID,
- Data: exerrors.Must(json.Marshal(err.Error())),
- }
- }
- var respData json.RawMessage
- respData, err = json.Marshal(resp)
- if err != nil {
- return &JSONCommand{
- Command: "error",
- RequestID: req.RequestID,
- Data: exerrors.Must(json.Marshal(fmt.Sprintf("failed to marshal response json: %v", err))),
- }
- }
- return &JSONCommand{
- Command: "response",
- RequestID: req.RequestID,
- Data: respData,
- }
-}
diff --git a/hicli/login.go b/hicli/login.go
deleted file mode 100644
index 6dbaf6e6..00000000
--- a/hicli/login.go
+++ /dev/null
@@ -1,87 +0,0 @@
-// Copyright (c) 2024 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package hicli
-
-import (
- "context"
- "fmt"
- "net/url"
-
- "maunium.net/go/mautrix"
- "maunium.net/go/mautrix/hicli/database"
- "maunium.net/go/mautrix/id"
-)
-
-var InitialDeviceDisplayName = "mautrix hiclient"
-
-func (h *HiClient) LoginPassword(ctx context.Context, homeserverURL, username, password string) error {
- var err error
- h.Client.HomeserverURL, err = url.Parse(homeserverURL)
- if err != nil {
- return err
- }
- return h.Login(ctx, &mautrix.ReqLogin{
- Type: mautrix.AuthTypePassword,
- Identifier: mautrix.UserIdentifier{
- Type: mautrix.IdentifierTypeUser,
- User: username,
- },
- Password: password,
- InitialDeviceDisplayName: InitialDeviceDisplayName,
- })
-}
-
-func (h *HiClient) Login(ctx context.Context, req *mautrix.ReqLogin) error {
- err := h.CheckServerVersions(ctx)
- if err != nil {
- return err
- }
- req.StoreCredentials = true
- req.StoreHomeserverURL = true
- resp, err := h.Client.Login(ctx, req)
- if err != nil {
- return err
- }
- defer h.dispatchCurrentState()
- h.Account = &database.Account{
- UserID: resp.UserID,
- DeviceID: resp.DeviceID,
- AccessToken: resp.AccessToken,
- HomeserverURL: h.Client.HomeserverURL.String(),
- }
- h.CryptoStore.AccountID = resp.UserID.String()
- h.CryptoStore.DeviceID = resp.DeviceID
- err = h.DB.Account.Put(ctx, h.Account)
- if err != nil {
- return err
- }
- err = h.Crypto.Load(ctx)
- if err != nil {
- return fmt.Errorf("failed to load olm machine: %w", err)
- }
- err = h.Crypto.ShareKeys(ctx, 0)
- if err != nil {
- return err
- }
- _, err = h.Crypto.FetchKeys(ctx, []id.UserID{h.Account.UserID}, true)
- if err != nil {
- return fmt.Errorf("failed to fetch own devices: %w", err)
- }
- return nil
-}
-
-func (h *HiClient) LoginAndVerify(ctx context.Context, homeserverURL, username, password, recoveryKey string) error {
- err := h.LoginPassword(ctx, homeserverURL, username, password)
- if err != nil {
- return err
- }
- err = h.VerifyWithRecoveryKey(ctx, recoveryKey)
- if err != nil {
- return err
- }
- return nil
-}
diff --git a/hicli/paginate.go b/hicli/paginate.go
deleted file mode 100644
index da927b9b..00000000
--- a/hicli/paginate.go
+++ /dev/null
@@ -1,240 +0,0 @@
-// Copyright (c) 2024 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package hicli
-
-import (
- "context"
- "errors"
- "fmt"
-
- "github.com/rs/zerolog"
-
- "maunium.net/go/mautrix"
- "maunium.net/go/mautrix/event"
- "maunium.net/go/mautrix/hicli/database"
- "maunium.net/go/mautrix/id"
-)
-
-var ErrPaginationAlreadyInProgress = errors.New("pagination is already in progress")
-
-func (h *HiClient) GetEventsByRowIDs(ctx context.Context, rowIDs []database.EventRowID) ([]*database.Event, error) {
- events, err := h.DB.Event.GetByRowIDs(ctx, rowIDs...)
- if err != nil {
- return nil, err
- } else if len(events) == 0 {
- return events, nil
- }
- firstRoomID := events[0].RoomID
- allInSameRoom := true
- for _, evt := range events {
- if evt.RoomID != firstRoomID {
- allInSameRoom = false
- break
- }
- }
- if allInSameRoom {
- err = h.DB.Event.FillLastEditRowIDs(ctx, firstRoomID, events)
- if err != nil {
- return events, fmt.Errorf("failed to fill last edit row IDs: %w", err)
- }
- err = h.DB.Event.FillReactionCounts(ctx, firstRoomID, events)
- if err != nil {
- return events, fmt.Errorf("failed to fill reaction counts: %w", err)
- }
- } else {
- // TODO slow path where events are collected and filling is done one room at a time?
- }
- return events, nil
-}
-
-func (h *HiClient) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (*database.Event, error) {
- if evt, err := h.DB.Event.GetByID(ctx, eventID); err != nil {
- return nil, fmt.Errorf("failed to get event from database: %w", err)
- } else if evt != nil {
- return evt, nil
- } else if serverEvt, err := h.Client.GetEvent(ctx, roomID, eventID); err != nil {
- return nil, fmt.Errorf("failed to get event from server: %w", err)
- } else {
- return h.processEvent(ctx, serverEvt, nil, false)
- }
-}
-
-func (h *HiClient) GetRoomState(ctx context.Context, roomID id.RoomID, fetchMembers, refetch bool) ([]*database.Event, error) {
- var evts []*event.Event
- if refetch {
- resp, err := h.Client.StateAsArray(ctx, roomID)
- if err != nil {
- return nil, fmt.Errorf("failed to refetch state: %w", err)
- }
- evts = resp
- } else if fetchMembers {
- resp, err := h.Client.Members(ctx, roomID)
- if err != nil {
- return nil, fmt.Errorf("failed to fetch members: %w", err)
- }
- evts = resp.Chunk
- }
- if evts != nil {
- err := h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
- room, err := h.DB.Room.Get(ctx, roomID)
- if err != nil {
- return fmt.Errorf("failed to get room from database: %w", err)
- }
- updatedRoom := &database.Room{
- ID: room.ID,
- HasMemberList: true,
- }
- entries := make([]*database.CurrentStateEntry, len(evts))
- for i, evt := range evts {
- dbEvt, err := h.processEvent(ctx, evt, nil, false)
- if err != nil {
- return fmt.Errorf("failed to process event %s: %w", evt.ID, err)
- }
- entries[i] = &database.CurrentStateEntry{
- EventType: evt.Type,
- StateKey: *evt.StateKey,
- EventRowID: dbEvt.RowID,
- }
- if evt.Type == event.StateMember {
- entries[i].Membership = event.Membership(evt.Content.Raw["membership"].(string))
- } else {
- processImportantEvent(ctx, evt, room, updatedRoom)
- }
- }
- err = h.DB.CurrentState.AddMany(ctx, room.ID, refetch, entries)
- if err != nil {
- return err
- }
- roomChanged := updatedRoom.CheckChangesAndCopyInto(room)
- if roomChanged {
- err = h.DB.Room.Upsert(ctx, updatedRoom)
- if err != nil {
- return fmt.Errorf("failed to save room data: %w", err)
- }
- }
- return nil
- })
- if err != nil {
- return nil, err
- }
- }
- return h.DB.CurrentState.GetAll(ctx, roomID)
-}
-
-type PaginationResponse struct {
- Events []*database.Event `json:"events"`
- HasMore bool `json:"has_more"`
-}
-
-func (h *HiClient) Paginate(ctx context.Context, roomID id.RoomID, maxTimelineID database.TimelineRowID, limit int) (*PaginationResponse, error) {
- evts, err := h.DB.Timeline.Get(ctx, roomID, limit, maxTimelineID)
- if err != nil {
- return nil, err
- } else if len(evts) > 0 {
- return &PaginationResponse{Events: evts, HasMore: true}, nil
- } else {
- return h.PaginateServer(ctx, roomID, limit)
- }
-}
-
-func (h *HiClient) PaginateServer(ctx context.Context, roomID id.RoomID, limit int) (*PaginationResponse, error) {
- ctx, cancel := context.WithCancelCause(ctx)
- h.paginationInterrupterLock.Lock()
- if _, alreadyPaginating := h.paginationInterrupter[roomID]; alreadyPaginating {
- h.paginationInterrupterLock.Unlock()
- return nil, ErrPaginationAlreadyInProgress
- }
- h.paginationInterrupter[roomID] = cancel
- h.paginationInterrupterLock.Unlock()
- defer func() {
- h.paginationInterrupterLock.Lock()
- delete(h.paginationInterrupter, roomID)
- h.paginationInterrupterLock.Unlock()
- }()
-
- room, err := h.DB.Room.Get(ctx, roomID)
- if err != nil {
- return nil, fmt.Errorf("failed to get room from database: %w", err)
- } else if room.PrevBatch == database.PrevBatchPaginationComplete {
- return &PaginationResponse{Events: []*database.Event{}, HasMore: false}, nil
- }
- resp, err := h.Client.Messages(ctx, roomID, room.PrevBatch, "", mautrix.DirectionBackward, nil, limit)
- if err != nil {
- return nil, fmt.Errorf("failed to get messages from server: %w", err)
- }
- events := make([]*database.Event, len(resp.Chunk))
- if resp.End == "" {
- resp.End = database.PrevBatchPaginationComplete
- }
- if resp.End == database.PrevBatchPaginationComplete || len(resp.Chunk) == 0 {
- err = h.DB.Room.SetPrevBatch(ctx, room.ID, resp.End)
- if err != nil {
- return nil, fmt.Errorf("failed to set prev_batch: %w", err)
- }
- return &PaginationResponse{Events: events, HasMore: resp.End != ""}, nil
- }
- wakeupSessionRequests := false
- err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
- if err = ctx.Err(); err != nil {
- return err
- }
- eventRowIDs := make([]database.EventRowID, len(resp.Chunk))
- decryptionQueue := make(map[id.SessionID]*database.SessionRequest)
- iOffset := 0
- for i, evt := range resp.Chunk {
- dbEvt, err := h.processEvent(ctx, evt, decryptionQueue, true)
- if err != nil {
- return err
- } else if exists, err := h.DB.Timeline.Has(ctx, roomID, dbEvt.RowID); err != nil {
- return fmt.Errorf("failed to check if event exists in timeline: %w", err)
- } else if exists {
- zerolog.Ctx(ctx).Warn().
- Int64("row_id", int64(dbEvt.RowID)).
- Str("event_id", dbEvt.ID.String()).
- Msg("Event already exists in timeline, skipping")
- iOffset++
- continue
- }
- events[i-iOffset] = dbEvt
- eventRowIDs[i-iOffset] = events[i-iOffset].RowID
- }
- if iOffset >= len(events) {
- events = events[:0]
- return nil
- }
- events = events[:len(events)-iOffset]
- eventRowIDs = eventRowIDs[:len(eventRowIDs)-iOffset]
- wakeupSessionRequests = len(decryptionQueue) > 0
- for _, entry := range decryptionQueue {
- err = h.DB.SessionRequest.Put(ctx, entry)
- if err != nil {
- return fmt.Errorf("failed to save session request for %s: %w", entry.SessionID, err)
- }
- }
- err = h.DB.Event.FillLastEditRowIDs(ctx, roomID, events)
- if err != nil {
- return fmt.Errorf("failed to fill last edit row IDs: %w", err)
- }
- err = h.DB.Room.SetPrevBatch(ctx, room.ID, resp.End)
- if err != nil {
- return fmt.Errorf("failed to set prev_batch: %w", err)
- }
- var tuples []database.TimelineRowTuple
- tuples, err = h.DB.Timeline.Prepend(ctx, room.ID, eventRowIDs)
- if err != nil {
- return fmt.Errorf("failed to prepend events to timeline: %w", err)
- }
- for i, evt := range events {
- evt.TimelineRowID = tuples[i].Timeline
- }
- return nil
- })
- if err == nil && wakeupSessionRequests {
- h.WakeupRequestQueue()
- }
- return &PaginationResponse{Events: events, HasMore: true}, err
-}
diff --git a/hicli/send.go b/hicli/send.go
deleted file mode 100644
index 76852dde..00000000
--- a/hicli/send.go
+++ /dev/null
@@ -1,287 +0,0 @@
-// Copyright (c) 2024 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package hicli
-
-import (
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "strings"
- "time"
-
- "github.com/rs/zerolog"
- "github.com/yuin/goldmark"
- "go.mau.fi/util/jsontime"
- "go.mau.fi/util/ptr"
-
- "maunium.net/go/mautrix"
- "maunium.net/go/mautrix/crypto"
- "maunium.net/go/mautrix/event"
- "maunium.net/go/mautrix/format"
- "maunium.net/go/mautrix/format/mdext/rainbow"
- "maunium.net/go/mautrix/hicli/database"
- "maunium.net/go/mautrix/id"
-)
-
-var (
- rainbowWithHTML = goldmark.New(format.Extensions, format.HTMLOptions, goldmark.WithExtensions(rainbow.Extension))
-)
-
-func (h *HiClient) SendMessage(ctx context.Context, roomID id.RoomID, text, mediaPath string, replyTo id.EventID, mentions *event.Mentions) (*database.Event, error) {
- var content event.MessageEventContent
- if strings.HasPrefix(text, "/rainbow ") {
- text = strings.TrimPrefix(text, "/rainbow ")
- content = format.RenderMarkdownCustom(text, rainbowWithHTML)
- content.FormattedBody = rainbow.ApplyColor(content.FormattedBody)
- } else if strings.HasPrefix(text, "/plain ") {
- text = strings.TrimPrefix(text, "/plain ")
- content = format.RenderMarkdown(text, false, false)
- } else if strings.HasPrefix(text, "/html ") {
- text = strings.TrimPrefix(text, "/html ")
- content = format.RenderMarkdown(text, false, true)
- } else {
- content = format.RenderMarkdown(text, true, false)
- }
- if mentions != nil {
- content.Mentions.Room = mentions.Room
- for _, userID := range mentions.UserIDs {
- if userID != h.Account.UserID {
- content.Mentions.Add(userID)
- }
- }
- }
- if replyTo != "" {
- content.RelatesTo = (&event.RelatesTo{}).SetReplyTo(replyTo)
- }
- return h.Send(ctx, roomID, event.EventMessage, &content)
-}
-
-func (h *HiClient) MarkRead(ctx context.Context, roomID id.RoomID, eventID id.EventID, receiptType event.ReceiptType) error {
- content := &mautrix.ReqSetReadMarkers{
- FullyRead: eventID,
- }
- if receiptType == event.ReceiptTypeRead {
- content.Read = eventID
- } else if receiptType == event.ReceiptTypeReadPrivate {
- content.ReadPrivate = eventID
- } else {
- return fmt.Errorf("invalid receipt type: %v", receiptType)
- }
- err := h.Client.SetReadMarkers(ctx, roomID, content)
- if err != nil {
- return fmt.Errorf("failed to mark event as read: %w", err)
- }
- return nil
-}
-
-func (h *HiClient) SetTyping(ctx context.Context, roomID id.RoomID, timeout time.Duration) error {
- _, err := h.Client.UserTyping(ctx, roomID, timeout > 0, timeout)
- return err
-}
-
-func (h *HiClient) Send(ctx context.Context, roomID id.RoomID, evtType event.Type, content any) (*database.Event, error) {
- roomMeta, err := h.DB.Room.Get(ctx, roomID)
- if err != nil {
- return nil, fmt.Errorf("failed to get room metadata: %w", err)
- } else if roomMeta == nil {
- return nil, fmt.Errorf("unknown room")
- }
- var decryptedType event.Type
- var decryptedContent json.RawMessage
- var megolmSessionID id.SessionID
- if roomMeta.EncryptionEvent != nil && evtType != event.EventReaction {
- decryptedType = evtType
- decryptedContent, err = json.Marshal(content)
- if err != nil {
- return nil, fmt.Errorf("failed to marshal event content: %w", err)
- }
- encryptedContent, err := h.Encrypt(ctx, roomMeta, evtType, content)
- if err != nil {
- return nil, fmt.Errorf("failed to encrypt event: %w", err)
- }
- megolmSessionID = encryptedContent.SessionID
- content = encryptedContent
- evtType = event.EventEncrypted
- }
- mainContent, err := json.Marshal(content)
- if err != nil {
- return nil, fmt.Errorf("failed to marshal event content: %w", err)
- }
- txnID := "hicli-" + h.Client.TxnID()
- relatesTo, relationType := database.GetRelatesToFromBytes(mainContent)
- dbEvt := &database.Event{
- RoomID: roomID,
- ID: id.EventID(fmt.Sprintf("~%s", txnID)),
- Sender: h.Account.UserID,
- Type: evtType.Type,
- Timestamp: jsontime.UnixMilliNow(),
- Content: mainContent,
- Decrypted: decryptedContent,
- DecryptedType: decryptedType.Type,
- Unsigned: []byte("{}"),
- TransactionID: txnID,
- RelatesTo: relatesTo,
- RelationType: relationType,
- MegolmSessionID: megolmSessionID,
- DecryptionError: "",
- SendError: "not sent",
- Reactions: map[string]int{},
- LastEditRowID: ptr.Ptr(database.EventRowID(0)),
- }
- _, err = h.DB.Event.Insert(ctx, dbEvt)
- if err != nil {
- return nil, fmt.Errorf("failed to insert event into database: %w", err)
- }
- ctx = context.WithoutCancel(ctx)
- go func() {
- err := h.SetTyping(ctx, roomID, 0)
- if err != nil {
- zerolog.Ctx(ctx).Err(err).Msg("Failed to stop typing while sending message")
- }
- }()
- go func() {
- var err error
- defer func() {
- h.EventHandler(&SendComplete{
- Event: dbEvt,
- Error: err,
- })
- }()
- var resp *mautrix.RespSendEvent
- resp, err = h.Client.SendMessageEvent(ctx, roomID, evtType, content, mautrix.ReqSendEvent{
- Timestamp: dbEvt.Timestamp.UnixMilli(),
- TransactionID: txnID,
- DontEncrypt: true,
- })
- if err != nil {
- dbEvt.SendError = err.Error()
- err = fmt.Errorf("failed to send event: %w", err)
- err2 := h.DB.Event.UpdateSendError(ctx, dbEvt.RowID, dbEvt.SendError)
- if err2 != nil {
- zerolog.Ctx(ctx).Err(err2).AnErr("send_error", err).
- Msg("Failed to update send error in database after sending failed")
- }
- return
- }
- dbEvt.ID = resp.EventID
- err = h.DB.Event.UpdateID(ctx, dbEvt.RowID, dbEvt.ID)
- if err != nil {
- err = fmt.Errorf("failed to update event ID in database: %w", err)
- }
- }()
- return dbEvt, nil
-}
-
-func (h *HiClient) Encrypt(ctx context.Context, room *database.Room, evtType event.Type, content any) (encrypted *event.EncryptedEventContent, err error) {
- h.encryptLock.Lock()
- defer h.encryptLock.Unlock()
- encrypted, err = h.Crypto.EncryptMegolmEvent(ctx, room.ID, evtType, content)
- if errors.Is(err, crypto.SessionExpired) || errors.Is(err, crypto.NoGroupSession) || errors.Is(err, crypto.SessionNotShared) {
- if err = h.shareGroupSession(ctx, room); err != nil {
- err = fmt.Errorf("failed to share group session: %w", err)
- } else if encrypted, err = h.Crypto.EncryptMegolmEvent(ctx, room.ID, evtType, content); err != nil {
- err = fmt.Errorf("failed to encrypt event after re-sharing group session: %w", err)
- }
- }
- return
-}
-
-func (h *HiClient) EnsureGroupSessionShared(ctx context.Context, roomID id.RoomID) error {
- h.encryptLock.Lock()
- defer h.encryptLock.Unlock()
- if session, err := h.CryptoStore.GetOutboundGroupSession(ctx, roomID); err != nil {
- return fmt.Errorf("failed to get previous outbound group session: %w", err)
- } else if session != nil && session.Shared && !session.Expired() {
- return nil
- } else if roomMeta, err := h.DB.Room.Get(ctx, roomID); err != nil {
- return fmt.Errorf("failed to get room metadata: %w", err)
- } else if roomMeta == nil {
- return fmt.Errorf("unknown room")
- } else {
- return h.shareGroupSession(ctx, roomMeta)
- }
-}
-
-func (h *HiClient) loadMembers(ctx context.Context, room *database.Room) error {
- if room.HasMemberList {
- return nil
- }
- resp, err := h.Client.Members(ctx, room.ID)
- if err != nil {
- return fmt.Errorf("failed to get room member list: %w", err)
- }
- err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
- entries := make([]*database.CurrentStateEntry, len(resp.Chunk))
- for i, evt := range resp.Chunk {
- dbEvt, err := h.processEvent(ctx, evt, nil, true)
- if err != nil {
- return err
- }
- entries[i] = &database.CurrentStateEntry{
- EventType: evt.Type,
- StateKey: *evt.StateKey,
- EventRowID: dbEvt.RowID,
- Membership: event.Membership(evt.Content.Raw["membership"].(string)),
- }
- }
- err := h.DB.CurrentState.AddMany(ctx, room.ID, false, entries)
- if err != nil {
- return err
- }
- return h.DB.Room.Upsert(ctx, &database.Room{
- ID: room.ID,
- HasMemberList: true,
- })
- })
- if err != nil {
- return fmt.Errorf("failed to process room member list: %w", err)
- }
- return nil
-}
-
-func (h *HiClient) shareGroupSession(ctx context.Context, room *database.Room) error {
- err := h.loadMembers(ctx, room)
- if err != nil {
- return err
- }
- shareToInvited := h.shouldShareKeysToInvitedUsers(ctx, room.ID)
- var users []id.UserID
- if shareToInvited {
- users, err = h.ClientStore.GetRoomJoinedOrInvitedMembers(ctx, room.ID)
- } else {
- users, err = h.ClientStore.GetRoomJoinedMembers(ctx, room.ID)
- }
- if err != nil {
- return fmt.Errorf("failed to get room member list: %w", err)
- } else if err = h.Crypto.ShareGroupSession(ctx, room.ID, users); err != nil {
- return fmt.Errorf("failed to share group session: %w", err)
- }
- return nil
-}
-
-func (h *HiClient) shouldShareKeysToInvitedUsers(ctx context.Context, roomID id.RoomID) bool {
- historyVisibility, err := h.DB.CurrentState.Get(ctx, roomID, event.StateHistoryVisibility, "")
- if err != nil {
- zerolog.Ctx(ctx).Err(err).Msg("Failed to get history visibility event")
- return false
- }
- mautrixEvt := historyVisibility.AsRawMautrix()
- err = mautrixEvt.Content.ParseRaw(mautrixEvt.Type)
- if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
- zerolog.Ctx(ctx).Err(err).Msg("Failed to parse history visibility event")
- return false
- }
- hv, ok := mautrixEvt.Content.Parsed.(*event.HistoryVisibilityEventContent)
- if !ok {
- zerolog.Ctx(ctx).Warn().Msg("Unexpected parsed content type for history visibility event")
- return false
- }
- return hv.HistoryVisibility == event.HistoryVisibilityInvited ||
- hv.HistoryVisibility == event.HistoryVisibilityShared ||
- hv.HistoryVisibility == event.HistoryVisibilityWorldReadable
-}
diff --git a/hicli/sync.go b/hicli/sync.go
deleted file mode 100644
index 16930b59..00000000
--- a/hicli/sync.go
+++ /dev/null
@@ -1,744 +0,0 @@
-// Copyright (c) 2024 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package hicli
-
-import (
- "context"
- "errors"
- "fmt"
- "strings"
-
- "github.com/rs/zerolog"
- "github.com/tidwall/gjson"
- "github.com/tidwall/sjson"
- "go.mau.fi/util/exzerolog"
- "go.mau.fi/util/jsontime"
- "golang.org/x/exp/slices"
-
- "maunium.net/go/mautrix"
- "maunium.net/go/mautrix/crypto"
- "maunium.net/go/mautrix/crypto/olm"
- "maunium.net/go/mautrix/event"
- "maunium.net/go/mautrix/hicli/database"
- "maunium.net/go/mautrix/id"
- "maunium.net/go/mautrix/pushrules"
-)
-
-type syncContext struct {
- shouldWakeupRequestQueue bool
-
- evt *SyncComplete
-}
-
-func (h *HiClient) preProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) error {
- log := zerolog.Ctx(ctx)
- postponedToDevices := resp.ToDevice.Events[:0]
- for _, evt := range resp.ToDevice.Events {
- evt.Type.Class = event.ToDeviceEventType
- err := evt.Content.ParseRaw(evt.Type)
- if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
- log.Warn().Err(err).
- Stringer("event_type", &evt.Type).
- Stringer("sender", evt.Sender).
- Msg("Failed to parse to-device event, skipping")
- continue
- }
-
- switch content := evt.Content.Parsed.(type) {
- case *event.EncryptedEventContent:
- h.Crypto.HandleEncryptedEvent(ctx, evt)
- case *event.RoomKeyWithheldEventContent:
- h.Crypto.HandleRoomKeyWithheld(ctx, content)
- default:
- postponedToDevices = append(postponedToDevices, evt)
- }
- }
- resp.ToDevice.Events = postponedToDevices
-
- return nil
-}
-
-func (h *HiClient) postProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) {
- h.Crypto.HandleOTKCounts(ctx, &resp.DeviceOTKCount)
- go h.asyncPostProcessSyncResponse(ctx, resp, since)
- syncCtx := ctx.Value(syncContextKey).(*syncContext)
- if syncCtx.shouldWakeupRequestQueue {
- h.WakeupRequestQueue()
- }
- h.firstSyncReceived = true
- if !syncCtx.evt.IsEmpty() {
- h.EventHandler(syncCtx.evt)
- }
-}
-
-func (h *HiClient) asyncPostProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) {
- for _, evt := range resp.ToDevice.Events {
- switch content := evt.Content.Parsed.(type) {
- case *event.SecretRequestEventContent:
- h.Crypto.HandleSecretRequest(ctx, evt.Sender, content)
- case *event.RoomKeyRequestEventContent:
- h.Crypto.HandleRoomKeyRequest(ctx, evt.Sender, content)
- }
- }
-}
-
-func (h *HiClient) processSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) error {
- if len(resp.DeviceLists.Changed) > 0 {
- zerolog.Ctx(ctx).Debug().
- Array("users", exzerolog.ArrayOfStringers(resp.DeviceLists.Changed)).
- Msg("Marking changed device lists for tracked users as outdated")
- err := h.Crypto.CryptoStore.MarkTrackedUsersOutdated(ctx, resp.DeviceLists.Changed)
- if err != nil {
- return fmt.Errorf("failed to mark changed device lists as outdated: %w", err)
- }
- ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue = true
- }
-
- for _, evt := range resp.AccountData.Events {
- evt.Type.Class = event.AccountDataEventType
- err := h.DB.AccountData.Put(ctx, h.Account.UserID, evt.Type, evt.Content.VeryRaw)
- if err != nil {
- return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err)
- }
- if evt.Type == event.AccountDataPushRules {
- err = evt.Content.ParseRaw(evt.Type)
- if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
- zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to parse push rules in sync")
- } else if pushRules, ok := evt.Content.Parsed.(*pushrules.EventContent); ok {
- h.PushRules.Store(pushRules.Ruleset)
- zerolog.Ctx(ctx).Debug().Msg("Updated push rules from sync")
- }
- }
- }
- for roomID, room := range resp.Rooms.Join {
- err := h.processSyncJoinedRoom(ctx, roomID, room)
- if err != nil {
- return fmt.Errorf("failed to process joined room %s: %w", roomID, err)
- }
- }
- for roomID, room := range resp.Rooms.Leave {
- err := h.processSyncLeftRoom(ctx, roomID, room)
- if err != nil {
- return fmt.Errorf("failed to process left room %s: %w", roomID, err)
- }
- }
- h.Account.NextBatch = resp.NextBatch
- err := h.DB.Account.PutNextBatch(ctx, h.Account.UserID, resp.NextBatch)
- if err != nil {
- return fmt.Errorf("failed to save next_batch: %w", err)
- }
- return nil
-}
-
-func receiptsToList(content *event.ReceiptEventContent) []*database.Receipt {
- receiptList := make([]*database.Receipt, 0)
- for eventID, receipts := range *content {
- for receiptType, users := range receipts {
- for userID, receiptInfo := range users {
- receiptList = append(receiptList, &database.Receipt{
- UserID: userID,
- ReceiptType: receiptType,
- ThreadID: receiptInfo.ThreadID,
- EventID: eventID,
- Timestamp: jsontime.UM(receiptInfo.Timestamp),
- })
- }
- }
- }
- return receiptList
-}
-
-func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncJoinedRoom) error {
- existingRoomData, err := h.DB.Room.Get(ctx, roomID)
- if err != nil {
- return fmt.Errorf("failed to get room data: %w", err)
- } else if existingRoomData == nil {
- err = h.DB.Room.CreateRow(ctx, roomID)
- if err != nil {
- return fmt.Errorf("failed to ensure room row exists: %w", err)
- }
- existingRoomData = &database.Room{ID: roomID, SortingTimestamp: jsontime.UnixMilliNow()}
- }
-
- for _, evt := range room.AccountData.Events {
- evt.Type.Class = event.AccountDataEventType
- evt.RoomID = roomID
- err = h.DB.AccountData.PutRoom(ctx, h.Account.UserID, roomID, evt.Type, evt.Content.VeryRaw)
- if err != nil {
- return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err)
- }
- }
- err = h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary)
- if err != nil {
- return err
- }
- for _, evt := range room.Ephemeral.Events {
- evt.Type.Class = event.EphemeralEventType
- err = evt.Content.ParseRaw(evt.Type)
- if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
- zerolog.Ctx(ctx).Debug().Err(err).Msg("Failed to parse ephemeral event content")
- continue
- }
- switch evt.Type {
- case event.EphemeralEventReceipt:
- err = h.DB.Receipt.PutMany(ctx, roomID, receiptsToList(evt.Content.AsReceipt())...)
- if err != nil {
- return fmt.Errorf("failed to save receipts: %w", err)
- }
- case event.EphemeralEventTyping:
- go h.EventHandler(&Typing{
- RoomID: roomID,
- TypingEventContent: *evt.Content.AsTyping(),
- })
- }
- if evt.Type != event.EphemeralEventReceipt {
- continue
- }
- }
- return nil
-}
-
-func (h *HiClient) processSyncLeftRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncLeftRoom) error {
- existingRoomData, err := h.DB.Room.Get(ctx, roomID)
- if err != nil {
- return fmt.Errorf("failed to get room data: %w", err)
- } else if existingRoomData == nil {
- return nil
- }
- return h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary)
-}
-
-func isDecryptionErrorRetryable(err error) bool {
- return errors.Is(err, crypto.NoSessionFound) || errors.Is(err, olm.UnknownMessageIndex) || errors.Is(err, crypto.ErrGroupSessionWithheld)
-}
-
-func removeReplyFallback(evt *event.Event) []byte {
- if evt.Type != event.EventMessage && evt.Type != event.EventSticker {
- return nil
- }
- _ = evt.Content.ParseRaw(evt.Type)
- content, ok := evt.Content.Parsed.(*event.MessageEventContent)
- if ok && content.RelatesTo.GetReplyTo() != "" {
- prevFormattedBody := content.FormattedBody
- content.RemoveReplyFallback()
- if content.FormattedBody != prevFormattedBody {
- bytes, err := sjson.SetBytes(evt.Content.VeryRaw, "formatted_body", content.FormattedBody)
- bytes, err2 := sjson.SetBytes(bytes, "body", content.Body)
- if err == nil && err2 == nil {
- return bytes
- }
- }
- }
- return nil
-}
-
-func (h *HiClient) decryptEvent(ctx context.Context, evt *event.Event) (*event.Event, []byte, string, error) {
- err := evt.Content.ParseRaw(evt.Type)
- if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
- return nil, nil, "", err
- }
- decrypted, err := h.Crypto.DecryptMegolmEvent(ctx, evt)
- if err != nil {
- return nil, nil, "", err
- }
- withoutFallback := removeReplyFallback(decrypted)
- if withoutFallback != nil {
- return decrypted, withoutFallback, decrypted.Type.Type, nil
- }
- return decrypted, decrypted.Content.VeryRaw, decrypted.Type.Type, nil
-}
-
-func (h *HiClient) addMediaCache(
- ctx context.Context,
- eventRowID database.EventRowID,
- uri id.ContentURIString,
- file *event.EncryptedFileInfo,
- info *event.FileInfo,
- fileName string,
-) {
- parsedMXC := uri.ParseOrIgnore()
- if !parsedMXC.IsValid() {
- return
- }
- cm := &database.CachedMedia{
- MXC: parsedMXC,
- EventRowID: eventRowID,
- FileName: fileName,
- }
- if file != nil {
- cm.EncFile = &file.EncryptedFile
- }
- if info != nil {
- cm.MimeType = info.MimeType
- }
- err := h.DB.CachedMedia.Put(ctx, cm)
- if err != nil {
- zerolog.Ctx(ctx).Warn().Err(err).
- Stringer("mxc", parsedMXC).
- Int64("event_rowid", int64(eventRowID)).
- Msg("Failed to add cached media entry")
- }
-}
-
-func (h *HiClient) cacheMedia(ctx context.Context, evt *event.Event, rowID database.EventRowID) {
- switch evt.Type {
- case event.EventMessage, event.EventSticker:
- content, ok := evt.Content.Parsed.(*event.MessageEventContent)
- if !ok {
- return
- }
- if content.File != nil {
- h.addMediaCache(ctx, rowID, content.File.URL, content.File, content.Info, content.GetFileName())
- } else if content.URL != "" {
- h.addMediaCache(ctx, rowID, content.URL, nil, content.Info, content.GetFileName())
- }
- if content.GetInfo().ThumbnailFile != nil {
- h.addMediaCache(ctx, rowID, content.Info.ThumbnailFile.URL, content.Info.ThumbnailFile, content.Info.ThumbnailInfo, "")
- } else if content.GetInfo().ThumbnailURL != "" {
- h.addMediaCache(ctx, rowID, content.Info.ThumbnailURL, nil, content.Info.ThumbnailInfo, "")
- }
- case event.StateRoomAvatar:
- _ = evt.Content.ParseRaw(evt.Type)
- content, ok := evt.Content.Parsed.(*event.RoomAvatarEventContent)
- if !ok {
- return
- }
- h.addMediaCache(ctx, rowID, content.URL, nil, nil, "")
- case event.StateMember:
- _ = evt.Content.ParseRaw(evt.Type)
- content, ok := evt.Content.Parsed.(*event.MemberEventContent)
- if !ok {
- return
- }
- h.addMediaCache(ctx, rowID, content.AvatarURL, nil, nil, "")
- }
-}
-
-func (h *HiClient) processEvent(ctx context.Context, evt *event.Event, decryptionQueue map[id.SessionID]*database.SessionRequest, checkDB bool) (*database.Event, error) {
- if checkDB {
- dbEvt, err := h.DB.Event.GetByID(ctx, evt.ID)
- if err != nil {
- return nil, fmt.Errorf("failed to check if event %s exists: %w", evt.ID, err)
- } else if dbEvt != nil {
- return dbEvt, nil
- }
- }
- dbEvt := database.MautrixToEvent(evt)
- contentWithoutFallback := removeReplyFallback(evt)
- if contentWithoutFallback != nil {
- dbEvt.Content = contentWithoutFallback
- }
- var decryptionErr error
- var decryptedMautrixEvt *event.Event
- if evt.Type == event.EventEncrypted && dbEvt.RedactedBy == "" {
- decryptedMautrixEvt, dbEvt.Decrypted, dbEvt.DecryptedType, decryptionErr = h.decryptEvent(ctx, evt)
- if decryptionErr != nil {
- dbEvt.DecryptionError = decryptionErr.Error()
- }
- } else if evt.Type == event.EventRedaction {
- if evt.Redacts != "" && gjson.GetBytes(evt.Content.VeryRaw, "redacts").Str != evt.Redacts.String() {
- var err error
- evt.Content.VeryRaw, err = sjson.SetBytes(evt.Content.VeryRaw, "redacts", evt.Redacts)
- if err != nil {
- return dbEvt, fmt.Errorf("failed to set redacts field: %w", err)
- }
- } else if evt.Redacts == "" {
- evt.Redacts = id.EventID(gjson.GetBytes(evt.Content.VeryRaw, "redacts").Str)
- }
- }
- _, err := h.DB.Event.Upsert(ctx, dbEvt)
- if err != nil {
- return dbEvt, fmt.Errorf("failed to save event %s: %w", evt.ID, err)
- }
- if decryptedMautrixEvt != nil {
- h.cacheMedia(ctx, decryptedMautrixEvt, dbEvt.RowID)
- } else {
- h.cacheMedia(ctx, evt, dbEvt.RowID)
- }
- if decryptionErr != nil && isDecryptionErrorRetryable(decryptionErr) {
- req, ok := decryptionQueue[dbEvt.MegolmSessionID]
- if !ok {
- req = &database.SessionRequest{
- RoomID: evt.RoomID,
- SessionID: dbEvt.MegolmSessionID,
- Sender: evt.Sender,
- }
- }
- minIndex, _ := crypto.ParseMegolmMessageIndex(evt.Content.AsEncrypted().MegolmCiphertext)
- req.MinIndex = min(uint32(minIndex), req.MinIndex)
- if decryptionQueue != nil {
- decryptionQueue[dbEvt.MegolmSessionID] = req
- } else {
- err = h.DB.SessionRequest.Put(ctx, req)
- if err != nil {
- zerolog.Ctx(ctx).Err(err).
- Stringer("session_id", dbEvt.MegolmSessionID).
- Msg("Failed to save session request")
- } else {
- h.WakeupRequestQueue()
- }
- }
- }
- return dbEvt, err
-}
-
-func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.Room, state *mautrix.SyncEventsList, timeline *mautrix.SyncTimeline, summary *mautrix.LazyLoadSummary) error {
- updatedRoom := &database.Room{
- ID: room.ID,
-
- SortingTimestamp: room.SortingTimestamp,
- NameQuality: room.NameQuality,
- }
- heroesChanged := false
- if summary.Heroes == nil && summary.JoinedMemberCount == nil && summary.InvitedMemberCount == nil {
- summary = room.LazyLoadSummary
- } else if room.LazyLoadSummary == nil ||
- !slices.Equal(summary.Heroes, room.LazyLoadSummary.Heroes) ||
- !intPtrEqual(summary.JoinedMemberCount, room.LazyLoadSummary.JoinedMemberCount) ||
- !intPtrEqual(summary.InvitedMemberCount, room.LazyLoadSummary.InvitedMemberCount) {
- updatedRoom.LazyLoadSummary = summary
- heroesChanged = true
- }
- decryptionQueue := make(map[id.SessionID]*database.SessionRequest)
- allNewEvents := make([]*database.Event, 0, len(state.Events)+len(timeline.Events))
- recalculatePreviewEvent := false
- addOldEvent := func(rowID database.EventRowID, evtID id.EventID) (dbEvt *database.Event, err error) {
- if rowID != 0 {
- dbEvt, err = h.DB.Event.GetByRowID(ctx, rowID)
- } else {
- dbEvt, err = h.DB.Event.GetByID(ctx, evtID)
- }
- if err != nil {
- return nil, fmt.Errorf("failed to get redaction target: %w", err)
- } else if dbEvt == nil {
- return nil, nil
- }
- allNewEvents = append(allNewEvents, dbEvt)
- return dbEvt, nil
- }
- processRedaction := func(evt *event.Event) error {
- dbEvt, err := addOldEvent(0, evt.Redacts)
- if err != nil {
- return fmt.Errorf("failed to get redaction target: %w", err)
- }
- if dbEvt == nil {
- return nil
- }
- if dbEvt.RelationType == event.RelReplace || dbEvt.RelationType == event.RelAnnotation {
- _, err = addOldEvent(0, dbEvt.RelatesTo)
- if err != nil {
- return fmt.Errorf("failed to get relation target of redaction target: %w", err)
- }
- }
- if updatedRoom.PreviewEventRowID == dbEvt.RowID {
- updatedRoom.PreviewEventRowID = 0
- recalculatePreviewEvent = true
- }
- return nil
- }
- processNewEvent := func(evt *event.Event, isTimeline bool) (database.EventRowID, error) {
- evt.RoomID = room.ID
- dbEvt, err := h.processEvent(ctx, evt, decryptionQueue, false)
- if err != nil {
- return -1, err
- }
- if isTimeline {
- if dbEvt.CanUseForPreview() {
- updatedRoom.PreviewEventRowID = dbEvt.RowID
- recalculatePreviewEvent = false
- }
- updatedRoom.BumpSortingTimestamp(dbEvt)
- }
- if evt.StateKey != nil {
- var membership event.Membership
- if evt.Type == event.StateMember {
- membership = event.Membership(gjson.GetBytes(evt.Content.VeryRaw, "membership").Str)
- if summary != nil && slices.Contains(summary.Heroes, id.UserID(*evt.StateKey)) {
- heroesChanged = true
- }
- } else if evt.Type == event.StateElementFunctionalMembers {
- heroesChanged = true
- }
- err = h.DB.CurrentState.Set(ctx, room.ID, evt.Type, *evt.StateKey, dbEvt.RowID, membership)
- if err != nil {
- return -1, fmt.Errorf("failed to save current state event ID %s for %s/%s: %w", evt.ID, evt.Type.Type, *evt.StateKey, err)
- }
- processImportantEvent(ctx, evt, room, updatedRoom)
- }
- allNewEvents = append(allNewEvents, dbEvt)
- if evt.Type == event.EventRedaction && evt.Redacts != "" {
- err = processRedaction(evt)
- if err != nil {
- return -1, fmt.Errorf("failed to process redaction: %w", err)
- }
- } else if dbEvt.RelationType == event.RelReplace || dbEvt.RelationType == event.RelAnnotation {
- _, err = addOldEvent(0, dbEvt.RelatesTo)
- if err != nil {
- return -1, fmt.Errorf("failed to get relation target of event: %w", err)
- }
- }
- return dbEvt.RowID, nil
- }
- changedState := make(map[event.Type]map[string]database.EventRowID)
- setNewState := func(evtType event.Type, stateKey string, rowID database.EventRowID) {
- if _, ok := changedState[evtType]; !ok {
- changedState[evtType] = make(map[string]database.EventRowID)
- }
- changedState[evtType][stateKey] = rowID
- }
- for _, evt := range state.Events {
- evt.Type.Class = event.StateEventType
- rowID, err := processNewEvent(evt, false)
- if err != nil {
- return err
- }
- setNewState(evt.Type, *evt.StateKey, rowID)
- }
- var timelineRowTuples []database.TimelineRowTuple
- var err error
- if len(timeline.Events) > 0 {
- timelineIDs := make([]database.EventRowID, len(timeline.Events))
- for i, evt := range timeline.Events {
- if evt.StateKey != nil {
- evt.Type.Class = event.StateEventType
- } else {
- evt.Type.Class = event.MessageEventType
- }
- timelineIDs[i], err = processNewEvent(evt, true)
- if err != nil {
- return err
- }
- if evt.StateKey != nil {
- setNewState(evt.Type, *evt.StateKey, timelineIDs[i])
- }
- }
- for _, entry := range decryptionQueue {
- err = h.DB.SessionRequest.Put(ctx, entry)
- if err != nil {
- return fmt.Errorf("failed to save session request for %s: %w", entry.SessionID, err)
- }
- }
- if len(decryptionQueue) > 0 {
- ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue = true
- }
- if timeline.Limited {
- err = h.DB.Timeline.Clear(ctx, room.ID)
- if err != nil {
- return fmt.Errorf("failed to clear old timeline: %w", err)
- }
- updatedRoom.PrevBatch = timeline.PrevBatch
- h.paginationInterrupterLock.Lock()
- if interrupt, ok := h.paginationInterrupter[room.ID]; ok {
- interrupt(ErrTimelineReset)
- }
- h.paginationInterrupterLock.Unlock()
- }
- timelineRowTuples, err = h.DB.Timeline.Append(ctx, room.ID, timelineIDs)
- if err != nil {
- return fmt.Errorf("failed to append timeline: %w", err)
- }
- } else {
- timelineRowTuples = make([]database.TimelineRowTuple, 0)
- }
- if recalculatePreviewEvent && updatedRoom.PreviewEventRowID == 0 {
- updatedRoom.PreviewEventRowID, err = h.DB.Room.RecalculatePreview(ctx, room.ID)
- if err != nil {
- return fmt.Errorf("failed to recalculate preview event: %w", err)
- }
- _, err = addOldEvent(updatedRoom.PreviewEventRowID, "")
- if err != nil {
- return fmt.Errorf("failed to get preview event: %w", err)
- }
- }
- // Calculate name from participants if participants changed and current name was generated from participants, or if the room name was unset
- if (heroesChanged && updatedRoom.NameQuality <= database.NameQualityParticipants) || updatedRoom.NameQuality == database.NameQualityNil {
- name, dmAvatarURL, err := h.calculateRoomParticipantName(ctx, room.ID, summary)
- if err != nil {
- return fmt.Errorf("failed to calculate room name: %w", err)
- }
- updatedRoom.Name = &name
- updatedRoom.NameQuality = database.NameQualityParticipants
- if !dmAvatarURL.IsEmpty() && !room.ExplicitAvatar {
- updatedRoom.Avatar = &dmAvatarURL
- }
- }
- if timeline.PrevBatch != "" && (room.PrevBatch == "" || timeline.Limited) {
- updatedRoom.PrevBatch = timeline.PrevBatch
- }
- roomChanged := updatedRoom.CheckChangesAndCopyInto(room)
- if roomChanged {
- err = h.DB.Room.Upsert(ctx, updatedRoom)
- if err != nil {
- return fmt.Errorf("failed to save room data: %w", err)
- }
- }
- if roomChanged || len(timelineRowTuples) > 0 || len(allNewEvents) > 0 {
- ctx.Value(syncContextKey).(*syncContext).evt.Rooms[room.ID] = &SyncRoom{
- Meta: room,
- Timeline: timelineRowTuples,
- State: changedState,
- Reset: timeline.Limited,
- Events: allNewEvents,
- }
- }
- return nil
-}
-
-func joinMemberNames(names []string, totalCount int) string {
- if len(names) == 1 {
- return names[0]
- } else if len(names) < 5 || (len(names) == 5 && totalCount <= 6) {
- return strings.Join(names[:len(names)-1], ", ") + " and " + names[len(names)-1]
- } else {
- return fmt.Sprintf("%s and %d others", strings.Join(names[:4], ", "), totalCount-5)
- }
-}
-
-func (h *HiClient) calculateRoomParticipantName(ctx context.Context, roomID id.RoomID, summary *mautrix.LazyLoadSummary) (string, id.ContentURI, error) {
- var primaryAvatarURL id.ContentURI
- if summary == nil || len(summary.Heroes) == 0 {
- return "Empty room", primaryAvatarURL, nil
- }
- var functionalMembers []id.UserID
- functionalMembersEvt, err := h.DB.CurrentState.Get(ctx, roomID, event.StateElementFunctionalMembers, "")
- if err != nil {
- return "", primaryAvatarURL, fmt.Errorf("failed to get %s event: %w", event.StateElementFunctionalMembers.Type, err)
- } else if functionalMembersEvt != nil {
- mautrixEvt := functionalMembersEvt.AsRawMautrix()
- _ = mautrixEvt.Content.ParseRaw(mautrixEvt.Type)
- content, ok := mautrixEvt.Content.Parsed.(*event.ElementFunctionalMembersContent)
- if ok {
- functionalMembers = content.ServiceMembers
- }
- }
- var members, leftMembers []string
- var memberCount int
- if summary.JoinedMemberCount != nil && *summary.JoinedMemberCount > 0 {
- memberCount = *summary.JoinedMemberCount
- } else if summary.InvitedMemberCount != nil {
- memberCount = *summary.InvitedMemberCount
- }
- for _, hero := range summary.Heroes {
- if slices.Contains(functionalMembers, hero) {
- memberCount--
- continue
- } else if len(members) >= 5 {
- break
- }
- heroEvt, err := h.DB.CurrentState.Get(ctx, roomID, event.StateMember, hero.String())
- if err != nil {
- return "", primaryAvatarURL, fmt.Errorf("failed to get %s's member event: %w", hero, err)
- } else if heroEvt == nil {
- leftMembers = append(leftMembers, hero.String())
- continue
- }
- membership := gjson.GetBytes(heroEvt.Content, "membership").Str
- name := gjson.GetBytes(heroEvt.Content, "displayname").Str
- if name == "" {
- name = hero.String()
- }
- avatarURL := gjson.GetBytes(heroEvt.Content, "avatar_url").Str
- if avatarURL != "" {
- primaryAvatarURL = id.ContentURIString(avatarURL).ParseOrIgnore()
- }
- if membership == "join" || membership == "invite" {
- members = append(members, name)
- } else {
- leftMembers = append(leftMembers, name)
- }
- }
- if len(members)+len(leftMembers) > 1 || !primaryAvatarURL.IsValid() {
- primaryAvatarURL = id.ContentURI{}
- }
- if len(members) > 0 {
- return joinMemberNames(members, memberCount), primaryAvatarURL, nil
- } else if len(leftMembers) > 0 {
- return fmt.Sprintf("Empty room (was %s)", joinMemberNames(leftMembers, memberCount)), primaryAvatarURL, nil
- } else {
- return "Empty room", primaryAvatarURL, nil
- }
-}
-
-func intPtrEqual(a, b *int) bool {
- if a == nil || b == nil {
- return a == b
- }
- return *a == *b
-}
-
-func processImportantEvent(ctx context.Context, evt *event.Event, existingRoomData, updatedRoom *database.Room) (roomDataChanged bool) {
- if evt.StateKey == nil {
- return
- }
- switch evt.Type {
- case event.StateCreate, event.StateRoomName, event.StateCanonicalAlias, event.StateRoomAvatar, event.StateTopic, event.StateEncryption:
- if *evt.StateKey != "" {
- return
- }
- default:
- return
- }
- err := evt.Content.ParseRaw(evt.Type)
- if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
- zerolog.Ctx(ctx).Warn().Err(err).
- Stringer("event_type", &evt.Type).
- Stringer("event_id", evt.ID).
- Msg("Failed to parse state event, skipping")
- return
- }
- switch evt.Type {
- case event.StateCreate:
- updatedRoom.CreationContent, _ = evt.Content.Parsed.(*event.CreateEventContent)
- case event.StateEncryption:
- newEncryption, _ := evt.Content.Parsed.(*event.EncryptionEventContent)
- if existingRoomData.EncryptionEvent == nil || existingRoomData.EncryptionEvent.Algorithm == newEncryption.Algorithm {
- updatedRoom.EncryptionEvent = newEncryption
- }
- case event.StateRoomName:
- content, ok := evt.Content.Parsed.(*event.RoomNameEventContent)
- if ok {
- updatedRoom.Name = &content.Name
- updatedRoom.NameQuality = database.NameQualityExplicit
- if content.Name == "" {
- if updatedRoom.CanonicalAlias != nil && *updatedRoom.CanonicalAlias != "" {
- updatedRoom.Name = (*string)(updatedRoom.CanonicalAlias)
- updatedRoom.NameQuality = database.NameQualityCanonicalAlias
- } else if existingRoomData.CanonicalAlias != nil && *existingRoomData.CanonicalAlias != "" {
- updatedRoom.Name = (*string)(existingRoomData.CanonicalAlias)
- updatedRoom.NameQuality = database.NameQualityCanonicalAlias
- } else {
- updatedRoom.NameQuality = database.NameQualityNil
- }
- }
- }
- case event.StateCanonicalAlias:
- content, ok := evt.Content.Parsed.(*event.CanonicalAliasEventContent)
- if ok {
- updatedRoom.CanonicalAlias = &content.Alias
- if updatedRoom.NameQuality <= database.NameQualityCanonicalAlias {
- updatedRoom.Name = (*string)(&content.Alias)
- updatedRoom.NameQuality = database.NameQualityCanonicalAlias
- if content.Alias == "" {
- updatedRoom.NameQuality = database.NameQualityNil
- }
- }
- }
- case event.StateRoomAvatar:
- content, ok := evt.Content.Parsed.(*event.RoomAvatarEventContent)
- if ok {
- url, _ := content.URL.Parse()
- updatedRoom.Avatar = &url
- updatedRoom.ExplicitAvatar = true
- }
- case event.StateTopic:
- content, ok := evt.Content.Parsed.(*event.TopicEventContent)
- if ok {
- updatedRoom.Topic = &content.Topic
- }
- }
- return
-}
diff --git a/hicli/syncwrap.go b/hicli/syncwrap.go
deleted file mode 100644
index 13837202..00000000
--- a/hicli/syncwrap.go
+++ /dev/null
@@ -1,96 +0,0 @@
-// Copyright (c) 2024 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package hicli
-
-import (
- "context"
- "fmt"
- "time"
-
- "maunium.net/go/mautrix"
- "maunium.net/go/mautrix/id"
-)
-
-type hiSyncer HiClient
-
-var _ mautrix.Syncer = (*hiSyncer)(nil)
-
-type contextKey int
-
-const (
- syncContextKey contextKey = iota
-)
-
-func (h *hiSyncer) ProcessResponse(ctx context.Context, resp *mautrix.RespSync, since string) error {
- c := (*HiClient)(h)
- ctx = context.WithValue(ctx, syncContextKey, &syncContext{evt: &SyncComplete{Rooms: make(map[id.RoomID]*SyncRoom, len(resp.Rooms.Join))}})
- err := c.preProcessSyncResponse(ctx, resp, since)
- if err != nil {
- return err
- }
- err = c.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
- return c.processSyncResponse(ctx, resp, since)
- })
- if err != nil {
- return err
- }
- c.postProcessSyncResponse(ctx, resp, since)
- return nil
-}
-
-func (h *hiSyncer) OnFailedSync(_ *mautrix.RespSync, err error) (time.Duration, error) {
- (*HiClient)(h).Log.Err(err).Msg("Sync failed, retrying in 1 second")
- return 1 * time.Second, nil
-}
-
-func (h *hiSyncer) GetFilterJSON(_ id.UserID) *mautrix.Filter {
- if !h.Verified {
- return &mautrix.Filter{
- Presence: mautrix.FilterPart{
- NotRooms: []id.RoomID{"*"},
- },
- Room: mautrix.RoomFilter{
- NotRooms: []id.RoomID{"*"},
- },
- }
- }
- return &mautrix.Filter{
- Presence: mautrix.FilterPart{
- NotRooms: []id.RoomID{"*"},
- },
- Room: mautrix.RoomFilter{
- State: mautrix.FilterPart{
- LazyLoadMembers: true,
- },
- Timeline: mautrix.FilterPart{
- Limit: 100,
- LazyLoadMembers: true,
- },
- },
- }
-}
-
-type hiStore HiClient
-
-var _ mautrix.SyncStore = (*hiStore)(nil)
-
-// Filter ID save and load are intentionally no-ops: we want to recreate filters when restarting syncing
-
-func (h *hiStore) SaveFilterID(_ context.Context, _ id.UserID, _ string) error { return nil }
-func (h *hiStore) LoadFilterID(_ context.Context, _ id.UserID) (string, error) { return "", nil }
-
-func (h *hiStore) SaveNextBatch(ctx context.Context, userID id.UserID, nextBatchToken string) error {
- // This is intentionally a no-op: we don't want to save the next batch before processing the sync
- return nil
-}
-
-func (h *hiStore) LoadNextBatch(_ context.Context, userID id.UserID) (string, error) {
- if h.Account.UserID != userID {
- return "", fmt.Errorf("mismatching user ID")
- }
- return h.Account.NextBatch, nil
-}
diff --git a/hicli/verify.go b/hicli/verify.go
deleted file mode 100644
index 6dc2a4c3..00000000
--- a/hicli/verify.go
+++ /dev/null
@@ -1,162 +0,0 @@
-// Copyright (c) 2024 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package hicli
-
-import (
- "context"
- "encoding/base64"
- "fmt"
-
- "github.com/rs/zerolog"
-
- "maunium.net/go/mautrix/crypto"
- "maunium.net/go/mautrix/crypto/backup"
- "maunium.net/go/mautrix/crypto/ssss"
- "maunium.net/go/mautrix/event"
- "maunium.net/go/mautrix/id"
-)
-
-func (h *HiClient) checkIsCurrentDeviceVerified(ctx context.Context) (bool, error) {
- keys := h.Crypto.GetOwnCrossSigningPublicKeys(ctx)
- if keys == nil {
- return false, fmt.Errorf("own cross-signing keys not found")
- }
- isVerified, err := h.Crypto.CryptoStore.IsKeySignedBy(ctx, h.Account.UserID, h.Crypto.GetAccount().SigningKey(), h.Account.UserID, keys.SelfSigningKey)
- if err != nil {
- return false, fmt.Errorf("failed to check if current device is signed by own self-signing key: %w", err)
- }
- return isVerified, nil
-}
-
-func (h *HiClient) fetchKeyBackupKey(ctx context.Context, ssssKey *ssss.Key) error {
- latestVersion, err := h.Client.GetKeyBackupLatestVersion(ctx)
- if err != nil {
- return fmt.Errorf("failed to get key backup latest version: %w", err)
- }
- h.KeyBackupVersion = latestVersion.Version
- data, err := h.Crypto.SSSS.GetDecryptedAccountData(ctx, event.AccountDataMegolmBackupKey, ssssKey)
- if err != nil {
- return fmt.Errorf("failed to get megolm backup key from SSSS: %w", err)
- }
- key, err := backup.MegolmBackupKeyFromBytes(data)
- if err != nil {
- return fmt.Errorf("failed to parse megolm backup key: %w", err)
- }
- err = h.CryptoStore.PutSecret(ctx, id.SecretMegolmBackupV1, base64.StdEncoding.EncodeToString(key.Bytes()))
- if err != nil {
- return fmt.Errorf("failed to store megolm backup key: %w", err)
- }
- h.KeyBackupKey = key
- return nil
-}
-
-func (h *HiClient) getAndDecodeSecret(ctx context.Context, secret id.Secret) ([]byte, error) {
- secretData, err := h.CryptoStore.GetSecret(ctx, secret)
- if err != nil {
- return nil, fmt.Errorf("failed to get secret %s: %w", secret, err)
- }
- data, err := base64.StdEncoding.DecodeString(secretData)
- if err != nil {
- return nil, fmt.Errorf("failed to decode secret %s: %w", secret, err)
- }
- return data, nil
-}
-
-func (h *HiClient) loadPrivateKeys(ctx context.Context) error {
- zerolog.Ctx(ctx).Debug().Msg("Loading cross-signing private keys")
- masterKeySeed, err := h.getAndDecodeSecret(ctx, id.SecretXSMaster)
- if err != nil {
- return fmt.Errorf("failed to get master key: %w", err)
- }
- selfSigningKeySeed, err := h.getAndDecodeSecret(ctx, id.SecretXSSelfSigning)
- if err != nil {
- return fmt.Errorf("failed to get self-signing key: %w", err)
- }
- userSigningKeySeed, err := h.getAndDecodeSecret(ctx, id.SecretXSUserSigning)
- if err != nil {
- return fmt.Errorf("failed to get user signing key: %w", err)
- }
- err = h.Crypto.ImportCrossSigningKeys(crypto.CrossSigningSeeds{
- MasterKey: masterKeySeed,
- SelfSigningKey: selfSigningKeySeed,
- UserSigningKey: userSigningKeySeed,
- })
- if err != nil {
- return fmt.Errorf("failed to import cross-signing private keys: %w", err)
- }
- zerolog.Ctx(ctx).Debug().Msg("Loading key backup key")
- keyBackupKey, err := h.getAndDecodeSecret(ctx, id.SecretMegolmBackupV1)
- if err != nil {
- return fmt.Errorf("failed to get megolm backup key: %w", err)
- }
- h.KeyBackupKey, err = backup.MegolmBackupKeyFromBytes(keyBackupKey)
- if err != nil {
- return fmt.Errorf("failed to parse megolm backup key: %w", err)
- }
- zerolog.Ctx(ctx).Debug().Msg("Fetching key backup version")
- latestVersion, err := h.Client.GetKeyBackupLatestVersion(ctx)
- if err != nil {
- return fmt.Errorf("failed to get key backup latest version: %w", err)
- }
- h.KeyBackupVersion = latestVersion.Version
- zerolog.Ctx(ctx).Debug().Msg("Secrets loaded")
- return nil
-}
-
-func (h *HiClient) storeCrossSigningPrivateKeys(ctx context.Context) error {
- keys := h.Crypto.CrossSigningKeys
- err := h.CryptoStore.PutSecret(ctx, id.SecretXSMaster, base64.StdEncoding.EncodeToString(keys.MasterKey.Seed()))
- if err != nil {
- return err
- }
- err = h.CryptoStore.PutSecret(ctx, id.SecretXSSelfSigning, base64.StdEncoding.EncodeToString(keys.SelfSigningKey.Seed()))
- if err != nil {
- return err
- }
- err = h.CryptoStore.PutSecret(ctx, id.SecretXSUserSigning, base64.StdEncoding.EncodeToString(keys.UserSigningKey.Seed()))
- if err != nil {
- return err
- }
- return nil
-}
-
-func (h *HiClient) VerifyWithRecoveryKey(ctx context.Context, code string) error {
- defer h.dispatchCurrentState()
- keyID, keyData, err := h.Crypto.SSSS.GetDefaultKeyData(ctx)
- if err != nil {
- return fmt.Errorf("failed to get default SSSS key data: %w", err)
- }
- key, err := keyData.VerifyRecoveryKey(keyID, code)
- if err != nil {
- return err
- }
- err = h.Crypto.FetchCrossSigningKeysFromSSSS(ctx, key)
- if err != nil {
- return fmt.Errorf("failed to fetch cross-signing keys from SSSS: %w", err)
- }
- err = h.Crypto.SignOwnDevice(ctx, h.Crypto.OwnIdentity())
- if err != nil {
- return fmt.Errorf("failed to sign own device: %w", err)
- }
- err = h.Crypto.SignOwnMasterKey(ctx)
- if err != nil {
- return fmt.Errorf("failed to sign own master key: %w", err)
- }
- err = h.storeCrossSigningPrivateKeys(ctx)
- if err != nil {
- return fmt.Errorf("failed to store cross-signing private keys: %w", err)
- }
- err = h.fetchKeyBackupKey(ctx, key)
- if err != nil {
- return fmt.Errorf("failed to fetch key backup key: %w", err)
- }
- h.Verified = true
- if !h.IsSyncing() {
- go h.Sync()
- }
- return nil
-}
diff --git a/id/contenturi.go b/id/contenturi.go
index e6a313f5..67127b6c 100644
--- a/id/contenturi.go
+++ b/id/contenturi.go
@@ -17,8 +17,14 @@ import (
)
var (
- InvalidContentURI = errors.New("invalid Matrix content URI")
- InputNotJSONString = errors.New("input doesn't look like a JSON string")
+ ErrInvalidContentURI = errors.New("invalid Matrix content URI")
+ ErrInputNotJSONString = errors.New("input doesn't look like a JSON string")
+)
+
+// Deprecated: use variables prefixed with Err
+var (
+ InvalidContentURI = ErrInvalidContentURI
+ InputNotJSONString = ErrInputNotJSONString
)
// ContentURIString is a string that's expected to be a Matrix content URI.
@@ -55,9 +61,9 @@ func ParseContentURI(uri string) (parsed ContentURI, err error) {
if len(uri) == 0 {
return
} else if !strings.HasPrefix(uri, "mxc://") {
- err = InvalidContentURI
+ err = ErrInvalidContentURI
} else if index := strings.IndexRune(uri[6:], '/'); index == -1 || index == len(uri)-7 {
- err = InvalidContentURI
+ err = ErrInvalidContentURI
} else {
parsed.Homeserver = uri[6 : 6+index]
parsed.FileID = uri[6+index+1:]
@@ -71,9 +77,9 @@ func ParseContentURIBytes(uri []byte) (parsed ContentURI, err error) {
if len(uri) == 0 {
return
} else if !bytes.HasPrefix(uri, mxcBytes) {
- err = InvalidContentURI
+ err = ErrInvalidContentURI
} else if index := bytes.IndexRune(uri[6:], '/'); index == -1 || index == len(uri)-7 {
- err = InvalidContentURI
+ err = ErrInvalidContentURI
} else {
parsed.Homeserver = string(uri[6 : 6+index])
parsed.FileID = string(uri[6+index+1:])
@@ -86,7 +92,7 @@ func (uri *ContentURI) UnmarshalJSON(raw []byte) (err error) {
*uri = ContentURI{}
return nil
} else if len(raw) < 2 || raw[0] != '"' || raw[len(raw)-1] != '"' {
- return InputNotJSONString
+ return fmt.Errorf("ContentURI: %w", ErrInputNotJSONString)
}
parsed, err := ParseContentURIBytes(raw[1 : len(raw)-1])
if err != nil {
diff --git a/id/crypto.go b/id/crypto.go
index 355a84a8..ee857f78 100644
--- a/id/crypto.go
+++ b/id/crypto.go
@@ -53,6 +53,34 @@ const (
KeyBackupAlgorithmMegolmBackupV1 KeyBackupAlgorithm = "m.megolm_backup.v1.curve25519-aes-sha2"
)
+type KeySource string
+
+func (source KeySource) String() string {
+ return string(source)
+}
+
+func (source KeySource) Int() int {
+ switch source {
+ case KeySourceDirect:
+ return 100
+ case KeySourceBackup:
+ return 90
+ case KeySourceImport:
+ return 80
+ case KeySourceForward:
+ return 50
+ default:
+ return 0
+ }
+}
+
+const (
+ KeySourceDirect KeySource = "direct"
+ KeySourceBackup KeySource = "backup"
+ KeySourceImport KeySource = "import"
+ KeySourceForward KeySource = "forward"
+)
+
// BackupVersion is an arbitrary string that identifies a server side key backup.
type KeyBackupVersion string
diff --git a/id/matrixuri.go b/id/matrixuri.go
index acd8e0c0..d5c78bc7 100644
--- a/id/matrixuri.go
+++ b/id/matrixuri.go
@@ -54,7 +54,7 @@ var SigilToPathSegment = map[rune]string{
func (uri *MatrixURI) getQuery() url.Values {
q := make(url.Values)
- if uri.Via != nil && len(uri.Via) > 0 {
+ if len(uri.Via) > 0 {
q["via"] = uri.Via
}
if len(uri.Action) > 0 {
@@ -210,10 +210,14 @@ func ProcessMatrixURI(uri *url.URL) (*MatrixURI, error) {
if len(parts[1]) == 0 {
return nil, ErrEmptySecondSegment
}
- parsed.MXID1 = parts[1]
+ var err error
+ parsed.MXID1, err = url.PathUnescape(parts[1])
+ if err != nil {
+ return nil, fmt.Errorf("failed to url decode second segment %q: %w", parts[1], err)
+ }
// Step 6: if the first part is a room and the URI has 4 segments, construct a second level identifier
- if (parsed.Sigil1 == '!' || parsed.Sigil1 == '#') && len(parts) == 4 {
+ if parsed.Sigil1 == '!' && len(parts) == 4 {
// a: find the sigil from the third segment
switch parts[2] {
case "e", "event":
@@ -226,7 +230,10 @@ func ProcessMatrixURI(uri *url.URL) (*MatrixURI, error) {
if len(parts[3]) == 0 {
return nil, ErrEmptyFourthSegment
}
- parsed.MXID2 = parts[3]
+ parsed.MXID2, err = url.PathUnescape(parts[3])
+ if err != nil {
+ return nil, fmt.Errorf("failed to url decode fourth segment %q: %w", parts[3], err)
+ }
}
// Step 7: parse the query and extract via and action items
diff --git a/id/matrixuri_test.go b/id/matrixuri_test.go
index d26d4bfd..90a0754d 100644
--- a/id/matrixuri_test.go
+++ b/id/matrixuri_test.go
@@ -16,12 +16,11 @@ import (
)
var (
- roomIDLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org"}
- roomIDViaLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org", Via: []string{"maunium.net", "matrix.org"}}
- roomAliasLink = id.MatrixURI{Sigil1: '#', MXID1: "someroom:example.org"}
- roomIDEventLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org", Sigil2: '$', MXID2: "uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s"}
- roomAliasEventLink = id.MatrixURI{Sigil1: '#', MXID1: "someroom:example.org", Sigil2: '$', MXID2: "uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s"}
- userLink = id.MatrixURI{Sigil1: '@', MXID1: "user:example.org"}
+ roomIDLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org"}
+ roomIDViaLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org", Via: []string{"maunium.net", "matrix.org"}}
+ roomAliasLink = id.MatrixURI{Sigil1: '#', MXID1: "someroom:example.org"}
+ roomIDEventLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org", Sigil2: '$', MXID2: "uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s"}
+ userLink = id.MatrixURI{Sigil1: '@', MXID1: "user:example.org"}
escapeRoomIDEventLink = id.MatrixURI{Sigil1: '!', MXID1: "meow & 🐈️:example.org", Sigil2: '$', MXID2: "uOH4C9cK4HhMeFWkUXMbdF/dtndJ0j9je+kIK3XpV1s"}
)
@@ -31,7 +30,6 @@ func TestMatrixURI_MatrixToURL(t *testing.T) {
assert.Equal(t, "https://matrix.to/#/%217NdBVvkd4aLSbgKt9RXl:example.org?via=maunium.net&via=matrix.org", roomIDViaLink.MatrixToURL())
assert.Equal(t, "https://matrix.to/#/%23someroom:example.org", roomAliasLink.MatrixToURL())
assert.Equal(t, "https://matrix.to/#/%217NdBVvkd4aLSbgKt9RXl:example.org/$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s", roomIDEventLink.MatrixToURL())
- assert.Equal(t, "https://matrix.to/#/%23someroom:example.org/$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s", roomAliasEventLink.MatrixToURL())
assert.Equal(t, "https://matrix.to/#/@user:example.org", userLink.MatrixToURL())
assert.Equal(t, "https://matrix.to/#/%21meow%20&%20%F0%9F%90%88%EF%B8%8F:example.org/$uOH4C9cK4HhMeFWkUXMbdF%2FdtndJ0j9je+kIK3XpV1s", escapeRoomIDEventLink.MatrixToURL())
}
@@ -41,7 +39,6 @@ func TestMatrixURI_String(t *testing.T) {
assert.Equal(t, "matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org?via=maunium.net&via=matrix.org", roomIDViaLink.String())
assert.Equal(t, "matrix:r/someroom:example.org", roomAliasLink.String())
assert.Equal(t, "matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s", roomIDEventLink.String())
- assert.Equal(t, "matrix:r/someroom:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s", roomAliasEventLink.String())
assert.Equal(t, "matrix:u/user:example.org", userLink.String())
assert.Equal(t, "matrix:roomid/meow%20&%20%F0%9F%90%88%EF%B8%8F:example.org/e/uOH4C9cK4HhMeFWkUXMbdF%2FdtndJ0j9je+kIK3XpV1s", escapeRoomIDEventLink.String())
}
@@ -80,8 +77,12 @@ func TestParseMatrixURI_RoomID(t *testing.T) {
parsedVia, err := id.ParseMatrixURI("matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org?via=maunium.net&via=matrix.org")
require.NoError(t, err)
require.NotNil(t, parsedVia)
+ parsedEncoded, err := id.ParseMatrixURI("matrix:roomid/7NdBVvkd4aLSbgKt9RXl%3Aexample.org")
+ require.NoError(t, err)
+ require.NotNil(t, parsedEncoded)
assert.Equal(t, roomIDLink, *parsed)
+ assert.Equal(t, roomIDLink, *parsedEncoded)
assert.Equal(t, roomIDViaLink, *parsedVia)
}
@@ -98,19 +99,11 @@ func TestParseMatrixURI_UserID(t *testing.T) {
}
func TestParseMatrixURI_EventID(t *testing.T) {
- parsed1, err := id.ParseMatrixURI("matrix:r/someroom:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s")
+ parsed, err := id.ParseMatrixURI("matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s")
require.NoError(t, err)
- require.NotNil(t, parsed1)
- parsed2, err := id.ParseMatrixURI("matrix:room/someroom:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s")
- require.NoError(t, err)
- require.NotNil(t, parsed2)
- parsed3, err := id.ParseMatrixURI("matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s")
- require.NoError(t, err)
- require.NotNil(t, parsed3)
+ require.NotNil(t, parsed)
- assert.Equal(t, roomAliasEventLink, *parsed1)
- assert.Equal(t, roomAliasEventLink, *parsed2)
- assert.Equal(t, roomIDEventLink, *parsed3)
+ assert.Equal(t, roomIDEventLink, *parsed)
}
func TestParseMatrixToURL_RoomAlias(t *testing.T) {
@@ -158,21 +151,13 @@ func TestParseMatrixToURL_UserID(t *testing.T) {
}
func TestParseMatrixToURL_EventID(t *testing.T) {
- parsed1, err := id.ParseMatrixToURL("https://matrix.to/#/#someroom:example.org/$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s")
+ parsed, err := id.ParseMatrixToURL("https://matrix.to/#/!7NdBVvkd4aLSbgKt9RXl:example.org/$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s")
require.NoError(t, err)
- require.NotNil(t, parsed1)
- parsed2, err := id.ParseMatrixToURL("https://matrix.to/#/!7NdBVvkd4aLSbgKt9RXl:example.org/$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s")
+ require.NotNil(t, parsed)
+ parsedEncoded, err := id.ParseMatrixToURL("https://matrix.to/#/%217NdBVvkd4aLSbgKt9RXl:example.org/%24uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s")
require.NoError(t, err)
- require.NotNil(t, parsed2)
- parsed1Encoded, err := id.ParseMatrixToURL("https://matrix.to/#/%23someroom:example.org/%24uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s")
- require.NoError(t, err)
- require.NotNil(t, parsed1)
- parsed2Encoded, err := id.ParseMatrixToURL("https://matrix.to/#/%217NdBVvkd4aLSbgKt9RXl:example.org/%24uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s")
- require.NoError(t, err)
- require.NotNil(t, parsed2)
+ require.NotNil(t, parsedEncoded)
- assert.Equal(t, roomAliasEventLink, *parsed1)
- assert.Equal(t, roomAliasEventLink, *parsed1Encoded)
- assert.Equal(t, roomIDEventLink, *parsed2)
- assert.Equal(t, roomIDEventLink, *parsed2Encoded)
+ assert.Equal(t, roomIDEventLink, *parsed)
+ assert.Equal(t, roomIDEventLink, *parsedEncoded)
}
diff --git a/id/opaque.go b/id/opaque.go
index 1d9f0dcf..c1ad4988 100644
--- a/id/opaque.go
+++ b/id/opaque.go
@@ -32,6 +32,9 @@ type EventID string
// https://github.com/matrix-org/matrix-doc/pull/2716
type BatchID string
+// A DelayID is a string identifying a delayed event.
+type DelayID string
+
func (roomID RoomID) String() string {
return string(roomID)
}
diff --git a/id/roomversion.go b/id/roomversion.go
new file mode 100644
index 00000000..578c10bd
--- /dev/null
+++ b/id/roomversion.go
@@ -0,0 +1,265 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package id
+
+import (
+ "errors"
+ "fmt"
+ "slices"
+)
+
+type RoomVersion string
+
+const (
+ RoomV0 RoomVersion = "" // No room version, used for rooms created before room versions were introduced, equivalent to v1
+ RoomV1 RoomVersion = "1"
+ RoomV2 RoomVersion = "2"
+ RoomV3 RoomVersion = "3"
+ RoomV4 RoomVersion = "4"
+ RoomV5 RoomVersion = "5"
+ RoomV6 RoomVersion = "6"
+ RoomV7 RoomVersion = "7"
+ RoomV8 RoomVersion = "8"
+ RoomV9 RoomVersion = "9"
+ RoomV10 RoomVersion = "10"
+ RoomV11 RoomVersion = "11"
+ RoomV12 RoomVersion = "12"
+)
+
+func (rv RoomVersion) Equals(versions ...RoomVersion) bool {
+ return slices.Contains(versions, rv)
+}
+
+func (rv RoomVersion) NotEquals(versions ...RoomVersion) bool {
+ return !rv.Equals(versions...)
+}
+
+var ErrUnknownRoomVersion = errors.New("unknown room version")
+
+func (rv RoomVersion) unknownVersionError() error {
+ return fmt.Errorf("%w %s", ErrUnknownRoomVersion, rv)
+}
+
+func (rv RoomVersion) IsKnown() bool {
+ switch rv {
+ case RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11, RoomV12:
+ return true
+ default:
+ return false
+ }
+}
+
+type StateResVersion int
+
+const (
+ // StateResV1 is the original state resolution algorithm.
+ StateResV1 StateResVersion = 0
+ // StateResV2 is state resolution v2 introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/1759
+ StateResV2 StateResVersion = 1
+ // StateResV2_1 is state resolution v2.1 introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/4297
+ StateResV2_1 StateResVersion = 2
+)
+
+// StateResVersion returns the version of the state resolution algorithm used by this room version.
+func (rv RoomVersion) StateResVersion() StateResVersion {
+ switch rv {
+ case RoomV0, RoomV1:
+ return StateResV1
+ case RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11:
+ return StateResV2
+ case RoomV12:
+ return StateResV2_1
+ default:
+ panic(rv.unknownVersionError())
+ }
+}
+
+type EventIDFormat int
+
+const (
+ // EventIDFormatCustom is the original format used by room v1 and v2.
+ // Event IDs in this format are an arbitrary string followed by a colon and the server name.
+ EventIDFormatCustom EventIDFormat = 0
+ // EventIDFormatBase64 is the format used by room v3 introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/1659.
+ // Event IDs in this format are the standard unpadded base64-encoded SHA256 reference hash of the event.
+ EventIDFormatBase64 EventIDFormat = 1
+ // EventIDFormatURLSafeBase64 is the format used by room v4 and later introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/2002.
+ // Event IDs in this format are the url-safe unpadded base64-encoded SHA256 reference hash of the event.
+ EventIDFormatURLSafeBase64 EventIDFormat = 2
+)
+
+// EventIDFormat returns the format of event IDs used by this room version.
+func (rv RoomVersion) EventIDFormat() EventIDFormat {
+ switch rv {
+ case RoomV0, RoomV1, RoomV2:
+ return EventIDFormatCustom
+ case RoomV3:
+ return EventIDFormatBase64
+ default:
+ return EventIDFormatURLSafeBase64
+ }
+}
+
+/////////////////////
+// Room v5 changes //
+/////////////////////
+// https://github.com/matrix-org/matrix-spec-proposals/pull/2077
+
+// EnforceSigningKeyValidity returns true if the `valid_until_ts` field of federation signing keys
+// must be enforced on received events.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/2076
+func (rv RoomVersion) EnforceSigningKeyValidity() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4)
+}
+
+/////////////////////
+// Room v6 changes //
+/////////////////////
+// https://github.com/matrix-org/matrix-spec-proposals/pull/2240
+
+// SpecialCasedAliasesAuth returns true if the `m.room.aliases` event authorization is special cased
+// to only always allow servers to modify the state event with their own server name as state key.
+// This also implies that the `aliases` field is protected from redactions.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/2432
+func (rv RoomVersion) SpecialCasedAliasesAuth() bool {
+ return rv.Equals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5)
+}
+
+// ForbidFloatsAndBigInts returns true if floats and integers greater than 2^53-1 or lower than -2^53+1 are forbidden everywhere.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/2540
+func (rv RoomVersion) ForbidFloatsAndBigInts() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5)
+}
+
+// NotificationsPowerLevels returns true if the `notifications` field in `m.room.power_levels` is validated in event auth.
+// However, the field is not protected from redactions.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/2209
+func (rv RoomVersion) NotificationsPowerLevels() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5)
+}
+
+/////////////////////
+// Room v7 changes //
+/////////////////////
+// https://github.com/matrix-org/matrix-spec-proposals/pull/2998
+
+// Knocks returns true if the `knock` join rule is supported.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/2403
+func (rv RoomVersion) Knocks() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6)
+}
+
+/////////////////////
+// Room v8 changes //
+/////////////////////
+// https://github.com/matrix-org/matrix-spec-proposals/pull/3289
+
+// RestrictedJoins returns true if the `restricted` join rule is supported.
+// This also implies that the `allow` field in the `m.room.join_rules` event is supported and protected from redactions.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/3083
+func (rv RoomVersion) RestrictedJoins() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7)
+}
+
+/////////////////////
+// Room v9 changes //
+/////////////////////
+// https://github.com/matrix-org/matrix-spec-proposals/pull/3375
+
+// RestrictedJoinsFix returns true if the `join_authorised_via_users_server` field in `m.room.member` events is protected from redactions.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/3375
+func (rv RoomVersion) RestrictedJoinsFix() bool {
+ return rv.RestrictedJoins() && rv != RoomV8
+}
+
+//////////////////////
+// Room v10 changes //
+//////////////////////
+// https://github.com/matrix-org/matrix-spec-proposals/pull/3604
+
+// ValidatePowerLevelInts returns true if the known values in `m.room.power_levels` must be integers (and not strings).
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/3667
+func (rv RoomVersion) ValidatePowerLevelInts() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9)
+}
+
+// KnockRestricted returns true if the `knock_restricted` join rule is supported.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/3787
+func (rv RoomVersion) KnockRestricted() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9)
+}
+
+//////////////////////
+// Room v11 changes //
+//////////////////////
+// https://github.com/matrix-org/matrix-spec-proposals/pull/3820
+
+// CreatorInContent returns true if the `m.room.create` event has a `creator` field in content.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/2175
+func (rv RoomVersion) CreatorInContent() bool {
+ return rv.Equals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10)
+}
+
+// RedactsInContent returns true if the `m.room.redaction` event has the `redacts` field in content instead of at the top level.
+// The redaction protection is also moved from the top level to the content field.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/2174
+// (and https://github.com/matrix-org/matrix-spec-proposals/pull/2176 for the redaction protection).
+func (rv RoomVersion) RedactsInContent() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10)
+}
+
+// UpdatedRedactionRules returns true if various updates to the redaction algorithm are applied.
+//
+// Specifically:
+//
+// * the `membership`, `origin`, and `prev_state` fields at the top level of all events are no longer protected.
+// * the entire content of `m.room.create` is protected.
+// * the `redacts` field in `m.room.redaction` content is protected instead of the top-level field.
+// * the `m.room.power_levels` event protects the `invite` field in content.
+// * the `signed` field inside the `third_party_invite` field in content of `m.room.member` events is protected.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/2176,
+// https://github.com/matrix-org/matrix-spec-proposals/pull/3821, and
+// https://github.com/matrix-org/matrix-spec-proposals/pull/3989
+func (rv RoomVersion) UpdatedRedactionRules() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10)
+}
+
+//////////////////////
+// Room v12 changes //
+//////////////////////
+// https://github.com/matrix-org/matrix-spec-proposals/pull/4304
+
+// Return value of StateResVersion was changed to StateResV2_1
+
+// PrivilegedRoomCreators returns true if the creator(s) of a room always have infinite power level.
+// This also implies that the `m.room.create` event has an `additional_creators` field,
+// and that the creators can't be present in the `m.room.power_levels` event.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/4289
+func (rv RoomVersion) PrivilegedRoomCreators() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11)
+}
+
+// RoomIDIsCreateEventID returns true if the ID of rooms is the same as the ID of the `m.room.create` event.
+// This also implies that `m.room.create` events do not have a `room_id` field.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/4291
+func (rv RoomVersion) RoomIDIsCreateEventID() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11)
+}
diff --git a/id/servername.go b/id/servername.go
new file mode 100644
index 00000000..923705b6
--- /dev/null
+++ b/id/servername.go
@@ -0,0 +1,58 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package id
+
+import (
+ "regexp"
+ "strconv"
+)
+
+type ParsedServerNameType int
+
+const (
+ ServerNameDNS ParsedServerNameType = iota
+ ServerNameIPv4
+ ServerNameIPv6
+)
+
+type ParsedServerName struct {
+ Type ParsedServerNameType
+ Host string
+ Port int
+}
+
+var ServerNameRegex = regexp.MustCompile(`^(?:\[([0-9A-Fa-f:.]{2,45})]|(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})|([0-9A-Za-z.-]{1,255}))(?::(\d{1,5}))?$`)
+
+func ValidateServerName(serverName string) bool {
+ return len(serverName) <= 255 && len(serverName) > 0 && ServerNameRegex.MatchString(serverName)
+}
+
+func ParseServerName(serverName string) *ParsedServerName {
+ if len(serverName) > 255 || len(serverName) < 1 {
+ return nil
+ }
+ match := ServerNameRegex.FindStringSubmatch(serverName)
+ if len(match) != 5 {
+ return nil
+ }
+ port, _ := strconv.Atoi(match[4])
+ parsed := &ParsedServerName{
+ Port: port,
+ }
+ switch {
+ case match[1] != "":
+ parsed.Type = ServerNameIPv6
+ parsed.Host = match[1]
+ case match[2] != "":
+ parsed.Type = ServerNameIPv4
+ parsed.Host = match[2]
+ case match[3] != "":
+ parsed.Type = ServerNameDNS
+ parsed.Host = match[3]
+ }
+ return parsed
+}
diff --git a/id/trust.go b/id/trust.go
index 04f6e36b..6255093e 100644
--- a/id/trust.go
+++ b/id/trust.go
@@ -16,6 +16,7 @@ type TrustState int
const (
TrustStateBlacklisted TrustState = -100
+ TrustStateDeviceKeyMismatch TrustState = -5
TrustStateUnset TrustState = 0
TrustStateUnknownDevice TrustState = 10
TrustStateForwarded TrustState = 20
@@ -23,7 +24,7 @@ const (
TrustStateCrossSignedTOFU TrustState = 100
TrustStateCrossSignedVerified TrustState = 200
TrustStateVerified TrustState = 300
- TrustStateInvalid TrustState = (1 << 31) - 1
+ TrustStateInvalid TrustState = -2147483647
)
func (ts *TrustState) UnmarshalText(data []byte) error {
@@ -44,6 +45,8 @@ func ParseTrustState(val string) TrustState {
switch strings.ToLower(val) {
case "blacklisted":
return TrustStateBlacklisted
+ case "device-key-mismatch":
+ return TrustStateDeviceKeyMismatch
case "unverified":
return TrustStateUnset
case "cross-signed-untrusted":
@@ -67,6 +70,8 @@ func (ts TrustState) String() string {
switch ts {
case TrustStateBlacklisted:
return "blacklisted"
+ case TrustStateDeviceKeyMismatch:
+ return "device-key-mismatch"
case TrustStateUnset:
return "unverified"
case TrustStateCrossSignedUntrusted:
diff --git a/id/userid.go b/id/userid.go
index 1e1f3b29..726a0d58 100644
--- a/id/userid.go
+++ b/id/userid.go
@@ -30,10 +30,11 @@ func NewEncodedUserID(localpart, homeserver string) UserID {
}
var (
- ErrInvalidUserID = errors.New("is not a valid user ID")
- ErrNoncompliantLocalpart = errors.New("contains characters that are not allowed")
- ErrUserIDTooLong = errors.New("the given user ID is longer than 255 characters")
- ErrEmptyLocalpart = errors.New("empty localparts are not allowed")
+ ErrInvalidUserID = errors.New("is not a valid user ID")
+ ErrNoncompliantLocalpart = errors.New("contains characters that are not allowed")
+ ErrUserIDTooLong = errors.New("the given user ID is longer than 255 characters")
+ ErrEmptyLocalpart = errors.New("empty localparts are not allowed")
+ ErrNoncompliantServerPart = errors.New("is not a valid server name")
)
// ParseCommonIdentifier parses a common identifier according to https://spec.matrix.org/v1.9/appendices/#common-identifier-format
@@ -43,10 +44,10 @@ func ParseCommonIdentifier[Stringish ~string](identifier Stringish) (sigil byte,
}
sigil = identifier[0]
strIdentifier := string(identifier)
- if strings.ContainsRune(strIdentifier, ':') {
- parts := strings.SplitN(strIdentifier, ":", 2)
- localpart = parts[0][1:]
- homeserver = parts[1]
+ colonIdx := strings.IndexByte(strIdentifier, ':')
+ if colonIdx > 0 {
+ localpart = strIdentifier[1:colonIdx]
+ homeserver = strIdentifier[colonIdx+1:]
} else {
localpart = strIdentifier[1:]
}
@@ -103,21 +104,32 @@ func ValidateUserLocalpart(localpart string) error {
return nil
}
-// ParseAndValidate parses the user ID into the localpart and server name like Parse,
-// and also validates that the localpart is allowed according to the user identifiers spec.
-func (userID UserID) ParseAndValidate() (localpart, homeserver string, err error) {
- localpart, homeserver, err = userID.Parse()
+// ParseAndValidateStrict is a stricter version of ParseAndValidateRelaxed that checks the localpart to only allow non-historical localparts.
+// This should be used with care: there are real users still using historical localparts.
+func (userID UserID) ParseAndValidateStrict() (localpart, homeserver string, err error) {
+ localpart, homeserver, err = userID.ParseAndValidateRelaxed()
if err == nil {
err = ValidateUserLocalpart(localpart)
}
- if err == nil && len(userID) > UserIDMaxLength {
+ return
+}
+
+// ParseAndValidateRelaxed parses the user ID into the localpart and server name like Parse,
+// and also validates that the user ID is not too long and that the server name is valid.
+func (userID UserID) ParseAndValidateRelaxed() (localpart, homeserver string, err error) {
+ if len(userID) > UserIDMaxLength {
err = ErrUserIDTooLong
+ return
+ }
+ localpart, homeserver, err = userID.Parse()
+ if err == nil && !ValidateServerName(homeserver) {
+ err = fmt.Errorf("%q %q", homeserver, ErrNoncompliantServerPart)
}
return
}
func (userID UserID) ParseAndDecode() (localpart, homeserver string, err error) {
- localpart, homeserver, err = userID.ParseAndValidate()
+ localpart, homeserver, err = userID.ParseAndValidateStrict()
if err == nil {
localpart, err = DecodeUserLocalpart(localpart)
}
@@ -207,15 +219,15 @@ func DecodeUserLocalpart(str string) (string, error) {
for i := 0; i < len(strBytes); i++ {
b := strBytes[i]
if !isValidByte(b) {
- return "", fmt.Errorf("Byte pos %d: Invalid byte", i)
+ return "", fmt.Errorf("invalid encoded byte at position %d: %c", i, b)
}
if b == '_' { // next byte is a-z and should be upper-case or is another _ and should be a literal _
if i+1 >= len(strBytes) {
- return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding but ran out of string", i)
+ return "", fmt.Errorf("unexpected end of string after underscore at %d", i)
}
if !isValidEscapedChar(strBytes[i+1]) { // invalid escaping
- return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding", i)
+ return "", fmt.Errorf("unexpected byte %c after underscore at %d", strBytes[i+1], i)
}
if strBytes[i+1] == '_' {
outputBuffer.WriteByte('_')
@@ -225,7 +237,7 @@ func DecodeUserLocalpart(str string) (string, error) {
i++ // skip next byte since we just handled it
} else if b == '=' { // next 2 bytes are hex and should be buffered ready to be read as utf8
if i+2 >= len(strBytes) {
- return "", fmt.Errorf("Byte pos: %d: expected quote-printable encoding but ran out of string", i)
+ return "", fmt.Errorf("unexpected end of string after equals sign at %d", i)
}
dst := make([]byte, 1)
_, err := hex.Decode(dst, strBytes[i+1:i+3])
diff --git a/id/userid_test.go b/id/userid_test.go
index 359bc687..57a88066 100644
--- a/id/userid_test.go
+++ b/id/userid_test.go
@@ -38,30 +38,30 @@ func TestUserID_Parse_Invalid(t *testing.T) {
assert.True(t, errors.Is(err, id.ErrInvalidUserID))
}
-func TestUserID_ParseAndValidate_Invalid(t *testing.T) {
+func TestUserID_ParseAndValidateStrict_Invalid(t *testing.T) {
const inputUserID = "@s p a c e:maunium.net"
- _, _, err := id.UserID(inputUserID).ParseAndValidate()
+ _, _, err := id.UserID(inputUserID).ParseAndValidateStrict()
assert.Error(t, err)
assert.True(t, errors.Is(err, id.ErrNoncompliantLocalpart))
}
-func TestUserID_ParseAndValidate_Empty(t *testing.T) {
+func TestUserID_ParseAndValidateStrict_Empty(t *testing.T) {
const inputUserID = "@:ponies.im"
- _, _, err := id.UserID(inputUserID).ParseAndValidate()
+ _, _, err := id.UserID(inputUserID).ParseAndValidateStrict()
assert.Error(t, err)
assert.True(t, errors.Is(err, id.ErrEmptyLocalpart))
}
-func TestUserID_ParseAndValidate_Long(t *testing.T) {
+func TestUserID_ParseAndValidateStrict_Long(t *testing.T) {
const inputUserID = "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa:example.com"
- _, _, err := id.UserID(inputUserID).ParseAndValidate()
+ _, _, err := id.UserID(inputUserID).ParseAndValidateStrict()
assert.Error(t, err)
assert.True(t, errors.Is(err, id.ErrUserIDTooLong))
}
-func TestUserID_ParseAndValidate_NotLong(t *testing.T) {
+func TestUserID_ParseAndValidateStrict_NotLong(t *testing.T) {
const inputUserID = "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa:example.com"
- _, _, err := id.UserID(inputUserID).ParseAndValidate()
+ _, _, err := id.UserID(inputUserID).ParseAndValidateStrict()
assert.NoError(t, err)
}
@@ -70,7 +70,7 @@ func TestUserIDEncoding(t *testing.T) {
const encodedLocalpart = "_this=20local+part=20contains=20_il_le_ga_l=20ch=c3=a4racters=20=f0=9f=9a=a8"
const inputServerName = "example.com"
userID := id.NewEncodedUserID(inputLocalpart, inputServerName)
- parsedLocalpart, parsedServerName, err := userID.ParseAndValidate()
+ parsedLocalpart, parsedServerName, err := userID.ParseAndValidateStrict()
assert.NoError(t, err)
assert.Equal(t, encodedLocalpart, parsedLocalpart)
assert.Equal(t, inputServerName, parsedServerName)
diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go
index f2591428..4d2bc7cf 100644
--- a/mediaproxy/mediaproxy.go
+++ b/mediaproxy/mediaproxy.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2024 Tulir Asokan
+// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -8,51 +8,107 @@ package mediaproxy
import (
"context"
- "encoding/json"
"errors"
"fmt"
"io"
"mime"
"mime/multipart"
- "net"
"net/http"
"net/textproto"
+ "net/url"
+ "os"
"strconv"
"strings"
"time"
- "github.com/gorilla/mux"
"github.com/rs/zerolog"
+ "github.com/rs/zerolog/hlog"
+ "go.mau.fi/util/exerrors"
+ "go.mau.fi/util/exhttp"
+ "go.mau.fi/util/ptr"
+ "go.mau.fi/util/requestlog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/federation"
+ "maunium.net/go/mautrix/id"
)
type GetMediaResponse interface {
isGetMediaResponse()
}
-func (*GetMediaResponseURL) isGetMediaResponse() {}
-func (*GetMediaResponseData) isGetMediaResponse() {}
+func (*GetMediaResponseURL) isGetMediaResponse() {}
+func (*GetMediaResponseData) isGetMediaResponse() {}
+func (*GetMediaResponseCallback) isGetMediaResponse() {}
+func (*GetMediaResponseFile) isGetMediaResponse() {}
type GetMediaResponseURL struct {
URL string
ExpiresAt time.Time
}
+type GetMediaResponseWriter interface {
+ GetMediaResponse
+ io.WriterTo
+ GetContentType() string
+ GetContentLength() int64
+}
+
+var (
+ _ GetMediaResponseWriter = (*GetMediaResponseCallback)(nil)
+ _ GetMediaResponseWriter = (*GetMediaResponseData)(nil)
+)
+
type GetMediaResponseData struct {
Reader io.ReadCloser
ContentType string
ContentLength int64
}
-type GetMediaFunc = func(ctx context.Context, mediaID string) (response GetMediaResponse, err error)
+func (d *GetMediaResponseData) WriteTo(w io.Writer) (int64, error) {
+ return io.Copy(w, d.Reader)
+}
+
+func (d *GetMediaResponseData) GetContentType() string {
+ return d.ContentType
+}
+
+func (d *GetMediaResponseData) GetContentLength() int64 {
+ return d.ContentLength
+}
+
+type GetMediaResponseCallback struct {
+ Callback func(w io.Writer) (int64, error)
+ ContentType string
+ ContentLength int64
+}
+
+func (d *GetMediaResponseCallback) WriteTo(w io.Writer) (int64, error) {
+ return d.Callback(w)
+}
+
+func (d *GetMediaResponseCallback) GetContentLength() int64 {
+ return d.ContentLength
+}
+
+func (d *GetMediaResponseCallback) GetContentType() string {
+ return d.ContentType
+}
+
+type FileMeta struct {
+ ContentType string
+ ReplacementFile string
+}
+
+type GetMediaResponseFile struct {
+ Callback func(w *os.File) (*FileMeta, error)
+}
+
+type GetMediaFunc = func(ctx context.Context, mediaID string, params map[string]string) (response GetMediaResponse, err error)
type MediaProxy struct {
- KeyServer *federation.KeyServer
- ProxyClient *http.Client
-
- ForceProxyLegacyFederation bool
+ KeyServer *federation.KeyServer
+ ServerAuth *federation.ServerAuth
GetMedia GetMediaFunc
PrepareProxyRequest func(*http.Request)
@@ -60,9 +116,8 @@ type MediaProxy struct {
serverName string
serverKey *federation.SigningKey
- FederationRouter *mux.Router
- LegacyMediaRouter *mux.Router
- ClientMediaRouter *mux.Router
+ FederationRouter *http.ServeMux
+ ClientMediaRouter *http.ServeMux
}
func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProxy, error) {
@@ -70,18 +125,10 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx
if err != nil {
return nil, err
}
- return &MediaProxy{
+ mp := &MediaProxy{
serverName: serverName,
serverKey: parsed,
GetMedia: getMedia,
- ProxyClient: &http.Client{
- Transport: &http.Transport{
- DialContext: (&net.Dialer{Timeout: 10 * time.Second}).DialContext,
- TLSHandshakeTimeout: 10 * time.Second,
- ForceAttemptHTTP2: false,
- },
- Timeout: 60 * time.Second,
- },
KeyServer: &federation.KeyServer{
KeyProvider: &federation.StaticServerKey{
ServerName: serverName,
@@ -93,13 +140,27 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx
Version: strings.TrimPrefix(mautrix.VersionWithCommit, "v"),
},
},
- }, nil
+ }
+ mp.FederationRouter = http.NewServeMux()
+ mp.FederationRouter.HandleFunc("GET /v1/media/download/{mediaID}", mp.DownloadMediaFederation)
+ mp.FederationRouter.HandleFunc("GET /v1/media/thumbnail/{mediaID}", mp.DownloadMediaFederation)
+ mp.FederationRouter.HandleFunc("GET /v1/version", mp.KeyServer.GetServerVersion)
+ mp.ClientMediaRouter = http.NewServeMux()
+ mp.ClientMediaRouter.HandleFunc("GET /download/{serverName}/{mediaID}", mp.DownloadMedia)
+ mp.ClientMediaRouter.HandleFunc("GET /download/{serverName}/{mediaID}/{fileName}", mp.DownloadMedia)
+ mp.ClientMediaRouter.HandleFunc("GET /thumbnail/{serverName}/{mediaID}", mp.DownloadMedia)
+ mp.ClientMediaRouter.HandleFunc("PUT /upload/{serverName}/{mediaID}", mp.UploadNotSupported)
+ mp.ClientMediaRouter.HandleFunc("POST /upload", mp.UploadNotSupported)
+ mp.ClientMediaRouter.HandleFunc("POST /create", mp.UploadNotSupported)
+ mp.ClientMediaRouter.HandleFunc("GET /config", mp.UploadNotSupported)
+ mp.ClientMediaRouter.HandleFunc("GET /preview_url", mp.PreviewURLNotSupported)
+ return mp, nil
}
type BasicConfig struct {
ServerName string `yaml:"server_name" json:"server_name"`
ServerKey string `yaml:"server_key" json:"server_key"`
- AllowProxy bool `yaml:"allow_proxy" json:"allow_proxy"`
+ FederationAuth bool `yaml:"federation_auth" json:"federation_auth"`
WellKnownResponse string `yaml:"well_known_response" json:"well_known_response"`
}
@@ -108,12 +169,12 @@ func NewFromConfig(cfg BasicConfig, getMedia GetMediaFunc) (*MediaProxy, error)
if err != nil {
return nil, err
}
- if !cfg.AllowProxy {
- mp.DisallowProxying()
- }
if cfg.WellKnownResponse != "" {
mp.KeyServer.WellKnownTarget = cfg.WellKnownResponse
}
+ if cfg.FederationAuth {
+ mp.EnableServerAuth(nil, nil)
+ }
return mp, nil
}
@@ -123,8 +184,8 @@ type ServerConfig struct {
}
func (mp *MediaProxy) Listen(cfg ServerConfig) error {
- router := mux.NewRouter()
- mp.RegisterRoutes(router)
+ router := http.NewServeMux()
+ mp.RegisterRoutes(router, zerolog.Nop())
return http.ListenAndServe(fmt.Sprintf("%s:%d", cfg.Hostname, cfg.Port), router)
}
@@ -136,99 +197,183 @@ func (mp *MediaProxy) GetServerKey() *federation.SigningKey {
return mp.serverKey
}
-func (mp *MediaProxy) DisallowProxying() {
- mp.ProxyClient = nil
+func (mp *MediaProxy) EnableServerAuth(client *federation.Client, keyCache federation.KeyCache) {
+ if keyCache == nil {
+ keyCache = federation.NewInMemoryCache()
+ }
+ if client == nil {
+ resCache, _ := keyCache.(federation.ResolutionCache)
+ client = federation.NewClient(mp.serverName, mp.serverKey, resCache)
+ }
+ mp.ServerAuth = federation.NewServerAuth(client, keyCache, func(auth federation.XMatrixAuth) string {
+ return mp.GetServerName()
+ })
}
-func (mp *MediaProxy) RegisterRoutes(router *mux.Router) {
- if mp.FederationRouter == nil {
- mp.FederationRouter = router.PathPrefix("/_matrix/federation").Subrouter()
+func (mp *MediaProxy) RegisterRoutes(router *http.ServeMux, log zerolog.Logger) {
+ errorBodies := exhttp.ErrorBodies{
+ NotFound: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint")).MarshalJSON()),
+ MethodNotAllowed: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint")).MarshalJSON()),
}
- if mp.LegacyMediaRouter == nil {
- mp.LegacyMediaRouter = router.PathPrefix("/_matrix/media").Subrouter()
- }
- if mp.ClientMediaRouter == nil {
- mp.ClientMediaRouter = router.PathPrefix("/_matrix/client/v1/media").Subrouter()
- }
-
- mp.FederationRouter.HandleFunc("/v1/media/download/{mediaID}", mp.DownloadMediaFederation).Methods(http.MethodGet)
- mp.FederationRouter.HandleFunc("/v1/version", mp.KeyServer.GetServerVersion).Methods(http.MethodGet)
- addClientRoutes := func(router *mux.Router, prefix string) {
- router.HandleFunc(prefix+"/download/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet)
- router.HandleFunc(prefix+"/download/{serverName}/{mediaID}/{fileName}", mp.DownloadMedia).Methods(http.MethodGet)
- router.HandleFunc(prefix+"/thumbnail/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet)
- router.HandleFunc(prefix+"/upload/{serverName}/{mediaID}", mp.UploadNotSupported).Methods(http.MethodPut)
- router.HandleFunc(prefix+"/upload", mp.UploadNotSupported).Methods(http.MethodPost)
- router.HandleFunc(prefix+"/create", mp.UploadNotSupported).Methods(http.MethodPost)
- router.HandleFunc(prefix+"/config", mp.UploadNotSupported).Methods(http.MethodGet)
- router.HandleFunc(prefix+"/preview_url", mp.PreviewURLNotSupported).Methods(http.MethodGet)
- }
- addClientRoutes(mp.LegacyMediaRouter, "/v3")
- addClientRoutes(mp.LegacyMediaRouter, "/r0")
- addClientRoutes(mp.LegacyMediaRouter, "/v1")
- addClientRoutes(mp.ClientMediaRouter, "")
- mp.LegacyMediaRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint)
- mp.LegacyMediaRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod)
- mp.FederationRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint)
- mp.FederationRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod)
- mp.ClientMediaRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint)
- mp.ClientMediaRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod)
- corsMiddleware := func(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Header().Set("Access-Control-Allow-Origin", "*")
- w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
- w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With, Content-Type, Authorization")
- w.Header().Set("Content-Security-Policy", "sandbox; default-src 'none'; script-src 'none'; plugin-types application/pdf; style-src 'unsafe-inline'; object-src 'self';")
- next.ServeHTTP(w, r)
- })
- }
- mp.LegacyMediaRouter.Use(corsMiddleware)
- mp.ClientMediaRouter.Use(corsMiddleware)
- mp.KeyServer.Register(router)
+ router.Handle("/_matrix/federation/", exhttp.ApplyMiddleware(
+ mp.FederationRouter,
+ exhttp.StripPrefix("/_matrix/federation"),
+ hlog.NewHandler(log),
+ hlog.RequestIDHandler("request_id", "Request-Id"),
+ requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}),
+ exhttp.HandleErrors(errorBodies),
+ ))
+ router.Handle("/_matrix/client/v1/media/", exhttp.ApplyMiddleware(
+ mp.ClientMediaRouter,
+ exhttp.StripPrefix("/_matrix/client/v1/media"),
+ hlog.NewHandler(log),
+ hlog.RequestIDHandler("request_id", "Request-Id"),
+ exhttp.CORSMiddleware,
+ requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}),
+ exhttp.HandleErrors(errorBodies),
+ ))
+ mp.KeyServer.Register(router, log)
}
-func (mp *MediaProxy) proxyDownload(ctx context.Context, w http.ResponseWriter, url, fileName string) {
- log := zerolog.Ctx(ctx)
- req, err := http.NewRequest(http.MethodGet, url, nil)
+var ErrInvalidMediaIDSyntax = errors.New("invalid media ID syntax")
+
+func queryToMap(vals url.Values) map[string]string {
+ m := make(map[string]string, len(vals))
+ for k, v := range vals {
+ m[k] = v[0]
+ }
+ return m
+}
+
+func (mp *MediaProxy) getMedia(w http.ResponseWriter, r *http.Request) GetMediaResponse {
+ mediaID := r.PathValue("mediaID")
+ if !id.IsValidMediaID(mediaID) {
+ mautrix.MNotFound.WithMessage("Media ID %q is not valid", mediaID).Write(w)
+ return nil
+ }
+ resp, err := mp.GetMedia(r.Context(), mediaID, queryToMap(r.URL.Query()))
if err != nil {
- log.Err(err).Str("url", url).Msg("Failed to create proxy request")
- jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{
- ErrCode: "M_UNKNOWN",
- Err: "Failed to create proxy request",
- })
- return
- }
- req.Header.Set("User-Agent", mautrix.DefaultUserAgent+" (media proxy)")
- if mp.PrepareProxyRequest != nil {
- mp.PrepareProxyRequest(req)
- }
- resp, err := mp.ProxyClient.Do(req)
- defer func() {
- if resp != nil && resp.Body != nil {
- _ = resp.Body.Close()
+ var mautrixRespError mautrix.RespError
+ if errors.Is(err, ErrInvalidMediaIDSyntax) {
+ mautrix.MNotFound.WithMessage("This is a media proxy at %q, other media downloads are not available here", mp.serverName).Write(w)
+ } else if errors.As(err, &mautrixRespError) {
+ mautrixRespError.Write(w)
+ } else {
+ zerolog.Ctx(r.Context()).Err(err).Str("media_id", mediaID).Msg("Failed to get media URL")
+ mautrix.MNotFound.WithMessage("Media not found").Write(w)
}
- }()
+ return nil
+ }
+ return resp
+}
+
+func startMultipart(ctx context.Context, w http.ResponseWriter) *multipart.Writer {
+ mpw := multipart.NewWriter(w)
+ w.Header().Set("Content-Type", strings.Replace(mpw.FormDataContentType(), "form-data", "mixed", 1))
+ w.WriteHeader(http.StatusOK)
+ metaPart, err := mpw.CreatePart(textproto.MIMEHeader{
+ "Content-Type": {"application/json"},
+ })
if err != nil {
- log.Err(err).Str("url", url).Msg("Failed to proxy download")
- jsonResponse(w, http.StatusServiceUnavailable, &mautrix.RespError{
- ErrCode: "M_UNKNOWN",
- Err: "Failed to proxy download",
- })
- return
- } else if resp.StatusCode != http.StatusOK {
- log.Warn().Str("url", url).Int("status", resp.StatusCode).Msg("Unexpected status code proxying download")
- jsonResponse(w, resp.StatusCode, &mautrix.RespError{
- ErrCode: "M_UNKNOWN",
- Err: "Unexpected status code proxying download",
- })
+ zerolog.Ctx(ctx).Err(err).Msg("Failed to create multipart metadata field")
+ return nil
+ }
+ _, err = metaPart.Write([]byte(`{}`))
+ if err != nil {
+ zerolog.Ctx(ctx).Err(err).Msg("Failed to write multipart metadata field")
+ return nil
+ }
+ return mpw
+}
+
+func (mp *MediaProxy) DownloadMediaFederation(w http.ResponseWriter, r *http.Request) {
+ if mp.ServerAuth != nil {
+ var err *mautrix.RespError
+ r, err = mp.ServerAuth.Authenticate(r)
+ if err != nil {
+ err.Write(w)
+ return
+ }
+ }
+ ctx := r.Context()
+ log := zerolog.Ctx(ctx)
+
+ resp := mp.getMedia(w, r)
+ if resp == nil {
return
}
- w.Header()["Content-Type"] = resp.Header["Content-Type"]
- w.Header()["Content-Length"] = resp.Header["Content-Length"]
- w.Header()["Last-Modified"] = resp.Header["Last-Modified"]
- w.Header()["Cache-Control"] = resp.Header["Cache-Control"]
+
+ var mpw *multipart.Writer
+ if urlResp, ok := resp.(*GetMediaResponseURL); ok {
+ mpw = startMultipart(ctx, w)
+ if mpw == nil {
+ return
+ }
+ _, err := mpw.CreatePart(textproto.MIMEHeader{
+ "Location": {urlResp.URL},
+ })
+ if err != nil {
+ log.Err(err).Msg("Failed to create multipart redirect field")
+ return
+ }
+ } else if fileResp, ok := resp.(*GetMediaResponseFile); ok {
+ responseStarted, err := doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error {
+ mpw = startMultipart(ctx, w)
+ if mpw == nil {
+ return fmt.Errorf("failed to start multipart writer")
+ }
+ dataPart, err := mpw.CreatePart(textproto.MIMEHeader{
+ "Content-Type": {mimeType},
+ })
+ if err != nil {
+ return fmt.Errorf("failed to create multipart data field: %w", err)
+ }
+ _, err = wt.WriteTo(dataPart)
+ return err
+ })
+ if err != nil {
+ log.Err(err).Msg("Failed to do media proxy with temp file")
+ if !responseStarted {
+ var mautrixRespError mautrix.RespError
+ if errors.As(err, &mautrixRespError) {
+ mautrixRespError.Write(w)
+ } else {
+ mautrix.MUnknown.WithMessage("Internal error proxying media").Write(w)
+ }
+ }
+ return
+ }
+ } else if dataResp, ok := resp.(GetMediaResponseWriter); ok {
+ mpw = startMultipart(ctx, w)
+ if mpw == nil {
+ return
+ }
+ dataPart, err := mpw.CreatePart(textproto.MIMEHeader{
+ "Content-Type": {dataResp.GetContentType()},
+ })
+ if err != nil {
+ log.Err(err).Msg("Failed to create multipart data field")
+ return
+ }
+ _, err = dataResp.WriteTo(dataPart)
+ if err != nil {
+ log.Err(err).Msg("Failed to write multipart data field")
+ return
+ }
+ } else {
+ panic(fmt.Errorf("unknown GetMediaResponse type %T", resp))
+ }
+ err := mpw.Close()
+ if err != nil {
+ log.Err(err).Msg("Failed to close multipart writer")
+ return
+ }
+}
+
+func (mp *MediaProxy) addHeaders(w http.ResponseWriter, mimeType, fileName string) {
+ w.Header().Set("Cache-Control", "public, max-age=31536000, immutable")
contentDisposition := "attachment"
- switch resp.Header.Get("Content-Type") {
+ switch mimeType {
case "text/css", "text/plain", "text/csv", "application/json", "application/ld+json", "image/jpeg", "image/gif",
"image/png", "image/apng", "image/webp", "image/avif", "video/mp4", "video/webm", "video/ogg", "video/quicktime",
"audio/mp4", "audio/webm", "audio/aac", "audio/mpeg", "audio/ogg", "audio/wave", "audio/wav", "audio/x-wav",
@@ -241,113 +386,14 @@ func (mp *MediaProxy) proxyDownload(ctx context.Context, w http.ResponseWriter,
})
}
w.Header().Set("Content-Disposition", contentDisposition)
- w.WriteHeader(http.StatusOK)
- _, err = io.Copy(w, resp.Body)
- if err != nil {
- log.Debug().Err(err).Msg("Failed to write proxy response")
- }
-}
-
-type ResponseError struct {
- Status int
- Data any
-}
-
-func (err *ResponseError) Error() string {
- return fmt.Sprintf("HTTP %d: %v", err.Status, err.Data)
-}
-
-var ErrInvalidMediaIDSyntax = errors.New("invalid media ID syntax")
-
-func (mp *MediaProxy) getMedia(w http.ResponseWriter, r *http.Request) GetMediaResponse {
- mediaID := mux.Vars(r)["mediaID"]
- resp, err := mp.GetMedia(r.Context(), mediaID)
- if err != nil {
- var respError *ResponseError
- if errors.Is(err, ErrInvalidMediaIDSyntax) {
- jsonResponse(w, http.StatusNotFound, &mautrix.RespError{
- ErrCode: mautrix.MNotFound.ErrCode,
- Err: fmt.Sprintf("This is a media proxy at %q, other media downloads are not available here", mp.serverName),
- })
- } else if errors.As(err, &respError) {
- jsonResponse(w, respError.Status, respError.Data)
- } else {
- zerolog.Ctx(r.Context()).Err(err).Str("media_id", mediaID).Msg("Failed to get media URL")
- jsonResponse(w, http.StatusNotFound, &mautrix.RespError{
- ErrCode: mautrix.MNotFound.ErrCode,
- Err: "Media not found",
- })
- }
- return nil
- }
- return resp
-}
-
-func (mp *MediaProxy) DownloadMediaFederation(w http.ResponseWriter, r *http.Request) {
- ctx := r.Context()
- log := zerolog.Ctx(ctx)
- // TODO check destination header in X-Matrix auth
-
- resp := mp.getMedia(w, r)
- if resp == nil {
- return
- }
-
- mpw := multipart.NewWriter(w)
- w.Header().Set("Content-Type", strings.Replace(mpw.FormDataContentType(), "form-data", "mixed", 1))
- w.WriteHeader(http.StatusOK)
- metaPart, err := mpw.CreatePart(textproto.MIMEHeader{
- "Content-Type": {"application/json"},
- })
- if err != nil {
- log.Err(err).Msg("Failed to create multipart metadata field")
- return
- }
- _, err = metaPart.Write([]byte(`{}`))
- if err != nil {
- log.Err(err).Msg("Failed to write multipart metadata field")
- return
- }
- if urlResp, ok := resp.(*GetMediaResponseURL); ok {
- _, err = mpw.CreatePart(textproto.MIMEHeader{
- "Location": {urlResp.URL},
- })
- if err != nil {
- log.Err(err).Msg("Failed to create multipart redirect field")
- return
- }
- } else if dataResp, ok := resp.(*GetMediaResponseData); ok {
- dataPart, err := mpw.CreatePart(textproto.MIMEHeader{
- "Content-Type": {dataResp.ContentType},
- })
- if err != nil {
- log.Err(err).Msg("Failed to create multipart data field")
- return
- }
- _, err = io.Copy(dataPart, dataResp.Reader)
- if err != nil {
- log.Err(err).Msg("Failed to write multipart data field")
- return
- }
- } else {
- panic("unknown GetMediaResponse type")
- }
- err = mpw.Close()
- if err != nil {
- log.Err(err).Msg("Failed to close multipart writer")
- return
- }
+ w.Header().Set("Content-Type", mimeType)
}
func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
log := zerolog.Ctx(ctx)
- vars := mux.Vars(r)
- if vars["serverName"] != mp.serverName {
- jsonResponse(w, http.StatusNotFound, &mautrix.RespError{
- ErrCode: mautrix.MNotFound.ErrCode,
- Err: fmt.Sprintf("This is a media proxy at %q, other media downloads are not available here", mp.serverName),
- })
+ if r.PathValue("serverName") != mp.serverName {
+ mautrix.MNotFound.WithMessage("This is a media proxy at %q, other media downloads are not available here", mp.serverName).Write(w)
return
}
resp := mp.getMedia(w, r)
@@ -356,13 +402,6 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) {
}
if urlResp, ok := resp.(*GetMediaResponseURL); ok {
- // Proxy if the config allows proxying and the request doesn't allow redirects.
- // In any other case, redirect to the URL.
- isFederated := strings.HasPrefix(r.Header.Get("Authorization"), "X-Matrix")
- if mp.ProxyClient != nil && (r.URL.Query().Get("allow_redirect") != "true" || (mp.ForceProxyLegacyFederation && isFederated)) {
- mp.proxyDownload(ctx, w, urlResp.URL, vars["fileName"])
- return
- }
w.Header().Set("Location", urlResp.URL)
expirySeconds := (time.Until(urlResp.ExpiresAt) - 5*time.Minute).Seconds()
if urlResp.ExpiresAt.IsZero() {
@@ -374,51 +413,113 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Cache-Control", "no-store")
}
w.WriteHeader(http.StatusTemporaryRedirect)
- } else if dataResp, ok := resp.(*GetMediaResponseData); ok {
- w.Header().Set("Content-Type", dataResp.ContentType)
- if dataResp.ContentLength != 0 {
- w.Header().Set("Content-Length", strconv.FormatInt(dataResp.ContentLength, 10))
+ } else if fileResp, ok := resp.(*GetMediaResponseFile); ok {
+ responseStarted, err := doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error {
+ mp.addHeaders(w, mimeType, r.PathValue("fileName"))
+ w.Header().Set("Content-Length", strconv.FormatInt(size, 10))
+ w.WriteHeader(http.StatusOK)
+ _, err := wt.WriteTo(w)
+ return err
+ })
+ if err != nil {
+ log.Err(err).Msg("Failed to do media proxy with temp file")
+ if !responseStarted {
+ var mautrixRespError mautrix.RespError
+ if errors.As(err, &mautrixRespError) {
+ mautrixRespError.Write(w)
+ } else {
+ mautrix.MUnknown.WithMessage("Internal error proxying media").Write(w)
+ }
+ }
+ }
+ } else if writerResp, ok := resp.(GetMediaResponseWriter); ok {
+ if dataResp, ok := writerResp.(*GetMediaResponseData); ok {
+ defer dataResp.Reader.Close()
+ }
+ mp.addHeaders(w, writerResp.GetContentType(), r.PathValue("fileName"))
+ if writerResp.GetContentLength() != 0 {
+ w.Header().Set("Content-Length", strconv.FormatInt(writerResp.GetContentLength(), 10))
}
w.WriteHeader(http.StatusOK)
- _, err := io.Copy(w, dataResp.Reader)
+ _, err := writerResp.WriteTo(w)
if err != nil {
log.Err(err).Msg("Failed to write media data")
}
} else {
- panic("unknown GetMediaResponse type")
+ panic(fmt.Errorf("unknown GetMediaResponse type %T", resp))
}
}
-func jsonResponse(w http.ResponseWriter, status int, response interface{}) {
- w.Header().Add("Content-Type", "application/json")
- w.WriteHeader(status)
- _ = json.NewEncoder(w).Encode(response)
+func doTempFileDownload(
+ data *GetMediaResponseFile,
+ respond func(w io.WriterTo, size int64, mimeType string) error,
+) (bool, error) {
+ tempFile, err := os.CreateTemp("", "mautrix-mediaproxy-*")
+ if err != nil {
+ return false, fmt.Errorf("failed to create temp file: %w", err)
+ }
+ origTempFile := tempFile
+ defer func() {
+ _ = origTempFile.Close()
+ _ = os.Remove(origTempFile.Name())
+ }()
+ meta, err := data.Callback(tempFile)
+ if err != nil {
+ return false, err
+ }
+ if meta.ReplacementFile != "" {
+ tempFile, err = os.Open(meta.ReplacementFile)
+ if err != nil {
+ return false, fmt.Errorf("failed to open replacement file: %w", err)
+ }
+ defer func() {
+ _ = tempFile.Close()
+ _ = os.Remove(origTempFile.Name())
+ }()
+ } else {
+ _, err = tempFile.Seek(0, io.SeekStart)
+ if err != nil {
+ return false, fmt.Errorf("failed to seek to start of temp file: %w", err)
+ }
+ }
+ fileInfo, err := tempFile.Stat()
+ if err != nil {
+ return false, fmt.Errorf("failed to stat temp file: %w", err)
+ }
+ mimeType := meta.ContentType
+ if mimeType == "" {
+ buf := make([]byte, 512)
+ n, err := tempFile.Read(buf)
+ if err != nil {
+ return false, fmt.Errorf("failed to read temp file to detect mime: %w", err)
+ }
+ buf = buf[:n]
+ _, err = tempFile.Seek(0, io.SeekStart)
+ if err != nil {
+ return false, fmt.Errorf("failed to seek to start of temp file: %w", err)
+ }
+ mimeType = http.DetectContentType(buf)
+ }
+ err = respond(tempFile, fileInfo.Size(), mimeType)
+ if err != nil {
+ return true, err
+ }
+ return true, nil
}
+var (
+ ErrUploadNotSupported = mautrix.MUnrecognized.
+ WithMessage("This is a media proxy and does not support media uploads.").
+ WithStatus(http.StatusNotImplemented)
+ ErrPreviewURLNotSupported = mautrix.MUnrecognized.
+ WithMessage("This is a media proxy and does not support URL previews.").
+ WithStatus(http.StatusNotImplemented)
+)
+
func (mp *MediaProxy) UploadNotSupported(w http.ResponseWriter, r *http.Request) {
- jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{
- ErrCode: mautrix.MUnrecognized.ErrCode,
- Err: "This is a media proxy and does not support media uploads.",
- })
+ ErrUploadNotSupported.Write(w)
}
func (mp *MediaProxy) PreviewURLNotSupported(w http.ResponseWriter, r *http.Request) {
- jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{
- ErrCode: mautrix.MUnrecognized.ErrCode,
- Err: "This is a media proxy and does not support URL previews.",
- })
-}
-
-func (mp *MediaProxy) UnknownEndpoint(w http.ResponseWriter, r *http.Request) {
- jsonResponse(w, http.StatusNotFound, &mautrix.RespError{
- ErrCode: mautrix.MUnrecognized.ErrCode,
- Err: "Unrecognized endpoint",
- })
-}
-
-func (mp *MediaProxy) UnsupportedMethod(w http.ResponseWriter, r *http.Request) {
- jsonResponse(w, http.StatusMethodNotAllowed, &mautrix.RespError{
- ErrCode: mautrix.MUnrecognized.ErrCode,
- Err: "Invalid method for endpoint",
- })
+ ErrPreviewURLNotSupported.Write(w)
}
diff --git a/mockserver/mockserver.go b/mockserver/mockserver.go
new file mode 100644
index 00000000..507c24a5
--- /dev/null
+++ b/mockserver/mockserver.go
@@ -0,0 +1,307 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mockserver
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "maps"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ globallog "github.com/rs/zerolog/log" // zerolog-allow-global-log
+ "github.com/stretchr/testify/require"
+ "go.mau.fi/util/dbutil"
+ "go.mau.fi/util/exerrors"
+ "go.mau.fi/util/exhttp"
+ "go.mau.fi/util/random"
+
+ "maunium.net/go/mautrix"
+ "maunium.net/go/mautrix/crypto"
+ "maunium.net/go/mautrix/crypto/cryptohelper"
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+)
+
+func mustDecode(r *http.Request, data any) {
+ exerrors.PanicIfNotNil(json.NewDecoder(r.Body).Decode(data))
+}
+
+type userAndDeviceID struct {
+ UserID id.UserID
+ DeviceID id.DeviceID
+}
+
+type MockServer struct {
+ Router *http.ServeMux
+ Server *httptest.Server
+
+ AccessTokenToUserID map[string]userAndDeviceID
+ DeviceInbox map[id.UserID]map[id.DeviceID][]event.Event
+ AccountData map[id.UserID]map[event.Type]json.RawMessage
+ DeviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys
+ OneTimeKeys map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey
+ MasterKeys map[id.UserID]mautrix.CrossSigningKeys
+ SelfSigningKeys map[id.UserID]mautrix.CrossSigningKeys
+ UserSigningKeys map[id.UserID]mautrix.CrossSigningKeys
+
+ PopOTKs bool
+ MemoryStore bool
+}
+
+func Create(t testing.TB) *MockServer {
+ t.Helper()
+
+ server := MockServer{
+ AccessTokenToUserID: map[string]userAndDeviceID{},
+ DeviceInbox: map[id.UserID]map[id.DeviceID][]event.Event{},
+ AccountData: map[id.UserID]map[event.Type]json.RawMessage{},
+ DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{},
+ OneTimeKeys: map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{},
+ MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{},
+ SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
+ UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
+ PopOTKs: true,
+ MemoryStore: true,
+ }
+
+ router := http.NewServeMux()
+ router.HandleFunc("POST /_matrix/client/v3/login", server.postLogin)
+ router.HandleFunc("POST /_matrix/client/v3/keys/query", server.postKeysQuery)
+ router.HandleFunc("POST /_matrix/client/v3/keys/claim", server.postKeysClaim)
+ router.HandleFunc("PUT /_matrix/client/v3/sendToDevice/{type}/{txn}", server.putSendToDevice)
+ router.HandleFunc("PUT /_matrix/client/v3/user/{userID}/account_data/{type}", server.putAccountData)
+ router.HandleFunc("POST /_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload)
+ router.HandleFunc("POST /_matrix/client/v3/keys/signatures/upload", server.emptyResp)
+ router.HandleFunc("POST /_matrix/client/v3/keys/upload", server.postKeysUpload)
+ server.Router = router
+ server.Server = httptest.NewServer(router)
+ t.Cleanup(server.Server.Close)
+ return &server
+}
+
+func (ms *MockServer) getUserID(r *http.Request) userAndDeviceID {
+ authHeader := r.Header.Get("Authorization")
+ authHeader = strings.TrimPrefix(authHeader, "Bearer ")
+ userID, ok := ms.AccessTokenToUserID[authHeader]
+ if !ok {
+ panic("no user ID found for access token " + authHeader)
+ }
+ return userID
+}
+
+func (ms *MockServer) emptyResp(w http.ResponseWriter, _ *http.Request) {
+ exhttp.WriteEmptyJSONResponse(w, http.StatusOK)
+}
+
+func (ms *MockServer) postLogin(w http.ResponseWriter, r *http.Request) {
+ var loginReq mautrix.ReqLogin
+ mustDecode(r, &loginReq)
+
+ deviceID := loginReq.DeviceID
+ if deviceID == "" {
+ deviceID = id.DeviceID(random.String(10))
+ }
+
+ accessToken := random.String(30)
+ userID := id.UserID(loginReq.Identifier.User)
+ ms.AccessTokenToUserID[accessToken] = userAndDeviceID{
+ UserID: userID,
+ DeviceID: deviceID,
+ }
+
+ exhttp.WriteJSONResponse(w, http.StatusOK, &mautrix.RespLogin{
+ AccessToken: accessToken,
+ DeviceID: deviceID,
+ UserID: userID,
+ })
+}
+
+func (ms *MockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) {
+ var req mautrix.ReqSendToDevice
+ mustDecode(r, &req)
+ evtType := event.Type{Type: r.PathValue("type"), Class: event.ToDeviceEventType}
+
+ for user, devices := range req.Messages {
+ for device, content := range devices {
+ if _, ok := ms.DeviceInbox[user]; !ok {
+ ms.DeviceInbox[user] = map[id.DeviceID][]event.Event{}
+ }
+ content.ParseRaw(evtType)
+ ms.DeviceInbox[user][device] = append(ms.DeviceInbox[user][device], event.Event{
+ Sender: ms.getUserID(r).UserID,
+ Type: evtType,
+ Content: *content,
+ })
+ }
+ }
+ ms.emptyResp(w, r)
+}
+
+func (ms *MockServer) putAccountData(w http.ResponseWriter, r *http.Request) {
+ userID := id.UserID(r.PathValue("userID"))
+ eventType := event.Type{Type: r.PathValue("type"), Class: event.AccountDataEventType}
+
+ jsonData, _ := io.ReadAll(r.Body)
+ if _, ok := ms.AccountData[userID]; !ok {
+ ms.AccountData[userID] = map[event.Type]json.RawMessage{}
+ }
+ ms.AccountData[userID][eventType] = json.RawMessage(jsonData)
+ ms.emptyResp(w, r)
+}
+
+func (ms *MockServer) postKeysQuery(w http.ResponseWriter, r *http.Request) {
+ var req mautrix.ReqQueryKeys
+ mustDecode(r, &req)
+ resp := mautrix.RespQueryKeys{
+ MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{},
+ UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
+ SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
+ DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{},
+ }
+ for user := range req.DeviceKeys {
+ resp.MasterKeys[user] = ms.MasterKeys[user]
+ resp.UserSigningKeys[user] = ms.UserSigningKeys[user]
+ resp.SelfSigningKeys[user] = ms.SelfSigningKeys[user]
+ resp.DeviceKeys[user] = ms.DeviceKeys[user]
+ }
+ exhttp.WriteJSONResponse(w, http.StatusOK, &resp)
+}
+
+func (ms *MockServer) postKeysClaim(w http.ResponseWriter, r *http.Request) {
+ var req mautrix.ReqClaimKeys
+ mustDecode(r, &req)
+ resp := mautrix.RespClaimKeys{
+ OneTimeKeys: map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{},
+ }
+ for user, devices := range req.OneTimeKeys {
+ resp.OneTimeKeys[user] = map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{}
+ for device := range devices {
+ keys := ms.OneTimeKeys[user][device]
+ for keyID, key := range keys {
+ if ms.PopOTKs {
+ delete(keys, keyID)
+ }
+ resp.OneTimeKeys[user][device] = map[id.KeyID]mautrix.OneTimeKey{
+ keyID: key,
+ }
+ break
+ }
+ }
+ }
+ exhttp.WriteJSONResponse(w, http.StatusOK, &resp)
+}
+
+func (ms *MockServer) postKeysUpload(w http.ResponseWriter, r *http.Request) {
+ var req mautrix.ReqUploadKeys
+ mustDecode(r, &req)
+
+ uid := ms.getUserID(r)
+ userID := uid.UserID
+ if _, ok := ms.DeviceKeys[userID]; !ok {
+ ms.DeviceKeys[userID] = map[id.DeviceID]mautrix.DeviceKeys{}
+ }
+ if _, ok := ms.OneTimeKeys[userID]; !ok {
+ ms.OneTimeKeys[userID] = map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{}
+ }
+
+ if req.DeviceKeys != nil {
+ ms.DeviceKeys[userID][uid.DeviceID] = *req.DeviceKeys
+ }
+ otks, ok := ms.OneTimeKeys[userID][uid.DeviceID]
+ if !ok {
+ otks = map[id.KeyID]mautrix.OneTimeKey{}
+ ms.OneTimeKeys[userID][uid.DeviceID] = otks
+ }
+ if req.OneTimeKeys != nil {
+ maps.Copy(otks, req.OneTimeKeys)
+ }
+
+ exhttp.WriteJSONResponse(w, http.StatusOK, &mautrix.RespUploadKeys{
+ OneTimeKeyCounts: mautrix.OTKCount{SignedCurve25519: len(otks)},
+ })
+}
+
+func (ms *MockServer) postDeviceSigningUpload(w http.ResponseWriter, r *http.Request) {
+ var req mautrix.UploadCrossSigningKeysReq[any]
+ mustDecode(r, &req)
+
+ userID := ms.getUserID(r).UserID
+ ms.MasterKeys[userID] = req.Master
+ ms.SelfSigningKeys[userID] = req.SelfSigning
+ ms.UserSigningKeys[userID] = req.UserSigning
+
+ ms.emptyResp(w, r)
+}
+
+func (ms *MockServer) Login(t testing.TB, ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*mautrix.Client, crypto.Store) {
+ t.Helper()
+ if ctx == nil {
+ ctx = context.TODO()
+ }
+ client, err := mautrix.NewClient(ms.Server.URL, "", "")
+ require.NoError(t, err)
+ client.Client = ms.Server.Client()
+
+ _, err = client.Login(ctx, &mautrix.ReqLogin{
+ Type: mautrix.AuthTypePassword,
+ Identifier: mautrix.UserIdentifier{
+ Type: mautrix.IdentifierTypeUser,
+ User: userID.String(),
+ },
+ DeviceID: deviceID,
+ Password: "password",
+ StoreCredentials: true,
+ })
+ require.NoError(t, err)
+
+ var store any
+ if ms.MemoryStore {
+ store = crypto.NewMemoryStore(nil)
+ client.StateStore = mautrix.NewMemoryStateStore()
+ } else {
+ store, err = dbutil.NewFromConfig("", dbutil.Config{
+ PoolConfig: dbutil.PoolConfig{
+ Type: "sqlite3-fk-wal",
+ URI: fmt.Sprintf("file:%s?mode=memory&cache=shared&_txlock=immediate", random.String(10)),
+ MaxOpenConns: 5,
+ MaxIdleConns: 1,
+ },
+ }, nil)
+ require.NoError(t, err)
+ }
+ cryptoHelper, err := cryptohelper.NewCryptoHelper(client, []byte("test"), store)
+ require.NoError(t, err)
+ client.Crypto = cryptoHelper
+
+ err = cryptoHelper.Init(ctx)
+ require.NoError(t, err)
+
+ machineLog := globallog.Logger.With().
+ Stringer("my_user_id", userID).
+ Stringer("my_device_id", deviceID).
+ Logger()
+ cryptoHelper.Machine().Log = &machineLog
+
+ err = cryptoHelper.Machine().ShareKeys(ctx, 50)
+ require.NoError(t, err)
+
+ return client, cryptoHelper.Machine().CryptoStore
+}
+
+func (ms *MockServer) DispatchToDevice(t testing.TB, ctx context.Context, client *mautrix.Client) {
+ t.Helper()
+
+ for _, evt := range ms.DeviceInbox[client.UserID][client.DeviceID] {
+ client.Syncer.(*mautrix.DefaultSyncer).Dispatch(ctx, &evt)
+ ms.DeviceInbox[client.UserID][client.DeviceID] = ms.DeviceInbox[client.UserID][client.DeviceID][1:]
+ }
+}
diff --git a/pushrules/action.go b/pushrules/action.go
index 9838e88b..b5a884b2 100644
--- a/pushrules/action.go
+++ b/pushrules/action.go
@@ -105,7 +105,7 @@ func (action *PushAction) UnmarshalJSON(raw []byte) error {
if ok {
action.Action = ActionSetTweak
action.Tweak = PushActionTweak(tweak)
- action.Value, _ = val["value"]
+ action.Value = val["value"]
}
}
return nil
diff --git a/pushrules/action_test.go b/pushrules/action_test.go
index a8f68415..3c0aa168 100644
--- a/pushrules/action_test.go
+++ b/pushrules/action_test.go
@@ -139,9 +139,9 @@ func TestPushAction_UnmarshalJSON_InvalidTypeDoesNothing(t *testing.T) {
}
err := pa.UnmarshalJSON([]byte(`{"foo": "bar"}`))
- assert.Nil(t, err)
+ assert.NoError(t, err)
err = pa.UnmarshalJSON([]byte(`9001`))
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, pushrules.PushActionType("unchanged"), pa.Action)
assert.Equal(t, pushrules.PushActionTweak("unchanged"), pa.Tweak)
@@ -156,7 +156,7 @@ func TestPushAction_UnmarshalJSON_StringChangesActionType(t *testing.T) {
}
err := pa.UnmarshalJSON([]byte(`"foo"`))
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, pushrules.PushActionType("foo"), pa.Action)
assert.Equal(t, pushrules.PushActionTweak("unchanged"), pa.Tweak)
@@ -171,7 +171,7 @@ func TestPushAction_UnmarshalJSON_SetTweakChangesTweak(t *testing.T) {
}
err := pa.UnmarshalJSON([]byte(`{"set_tweak": "foo", "value": 123.0}`))
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, pushrules.ActionSetTweak, pa.Action)
assert.Equal(t, pushrules.PushActionTweak("foo"), pa.Tweak)
@@ -185,7 +185,7 @@ func TestPushAction_MarshalJSON_TweakOutputWorks(t *testing.T) {
Value: "bar",
}
data, err := pa.MarshalJSON()
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, []byte(`{"set_tweak":"foo","value":"bar"}`), data)
}
@@ -196,6 +196,6 @@ func TestPushAction_MarshalJSON_OtherOutputWorks(t *testing.T) {
Value: "bar",
}
data, err := pa.MarshalJSON()
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, []byte(`"something else"`), data)
}
diff --git a/pushrules/condition.go b/pushrules/condition.go
index dbe83a61..caa717de 100644
--- a/pushrules/condition.go
+++ b/pushrules/condition.go
@@ -27,6 +27,11 @@ type Room interface {
GetMemberCount() int
}
+type PowerLevelfulRoom interface {
+ Room
+ GetPowerLevels() *event.PowerLevelsEventContent
+}
+
// EventfulRoom is an extension of Room to support MSC3664.
type EventfulRoom interface {
Room
@@ -38,11 +43,12 @@ type PushCondKind string
// The allowed push condition kinds as specified in https://spec.matrix.org/v1.2/client-server-api/#conditions-1
const (
- KindEventMatch PushCondKind = "event_match"
- KindContainsDisplayName PushCondKind = "contains_display_name"
- KindRoomMemberCount PushCondKind = "room_member_count"
- KindEventPropertyIs PushCondKind = "event_property_is"
- KindEventPropertyContains PushCondKind = "event_property_contains"
+ KindEventMatch PushCondKind = "event_match"
+ KindContainsDisplayName PushCondKind = "contains_display_name"
+ KindRoomMemberCount PushCondKind = "room_member_count"
+ KindEventPropertyIs PushCondKind = "event_property_is"
+ KindEventPropertyContains PushCondKind = "event_property_contains"
+ KindSenderNotificationPermission PushCondKind = "sender_notification_permission"
// MSC3664: https://github.com/matrix-org/matrix-spec-proposals/pull/3664
@@ -82,6 +88,8 @@ func (cond *PushCondition) Match(room Room, evt *event.Event) bool {
return cond.matchDisplayName(room, evt)
case KindRoomMemberCount:
return cond.matchMemberCount(room)
+ case KindSenderNotificationPermission:
+ return cond.matchSenderNotificationPermission(room, evt.Sender, cond.Key)
default:
return false
}
@@ -334,3 +342,18 @@ func (cond *PushCondition) matchMemberCount(room Room) bool {
return false
}
}
+
+func (cond *PushCondition) matchSenderNotificationPermission(room Room, sender id.UserID, key string) bool {
+ if key != "room" {
+ return false
+ }
+ plRoom, ok := room.(PowerLevelfulRoom)
+ if !ok {
+ return false
+ }
+ pls := plRoom.GetPowerLevels()
+ if pls == nil {
+ return false
+ }
+ return pls.GetUserLevel(sender) >= pls.Notifications.Room()
+}
diff --git a/pushrules/condition_test.go b/pushrules/condition_test.go
index 0d3eaf7a..37af3e34 100644
--- a/pushrules/condition_test.go
+++ b/pushrules/condition_test.go
@@ -102,14 +102,6 @@ func newEventPropertyIsPushCondition(key string, value any) *pushrules.PushCondi
}
}
-func newEventPropertyContainsPushCondition(key string, value any) *pushrules.PushCondition {
- return &pushrules.PushCondition{
- Kind: pushrules.KindEventPropertyContains,
- Key: key,
- Value: value,
- }
-}
-
func TestPushCondition_Match_InvalidKind(t *testing.T) {
condition := &pushrules.PushCondition{
Kind: pushrules.PushCondKind("invalid"),
diff --git a/pushrules/pushrules_test.go b/pushrules/pushrules_test.go
index a531ca28..a5a0f5e7 100644
--- a/pushrules/pushrules_test.go
+++ b/pushrules/pushrules_test.go
@@ -25,7 +25,7 @@ func TestEventToPushRules(t *testing.T) {
},
}
pushRuleset, err := pushrules.EventToPushRules(evt)
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.NotNil(t, pushRuleset)
assert.IsType(t, pushRuleset.Override, pushrules.PushRuleArray{})
diff --git a/pushrules/rule.go b/pushrules/rule.go
index ee6d33c4..cf659695 100644
--- a/pushrules/rule.go
+++ b/pushrules/rule.go
@@ -8,7 +8,10 @@ package pushrules
import (
"encoding/gob"
+ "regexp"
+ "strings"
+ "go.mau.fi/util/exerrors"
"go.mau.fi/util/glob"
"maunium.net/go/mautrix/event"
@@ -165,13 +168,20 @@ func (rule *PushRule) matchConditions(room Room, evt *event.Event) bool {
}
func (rule *PushRule) matchPattern(room Room, evt *event.Event) bool {
- pattern := glob.CompileWithImplicitContains(rule.Pattern)
- if pattern == nil {
- return false
- }
msg, ok := evt.Content.Raw["body"].(string)
if !ok {
return false
}
- return pattern.Match(msg)
+ var buf strings.Builder
+ // As per https://spec.matrix.org/unstable/client-server-api/#push-rules, content rules are case-insensitive
+ // and must match whole words, so wrap the converted glob in (?i) and \b.
+ buf.WriteString(`(?i)\b`)
+ // strings.Builder will never return errors
+ exerrors.PanicIfNotNil(glob.ToRegexPattern(rule.Pattern, &buf))
+ buf.WriteString(`\b`)
+ pattern, err := regexp.Compile(buf.String())
+ if err != nil {
+ return false
+ }
+ return pattern.MatchString(msg)
}
diff --git a/pushrules/rule_test.go b/pushrules/rule_test.go
index 803c721e..7ff839a7 100644
--- a/pushrules/rule_test.go
+++ b/pushrules/rule_test.go
@@ -186,6 +186,34 @@ func TestPushRule_Match_Content(t *testing.T) {
assert.True(t, rule.Match(blankTestRoom, evt))
}
+func TestPushRule_Match_WordBoundary(t *testing.T) {
+ rule := &pushrules.PushRule{
+ Type: pushrules.ContentRule,
+ Enabled: true,
+ Pattern: "test",
+ }
+
+ evt := newFakeEvent(event.EventMessage, &event.MessageEventContent{
+ MsgType: event.MsgEmote,
+ Body: "is testing pushrules",
+ })
+ assert.False(t, rule.Match(blankTestRoom, evt))
+}
+
+func TestPushRule_Match_CaseInsensitive(t *testing.T) {
+ rule := &pushrules.PushRule{
+ Type: pushrules.ContentRule,
+ Enabled: true,
+ Pattern: "test",
+ }
+
+ evt := newFakeEvent(event.EventMessage, &event.MessageEventContent{
+ MsgType: event.MsgEmote,
+ Body: "is TeSt-InG pushrules",
+ })
+ assert.True(t, rule.Match(blankTestRoom, evt))
+}
+
func TestPushRule_Match_Content_Fail(t *testing.T) {
rule := &pushrules.PushRule{
Type: pushrules.ContentRule,
diff --git a/pushrules/ruleset.go b/pushrules/ruleset.go
index 609997b4..c42d4799 100644
--- a/pushrules/ruleset.go
+++ b/pushrules/ruleset.go
@@ -68,6 +68,9 @@ func (rs *PushRuleset) MarshalJSON() ([]byte, error) {
var DefaultPushActions = PushActionArray{&PushAction{Action: ActionDontNotify}}
func (rs *PushRuleset) GetMatchingRule(room Room, evt *event.Event) (rule *PushRule) {
+ if rs == nil {
+ return nil
+ }
// Add push rule collections to array in priority order
arrays := []PushRuleCollection{rs.Override, rs.Content, rs.Room, rs.Sender, rs.Underride}
// Loop until one of the push rule collections matches the room/event combo.
diff --git a/requests.go b/requests.go
index a6b0ea8b..cc8b7266 100644
--- a/requests.go
+++ b/requests.go
@@ -2,7 +2,9 @@ package mautrix
import (
"encoding/json"
+ "fmt"
"strconv"
+ "time"
"maunium.net/go/mautrix/crypto/signatures"
"maunium.net/go/mautrix/event"
@@ -38,20 +40,40 @@ const (
type Direction rune
+func (d Direction) MarshalJSON() ([]byte, error) {
+ return json.Marshal(string(d))
+}
+
+func (d *Direction) UnmarshalJSON(data []byte) error {
+ var str string
+ if err := json.Unmarshal(data, &str); err != nil {
+ return err
+ }
+ switch str {
+ case "f":
+ *d = DirectionForward
+ case "b":
+ *d = DirectionBackward
+ default:
+ return fmt.Errorf("invalid direction %q, must be 'f' or 'b'", str)
+ }
+ return nil
+}
+
const (
DirectionForward Direction = 'f'
DirectionBackward Direction = 'b'
)
// ReqRegister is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3register
-type ReqRegister struct {
+type ReqRegister[UIAType any] struct {
Username string `json:"username,omitempty"`
Password string `json:"password,omitempty"`
DeviceID id.DeviceID `json:"device_id,omitempty"`
InitialDeviceDisplayName string `json:"initial_device_display_name,omitempty"`
InhibitLogin bool `json:"inhibit_login,omitempty"`
RefreshToken bool `json:"refresh_token,omitempty"`
- Auth interface{} `json:"auth,omitempty"`
+ Auth UIAType `json:"auth,omitempty"`
// Type for registration, only used for appservice user registrations
// https://spec.matrix.org/v1.2/application-service-api/#server-admin-style-permissions
@@ -91,6 +113,10 @@ type ReqLogin struct {
StoreHomeserverURL bool `json:"-"`
}
+type ReqPutDevice struct {
+ DisplayName string `json:"display_name,omitempty"`
+}
+
type ReqUIAuthFallback struct {
Session string `json:"session"`
User string `json:"user"`
@@ -115,14 +141,17 @@ type ReqCreateRoom struct {
InitialState []*event.Event `json:"initial_state,omitempty"`
Preset string `json:"preset,omitempty"`
IsDirect bool `json:"is_direct,omitempty"`
- RoomVersion string `json:"room_version,omitempty"`
+ RoomVersion id.RoomVersion `json:"room_version,omitempty"`
PowerLevelOverride *event.PowerLevelsEventContent `json:"power_level_content_override,omitempty"`
MeowRoomID id.RoomID `json:"fi.mau.room_id,omitempty"`
+ MeowCreateTS int64 `json:"fi.mau.origin_server_ts,omitempty"`
BeeperInitialMembers []id.UserID `json:"com.beeper.initial_members,omitempty"`
BeeperAutoJoinInvites bool `json:"com.beeper.auto_join_invites,omitempty"`
BeeperLocalRoomID id.RoomID `json:"com.beeper.local_room_id,omitempty"`
+ BeeperBridgeName string `json:"com.beeper.bridge_name,omitempty"`
+ BeeperBridgeAccountID string `json:"com.beeper.bridge_account_id,omitempty"`
}
// ReqRedact is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidredacteventidtxnid
@@ -132,12 +161,37 @@ type ReqRedact struct {
Extra map[string]interface{}
}
+type ReqRedactUser struct {
+ Reason string `json:"reason"`
+ Limit int `json:"-"`
+}
+
type ReqMembers struct {
At string `json:"at"`
Membership event.Membership `json:"membership,omitempty"`
NotMembership event.Membership `json:"not_membership,omitempty"`
}
+type ReqJoinRoom struct {
+ Via []string `json:"-"`
+ Reason string `json:"reason,omitempty"`
+ ThirdPartySigned any `json:"third_party_signed,omitempty"`
+}
+
+type ReqKnockRoom struct {
+ Via []string `json:"-"`
+ Reason string `json:"reason,omitempty"`
+}
+
+type ReqSearchUserDirectory struct {
+ SearchTerm string `json:"search_term"`
+ Limit int `json:"limit,omitempty"`
+}
+
+type ReqMutualRooms struct {
+ From string `json:"-"`
+}
+
// ReqInvite3PID is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidinvite-1
// It is also a JSON object used in https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3createroom
type ReqInvite3PID struct {
@@ -166,6 +220,8 @@ type ReqKickUser struct {
type ReqBanUser struct {
Reason string `json:"reason,omitempty"`
UserID id.UserID `json:"user_id"`
+
+ MSC4293RedactEvents bool `json:"org.matrix.msc4293.redact_events,omitempty"`
}
// ReqUnbanUser is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidunban
@@ -181,7 +237,8 @@ type ReqTyping struct {
}
type ReqPresence struct {
- Presence event.Presence `json:"presence"`
+ Presence event.Presence `json:"presence"`
+ StatusMsg string `json:"status_msg,omitempty"`
}
type ReqAliasCreate struct {
@@ -263,11 +320,11 @@ func (csk *CrossSigningKeys) FirstKey() id.Ed25519 {
return ""
}
-type UploadCrossSigningKeysReq struct {
+type UploadCrossSigningKeysReq[UIAType any] struct {
Master CrossSigningKeys `json:"master_key"`
SelfSigning CrossSigningKeys `json:"self_signing_key"`
UserSigning CrossSigningKeys `json:"user_signing_key"`
- Auth interface{} `json:"auth,omitempty"`
+ Auth UIAType `json:"auth,omitempty"`
}
type KeyMap map[id.DeviceKeyID]string
@@ -309,20 +366,40 @@ type ReqSendToDevice struct {
Messages map[id.UserID]map[id.DeviceID]*event.Content `json:"messages"`
}
+type ReqSendEvent struct {
+ Timestamp int64
+ TransactionID string
+ UnstableDelay time.Duration
+ UnstableStickyDuration time.Duration
+ DontEncrypt bool
+ MeowEventID id.EventID
+}
+
+type ReqDelayedEvents struct {
+ DelayID id.DelayID `json:"-"`
+ Status event.DelayStatus `json:"-"`
+ NextBatch string `json:"-"`
+}
+
+type ReqUpdateDelayedEvent struct {
+ DelayID id.DelayID `json:"-"`
+ Action event.DelayAction `json:"action"`
+}
+
// ReqDeviceInfo is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3devicesdeviceid
type ReqDeviceInfo struct {
DisplayName string `json:"display_name,omitempty"`
}
// ReqDeleteDevice is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#delete_matrixclientv3devicesdeviceid
-type ReqDeleteDevice struct {
- Auth interface{} `json:"auth,omitempty"`
+type ReqDeleteDevice[UIAType any] struct {
+ Auth UIAType `json:"auth,omitempty"`
}
// ReqDeleteDevices is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3delete_devices
-type ReqDeleteDevices struct {
+type ReqDeleteDevices[UIAType any] struct {
Devices []id.DeviceID `json:"devices"`
- Auth interface{} `json:"auth,omitempty"`
+ Auth UIAType `json:"auth,omitempty"`
}
type ReqPutPushRule struct {
@@ -334,18 +411,6 @@ type ReqPutPushRule struct {
Pattern string `json:"pattern"`
}
-// Deprecated: MSC2716 was abandoned
-type ReqBatchSend struct {
- PrevEventID id.EventID `json:"-"`
- BatchID id.BatchID `json:"-"`
-
- BeeperNewMessages bool `json:"-"`
- BeeperMarkReadBy id.UserID `json:"-"`
-
- StateEventsAtStart []*event.Event `json:"state_events_at_start"`
- Events []*event.Event `json:"events"`
-}
-
type ReqBeeperBatchSend struct {
// ForwardIfNoMessages should be set to true if the batch should be forward
// backfilled if there are no messages currently in the room.
@@ -381,6 +446,33 @@ type ReqSendReceipt struct {
ThreadID string `json:"thread_id,omitempty"`
}
+type ReqPublicRooms struct {
+ IncludeAllNetworks bool
+ Limit int
+ Since string
+ ThirdPartyInstanceID string
+}
+
+func (req *ReqPublicRooms) Query() map[string]string {
+ query := map[string]string{}
+ if req == nil {
+ return query
+ }
+ if req.IncludeAllNetworks {
+ query["include_all_networks"] = "true"
+ }
+ if req.Limit > 0 {
+ query["limit"] = strconv.Itoa(req.Limit)
+ }
+ if req.Since != "" {
+ query["since"] = req.Since
+ }
+ if req.ThirdPartyInstanceID != "" {
+ query["third_party_instance_id"] = req.ThirdPartyInstanceID
+ }
+ return query
+}
+
// ReqHierarchy contains the parameters for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv1roomsroomidhierarchy
//
// As it's a GET method, there is no JSON body, so this is only query parameters.
@@ -473,3 +565,54 @@ type ReqReport struct {
Reason string `json:"reason,omitempty"`
Score int `json:"score,omitempty"`
}
+
+type ReqGetRelations struct {
+ RelationType event.RelationType
+ EventType event.Type
+
+ Dir Direction
+ From string
+ To string
+ Limit int
+ Recurse bool
+}
+
+func (rgr *ReqGetRelations) PathSuffix() ClientURLPath {
+ if rgr.RelationType != "" {
+ if rgr.EventType.Type != "" {
+ return ClientURLPath{rgr.RelationType, rgr.EventType.Type}
+ }
+ return ClientURLPath{rgr.RelationType}
+ }
+ return ClientURLPath{}
+}
+
+func (rgr *ReqGetRelations) Query() map[string]string {
+ query := map[string]string{}
+ if rgr.Dir != 0 {
+ query["dir"] = string(rgr.Dir)
+ }
+ if rgr.From != "" {
+ query["from"] = rgr.From
+ }
+ if rgr.To != "" {
+ query["to"] = rgr.To
+ }
+ if rgr.Limit > 0 {
+ query["limit"] = strconv.Itoa(rgr.Limit)
+ }
+ if rgr.Recurse {
+ query["recurse"] = "true"
+ }
+ return query
+}
+
+// ReqSuspend is the request body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323
+type ReqSuspend struct {
+ Suspended bool `json:"suspended"`
+}
+
+// ReqLocked is the request body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323
+type ReqLocked struct {
+ Locked bool `json:"locked"`
+}
diff --git a/responses.go b/responses.go
index 26aaac77..4fbe1fbc 100644
--- a/responses.go
+++ b/responses.go
@@ -4,13 +4,16 @@ import (
"bytes"
"encoding/json"
"fmt"
+ "maps"
"reflect"
+ "slices"
"strconv"
"strings"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.mau.fi/util/jsontime"
+ "go.mau.fi/util/ptr"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
@@ -32,6 +35,11 @@ type RespJoinRoom struct {
RoomID id.RoomID `json:"room_id"`
}
+// RespKnockRoom is the JSON response for https://spec.matrix.org/v1.13/client-server-api/#post_matrixclientv3knockroomidoralias
+type RespKnockRoom struct {
+ RoomID id.RoomID `json:"room_id"`
+}
+
// RespLeaveRoom is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidleave
type RespLeaveRoom struct{}
@@ -97,6 +105,29 @@ type RespContext struct {
// RespSendEvent is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidsendeventtypetxnid
type RespSendEvent struct {
EventID id.EventID `json:"event_id"`
+
+ UnstableDelayID id.DelayID `json:"delay_id,omitempty"`
+}
+
+type RespUpdateDelayedEvent struct{}
+
+type RespDelayedEvents struct {
+ Scheduled []*event.ScheduledDelayedEvent `json:"scheduled,omitempty"`
+ Finalised []*event.FinalisedDelayedEvent `json:"finalised,omitempty"`
+ NextBatch string `json:"next_batch,omitempty"`
+
+ // Deprecated: Synapse implementation still returns this
+ DelayedEvents []*event.ScheduledDelayedEvent `json:"delayed_events,omitempty"`
+ // Deprecated: Synapse implementation still returns this
+ FinalisedEvents []*event.FinalisedDelayedEvent `json:"finalised_events,omitempty"`
+}
+
+type RespRedactUserEvents struct {
+ IsMoreEvents bool `json:"is_more_events"`
+ RedactedEvents struct {
+ Total int `json:"total"`
+ SoftFailed int `json:"soft_failed"`
+ } `json:"redacted_events"`
}
// RespMediaConfig is the JSON response for https://spec.matrix.org/v1.4/client-server-api/#get_matrixmediav3config
@@ -155,8 +186,89 @@ type RespUserDisplayName struct {
}
type RespUserProfile struct {
- DisplayName string `json:"displayname"`
- AvatarURL id.ContentURI `json:"avatar_url"`
+ DisplayName string `json:"displayname,omitempty"`
+ AvatarURL id.ContentURI `json:"avatar_url,omitempty"`
+ Extra map[string]any `json:"-"`
+}
+
+type marshalableUserProfile RespUserProfile
+
+func (r *RespUserProfile) UnmarshalJSON(data []byte) error {
+ err := json.Unmarshal(data, &r.Extra)
+ if err != nil {
+ return err
+ }
+ r.DisplayName, _ = r.Extra["displayname"].(string)
+ avatarURL, _ := r.Extra["avatar_url"].(string)
+ if avatarURL != "" {
+ r.AvatarURL, _ = id.ParseContentURI(avatarURL)
+ }
+ delete(r.Extra, "displayname")
+ delete(r.Extra, "avatar_url")
+ return nil
+}
+
+func (r *RespUserProfile) MarshalJSON() ([]byte, error) {
+ if len(r.Extra) == 0 {
+ return json.Marshal((*marshalableUserProfile)(r))
+ }
+ marshalMap := maps.Clone(r.Extra)
+ if r.DisplayName != "" {
+ marshalMap["displayname"] = r.DisplayName
+ } else {
+ delete(marshalMap, "displayname")
+ }
+ if !r.AvatarURL.IsEmpty() {
+ marshalMap["avatar_url"] = r.AvatarURL.String()
+ } else {
+ delete(marshalMap, "avatar_url")
+ }
+ return json.Marshal(marshalMap)
+}
+
+type RespSearchUserDirectory struct {
+ Limited bool `json:"limited"`
+ Results []*UserDirectoryEntry `json:"results"`
+}
+
+type UserDirectoryEntry struct {
+ RespUserProfile
+ UserID id.UserID `json:"user_id"`
+}
+
+func (r *UserDirectoryEntry) UnmarshalJSON(data []byte) error {
+ err := r.RespUserProfile.UnmarshalJSON(data)
+ if err != nil {
+ return err
+ }
+ userIDStr, _ := r.Extra["user_id"].(string)
+ r.UserID = id.UserID(userIDStr)
+ delete(r.Extra, "user_id")
+ return nil
+}
+
+func (r *UserDirectoryEntry) MarshalJSON() ([]byte, error) {
+ if r.Extra == nil {
+ r.Extra = make(map[string]any)
+ }
+ r.Extra["user_id"] = r.UserID.String()
+ return r.RespUserProfile.MarshalJSON()
+}
+
+type RespMutualRooms struct {
+ Joined []id.RoomID `json:"joined"`
+ NextBatch string `json:"next_batch,omitempty"`
+ Count int `json:"count,omitempty"`
+}
+
+type RespRoomSummary struct {
+ PublicRoomInfo
+
+ Membership event.Membership `json:"membership,omitempty"`
+
+ UnstableRoomVersion id.RoomVersion `json:"im.nheko.summary.room_version,omitempty"`
+ UnstableRoomVersionOld id.RoomVersion `json:"im.nheko.summary.version,omitempty"`
+ UnstableEncryption id.Algorithm `json:"im.nheko.summary.encryption,omitempty"`
}
// RespRegisterAvailable is the JSON response for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv3registeravailable
@@ -207,6 +319,9 @@ type RespLogin struct {
DeviceID id.DeviceID `json:"device_id"`
UserID id.UserID `json:"user_id"`
WellKnown *ClientWellKnown `json:"well_known,omitempty"`
+
+ RefreshToken string `json:"refresh_token,omitempty"`
+ ExpiresInMS int64 `json:"expires_in_ms,omitempty"`
}
// RespLogout is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3logout
@@ -227,6 +342,24 @@ type LazyLoadSummary struct {
InvitedMemberCount *int `json:"m.invited_member_count,omitempty"`
}
+func (lls *LazyLoadSummary) MemberCount() int {
+ if lls == nil {
+ return 0
+ }
+ return ptr.Val(lls.JoinedMemberCount) + ptr.Val(lls.InvitedMemberCount)
+}
+
+func (lls *LazyLoadSummary) Equal(other *LazyLoadSummary) bool {
+ if lls == other {
+ return true
+ } else if lls == nil || other == nil {
+ return false
+ }
+ return ptr.Val(lls.JoinedMemberCount) == ptr.Val(other.JoinedMemberCount) &&
+ ptr.Val(lls.InvitedMemberCount) == ptr.Val(other.InvitedMemberCount) &&
+ slices.Equal(lls.Heroes, other.Heroes)
+}
+
type SyncEventsList struct {
Events []*event.Event `json:"events,omitempty"`
}
@@ -322,6 +455,7 @@ type BeeperInboxPreviewEvent struct {
type SyncJoinedRoom struct {
Summary LazyLoadSummary `json:"summary"`
State SyncEventsList `json:"state"`
+ StateAfter *SyncEventsList `json:"state_after,omitempty"`
Timeline SyncTimeline `json:"timeline"`
Ephemeral SyncEventsList `json:"ephemeral"`
AccountData SyncEventsList `json:"account_data"`
@@ -347,16 +481,7 @@ func (sjr SyncJoinedRoom) MarshalJSON() ([]byte, error) {
}
type SyncInvitedRoom struct {
- Summary LazyLoadSummary `json:"summary"`
- State SyncEventsList `json:"invite_state"`
-}
-
-type marshalableSyncInvitedRoom SyncInvitedRoom
-
-var syncInvitedRoomPathsToDelete = []string{"summary"}
-
-func (sir SyncInvitedRoom) MarshalJSON() ([]byte, error) {
- return marshalAndDeleteEmpty((marshalableSyncInvitedRoom)(sir), syncInvitedRoomPathsToDelete)
+ State SyncEventsList `json:"invite_state"`
}
type SyncKnockedRoom struct {
@@ -421,29 +546,19 @@ type RespDeviceInfo struct {
LastSeenTS int64 `json:"last_seen_ts"`
}
-// Deprecated: MSC2716 was abandoned
-type RespBatchSend struct {
- StateEventIDs []id.EventID `json:"state_event_ids"`
- EventIDs []id.EventID `json:"event_ids"`
-
- InsertionEventID id.EventID `json:"insertion_event_id"`
- BatchEventID id.EventID `json:"batch_event_id"`
- BaseInsertionEventID id.EventID `json:"base_insertion_event_id"`
-
- NextBatchID id.BatchID `json:"next_batch_id"`
-}
-
type RespBeeperBatchSend struct {
EventIDs []id.EventID `json:"event_ids"`
}
// RespCapabilities is the JSON response for https://spec.matrix.org/v1.3/client-server-api/#get_matrixclientv3capabilities
type RespCapabilities struct {
- RoomVersions *CapRoomVersions `json:"m.room_versions,omitempty"`
- ChangePassword *CapBooleanTrue `json:"m.change_password,omitempty"`
- SetDisplayname *CapBooleanTrue `json:"m.set_displayname,omitempty"`
- SetAvatarURL *CapBooleanTrue `json:"m.set_avatar_url,omitempty"`
- ThreePIDChanges *CapBooleanTrue `json:"m.3pid_changes,omitempty"`
+ RoomVersions *CapRoomVersions `json:"m.room_versions,omitempty"`
+ ChangePassword *CapBooleanTrue `json:"m.change_password,omitempty"`
+ SetDisplayname *CapBooleanTrue `json:"m.set_displayname,omitempty"`
+ SetAvatarURL *CapBooleanTrue `json:"m.set_avatar_url,omitempty"`
+ ThreePIDChanges *CapBooleanTrue `json:"m.3pid_changes,omitempty"`
+ GetLoginToken *CapBooleanTrue `json:"m.get_login_token,omitempty"`
+ UnstableAccountModeration *CapUnstableAccountModeration `json:"uk.timedout.msc4323,omitempty"`
Custom map[string]interface{} `json:"-"`
}
@@ -552,29 +667,44 @@ func (vers *CapRoomVersions) IsAvailable(version string) bool {
return available
}
+type CapUnstableAccountModeration struct {
+ Suspend bool `json:"suspend"`
+ Lock bool `json:"lock"`
+}
+
+type RespPublicRooms struct {
+ Chunk []*PublicRoomInfo `json:"chunk"`
+ NextBatch string `json:"next_batch,omitempty"`
+ PrevBatch string `json:"prev_batch,omitempty"`
+ TotalRoomCountEstimate int `json:"total_room_count_estimate"`
+}
+
+type PublicRoomInfo struct {
+ RoomID id.RoomID `json:"room_id"`
+ AvatarURL id.ContentURIString `json:"avatar_url,omitempty"`
+ CanonicalAlias id.RoomAlias `json:"canonical_alias,omitempty"`
+ GuestCanJoin bool `json:"guest_can_join"`
+ JoinRule event.JoinRule `json:"join_rule,omitempty"`
+ Name string `json:"name,omitempty"`
+ NumJoinedMembers int `json:"num_joined_members"`
+ RoomType event.RoomType `json:"room_type"`
+ Topic string `json:"topic,omitempty"`
+ WorldReadable bool `json:"world_readable"`
+
+ RoomVersion id.RoomVersion `json:"room_version,omitempty"`
+ Encryption id.Algorithm `json:"encryption,omitempty"`
+ AllowedRoomIDs []id.RoomID `json:"allowed_room_ids,omitempty"`
+}
+
// RespHierarchy is the JSON response for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv1roomsroomidhierarchy
type RespHierarchy struct {
- NextBatch string `json:"next_batch,omitempty"`
- Rooms []ChildRoomsChunk `json:"rooms"`
+ NextBatch string `json:"next_batch,omitempty"`
+ Rooms []*ChildRoomsChunk `json:"rooms"`
}
type ChildRoomsChunk struct {
- AvatarURL id.ContentURI `json:"avatar_url,omitempty"`
- CanonicalAlias id.RoomAlias `json:"canonical_alias,omitempty"`
- ChildrenState []StrippedStateWithTime `json:"children_state"`
- GuestCanJoin bool `json:"guest_can_join"`
- JoinRule event.JoinRule `json:"join_rule,omitempty"`
- Name string `json:"name,omitempty"`
- NumJoinedMembers int `json:"num_joined_members"`
- RoomID id.RoomID `json:"room_id"`
- RoomType event.RoomType `json:"room_type"`
- Topic string `json:"topic,omitempty"`
- WorldReadble bool `json:"world_readable"`
-}
-
-type StrippedStateWithTime struct {
- event.StrippedState
- Timestamp jsontime.UnixMilli `json:"origin_server_ts"`
+ PublicRoomInfo
+ ChildrenState []*event.Event `json:"children_state"`
}
type RespAppservicePing struct {
@@ -623,3 +753,47 @@ type RespRoomKeysUpdate struct {
Count int `json:"count"`
ETag string `json:"etag"`
}
+
+type RespOpenIDToken struct {
+ AccessToken string `json:"access_token"`
+ ExpiresIn int `json:"expires_in"`
+ MatrixServerName string `json:"matrix_server_name"`
+ TokenType string `json:"token_type"` // Always "Bearer"
+}
+
+type RespGetRelations struct {
+ Chunk []*event.Event `json:"chunk"`
+ NextBatch string `json:"next_batch,omitempty"`
+ PrevBatch string `json:"prev_batch,omitempty"`
+ RecursionDepth int `json:"recursion_depth,omitempty"`
+}
+
+// RespSuspended is the response body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323
+type RespSuspended struct {
+ Suspended bool `json:"suspended"`
+}
+
+// RespLocked is the response body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323
+type RespLocked struct {
+ Locked bool `json:"locked"`
+}
+
+type ConnectionInfo struct {
+ IP string `json:"ip,omitempty"`
+ LastSeen jsontime.UnixMilli `json:"last_seen,omitempty"`
+ UserAgent string `json:"user_agent,omitempty"`
+}
+
+type SessionInfo struct {
+ Connections []ConnectionInfo `json:"connections,omitempty"`
+}
+
+type DeviceInfo struct {
+ Sessions []SessionInfo `json:"sessions,omitempty"`
+}
+
+// RespWhoIs is the response body for https://spec.matrix.org/v1.15/client-server-api/#get_matrixclientv3adminwhoisuserid
+type RespWhoIs struct {
+ UserID id.UserID `json:"user_id,omitempty"`
+ Devices map[id.DeviceID]DeviceInfo `json:"devices,omitempty"`
+}
diff --git a/responses_test.go b/responses_test.go
index b23d85ad..73d82635 100644
--- a/responses_test.go
+++ b/responses_test.go
@@ -8,7 +8,6 @@ package mautrix_test
import (
"encoding/json"
- "fmt"
"testing"
"github.com/stretchr/testify/assert"
@@ -86,7 +85,6 @@ func TestRespCapabilities_UnmarshalJSON(t *testing.T) {
var caps mautrix.RespCapabilities
err := json.Unmarshal([]byte(sampleData), &caps)
require.NoError(t, err)
- fmt.Println(caps)
require.NotNil(t, caps.RoomVersions)
assert.Equal(t, "9", caps.RoomVersions.Default)
diff --git a/room.go b/room.go
index c3ddb7e6..4292bff5 100644
--- a/room.go
+++ b/room.go
@@ -5,8 +5,6 @@ import (
"maunium.net/go/mautrix/id"
)
-type RoomStateMap = map[event.Type]map[string]*event.Event
-
// Room represents a single Matrix room.
type Room struct {
ID id.RoomID
@@ -25,8 +23,8 @@ func (room Room) UpdateState(evt *event.Event) {
// GetStateEvent returns the state event for the given type/state_key combo, or nil.
func (room Room) GetStateEvent(eventType event.Type, stateKey string) *event.Event {
- stateEventMap, _ := room.State[eventType]
- evt, _ := stateEventMap[stateKey]
+ stateEventMap := room.State[eventType]
+ evt := stateEventMap[stateKey]
return evt
}
diff --git a/sqlstatestore/statestore.go b/sqlstatestore/statestore.go
index d594c307..11957dfa 100644
--- a/sqlstatestore/statestore.go
+++ b/sqlstatestore/statestore.go
@@ -62,6 +62,9 @@ func (store *SQLStateStore) IsRegistered(ctx context.Context, userID id.UserID)
}
func (store *SQLStateStore) MarkRegistered(ctx context.Context, userID id.UserID) error {
+ if userID == "" {
+ return fmt.Errorf("user ID is empty")
+ }
_, err := store.Exec(ctx, "INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
return err
}
@@ -85,14 +88,11 @@ func (store *SQLStateStore) GetRoomMembers(ctx context.Context, roomID id.RoomID
query = fmt.Sprintf("%s AND membership IN (%s)", query, strings.Join(placeholders, ","))
}
rows, err := store.Query(ctx, query, args...)
- if err != nil {
- return nil, err
- }
members := make(map[id.UserID]*event.MemberEventContent)
- return members, dbutil.NewRowIter(rows, func(row dbutil.Scannable) (ret Member, err error) {
+ return members, dbutil.NewRowIterWithError(rows, func(row dbutil.Scannable) (ret Member, err error) {
err = row.Scan(&ret.UserID, &ret.Membership, &ret.Displayname, &ret.AvatarURL)
return
- }).Iter(func(m Member) (bool, error) {
+ }, err).Iter(func(m Member) (bool, error) {
members[m.UserID] = &m.MemberEventContent
return true, nil
})
@@ -159,10 +159,7 @@ func (store *SQLStateStore) FindSharedRooms(ctx context.Context, userID id.UserI
`
}
rows, err := store.Query(ctx, query, userID)
- if err != nil {
- return nil, err
- }
- return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.RoomID]).AsList()
+ return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.RoomID], err).AsList()
}
func (store *SQLStateStore) IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool {
@@ -188,6 +185,11 @@ func (store *SQLStateStore) IsMembership(ctx context.Context, roomID id.RoomID,
}
func (store *SQLStateStore) SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error {
+ if roomID == "" {
+ return fmt.Errorf("room ID is empty")
+ } else if userID == "" {
+ return fmt.Errorf("user ID is empty")
+ }
_, err := store.Exec(ctx, `
INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, '', '')
ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership
@@ -220,6 +222,11 @@ func (u *userProfileRow) GetMassInsertValues() [5]any {
var userProfileMassInserter = dbutil.NewMassInsertBuilder[*userProfileRow, [1]any](insertUserProfileQuery, "($1, $%d, $%d, $%d, $%d, $%d)")
func (store *SQLStateStore) SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error {
+ if roomID == "" {
+ return fmt.Errorf("room ID is empty")
+ } else if userID == "" {
+ return fmt.Errorf("user ID is empty")
+ }
var nameSkeleton []byte
if !store.DisableNameDisambiguation && len(member.Displayname) > 0 {
nameSkeletonArr := confusable.SkeletonHash(member.Displayname)
@@ -241,6 +248,9 @@ func (store *SQLStateStore) IsConfusableName(ctx context.Context, roomID id.Room
const userProfileMassInsertBatchSize = 500
func (store *SQLStateStore) ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error {
+ if roomID == "" {
+ return fmt.Errorf("room ID is empty")
+ }
return store.DoTxn(ctx, nil, func(ctx context.Context) error {
err := store.ClearCachedMembers(ctx, roomID, onlyMemberships...)
if err != nil {
@@ -303,7 +313,7 @@ func (store *SQLStateStore) ClearCachedMembers(ctx context.Context, roomID id.Ro
}
func (store *SQLStateStore) HasFetchedMembers(ctx context.Context, roomID id.RoomID) (fetched bool, err error) {
- err = store.QueryRow(ctx, "SELECT members_fetched FROM mx_room_state WHERE room_id=$1", roomID).Scan(&fetched)
+ err = store.QueryRow(ctx, "SELECT COALESCE(members_fetched, false) FROM mx_room_state WHERE room_id=$1", roomID).Scan(&fetched)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
@@ -311,6 +321,9 @@ func (store *SQLStateStore) HasFetchedMembers(ctx context.Context, roomID id.Roo
}
func (store *SQLStateStore) MarkMembersFetched(ctx context.Context, roomID id.RoomID) error {
+ if roomID == "" {
+ return fmt.Errorf("room ID is empty")
+ }
_, err := store.Exec(ctx, `
INSERT INTO mx_room_state (room_id, members_fetched) VALUES ($1, true)
ON CONFLICT (room_id) DO UPDATE SET members_fetched=true
@@ -340,6 +353,9 @@ func (store *SQLStateStore) GetAllMembers(ctx context.Context, roomID id.RoomID)
}
func (store *SQLStateStore) SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error {
+ if roomID == "" {
+ return fmt.Errorf("room ID is empty")
+ }
contentBytes, err := json.Marshal(content)
if err != nil {
return fmt.Errorf("failed to marshal content JSON: %w", err)
@@ -354,7 +370,7 @@ func (store *SQLStateStore) SetEncryptionEvent(ctx context.Context, roomID id.Ro
func (store *SQLStateStore) GetEncryptionEvent(ctx context.Context, roomID id.RoomID) (*event.EncryptionEventContent, error) {
var data []byte
err := store.
- QueryRow(ctx, "SELECT encryption FROM mx_room_state WHERE room_id=$1", roomID).
+ QueryRow(ctx, "SELECT encryption FROM mx_room_state WHERE room_id=$1 AND encryption IS NOT NULL", roomID).
Scan(&data)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
@@ -377,6 +393,9 @@ func (store *SQLStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (
}
func (store *SQLStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error {
+ if roomID == "" {
+ return fmt.Errorf("room ID is empty")
+ }
_, err := store.Exec(ctx, `
INSERT INTO mx_room_state (room_id, power_levels) VALUES ($1, $2)
ON CONFLICT (room_id) DO UPDATE SET power_levels=excluded.power_levels
@@ -385,89 +404,92 @@ func (store *SQLStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID
}
func (store *SQLStateStore) GetPowerLevels(ctx context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) {
+ levels = &event.PowerLevelsEventContent{}
err = store.
- QueryRow(ctx, "SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID).
- Scan(&dbutil.JSON{Data: &levels})
+ QueryRow(ctx, "SELECT power_levels, create_event FROM mx_room_state WHERE room_id=$1 AND power_levels IS NOT NULL", roomID).
+ Scan(&dbutil.JSON{Data: &levels}, &dbutil.JSON{Data: &levels.CreateEvent})
if errors.Is(err, sql.ErrNoRows) {
- err = nil
+ return nil, nil
+ } else if err != nil {
+ return nil, err
+ }
+ if levels.CreateEvent != nil {
+ err = levels.CreateEvent.Content.ParseRaw(event.StateCreate)
}
return
}
func (store *SQLStateStore) GetPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID) (int, error) {
- if store.Dialect == dbutil.Postgres {
- var powerLevel int
- err := store.
- QueryRow(ctx, `
- SELECT COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
- FROM mx_room_state WHERE room_id=$1
- `, roomID, userID).
- Scan(&powerLevel)
- return powerLevel, err
- } else {
- levels, err := store.GetPowerLevels(ctx, roomID)
- if err != nil {
- return 0, err
- }
- return levels.GetUserLevel(userID), nil
+ levels, err := store.GetPowerLevels(ctx, roomID)
+ if err != nil {
+ return 0, err
}
+ return levels.GetUserLevel(userID), nil
}
func (store *SQLStateStore) GetPowerLevelRequirement(ctx context.Context, roomID id.RoomID, eventType event.Type) (int, error) {
- if store.Dialect == dbutil.Postgres {
- defaultType := "events_default"
- defaultValue := 0
- if eventType.IsState() {
- defaultType = "state_default"
- defaultValue = 50
- }
- var powerLevel int
- err := store.
- QueryRow(ctx, `
- SELECT COALESCE((power_levels->'events'->$2)::int, (power_levels->'$3')::int, $4)
- FROM mx_room_state WHERE room_id=$1
- `, roomID, eventType.Type, defaultType, defaultValue).
- Scan(&powerLevel)
- if errors.Is(err, sql.ErrNoRows) {
- err = nil
- powerLevel = defaultValue
- }
- return powerLevel, err
- } else {
- levels, err := store.GetPowerLevels(ctx, roomID)
- if err != nil {
- return 0, err
- }
- return levels.GetEventLevel(eventType), nil
+ levels, err := store.GetPowerLevels(ctx, roomID)
+ if err != nil {
+ return 0, err
}
+ return levels.GetEventLevel(eventType), nil
}
func (store *SQLStateStore) HasPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID, eventType event.Type) (bool, error) {
- if store.Dialect == dbutil.Postgres {
- defaultType := "events_default"
- defaultValue := 0
- if eventType.IsState() {
- defaultType = "state_default"
- defaultValue = 50
- }
- var hasPower bool
- err := store.
- QueryRow(ctx, `SELECT
- COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
- >=
- COALESCE((power_levels->'events'->$3)::int, (power_levels->'$4')::int, $5)
- FROM mx_room_state WHERE room_id=$1`, roomID, userID, eventType.Type, defaultType, defaultValue).
- Scan(&hasPower)
- if errors.Is(err, sql.ErrNoRows) {
- err = nil
- hasPower = defaultValue == 0
- }
- return hasPower, err
- } else {
- levels, err := store.GetPowerLevels(ctx, roomID)
- if err != nil {
- return false, err
- }
- return levels.GetUserLevel(userID) >= levels.GetEventLevel(eventType), nil
+ levels, err := store.GetPowerLevels(ctx, roomID)
+ if err != nil {
+ return false, err
}
+ return levels.GetUserLevel(userID) >= levels.GetEventLevel(eventType), nil
+}
+
+func (store *SQLStateStore) SetCreate(ctx context.Context, evt *event.Event) error {
+ if evt.Type != event.StateCreate {
+ return fmt.Errorf("invalid event type for create event: %s", evt.Type)
+ } else if evt.RoomID == "" {
+ return fmt.Errorf("room ID is empty")
+ }
+ _, err := store.Exec(ctx, `
+ INSERT INTO mx_room_state (room_id, create_event) VALUES ($1, $2)
+ ON CONFLICT (room_id) DO UPDATE SET create_event=excluded.create_event
+ `, evt.RoomID, dbutil.JSON{Data: evt})
+ return err
+}
+
+func (store *SQLStateStore) GetCreate(ctx context.Context, roomID id.RoomID) (evt *event.Event, err error) {
+ err = store.
+ QueryRow(ctx, "SELECT create_event FROM mx_room_state WHERE room_id=$1 AND create_event IS NOT NULL", roomID).
+ Scan(&dbutil.JSON{Data: &evt})
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, nil
+ } else if err != nil {
+ return nil, err
+ }
+ if evt != nil {
+ err = evt.Content.ParseRaw(event.StateCreate)
+ }
+ return
+}
+
+func (store *SQLStateStore) SetJoinRules(ctx context.Context, roomID id.RoomID, rules *event.JoinRulesEventContent) error {
+ if roomID == "" {
+ return fmt.Errorf("room ID is empty")
+ }
+ _, err := store.Exec(ctx, `
+ INSERT INTO mx_room_state (room_id, join_rules) VALUES ($1, $2)
+ ON CONFLICT (room_id) DO UPDATE SET join_rules=excluded.join_rules
+ `, roomID, dbutil.JSON{Data: rules})
+ return err
+}
+
+func (store *SQLStateStore) GetJoinRules(ctx context.Context, roomID id.RoomID) (levels *event.JoinRulesEventContent, err error) {
+ levels = &event.JoinRulesEventContent{}
+ err = store.
+ QueryRow(ctx, "SELECT join_rules FROM mx_room_state WHERE room_id=$1 AND join_rules IS NOT NULL", roomID).
+ Scan(&dbutil.JSON{Data: &levels})
+ if errors.Is(err, sql.ErrNoRows) {
+ levels = nil
+ err = nil
+ }
+ return
}
diff --git a/sqlstatestore/v00-latest-revision.sql b/sqlstatestore/v00-latest-revision.sql
index a58cc56a..4679f1c6 100644
--- a/sqlstatestore/v00-latest-revision.sql
+++ b/sqlstatestore/v00-latest-revision.sql
@@ -1,4 +1,4 @@
--- v0 -> v7 (compatible with v3+): Latest revision
+-- v0 -> v10 (compatible with v3+): Latest revision
CREATE TABLE mx_registrations (
user_id TEXT PRIMARY KEY
@@ -26,5 +26,7 @@ CREATE TABLE mx_room_state (
room_id TEXT PRIMARY KEY,
power_levels jsonb,
encryption jsonb,
+ create_event jsonb,
+ join_rules jsonb,
members_fetched BOOLEAN NOT NULL DEFAULT false
);
diff --git a/sqlstatestore/v08-create-event.sql b/sqlstatestore/v08-create-event.sql
new file mode 100644
index 00000000..9f1b55c9
--- /dev/null
+++ b/sqlstatestore/v08-create-event.sql
@@ -0,0 +1,2 @@
+-- v8 (compatible with v3+): Add create event to room state table
+ALTER TABLE mx_room_state ADD COLUMN create_event jsonb;
diff --git a/sqlstatestore/v09-clear-empty-room-ids.sql b/sqlstatestore/v09-clear-empty-room-ids.sql
new file mode 100644
index 00000000..ca951068
--- /dev/null
+++ b/sqlstatestore/v09-clear-empty-room-ids.sql
@@ -0,0 +1,3 @@
+-- v9 (compatible with v3+): Clear invalid rows
+DELETE FROM mx_room_state WHERE room_id='';
+DELETE FROM mx_user_profile WHERE room_id='' OR user_id='';
diff --git a/sqlstatestore/v10-join-rules.sql b/sqlstatestore/v10-join-rules.sql
new file mode 100644
index 00000000..3074c46a
--- /dev/null
+++ b/sqlstatestore/v10-join-rules.sql
@@ -0,0 +1,2 @@
+-- v10 (compatible with v3+): Add join rules to room state table
+ALTER TABLE mx_room_state ADD COLUMN join_rules jsonb;
diff --git a/statestore.go b/statestore.go
index e728b885..2bd498dd 100644
--- a/statestore.go
+++ b/statestore.go
@@ -34,6 +34,12 @@ type StateStore interface {
SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error
GetPowerLevels(ctx context.Context, roomID id.RoomID) (*event.PowerLevelsEventContent, error)
+ SetCreate(ctx context.Context, evt *event.Event) error
+ GetCreate(ctx context.Context, roomID id.RoomID) (*event.Event, error)
+
+ GetJoinRules(ctx context.Context, roomID id.RoomID) (*event.JoinRulesEventContent, error)
+ SetJoinRules(ctx context.Context, roomID id.RoomID, content *event.JoinRulesEventContent) error
+
HasFetchedMembers(ctx context.Context, roomID id.RoomID) (bool, error)
MarkMembersFetched(ctx context.Context, roomID id.RoomID) error
GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error)
@@ -68,9 +74,13 @@ func UpdateStateStore(ctx context.Context, store StateStore, evt *event.Event) {
err = store.SetPowerLevels(ctx, evt.RoomID, content)
case *event.EncryptionEventContent:
err = store.SetEncryptionEvent(ctx, evt.RoomID, content)
+ case *event.CreateEventContent:
+ err = store.SetCreate(ctx, evt)
+ case *event.JoinRulesEventContent:
+ err = store.SetJoinRules(ctx, evt.RoomID, content)
default:
switch evt.Type {
- case event.StateMember, event.StatePowerLevels, event.StateEncryption:
+ case event.StateMember, event.StatePowerLevels, event.StateEncryption, event.StateCreate:
zerolog.Ctx(ctx).Warn().
Stringer("event_id", evt.ID).
Str("event_type", evt.Type.Type).
@@ -101,11 +111,14 @@ type MemoryStateStore struct {
MembersFetched map[id.RoomID]bool `json:"members_fetched"`
PowerLevels map[id.RoomID]*event.PowerLevelsEventContent `json:"power_levels"`
Encryption map[id.RoomID]*event.EncryptionEventContent `json:"encryption"`
+ Create map[id.RoomID]*event.Event `json:"create"`
+ JoinRules map[id.RoomID]*event.JoinRulesEventContent `json:"join_rules"`
registrationsLock sync.RWMutex
membersLock sync.RWMutex
powerLevelsLock sync.RWMutex
encryptionLock sync.RWMutex
+ joinRulesLock sync.RWMutex
}
func NewMemoryStateStore() StateStore {
@@ -115,6 +128,8 @@ func NewMemoryStateStore() StateStore {
MembersFetched: make(map[id.RoomID]bool),
PowerLevels: make(map[id.RoomID]*event.PowerLevelsEventContent),
Encryption: make(map[id.RoomID]*event.EncryptionEventContent),
+ Create: make(map[id.RoomID]*event.Event),
+ JoinRules: make(map[id.RoomID]*event.JoinRulesEventContent),
}
}
@@ -298,6 +313,9 @@ func (store *MemoryStateStore) SetPowerLevels(_ context.Context, roomID id.RoomI
func (store *MemoryStateStore) GetPowerLevels(_ context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) {
store.powerLevelsLock.RLock()
levels = store.PowerLevels[roomID]
+ if levels != nil && levels.CreateEvent == nil {
+ levels.CreateEvent = store.Create[roomID]
+ }
store.powerLevelsLock.RUnlock()
return
}
@@ -314,6 +332,23 @@ func (store *MemoryStateStore) HasPowerLevel(ctx context.Context, roomID id.Room
return exerrors.Must(store.GetPowerLevel(ctx, roomID, userID)) >= exerrors.Must(store.GetPowerLevelRequirement(ctx, roomID, eventType)), nil
}
+func (store *MemoryStateStore) SetCreate(ctx context.Context, evt *event.Event) error {
+ store.powerLevelsLock.Lock()
+ store.Create[evt.RoomID] = evt
+ if pls, ok := store.PowerLevels[evt.RoomID]; ok && pls.CreateEvent == nil {
+ pls.CreateEvent = evt
+ }
+ store.powerLevelsLock.Unlock()
+ return nil
+}
+
+func (store *MemoryStateStore) GetCreate(ctx context.Context, roomID id.RoomID) (*event.Event, error) {
+ store.powerLevelsLock.RLock()
+ evt := store.Create[roomID]
+ store.powerLevelsLock.RUnlock()
+ return evt, nil
+}
+
func (store *MemoryStateStore) SetEncryptionEvent(_ context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error {
store.encryptionLock.Lock()
store.Encryption[roomID] = content
@@ -327,6 +362,19 @@ func (store *MemoryStateStore) GetEncryptionEvent(_ context.Context, roomID id.R
return store.Encryption[roomID], nil
}
+func (store *MemoryStateStore) SetJoinRules(ctx context.Context, roomID id.RoomID, content *event.JoinRulesEventContent) error {
+ store.joinRulesLock.Lock()
+ store.JoinRules[roomID] = content
+ store.joinRulesLock.Unlock()
+ return nil
+}
+
+func (store *MemoryStateStore) GetJoinRules(ctx context.Context, roomID id.RoomID) (*event.JoinRulesEventContent, error) {
+ store.joinRulesLock.RLock()
+ defer store.joinRulesLock.RUnlock()
+ return store.JoinRules[roomID], nil
+}
+
func (store *MemoryStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error) {
cfg, err := store.GetEncryptionEvent(ctx, roomID)
return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1, err
diff --git a/synapseadmin/client.go b/synapseadmin/client.go
index 775b4b13..6925ca7d 100644
--- a/synapseadmin/client.go
+++ b/synapseadmin/client.go
@@ -14,9 +14,9 @@ import (
//
// https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/index.html
type Client struct {
- *mautrix.Client
+ Client *mautrix.Client
}
func (cli *Client) BuildAdminURL(path ...any) string {
- return cli.BuildURL(mautrix.SynapseAdminURLPath(path))
+ return cli.Client.BuildURL(mautrix.SynapseAdminURLPath(path))
}
diff --git a/synapseadmin/register.go b/synapseadmin/register.go
index 641f9b56..05e0729a 100644
--- a/synapseadmin/register.go
+++ b/synapseadmin/register.go
@@ -73,7 +73,7 @@ func (req *ReqSharedSecretRegister) Sign(secret string) string {
// This does not need to be called manually as SharedSecretRegister will automatically call this if no nonce is provided.
func (cli *Client) GetRegisterNonce(ctx context.Context) (string, error) {
var resp respGetRegisterNonce
- _, err := cli.MakeRequest(ctx, http.MethodGet, cli.BuildURL(mautrix.SynapseAdminURLPath{"v1", "register"}), nil, &resp)
+ _, err := cli.Client.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v1", "register"), nil, &resp)
if err != nil {
return "", err
}
@@ -93,7 +93,7 @@ func (cli *Client) SharedSecretRegister(ctx context.Context, sharedSecret string
}
req.SHA1Checksum = req.Sign(sharedSecret)
var resp mautrix.RespRegister
- _, err = cli.MakeRequest(ctx, http.MethodPost, cli.BuildURL(mautrix.SynapseAdminURLPath{"v1", "register"}), &req, &resp)
+ _, err = cli.Client.MakeRequest(ctx, http.MethodPost, cli.BuildAdminURL("v1", "register"), &req, &resp)
if err != nil {
return nil, err
}
diff --git a/synapseadmin/roomapi.go b/synapseadmin/roomapi.go
index 6c072e23..0925b748 100644
--- a/synapseadmin/roomapi.go
+++ b/synapseadmin/roomapi.go
@@ -75,12 +75,17 @@ type RespListRooms struct {
// https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#list-room-api
func (cli *Client) ListRooms(ctx context.Context, req ReqListRoom) (RespListRooms, error) {
var resp RespListRooms
- var reqURL string
- reqURL = cli.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery())
- _, err := cli.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp)
+ reqURL := cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery())
+ _, err := cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp)
return resp, err
}
+func (cli *Client) RoomInfo(ctx context.Context, roomID id.RoomID) (resp *RoomInfo, err error) {
+ reqURL := cli.BuildAdminURL("v1", "rooms", roomID)
+ _, err = cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp)
+ return
+}
+
type RespRoomMessages = mautrix.RespMessages
// RoomMessages returns a list of messages in a room.
@@ -104,13 +109,14 @@ func (cli *Client) RoomMessages(ctx context.Context, roomID id.RoomID, from, to
if limit != 0 {
query["limit"] = strconv.Itoa(limit)
}
- urlPath := cli.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms", roomID, "messages"}, query)
- _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
+ urlPath := cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms", roomID, "messages"}, query)
+ _, err = cli.Client.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return resp, err
}
type ReqDeleteRoom struct {
Purge bool `json:"purge,omitempty"`
+ ForcePurge bool `json:"force_purge,omitempty"`
Block bool `json:"block,omitempty"`
Message string `json:"message,omitempty"`
RoomName string `json:"room_name,omitempty"`
@@ -121,6 +127,19 @@ type RespDeleteRoom struct {
DeleteID string `json:"delete_id"`
}
+type RespDeleteRoomResult struct {
+ KickedUsers []id.UserID `json:"kicked_users,omitempty"`
+ FailedToKickUsers []id.UserID `json:"failed_to_kick_users,omitempty"`
+ LocalAliases []id.RoomAlias `json:"local_aliases,omitempty"`
+ NewRoomID id.RoomID `json:"new_room_id,omitempty"`
+}
+
+type RespDeleteRoomStatus struct {
+ Status string `json:"status,omitempty"`
+ Error string `json:"error,omitempty"`
+ ShutdownRoom RespDeleteRoomResult `json:"shutdown_room,omitempty"`
+}
+
// DeleteRoom deletes a room from the server, optionally blocking it and/or purging all data from the database.
//
// This calls the async version of the endpoint, which will return immediately and delete the room in the background.
@@ -129,10 +148,37 @@ type RespDeleteRoom struct {
func (cli *Client) DeleteRoom(ctx context.Context, roomID id.RoomID, req ReqDeleteRoom) (RespDeleteRoom, error) {
reqURL := cli.BuildAdminURL("v2", "rooms", roomID)
var resp RespDeleteRoom
- _, err := cli.MakeRequest(ctx, http.MethodDelete, reqURL, &req, &resp)
+ _, err := cli.Client.MakeRequest(ctx, http.MethodDelete, reqURL, &req, &resp)
return resp, err
}
+func (cli *Client) DeleteRoomStatus(ctx context.Context, deleteID string) (resp RespDeleteRoomStatus, err error) {
+ reqURL := cli.BuildAdminURL("v2", "rooms", "delete_status", deleteID)
+ _, err = cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp)
+ return
+}
+
+// DeleteRoomSync deletes a room from the server, optionally blocking it and/or purging all data from the database.
+//
+// This calls the synchronous version of the endpoint, which will block until the room is deleted.
+//
+// https://element-hq.github.io/synapse/latest/admin_api/rooms.html#version-1-old-version
+func (cli *Client) DeleteRoomSync(ctx context.Context, roomID id.RoomID, req ReqDeleteRoom) (resp RespDeleteRoomResult, err error) {
+ reqURL := cli.BuildAdminURL("v1", "rooms", roomID)
+ httpClient := &http.Client{}
+ _, err = cli.Client.MakeFullRequest(ctx, mautrix.FullRequest{
+ Method: http.MethodDelete,
+ URL: reqURL,
+ RequestJSON: &req,
+ ResponseJSON: &resp,
+ MaxAttempts: 1,
+ // Use a fresh HTTP client without timeouts
+ Client: httpClient,
+ })
+ httpClient.CloseIdleConnections()
+ return
+}
+
type RespRoomsMembers struct {
Members []id.UserID `json:"members"`
Total int `json:"total"`
@@ -144,7 +190,7 @@ type RespRoomsMembers struct {
func (cli *Client) RoomMembers(ctx context.Context, roomID id.RoomID) (RespRoomsMembers, error) {
reqURL := cli.BuildAdminURL("v1", "rooms", roomID, "members")
var resp RespRoomsMembers
- _, err := cli.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp)
+ _, err := cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp)
return resp, err
}
@@ -157,7 +203,7 @@ type ReqMakeRoomAdmin struct {
// https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#make-room-admin-api
func (cli *Client) MakeRoomAdmin(ctx context.Context, roomIDOrAlias string, req ReqMakeRoomAdmin) error {
reqURL := cli.BuildAdminURL("v1", "rooms", roomIDOrAlias, "make_room_admin")
- _, err := cli.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil)
+ _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil)
return err
}
@@ -170,7 +216,7 @@ type ReqJoinUserToRoom struct {
// https://matrix-org.github.io/synapse/latest/admin_api/room_membership.html
func (cli *Client) JoinUserToRoom(ctx context.Context, roomID id.RoomID, req ReqJoinUserToRoom) error {
reqURL := cli.BuildAdminURL("v1", "join", roomID)
- _, err := cli.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil)
+ _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil)
return err
}
@@ -183,7 +229,7 @@ type ReqBlockRoom struct {
// https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#block-room-api
func (cli *Client) BlockRoom(ctx context.Context, roomID id.RoomID, req ReqBlockRoom) error {
reqURL := cli.BuildAdminURL("v1", "rooms", roomID, "block")
- _, err := cli.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil)
+ _, err := cli.Client.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil)
return err
}
@@ -199,6 +245,6 @@ type RoomsBlockResponse struct {
func (cli *Client) GetRoomBlockStatus(ctx context.Context, roomID id.RoomID) (RoomsBlockResponse, error) {
var resp RoomsBlockResponse
reqURL := cli.BuildAdminURL("v1", "rooms", roomID, "block")
- _, err := cli.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp)
+ _, err := cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp)
return resp, err
}
diff --git a/synapseadmin/userapi.go b/synapseadmin/userapi.go
index 9cbb17e4..b1de55b6 100644
--- a/synapseadmin/userapi.go
+++ b/synapseadmin/userapi.go
@@ -32,7 +32,7 @@ type ReqResetPassword struct {
// https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#reset-password
func (cli *Client) ResetPassword(ctx context.Context, req ReqResetPassword) error {
reqURL := cli.BuildAdminURL("v1", "reset_password", req.UserID)
- _, err := cli.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil)
+ _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil)
return err
}
@@ -43,8 +43,8 @@ func (cli *Client) ResetPassword(ctx context.Context, req ReqResetPassword) erro
//
// https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#check-username-availability
func (cli *Client) UsernameAvailable(ctx context.Context, username string) (resp *mautrix.RespRegisterAvailable, err error) {
- u := cli.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "username_available"}, map[string]string{"username": username})
- _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &resp)
+ u := cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "username_available"}, map[string]string{"username": username})
+ _, err = cli.Client.MakeRequest(ctx, http.MethodGet, u, nil, &resp)
if err == nil && !resp.Available {
err = fmt.Errorf(`request returned OK status without "available": true`)
}
@@ -65,7 +65,7 @@ type RespListDevices struct {
//
// https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#list-all-devices
func (cli *Client) ListDevices(ctx context.Context, userID id.UserID) (resp *RespListDevices, err error) {
- _, err = cli.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v2", "users", userID, "devices"), nil, &resp)
+ _, err = cli.Client.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v2", "users", userID, "devices"), nil, &resp)
return
}
@@ -89,7 +89,7 @@ type RespUserInfo struct {
//
// https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#query-user-account
func (cli *Client) GetUserInfo(ctx context.Context, userID id.UserID) (resp *RespUserInfo, err error) {
- _, err = cli.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v2", "users", userID), nil, &resp)
+ _, err = cli.Client.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v2", "users", userID), nil, &resp)
return
}
@@ -102,7 +102,20 @@ type ReqDeleteUser struct {
// https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#deactivate-account
func (cli *Client) DeactivateAccount(ctx context.Context, userID id.UserID, req ReqDeleteUser) error {
reqURL := cli.BuildAdminURL("v1", "deactivate", userID)
- _, err := cli.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil)
+ _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil)
+ return err
+}
+
+type ReqSuspendUser struct {
+ Suspend bool `json:"suspend"`
+}
+
+// SuspendAccount suspends or unsuspends a specific local user account.
+//
+// https://element-hq.github.io/synapse/latest/admin_api/user_admin_api.html#suspendunsuspend-account
+func (cli *Client) SuspendAccount(ctx context.Context, userID id.UserID, req ReqSuspendUser) error {
+ reqURL := cli.BuildAdminURL("v1", "suspend", userID)
+ _, err := cli.Client.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil)
return err
}
@@ -124,7 +137,7 @@ type ReqCreateOrModifyAccount struct {
// https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#create-or-modify-account
func (cli *Client) CreateOrModifyAccount(ctx context.Context, userID id.UserID, req ReqCreateOrModifyAccount) error {
reqURL := cli.BuildAdminURL("v2", "users", userID)
- _, err := cli.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil)
+ _, err := cli.Client.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil)
return err
}
@@ -140,7 +153,7 @@ type ReqSetRatelimit = RatelimitOverride
// https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#set-ratelimit
func (cli *Client) SetUserRatelimit(ctx context.Context, userID id.UserID, req ReqSetRatelimit) error {
reqURL := cli.BuildAdminURL("v1", "users", userID, "override_ratelimit")
- _, err := cli.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil)
+ _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil)
return err
}
@@ -150,7 +163,7 @@ type RespUserRatelimit = RatelimitOverride
//
// https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#get-status-of-ratelimit
func (cli *Client) GetUserRatelimit(ctx context.Context, userID id.UserID) (resp RespUserRatelimit, err error) {
- _, err = cli.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v1", "users", userID, "override_ratelimit"), nil, &resp)
+ _, err = cli.Client.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v1", "users", userID, "override_ratelimit"), nil, &resp)
return
}
@@ -158,6 +171,6 @@ func (cli *Client) GetUserRatelimit(ctx context.Context, userID id.UserID) (resp
//
// https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#delete-ratelimit
func (cli *Client) DeleteUserRatelimit(ctx context.Context, userID id.UserID) (err error) {
- _, err = cli.MakeRequest(ctx, http.MethodDelete, cli.BuildAdminURL("v1", "users", userID, "override_ratelimit"), nil, nil)
+ _, err = cli.Client.MakeRequest(ctx, http.MethodDelete, cli.BuildAdminURL("v1", "users", userID, "override_ratelimit"), nil, nil)
return
}
diff --git a/sync.go b/sync.go
index d4208404..598df8e0 100644
--- a/sync.go
+++ b/sync.go
@@ -90,6 +90,7 @@ func (s *DefaultSyncer) ProcessResponse(ctx context.Context, res *RespSync, sinc
err = fmt.Errorf("ProcessResponse panicked! since=%s panic=%s\n%s", since, r, debug.Stack())
}
}()
+ ctx = context.WithValue(ctx, SyncTokenContextKey, since)
for _, listener := range s.syncListeners {
if !listener(ctx, res, since) {
@@ -97,33 +98,38 @@ func (s *DefaultSyncer) ProcessResponse(ctx context.Context, res *RespSync, sinc
}
}
- s.processSyncEvents(ctx, "", res.ToDevice.Events, event.SourceToDevice)
- s.processSyncEvents(ctx, "", res.Presence.Events, event.SourcePresence)
- s.processSyncEvents(ctx, "", res.AccountData.Events, event.SourceAccountData)
+ s.processSyncEvents(ctx, "", res.ToDevice.Events, event.SourceToDevice, false)
+ s.processSyncEvents(ctx, "", res.Presence.Events, event.SourcePresence, false)
+ s.processSyncEvents(ctx, "", res.AccountData.Events, event.SourceAccountData, false)
for roomID, roomData := range res.Rooms.Join {
- s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceJoin|event.SourceState)
- s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceJoin|event.SourceTimeline)
- s.processSyncEvents(ctx, roomID, roomData.Ephemeral.Events, event.SourceJoin|event.SourceEphemeral)
- s.processSyncEvents(ctx, roomID, roomData.AccountData.Events, event.SourceJoin|event.SourceAccountData)
+ if roomData.StateAfter == nil {
+ s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceJoin|event.SourceState, false)
+ s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceJoin|event.SourceTimeline, false)
+ } else {
+ s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceJoin|event.SourceTimeline, true)
+ s.processSyncEvents(ctx, roomID, roomData.StateAfter.Events, event.SourceJoin|event.SourceState, false)
+ }
+ s.processSyncEvents(ctx, roomID, roomData.Ephemeral.Events, event.SourceJoin|event.SourceEphemeral, false)
+ s.processSyncEvents(ctx, roomID, roomData.AccountData.Events, event.SourceJoin|event.SourceAccountData, false)
}
for roomID, roomData := range res.Rooms.Invite {
- s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceInvite|event.SourceState)
+ s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceInvite|event.SourceState, false)
}
for roomID, roomData := range res.Rooms.Leave {
- s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceLeave|event.SourceState)
- s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceLeave|event.SourceTimeline)
+ s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceLeave|event.SourceState, false)
+ s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceLeave|event.SourceTimeline, false)
}
return
}
-func (s *DefaultSyncer) processSyncEvents(ctx context.Context, roomID id.RoomID, events []*event.Event, source event.Source) {
+func (s *DefaultSyncer) processSyncEvents(ctx context.Context, roomID id.RoomID, events []*event.Event, source event.Source, ignoreState bool) {
for _, evt := range events {
- s.processSyncEvent(ctx, roomID, evt, source)
+ s.processSyncEvent(ctx, roomID, evt, source, ignoreState)
}
}
-func (s *DefaultSyncer) processSyncEvent(ctx context.Context, roomID id.RoomID, evt *event.Event, source event.Source) {
+func (s *DefaultSyncer) processSyncEvent(ctx context.Context, roomID id.RoomID, evt *event.Event, source event.Source, ignoreState bool) {
evt.RoomID = roomID
// Ensure the type class is correct. It's safe to mutate the class since the event type is not a pointer.
@@ -149,6 +155,7 @@ func (s *DefaultSyncer) processSyncEvent(ctx context.Context, roomID id.RoomID,
}
evt.Mautrix.EventSource = source
+ evt.Mautrix.IgnoreState = ignoreState
s.Dispatch(ctx, evt)
}
@@ -191,8 +198,8 @@ func (s *DefaultSyncer) OnFailedSync(res *RespSync, err error) (time.Duration, e
}
var defaultFilter = Filter{
- Room: RoomFilter{
- Timeline: FilterPart{
+ Room: &RoomFilter{
+ Timeline: &FilterPart{
Limit: 50,
},
},
@@ -257,7 +264,7 @@ func dontProcessOldEvents(userID id.UserID, resp *RespSync, since string) bool {
// cli.Syncer.(mautrix.ExtensibleSyncer).OnSync(cli.MoveInviteState)
func (cli *Client) MoveInviteState(ctx context.Context, resp *RespSync, _ string) bool {
for _, meta := range resp.Rooms.Invite {
- var inviteState []event.StrippedState
+ var inviteState []*event.Event
var inviteEvt *event.Event
for _, evt := range meta.State.Events {
if evt.Type == event.StateMember && evt.GetStateKey() == cli.UserID.String() {
@@ -265,12 +272,7 @@ func (cli *Client) MoveInviteState(ctx context.Context, resp *RespSync, _ string
} else {
evt.Type.Class = event.StateEventType
_ = evt.Content.ParseRaw(evt.Type)
- inviteState = append(inviteState, event.StrippedState{
- Content: evt.Content,
- Type: evt.Type,
- StateKey: evt.GetStateKey(),
- Sender: evt.Sender,
- })
+ inviteState = append(inviteState, evt)
}
}
if inviteEvt != nil {
diff --git a/url.go b/url.go
index f35ae5e2..91b3d49d 100644
--- a/url.go
+++ b/url.go
@@ -57,13 +57,13 @@ func BuildURL(baseURL *url.URL, path ...any) *url.URL {
// BuildURL builds a URL with the Client's homeserver and appservice user ID set already.
func (cli *Client) BuildURL(urlPath PrefixableURLPath) string {
- return cli.BuildURLWithQuery(urlPath, nil)
+ return cli.BuildURLWithFullQuery(urlPath, nil)
}
// BuildClientURL builds a URL with the Client's homeserver and appservice user ID set already.
// This method also automatically prepends the client API prefix (/_matrix/client).
func (cli *Client) BuildClientURL(urlPath ...any) string {
- return cli.BuildURLWithQuery(ClientURLPath(urlPath), nil)
+ return cli.BuildURLWithFullQuery(ClientURLPath(urlPath), nil)
}
type PrefixableURLPath interface {
@@ -97,6 +97,19 @@ func (saup SynapseAdminURLPath) FullPath() []any {
// BuildURLWithQuery builds a URL with query parameters in addition to the Client's homeserver
// and appservice user ID set already.
func (cli *Client) BuildURLWithQuery(urlPath PrefixableURLPath, urlQuery map[string]string) string {
+ return cli.BuildURLWithFullQuery(urlPath, func(q url.Values) {
+ for k, v := range urlQuery {
+ q.Set(k, v)
+ }
+ })
+}
+
+// BuildURLWithQuery builds a URL with query parameters in addition to the Client's homeserver
+// and appservice user ID set already.
+func (cli *Client) BuildURLWithFullQuery(urlPath PrefixableURLPath, fn func(q url.Values)) string {
+ if cli == nil {
+ return "client is nil"
+ }
hsURL := *BuildURL(cli.HomeserverURL, urlPath.FullPath()...)
query := hsURL.Query()
if cli.SetAppServiceUserID {
@@ -106,10 +119,8 @@ func (cli *Client) BuildURLWithQuery(urlPath PrefixableURLPath, urlQuery map[str
query.Set("device_id", string(cli.DeviceID))
query.Set("org.matrix.msc3202.device_id", string(cli.DeviceID))
}
- if urlQuery != nil {
- for k, v := range urlQuery {
- query.Set(k, v)
- }
+ if fn != nil {
+ fn(query)
}
hsURL.RawQuery = query.Encode()
return hsURL.String()
diff --git a/version.go b/version.go
index 29368573..f00bbf39 100644
--- a/version.go
+++ b/version.go
@@ -4,10 +4,11 @@ import (
"fmt"
"regexp"
"runtime"
+ "runtime/debug"
"strings"
)
-const Version = "v0.21.1"
+const Version = "v0.26.3"
var GoModVersion = ""
var Commit = ""
@@ -15,11 +16,20 @@ var VersionWithCommit = Version
var DefaultUserAgent = "mautrix-go/" + Version + " go/" + strings.TrimPrefix(runtime.Version(), "go")
-var goModVersionRegex = regexp.MustCompile(`v.+\d{14}-([0-9a-f]{12})`)
-
func init() {
+ if GoModVersion == "" {
+ info, _ := debug.ReadBuildInfo()
+ if info != nil {
+ for _, mod := range info.Deps {
+ if mod.Path == "maunium.net/go/mautrix" {
+ GoModVersion = mod.Version
+ break
+ }
+ }
+ }
+ }
if GoModVersion != "" {
- match := goModVersionRegex.FindStringSubmatch(GoModVersion)
+ match := regexp.MustCompile(`v.+\d{14}-([0-9a-f]{12})`).FindStringSubmatch(GoModVersion)
if match != nil {
Commit = match[1]
}
diff --git a/versions.go b/versions.go
index 672018ff..61b2e4ea 100644
--- a/versions.go
+++ b/versions.go
@@ -60,17 +60,28 @@ type UnstableFeature struct {
}
var (
- FeatureAsyncUploads = UnstableFeature{UnstableFlag: "fi.mau.msc2246.stable", SpecVersion: SpecV17}
- FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17}
- FeatureAuthenticatedMedia = UnstableFeature{UnstableFlag: "org.matrix.msc3916.stable", SpecVersion: SpecV111}
+ FeatureAsyncUploads = UnstableFeature{UnstableFlag: "fi.mau.msc2246.stable", SpecVersion: SpecV17}
+ FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17}
+ FeatureAuthenticatedMedia = UnstableFeature{UnstableFlag: "org.matrix.msc3916.stable", SpecVersion: SpecV111}
+ FeatureUnstableMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms"}
+ FeatureStableMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms.stable" /*, SpecVersion: SpecV118*/}
+ FeatureUserRedaction = UnstableFeature{UnstableFlag: "org.matrix.msc4194"}
+ FeatureViewRedactedContent = UnstableFeature{UnstableFlag: "fi.mau.msc2815"}
+ FeatureUnstableAccountModeration = UnstableFeature{UnstableFlag: "uk.timedout.msc4323"}
+ FeatureStableAccountModeration = UnstableFeature{UnstableFlag: "uk.timedout.msc4323.stable" /*, SpecVersion: SpecV118*/}
+ FeatureUnstableProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133"}
+ FeatureArbitraryProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133.stable", SpecVersion: SpecV116}
+ FeatureRedactSendAsEvent = UnstableFeature{UnstableFlag: "com.beeper.msc4169"}
- BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"}
- BeeperFeatureBatchSending = UnstableFeature{UnstableFlag: "com.beeper.batch_sending"}
- BeeperFeatureRoomYeeting = UnstableFeature{UnstableFlag: "com.beeper.room_yeeting"}
- BeeperFeatureAutojoinInvites = UnstableFeature{UnstableFlag: "com.beeper.room_create_autojoin_invites"}
- BeeperFeatureArbitraryProfileMeta = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_profile_meta"}
- BeeperFeatureAccountDataMute = UnstableFeature{UnstableFlag: "com.beeper.account_data_mute"}
- BeeperFeatureInboxState = UnstableFeature{UnstableFlag: "com.beeper.inbox_state"}
+ BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"}
+ BeeperFeatureBatchSending = UnstableFeature{UnstableFlag: "com.beeper.batch_sending"}
+ BeeperFeatureRoomYeeting = UnstableFeature{UnstableFlag: "com.beeper.room_yeeting"}
+ BeeperFeatureAutojoinInvites = UnstableFeature{UnstableFlag: "com.beeper.room_create_autojoin_invites"}
+ BeeperFeatureArbitraryProfileMeta = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_profile_meta"}
+ BeeperFeatureAccountDataMute = UnstableFeature{UnstableFlag: "com.beeper.account_data_mute"}
+ BeeperFeatureInboxState = UnstableFeature{UnstableFlag: "com.beeper.inbox_state"}
+ BeeperFeatureArbitraryMemberChange = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_member_change"}
+ BeeperFeatureEphemeralEvents = UnstableFeature{UnstableFlag: "com.beeper.ephemeral"}
)
func (versions *RespVersions) Supports(feature UnstableFeature) bool {
@@ -110,6 +121,12 @@ var (
SpecV19 = MustParseSpecVersion("v1.9")
SpecV110 = MustParseSpecVersion("v1.10")
SpecV111 = MustParseSpecVersion("v1.11")
+ SpecV112 = MustParseSpecVersion("v1.12")
+ SpecV113 = MustParseSpecVersion("v1.13")
+ SpecV114 = MustParseSpecVersion("v1.14")
+ SpecV115 = MustParseSpecVersion("v1.15")
+ SpecV116 = MustParseSpecVersion("v1.16")
+ SpecV117 = MustParseSpecVersion("v1.17")
)
func (svf SpecVersionFormat) String() string {