%s", html.EscapeString(qr)),
+ Info: &event.FileInfo{
+ MimeType: "image/png",
+ Width: qrSizePx,
+ Height: qrSizePx,
+ Size: len(qrData),
+ },
+ }
+ if *prevEventID != "" {
+ content.SetEdit(*prevEventID)
+ }
+ newEventID, err := ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventMessage, &event.Content{Parsed: content}, nil)
+ if err != nil {
+ return err
+ }
+ if *prevEventID == "" {
+ *prevEventID = newEventID.EventID
+ }
+ return nil
+}
+
+func sendUserInputAttachments(ce *Event, atts []*bridgev2.LoginUserInputAttachment) error {
+ for _, att := range atts {
+ if att.FileName == "" {
+ return fmt.Errorf("missing attachment filename")
+ }
+ mxc, file, err := ce.Bot.UploadMedia(ce.Ctx, ce.RoomID, att.Content, att.FileName, att.Info.MimeType)
+ if err != nil {
+ return fmt.Errorf("failed to upload attachment %q: %w", att.FileName, err)
+ }
+ content := &event.MessageEventContent{
+ MsgType: att.Type,
+ FileName: att.FileName,
+ URL: mxc,
+ File: file,
+ Info: &event.FileInfo{
+ MimeType: att.Info.MimeType,
+ Width: att.Info.Width,
+ Height: att.Info.Height,
+ Size: att.Info.Size,
+ },
+ Body: att.FileName,
+ }
+ _, err = ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventMessage, &event.Content{Parsed: content}, nil)
+ if err != nil {
+ return nil
+ }
+ }
+ return nil
+}
+
+type contextKey int
+
+const (
+ contextKeyPrevEventID contextKey = iota
+)
+
+func doLoginDisplayAndWait(ce *Event, login bridgev2.LoginProcessDisplayAndWait, step *bridgev2.LoginStep, override *bridgev2.UserLogin) {
+ prevEvent, ok := ce.Ctx.Value(contextKeyPrevEventID).(*id.EventID)
+ if !ok {
+ prevEvent = new(id.EventID)
+ ce.Ctx = context.WithValue(ce.Ctx, contextKeyPrevEventID, prevEvent)
+ }
+ cancelCtx, cancelFunc := context.WithCancel(ce.Ctx)
+ defer cancelFunc()
+ StoreCommandState(ce.User, &CommandState{
+ Action: "Login",
+ Cancel: cancelFunc,
+ })
+ defer StoreCommandState(ce.User, nil)
+ switch step.DisplayAndWaitParams.Type {
+ case bridgev2.LoginDisplayTypeQR:
+ err := sendQR(ce, step.DisplayAndWaitParams.Data, prevEvent)
+ if err != nil {
+ ce.Reply("Failed to send QR code: %v", err)
+ login.Cancel()
+ return
+ }
+ case bridgev2.LoginDisplayTypeEmoji:
+ ce.ReplyAdvanced(step.DisplayAndWaitParams.Data, false, false)
+ case bridgev2.LoginDisplayTypeCode:
+ ce.ReplyAdvanced(fmt.Sprintf("%s", html.EscapeString(step.DisplayAndWaitParams.Data)), false, true)
+ case bridgev2.LoginDisplayTypeNothing:
+ // Do nothing
+ default:
+ ce.Reply("Unsupported display type %q", step.DisplayAndWaitParams.Type)
+ login.Cancel()
+ return
+ }
+ nextStep, err := login.Wait(cancelCtx)
+ // Redact the QR code, unless the next step is refreshing the code (in which case the event is just edited)
+ if *prevEvent != "" && (nextStep == nil || nextStep.StepID != step.StepID) {
+ _, _ = ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventRedaction, &event.Content{
+ Parsed: &event.RedactionEventContent{
+ Redacts: *prevEvent,
+ },
+ }, nil)
+ *prevEvent = ""
+ }
+ if err != nil {
+ ce.Reply("Login failed: %v", err)
+ return
+ }
+ doLoginStep(ce, login, nextStep, override)
+}
+
+type cookieLoginCommandState struct {
+ Login bridgev2.LoginProcessCookies
+ Data *bridgev2.LoginCookiesParams
+ Override *bridgev2.UserLogin
+}
+
+func (clcs *cookieLoginCommandState) prompt(ce *Event) {
+ ce.Reply("Login URL: <%s>", clcs.Data.URL)
+ StoreCommandState(ce.User, &CommandState{
+ Next: MinimalCommandHandlerFunc(clcs.submit),
+ Action: "Login",
+ Meta: clcs,
+ Cancel: clcs.Login.Cancel,
+ })
+}
+
+func (clcs *cookieLoginCommandState) submit(ce *Event) {
+ ce.Redact()
+
+ cookiesInput := make(map[string]string)
+ if strings.HasPrefix(strings.TrimSpace(ce.RawArgs), "curl") {
+ parsed, err := curl.Parse(ce.RawArgs)
+ if err != nil {
+ ce.Reply("Failed to parse curl: %v", err)
+ return
+ }
+ reqCookies := make(map[string]string)
+ for _, cookie := range parsed.Cookies() {
+ reqCookies[cookie.Name], err = url.PathUnescape(cookie.Value)
+ if err != nil {
+ ce.Reply("Failed to parse cookie %s: %v", cookie.Name, err)
+ return
+ }
+ }
+ var missingKeys, unsupportedKeys []string
+ for _, field := range clcs.Data.Fields {
+ var value string
+ var supported bool
+ for _, src := range field.Sources {
+ switch src.Type {
+ case bridgev2.LoginCookieTypeCookie:
+ supported = true
+ value = reqCookies[src.Name]
+ case bridgev2.LoginCookieTypeRequestHeader:
+ supported = true
+ value = parsed.Header.Get(src.Name)
+ case bridgev2.LoginCookieTypeRequestBody:
+ supported = true
+ switch {
+ case parsed.MultipartForm != nil:
+ values, ok := parsed.MultipartForm.Value[src.Name]
+ if ok && len(values) > 0 {
+ value = values[0]
+ }
+ case parsed.ParsedJSON != nil:
+ untypedValue, ok := parsed.ParsedJSON[src.Name]
+ if ok {
+ value = fmt.Sprintf("%v", untypedValue)
+ }
+ }
+ }
+ if value != "" {
+ cookiesInput[field.ID] = value
+ break
+ }
+ }
+ if value == "" && field.Required {
+ if supported {
+ missingKeys = append(missingKeys, field.ID)
+ } else {
+ unsupportedKeys = append(unsupportedKeys, field.ID)
+ }
+ }
+ }
+ if len(unsupportedKeys) > 0 {
+ ce.Reply("Some keys can't be extracted from a cURL request: %+v\n\nPlease provide a JSON object instead.", unsupportedKeys)
+ return
+ } else if len(missingKeys) > 0 {
+ ce.Reply("Missing some keys: %+v", missingKeys)
+ return
+ }
+ } else {
+ err := json.Unmarshal([]byte(ce.RawArgs), &cookiesInput)
+ if err != nil {
+ ce.Reply("Failed to parse input as JSON: %v", err)
+ return
+ }
+ for _, field := range clcs.Data.Fields {
+ val, ok := cookiesInput[field.ID]
+ if ok {
+ cookiesInput[field.ID] = maybeURLDecodeCookie(val, &field)
+ }
+ }
+ }
+ var missingKeys []string
+ for _, field := range clcs.Data.Fields {
+ val, ok := cookiesInput[field.ID]
+ if !ok && field.Required {
+ missingKeys = append(missingKeys, field.ID)
+ }
+ if match, _ := regexp.MatchString(field.Pattern, val); !match {
+ ce.Reply("Invalid value for %s: `%s` doesn't match regex `%s`", field.ID, val, field.Pattern)
+ return
+ }
+ }
+ if len(missingKeys) > 0 {
+ ce.Reply("Missing some keys: %+v", missingKeys)
+ return
+ }
+ StoreCommandState(ce.User, nil)
+ nextStep, err := clcs.Login.SubmitCookies(ce.Ctx, cookiesInput)
+ if err != nil {
+ ce.Reply("Login failed: %v", err)
+ return
+ }
+ doLoginStep(ce, clcs.Login, nextStep, clcs.Override)
+}
+
+func maybeURLDecodeCookie(val string, field *bridgev2.LoginCookieField) string {
+ if val == "" {
+ return val
+ }
+ isCookie := slices.ContainsFunc(field.Sources, func(src bridgev2.LoginCookieFieldSource) bool {
+ return src.Type == bridgev2.LoginCookieTypeCookie
+ })
+ if !isCookie {
+ return val
+ }
+ decoded, err := url.PathUnescape(val)
+ if err != nil {
+ return val
+ }
+ return decoded
+}
+
+func doLoginStep(ce *Event, login bridgev2.LoginProcess, step *bridgev2.LoginStep, override *bridgev2.UserLogin) {
+ ce.Log.Debug().Any("next_step", step).Msg("Got next login step")
+ if step.Instructions != "" {
+ ce.Reply(step.Instructions)
+ }
+
+ switch step.Type {
+ case bridgev2.LoginStepTypeDisplayAndWait:
+ doLoginDisplayAndWait(ce, login.(bridgev2.LoginProcessDisplayAndWait), step, override)
+ case bridgev2.LoginStepTypeCookies:
+ (&cookieLoginCommandState{
+ Login: login.(bridgev2.LoginProcessCookies),
+ Data: step.CookiesParams,
+ Override: override,
+ }).prompt(ce)
+ case bridgev2.LoginStepTypeUserInput:
+ err := sendUserInputAttachments(ce, step.UserInputParams.Attachments)
+ if err != nil {
+ ce.Reply("Failed to send attachments: %v", err)
+ }
+ (&userInputLoginCommandState{
+ Login: login.(bridgev2.LoginProcessUserInput),
+ RemainingFields: step.UserInputParams.Fields,
+ Data: make(map[string]string),
+ Override: override,
+ }).promptNext(ce)
+ case bridgev2.LoginStepTypeComplete:
+ if override != nil && override.ID != step.CompleteParams.UserLoginID {
+ ce.Log.Info().
+ Str("old_login_id", string(override.ID)).
+ Str("new_login_id", string(step.CompleteParams.UserLoginID)).
+ Msg("Login resulted in different remote ID than what was being overridden. Deleting previous login")
+ override.Delete(ce.Ctx, status.BridgeState{
+ StateEvent: status.StateLoggedOut,
+ Reason: "LOGIN_OVERRIDDEN",
+ }, bridgev2.DeleteOpts{LogoutRemote: true})
+ }
+ default:
+ panic(fmt.Errorf("unknown login step type %q", step.Type))
+ }
+}
+
+var CommandListLogins = &FullHandler{
+ Func: fnListLogins,
+ Name: "list-logins",
+ Help: HelpMeta{
+ Section: HelpSectionAuth,
+ Description: "List your logins",
+ },
+ RequiresLoginPermission: true,
+}
+
+func fnListLogins(ce *Event) {
+ logins := ce.User.GetFormattedUserLogins()
+ if len(logins) == 0 {
+ ce.Reply("You're not logged in")
+ } else {
+ ce.Reply("%s", logins)
+ }
+}
+
+var CommandLogout = &FullHandler{
+ Func: fnLogout,
+ Name: "logout",
+ Help: HelpMeta{
+ Section: HelpSectionAuth,
+ Description: "Log out of the bridge",
+ Args: "<_login ID_>",
+ },
+}
+
+func fnLogout(ce *Event) {
+ if len(ce.Args) == 0 {
+ ce.Reply("Usage: `$cmdprefix logout %s not found", html.EscapeString(identifier)), false, true)
+ return
+ }
+ formattedName := formatResolveIdentifierResult(resp)
+ if createChat {
+ name := resp.Portal.Name
+ if name == "" {
+ name = resp.Portal.MXID.String()
+ }
+ if !resp.JustCreated {
+ ce.Reply("You already have a direct chat with %s at [%s](%s)", formattedName, name, resp.Portal.MXID.URI().MatrixToURL())
+ } else {
+ ce.Reply("Created chat with %s: [%s](%s)", formattedName, name, resp.Portal.MXID.URI().MatrixToURL())
+ }
+ } else {
+ ce.Reply("Found %s", formattedName)
+ }
+}
+
+var CommandCreateGroup = &FullHandler{
+ Func: fnCreateGroup,
+ Name: "create-group",
+ Aliases: []string{"create"},
+ Help: HelpMeta{
+ Section: HelpSectionChats,
+ Description: "Create a new group chat for the current Matrix room",
+ Args: "[_group type_]",
+ },
+ RequiresLogin: true,
+ NetworkAPI: NetworkAPIImplements[bridgev2.GroupCreatingNetworkAPI],
+}
+
+func getState[T any](ctx context.Context, roomID id.RoomID, evtType event.Type, provider bridgev2.MatrixConnectorWithArbitraryRoomState) (content T) {
+ evt, err := provider.GetStateEvent(ctx, roomID, evtType, "")
+ if err != nil {
+ zerolog.Ctx(ctx).Err(err).Stringer("event_type", evtType).Msg("Failed to get state event for group creation")
+ } else if evt != nil {
+ content, _ = evt.Content.Parsed.(T)
+ }
+ return
+}
+
+func fnCreateGroup(ce *Event) {
+ ce.Bridge.Matrix.GetCapabilities()
+ login, api, remainingArgs := getClientForStartingChat[bridgev2.GroupCreatingNetworkAPI](ce, "creating group")
+ if api == nil {
+ return
+ }
+ stateProvider, ok := ce.Bridge.Matrix.(bridgev2.MatrixConnectorWithArbitraryRoomState)
+ if !ok {
+ ce.Reply("Matrix connector doesn't support fetching room state")
+ return
+ }
+ members, err := ce.Bridge.Matrix.GetMembers(ce.Ctx, ce.RoomID)
+ if err != nil {
+ ce.Log.Err(err).Msg("Failed to get room members for group creation")
+ ce.Reply("Failed to get room members: %v", err)
+ return
+ }
+ caps := ce.Bridge.Network.GetCapabilities()
+ params := &bridgev2.GroupCreateParams{
+ Username: "",
+ Participants: make([]networkid.UserID, 0, len(members)-2),
+ Parent: nil, // TODO check space parent event
+ Name: getState[*event.RoomNameEventContent](ce.Ctx, ce.RoomID, event.StateRoomName, stateProvider),
+ Avatar: getState[*event.RoomAvatarEventContent](ce.Ctx, ce.RoomID, event.StateRoomAvatar, stateProvider),
+ Topic: getState[*event.TopicEventContent](ce.Ctx, ce.RoomID, event.StateTopic, stateProvider),
+ Disappear: getState[*event.BeeperDisappearingTimer](ce.Ctx, ce.RoomID, event.StateBeeperDisappearingTimer, stateProvider),
+ RoomID: ce.RoomID,
+ }
+ for userID, member := range members {
+ if userID == ce.User.MXID || userID == ce.Bot.GetMXID() || !member.Membership.IsInviteOrJoin() {
+ continue
+ }
+ if parsedUserID, ok := ce.Bridge.Matrix.ParseGhostMXID(userID); ok {
+ params.Participants = append(params.Participants, parsedUserID)
+ } else if !ce.Bridge.Config.SplitPortals {
+ if user, err := ce.Bridge.GetExistingUserByMXID(ce.Ctx, userID); err != nil {
+ ce.Log.Err(err).Stringer("user_id", userID).Msg("Failed to get user for room member")
+ } else if user != nil {
+ // TODO add user logins to participants
+ //for _, login := range user.GetUserLogins() {
+ // params.Participants = append(params.Participants, login.GetUserID())
+ //}
+ }
+ }
+ }
+
+ if len(caps.Provisioning.GroupCreation) == 0 {
+ ce.Reply("No group creation types defined in network capabilities")
+ return
+ } else if len(remainingArgs) > 0 {
+ params.Type = remainingArgs[0]
+ } else if len(caps.Provisioning.GroupCreation) == 1 {
+ for params.Type = range caps.Provisioning.GroupCreation {
+ // The loop assigns the variable we want
+ }
+ } else {
+ types := strings.Join(slices.Collect(maps.Keys(caps.Provisioning.GroupCreation)), "`, `")
+ ce.Reply("Please specify type of group to create: `%s`", types)
+ return
+ }
+ resp, err := provisionutil.CreateGroup(ce.Ctx, login, params)
+ if err != nil {
+ ce.Reply("Failed to create group: %v", err)
+ return
+ }
+ var postfix string
+ if len(resp.FailedParticipants) > 0 {
+ failedParticipantsStrings := make([]string, len(resp.FailedParticipants))
+ i := 0
+ for participantID, meta := range resp.FailedParticipants {
+ failedParticipantsStrings[i] = fmt.Sprintf("* %s: %s", format.SafeMarkdownCode(participantID), meta.Reason)
+ i++
+ }
+ postfix += "\n\nFailed to add some participants:\n" + strings.Join(failedParticipantsStrings, "\n")
+ }
+ ce.Reply("Successfully created group `%s`%s", resp.ID, postfix)
+}
+
+var CommandSearch = &FullHandler{
+ Func: fnSearch,
+ Name: "search",
+ Help: HelpMeta{
+ Section: HelpSectionChats,
+ Description: "Search for users on the remote network",
+ Args: "<_query_>",
+ },
+ RequiresLogin: true,
+ NetworkAPI: NetworkAPIImplements[bridgev2.UserSearchingNetworkAPI],
+}
+
+func fnSearch(ce *Event) {
+ if len(ce.Args) == 0 {
+ ce.Reply("Usage: `$cmdprefix search
+ Blockquote = "blockquote", // blockquote
+ InlineLink = "inline_link", // a
+ UserLink = "user_link", //
+ RoomLink = "room_link", //
+ EventLink = "event_link", //
+ AtRoomMention = "at_room_mention", // @room (no html tag)
+ UnorderedList = "unordered_list", // ul + li
+ OrderedList = "ordered_list", // ol + li
+ ListStart = "ordered_list.start", //
+ ListJumpValue = "ordered_list.jump_value", // -
+ CustomEmoji = "custom_emoji", //
+ Spoiler = "spoiler", //
+ SpoilerReason = "spoiler.reason", //
+ TextForegroundColor = "color.foreground", //
+ TextBackgroundColor = "color.background", //
+ HorizontalLine = "horizontal_line", // hr
+ Headers = "headers", // h1, h2, h3, h4, h5, h6
+ Superscript = "superscript", // sup
+ Subscript = "subscript", // sub
+ Math = "math", //
+ DetailsSummary = "details_summary", // ...
...
+ Table = "table", // table, thead, tbody, tr, th, td
+}
diff --git a/event/capabilities.go b/event/capabilities.go
new file mode 100644
index 00000000..a86c726b
--- /dev/null
+++ b/event/capabilities.go
@@ -0,0 +1,414 @@
+// Copyright (c) 2024 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package event
+
+import (
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/binary"
+ "fmt"
+ "io"
+ "mime"
+ "slices"
+ "strings"
+
+ "go.mau.fi/util/exerrors"
+ "go.mau.fi/util/jsontime"
+ "go.mau.fi/util/ptr"
+ "golang.org/x/exp/constraints"
+ "golang.org/x/exp/maps"
+)
+
+type RoomFeatures struct {
+ ID string `json:"id,omitempty"`
+
+ // N.B. New fields need to be added to the Hash function to be included in the deduplication hash.
+
+ Formatting FormattingFeatureMap `json:"formatting,omitempty"`
+ File FileFeatureMap `json:"file,omitempty"`
+ State StateFeatureMap `json:"state,omitempty"`
+ MemberActions MemberFeatureMap `json:"member_actions,omitempty"`
+
+ MaxTextLength int `json:"max_text_length,omitempty"`
+
+ LocationMessage CapabilitySupportLevel `json:"location_message,omitempty"`
+ Poll CapabilitySupportLevel `json:"poll,omitempty"`
+ Thread CapabilitySupportLevel `json:"thread,omitempty"`
+ Reply CapabilitySupportLevel `json:"reply,omitempty"`
+
+ Edit CapabilitySupportLevel `json:"edit,omitempty"`
+ EditMaxCount int `json:"edit_max_count,omitempty"`
+ EditMaxAge *jsontime.Seconds `json:"edit_max_age,omitempty"`
+ Delete CapabilitySupportLevel `json:"delete,omitempty"`
+ DeleteForMe bool `json:"delete_for_me,omitempty"`
+ DeleteMaxAge *jsontime.Seconds `json:"delete_max_age,omitempty"`
+
+ DisappearingTimer *DisappearingTimerCapability `json:"disappearing_timer,omitempty"`
+
+ Reaction CapabilitySupportLevel `json:"reaction,omitempty"`
+ ReactionCount int `json:"reaction_count,omitempty"`
+ AllowedReactions []string `json:"allowed_reactions,omitempty"`
+ CustomEmojiReactions bool `json:"custom_emoji_reactions,omitempty"`
+
+ ReadReceipts bool `json:"read_receipts,omitempty"`
+ TypingNotifications bool `json:"typing_notifications,omitempty"`
+ Archive bool `json:"archive,omitempty"`
+ MarkAsUnread bool `json:"mark_as_unread,omitempty"`
+ DeleteChat bool `json:"delete_chat,omitempty"`
+ DeleteChatForEveryone bool `json:"delete_chat_for_everyone,omitempty"`
+
+ MessageRequest *MessageRequestFeatures `json:"message_request,omitempty"`
+
+ PerMessageProfileRelay bool `json:"-"`
+}
+
+func (rf *RoomFeatures) GetID() string {
+ if rf.ID != "" {
+ return rf.ID
+ }
+ return base64.RawURLEncoding.EncodeToString(rf.Hash())
+}
+
+func (rf *RoomFeatures) Clone() *RoomFeatures {
+ if rf == nil {
+ return nil
+ }
+ clone := *rf
+ clone.File = clone.File.Clone()
+ clone.Formatting = maps.Clone(clone.Formatting)
+ clone.State = clone.State.Clone()
+ clone.MemberActions = clone.MemberActions.Clone()
+ clone.EditMaxAge = ptr.Clone(clone.EditMaxAge)
+ clone.DeleteMaxAge = ptr.Clone(clone.DeleteMaxAge)
+ clone.DisappearingTimer = clone.DisappearingTimer.Clone()
+ clone.AllowedReactions = slices.Clone(clone.AllowedReactions)
+ clone.MessageRequest = clone.MessageRequest.Clone()
+ return &clone
+}
+
+type MemberFeatureMap map[MemberAction]CapabilitySupportLevel
+
+func (mfm MemberFeatureMap) Clone() MemberFeatureMap {
+ return maps.Clone(mfm)
+}
+
+type MemberAction string
+
+const (
+ MemberActionBan MemberAction = "ban"
+ MemberActionKick MemberAction = "kick"
+ MemberActionLeave MemberAction = "leave"
+ MemberActionRevokeInvite MemberAction = "revoke_invite"
+ MemberActionInvite MemberAction = "invite"
+)
+
+type StateFeatureMap map[string]*StateFeatures
+
+func (sfm StateFeatureMap) Clone() StateFeatureMap {
+ dup := maps.Clone(sfm)
+ for key, value := range dup {
+ dup[key] = value.Clone()
+ }
+ return dup
+}
+
+type StateFeatures struct {
+ Level CapabilitySupportLevel `json:"level"`
+}
+
+func (sf *StateFeatures) Clone() *StateFeatures {
+ if sf == nil {
+ return nil
+ }
+ clone := *sf
+ return &clone
+}
+
+func (sf *StateFeatures) Hash() []byte {
+ return sf.Level.Hash()
+}
+
+type FormattingFeatureMap map[FormattingFeature]CapabilitySupportLevel
+
+type FileFeatureMap map[CapabilityMsgType]*FileFeatures
+
+func (ffm FileFeatureMap) Clone() FileFeatureMap {
+ dup := maps.Clone(ffm)
+ for key, value := range dup {
+ dup[key] = value.Clone()
+ }
+ return dup
+}
+
+type DisappearingTimerCapability struct {
+ Types []DisappearingType `json:"types"`
+ Timers []jsontime.Milliseconds `json:"timers,omitempty"`
+
+ OmitEmptyTimer bool `json:"omit_empty_timer,omitempty"`
+}
+
+func (dtc *DisappearingTimerCapability) Clone() *DisappearingTimerCapability {
+ if dtc == nil {
+ return nil
+ }
+ clone := *dtc
+ clone.Types = slices.Clone(clone.Types)
+ clone.Timers = slices.Clone(clone.Timers)
+ return &clone
+}
+
+func (dtc *DisappearingTimerCapability) Supports(content *BeeperDisappearingTimer) bool {
+ if dtc == nil || content == nil || content.Type == DisappearingTypeNone {
+ return true
+ }
+ return slices.Contains(dtc.Types, content.Type) && (dtc.Timers == nil || slices.Contains(dtc.Timers, content.Timer))
+}
+
+type MessageRequestFeatures struct {
+ AcceptWithMessage CapabilitySupportLevel `json:"accept_with_message,omitempty"`
+ AcceptWithButton CapabilitySupportLevel `json:"accept_with_button,omitempty"`
+}
+
+func (mrf *MessageRequestFeatures) Clone() *MessageRequestFeatures {
+ return ptr.Clone(mrf)
+}
+
+func (mrf *MessageRequestFeatures) Hash() []byte {
+ if mrf == nil {
+ return nil
+ }
+ hasher := sha256.New()
+ hashValue(hasher, "accept_with_message", mrf.AcceptWithMessage)
+ hashValue(hasher, "accept_with_button", mrf.AcceptWithButton)
+ return hasher.Sum(nil)
+}
+
+type CapabilityMsgType = MessageType
+
+// Message types which are used for event capability signaling, but aren't real values for the msgtype field.
+const (
+ CapMsgVoice CapabilityMsgType = "org.matrix.msc3245.voice"
+ CapMsgGIF CapabilityMsgType = "fi.mau.gif"
+ CapMsgSticker CapabilityMsgType = "m.sticker"
+)
+
+type CapabilitySupportLevel int
+
+func (csl CapabilitySupportLevel) Partial() bool {
+ return csl >= CapLevelPartialSupport
+}
+
+func (csl CapabilitySupportLevel) Full() bool {
+ return csl >= CapLevelFullySupported
+}
+
+func (csl CapabilitySupportLevel) Reject() bool {
+ return csl <= CapLevelRejected
+}
+
+const (
+ CapLevelRejected CapabilitySupportLevel = -2 // The feature is unsupported and messages using it will be rejected.
+ CapLevelDropped CapabilitySupportLevel = -1 // The feature is unsupported and has no fallback. The message will go through, but data may be lost.
+ CapLevelUnsupported CapabilitySupportLevel = 0 // The feature is unsupported, but may have a fallback.
+ CapLevelPartialSupport CapabilitySupportLevel = 1 // The feature is partially supported (e.g. it may be converted to a different format).
+ CapLevelFullySupported CapabilitySupportLevel = 2 // The feature is fully supported and can be safely used.
+)
+
+type FormattingFeature string
+
+const (
+ FmtBold FormattingFeature = "bold" // strong, b
+ FmtItalic FormattingFeature = "italic" // em, i
+ FmtUnderline FormattingFeature = "underline" // u
+ FmtStrikethrough FormattingFeature = "strikethrough" // del, s
+ FmtInlineCode FormattingFeature = "inline_code" // code
+ FmtCodeBlock FormattingFeature = "code_block" // pre + code
+ FmtSyntaxHighlighting FormattingFeature = "code_block.syntax_highlighting" //
+ FmtBlockquote FormattingFeature = "blockquote" // blockquote
+ FmtInlineLink FormattingFeature = "inline_link" // a
+ FmtUserLink FormattingFeature = "user_link" //
+ FmtRoomLink FormattingFeature = "room_link" //
+ FmtEventLink FormattingFeature = "event_link" //
+ FmtAtRoomMention FormattingFeature = "at_room_mention" // @room (no html tag)
+ FmtUnorderedList FormattingFeature = "unordered_list" // ul + li
+ FmtOrderedList FormattingFeature = "ordered_list" // ol + li
+ FmtListStart FormattingFeature = "ordered_list.start" //
+ FmtListJumpValue FormattingFeature = "ordered_list.jump_value" // -
+ FmtCustomEmoji FormattingFeature = "custom_emoji" //
+ FmtSpoiler FormattingFeature = "spoiler" //
+ FmtSpoilerReason FormattingFeature = "spoiler.reason" //
+ FmtTextForegroundColor FormattingFeature = "color.foreground" //
+ FmtTextBackgroundColor FormattingFeature = "color.background" //
+ FmtHorizontalLine FormattingFeature = "horizontal_line" // hr
+ FmtHeaders FormattingFeature = "headers" // h1, h2, h3, h4, h5, h6
+ FmtSuperscript FormattingFeature = "superscript" // sup
+ FmtSubscript FormattingFeature = "subscript" // sub
+ FmtMath FormattingFeature = "math" //
+ FmtDetailsSummary FormattingFeature = "details_summary" // ...
...
+ FmtTable FormattingFeature = "table" // table, thead, tbody, tr, th, td
+)
+
+type FileFeatures struct {
+ // N.B. New fields need to be added to the Hash function to be included in the deduplication hash.
+
+ MimeTypes map[string]CapabilitySupportLevel `json:"mime_types"`
+
+ Caption CapabilitySupportLevel `json:"caption,omitempty"`
+ MaxCaptionLength int `json:"max_caption_length,omitempty"`
+
+ MaxSize int64 `json:"max_size,omitempty"`
+ MaxWidth int `json:"max_width,omitempty"`
+ MaxHeight int `json:"max_height,omitempty"`
+ MaxDuration *jsontime.Seconds `json:"max_duration,omitempty"`
+
+ ViewOnce bool `json:"view_once,omitempty"`
+}
+
+func (ff *FileFeatures) GetMimeSupport(inputType string) CapabilitySupportLevel {
+ match, ok := ff.MimeTypes[inputType]
+ if ok {
+ return match
+ }
+ if strings.IndexByte(inputType, ';') != -1 {
+ plainMime, _, _ := mime.ParseMediaType(inputType)
+ if plainMime != "" {
+ if match, ok = ff.MimeTypes[plainMime]; ok {
+ return match
+ }
+ }
+ }
+ if slash := strings.IndexByte(inputType, '/'); slash > 0 {
+ generalType := fmt.Sprintf("%s/*", inputType[:slash])
+ if match, ok = ff.MimeTypes[generalType]; ok {
+ return match
+ }
+ }
+ match, ok = ff.MimeTypes["*/*"]
+ if ok {
+ return match
+ }
+ return CapLevelRejected
+}
+
+type hashable interface {
+ Hash() []byte
+}
+
+func hashMap[Key ~string, Value hashable](w io.Writer, name string, data map[Key]Value) {
+ keys := maps.Keys(data)
+ slices.Sort(keys)
+ exerrors.Must(w.Write([]byte(name)))
+ for _, key := range keys {
+ exerrors.Must(w.Write([]byte(key)))
+ exerrors.Must(w.Write(data[key].Hash()))
+ exerrors.Must(w.Write([]byte{0}))
+ }
+}
+
+func hashValue(w io.Writer, name string, data hashable) {
+ exerrors.Must(w.Write([]byte(name)))
+ exerrors.Must(w.Write(data.Hash()))
+}
+
+func hashInt[T constraints.Integer](w io.Writer, name string, data T) {
+ exerrors.Must(w.Write(binary.BigEndian.AppendUint64([]byte(name), uint64(data))))
+}
+
+func hashBool[T ~bool](w io.Writer, name string, data T) {
+ exerrors.Must(w.Write([]byte(name)))
+ if data {
+ exerrors.Must(w.Write([]byte{1}))
+ } else {
+ exerrors.Must(w.Write([]byte{0}))
+ }
+}
+
+func (csl CapabilitySupportLevel) Hash() []byte {
+ return []byte{byte(csl + 128)}
+}
+
+func (rf *RoomFeatures) Hash() []byte {
+ hasher := sha256.New()
+
+ hashMap(hasher, "formatting", rf.Formatting)
+ hashMap(hasher, "file", rf.File)
+ hashMap(hasher, "state", rf.State)
+ hashMap(hasher, "member_actions", rf.MemberActions)
+
+ hashInt(hasher, "max_text_length", rf.MaxTextLength)
+
+ hashValue(hasher, "location_message", rf.LocationMessage)
+ hashValue(hasher, "poll", rf.Poll)
+ hashValue(hasher, "thread", rf.Thread)
+ hashValue(hasher, "reply", rf.Reply)
+
+ hashValue(hasher, "edit", rf.Edit)
+ hashInt(hasher, "edit_max_count", rf.EditMaxCount)
+ hashInt(hasher, "edit_max_age", rf.EditMaxAge.Get())
+
+ hashValue(hasher, "delete", rf.Delete)
+ hashBool(hasher, "delete_for_me", rf.DeleteForMe)
+ hashInt(hasher, "delete_max_age", rf.DeleteMaxAge.Get())
+ hashValue(hasher, "disappearing_timer", rf.DisappearingTimer)
+
+ hashValue(hasher, "reaction", rf.Reaction)
+ hashInt(hasher, "reaction_count", rf.ReactionCount)
+ hasher.Write([]byte("allowed_reactions"))
+ for _, reaction := range rf.AllowedReactions {
+ hasher.Write([]byte(reaction))
+ }
+ hashBool(hasher, "custom_emoji_reactions", rf.CustomEmojiReactions)
+
+ hashBool(hasher, "read_receipts", rf.ReadReceipts)
+ hashBool(hasher, "typing_notifications", rf.TypingNotifications)
+ hashBool(hasher, "archive", rf.Archive)
+ hashBool(hasher, "mark_as_unread", rf.MarkAsUnread)
+ hashBool(hasher, "delete_chat", rf.DeleteChat)
+ hashBool(hasher, "delete_chat_for_everyone", rf.DeleteChatForEveryone)
+ hashValue(hasher, "message_request", rf.MessageRequest)
+
+ return hasher.Sum(nil)
+}
+
+func (dtc *DisappearingTimerCapability) Hash() []byte {
+ if dtc == nil {
+ return nil
+ }
+ hasher := sha256.New()
+ hasher.Write([]byte("types"))
+ for _, t := range dtc.Types {
+ hasher.Write([]byte(t))
+ }
+ hasher.Write([]byte("timers"))
+ for _, timer := range dtc.Timers {
+ hashInt(hasher, "", timer.Milliseconds())
+ }
+ return hasher.Sum(nil)
+}
+
+func (ff *FileFeatures) Hash() []byte {
+ hasher := sha256.New()
+ hashMap(hasher, "mime_types", ff.MimeTypes)
+ hashValue(hasher, "caption", ff.Caption)
+ hashInt(hasher, "max_caption_length", ff.MaxCaptionLength)
+ hashInt(hasher, "max_size", ff.MaxSize)
+ hashInt(hasher, "max_width", ff.MaxWidth)
+ hashInt(hasher, "max_height", ff.MaxHeight)
+ hashInt(hasher, "max_duration", ff.MaxDuration.Get())
+ hashBool(hasher, "view_once", ff.ViewOnce)
+ return hasher.Sum(nil)
+}
+
+func (ff *FileFeatures) Clone() *FileFeatures {
+ if ff == nil {
+ return nil
+ }
+ clone := *ff
+ clone.MimeTypes = maps.Clone(clone.MimeTypes)
+ clone.MaxDuration = ptr.Clone(clone.MaxDuration)
+ return &clone
+}
diff --git a/event/cmdschema/content.go b/event/cmdschema/content.go
new file mode 100644
index 00000000..ce07c4c0
--- /dev/null
+++ b/event/cmdschema/content.go
@@ -0,0 +1,78 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package cmdschema
+
+import (
+ "crypto/sha256"
+ "encoding/base64"
+ "fmt"
+ "reflect"
+ "slices"
+
+ "go.mau.fi/util/exsync"
+ "go.mau.fi/util/ptr"
+
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+)
+
+type EventContent struct {
+ Command string `json:"command"`
+ Aliases []string `json:"aliases,omitempty"`
+ Parameters []*Parameter `json:"parameters,omitempty"`
+ Description *event.ExtensibleTextContainer `json:"description,omitempty"`
+ TailParam string `json:"fi.mau.tail_parameter,omitempty"`
+}
+
+func (ec *EventContent) Validate() error {
+ if ec == nil {
+ return fmt.Errorf("event content is nil")
+ } else if ec.Command == "" {
+ return fmt.Errorf("command is empty")
+ }
+ var tailFound bool
+ dupMap := exsync.NewSet[string]()
+ for i, p := range ec.Parameters {
+ if err := p.Validate(); err != nil {
+ return fmt.Errorf("parameter %q (#%d) is invalid: %w", ptr.Val(p).Key, i+1, err)
+ } else if !dupMap.Add(p.Key) {
+ return fmt.Errorf("duplicate parameter key %q at #%d", p.Key, i+1)
+ } else if p.Key == ec.TailParam {
+ tailFound = true
+ } else if tailFound && !p.Optional {
+ return fmt.Errorf("required parameter %q (#%d) is after tail parameter %q", p.Key, i+1, ec.TailParam)
+ }
+ }
+ if ec.TailParam != "" && !tailFound {
+ return fmt.Errorf("tail parameter %q not found in parameters", ec.TailParam)
+ }
+ return nil
+}
+
+func (ec *EventContent) IsValid() bool {
+ return ec.Validate() == nil
+}
+
+func (ec *EventContent) StateKey(owner id.UserID) string {
+ hash := sha256.Sum256([]byte(ec.Command + owner.String()))
+ return base64.StdEncoding.EncodeToString(hash[:])
+}
+
+func (ec *EventContent) Equals(other *EventContent) bool {
+ if ec == nil || other == nil {
+ return ec == other
+ }
+ return ec.Command == other.Command &&
+ slices.Equal(ec.Aliases, other.Aliases) &&
+ slices.EqualFunc(ec.Parameters, other.Parameters, (*Parameter).Equals) &&
+ ec.Description.Equals(other.Description) &&
+ ec.TailParam == other.TailParam
+}
+
+func init() {
+ event.TypeMap[event.StateMSC4391BotCommand] = reflect.TypeOf(EventContent{})
+}
diff --git a/event/cmdschema/parameter.go b/event/cmdschema/parameter.go
new file mode 100644
index 00000000..4193b297
--- /dev/null
+++ b/event/cmdschema/parameter.go
@@ -0,0 +1,286 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package cmdschema
+
+import (
+ "encoding/json"
+ "fmt"
+ "slices"
+
+ "go.mau.fi/util/exslices"
+
+ "maunium.net/go/mautrix/event"
+)
+
+type Parameter struct {
+ Key string `json:"key"`
+ Schema *ParameterSchema `json:"schema"`
+ Optional bool `json:"optional,omitempty"`
+ Description *event.ExtensibleTextContainer `json:"description,omitempty"`
+ DefaultValue any `json:"fi.mau.default_value,omitempty"`
+}
+
+func (p *Parameter) Equals(other *Parameter) bool {
+ if p == nil || other == nil {
+ return p == other
+ }
+ return p.Key == other.Key &&
+ p.Schema.Equals(other.Schema) &&
+ p.Optional == other.Optional &&
+ p.Description.Equals(other.Description) &&
+ p.DefaultValue == other.DefaultValue // TODO this won't work for room/event ID values
+}
+
+func (p *Parameter) Validate() error {
+ if p == nil {
+ return fmt.Errorf("parameter is nil")
+ } else if p.Key == "" {
+ return fmt.Errorf("key is empty")
+ }
+ return p.Schema.Validate()
+}
+
+func (p *Parameter) IsValid() bool {
+ return p.Validate() == nil
+}
+
+func (p *Parameter) GetDefaultValue() any {
+ if p != nil && p.DefaultValue != nil {
+ return p.DefaultValue
+ } else if p == nil || p.Optional {
+ return nil
+ }
+ return p.Schema.GetDefaultValue()
+}
+
+type PrimitiveType string
+
+const (
+ PrimitiveTypeString PrimitiveType = "string"
+ PrimitiveTypeInteger PrimitiveType = "integer"
+ PrimitiveTypeBoolean PrimitiveType = "boolean"
+ PrimitiveTypeServerName PrimitiveType = "server_name"
+ PrimitiveTypeUserID PrimitiveType = "user_id"
+ PrimitiveTypeRoomID PrimitiveType = "room_id"
+ PrimitiveTypeRoomAlias PrimitiveType = "room_alias"
+ PrimitiveTypeEventID PrimitiveType = "event_id"
+)
+
+func (pt PrimitiveType) Schema() *ParameterSchema {
+ return &ParameterSchema{
+ SchemaType: SchemaTypePrimitive,
+ Type: pt,
+ }
+}
+
+func (pt PrimitiveType) IsValid() bool {
+ switch pt {
+ case PrimitiveTypeString,
+ PrimitiveTypeInteger,
+ PrimitiveTypeBoolean,
+ PrimitiveTypeServerName,
+ PrimitiveTypeUserID,
+ PrimitiveTypeRoomID,
+ PrimitiveTypeRoomAlias,
+ PrimitiveTypeEventID:
+ return true
+ default:
+ return false
+ }
+}
+
+type SchemaType string
+
+const (
+ SchemaTypePrimitive SchemaType = "primitive"
+ SchemaTypeArray SchemaType = "array"
+ SchemaTypeUnion SchemaType = "union"
+ SchemaTypeLiteral SchemaType = "literal"
+)
+
+type ParameterSchema struct {
+ SchemaType SchemaType `json:"schema_type"`
+ Type PrimitiveType `json:"type,omitempty"` // Only for primitive
+ Items *ParameterSchema `json:"items,omitempty"` // Only for array
+ Variants []*ParameterSchema `json:"variants,omitempty"` // Only for union
+ Value any `json:"value,omitempty"` // Only for literal
+}
+
+func Literal(value any) *ParameterSchema {
+ return &ParameterSchema{
+ SchemaType: SchemaTypeLiteral,
+ Value: value,
+ }
+}
+
+func Enum(values ...any) *ParameterSchema {
+ return Union(exslices.CastFunc(values, Literal)...)
+}
+
+func flattenUnion(variants []*ParameterSchema) []*ParameterSchema {
+ var flattened []*ParameterSchema
+ for _, variant := range variants {
+ switch variant.SchemaType {
+ case SchemaTypeArray:
+ panic(fmt.Errorf("illegal array schema in union"))
+ case SchemaTypeUnion:
+ flattened = append(flattened, flattenUnion(variant.Variants)...)
+ default:
+ flattened = append(flattened, variant)
+ }
+ }
+ return flattened
+}
+
+func Union(variants ...*ParameterSchema) *ParameterSchema {
+ needsFlattening := false
+ for _, variant := range variants {
+ if variant.SchemaType == SchemaTypeArray {
+ panic(fmt.Errorf("illegal array schema in union"))
+ } else if variant.SchemaType == SchemaTypeUnion {
+ needsFlattening = true
+ }
+ }
+ if needsFlattening {
+ variants = flattenUnion(variants)
+ }
+ return &ParameterSchema{
+ SchemaType: SchemaTypeUnion,
+ Variants: variants,
+ }
+}
+
+func Array(items *ParameterSchema) *ParameterSchema {
+ if items.SchemaType == SchemaTypeArray {
+ panic(fmt.Errorf("illegal array schema in array"))
+ }
+ return &ParameterSchema{
+ SchemaType: SchemaTypeArray,
+ Items: items,
+ }
+}
+
+func (ps *ParameterSchema) GetDefaultValue() any {
+ if ps == nil {
+ return nil
+ }
+ switch ps.SchemaType {
+ case SchemaTypePrimitive:
+ switch ps.Type {
+ case PrimitiveTypeInteger:
+ return 0
+ case PrimitiveTypeBoolean:
+ return false
+ default:
+ return ""
+ }
+ case SchemaTypeArray:
+ return []any{}
+ case SchemaTypeUnion:
+ if len(ps.Variants) > 0 {
+ return ps.Variants[0].GetDefaultValue()
+ }
+ return nil
+ case SchemaTypeLiteral:
+ return ps.Value
+ default:
+ return nil
+ }
+}
+
+func (ps *ParameterSchema) IsValid() bool {
+ return ps.validate("") == nil
+}
+
+func (ps *ParameterSchema) Validate() error {
+ return ps.validate("")
+}
+
+func (ps *ParameterSchema) validate(parent SchemaType) error {
+ if ps == nil {
+ return fmt.Errorf("schema is nil")
+ }
+ switch ps.SchemaType {
+ case SchemaTypePrimitive:
+ if !ps.Type.IsValid() {
+ return fmt.Errorf("invalid primitive type %s", ps.Type)
+ } else if ps.Items != nil || ps.Variants != nil || ps.Value != nil {
+ return fmt.Errorf("primitive schema has extra fields")
+ }
+ return nil
+ case SchemaTypeArray:
+ if parent != "" {
+ return fmt.Errorf("arrays can't be nested in other types")
+ } else if err := ps.Items.validate(ps.SchemaType); err != nil {
+ return fmt.Errorf("item schema is invalid: %w", err)
+ } else if ps.Type != "" || ps.Variants != nil || ps.Value != nil {
+ return fmt.Errorf("array schema has extra fields")
+ }
+ return nil
+ case SchemaTypeUnion:
+ if len(ps.Variants) == 0 {
+ return fmt.Errorf("no variants specified for union")
+ } else if parent != "" && parent != SchemaTypeArray {
+ return fmt.Errorf("unions can't be nested in anything other than arrays")
+ }
+ for i, v := range ps.Variants {
+ if err := v.validate(ps.SchemaType); err != nil {
+ return fmt.Errorf("variant #%d is invalid: %w", i+1, err)
+ }
+ }
+ if ps.Type != "" || ps.Items != nil || ps.Value != nil {
+ return fmt.Errorf("union schema has extra fields")
+ }
+ return nil
+ case SchemaTypeLiteral:
+ switch typedVal := ps.Value.(type) {
+ case string, float64, int, int64, json.Number, bool, RoomIDValue, *RoomIDValue:
+ // ok
+ case map[string]any:
+ if typedVal["type"] != "event_id" && typedVal["type"] != "room_id" {
+ return fmt.Errorf("literal value has invalid map data")
+ }
+ default:
+ return fmt.Errorf("literal value has unsupported type %T", ps.Value)
+ }
+ if ps.Type != "" || ps.Items != nil || ps.Variants != nil {
+ return fmt.Errorf("literal schema has extra fields")
+ }
+ return nil
+ default:
+ return fmt.Errorf("invalid schema type %s", ps.SchemaType)
+ }
+}
+
+func (ps *ParameterSchema) Equals(other *ParameterSchema) bool {
+ if ps == nil || other == nil {
+ return ps == other
+ }
+ return ps.SchemaType == other.SchemaType &&
+ ps.Type == other.Type &&
+ ps.Items.Equals(other.Items) &&
+ slices.EqualFunc(ps.Variants, other.Variants, (*ParameterSchema).Equals) &&
+ ps.Value == other.Value // TODO this won't work for room/event ID values
+}
+
+func (ps *ParameterSchema) AllowsPrimitive(prim PrimitiveType) bool {
+ switch ps.SchemaType {
+ case SchemaTypePrimitive:
+ return ps.Type == prim
+ case SchemaTypeUnion:
+ for _, variant := range ps.Variants {
+ if variant.AllowsPrimitive(prim) {
+ return true
+ }
+ }
+ return false
+ case SchemaTypeArray:
+ return ps.Items.AllowsPrimitive(prim)
+ default:
+ return false
+ }
+}
diff --git a/event/cmdschema/parse.go b/event/cmdschema/parse.go
new file mode 100644
index 00000000..92e69b60
--- /dev/null
+++ b/event/cmdschema/parse.go
@@ -0,0 +1,478 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package cmdschema
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "regexp"
+ "strconv"
+ "strings"
+
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+)
+
+const botArrayOpener = "<"
+const botArrayCloser = ">"
+
+func parseQuoted(val string) (parsed, remaining string, quoted bool) {
+ if len(val) == 0 {
+ return
+ }
+ if !strings.HasPrefix(val, `"`) {
+ spaceIdx := strings.IndexByte(val, ' ')
+ if spaceIdx == -1 {
+ parsed = val
+ } else {
+ parsed = val[:spaceIdx]
+ remaining = strings.TrimLeft(val[spaceIdx+1:], " ")
+ }
+ return
+ }
+ val = val[1:]
+ var buf strings.Builder
+ for {
+ quoteIdx := strings.IndexByte(val, '"')
+ var valUntilQuote string
+ if quoteIdx == -1 {
+ valUntilQuote = val
+ } else {
+ valUntilQuote = val[:quoteIdx]
+ }
+ escapeIdx := strings.IndexByte(valUntilQuote, '\\')
+ if escapeIdx >= 0 {
+ buf.WriteString(val[:escapeIdx])
+ if len(val) > escapeIdx+1 {
+ buf.WriteByte(val[escapeIdx+1])
+ }
+ val = val[min(escapeIdx+2, len(val)):]
+ } else if quoteIdx >= 0 {
+ buf.WriteString(val[:quoteIdx])
+ val = val[quoteIdx+1:]
+ break
+ } else if buf.Len() == 0 {
+ // Unterminated quote, no escape characters, val is the whole input
+ return val, "", true
+ } else {
+ // Unterminated quote, but there were escape characters previously
+ buf.WriteString(val)
+ val = ""
+ break
+ }
+ }
+ return buf.String(), strings.TrimLeft(val, " "), true
+}
+
+// ParseInput tries to parse the given text into a bot command event matching this command definition.
+//
+// If the prefix doesn't match, this will return a nil content and nil error.
+// If the prefix does match, some content is always returned, but there may still be an error if parsing failed.
+func (ec *EventContent) ParseInput(owner id.UserID, sigils []string, input string) (content *event.MessageEventContent, err error) {
+ prefix := ec.parsePrefix(input, sigils, owner.String())
+ if prefix == "" {
+ return nil, nil
+ }
+ content = &event.MessageEventContent{
+ MsgType: event.MsgText,
+ Body: input,
+ Mentions: &event.Mentions{UserIDs: []id.UserID{owner}},
+ MSC4391BotCommand: &event.MSC4391BotCommandInput{
+ Command: ec.Command,
+ },
+ }
+ content.MSC4391BotCommand.Arguments, err = ec.ParseArguments(input[len(prefix):])
+ return content, err
+}
+
+func (ec *EventContent) ParseArguments(input string) (json.RawMessage, error) {
+ args := make(map[string]any)
+ var retErr error
+ setError := func(err error) {
+ if err != nil && retErr == nil {
+ retErr = err
+ }
+ }
+ processParameter := func(param *Parameter, isLast, isTail, isNamed bool) {
+ origInput := input
+ var nextVal string
+ var wasQuoted bool
+ if param.Schema.SchemaType == SchemaTypeArray {
+ hasOpener := strings.HasPrefix(input, botArrayOpener)
+ arrayClosed := false
+ if hasOpener {
+ input = input[len(botArrayOpener):]
+ if strings.HasPrefix(input, botArrayCloser) {
+ input = strings.TrimLeft(input[len(botArrayCloser):], " ")
+ arrayClosed = true
+ }
+ }
+ var collector []any
+ for len(input) > 0 && !arrayClosed {
+ //origInput = input
+ nextVal, input, wasQuoted = parseQuoted(input)
+ if !wasQuoted && hasOpener && strings.HasSuffix(nextVal, botArrayCloser) {
+ // The value wasn't quoted and has the array delimiter at the end, close the array
+ nextVal = strings.TrimRight(nextVal, botArrayCloser)
+ arrayClosed = true
+ } else if hasOpener && strings.HasPrefix(input, botArrayCloser) {
+ // The value was quoted or there was a space, and the next character is the
+ // array delimiter, close the array
+ input = strings.TrimLeft(input[len(botArrayCloser):], " ")
+ arrayClosed = true
+ } else if !hasOpener && !isLast {
+ // For array arguments in the middle without the <> delimiters, stop after the first item
+ arrayClosed = true
+ }
+ parsedVal, err := param.Schema.Items.ParseString(nextVal)
+ if err == nil {
+ collector = append(collector, parsedVal)
+ } else if hasOpener || isLast {
+ setError(fmt.Errorf("failed to parse item #%d of array %s: %w", len(collector)+1, param.Key, err))
+ } else {
+ //input = origInput
+ }
+ }
+ args[param.Key] = collector
+ } else {
+ nextVal, input, wasQuoted = parseQuoted(input)
+ if (isLast || isTail) && !wasQuoted && len(input) > 0 {
+ // If the last argument is not quoted, just treat the rest of the string
+ // as the argument without escapes (arguments with escapes should be quoted).
+ nextVal += " " + input
+ input = ""
+ }
+ // Special case for named boolean parameters: if no value is given, treat it as true
+ if nextVal == "" && !wasQuoted && isNamed && param.Schema.AllowsPrimitive(PrimitiveTypeBoolean) {
+ args[param.Key] = true
+ return
+ }
+ if nextVal == "" && !wasQuoted && !isNamed && !param.Optional {
+ setError(fmt.Errorf("missing value for required parameter %s", param.Key))
+ }
+ parsedVal, err := param.Schema.ParseString(nextVal)
+ if err != nil {
+ args[param.Key] = param.GetDefaultValue()
+ // For optional parameters that fail to parse, restore the input and try passing it as the next parameter
+ if param.Optional && !isLast && !isNamed {
+ input = strings.TrimLeft(origInput, " ")
+ } else if !param.Optional || isNamed {
+ setError(fmt.Errorf("failed to parse %s: %w", param.Key, err))
+ }
+ } else {
+ args[param.Key] = parsedVal
+ }
+ }
+ }
+ skipParams := make([]bool, len(ec.Parameters))
+ for i, param := range ec.Parameters {
+ for strings.HasPrefix(input, "--") {
+ nameEndIdx := strings.IndexAny(input, " =")
+ if nameEndIdx == -1 {
+ nameEndIdx = len(input)
+ }
+ overrideParam, paramIdx := ec.parameterByName(input[2:nameEndIdx])
+ if overrideParam != nil {
+ // Trim the equals sign, but leave spaces alone to let parseQuoted treat it as empty input
+ input = strings.TrimPrefix(input[nameEndIdx:], "=")
+ skipParams[paramIdx] = true
+ processParameter(overrideParam, false, false, true)
+ } else {
+ break
+ }
+ }
+ isTail := param.Key == ec.TailParam
+ if skipParams[i] || (param.Optional && !isTail) {
+ continue
+ }
+ processParameter(param, i == len(ec.Parameters)-1, isTail, false)
+ }
+ jsonArgs, marshalErr := json.Marshal(args)
+ if marshalErr != nil {
+ return nil, fmt.Errorf("failed to marshal arguments: %w", marshalErr)
+ }
+ return jsonArgs, retErr
+}
+
+func (ec *EventContent) parameterByName(name string) (*Parameter, int) {
+ for i, param := range ec.Parameters {
+ if strings.EqualFold(param.Key, name) {
+ return param, i
+ }
+ }
+ return nil, -1
+}
+
+func (ec *EventContent) parsePrefix(origInput string, sigils []string, owner string) (prefix string) {
+ input := origInput
+ var chosenSigil string
+ for _, sigil := range sigils {
+ if strings.HasPrefix(input, sigil) {
+ chosenSigil = sigil
+ break
+ }
+ }
+ if chosenSigil == "" {
+ return ""
+ }
+ input = input[len(chosenSigil):]
+ var chosenAlias string
+ if !strings.HasPrefix(input, ec.Command) {
+ for _, alias := range ec.Aliases {
+ if strings.HasPrefix(input, alias) {
+ chosenAlias = alias
+ break
+ }
+ }
+ if chosenAlias == "" {
+ return ""
+ }
+ } else {
+ chosenAlias = ec.Command
+ }
+ input = strings.TrimPrefix(input[len(chosenAlias):], owner)
+ if input == "" || input[0] == ' ' {
+ input = strings.TrimLeft(input, " ")
+ return origInput[:len(origInput)-len(input)]
+ }
+ return ""
+}
+
+func (pt PrimitiveType) ValidateValue(value any) bool {
+ _, err := pt.NormalizeValue(value)
+ return err == nil
+}
+
+func normalizeNumber(value any) (int, error) {
+ switch typedValue := value.(type) {
+ case int:
+ return typedValue, nil
+ case int64:
+ return int(typedValue), nil
+ case float64:
+ return int(typedValue), nil
+ case json.Number:
+ if i, err := typedValue.Int64(); err != nil {
+ return 0, fmt.Errorf("failed to parse json.Number: %w", err)
+ } else {
+ return int(i), nil
+ }
+ default:
+ return 0, fmt.Errorf("unsupported type %T for integer", value)
+ }
+}
+
+func (pt PrimitiveType) NormalizeValue(value any) (any, error) {
+ switch pt {
+ case PrimitiveTypeInteger:
+ return normalizeNumber(value)
+ case PrimitiveTypeBoolean:
+ bv, ok := value.(bool)
+ if !ok {
+ return nil, fmt.Errorf("unsupported type %T for boolean", value)
+ }
+ return bv, nil
+ case PrimitiveTypeString, PrimitiveTypeServerName:
+ str, ok := value.(string)
+ if !ok {
+ return nil, fmt.Errorf("unsupported type %T for string", value)
+ }
+ return str, pt.validateStringValue(str)
+ case PrimitiveTypeUserID, PrimitiveTypeRoomAlias:
+ str, ok := value.(string)
+ if !ok {
+ return nil, fmt.Errorf("unsupported type %T for user ID or room alias", value)
+ } else if plainErr := pt.validateStringValue(str); plainErr == nil {
+ return str, nil
+ } else if parsed, err := id.ParseMatrixURIOrMatrixToURL(str); err != nil {
+ return nil, fmt.Errorf("couldn't parse %q as plain ID nor matrix URI: %w / %w", value, plainErr, err)
+ } else if parsed.Sigil1 == '@' && pt == PrimitiveTypeUserID {
+ return parsed.UserID(), nil
+ } else if parsed.Sigil1 == '#' && pt == PrimitiveTypeRoomAlias {
+ return parsed.RoomAlias(), nil
+ } else {
+ return nil, fmt.Errorf("unexpected sigil %c for user ID or room alias", parsed.Sigil1)
+ }
+ case PrimitiveTypeRoomID, PrimitiveTypeEventID:
+ riv, err := NormalizeRoomIDValue(value)
+ if err != nil {
+ return nil, err
+ }
+ return riv, riv.Validate()
+ default:
+ return nil, fmt.Errorf("cannot normalize value for argument type %s", pt)
+ }
+}
+
+func (pt PrimitiveType) validateStringValue(value string) error {
+ switch pt {
+ case PrimitiveTypeString:
+ return nil
+ case PrimitiveTypeServerName:
+ if !id.ValidateServerName(value) {
+ return fmt.Errorf("invalid server name: %q", value)
+ }
+ return nil
+ case PrimitiveTypeUserID:
+ _, _, err := id.UserID(value).ParseAndValidateRelaxed()
+ return err
+ case PrimitiveTypeRoomAlias:
+ sigil, localpart, serverName := id.ParseCommonIdentifier(value)
+ if sigil != '#' || localpart == "" || serverName == "" {
+ return fmt.Errorf("invalid room alias: %q", value)
+ } else if !id.ValidateServerName(serverName) {
+ return fmt.Errorf("invalid server name in room alias: %q", serverName)
+ }
+ return nil
+ default:
+ panic(fmt.Errorf("validateStringValue called with invalid type %s", pt))
+ }
+}
+
+func parseBoolean(val string) (bool, error) {
+ if len(val) == 0 {
+ return false, fmt.Errorf("cannot parse empty string as boolean")
+ }
+ switch strings.ToLower(val) {
+ case "t", "true", "y", "yes", "1":
+ return true, nil
+ case "f", "false", "n", "no", "0":
+ return false, nil
+ default:
+ return false, fmt.Errorf("invalid boolean string: %q", val)
+ }
+}
+
+var markdownLinkRegex = regexp.MustCompile(`^\[.+]\(([^)]+)\)$`)
+
+func parseRoomOrEventID(value string) (*RoomIDValue, error) {
+ if strings.HasPrefix(value, "[") && strings.Contains(value, "](") && strings.HasSuffix(value, ")") {
+ matches := markdownLinkRegex.FindStringSubmatch(value)
+ if len(matches) == 2 {
+ value = matches[1]
+ }
+ }
+ parsed, err := id.ParseMatrixURIOrMatrixToURL(value)
+ if err != nil && strings.HasPrefix(value, "!") {
+ return &RoomIDValue{
+ Type: PrimitiveTypeRoomID,
+ RoomID: id.RoomID(value),
+ }, nil
+ }
+ if err != nil {
+ return nil, err
+ } else if parsed.Sigil1 != '!' {
+ return nil, fmt.Errorf("unexpected sigil %c for room ID", parsed.Sigil1)
+ } else if parsed.MXID2 != "" && parsed.Sigil2 != '$' {
+ return nil, fmt.Errorf("unexpected sigil %c for event ID", parsed.Sigil2)
+ }
+ valType := PrimitiveTypeRoomID
+ if parsed.MXID2 != "" {
+ valType = PrimitiveTypeEventID
+ }
+ return &RoomIDValue{
+ Type: valType,
+ RoomID: parsed.RoomID(),
+ Via: parsed.Via,
+ EventID: parsed.EventID(),
+ }, nil
+}
+
+func (pt PrimitiveType) ParseString(value string) (any, error) {
+ switch pt {
+ case PrimitiveTypeInteger:
+ return strconv.Atoi(value)
+ case PrimitiveTypeBoolean:
+ return parseBoolean(value)
+ case PrimitiveTypeString, PrimitiveTypeServerName, PrimitiveTypeUserID:
+ return value, pt.validateStringValue(value)
+ case PrimitiveTypeRoomAlias:
+ plainErr := pt.validateStringValue(value)
+ if plainErr == nil {
+ return value, nil
+ }
+ parsed, err := id.ParseMatrixURIOrMatrixToURL(value)
+ if err != nil {
+ return nil, fmt.Errorf("couldn't parse %q as plain room alias nor matrix URI: %w / %w", value, plainErr, err)
+ } else if parsed.Sigil1 != '#' {
+ return nil, fmt.Errorf("unexpected sigil %c for room alias", parsed.Sigil1)
+ }
+ return parsed.RoomAlias(), nil
+ case PrimitiveTypeRoomID, PrimitiveTypeEventID:
+ parsed, err := parseRoomOrEventID(value)
+ if err != nil {
+ return nil, err
+ } else if pt != parsed.Type {
+ return nil, fmt.Errorf("mismatching argument type: expected %s but got %s", pt, parsed.Type)
+ }
+ return parsed, nil
+ default:
+ return nil, fmt.Errorf("cannot parse string for argument type %s", pt)
+ }
+}
+
+func (ps *ParameterSchema) ParseString(value string) (any, error) {
+ if ps == nil {
+ return nil, fmt.Errorf("parameter schema is nil")
+ }
+ switch ps.SchemaType {
+ case SchemaTypePrimitive:
+ return ps.Type.ParseString(value)
+ case SchemaTypeLiteral:
+ switch typedValue := ps.Value.(type) {
+ case string:
+ if value == typedValue {
+ return typedValue, nil
+ } else {
+ return nil, fmt.Errorf("literal value %q does not match %q", typedValue, value)
+ }
+ case int, int64, float64, json.Number:
+ expectedVal, _ := normalizeNumber(typedValue)
+ intVal, err := strconv.Atoi(value)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse integer literal: %w", err)
+ } else if intVal != expectedVal {
+ return nil, fmt.Errorf("literal value %d does not match %d", expectedVal, intVal)
+ }
+ return intVal, nil
+ case bool:
+ boolVal, err := parseBoolean(value)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse boolean literal: %w", err)
+ } else if boolVal != typedValue {
+ return nil, fmt.Errorf("literal value %t does not match %t", typedValue, boolVal)
+ }
+ return boolVal, nil
+ case RoomIDValue, *RoomIDValue, map[string]any, json.RawMessage:
+ expectedVal, _ := NormalizeRoomIDValue(typedValue)
+ parsed, err := parseRoomOrEventID(value)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse room or event ID literal: %w", err)
+ } else if !parsed.Equals(expectedVal) {
+ return nil, fmt.Errorf("literal value %s does not match %s", expectedVal, parsed)
+ }
+ return parsed, nil
+ default:
+ return nil, fmt.Errorf("unsupported literal type %T", ps.Value)
+ }
+ case SchemaTypeUnion:
+ var errs []error
+ for _, variant := range ps.Variants {
+ if parsed, err := variant.ParseString(value); err == nil {
+ return parsed, nil
+ } else {
+ errs = append(errs, err)
+ }
+ }
+ return nil, fmt.Errorf("no union variant matched: %w", errors.Join(errs...))
+ case SchemaTypeArray:
+ return nil, fmt.Errorf("cannot parse string for array schema type")
+ default:
+ return nil, fmt.Errorf("unknown schema type %s", ps.SchemaType)
+ }
+}
diff --git a/event/cmdschema/parse_test.go b/event/cmdschema/parse_test.go
new file mode 100644
index 00000000..1e0d1817
--- /dev/null
+++ b/event/cmdschema/parse_test.go
@@ -0,0 +1,118 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package cmdschema
+
+import (
+ "bytes"
+ "encoding/json"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "go.mau.fi/util/exbytes"
+ "go.mau.fi/util/exerrors"
+
+ "maunium.net/go/mautrix/event/cmdschema/testdata"
+)
+
+type QuoteParseOutput struct {
+ Parsed string
+ Remaining string
+ Quoted bool
+}
+
+func (qpo *QuoteParseOutput) UnmarshalJSON(data []byte) error {
+ var arr []any
+ if err := json.Unmarshal(data, &arr); err != nil {
+ return err
+ }
+ qpo.Parsed = arr[0].(string)
+ qpo.Remaining = arr[1].(string)
+ qpo.Quoted = arr[2].(bool)
+ return nil
+}
+
+type QuoteParseTestData struct {
+ Name string `json:"name"`
+ Input string `json:"input"`
+ Output QuoteParseOutput `json:"output"`
+}
+
+func loadFile[T any](name string) (into T) {
+ quoteData := exerrors.Must(testdata.FS.ReadFile(name))
+ exerrors.PanicIfNotNil(json.Unmarshal(quoteData, &into))
+ return
+}
+
+func TestParseQuoted(t *testing.T) {
+ qptd := loadFile[[]QuoteParseTestData]("parse_quote.json")
+ for _, test := range qptd {
+ t.Run(test.Name, func(t *testing.T) {
+ parsed, remaining, quoted := parseQuoted(test.Input)
+ assert.Equalf(t, test.Output, QuoteParseOutput{
+ Parsed: parsed,
+ Remaining: remaining,
+ Quoted: quoted,
+ }, "Failed with input `%s`", test.Input)
+ // Note: can't just test that requoted == input, because some inputs
+ // have unnecessary escapes which won't survive roundtripping
+ t.Run("roundtrip", func(t *testing.T) {
+ requoted := quoteString(parsed) + " " + remaining
+ reparsed, newRemaining, _ := parseQuoted(requoted)
+ assert.Equal(t, parsed, reparsed)
+ assert.Equal(t, remaining, newRemaining)
+ })
+ })
+ }
+}
+
+type CommandTestData struct {
+ Spec *EventContent
+ Tests []*CommandTestUnit
+}
+
+type CommandTestUnit struct {
+ Name string `json:"name"`
+ Input string `json:"input"`
+ Broken string `json:"broken,omitempty"`
+ Error bool `json:"error"`
+ Output json.RawMessage `json:"output"`
+}
+
+func compactJSON(input json.RawMessage) json.RawMessage {
+ var buf bytes.Buffer
+ exerrors.PanicIfNotNil(json.Compact(&buf, input))
+ return buf.Bytes()
+}
+
+func TestMSC4391BotCommandEventContent_ParseInput(t *testing.T) {
+ for _, cmd := range exerrors.Must(testdata.FS.ReadDir("commands")) {
+ t.Run(strings.TrimSuffix(cmd.Name(), ".json"), func(t *testing.T) {
+ ctd := loadFile[CommandTestData]("commands/" + cmd.Name())
+ for _, test := range ctd.Tests {
+ outputStr := exbytes.UnsafeString(compactJSON(test.Output))
+ t.Run(test.Name, func(t *testing.T) {
+ if test.Broken != "" {
+ t.Skip(test.Broken)
+ }
+ output, err := ctd.Spec.ParseInput("@testbot", []string{"/"}, test.Input)
+ if test.Error {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ }
+ if outputStr == "null" {
+ assert.Nil(t, output)
+ } else {
+ assert.Equal(t, ctd.Spec.Command, output.MSC4391BotCommand.Command)
+ assert.Equalf(t, outputStr, exbytes.UnsafeString(output.MSC4391BotCommand.Arguments), "Input: %s", test.Input)
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/event/cmdschema/roomid.go b/event/cmdschema/roomid.go
new file mode 100644
index 00000000..98c421fc
--- /dev/null
+++ b/event/cmdschema/roomid.go
@@ -0,0 +1,135 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package cmdschema
+
+import (
+ "encoding/json"
+ "fmt"
+ "slices"
+ "strings"
+
+ "maunium.net/go/mautrix/id"
+)
+
+var ParameterSchemaJoinableRoom = Union(
+ PrimitiveTypeRoomID.Schema(),
+ PrimitiveTypeRoomAlias.Schema(),
+)
+
+type RoomIDValue struct {
+ Type PrimitiveType `json:"type"`
+ RoomID id.RoomID `json:"id"`
+ Via []string `json:"via,omitempty"`
+ EventID id.EventID `json:"event_id,omitempty"`
+}
+
+func NormalizeRoomIDValue(input any) (riv *RoomIDValue, err error) {
+ switch typedValue := input.(type) {
+ case map[string]any, json.RawMessage:
+ var raw json.RawMessage
+ if raw, err = json.Marshal(input); err != nil {
+ err = fmt.Errorf("failed to roundtrip room ID value: %w", err)
+ } else if err = json.Unmarshal(raw, &riv); err != nil {
+ err = fmt.Errorf("failed to roundtrip room ID value: %w", err)
+ }
+ case *RoomIDValue:
+ riv = typedValue
+ case RoomIDValue:
+ riv = &typedValue
+ default:
+ err = fmt.Errorf("unsupported type %T for room or event ID", input)
+ }
+ return
+}
+
+func (riv *RoomIDValue) String() string {
+ return riv.URI().String()
+}
+
+func (riv *RoomIDValue) URI() *id.MatrixURI {
+ if riv == nil {
+ return nil
+ }
+ switch riv.Type {
+ case PrimitiveTypeRoomID:
+ return riv.RoomID.URI(riv.Via...)
+ case PrimitiveTypeEventID:
+ return riv.RoomID.EventURI(riv.EventID, riv.Via...)
+ default:
+ return nil
+ }
+}
+
+func (riv *RoomIDValue) Equals(other *RoomIDValue) bool {
+ if riv == nil || other == nil {
+ return riv == other
+ }
+ return riv.Type == other.Type &&
+ riv.RoomID == other.RoomID &&
+ riv.EventID == other.EventID &&
+ slices.Equal(riv.Via, other.Via)
+}
+
+func (riv *RoomIDValue) Validate() error {
+ if riv == nil {
+ return fmt.Errorf("value is nil")
+ }
+ switch riv.Type {
+ case PrimitiveTypeRoomID:
+ if riv.EventID != "" {
+ return fmt.Errorf("event ID must be empty for room ID type")
+ }
+ case PrimitiveTypeEventID:
+ if !strings.HasPrefix(riv.EventID.String(), "$") {
+ return fmt.Errorf("event ID not valid: %q", riv.EventID)
+ }
+ default:
+ return fmt.Errorf("unexpected type %s for room/event ID value", riv.Type)
+ }
+ for _, via := range riv.Via {
+ if !id.ValidateServerName(via) {
+ return fmt.Errorf("invalid server name %q in vias", via)
+ }
+ }
+ sigil, localpart, serverName := id.ParseCommonIdentifier(riv.RoomID)
+ if sigil != '!' {
+ return fmt.Errorf("room ID does not start with !: %q", riv.RoomID)
+ } else if localpart == "" && serverName == "" {
+ return fmt.Errorf("room ID has empty localpart and server name: %q", riv.RoomID)
+ } else if serverName != "" && !id.ValidateServerName(serverName) {
+ return fmt.Errorf("invalid server name %q in room ID", serverName)
+ }
+ return nil
+}
+
+func (riv *RoomIDValue) IsValid() bool {
+ return riv.Validate() == nil
+}
+
+type RoomIDOrString string
+
+func (ros *RoomIDOrString) UnmarshalJSON(data []byte) error {
+ if len(data) == 0 {
+ return fmt.Errorf("empty data for room ID or string")
+ }
+ if data[0] == '"' {
+ var str string
+ if err := json.Unmarshal(data, &str); err != nil {
+ return err
+ }
+ *ros = RoomIDOrString(str)
+ return nil
+ }
+ var riv RoomIDValue
+ if err := json.Unmarshal(data, &riv); err != nil {
+ return err
+ } else if err = riv.Validate(); err != nil {
+ return err
+ }
+ *ros = RoomIDOrString(riv.String())
+ return nil
+}
diff --git a/event/cmdschema/stringify.go b/event/cmdschema/stringify.go
new file mode 100644
index 00000000..c5c57c53
--- /dev/null
+++ b/event/cmdschema/stringify.go
@@ -0,0 +1,122 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package cmdschema
+
+import (
+ "encoding/json"
+ "strconv"
+ "strings"
+)
+
+var quoteEscaper = strings.NewReplacer(
+ `"`, `\"`,
+ `\`, `\\`,
+)
+
+const charsToQuote = ` \` + botArrayOpener + botArrayCloser
+
+func quoteString(val string) string {
+ if val == "" {
+ return `""`
+ }
+ val = quoteEscaper.Replace(val)
+ if strings.ContainsAny(val, charsToQuote) {
+ return `"` + val + `"`
+ }
+ return val
+}
+
+func (ec *EventContent) StringifyArgs(args any) string {
+ var argMap map[string]any
+ switch typedArgs := args.(type) {
+ case json.RawMessage:
+ err := json.Unmarshal(typedArgs, &argMap)
+ if err != nil {
+ return ""
+ }
+ case map[string]any:
+ argMap = typedArgs
+ default:
+ if b, err := json.Marshal(args); err != nil {
+ return ""
+ } else if err = json.Unmarshal(b, &argMap); err != nil {
+ return ""
+ }
+ }
+ parts := make([]string, 0, len(ec.Parameters))
+ for i, param := range ec.Parameters {
+ isLast := i == len(ec.Parameters)-1
+ val := argMap[param.Key]
+ if val == nil {
+ val = param.DefaultValue
+ if val == nil && !param.Optional {
+ val = param.Schema.GetDefaultValue()
+ }
+ }
+ if val == nil {
+ continue
+ }
+ var stringified string
+ if param.Schema.SchemaType == SchemaTypeArray {
+ stringified = arrayArgumentToString(val, isLast)
+ } else {
+ stringified = singleArgumentToString(val)
+ }
+ if stringified != "" {
+ parts = append(parts, stringified)
+ }
+ }
+ return strings.Join(parts, " ")
+}
+
+func arrayArgumentToString(val any, isLast bool) string {
+ valArr, ok := val.([]any)
+ if !ok {
+ return ""
+ }
+ parts := make([]string, 0, len(valArr))
+ for _, elem := range valArr {
+ stringified := singleArgumentToString(elem)
+ if stringified != "" {
+ parts = append(parts, stringified)
+ }
+ }
+ joinedParts := strings.Join(parts, " ")
+ if isLast && len(parts) > 0 {
+ return joinedParts
+ }
+ return botArrayOpener + joinedParts + botArrayCloser
+}
+
+func singleArgumentToString(val any) string {
+ switch typedVal := val.(type) {
+ case string:
+ return quoteString(typedVal)
+ case json.Number:
+ return typedVal.String()
+ case bool:
+ return strconv.FormatBool(typedVal)
+ case int:
+ return strconv.Itoa(typedVal)
+ case int64:
+ return strconv.FormatInt(typedVal, 10)
+ case float64:
+ return strconv.FormatInt(int64(typedVal), 10)
+ case map[string]any, json.RawMessage, RoomIDValue, *RoomIDValue:
+ normalized, err := NormalizeRoomIDValue(typedVal)
+ if err != nil {
+ return ""
+ }
+ uri := normalized.URI()
+ if uri == nil {
+ return ""
+ }
+ return quoteString(uri.String())
+ default:
+ return ""
+ }
+}
diff --git a/event/cmdschema/testdata/commands.schema.json b/event/cmdschema/testdata/commands.schema.json
new file mode 100644
index 00000000..e53382db
--- /dev/null
+++ b/event/cmdschema/testdata/commands.schema.json
@@ -0,0 +1,281 @@
+{
+ "$schema": "https://json-schema.org/draft/2020-12/schema#",
+ "$id": "commands.schema.json",
+ "title": "ParseInput test cases",
+ "description": "JSON schema for test case files containing command specifications and test cases",
+ "type": "object",
+ "required": [
+ "spec",
+ "tests"
+ ],
+ "additionalProperties": false,
+ "properties": {
+ "spec": {
+ "title": "MSC4391 Command Description",
+ "description": "JSON schema defining the structure of a bot command event content",
+ "type": "object",
+ "required": [
+ "command"
+ ],
+ "additionalProperties": false,
+ "properties": {
+ "command": {
+ "type": "string",
+ "description": "The command name that triggers this bot command"
+ },
+ "aliases": {
+ "type": "array",
+ "description": "Alternative names/aliases for this command",
+ "items": {
+ "type": "string"
+ }
+ },
+ "parameters": {
+ "type": "array",
+ "description": "List of parameters accepted by this command",
+ "items": {
+ "$ref": "#/$defs/Parameter"
+ }
+ },
+ "description": {
+ "$ref": "#/$defs/ExtensibleTextContainer",
+ "description": "Human-readable description of the command"
+ },
+ "fi.mau.tail_parameter": {
+ "type": "string",
+ "description": "The key of the parameter that accepts remaining arguments as tail text"
+ },
+ "source": {
+ "type": "string",
+ "description": "The user ID of the bot that responds to this command"
+ }
+ }
+ },
+ "tests": {
+ "type": "array",
+ "description": "Array of test cases for the command",
+ "items": {
+ "type": "object",
+ "description": "A single test case for command parsing",
+ "required": [
+ "name",
+ "input"
+ ],
+ "additionalProperties": false,
+ "properties": {
+ "name": {
+ "type": "string",
+ "description": "The name of the test case"
+ },
+ "input": {
+ "type": "string",
+ "description": "The command input string to parse"
+ },
+ "output": {
+ "description": "The expected parsed parameter values, or null if the parsing is expected to fail",
+ "oneOf": [
+ {
+ "type": "object",
+ "additionalProperties": true
+ },
+ {
+ "type": "null"
+ }
+ ]
+ },
+ "error": {
+ "type": "boolean",
+ "description": "Whether parsing should result in an error. May still produce output.",
+ "default": false
+ }
+ }
+ }
+ }
+ },
+ "$defs": {
+ "ExtensibleTextContainer": {
+ "type": "object",
+ "description": "Container for text that can have multiple representations",
+ "required": [
+ "m.text"
+ ],
+ "properties": {
+ "m.text": {
+ "type": "array",
+ "description": "Array of text representations in different formats",
+ "items": {
+ "$ref": "#/$defs/ExtensibleText"
+ }
+ }
+ }
+ },
+ "ExtensibleText": {
+ "type": "object",
+ "description": "A text representation with a specific MIME type",
+ "required": [
+ "body"
+ ],
+ "properties": {
+ "body": {
+ "type": "string",
+ "description": "The text content"
+ },
+ "mimetype": {
+ "type": "string",
+ "description": "The MIME type of the text (e.g., text/plain, text/html)",
+ "default": "text/plain",
+ "examples": [
+ "text/plain",
+ "text/html"
+ ]
+ }
+ }
+ },
+ "Parameter": {
+ "type": "object",
+ "description": "A parameter definition for a command",
+ "required": [
+ "key",
+ "schema"
+ ],
+ "additionalProperties": false,
+ "properties": {
+ "key": {
+ "type": "string",
+ "description": "The identifier for this parameter"
+ },
+ "schema": {
+ "$ref": "#/$defs/ParameterSchema",
+ "description": "The schema defining the type and structure of this parameter"
+ },
+ "optional": {
+ "type": "boolean",
+ "description": "Whether this parameter is optional",
+ "default": false
+ },
+ "description": {
+ "$ref": "#/$defs/ExtensibleTextContainer",
+ "description": "Human-readable description of this parameter"
+ },
+ "fi.mau.default_value": {
+ "description": "Default value for this parameter if not provided"
+ }
+ }
+ },
+ "ParameterSchema": {
+ "type": "object",
+ "description": "Schema definition for a parameter value",
+ "required": [
+ "schema_type"
+ ],
+ "additionalProperties": false,
+ "properties": {
+ "schema_type": {
+ "type": "string",
+ "enum": [
+ "primitive",
+ "array",
+ "union",
+ "literal"
+ ],
+ "description": "The type of schema"
+ }
+ },
+ "allOf": [
+ {
+ "if": {
+ "properties": {
+ "schema_type": {
+ "const": "primitive"
+ }
+ }
+ },
+ "then": {
+ "required": [
+ "type"
+ ],
+ "properties": {
+ "type": {
+ "type": "string",
+ "enum": [
+ "string",
+ "integer",
+ "boolean",
+ "server_name",
+ "user_id",
+ "room_id",
+ "room_alias",
+ "event_id"
+ ],
+ "description": "The primitive type (only for schema_type: primitive)"
+ }
+ }
+ }
+ },
+ {
+ "if": {
+ "properties": {
+ "schema_type": {
+ "const": "array"
+ }
+ }
+ },
+ "then": {
+ "required": [
+ "items"
+ ],
+ "properties": {
+ "items": {
+ "$ref": "#/$defs/ParameterSchema",
+ "description": "The schema for array items (only for schema_type: array)"
+ }
+ }
+ }
+ },
+ {
+ "if": {
+ "properties": {
+ "schema_type": {
+ "const": "union"
+ }
+ }
+ },
+ "then": {
+ "required": [
+ "variants"
+ ],
+ "properties": {
+ "variants": {
+ "type": "array",
+ "description": "The possible variants (only for schema_type: union)",
+ "items": {
+ "$ref": "#/$defs/ParameterSchema"
+ },
+ "minItems": 1
+ }
+ }
+ }
+ },
+ {
+ "if": {
+ "properties": {
+ "schema_type": {
+ "const": "literal"
+ }
+ }
+ },
+ "then": {
+ "required": [
+ "value"
+ ],
+ "properties": {
+ "value": {
+ "description": "The literal value (only for schema_type: literal)"
+ }
+ }
+ }
+ }
+ ]
+ }
+ }
+}
diff --git a/event/cmdschema/testdata/commands/flags.json b/event/cmdschema/testdata/commands/flags.json
new file mode 100644
index 00000000..6ce1f4da
--- /dev/null
+++ b/event/cmdschema/testdata/commands/flags.json
@@ -0,0 +1,126 @@
+{
+ "$schema": "../commands.schema.json#",
+ "spec": {
+ "command": "flag",
+ "source": "@testbot",
+ "parameters": [
+ {
+ "key": "meow",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "string"
+ }
+ },
+ {
+ "key": "user",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "user_id"
+ },
+ "optional": true
+ },
+ {
+ "key": "woof",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "boolean"
+ },
+ "optional": true,
+ "fi.mau.default_value": false
+ }
+ ],
+ "fi.mau.tail_parameter": "user"
+ },
+ "tests": [
+ {
+ "name": "no flags",
+ "input": "/flag mrrp",
+ "output": {
+ "meow": "mrrp",
+ "user": null
+ }
+ },
+ {
+ "name": "no flags, has tail",
+ "input": "/flag mrrp @user:example.com",
+ "output": {
+ "meow": "mrrp",
+ "user": "@user:example.com"
+ }
+ },
+ {
+ "name": "named flag at start",
+ "input": "/flag --woof=yes mrrp @user:example.com",
+ "output": {
+ "meow": "mrrp",
+ "user": "@user:example.com",
+ "woof": true
+ }
+ },
+ {
+ "name": "boolean flag without value",
+ "input": "/flag --woof mrrp @user:example.com",
+ "output": {
+ "meow": "mrrp",
+ "user": "@user:example.com",
+ "woof": true
+ }
+ },
+ {
+ "name": "user id flag without value",
+ "input": "/flag --user --woof mrrp",
+ "error": true,
+ "output": {
+ "meow": "mrrp",
+ "user": null,
+ "woof": true
+ }
+ },
+ {
+ "name": "named flag in the middle",
+ "input": "/flag mrrp --woof=yes @user:example.com",
+ "output": {
+ "meow": "mrrp",
+ "user": "@user:example.com",
+ "woof": true
+ }
+ },
+ {
+ "name": "named flag in the middle with different value",
+ "input": "/flag mrrp --woof=no @user:example.com",
+ "output": {
+ "meow": "mrrp",
+ "user": "@user:example.com",
+ "woof": false
+ }
+ },
+ {
+ "name": "all variables named",
+ "input": "/flag --woof=no --meow=mrrp --user=@user:example.com",
+ "output": {
+ "meow": "mrrp",
+ "user": "@user:example.com",
+ "woof": false
+ }
+ },
+ {
+ "name": "all variables named with quotes",
+ "input": "/flag --woof --meow=\"meow meow mrrp\" --user=\"@user:example.com\"",
+ "output": {
+ "meow": "meow meow mrrp",
+ "user": "@user:example.com",
+ "woof": true
+ }
+ },
+ {
+ "name": "invalid value for named parameter",
+ "input": "/flag --user=meowings mrrp --woof",
+ "error": true,
+ "output": {
+ "meow": "mrrp",
+ "user": null,
+ "woof": true
+ }
+ }
+ ]
+}
diff --git a/event/cmdschema/testdata/commands/room_id_or_alias.json b/event/cmdschema/testdata/commands/room_id_or_alias.json
new file mode 100644
index 00000000..1351c292
--- /dev/null
+++ b/event/cmdschema/testdata/commands/room_id_or_alias.json
@@ -0,0 +1,85 @@
+{
+ "$schema": "../commands.schema.json#",
+ "spec": {
+ "command": "test room reference",
+ "source": "@testbot",
+ "parameters": [
+ {
+ "key": "room",
+ "schema": {
+ "schema_type": "union",
+ "variants": [
+ {
+ "schema_type": "primitive",
+ "type": "room_id"
+ },
+ {
+ "schema_type": "primitive",
+ "type": "room_alias"
+ }
+ ]
+ }
+ }
+ ]
+ },
+ "tests": [
+ {
+ "name": "room alias",
+ "input": "/test room reference #test:matrix.org",
+ "output": {
+ "room": "#test:matrix.org"
+ }
+ },
+ {
+ "name": "room id",
+ "input": "/test room reference !aiwVrNhPwbGBNjqlNu:matrix.org",
+ "output": {
+ "room": {
+ "type": "room_id",
+ "id": "!aiwVrNhPwbGBNjqlNu:matrix.org"
+ }
+ }
+ },
+ {
+ "name": "room id matrix.to link",
+ "input": "/test room reference https://matrix.to/#/!aiwVrNhPwbGBNjqlNu:matrix.org?via=example.com",
+ "output": {
+ "room": {
+ "type": "room_id",
+ "id": "!aiwVrNhPwbGBNjqlNu:matrix.org",
+ "via": [
+ "example.com"
+ ]
+ }
+ }
+ },
+ {
+ "name": "room id matrix.to link with url encoding",
+ "input": "/test room reference https://matrix.to/#/!%23test%2Froom%0Aversion%20%3Cu%3E11%3C%2Fu%3E%2C%20with%20%40%F0%9F%90%88%EF%B8%8F%3Amaunium.net?via=maunium.net",
+ "broken": "Go's url.URL does url decoding on the fragment, which breaks splitting the path segments properly",
+ "output": {
+ "room": {
+ "type": "room_id",
+ "id": "!#test/room\nversion 11, with @🐈️:maunium.net",
+ "via": [
+ "maunium.net"
+ ]
+ }
+ }
+ },
+ {
+ "name": "room id matrix: URI",
+ "input": "/test room reference matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org",
+ "output": {
+ "room": {
+ "type": "room_id",
+ "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ",
+ "via": [
+ "maunium.net",
+ "matrix.org"
+ ]
+ }
+ }
+ }
+ ]
+}
diff --git a/event/cmdschema/testdata/commands/room_reference_list.json b/event/cmdschema/testdata/commands/room_reference_list.json
new file mode 100644
index 00000000..aa266054
--- /dev/null
+++ b/event/cmdschema/testdata/commands/room_reference_list.json
@@ -0,0 +1,106 @@
+{
+ "$schema": "../commands.schema.json#",
+ "spec": {
+ "command": "test room reference",
+ "source": "@testbot",
+ "parameters": [
+ {
+ "key": "rooms",
+ "schema": {
+ "schema_type": "array",
+ "items": {
+ "schema_type": "union",
+ "variants": [
+ {
+ "schema_type": "primitive",
+ "type": "room_id"
+ },
+ {
+ "schema_type": "primitive",
+ "type": "room_alias"
+ }
+ ]
+ }
+ }
+ }
+ ]
+ },
+ "tests": [
+ {
+ "name": "room alias",
+ "input": "/test room reference #test:matrix.org",
+ "output": {
+ "rooms": [
+ "#test:matrix.org"
+ ]
+ }
+ },
+ {
+ "name": "room id",
+ "input": "/test room reference !aiwVrNhPwbGBNjqlNu:matrix.org",
+ "output": {
+ "rooms": [
+ {
+ "type": "room_id",
+ "id": "!aiwVrNhPwbGBNjqlNu:matrix.org"
+ }
+ ]
+ }
+ },
+ {
+ "name": "two room ids",
+ "input": "/test room reference !mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ !aiwVrNhPwbGBNjqlNu:matrix.org",
+ "output": {
+ "rooms": [
+ {
+ "type": "room_id",
+ "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ"
+ },
+ {
+ "type": "room_id",
+ "id": "!aiwVrNhPwbGBNjqlNu:matrix.org"
+ }
+ ]
+ }
+ },
+ {
+ "name": "room id matrix: URI",
+ "input": "/test room reference matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org",
+ "output": {
+ "rooms": [
+ {
+ "type": "room_id",
+ "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ",
+ "via": [
+ "maunium.net",
+ "matrix.org"
+ ]
+ }
+ ]
+ }
+ },
+ {
+ "name": "room id matrix: URI and matrix.to URL",
+ "input": "/test room reference https://matrix.to/#/!aiwVrNhPwbGBNjqlNu:matrix.org?via=example.com matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org",
+ "output": {
+ "rooms": [
+ {
+ "type": "room_id",
+ "id": "!aiwVrNhPwbGBNjqlNu:matrix.org",
+ "via": [
+ "example.com"
+ ]
+ },
+ {
+ "type": "room_id",
+ "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ",
+ "via": [
+ "maunium.net",
+ "matrix.org"
+ ]
+ }
+ ]
+ }
+ }
+ ]
+}
diff --git a/event/cmdschema/testdata/commands/simple.json b/event/cmdschema/testdata/commands/simple.json
new file mode 100644
index 00000000..94667323
--- /dev/null
+++ b/event/cmdschema/testdata/commands/simple.json
@@ -0,0 +1,46 @@
+{
+ "$schema": "../commands.schema.json#",
+ "spec": {
+ "command": "test simple",
+ "source": "@testbot",
+ "parameters": [
+ {
+ "key": "meow",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "string"
+ }
+ }
+ ]
+ },
+ "tests": [
+ {
+ "name": "success",
+ "input": "/test simple mrrp",
+ "output": {
+ "meow": "mrrp"
+ }
+ },
+ {
+ "name": "directed success",
+ "input": "/test simple@testbot mrrp",
+ "output": {
+ "meow": "mrrp"
+ }
+ },
+ {
+ "name": "missing parameter",
+ "input": "/test simple",
+ "error": true,
+ "output": {
+ "meow": ""
+ }
+ },
+ {
+ "name": "directed at another bot",
+ "input": "/test simple@anotherbot mrrp",
+ "error": false,
+ "output": null
+ }
+ ]
+}
diff --git a/event/cmdschema/testdata/commands/tail.json b/event/cmdschema/testdata/commands/tail.json
new file mode 100644
index 00000000..9782f8ec
--- /dev/null
+++ b/event/cmdschema/testdata/commands/tail.json
@@ -0,0 +1,60 @@
+{
+ "$schema": "../commands.schema.json#",
+ "spec": {
+ "command": "tail",
+ "source": "@testbot",
+ "parameters": [
+ {
+ "key": "meow",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "string"
+ }
+ },
+ {
+ "key": "reason",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "string"
+ },
+ "optional": true
+ },
+ {
+ "key": "woof",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "boolean"
+ },
+ "optional": true
+ }
+ ],
+ "fi.mau.tail_parameter": "reason"
+ },
+ "tests": [
+ {
+ "name": "no tail or flag",
+ "input": "/tail mrrp",
+ "output": {
+ "meow": "mrrp",
+ "reason": ""
+ }
+ },
+ {
+ "name": "tail, no flag",
+ "input": "/tail mrrp meow meow",
+ "output": {
+ "meow": "mrrp",
+ "reason": "meow meow"
+ }
+ },
+ {
+ "name": "flag before tail",
+ "input": "/tail mrrp --woof meow meow",
+ "output": {
+ "meow": "mrrp",
+ "reason": "meow meow",
+ "woof": true
+ }
+ }
+ ]
+}
diff --git a/event/cmdschema/testdata/data.go b/event/cmdschema/testdata/data.go
new file mode 100644
index 00000000..eceea3d2
--- /dev/null
+++ b/event/cmdschema/testdata/data.go
@@ -0,0 +1,14 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package testdata
+
+import (
+ "embed"
+)
+
+//go:embed *
+var FS embed.FS
diff --git a/event/cmdschema/testdata/parse_quote.json b/event/cmdschema/testdata/parse_quote.json
new file mode 100644
index 00000000..8f52b7f5
--- /dev/null
+++ b/event/cmdschema/testdata/parse_quote.json
@@ -0,0 +1,30 @@
+[
+ {"name": "empty string", "input": "", "output": ["", "", false]},
+ {"name": "single word", "input": "meow", "output": ["meow", "", false]},
+ {"name": "two words", "input": "meow woof", "output": ["meow", "woof", false]},
+ {"name": "many words", "input": "meow meow mrrp", "output": ["meow", "meow mrrp", false]},
+ {"name": "extra spaces", "input": "meow meow mrrp", "output": ["meow", "meow mrrp", false]},
+ {"name": "trailing space", "input": "meow ", "output": ["meow", "", false]},
+ {"name": "only spaces", "input": " ", "output": ["", "", false]},
+ {"name": "leading spaces", "input": " meow woof", "output": ["", "meow woof", false]},
+ {"name": "backslash at end unquoted", "input": "meow\\ woof", "output": ["meow\\", "woof", false]},
+ {"name": "quoted word", "input": "\"meow\" meow mrrp", "output": ["meow", "meow mrrp", true]},
+ {"name": "quoted words", "input": "\"meow meow\" mrrp", "output": ["meow meow", "mrrp", true]},
+ {"name": "spaces in quotes", "input": "\" meow meow \" mrrp", "output": [" meow meow ", "mrrp", true]},
+ {"name": "empty quoted string", "input": "\"\"", "output": ["", "", true]},
+ {"name": "empty quoted with trailing", "input": "\"\" meow", "output": ["", "meow", true]},
+ {"name": "quote no space before next", "input": "\"meow\"woof", "output": ["meow", "woof", true]},
+ {"name": "just opening quote", "input": "\"", "output": ["", "", true]},
+ {"name": "quote then space then text", "input": "\" meow", "output": [" meow", "", true]},
+ {"name": "quotes after word", "input": "meow \" meow mrrp \"", "output": ["meow", "\" meow mrrp \"", false]},
+ {"name": "escaped quote", "input": "\"meow\\\" meow\" mrrp", "output": ["meow\" meow", "mrrp", true]},
+ {"name": "missing end quote", "input": "\"meow meow mrrp", "output": ["meow meow mrrp", "", true]},
+ {"name": "missing end quote with escaped quote", "input": "\"meow\\\" meow mrrp", "output": ["meow\" meow mrrp", "", true]},
+ {"name": "quote in the middle", "input": "me\"ow meow mrrp", "output": ["me\"ow", "meow mrrp", false]},
+ {"name": "backslash in the middle", "input": "me\\ow meow mrrp", "output": ["me\\ow", "meow mrrp", false]},
+ {"name": "other escaped character", "input": "\"m\\eow\" meow mrrp", "output": ["meow", "meow mrrp", true]},
+ {"name": "escaped backslashes", "input": "\"m\\\\e\\\"ow\\\\\" meow mrrp", "output": ["m\\e\"ow\\", "meow mrrp", true]},
+ {"name": "just quotes", "input": "\"\\\"\\\"\\\\\\\"\" meow", "output": ["\"\"\\\"", "meow", true]},
+ {"name": "escape at eof", "input": "\"meow\\", "output": ["meow", "", true]},
+ {"name": "escaped backslash at eof", "input": "\"meow\\\\", "output": ["meow\\", "", true]}
+]
diff --git a/event/cmdschema/testdata/parse_quote.schema.json b/event/cmdschema/testdata/parse_quote.schema.json
new file mode 100644
index 00000000..9f249116
--- /dev/null
+++ b/event/cmdschema/testdata/parse_quote.schema.json
@@ -0,0 +1,46 @@
+{
+ "$schema": "https://json-schema.org/draft/2020-12/schema#",
+ "$id": "parse_quote.schema.json",
+ "title": "parseQuote test cases",
+ "description": "Test cases for the parseQuoted function",
+ "type": "array",
+ "items": {
+ "type": "object",
+ "required": [
+ "name",
+ "input",
+ "output"
+ ],
+ "properties": {
+ "name": {
+ "type": "string",
+ "description": "Name of the test case"
+ },
+ "input": {
+ "type": "string",
+ "description": "Input string to be parsed"
+ },
+ "output": {
+ "type": "array",
+ "description": "Expected output of parsing: [first word, remaining text, was quoted]",
+ "minItems": 3,
+ "maxItems": 3,
+ "prefixItems": [
+ {
+ "type": "string",
+ "description": "First parsed word"
+ },
+ {
+ "type": "string",
+ "description": "Remaining text after the first word"
+ },
+ {
+ "type": "boolean",
+ "description": "Whether the first word was quoted"
+ }
+ ]
+ }
+ },
+ "additionalProperties": false
+ }
+}
diff --git a/event/content.go b/event/content.go
index 24c1c193..814aeec4 100644
--- a/event/content.go
+++ b/event/content.go
@@ -18,6 +18,7 @@ import (
// This is used by Content.ParseRaw() for creating the correct type of struct.
var TypeMap = map[Type]reflect.Type{
StateMember: reflect.TypeOf(MemberEventContent{}),
+ StateThirdPartyInvite: reflect.TypeOf(ThirdPartyInviteEventContent{}),
StatePowerLevels: reflect.TypeOf(PowerLevelsEventContent{}),
StateCanonicalAlias: reflect.TypeOf(CanonicalAliasEventContent{}),
StateRoomName: reflect.TypeOf(RoomNameEventContent{}),
@@ -38,7 +39,20 @@ var TypeMap = map[Type]reflect.Type{
StateHalfShotBridge: reflect.TypeOf(BridgeEventContent{}),
StateSpaceParent: reflect.TypeOf(SpaceParentEventContent{}),
StateSpaceChild: reflect.TypeOf(SpaceChildEventContent{}),
- StateInsertionMarker: reflect.TypeOf(InsertionMarkerContent{}),
+
+ StateRoomPolicy: reflect.TypeOf(RoomPolicyEventContent{}),
+ StateUnstableRoomPolicy: reflect.TypeOf(RoomPolicyEventContent{}),
+
+ StateLegacyPolicyRoom: reflect.TypeOf(ModPolicyContent{}),
+ StateLegacyPolicyServer: reflect.TypeOf(ModPolicyContent{}),
+ StateLegacyPolicyUser: reflect.TypeOf(ModPolicyContent{}),
+ StateUnstablePolicyRoom: reflect.TypeOf(ModPolicyContent{}),
+ StateUnstablePolicyServer: reflect.TypeOf(ModPolicyContent{}),
+ StateUnstablePolicyUser: reflect.TypeOf(ModPolicyContent{}),
+
+ StateElementFunctionalMembers: reflect.TypeOf(ElementFunctionalMembersContent{}),
+ StateBeeperRoomFeatures: reflect.TypeOf(RoomFeatures{}),
+ StateBeeperDisappearingTimer: reflect.TypeOf(BeeperDisappearingTimer{}),
EventMessage: reflect.TypeOf(MessageEventContent{}),
EventSticker: reflect.TypeOf(MessageEventContent{}),
@@ -46,37 +60,55 @@ var TypeMap = map[Type]reflect.Type{
EventRedaction: reflect.TypeOf(RedactionEventContent{}),
EventReaction: reflect.TypeOf(ReactionEventContent{}),
- BeeperMessageStatus: reflect.TypeOf(BeeperMessageStatusEventContent{}),
+ EventUnstablePollStart: reflect.TypeOf(PollStartEventContent{}),
+ EventUnstablePollResponse: reflect.TypeOf(PollResponseEventContent{}),
+
+ BeeperMessageStatus: reflect.TypeOf(BeeperMessageStatusEventContent{}),
+ BeeperTranscription: reflect.TypeOf(BeeperTranscriptionEventContent{}),
+ BeeperDeleteChat: reflect.TypeOf(BeeperChatDeleteEventContent{}),
+ BeeperAcceptMessageRequest: reflect.TypeOf(BeeperAcceptMessageRequestEventContent{}),
+ BeeperSendState: reflect.TypeOf(BeeperSendStateEventContent{}),
AccountDataRoomTags: reflect.TypeOf(TagEventContent{}),
AccountDataDirectChats: reflect.TypeOf(DirectChatsEventContent{}),
AccountDataFullyRead: reflect.TypeOf(FullyReadEventContent{}),
AccountDataIgnoredUserList: reflect.TypeOf(IgnoredUserListEventContent{}),
+ AccountDataMarkedUnread: reflect.TypeOf(MarkedUnreadEventContent{}),
+ AccountDataBeeperMute: reflect.TypeOf(BeeperMuteEventContent{}),
- EphemeralEventTyping: reflect.TypeOf(TypingEventContent{}),
- EphemeralEventReceipt: reflect.TypeOf(ReceiptEventContent{}),
- EphemeralEventPresence: reflect.TypeOf(PresenceEventContent{}),
+ EphemeralEventTyping: reflect.TypeOf(TypingEventContent{}),
+ EphemeralEventReceipt: reflect.TypeOf(ReceiptEventContent{}),
+ EphemeralEventPresence: reflect.TypeOf(PresenceEventContent{}),
+ EphemeralEventEncrypted: reflect.TypeOf(EncryptedEventContent{}),
+ BeeperEphemeralEventAIStream: reflect.TypeOf(BeeperAIStreamEventContent{}),
- InRoomVerificationStart: reflect.TypeOf(VerificationStartEventContent{}),
InRoomVerificationReady: reflect.TypeOf(VerificationReadyEventContent{}),
+ InRoomVerificationStart: reflect.TypeOf(VerificationStartEventContent{}),
+ InRoomVerificationDone: reflect.TypeOf(VerificationDoneEventContent{}),
+ InRoomVerificationCancel: reflect.TypeOf(VerificationCancelEventContent{}),
+
InRoomVerificationAccept: reflect.TypeOf(VerificationAcceptEventContent{}),
InRoomVerificationKey: reflect.TypeOf(VerificationKeyEventContent{}),
- InRoomVerificationMAC: reflect.TypeOf(VerificationMacEventContent{}),
- InRoomVerificationCancel: reflect.TypeOf(VerificationCancelEventContent{}),
+ InRoomVerificationMAC: reflect.TypeOf(VerificationMACEventContent{}),
ToDeviceRoomKey: reflect.TypeOf(RoomKeyEventContent{}),
ToDeviceForwardedRoomKey: reflect.TypeOf(ForwardedRoomKeyEventContent{}),
ToDeviceRoomKeyRequest: reflect.TypeOf(RoomKeyRequestEventContent{}),
ToDeviceEncrypted: reflect.TypeOf(EncryptedEventContent{}),
ToDeviceRoomKeyWithheld: reflect.TypeOf(RoomKeyWithheldEventContent{}),
+ ToDeviceSecretRequest: reflect.TypeOf(SecretRequestEventContent{}),
+ ToDeviceSecretSend: reflect.TypeOf(SecretSendEventContent{}),
ToDeviceDummy: reflect.TypeOf(DummyEventContent{}),
- ToDeviceVerificationStart: reflect.TypeOf(VerificationStartEventContent{}),
- ToDeviceVerificationAccept: reflect.TypeOf(VerificationAcceptEventContent{}),
- ToDeviceVerificationKey: reflect.TypeOf(VerificationKeyEventContent{}),
- ToDeviceVerificationMAC: reflect.TypeOf(VerificationMacEventContent{}),
- ToDeviceVerificationCancel: reflect.TypeOf(VerificationCancelEventContent{}),
ToDeviceVerificationRequest: reflect.TypeOf(VerificationRequestEventContent{}),
+ ToDeviceVerificationReady: reflect.TypeOf(VerificationReadyEventContent{}),
+ ToDeviceVerificationStart: reflect.TypeOf(VerificationStartEventContent{}),
+ ToDeviceVerificationDone: reflect.TypeOf(VerificationDoneEventContent{}),
+ ToDeviceVerificationCancel: reflect.TypeOf(VerificationCancelEventContent{}),
+
+ ToDeviceVerificationAccept: reflect.TypeOf(VerificationAcceptEventContent{}),
+ ToDeviceVerificationKey: reflect.TypeOf(VerificationKeyEventContent{}),
+ ToDeviceVerificationMAC: reflect.TypeOf(VerificationMACEventContent{}),
ToDeviceOrgMatrixRoomKeyWithheld: reflect.TypeOf(RoomKeyWithheldEventContent{}),
@@ -100,7 +132,7 @@ var TypeMap = map[Type]reflect.Type{
// When being marshaled into JSON, the data in Parsed will be marshaled first and then recursively merged
// with the data in Raw. Values in Raw are preferred, but nested objects will be recursed into before merging,
// rather than overriding the whole object with the one in Raw).
-// If one of them is nil, the only the other is used. If both (Parsed and Raw) are nil, VeryRaw is used instead.
+// If one of them is nil, then only the other is used. If both (Parsed and Raw) are nil, VeryRaw is used instead.
type Content struct {
VeryRaw json.RawMessage
Raw map[string]interface{}
@@ -167,6 +199,13 @@ func IsUnsupportedContentType(err error) bool {
var ErrContentAlreadyParsed = errors.New("content is already parsed")
var ErrUnsupportedContentType = errors.New("unsupported event type")
+func (content *Content) GetRaw() map[string]interface{} {
+ if content.Raw == nil {
+ content.Raw = make(map[string]interface{})
+ }
+ return content.Raw
+}
+
func (content *Content) ParseRaw(evtType Type) error {
if content.Parsed != nil {
return ErrContentAlreadyParsed
@@ -204,6 +243,7 @@ func init() {
gob.Register(&BridgeEventContent{})
gob.Register(&SpaceChildEventContent{})
gob.Register(&SpaceParentEventContent{})
+ gob.Register(&ElementFunctionalMembersContent{})
gob.Register(&RoomNameEventContent{})
gob.Register(&RoomAvatarEventContent{})
gob.Register(&TopicEventContent{})
@@ -231,6 +271,15 @@ func init() {
gob.Register(&RoomKeyWithheldEventContent{})
}
+func CastOrDefault[T any](content *Content) *T {
+ casted, ok := content.Parsed.(*T)
+ if ok {
+ return casted
+ }
+ casted2, _ := content.Parsed.(T)
+ return &casted2
+}
+
// Helper cast functions below
func (content *Content) AsMember() *MemberEventContent {
@@ -345,6 +394,13 @@ func (content *Content) AsSpaceParent() *SpaceParentEventContent {
}
return casted
}
+func (content *Content) AsElementFunctionalMembers() *ElementFunctionalMembersContent {
+ casted, ok := content.Parsed.(*ElementFunctionalMembersContent)
+ if !ok {
+ return &ElementFunctionalMembersContent{}
+ }
+ return casted
+}
func (content *Content) AsMessage() *MessageEventContent {
casted, ok := content.Parsed.(*MessageEventContent)
if !ok {
@@ -401,6 +457,13 @@ func (content *Content) AsIgnoredUserList() *IgnoredUserListEventContent {
}
return casted
}
+func (content *Content) AsMarkedUnread() *MarkedUnreadEventContent {
+ casted, ok := content.Parsed.(*MarkedUnreadEventContent)
+ if !ok {
+ return &MarkedUnreadEventContent{}
+ }
+ return casted
+}
func (content *Content) AsTyping() *TypingEventContent {
casted, ok := content.Parsed.(*TypingEventContent)
if !ok {
@@ -506,3 +569,59 @@ func (content *Content) AsModPolicy() *ModPolicyContent {
}
return casted
}
+func (content *Content) AsVerificationRequest() *VerificationRequestEventContent {
+ casted, ok := content.Parsed.(*VerificationRequestEventContent)
+ if !ok {
+ return &VerificationRequestEventContent{}
+ }
+ return casted
+}
+func (content *Content) AsVerificationReady() *VerificationReadyEventContent {
+ casted, ok := content.Parsed.(*VerificationReadyEventContent)
+ if !ok {
+ return &VerificationReadyEventContent{}
+ }
+ return casted
+}
+func (content *Content) AsVerificationStart() *VerificationStartEventContent {
+ casted, ok := content.Parsed.(*VerificationStartEventContent)
+ if !ok {
+ return &VerificationStartEventContent{}
+ }
+ return casted
+}
+func (content *Content) AsVerificationDone() *VerificationDoneEventContent {
+ casted, ok := content.Parsed.(*VerificationDoneEventContent)
+ if !ok {
+ return &VerificationDoneEventContent{}
+ }
+ return casted
+}
+func (content *Content) AsVerificationCancel() *VerificationCancelEventContent {
+ casted, ok := content.Parsed.(*VerificationCancelEventContent)
+ if !ok {
+ return &VerificationCancelEventContent{}
+ }
+ return casted
+}
+func (content *Content) AsVerificationAccept() *VerificationAcceptEventContent {
+ casted, ok := content.Parsed.(*VerificationAcceptEventContent)
+ if !ok {
+ return &VerificationAcceptEventContent{}
+ }
+ return casted
+}
+func (content *Content) AsVerificationKey() *VerificationKeyEventContent {
+ casted, ok := content.Parsed.(*VerificationKeyEventContent)
+ if !ok {
+ return &VerificationKeyEventContent{}
+ }
+ return casted
+}
+func (content *Content) AsVerificationMAC() *VerificationMACEventContent {
+ casted, ok := content.Parsed.(*VerificationMACEventContent)
+ if !ok {
+ return &VerificationMACEventContent{}
+ }
+ return casted
+}
diff --git a/event/delayed.go b/event/delayed.go
new file mode 100644
index 00000000..fefb62af
--- /dev/null
+++ b/event/delayed.go
@@ -0,0 +1,70 @@
+package event
+
+import (
+ "encoding/json"
+
+ "go.mau.fi/util/jsontime"
+
+ "maunium.net/go/mautrix/id"
+)
+
+type ScheduledDelayedEvent struct {
+ DelayID id.DelayID `json:"delay_id"`
+ RoomID id.RoomID `json:"room_id"`
+ Type Type `json:"type"`
+ StateKey *string `json:"state_key,omitempty"`
+ Delay int64 `json:"delay"`
+ RunningSince jsontime.UnixMilli `json:"running_since"`
+ Content Content `json:"content"`
+}
+
+func (e ScheduledDelayedEvent) AsEvent(eventID id.EventID, ts jsontime.UnixMilli) (*Event, error) {
+ evt := &Event{
+ ID: eventID,
+ RoomID: e.RoomID,
+ Type: e.Type,
+ StateKey: e.StateKey,
+ Content: e.Content,
+ Timestamp: ts.UnixMilli(),
+ }
+ return evt, evt.Content.ParseRaw(evt.Type)
+}
+
+type FinalisedDelayedEvent struct {
+ DelayedEvent *ScheduledDelayedEvent `json:"scheduled_event"`
+ Outcome DelayOutcome `json:"outcome"`
+ Reason DelayReason `json:"reason"`
+ Error json.RawMessage `json:"error,omitempty"`
+ EventID id.EventID `json:"event_id,omitempty"`
+ Timestamp jsontime.UnixMilli `json:"origin_server_ts"`
+}
+
+type DelayStatus string
+
+var (
+ DelayStatusScheduled DelayStatus = "scheduled"
+ DelayStatusFinalised DelayStatus = "finalised"
+)
+
+type DelayAction string
+
+var (
+ DelayActionSend DelayAction = "send"
+ DelayActionCancel DelayAction = "cancel"
+ DelayActionRestart DelayAction = "restart"
+)
+
+type DelayOutcome string
+
+var (
+ DelayOutcomeSend DelayOutcome = "send"
+ DelayOutcomeCancel DelayOutcome = "cancel"
+)
+
+type DelayReason string
+
+var (
+ DelayReasonAction DelayReason = "action"
+ DelayReasonError DelayReason = "error"
+ DelayReasonDelay DelayReason = "delay"
+)
diff --git a/event/encryption.go b/event/encryption.go
index fa1ac2dd..c60cb91a 100644
--- a/event/encryption.go
+++ b/event/encryption.go
@@ -63,7 +63,7 @@ func (content *EncryptedEventContent) UnmarshalJSON(data []byte) error {
return json.Unmarshal(content.Ciphertext, &content.OlmCiphertext)
case id.AlgorithmMegolmV1:
if len(content.Ciphertext) == 0 || content.Ciphertext[0] != '"' || content.Ciphertext[len(content.Ciphertext)-1] != '"' {
- return id.InputNotJSONString
+ return fmt.Errorf("ciphertext %w", id.ErrInputNotJSONString)
}
content.MegolmCiphertext = content.Ciphertext[1 : len(content.Ciphertext)-1]
}
@@ -132,8 +132,9 @@ type RoomKeyRequestEventContent struct {
type RequestedKeyInfo struct {
Algorithm id.Algorithm `json:"algorithm"`
RoomID id.RoomID `json:"room_id"`
- SenderKey id.SenderKey `json:"sender_key"`
SessionID id.SessionID `json:"session_id"`
+ // Deprecated: Matrix v1.3
+ SenderKey id.SenderKey `json:"sender_key"`
}
type RoomKeyWithheldCode string
@@ -176,4 +177,27 @@ func (withheld *RoomKeyWithheldEventContent) Is(other error) bool {
return withheld.Code == "" || otherWithheld.Code == "" || withheld.Code == otherWithheld.Code
}
+type SecretRequestAction string
+
+func (a SecretRequestAction) String() string {
+ return string(a)
+}
+
+const (
+ SecretRequestRequest = "request"
+ SecretRequestCancellation = "request_cancellation"
+)
+
+type SecretRequestEventContent struct {
+ Name id.Secret `json:"name,omitempty"`
+ Action SecretRequestAction `json:"action"`
+ RequestingDeviceID id.DeviceID `json:"requesting_device_id"`
+ RequestID string `json:"request_id"`
+}
+
+type SecretSendEventContent struct {
+ RequestID string `json:"request_id"`
+ Secret string `json:"secret"`
+}
+
type DummyEventContent struct{}
diff --git a/event/events.go b/event/events.go
index f7b4d4d6..72c1e161 100644
--- a/event/events.go
+++ b/event/events.go
@@ -118,6 +118,9 @@ type MautrixInfo struct {
DecryptionDuration time.Duration
CheckpointSent bool
+ // When using MSC4222 and the state_after field, this field is set
+ // for timeline events to indicate they shouldn't update room state.
+ IgnoreState bool
}
func (evt *Event) GetStateKey() string {
@@ -127,27 +130,29 @@ func (evt *Event) GetStateKey() string {
return ""
}
-type StrippedState struct {
- Content Content `json:"content"`
- Type Type `json:"type"`
- StateKey string `json:"state_key"`
- Sender id.UserID `json:"sender"`
-}
-
type Unsigned struct {
- PrevContent *Content `json:"prev_content,omitempty"`
- PrevSender id.UserID `json:"prev_sender,omitempty"`
- ReplacesState id.EventID `json:"replaces_state,omitempty"`
- Age int64 `json:"age,omitempty"`
- TransactionID string `json:"transaction_id,omitempty"`
- Relations *Relations `json:"m.relations,omitempty"`
- RedactedBecause *Event `json:"redacted_because,omitempty"`
- InviteRoomState []StrippedState `json:"invite_room_state,omitempty"`
+ PrevContent *Content `json:"prev_content,omitempty"`
+ PrevSender id.UserID `json:"prev_sender,omitempty"`
+ Membership Membership `json:"membership,omitempty"`
+ ReplacesState id.EventID `json:"replaces_state,omitempty"`
+ Age int64 `json:"age,omitempty"`
+ TransactionID string `json:"transaction_id,omitempty"`
+ Relations *Relations `json:"m.relations,omitempty"`
+ RedactedBecause *Event `json:"redacted_because,omitempty"`
+ InviteRoomState []*Event `json:"invite_room_state,omitempty"`
- BeeperHSOrder int64 `json:"com.beeper.hs.order,omitempty"`
+ BeeperHSOrder int64 `json:"com.beeper.hs.order,omitempty"`
+ BeeperHSSuborder int16 `json:"com.beeper.hs.suborder,omitempty"`
+ BeeperHSOrderString *BeeperEncodedOrder `json:"com.beeper.hs.order_string,omitempty"`
+ BeeperFromBackup bool `json:"com.beeper.from_backup,omitempty"`
+
+ ElementSoftFailed bool `json:"io.element.synapse.soft_failed,omitempty"`
+ ElementPolicyServerSpammy bool `json:"io.element.synapse.policy_server_spammy,omitempty"`
}
func (us *Unsigned) IsEmpty() bool {
- return us.PrevContent == nil && us.PrevSender == "" && us.ReplacesState == "" && us.Age == 0 &&
- us.TransactionID == "" && us.RedactedBecause == nil && us.InviteRoomState == nil && us.Relations == nil
+ return us.PrevContent == nil && us.PrevSender == "" && us.ReplacesState == "" && us.Age == 0 && us.Membership == "" &&
+ us.TransactionID == "" && us.RedactedBecause == nil && us.InviteRoomState == nil && us.Relations == nil &&
+ us.BeeperHSOrder == 0 && us.BeeperHSSuborder == 0 && us.BeeperHSOrderString.IsZero() &&
+ !us.ElementSoftFailed
}
diff --git a/event/member.go b/event/member.go
index ebafdcb7..9956a36b 100644
--- a/event/member.go
+++ b/event/member.go
@@ -7,8 +7,6 @@
package event
import (
- "encoding/json"
-
"maunium.net/go/mautrix/id"
)
@@ -35,19 +33,37 @@ const (
// MemberEventContent represents the content of a m.room.member state event.
// https://spec.matrix.org/v1.2/client-server-api/#mroommember
type MemberEventContent struct {
- Membership Membership `json:"membership"`
- AvatarURL id.ContentURIString `json:"avatar_url,omitempty"`
- Displayname string `json:"displayname,omitempty"`
- IsDirect bool `json:"is_direct,omitempty"`
- ThirdPartyInvite *ThirdPartyInvite `json:"third_party_invite,omitempty"`
- Reason string `json:"reason,omitempty"`
+ Membership Membership `json:"membership"`
+ AvatarURL id.ContentURIString `json:"avatar_url,omitempty"`
+ Displayname string `json:"displayname,omitempty"`
+ IsDirect bool `json:"is_direct,omitempty"`
+ ThirdPartyInvite *ThirdPartyInvite `json:"third_party_invite,omitempty"`
+ Reason string `json:"reason,omitempty"`
+ JoinAuthorisedViaUsersServer id.UserID `json:"join_authorised_via_users_server,omitempty"`
+ MSC3414File *EncryptedFileInfo `json:"org.matrix.msc3414.file,omitempty"`
+
+ MSC4293RedactEvents bool `json:"org.matrix.msc4293.redact_events,omitempty"`
+}
+
+type SignedThirdPartyInvite struct {
+ Token string `json:"token"`
+ Signatures map[string]map[id.KeyID]string `json:"signatures,omitempty"`
+ MXID string `json:"mxid"`
}
type ThirdPartyInvite struct {
- DisplayName string `json:"display_name"`
- Signed struct {
- Token string `json:"token"`
- Signatures json.RawMessage `json:"signatures"`
- MXID string `json:"mxid"`
- }
+ DisplayName string `json:"display_name"`
+ Signed SignedThirdPartyInvite `json:"signed"`
+}
+
+type ThirdPartyInviteEventContent struct {
+ DisplayName string `json:"display_name"`
+ KeyValidityURL string `json:"key_validity_url"`
+ PublicKey id.Ed25519 `json:"public_key"`
+ PublicKeys []ThirdPartyInviteKey `json:"public_keys,omitempty"`
+}
+
+type ThirdPartyInviteKey struct {
+ KeyValidityURL string `json:"key_validity_url,omitempty"`
+ PublicKey id.Ed25519 `json:"public_key"`
}
diff --git a/event/message.go b/event/message.go
index 6512f9be..3fb3dc82 100644
--- a/event/message.go
+++ b/event/message.go
@@ -8,11 +8,11 @@ package event
import (
"encoding/json"
+ "html"
+ "slices"
"strconv"
"strings"
- "golang.org/x/net/html"
-
"maunium.net/go/mautrix/crypto/attachment"
"maunium.net/go/mautrix/id"
)
@@ -21,6 +21,24 @@ import (
// https://spec.matrix.org/v1.2/client-server-api/#mroommessage-msgtypes
type MessageType string
+func (mt MessageType) IsText() bool {
+ switch mt {
+ case MsgText, MsgNotice, MsgEmote:
+ return true
+ default:
+ return false
+ }
+}
+
+func (mt MessageType) IsMedia() bool {
+ switch mt {
+ case MsgImage, MsgVideo, MsgAudio, MsgFile, CapMsgSticker:
+ return true
+ default:
+ return false
+ }
+}
+
// Msgtypes
const (
MsgText MessageType = "m.text"
@@ -112,10 +130,68 @@ type MessageEventContent struct {
replyFallbackRemoved bool
- MessageSendRetry *BeeperRetryMetadata `json:"com.beeper.message_send_retry,omitempty"`
- BeeperGalleryImages []*MessageEventContent `json:"com.beeper.gallery.images,omitempty"`
- BeeperGalleryCaption string `json:"com.beeper.gallery.caption,omitempty"`
- BeeperGalleryCaptionHTML string `json:"com.beeper.gallery.caption_html,omitempty"`
+ MessageSendRetry *BeeperRetryMetadata `json:"com.beeper.message_send_retry,omitempty"`
+ BeeperGalleryImages []*MessageEventContent `json:"com.beeper.gallery.images,omitempty"`
+ BeeperGalleryCaption string `json:"com.beeper.gallery.caption,omitempty"`
+ BeeperGalleryCaptionHTML string `json:"com.beeper.gallery.caption_html,omitempty"`
+ BeeperPerMessageProfile *BeeperPerMessageProfile `json:"com.beeper.per_message_profile,omitempty"`
+ BeeperActionMessage *BeeperActionMessage `json:"com.beeper.action_message,omitempty"`
+
+ BeeperLinkPreviews []*BeeperLinkPreview `json:"com.beeper.linkpreviews,omitempty"`
+
+ BeeperDisappearingTimer *BeeperDisappearingTimer `json:"com.beeper.disappearing_timer,omitempty"`
+
+ MSC1767Audio *MSC1767Audio `json:"org.matrix.msc1767.audio,omitempty"`
+ MSC3245Voice *MSC3245Voice `json:"org.matrix.msc3245.voice,omitempty"`
+
+ MSC4391BotCommand *MSC4391BotCommandInput `json:"org.matrix.msc4391.command,omitempty"`
+}
+
+func (content *MessageEventContent) GetCapMsgType() CapabilityMsgType {
+ switch content.MsgType {
+ case CapMsgSticker:
+ return CapMsgSticker
+ case "":
+ if content.URL != "" || content.File != nil {
+ return CapMsgSticker
+ }
+ case MsgImage:
+ return MsgImage
+ case MsgAudio:
+ if content.MSC3245Voice != nil {
+ return CapMsgVoice
+ }
+ return MsgAudio
+ case MsgVideo:
+ if content.Info != nil && content.Info.MauGIF {
+ return CapMsgGIF
+ }
+ return MsgVideo
+ case MsgFile:
+ return MsgFile
+ }
+ return ""
+}
+
+func (content *MessageEventContent) GetFileName() string {
+ if content.FileName != "" {
+ return content.FileName
+ }
+ return content.Body
+}
+
+func (content *MessageEventContent) GetCaption() string {
+ if content.FileName != "" && content.Body != "" && content.Body != content.FileName {
+ return content.Body
+ }
+ return ""
+}
+
+func (content *MessageEventContent) GetFormattedCaption() string {
+ if content.Format == FormatHTML && content.FormattedBody != "" {
+ return content.FormattedBody
+ }
+ return ""
}
func (content *MessageEventContent) GetRelatesTo() *RelatesTo {
@@ -139,6 +215,7 @@ func (content *MessageEventContent) SetEdit(original id.EventID) {
content.RelatesTo = (&RelatesTo{}).SetReplace(original)
if content.MsgType == MsgText || content.MsgType == MsgNotice {
content.Body = "* " + content.Body
+ content.Mentions = &Mentions{}
if content.Format == FormatHTML && len(content.FormattedBody) > 0 {
content.FormattedBody = "* " + content.FormattedBody
}
@@ -189,24 +266,56 @@ type Mentions struct {
Room bool `json:"room,omitempty"`
}
+func (m *Mentions) Add(userID id.UserID) {
+ if userID != "" && !slices.Contains(m.UserIDs, userID) {
+ m.UserIDs = append(m.UserIDs, userID)
+ }
+}
+
+func (m *Mentions) Has(userID id.UserID) bool {
+ return m != nil && slices.Contains(m.UserIDs, userID)
+}
+
+func (m *Mentions) Merge(other *Mentions) *Mentions {
+ if m == nil {
+ return other
+ } else if other == nil {
+ return m
+ }
+ return &Mentions{
+ UserIDs: slices.Concat(m.UserIDs, other.UserIDs),
+ Room: m.Room || other.Room,
+ }
+}
+
+type MSC4391BotCommandInputCustom[T any] struct {
+ Command string `json:"command"`
+ Arguments T `json:"arguments,omitempty"`
+}
+
+type MSC4391BotCommandInput = MSC4391BotCommandInputCustom[json.RawMessage]
+
type EncryptedFileInfo struct {
attachment.EncryptedFile
URL id.ContentURIString `json:"url"`
}
type FileInfo struct {
- MimeType string `json:"mimetype,omitempty"`
- ThumbnailInfo *FileInfo `json:"thumbnail_info,omitempty"`
- ThumbnailURL id.ContentURIString `json:"thumbnail_url,omitempty"`
- ThumbnailFile *EncryptedFileInfo `json:"thumbnail_file,omitempty"`
+ MimeType string
+ ThumbnailInfo *FileInfo
+ ThumbnailURL id.ContentURIString
+ ThumbnailFile *EncryptedFileInfo
- Blurhash string `json:"blurhash,omitempty"`
- AnoaBlurhash string `json:"xyz.amorgan.blurhash,omitempty"`
+ Blurhash string
+ AnoaBlurhash string
- Width int `json:"-"`
- Height int `json:"-"`
- Duration int `json:"-"`
- Size int `json:"-"`
+ MauGIF bool
+ IsAnimated bool
+
+ Width int
+ Height int
+ Duration int
+ Size int
}
type serializableFileInfo struct {
@@ -218,6 +327,9 @@ type serializableFileInfo struct {
Blurhash string `json:"blurhash,omitempty"`
AnoaBlurhash string `json:"xyz.amorgan.blurhash,omitempty"`
+ MauGIF bool `json:"fi.mau.gif,omitempty"`
+ IsAnimated bool `json:"is_animated,omitempty"`
+
Width json.Number `json:"w,omitempty"`
Height json.Number `json:"h,omitempty"`
Duration json.Number `json:"duration,omitempty"`
@@ -234,6 +346,9 @@ func (sfi *serializableFileInfo) CopyFrom(fileInfo *FileInfo) *serializableFileI
ThumbnailInfo: (&serializableFileInfo{}).CopyFrom(fileInfo.ThumbnailInfo),
ThumbnailFile: fileInfo.ThumbnailFile,
+ MauGIF: fileInfo.MauGIF,
+ IsAnimated: fileInfo.IsAnimated,
+
Blurhash: fileInfo.Blurhash,
AnoaBlurhash: fileInfo.AnoaBlurhash,
}
@@ -262,6 +377,8 @@ func (sfi *serializableFileInfo) CopyTo(fileInfo *FileInfo) {
MimeType: sfi.MimeType,
ThumbnailURL: sfi.ThumbnailURL,
ThumbnailFile: sfi.ThumbnailFile,
+ MauGIF: sfi.MauGIF,
+ IsAnimated: sfi.IsAnimated,
Blurhash: sfi.Blurhash,
AnoaBlurhash: sfi.AnoaBlurhash,
}
diff --git a/event/message_test.go b/event/message_test.go
index 562a6622..c721df35 100644
--- a/event/message_test.go
+++ b/event/message_test.go
@@ -33,7 +33,7 @@ const invalidMessageEvent = `{
func TestMessageEventContent__ParseInvalid(t *testing.T) {
var evt *event.Event
err := json.Unmarshal([]byte(invalidMessageEvent), &evt)
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender)
assert.Equal(t, event.EventMessage, evt.Type)
@@ -42,7 +42,7 @@ func TestMessageEventContent__ParseInvalid(t *testing.T) {
assert.Equal(t, id.RoomID("!bar"), evt.RoomID)
err = evt.Content.ParseRaw(evt.Type)
- assert.NotNil(t, err)
+ assert.Error(t, err)
}
const messageEvent = `{
@@ -68,7 +68,7 @@ const messageEvent = `{
func TestMessageEventContent__ParseEdit(t *testing.T) {
var evt *event.Event
err := json.Unmarshal([]byte(messageEvent), &evt)
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender)
assert.Equal(t, event.EventMessage, evt.Type)
@@ -110,7 +110,7 @@ const imageMessageEvent = `{
func TestMessageEventContent__ParseMedia(t *testing.T) {
var evt *event.Event
err := json.Unmarshal([]byte(imageMessageEvent), &evt)
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender)
assert.Equal(t, event.EventMessage, evt.Type)
@@ -125,7 +125,7 @@ func TestMessageEventContent__ParseMedia(t *testing.T) {
content := evt.Content.Parsed.(*event.MessageEventContent)
assert.Equal(t, event.MsgImage, content.MsgType)
parsedURL, err := content.URL.Parse()
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, id.ContentURI{Homeserver: "example.com", FileID: "image"}, parsedURL)
assert.Nil(t, content.NewContent)
assert.Equal(t, "image/png", content.GetInfo().MimeType)
@@ -145,7 +145,7 @@ const expectedMarshalResult = `{"msgtype":"m.text","body":"test"}`
func TestMessageEventContent__Marshal(t *testing.T) {
data, err := json.Marshal(parsedMessage)
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, expectedMarshalResult, string(data))
}
@@ -163,6 +163,6 @@ const expectedCustomMarshalResult = `{"body":"test","msgtype":"m.text","net.maun
func TestMessageEventContent__Marshal_Custom(t *testing.T) {
data, err := json.Marshal(customParsedMessage)
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, expectedCustomMarshalResult, string(data))
}
diff --git a/event/poll.go b/event/poll.go
new file mode 100644
index 00000000..9082f65e
--- /dev/null
+++ b/event/poll.go
@@ -0,0 +1,64 @@
+// Copyright (c) 2024 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package event
+
+type PollResponseEventContent struct {
+ RelatesTo RelatesTo `json:"m.relates_to"`
+ Response struct {
+ Answers []string `json:"answers"`
+ } `json:"org.matrix.msc3381.poll.response"`
+}
+
+func (content *PollResponseEventContent) GetRelatesTo() *RelatesTo {
+ return &content.RelatesTo
+}
+
+func (content *PollResponseEventContent) OptionalGetRelatesTo() *RelatesTo {
+ if content.RelatesTo.Type == "" {
+ return nil
+ }
+ return &content.RelatesTo
+}
+
+func (content *PollResponseEventContent) SetRelatesTo(rel *RelatesTo) {
+ content.RelatesTo = *rel
+}
+
+type MSC1767Message struct {
+ Text string `json:"org.matrix.msc1767.text,omitempty"`
+ HTML string `json:"org.matrix.msc1767.html,omitempty"`
+ Message []ExtensibleText `json:"org.matrix.msc1767.message,omitempty"`
+}
+
+type PollStartEventContent struct {
+ RelatesTo *RelatesTo `json:"m.relates_to,omitempty"`
+ Mentions *Mentions `json:"m.mentions,omitempty"`
+ PollStart struct {
+ Kind string `json:"kind"`
+ MaxSelections int `json:"max_selections"`
+ Question MSC1767Message `json:"question"`
+ Answers []struct {
+ ID string `json:"id"`
+ MSC1767Message
+ } `json:"answers"`
+ } `json:"org.matrix.msc3381.poll.start"`
+}
+
+func (content *PollStartEventContent) GetRelatesTo() *RelatesTo {
+ if content.RelatesTo == nil {
+ content.RelatesTo = &RelatesTo{}
+ }
+ return content.RelatesTo
+}
+
+func (content *PollStartEventContent) OptionalGetRelatesTo() *RelatesTo {
+ return content.RelatesTo
+}
+
+func (content *PollStartEventContent) SetRelatesTo(rel *RelatesTo) {
+ content.RelatesTo = rel
+}
diff --git a/event/powerlevels.go b/event/powerlevels.go
index 91d56611..668eb6d3 100644
--- a/event/powerlevels.go
+++ b/event/powerlevels.go
@@ -7,8 +7,13 @@
package event
import (
+ "math"
+ "slices"
"sync"
+ "go.mau.fi/util/ptr"
+ "golang.org/x/exp/maps"
+
"maunium.net/go/mautrix/id"
)
@@ -23,6 +28,9 @@ type PowerLevelsEventContent struct {
Events map[string]int `json:"events,omitempty"`
EventsDefault int `json:"events_default,omitempty"`
+ beeperEphemeralLock sync.RWMutex
+ BeeperEphemeral map[string]int `json:"com.beeper.ephemeral,omitempty"`
+
Notifications *NotificationPowerLevels `json:"notifications,omitempty"`
StateDefaultPtr *int `json:"state_default,omitempty"`
@@ -31,25 +39,12 @@ type PowerLevelsEventContent struct {
KickPtr *int `json:"kick,omitempty"`
BanPtr *int `json:"ban,omitempty"`
RedactPtr *int `json:"redact,omitempty"`
-}
-func copyPtr(ptr *int) *int {
- if ptr == nil {
- return nil
- }
- val := *ptr
- return &val
-}
+ BeeperEphemeralDefaultPtr *int `json:"com.beeper.ephemeral_default,omitempty"`
-func copyMap[Key comparable](m map[Key]int) map[Key]int {
- if m == nil {
- return nil
- }
- copied := make(map[Key]int, len(m))
- for k, v := range m {
- copied[k] = v
- }
- return copied
+ // This is not a part of power levels, it's added by mautrix-go internally in certain places
+ // in order to detect creator power accurately.
+ CreateEvent *Event `json:"-"`
}
func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent {
@@ -57,18 +52,23 @@ func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent {
return nil
}
return &PowerLevelsEventContent{
- Users: copyMap(pl.Users),
+ Users: maps.Clone(pl.Users),
UsersDefault: pl.UsersDefault,
- Events: copyMap(pl.Events),
+ Events: maps.Clone(pl.Events),
EventsDefault: pl.EventsDefault,
- StateDefaultPtr: copyPtr(pl.StateDefaultPtr),
+ BeeperEphemeral: maps.Clone(pl.BeeperEphemeral),
+ StateDefaultPtr: ptr.Clone(pl.StateDefaultPtr),
Notifications: pl.Notifications.Clone(),
- InvitePtr: copyPtr(pl.InvitePtr),
- KickPtr: copyPtr(pl.KickPtr),
- BanPtr: copyPtr(pl.BanPtr),
- RedactPtr: copyPtr(pl.RedactPtr),
+ InvitePtr: ptr.Clone(pl.InvitePtr),
+ KickPtr: ptr.Clone(pl.KickPtr),
+ BanPtr: ptr.Clone(pl.BanPtr),
+ RedactPtr: ptr.Clone(pl.RedactPtr),
+
+ BeeperEphemeralDefaultPtr: ptr.Clone(pl.BeeperEphemeralDefaultPtr),
+
+ CreateEvent: pl.CreateEvent,
}
}
@@ -81,7 +81,7 @@ func (npl *NotificationPowerLevels) Clone() *NotificationPowerLevels {
return nil
}
return &NotificationPowerLevels{
- RoomPtr: copyPtr(npl.RoomPtr),
+ RoomPtr: ptr.Clone(npl.RoomPtr),
}
}
@@ -96,7 +96,7 @@ func (pl *PowerLevelsEventContent) Invite() int {
if pl.InvitePtr != nil {
return *pl.InvitePtr
}
- return 50
+ return 0
}
func (pl *PowerLevelsEventContent) Kick() int {
@@ -127,7 +127,17 @@ func (pl *PowerLevelsEventContent) StateDefault() int {
return 50
}
+func (pl *PowerLevelsEventContent) BeeperEphemeralDefault() int {
+ if pl.BeeperEphemeralDefaultPtr != nil {
+ return *pl.BeeperEphemeralDefaultPtr
+ }
+ return pl.EventsDefault
+}
+
func (pl *PowerLevelsEventContent) GetUserLevel(userID id.UserID) int {
+ if pl.isCreator(userID) {
+ return math.MaxInt
+ }
pl.usersLock.RLock()
defer pl.usersLock.RUnlock()
level, ok := pl.Users[userID]
@@ -137,20 +147,58 @@ func (pl *PowerLevelsEventContent) GetUserLevel(userID id.UserID) int {
return level
}
+const maxPL = 1<<53 - 1
+
func (pl *PowerLevelsEventContent) SetUserLevel(userID id.UserID, level int) {
pl.usersLock.Lock()
defer pl.usersLock.Unlock()
+ if pl.isCreator(userID) {
+ return
+ }
+ if level == math.MaxInt && maxPL < math.MaxInt {
+ // Hack to avoid breaking on 32-bit systems (they're only slightly supported)
+ x := int64(maxPL)
+ level = int(x)
+ }
if level == pl.UsersDefault {
delete(pl.Users, userID)
} else {
+ if pl.Users == nil {
+ pl.Users = make(map[id.UserID]int)
+ }
pl.Users[userID] = level
}
}
-func (pl *PowerLevelsEventContent) EnsureUserLevel(userID id.UserID, level int) bool {
- existingLevel := pl.GetUserLevel(userID)
+func (pl *PowerLevelsEventContent) EnsureUserLevel(target id.UserID, level int) bool {
+ return pl.EnsureUserLevelAs("", target, level)
+}
+
+func (pl *PowerLevelsEventContent) createContent() *CreateEventContent {
+ if pl.CreateEvent == nil {
+ return &CreateEventContent{}
+ }
+ return pl.CreateEvent.Content.AsCreate()
+}
+
+func (pl *PowerLevelsEventContent) isCreator(userID id.UserID) bool {
+ cc := pl.createContent()
+ return cc.SupportsCreatorPower() && (userID == pl.CreateEvent.Sender || slices.Contains(cc.AdditionalCreators, userID))
+}
+
+func (pl *PowerLevelsEventContent) EnsureUserLevelAs(actor, target id.UserID, level int) bool {
+ if pl.isCreator(target) {
+ return false
+ }
+ existingLevel := pl.GetUserLevel(target)
+ if actor != "" && !pl.isCreator(actor) {
+ actorLevel := pl.GetUserLevel(actor)
+ if actorLevel <= existingLevel || actorLevel < level {
+ return false
+ }
+ }
if existingLevel != level {
- pl.SetUserLevel(userID, level)
+ pl.SetUserLevel(target, level)
return true
}
return false
@@ -169,18 +217,54 @@ func (pl *PowerLevelsEventContent) GetEventLevel(eventType Type) int {
return level
}
+func (pl *PowerLevelsEventContent) GetBeeperEphemeralLevel(eventType Type) int {
+ pl.beeperEphemeralLock.RLock()
+ defer pl.beeperEphemeralLock.RUnlock()
+ level, ok := pl.BeeperEphemeral[eventType.String()]
+ if !ok {
+ return pl.BeeperEphemeralDefault()
+ }
+ return level
+}
+
+func (pl *PowerLevelsEventContent) SetBeeperEphemeralLevel(eventType Type, level int) {
+ pl.beeperEphemeralLock.Lock()
+ defer pl.beeperEphemeralLock.Unlock()
+ if level == pl.BeeperEphemeralDefault() {
+ delete(pl.BeeperEphemeral, eventType.String())
+ } else {
+ if pl.BeeperEphemeral == nil {
+ pl.BeeperEphemeral = make(map[string]int)
+ }
+ pl.BeeperEphemeral[eventType.String()] = level
+ }
+}
+
func (pl *PowerLevelsEventContent) SetEventLevel(eventType Type, level int) {
pl.eventsLock.Lock()
defer pl.eventsLock.Unlock()
if (eventType.IsState() && level == pl.StateDefault()) || (!eventType.IsState() && level == pl.EventsDefault) {
delete(pl.Events, eventType.String())
} else {
+ if pl.Events == nil {
+ pl.Events = make(map[string]int)
+ }
pl.Events[eventType.String()] = level
}
}
func (pl *PowerLevelsEventContent) EnsureEventLevel(eventType Type, level int) bool {
+ return pl.EnsureEventLevelAs("", eventType, level)
+}
+
+func (pl *PowerLevelsEventContent) EnsureEventLevelAs(actor id.UserID, eventType Type, level int) bool {
existingLevel := pl.GetEventLevel(eventType)
+ if actor != "" && !pl.isCreator(actor) {
+ actorLevel := pl.GetUserLevel(actor)
+ if existingLevel > actorLevel || level > actorLevel {
+ return false
+ }
+ }
if existingLevel != level {
pl.SetEventLevel(eventType, level)
return true
diff --git a/event/powerlevels_ephemeral_test.go b/event/powerlevels_ephemeral_test.go
new file mode 100644
index 00000000..f5861583
--- /dev/null
+++ b/event/powerlevels_ephemeral_test.go
@@ -0,0 +1,67 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package event_test
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "maunium.net/go/mautrix/event"
+)
+
+func TestPowerLevelsEventContent_BeeperEphemeralDefaultFallsBackToEventsDefault(t *testing.T) {
+ pl := &event.PowerLevelsEventContent{
+ EventsDefault: 45,
+ }
+
+ assert.Equal(t, 45, pl.BeeperEphemeralDefault())
+
+ override := 60
+ pl.BeeperEphemeralDefaultPtr = &override
+ assert.Equal(t, 60, pl.BeeperEphemeralDefault())
+}
+
+func TestPowerLevelsEventContent_GetSetBeeperEphemeralLevel(t *testing.T) {
+ pl := &event.PowerLevelsEventContent{
+ EventsDefault: 25,
+ }
+ evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType}
+
+ assert.Equal(t, 25, pl.GetBeeperEphemeralLevel(evtType))
+
+ pl.SetBeeperEphemeralLevel(evtType, 50)
+ assert.Equal(t, 50, pl.GetBeeperEphemeralLevel(evtType))
+ require.NotNil(t, pl.BeeperEphemeral)
+ assert.Equal(t, 50, pl.BeeperEphemeral[evtType.String()])
+
+ pl.SetBeeperEphemeralLevel(evtType, 25)
+ _, exists := pl.BeeperEphemeral[evtType.String()]
+ assert.False(t, exists)
+}
+
+func TestPowerLevelsEventContent_CloneCopiesBeeperEphemeralFields(t *testing.T) {
+ override := 70
+ pl := &event.PowerLevelsEventContent{
+ EventsDefault: 35,
+ BeeperEphemeral: map[string]int{"com.example.ephemeral": 90},
+ BeeperEphemeralDefaultPtr: &override,
+ }
+
+ cloned := pl.Clone()
+ require.NotNil(t, cloned)
+ require.NotNil(t, cloned.BeeperEphemeralDefaultPtr)
+ assert.Equal(t, 70, *cloned.BeeperEphemeralDefaultPtr)
+ assert.Equal(t, 90, cloned.BeeperEphemeral["com.example.ephemeral"])
+
+ cloned.BeeperEphemeral["com.example.ephemeral"] = 99
+ *cloned.BeeperEphemeralDefaultPtr = 71
+
+ assert.Equal(t, 90, pl.BeeperEphemeral["com.example.ephemeral"])
+ assert.Equal(t, 70, *pl.BeeperEphemeralDefaultPtr)
+}
diff --git a/event/relations.go b/event/relations.go
index ecd7a959..2316cbc7 100644
--- a/event/relations.go
+++ b/event/relations.go
@@ -15,10 +15,11 @@ import (
type RelationType string
const (
- RelReplace RelationType = "m.replace"
- RelReference RelationType = "m.reference"
- RelAnnotation RelationType = "m.annotation"
- RelThread RelationType = "m.thread"
+ RelReplace RelationType = "m.replace"
+ RelReference RelationType = "m.reference"
+ RelAnnotation RelationType = "m.annotation"
+ RelThread RelationType = "m.thread"
+ RelBeeperTranscription RelationType = "com.beeper.transcription"
)
type RelatesTo struct {
@@ -33,7 +34,7 @@ type RelatesTo struct {
type InReplyTo struct {
EventID id.EventID `json:"event_id,omitempty"`
- UnstableRoomID id.RoomID `json:"room_id,omitempty"`
+ UnstableRoomID id.RoomID `json:"com.beeper.cross_room_id,omitempty"`
}
func (rel *RelatesTo) Copy() *RelatesTo {
@@ -73,7 +74,7 @@ func (rel *RelatesTo) GetReplyTo() id.EventID {
}
func (rel *RelatesTo) GetNonFallbackReplyTo() id.EventID {
- if rel != nil && rel.InReplyTo != nil && !rel.IsFallingBack {
+ if rel != nil && rel.InReplyTo != nil && (rel.Type != RelThread || !rel.IsFallingBack) {
return rel.InReplyTo.EventID
}
return ""
@@ -100,6 +101,10 @@ func (rel *RelatesTo) SetReplace(mxid id.EventID) *RelatesTo {
}
func (rel *RelatesTo) SetReplyTo(mxid id.EventID) *RelatesTo {
+ if rel.Type != RelThread {
+ rel.Type = ""
+ rel.EventID = ""
+ }
rel.InReplyTo = &InReplyTo{EventID: mxid}
rel.IsFallingBack = false
return rel
diff --git a/event/reply.go b/event/reply.go
index 73f8cfc7..5f55bb80 100644
--- a/event/reply.go
+++ b/event/reply.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2020 Tulir Asokan
+// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -7,7 +7,6 @@
package event
import (
- "fmt"
"regexp"
"strings"
@@ -33,12 +32,13 @@ func TrimReplyFallbackText(text string) string {
}
func (content *MessageEventContent) RemoveReplyFallback() {
- if len(content.RelatesTo.GetReplyTo()) > 0 && !content.replyFallbackRemoved {
- if content.Format == FormatHTML {
- content.FormattedBody = TrimReplyFallbackHTML(content.FormattedBody)
+ if len(content.RelatesTo.GetReplyTo()) > 0 && !content.replyFallbackRemoved && content.Format == FormatHTML {
+ origHTML := content.FormattedBody
+ content.FormattedBody = TrimReplyFallbackHTML(content.FormattedBody)
+ if content.FormattedBody != origHTML {
+ content.Body = TrimReplyFallbackText(content.Body)
+ content.replyFallbackRemoved = true
}
- content.Body = TrimReplyFallbackText(content.Body)
- content.replyFallbackRemoved = true
}
}
@@ -47,52 +47,28 @@ func (content *MessageEventContent) GetReplyTo() id.EventID {
return content.RelatesTo.GetReplyTo()
}
-const ReplyFormat = `In reply to %s
%s
`
-
-func (evt *Event) GenerateReplyFallbackHTML() string {
- parsedContent, ok := evt.Content.Parsed.(*MessageEventContent)
- if !ok {
- return ""
- }
- parsedContent.RemoveReplyFallback()
- body := parsedContent.FormattedBody
- if len(body) == 0 {
- body = TextToHTML(parsedContent.Body)
- }
-
- senderDisplayName := evt.Sender
-
- return fmt.Sprintf(ReplyFormat, evt.RoomID, evt.ID, evt.Sender, senderDisplayName, body)
-}
-
-func (evt *Event) GenerateReplyFallbackText() string {
- parsedContent, ok := evt.Content.Parsed.(*MessageEventContent)
- if !ok {
- return ""
- }
- parsedContent.RemoveReplyFallback()
- body := parsedContent.Body
- lines := strings.Split(strings.TrimSpace(body), "\n")
- firstLine, lines := lines[0], lines[1:]
-
- senderDisplayName := evt.Sender
-
- var fallbackText strings.Builder
- _, _ = fmt.Fprintf(&fallbackText, "> <%s> %s", senderDisplayName, firstLine)
- for _, line := range lines {
- _, _ = fmt.Fprintf(&fallbackText, "\n> %s", line)
- }
- fallbackText.WriteString("\n\n")
- return fallbackText.String()
-}
-
func (content *MessageEventContent) SetReply(inReplyTo *Event) {
- content.RelatesTo = (&RelatesTo{}).SetReplyTo(inReplyTo.ID)
-
- if content.MsgType == MsgText || content.MsgType == MsgNotice {
- content.EnsureHasHTML()
- content.FormattedBody = inReplyTo.GenerateReplyFallbackHTML() + content.FormattedBody
- content.Body = inReplyTo.GenerateReplyFallbackText() + content.Body
- content.replyFallbackRemoved = false
+ if content.RelatesTo == nil {
+ content.RelatesTo = &RelatesTo{}
}
+ content.RelatesTo.SetReplyTo(inReplyTo.ID)
+ if content.Mentions == nil {
+ content.Mentions = &Mentions{}
+ }
+ content.Mentions.Add(inReplyTo.Sender)
+}
+
+func (content *MessageEventContent) SetThread(inReplyTo *Event) {
+ root := inReplyTo.ID
+ relatable, ok := inReplyTo.Content.Parsed.(Relatable)
+ if ok {
+ targetRoot := relatable.OptionalGetRelatesTo().GetThreadParent()
+ if targetRoot != "" {
+ root = targetRoot
+ }
+ }
+ if content.RelatesTo == nil {
+ content.RelatesTo = &RelatesTo{}
+ }
+ content.RelatesTo.SetThread(root, inReplyTo.ID)
}
diff --git a/event/state.go b/event/state.go
index d6b6cf70..ace170a5 100644
--- a/event/state.go
+++ b/event/state.go
@@ -7,6 +7,12 @@
package event
import (
+ "encoding/base64"
+ "encoding/json"
+ "slices"
+
+ "go.mau.fi/util/jsontime"
+
"maunium.net/go/mautrix/id"
)
@@ -26,8 +32,9 @@ type RoomNameEventContent struct {
// RoomAvatarEventContent represents the content of a m.room.avatar state event.
// https://spec.matrix.org/v1.2/client-server-api/#mroomavatar
type RoomAvatarEventContent struct {
- URL id.ContentURI `json:"url"`
- Info *FileInfo `json:"info,omitempty"`
+ URL id.ContentURIString `json:"url,omitempty"`
+ Info *FileInfo `json:"info,omitempty"`
+ MSC3414File *EncryptedFileInfo `json:"org.matrix.msc3414.file,omitempty"`
}
// ServerACLEventContent represents the content of a m.room.server_acl state event.
@@ -41,7 +48,52 @@ type ServerACLEventContent struct {
// TopicEventContent represents the content of a m.room.topic state event.
// https://spec.matrix.org/v1.2/client-server-api/#mroomtopic
type TopicEventContent struct {
- Topic string `json:"topic"`
+ Topic string `json:"topic"`
+ ExtensibleTopic *ExtensibleTopic `json:"m.topic,omitempty"`
+}
+
+// ExtensibleTopic represents the contents of the m.topic field within the
+// m.room.topic state event as described in [MSC3765].
+//
+// [MSC3765]: https://github.com/matrix-org/matrix-spec-proposals/pull/3765
+type ExtensibleTopic = ExtensibleTextContainer
+
+type ExtensibleTextContainer struct {
+ Text []ExtensibleText `json:"m.text"`
+}
+
+func (c *ExtensibleTextContainer) Equals(description *ExtensibleTextContainer) bool {
+ if c == nil || description == nil {
+ return c == description
+ }
+ return slices.Equal(c.Text, description.Text)
+}
+
+func MakeExtensibleText(text string) *ExtensibleTextContainer {
+ return &ExtensibleTextContainer{
+ Text: []ExtensibleText{{
+ Body: text,
+ MimeType: "text/plain",
+ }},
+ }
+}
+
+func MakeExtensibleFormattedText(plaintext, html string) *ExtensibleTextContainer {
+ return &ExtensibleTextContainer{
+ Text: []ExtensibleText{{
+ Body: plaintext,
+ MimeType: "text/plain",
+ }, {
+ Body: html,
+ MimeType: "text/html",
+ }},
+ }
+}
+
+// ExtensibleText represents the contents of an m.text field.
+type ExtensibleText struct {
+ MimeType string `json:"mimetype,omitempty"`
+ Body string `json:"body"`
}
// TombstoneEventContent represents the content of a m.room.tombstone state event.
@@ -51,19 +103,64 @@ type TombstoneEventContent struct {
ReplacementRoom id.RoomID `json:"replacement_room"`
}
+func (tec *TombstoneEventContent) GetReplacementRoom() id.RoomID {
+ if tec == nil {
+ return ""
+ }
+ return tec.ReplacementRoom
+}
+
type Predecessor struct {
RoomID id.RoomID `json:"room_id"`
EventID id.EventID `json:"event_id"`
}
+// Deprecated: use id.RoomVersion instead
+type RoomVersion = id.RoomVersion
+
+// Deprecated: use id.RoomVX constants instead
+const (
+ RoomV1 = id.RoomV1
+ RoomV2 = id.RoomV2
+ RoomV3 = id.RoomV3
+ RoomV4 = id.RoomV4
+ RoomV5 = id.RoomV5
+ RoomV6 = id.RoomV6
+ RoomV7 = id.RoomV7
+ RoomV8 = id.RoomV8
+ RoomV9 = id.RoomV9
+ RoomV10 = id.RoomV10
+ RoomV11 = id.RoomV11
+ RoomV12 = id.RoomV12
+)
+
// CreateEventContent represents the content of a m.room.create state event.
// https://spec.matrix.org/v1.2/client-server-api/#mroomcreate
type CreateEventContent struct {
- Type RoomType `json:"type,omitempty"`
- Creator id.UserID `json:"creator,omitempty"`
- Federate bool `json:"m.federate,omitempty"`
- RoomVersion string `json:"room_version,omitempty"`
- Predecessor *Predecessor `json:"predecessor,omitempty"`
+ Type RoomType `json:"type,omitempty"`
+ Federate *bool `json:"m.federate,omitempty"`
+ RoomVersion id.RoomVersion `json:"room_version,omitempty"`
+ Predecessor *Predecessor `json:"predecessor,omitempty"`
+
+ // Room v12+ only
+ AdditionalCreators []id.UserID `json:"additional_creators,omitempty"`
+
+ // Deprecated: use the event sender instead
+ Creator id.UserID `json:"creator,omitempty"`
+}
+
+func (cec *CreateEventContent) GetPredecessor() (p Predecessor) {
+ if cec != nil && cec.Predecessor != nil {
+ p = *cec.Predecessor
+ }
+ return
+}
+
+func (cec *CreateEventContent) SupportsCreatorPower() bool {
+ if cec == nil {
+ return false
+ }
+ return cec.RoomVersion.PrivilegedRoomCreators()
}
// JoinRule specifies how open a room is to new members.
@@ -71,11 +168,12 @@ type CreateEventContent struct {
type JoinRule string
const (
- JoinRulePublic JoinRule = "public"
- JoinRuleKnock JoinRule = "knock"
- JoinRuleInvite JoinRule = "invite"
- JoinRuleRestricted JoinRule = "restricted"
- JoinRulePrivate JoinRule = "private"
+ JoinRulePublic JoinRule = "public"
+ JoinRuleKnock JoinRule = "knock"
+ JoinRuleInvite JoinRule = "invite"
+ JoinRuleRestricted JoinRule = "restricted"
+ JoinRuleKnockRestricted JoinRule = "knock_restricted"
+ JoinRulePrivate JoinRule = "private"
)
// JoinRulesEventContent represents the content of a m.room.join_rules state event.
@@ -139,6 +237,9 @@ type BridgeInfoSection struct {
DisplayName string `json:"displayname,omitempty"`
AvatarURL id.ContentURIString `json:"avatar_url,omitempty"`
ExternalURL string `json:"external_url,omitempty"`
+
+ Receiver string `json:"fi.mau.receiver,omitempty"`
+ MessageRequest bool `json:"com.beeper.message_request,omitempty"`
}
// BridgeEventContent represents the content of a m.bridge state event.
@@ -149,6 +250,35 @@ type BridgeEventContent struct {
Protocol BridgeInfoSection `json:"protocol"`
Network *BridgeInfoSection `json:"network,omitempty"`
Channel BridgeInfoSection `json:"channel"`
+
+ BeeperRoomType string `json:"com.beeper.room_type,omitempty"`
+ BeeperRoomTypeV2 string `json:"com.beeper.room_type.v2,omitempty"`
+
+ TempSlackRemoteIDMigratedFlag bool `json:"com.beeper.slack_remote_id_migrated,omitempty"`
+ TempSlackRemoteIDMigratedFlag2 bool `json:"com.beeper.slack_remote_id_really_migrated,omitempty"`
+}
+
+// DisappearingType represents the type of a disappearing message timer.
+type DisappearingType string
+
+const (
+ DisappearingTypeNone DisappearingType = ""
+ DisappearingTypeAfterRead DisappearingType = "after_read"
+ DisappearingTypeAfterSend DisappearingType = "after_send"
+)
+
+type BeeperDisappearingTimer struct {
+ Type DisappearingType `json:"type"`
+ Timer jsontime.Milliseconds `json:"timer"`
+}
+
+type marshalableBeeperDisappearingTimer BeeperDisappearingTimer
+
+func (bdt *BeeperDisappearingTimer) MarshalJSON() ([]byte, error) {
+ if bdt == nil || bdt.Type == DisappearingTypeNone {
+ return []byte("{}"), nil
+ }
+ return json.Marshal((*marshalableBeeperDisappearingTimer)(bdt))
}
type SpaceChildEventContent struct {
@@ -162,16 +292,66 @@ type SpaceParentEventContent struct {
Canonical bool `json:"canonical,omitempty"`
}
+type PolicyRecommendation string
+
+const (
+ PolicyRecommendationBan PolicyRecommendation = "m.ban"
+ PolicyRecommendationUnstableTakedown PolicyRecommendation = "org.matrix.msc4204.takedown"
+ PolicyRecommendationUnstableBan PolicyRecommendation = "org.matrix.mjolnir.ban"
+ PolicyRecommendationUnban PolicyRecommendation = "fi.mau.meowlnir.unban"
+)
+
+type PolicyHashes struct {
+ SHA256 string `json:"sha256"`
+}
+
+func (ph *PolicyHashes) DecodeSHA256() *[32]byte {
+ if ph == nil || ph.SHA256 == "" {
+ return nil
+ }
+ decoded, _ := base64.StdEncoding.DecodeString(ph.SHA256)
+ if len(decoded) == 32 {
+ return (*[32]byte)(decoded)
+ }
+ return nil
+}
+
// ModPolicyContent represents the content of a m.room.rule.user, m.room.rule.room, and m.room.rule.server state event.
// https://spec.matrix.org/v1.2/client-server-api/#moderation-policy-lists
type ModPolicyContent struct {
- Entity string `json:"entity"`
- Reason string `json:"reason"`
- Recommendation string `json:"recommendation"`
+ Entity string `json:"entity,omitempty"`
+ Reason string `json:"reason"`
+ Recommendation PolicyRecommendation `json:"recommendation"`
+ UnstableHashes *PolicyHashes `json:"org.matrix.msc4205.hashes,omitempty"`
}
-// Deprecated: MSC2716 has been abandoned
-type InsertionMarkerContent struct {
- InsertionID id.EventID `json:"org.matrix.msc2716.marker.insertion"`
- Timestamp int64 `json:"com.beeper.timestamp,omitempty"`
+func (mpc *ModPolicyContent) EntityOrHash() string {
+ if mpc.UnstableHashes != nil && mpc.UnstableHashes.SHA256 != "" {
+ return mpc.UnstableHashes.SHA256
+ }
+ return mpc.Entity
+}
+
+type ElementFunctionalMembersContent struct {
+ ServiceMembers []id.UserID `json:"service_members"`
+}
+
+func (efmc *ElementFunctionalMembersContent) Add(mxid id.UserID) bool {
+ if slices.Contains(efmc.ServiceMembers, mxid) {
+ return false
+ }
+ efmc.ServiceMembers = append(efmc.ServiceMembers, mxid)
+ return true
+}
+
+type PolicyServerPublicKeys struct {
+ Ed25519 id.Ed25519 `json:"ed25519,omitempty"`
+}
+
+type RoomPolicyEventContent struct {
+ Via string `json:"via,omitempty"`
+ PublicKeys *PolicyServerPublicKeys `json:"public_keys,omitempty"`
+
+ // Deprecated, only for legacy use
+ PublicKey id.Ed25519 `json:"public_key,omitempty"`
}
diff --git a/event/type.go b/event/type.go
index 2f4f4f94..80b86728 100644
--- a/event/type.go
+++ b/event/type.go
@@ -10,6 +10,8 @@ import (
"encoding/json"
"fmt"
"strings"
+
+ "maunium.net/go/mautrix/id"
)
type RoomType string
@@ -106,23 +108,27 @@ func (et *Type) IsCustom() bool {
func (et *Type) GuessClass() TypeClass {
switch et.Type {
- case StateAliases.Type, StateCanonicalAlias.Type, StateCreate.Type, StateJoinRules.Type, StateMember.Type,
+ case StateAliases.Type, StateCanonicalAlias.Type, StateCreate.Type, StateJoinRules.Type, StateMember.Type, StateThirdPartyInvite.Type,
StatePowerLevels.Type, StateRoomName.Type, StateRoomAvatar.Type, StateServerACL.Type, StateTopic.Type,
StatePinnedEvents.Type, StateTombstone.Type, StateEncryption.Type, StateBridge.Type, StateHalfShotBridge.Type,
StateSpaceParent.Type, StateSpaceChild.Type, StatePolicyRoom.Type, StatePolicyServer.Type, StatePolicyUser.Type,
- StateInsertionMarker.Type:
+ StateElementFunctionalMembers.Type, StateBeeperRoomFeatures.Type, StateBeeperDisappearingTimer.Type,
+ StateMSC4391BotCommand.Type, StateRoomPolicy.Type, StateUnstableRoomPolicy.Type:
return StateEventType
- case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type:
+ case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type, BeeperEphemeralEventAIStream.Type:
return EphemeralEventType
case AccountDataDirectChats.Type, AccountDataPushRules.Type, AccountDataRoomTags.Type,
+ AccountDataFullyRead.Type, AccountDataIgnoredUserList.Type, AccountDataMarkedUnread.Type,
AccountDataSecretStorageKey.Type, AccountDataSecretStorageDefaultKey.Type,
- AccountDataCrossSigningMaster.Type, AccountDataCrossSigningSelf.Type, AccountDataCrossSigningUser.Type:
+ AccountDataCrossSigningMaster.Type, AccountDataCrossSigningSelf.Type, AccountDataCrossSigningUser.Type,
+ AccountDataFullyRead.Type, AccountDataMegolmBackupKey.Type:
return AccountDataEventType
case EventRedaction.Type, EventMessage.Type, EventEncrypted.Type, EventReaction.Type, EventSticker.Type,
InRoomVerificationStart.Type, InRoomVerificationReady.Type, InRoomVerificationAccept.Type,
InRoomVerificationKey.Type, InRoomVerificationMAC.Type, InRoomVerificationCancel.Type,
CallInvite.Type, CallCandidates.Type, CallAnswer.Type, CallReject.Type, CallSelectAnswer.Type,
- CallNegotiate.Type, CallHangup.Type, BeeperMessageStatus.Type:
+ CallNegotiate.Type, CallHangup.Type, BeeperMessageStatus.Type, EventUnstablePollStart.Type, EventUnstablePollResponse.Type,
+ EventUnstablePollEnd.Type, BeeperTranscription.Type, BeeperDeleteChat.Type, BeeperAcceptMessageRequest.Type:
return MessageEventType
case ToDeviceRoomKey.Type, ToDeviceRoomKeyRequest.Type, ToDeviceForwardedRoomKey.Type, ToDeviceRoomKeyWithheld.Type,
ToDeviceBeeperRoomKeyAck.Type:
@@ -145,7 +151,7 @@ func (et *Type) MarshalJSON() ([]byte, error) {
return json.Marshal(&et.Type)
}
-func (et Type) UnmarshalText(data []byte) error {
+func (et *Type) UnmarshalText(data []byte) error {
et.Type = string(data)
et.Class = et.GuessClass()
return nil
@@ -155,11 +161,11 @@ func (et Type) MarshalText() ([]byte, error) {
return []byte(et.Type), nil
}
-func (et *Type) String() string {
+func (et Type) String() string {
return et.Type
}
-func (et *Type) Repr() string {
+func (et Type) Repr() string {
return fmt.Sprintf("%s (%s)", et.Type, et.Class.Name())
}
@@ -172,6 +178,7 @@ var (
StateHistoryVisibility = Type{"m.room.history_visibility", StateEventType}
StateGuestAccess = Type{"m.room.guest_access", StateEventType}
StateMember = Type{"m.room.member", StateEventType}
+ StateThirdPartyInvite = Type{"m.room.third_party_invite", StateEventType}
StatePowerLevels = Type{"m.room.power_levels", StateEventType}
StateRoomName = Type{"m.room.name", StateEventType}
StateTopic = Type{"m.room.topic", StateEventType}
@@ -188,8 +195,20 @@ var (
StateSpaceChild = Type{"m.space.child", StateEventType}
StateSpaceParent = Type{"m.space.parent", StateEventType}
- // Deprecated: MSC2716 has been abandoned
- StateInsertionMarker = Type{"org.matrix.msc2716.marker", StateEventType}
+ StateRoomPolicy = Type{"m.room.policy", StateEventType}
+ StateUnstableRoomPolicy = Type{"org.matrix.msc4284.policy", StateEventType}
+
+ StateLegacyPolicyRoom = Type{"m.room.rule.room", StateEventType}
+ StateLegacyPolicyServer = Type{"m.room.rule.server", StateEventType}
+ StateLegacyPolicyUser = Type{"m.room.rule.user", StateEventType}
+ StateUnstablePolicyRoom = Type{"org.matrix.mjolnir.rule.room", StateEventType}
+ StateUnstablePolicyServer = Type{"org.matrix.mjolnir.rule.server", StateEventType}
+ StateUnstablePolicyUser = Type{"org.matrix.mjolnir.rule.user", StateEventType}
+
+ StateElementFunctionalMembers = Type{"io.element.functional_members", StateEventType}
+ StateBeeperRoomFeatures = Type{"com.beeper.room_features", StateEventType}
+ StateBeeperDisappearingTimer = Type{"com.beeper.disappearing_timer", StateEventType}
+ StateMSC4391BotCommand = Type{"org.matrix.msc4391.command_description", StateEventType}
)
// Message events
@@ -200,12 +219,15 @@ var (
EventReaction = Type{"m.reaction", MessageEventType}
EventSticker = Type{"m.sticker", MessageEventType}
- InRoomVerificationStart = Type{"m.key.verification.start", MessageEventType}
InRoomVerificationReady = Type{"m.key.verification.ready", MessageEventType}
+ InRoomVerificationStart = Type{"m.key.verification.start", MessageEventType}
+ InRoomVerificationDone = Type{"m.key.verification.done", MessageEventType}
+ InRoomVerificationCancel = Type{"m.key.verification.cancel", MessageEventType}
+
+ // SAS Verification Events
InRoomVerificationAccept = Type{"m.key.verification.accept", MessageEventType}
InRoomVerificationKey = Type{"m.key.verification.key", MessageEventType}
InRoomVerificationMAC = Type{"m.key.verification.mac", MessageEventType}
- InRoomVerificationCancel = Type{"m.key.verification.cancel", MessageEventType}
CallInvite = Type{"m.call.invite", MessageEventType}
CallCandidates = Type{"m.call.candidates", MessageEventType}
@@ -215,14 +237,24 @@ var (
CallNegotiate = Type{"m.call.negotiate", MessageEventType}
CallHangup = Type{"m.call.hangup", MessageEventType}
- BeeperMessageStatus = Type{"com.beeper.message_send_status", MessageEventType}
+ BeeperMessageStatus = Type{"com.beeper.message_send_status", MessageEventType}
+ BeeperTranscription = Type{"com.beeper.transcription", MessageEventType}
+ BeeperDeleteChat = Type{"com.beeper.delete_chat", MessageEventType}
+ BeeperAcceptMessageRequest = Type{"com.beeper.accept_message_request", MessageEventType}
+ BeeperSendState = Type{"com.beeper.send_state", MessageEventType}
+
+ EventUnstablePollStart = Type{Type: "org.matrix.msc3381.poll.start", Class: MessageEventType}
+ EventUnstablePollResponse = Type{Type: "org.matrix.msc3381.poll.response", Class: MessageEventType}
+ EventUnstablePollEnd = Type{Type: "org.matrix.msc3381.poll.end", Class: MessageEventType}
)
// Ephemeral events
var (
- EphemeralEventReceipt = Type{"m.receipt", EphemeralEventType}
- EphemeralEventTyping = Type{"m.typing", EphemeralEventType}
- EphemeralEventPresence = Type{"m.presence", EphemeralEventType}
+ EphemeralEventReceipt = Type{"m.receipt", EphemeralEventType}
+ EphemeralEventTyping = Type{"m.typing", EphemeralEventType}
+ EphemeralEventPresence = Type{"m.presence", EphemeralEventType}
+ EphemeralEventEncrypted = Type{"m.room.encrypted", EphemeralEventType}
+ BeeperEphemeralEventAIStream = Type{"com.beeper.ai.stream_event", EphemeralEventType}
)
// Account data events
@@ -232,29 +264,39 @@ var (
AccountDataRoomTags = Type{"m.tag", AccountDataEventType}
AccountDataFullyRead = Type{"m.fully_read", AccountDataEventType}
AccountDataIgnoredUserList = Type{"m.ignored_user_list", AccountDataEventType}
+ AccountDataMarkedUnread = Type{"m.marked_unread", AccountDataEventType}
+ AccountDataBeeperMute = Type{"com.beeper.mute", AccountDataEventType}
AccountDataSecretStorageDefaultKey = Type{"m.secret_storage.default_key", AccountDataEventType}
AccountDataSecretStorageKey = Type{"m.secret_storage.key", AccountDataEventType}
- AccountDataCrossSigningMaster = Type{"m.cross_signing.master", AccountDataEventType}
- AccountDataCrossSigningUser = Type{"m.cross_signing.user_signing", AccountDataEventType}
- AccountDataCrossSigningSelf = Type{"m.cross_signing.self_signing", AccountDataEventType}
+ AccountDataCrossSigningMaster = Type{string(id.SecretXSMaster), AccountDataEventType}
+ AccountDataCrossSigningUser = Type{string(id.SecretXSUserSigning), AccountDataEventType}
+ AccountDataCrossSigningSelf = Type{string(id.SecretXSSelfSigning), AccountDataEventType}
+ AccountDataMegolmBackupKey = Type{string(id.SecretMegolmBackupV1), AccountDataEventType}
)
// Device-to-device events
var (
- ToDeviceRoomKey = Type{"m.room_key", ToDeviceEventType}
- ToDeviceRoomKeyRequest = Type{"m.room_key_request", ToDeviceEventType}
- ToDeviceForwardedRoomKey = Type{"m.forwarded_room_key", ToDeviceEventType}
- ToDeviceEncrypted = Type{"m.room.encrypted", ToDeviceEventType}
- ToDeviceRoomKeyWithheld = Type{"m.room_key.withheld", ToDeviceEventType}
- ToDeviceDummy = Type{"m.dummy", ToDeviceEventType}
+ ToDeviceRoomKey = Type{"m.room_key", ToDeviceEventType}
+ ToDeviceRoomKeyRequest = Type{"m.room_key_request", ToDeviceEventType}
+ ToDeviceForwardedRoomKey = Type{"m.forwarded_room_key", ToDeviceEventType}
+ ToDeviceEncrypted = Type{"m.room.encrypted", ToDeviceEventType}
+ ToDeviceRoomKeyWithheld = Type{"m.room_key.withheld", ToDeviceEventType}
+ ToDeviceSecretRequest = Type{"m.secret.request", ToDeviceEventType}
+ ToDeviceSecretSend = Type{"m.secret.send", ToDeviceEventType}
+ ToDeviceDummy = Type{"m.dummy", ToDeviceEventType}
+
ToDeviceVerificationRequest = Type{"m.key.verification.request", ToDeviceEventType}
+ ToDeviceVerificationReady = Type{"m.key.verification.ready", ToDeviceEventType}
ToDeviceVerificationStart = Type{"m.key.verification.start", ToDeviceEventType}
- ToDeviceVerificationAccept = Type{"m.key.verification.accept", ToDeviceEventType}
- ToDeviceVerificationKey = Type{"m.key.verification.key", ToDeviceEventType}
- ToDeviceVerificationMAC = Type{"m.key.verification.mac", ToDeviceEventType}
+ ToDeviceVerificationDone = Type{"m.key.verification.done", ToDeviceEventType}
ToDeviceVerificationCancel = Type{"m.key.verification.cancel", ToDeviceEventType}
+ // SAS Verification Events
+ ToDeviceVerificationAccept = Type{"m.key.verification.accept", ToDeviceEventType}
+ ToDeviceVerificationKey = Type{"m.key.verification.key", ToDeviceEventType}
+ ToDeviceVerificationMAC = Type{"m.key.verification.mac", ToDeviceEventType}
+
ToDeviceOrgMatrixRoomKeyWithheld = Type{"org.matrix.room_key.withheld", ToDeviceEventType}
ToDeviceBeeperRoomKeyAck = Type{"com.beeper.room_key.ack", ToDeviceEventType}
diff --git a/event/verification.go b/event/verification.go
index 8410904d..6101896f 100644
--- a/event/verification.go
+++ b/event/verification.go
@@ -7,301 +7,302 @@
package event
import (
+ "go.mau.fi/util/jsonbytes"
+ "go.mau.fi/util/jsontime"
+
"maunium.net/go/mautrix/id"
)
type VerificationMethod string
-const VerificationMethodSAS VerificationMethod = "m.sas.v1"
+const (
+ VerificationMethodSAS VerificationMethod = "m.sas.v1"
-// VerificationRequestEventContent represents the content of a m.key.verification.request to_device event.
-// https://spec.matrix.org/v1.2/client-server-api/#mkeyverificationrequest
-type VerificationRequestEventContent struct {
- // The device ID which is initiating the request.
- FromDevice id.DeviceID `json:"from_device"`
- // An opaque identifier for the verification request. Must be unique with respect to the devices involved.
- TransactionID string `json:"transaction_id,omitempty"`
- // The verification methods supported by the sender.
- Methods []VerificationMethod `json:"methods"`
- // The POSIX timestamp in milliseconds for when the request was made.
- Timestamp int64 `json:"timestamp,omitempty"`
- // The user that the event is sent to for in-room verification.
- To id.UserID `json:"to,omitempty"`
- // Original event ID for in-room verification.
+ VerificationMethodReciprocate VerificationMethod = "m.reciprocate.v1"
+ VerificationMethodQRCodeShow VerificationMethod = "m.qr_code.show.v1"
+ VerificationMethodQRCodeScan VerificationMethod = "m.qr_code.scan.v1"
+)
+
+type VerificationTransactionable interface {
+ GetTransactionID() id.VerificationTransactionID
+ SetTransactionID(id.VerificationTransactionID)
+}
+
+// ToDeviceVerificationEvent contains the fields common to all to-device
+// verification events.
+type ToDeviceVerificationEvent struct {
+ // TransactionID is an opaque identifier for the verification request. Must
+ // be unique with respect to the devices involved.
+ TransactionID id.VerificationTransactionID `json:"transaction_id,omitempty"`
+}
+
+var _ VerificationTransactionable = (*ToDeviceVerificationEvent)(nil)
+
+func (ve *ToDeviceVerificationEvent) GetTransactionID() id.VerificationTransactionID {
+ return ve.TransactionID
+}
+
+func (ve *ToDeviceVerificationEvent) SetTransactionID(id id.VerificationTransactionID) {
+ ve.TransactionID = id
+}
+
+// InRoomVerificationEvent contains the fields common to all in-room
+// verification events.
+type InRoomVerificationEvent struct {
+ // RelatesTo indicates the m.key.verification.request that this message is
+ // related to. Note that for encrypted messages, this property should be in
+ // the unencrypted portion of the event.
RelatesTo *RelatesTo `json:"m.relates_to,omitempty"`
}
-func (vrec *VerificationRequestEventContent) SupportsVerificationMethod(meth VerificationMethod) bool {
- for _, supportedMeth := range vrec.Methods {
- if supportedMeth == meth {
- return true
- }
+var _ Relatable = (*InRoomVerificationEvent)(nil)
+
+func (ve *InRoomVerificationEvent) GetRelatesTo() *RelatesTo {
+ if ve.RelatesTo == nil {
+ ve.RelatesTo = &RelatesTo{}
}
- return false
+ return ve.RelatesTo
+}
+
+func (ve *InRoomVerificationEvent) OptionalGetRelatesTo() *RelatesTo {
+ return ve.RelatesTo
+}
+
+func (ve *InRoomVerificationEvent) SetRelatesTo(rel *RelatesTo) {
+ ve.RelatesTo = rel
+}
+
+// VerificationRequestEventContent represents the content of an
+// [m.key.verification.request] to-device event as described in [Section
+// 11.12.2.1] of the Spec.
+//
+// For the in-room version, use a standard [MessageEventContent] struct.
+//
+// [m.key.verification.request]: https://spec.matrix.org/v1.9/client-server-api/#mkeyverificationrequest
+// [Section 11.12.2.1]: https://spec.matrix.org/v1.9/client-server-api/#key-verification-framework
+type VerificationRequestEventContent struct {
+ ToDeviceVerificationEvent
+ // FromDevice is the device ID which is initiating the request.
+ FromDevice id.DeviceID `json:"from_device"`
+ // Methods is a list of the verification methods supported by the sender.
+ Methods []VerificationMethod `json:"methods"`
+ // Timestamp is the time at which the request was made.
+ Timestamp jsontime.UnixMilli `json:"timestamp,omitempty"`
+}
+
+// VerificationRequestEventContentFromMessage converts an in-room verification
+// request message event to a [VerificationRequestEventContent].
+func VerificationRequestEventContentFromMessage(evt *Event) *VerificationRequestEventContent {
+ content := evt.Content.AsMessage()
+ return &VerificationRequestEventContent{
+ ToDeviceVerificationEvent: ToDeviceVerificationEvent{
+ TransactionID: id.VerificationTransactionID(evt.ID),
+ },
+ Timestamp: jsontime.UMInt(evt.Timestamp),
+ FromDevice: content.FromDevice,
+ Methods: content.Methods,
+ }
+}
+
+// VerificationReadyEventContent represents the content of an
+// [m.key.verification.ready] event (both the to-device and the in-room
+// version) as described in [Section 11.12.2.1] of the Spec.
+//
+// [m.key.verification.ready]: https://spec.matrix.org/v1.9/client-server-api/#mkeyverificationready
+// [Section 11.12.2.1]: https://spec.matrix.org/v1.9/client-server-api/#key-verification-framework
+type VerificationReadyEventContent struct {
+ ToDeviceVerificationEvent
+ InRoomVerificationEvent
+
+ // FromDevice is the device ID which is initiating the request.
+ FromDevice id.DeviceID `json:"from_device"`
+ // Methods is a list of the verification methods supported by the sender.
+ Methods []VerificationMethod `json:"methods"`
}
type KeyAgreementProtocol string
const (
- KeyAgreementCurve25519 KeyAgreementProtocol = "curve25519"
- KeyAgreementCurve25519HKDFSHA256 KeyAgreementProtocol = "curve25519-hkdf-sha256"
+ KeyAgreementProtocolCurve25519 KeyAgreementProtocol = "curve25519"
+ KeyAgreementProtocolCurve25519HKDFSHA256 KeyAgreementProtocol = "curve25519-hkdf-sha256"
)
type VerificationHashMethod string
-const VerificationHashSHA256 VerificationHashMethod = "sha256"
+const VerificationHashMethodSHA256 VerificationHashMethod = "sha256"
type MACMethod string
-const HKDFHMACSHA256 MACMethod = "hkdf-hmac-sha256"
+const (
+ MACMethodHKDFHMACSHA256 MACMethod = "hkdf-hmac-sha256"
+ MACMethodHKDFHMACSHA256V2 MACMethod = "hkdf-hmac-sha256.v2"
+)
type SASMethod string
const (
- SASDecimal SASMethod = "decimal"
- SASEmoji SASMethod = "emoji"
+ SASMethodDecimal SASMethod = "decimal"
+ SASMethodEmoji SASMethod = "emoji"
)
-// VerificationStartEventContent represents the content of a m.key.verification.start to_device event.
-// https://spec.matrix.org/v1.2/client-server-api/#mkeyverificationstartmsasv1
+// VerificationStartEventContent represents the content of an
+// [m.key.verification.start] event (both the to-device and the in-room
+// version) as described in [Section 11.12.2.1] of the Spec.
+//
+// This struct also contains the fields for an [m.key.verification.start] event
+// using the [VerificationMethodSAS] method as described in [Section
+// 11.12.2.2.2] and an [m.key.verification.start] using
+// [VerificationMethodReciprocate] as described in [Section 11.12.2.4.2].
+//
+// [m.key.verification.start]: https://spec.matrix.org/v1.9/client-server-api/#mkeyverificationstart
+// [Section 11.12.2.1]: https://spec.matrix.org/v1.9/client-server-api/#key-verification-framework
+// [Section 11.12.2.2.2]: https://spec.matrix.org/v1.9/client-server-api/#verification-messages-specific-to-sas
+// [Section 11.12.2.4.2]: https://spec.matrix.org/v1.9/client-server-api/#verification-messages-specific-to-qr-codes
type VerificationStartEventContent struct {
- // The device ID which is initiating the process.
+ ToDeviceVerificationEvent
+ InRoomVerificationEvent
+
+ // FromDevice is the device ID which is initiating the request.
FromDevice id.DeviceID `json:"from_device"`
- // An opaque identifier for the verification process. Must be unique with respect to the devices involved.
- TransactionID string `json:"transaction_id,omitempty"`
- // The verification method to use.
+ // Method is the verification method to use.
Method VerificationMethod `json:"method"`
- // The key agreement protocols the sending device understands.
- KeyAgreementProtocols []KeyAgreementProtocol `json:"key_agreement_protocols"`
- // The hash methods the sending device understands.
- Hashes []VerificationHashMethod `json:"hashes"`
- // The message authentication codes that the sending device understands.
+ // NextMethod is an optional method to use to verify the other user's key.
+ // Applicable when the method chosen only verifies one user’s key. This
+ // field will never be present if the method verifies keys both ways.
+ NextMethod VerificationMethod `json:"next_method,omitempty"`
+
+ // Hashes are the hash methods the sending device understands. This field
+ // is only applicable when the method is m.sas.v1.
+ Hashes []VerificationHashMethod `json:"hashes,omitempty"`
+ // KeyAgreementProtocols is the list of key agreement protocols the sending
+ // device understands. This field is only applicable when the method is
+ // m.sas.v1.
+ KeyAgreementProtocols []KeyAgreementProtocol `json:"key_agreement_protocols,omitempty"`
+ // MessageAuthenticationCodes is a list of the MAC methods that the sending
+ // device understands. This field is only applicable when the method is
+ // m.sas.v1.
MessageAuthenticationCodes []MACMethod `json:"message_authentication_codes"`
- // The SAS methods the sending device (and the sending device's user) understands.
+ // ShortAuthenticationString is a list of SAS methods the sending device
+ // (and the sending device's user) understands. This field is only
+ // applicable when the method is m.sas.v1.
ShortAuthenticationString []SASMethod `json:"short_authentication_string"`
- // The user that the event is sent to for in-room verification.
- To id.UserID `json:"to,omitempty"`
- // Original event ID for in-room verification.
- RelatesTo *RelatesTo `json:"m.relates_to,omitempty"`
+
+ // Secret is the shared secret from the QR code. This field is only
+ // applicable when the method is m.reciprocate.v1.
+ Secret jsonbytes.UnpaddedBytes `json:"secret,omitempty"`
}
-func (vsec *VerificationStartEventContent) SupportsKeyAgreementProtocol(proto KeyAgreementProtocol) bool {
- for _, supportedProto := range vsec.KeyAgreementProtocols {
- if supportedProto == proto {
- return true
- }
- }
- return false
-}
-
-func (vsec *VerificationStartEventContent) SupportsHashMethod(alg VerificationHashMethod) bool {
- for _, supportedAlg := range vsec.Hashes {
- if supportedAlg == alg {
- return true
- }
- }
- return false
-}
-
-func (vsec *VerificationStartEventContent) SupportsMACMethod(meth MACMethod) bool {
- for _, supportedMeth := range vsec.MessageAuthenticationCodes {
- if supportedMeth == meth {
- return true
- }
- }
- return false
-}
-
-func (vsec *VerificationStartEventContent) SupportsSASMethod(meth SASMethod) bool {
- for _, supportedMeth := range vsec.ShortAuthenticationString {
- if supportedMeth == meth {
- return true
- }
- }
- return false
-}
-
-func (vsec *VerificationStartEventContent) GetRelatesTo() *RelatesTo {
- if vsec.RelatesTo == nil {
- vsec.RelatesTo = &RelatesTo{}
- }
- return vsec.RelatesTo
-}
-
-func (vsec *VerificationStartEventContent) OptionalGetRelatesTo() *RelatesTo {
- return vsec.RelatesTo
-}
-
-func (vsec *VerificationStartEventContent) SetRelatesTo(rel *RelatesTo) {
- vsec.RelatesTo = rel
-}
-
-// VerificationReadyEventContent represents the content of a m.key.verification.ready event.
-// https://spec.matrix.org/v1.2/client-server-api/#mkeyverificationready
-type VerificationReadyEventContent struct {
- // The device ID which accepted the process.
- FromDevice id.DeviceID `json:"from_device"`
- // The verification methods supported by the sender.
- Methods []VerificationMethod `json:"methods"`
- // Original event ID for in-room verification.
- RelatesTo *RelatesTo `json:"m.relates_to,omitempty"`
-}
-
-var _ Relatable = (*VerificationReadyEventContent)(nil)
-
-func (vrec *VerificationReadyEventContent) GetRelatesTo() *RelatesTo {
- if vrec.RelatesTo == nil {
- vrec.RelatesTo = &RelatesTo{}
- }
- return vrec.RelatesTo
-}
-
-func (vrec *VerificationReadyEventContent) OptionalGetRelatesTo() *RelatesTo {
- return vrec.RelatesTo
-}
-
-func (vrec *VerificationReadyEventContent) SetRelatesTo(rel *RelatesTo) {
- vrec.RelatesTo = rel
-}
-
-// VerificationAcceptEventContent represents the content of a m.key.verification.accept to_device event.
-// https://spec.matrix.org/v1.2/client-server-api/#mkeyverificationaccept
-type VerificationAcceptEventContent struct {
- // An opaque identifier for the verification process. Must be the same as the one used for the m.key.verification.start message.
- TransactionID string `json:"transaction_id,omitempty"`
- // The verification method to use.
- Method VerificationMethod `json:"method"`
- // The key agreement protocol the device is choosing to use, out of the options in the m.key.verification.start message.
- KeyAgreementProtocol KeyAgreementProtocol `json:"key_agreement_protocol"`
- // The hash method the device is choosing to use, out of the options in the m.key.verification.start message.
- Hash VerificationHashMethod `json:"hash"`
- // The message authentication code the device is choosing to use, out of the options in the m.key.verification.start message.
- MessageAuthenticationCode MACMethod `json:"message_authentication_code"`
- // The SAS methods both devices involved in the verification process understand. Must be a subset of the options in the m.key.verification.start message.
- ShortAuthenticationString []SASMethod `json:"short_authentication_string"`
- // The hash (encoded as unpadded base64) of the concatenation of the device's ephemeral public key (encoded as unpadded base64) and the canonical JSON representation of the m.key.verification.start message.
- Commitment string `json:"commitment"`
- // The user that the event is sent to for in-room verification.
- To id.UserID `json:"to,omitempty"`
- // Original event ID for in-room verification.
- RelatesTo *RelatesTo `json:"m.relates_to,omitempty"`
-}
-
-func (vaec *VerificationAcceptEventContent) GetRelatesTo() *RelatesTo {
- if vaec.RelatesTo == nil {
- vaec.RelatesTo = &RelatesTo{}
- }
- return vaec.RelatesTo
-}
-
-func (vaec *VerificationAcceptEventContent) OptionalGetRelatesTo() *RelatesTo {
- return vaec.RelatesTo
-}
-
-func (vaec *VerificationAcceptEventContent) SetRelatesTo(rel *RelatesTo) {
- vaec.RelatesTo = rel
-}
-
-// VerificationKeyEventContent represents the content of a m.key.verification.key to_device event.
-// https://spec.matrix.org/v1.2/client-server-api/#mkeyverificationkey
-type VerificationKeyEventContent struct {
- // An opaque identifier for the verification process. Must be the same as the one used for the m.key.verification.start message.
- TransactionID string `json:"transaction_id,omitempty"`
- // The device's ephemeral public key, encoded as unpadded base64.
- Key string `json:"key"`
- // The user that the event is sent to for in-room verification.
- To id.UserID `json:"to,omitempty"`
- // Original event ID for in-room verification.
- RelatesTo *RelatesTo `json:"m.relates_to,omitempty"`
-}
-
-func (vkec *VerificationKeyEventContent) GetRelatesTo() *RelatesTo {
- if vkec.RelatesTo == nil {
- vkec.RelatesTo = &RelatesTo{}
- }
- return vkec.RelatesTo
-}
-
-func (vkec *VerificationKeyEventContent) OptionalGetRelatesTo() *RelatesTo {
- return vkec.RelatesTo
-}
-
-func (vkec *VerificationKeyEventContent) SetRelatesTo(rel *RelatesTo) {
- vkec.RelatesTo = rel
-}
-
-// VerificationMacEventContent represents the content of a m.key.verification.mac to_device event.
-// https://spec.matrix.org/v1.2/client-server-api/#mkeyverificationmac
-type VerificationMacEventContent struct {
- // An opaque identifier for the verification process. Must be the same as the one used for the m.key.verification.start message.
- TransactionID string `json:"transaction_id,omitempty"`
- // A map of the key ID to the MAC of the key, using the algorithm in the verification process. The MAC is encoded as unpadded base64.
- Mac map[id.KeyID]string `json:"mac"`
- // The MAC of the comma-separated, sorted, list of key IDs given in the mac property, encoded as unpadded base64.
- Keys string `json:"keys"`
- // The user that the event is sent to for in-room verification.
- To id.UserID `json:"to,omitempty"`
- // Original event ID for in-room verification.
- RelatesTo *RelatesTo `json:"m.relates_to,omitempty"`
-}
-
-func (vmec *VerificationMacEventContent) GetRelatesTo() *RelatesTo {
- if vmec.RelatesTo == nil {
- vmec.RelatesTo = &RelatesTo{}
- }
- return vmec.RelatesTo
-}
-
-func (vmec *VerificationMacEventContent) OptionalGetRelatesTo() *RelatesTo {
- return vmec.RelatesTo
-}
-
-func (vmec *VerificationMacEventContent) SetRelatesTo(rel *RelatesTo) {
- vmec.RelatesTo = rel
+// VerificationDoneEventContent represents the content of an
+// [m.key.verification.done] event (both the to-device and the in-room version)
+// as described in [Section 11.12.2.1] of the Spec.
+//
+// This type is an alias for [VerificationRelatable] since there are no
+// additional fields defined by the spec.
+//
+// [m.key.verification.done]: https://spec.matrix.org/v1.9/client-server-api/#mkeyverificationdone
+// [Section 11.12.2.1]: https://spec.matrix.org/v1.9/client-server-api/#mkeyverificationdone
+type VerificationDoneEventContent struct {
+ ToDeviceVerificationEvent
+ InRoomVerificationEvent
}
type VerificationCancelCode string
const (
- VerificationCancelByUser VerificationCancelCode = "m.user"
- VerificationCancelByTimeout VerificationCancelCode = "m.timeout"
- VerificationCancelUnknownTransaction VerificationCancelCode = "m.unknown_transaction"
- VerificationCancelUnknownMethod VerificationCancelCode = "m.unknown_method"
- VerificationCancelUnexpectedMessage VerificationCancelCode = "m.unexpected_message"
- VerificationCancelKeyMismatch VerificationCancelCode = "m.key_mismatch"
- VerificationCancelUserMismatch VerificationCancelCode = "m.user_mismatch"
- VerificationCancelInvalidMessage VerificationCancelCode = "m.invalid_message"
- VerificationCancelAccepted VerificationCancelCode = "m.accepted"
- VerificationCancelSASMismatch VerificationCancelCode = "m.mismatched_sas"
- VerificationCancelCommitmentMismatch VerificationCancelCode = "m.mismatched_commitment"
+ VerificationCancelCodeUser VerificationCancelCode = "m.user"
+ VerificationCancelCodeTimeout VerificationCancelCode = "m.timeout"
+ VerificationCancelCodeUnknownTransaction VerificationCancelCode = "m.unknown_transaction"
+ VerificationCancelCodeUnknownMethod VerificationCancelCode = "m.unknown_method"
+ VerificationCancelCodeUnexpectedMessage VerificationCancelCode = "m.unexpected_message"
+ VerificationCancelCodeKeyMismatch VerificationCancelCode = "m.key_mismatch"
+ VerificationCancelCodeUserMismatch VerificationCancelCode = "m.user_mismatch"
+ VerificationCancelCodeInvalidMessage VerificationCancelCode = "m.invalid_message"
+ VerificationCancelCodeAccepted VerificationCancelCode = "m.accepted"
+ VerificationCancelCodeSASMismatch VerificationCancelCode = "m.mismatched_sas"
+ VerificationCancelCodeCommitmentMismatch VerificationCancelCode = "m.mismatched_commitment"
+
+ // Non-spec codes
+ VerificationCancelCodeInternalError VerificationCancelCode = "com.beeper.internal_error"
+ VerificationCancelCodeMasterKeyNotTrusted VerificationCancelCode = "com.beeper.master_key_not_trusted" // the master key is not trusted by this device, but the QR code that was scanned was from a device that doesn't trust the master key
)
-// VerificationCancelEventContent represents the content of a m.key.verification.cancel to_device event.
-// https://spec.matrix.org/v1.2/client-server-api/#mkeyverificationcancel
+// VerificationCancelEventContent represents the content of an
+// [m.key.verification.cancel] event (both the to-device and the in-room
+// version) as described in [Section 11.12.2.1] of the Spec.
+//
+// [m.key.verification.cancel]: https://spec.matrix.org/v1.9/client-server-api/#mkeyverificationcancel
+// [Section 11.12.2.1]: https://spec.matrix.org/v1.9/client-server-api/#mkeyverificationdone
type VerificationCancelEventContent struct {
- // The opaque identifier for the verification process/request.
- TransactionID string `json:"transaction_id,omitempty"`
- // A human readable description of the code. The client should only rely on this string if it does not understand the code.
- Reason string `json:"reason"`
- // The error code for why the process/request was cancelled by the user.
+ ToDeviceVerificationEvent
+ InRoomVerificationEvent
+
+ // Code is the error code for why the process/request was cancelled by the
+ // user.
Code VerificationCancelCode `json:"code"`
- // The user that the event is sent to for in-room verification.
- To id.UserID `json:"to,omitempty"`
- // Original event ID for in-room verification.
- RelatesTo *RelatesTo `json:"m.relates_to,omitempty"`
+ // Reason is a human readable description of the code. The client should
+ // only rely on this string if it does not understand the code.
+ Reason string `json:"reason"`
}
-func (vcec *VerificationCancelEventContent) GetRelatesTo() *RelatesTo {
- if vcec.RelatesTo == nil {
- vcec.RelatesTo = &RelatesTo{}
- }
- return vcec.RelatesTo
+// VerificationAcceptEventContent represents the content of an
+// [m.key.verification.accept] event (both the to-device and the in-room
+// version) as described in [Section 11.12.2.2.2] of the Spec.
+//
+// [m.key.verification.accept]: https://spec.matrix.org/v1.9/client-server-api/#mkeyverificationaccept
+// [Section 11.12.2.2.2]: https://spec.matrix.org/v1.9/client-server-api/#verification-messages-specific-to-sas
+type VerificationAcceptEventContent struct {
+ ToDeviceVerificationEvent
+ InRoomVerificationEvent
+
+ // Commitment is the hash of the concatenation of the device's ephemeral
+ // public key (encoded as unpadded base64) and the canonical JSON
+ // representation of the m.key.verification.start message.
+ Commitment jsonbytes.UnpaddedBytes `json:"commitment"`
+ // Hash is the hash method the device is choosing to use, out of the
+ // options in the m.key.verification.start message.
+ Hash VerificationHashMethod `json:"hash"`
+ // KeyAgreementProtocol is the key agreement protocol the device is
+ // choosing to use, out of the options in the m.key.verification.start
+ // message.
+ KeyAgreementProtocol KeyAgreementProtocol `json:"key_agreement_protocol"`
+ // MessageAuthenticationCode is the message authentication code the device
+ // is choosing to use, out of the options in the m.key.verification.start
+ // message.
+ MessageAuthenticationCode MACMethod `json:"message_authentication_code"`
+ // ShortAuthenticationString is a list of SAS methods both devices involved
+ // in the verification process understand. Must be a subset of the options
+ // in the m.key.verification.start message.
+ ShortAuthenticationString []SASMethod `json:"short_authentication_string"`
}
-func (vcec *VerificationCancelEventContent) OptionalGetRelatesTo() *RelatesTo {
- return vcec.RelatesTo
+// VerificationKeyEventContent represents the content of an
+// [m.key.verification.key] event (both the to-device and the in-room version)
+// as described in [Section 11.12.2.2.2] of the Spec.
+//
+// [m.key.verification.key]: https://spec.matrix.org/v1.9/client-server-api/#mkeyverificationkey
+// [Section 11.12.2.2.2]: https://spec.matrix.org/v1.9/client-server-api/#verification-messages-specific-to-sas
+type VerificationKeyEventContent struct {
+ ToDeviceVerificationEvent
+ InRoomVerificationEvent
+
+ // Key is the device’s ephemeral public key.
+ Key jsonbytes.UnpaddedBytes `json:"key"`
}
-func (vcec *VerificationCancelEventContent) SetRelatesTo(rel *RelatesTo) {
- vcec.RelatesTo = rel
+// VerificationMACEventContent represents the content of an
+// [m.key.verification.mac] event (both the to-device and the in-room version)
+// as described in [Section 11.12.2.2.2] of the Spec.
+//
+// [m.key.verification.mac]: https://spec.matrix.org/v1.9/client-server-api/#mkeyverificationmac
+// [Section 11.12.2.2.2]: https://spec.matrix.org/v1.9/client-server-api/#verification-messages-specific-to-sas
+type VerificationMACEventContent struct {
+ ToDeviceVerificationEvent
+ InRoomVerificationEvent
+
+ // Keys is the MAC of the comma-separated, sorted, list of key IDs given in
+ // the MAC property.
+ Keys jsonbytes.UnpaddedBytes `json:"keys"`
+ // MAC is a map of the key ID to the MAC of the key, using the algorithm in
+ // the verification process.
+ MAC map[id.KeyID]jsonbytes.UnpaddedBytes `json:"mac"`
}
diff --git a/event/voip.go b/event/voip.go
index 28f56c95..cd8364a1 100644
--- a/event/voip.go
+++ b/event/voip.go
@@ -76,7 +76,7 @@ func (cv *CallVersion) Int() (int, error) {
type BaseCallEventContent struct {
CallID string `json:"call_id"`
PartyID string `json:"party_id"`
- Version CallVersion `json:"version"`
+ Version CallVersion `json:"version,omitempty"`
}
type CallInviteEventContent struct {
diff --git a/example/go.mod b/example/go.mod
deleted file mode 100644
index f78b3fa0..00000000
--- a/example/go.mod
+++ /dev/null
@@ -1,27 +0,0 @@
-module maunium.net/go/mautrix/example
-
-go 1.20
-
-require (
- github.com/chzyer/readline v1.5.1
- github.com/mattn/go-sqlite3 v1.14.19
- github.com/rs/zerolog v1.31.0
- maunium.net/go/mautrix v0.16.3-0.20240113165612-308e3583b06f
-)
-
-require (
- github.com/mattn/go-colorable v0.1.13 // indirect
- github.com/mattn/go-isatty v0.0.19 // indirect
- github.com/tidwall/gjson v1.17.0 // indirect
- github.com/tidwall/match v1.1.1 // indirect
- github.com/tidwall/pretty v1.2.0 // indirect
- github.com/tidwall/sjson v1.2.5 // indirect
- go.mau.fi/util v0.2.2-0.20240107143939-48dfc4dc3894 // indirect
- golang.org/x/crypto v0.17.0 // indirect
- golang.org/x/exp v0.0.0-20231226003508-02704c960a9b // indirect
- golang.org/x/net v0.19.0 // indirect
- golang.org/x/sys v0.15.0 // indirect
- maunium.net/go/maulogger/v2 v2.4.1 // indirect
-)
-
-//replace maunium.net/go/mautrix => ../
diff --git a/example/go.sum b/example/go.sum
deleted file mode 100644
index 0a3092ed..00000000
--- a/example/go.sum
+++ /dev/null
@@ -1,51 +0,0 @@
-github.com/DATA-DOG/go-sqlmock v1.5.1 h1:FK6RCIUSfmbnI/imIICmboyQBkOckutaa6R5YYlLZyo=
-github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM=
-github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ=
-github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI=
-github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk=
-github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04=
-github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
-github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
-github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
-github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
-github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
-github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
-github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
-github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
-github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
-github.com/mattn/go-sqlite3 v1.14.19 h1:fhGleo2h1p8tVChob4I9HpmVFIAkKGpiukdrgQbWfGI=
-github.com/mattn/go-sqlite3 v1.14.19/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
-github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
-github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
-github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
-github.com/rs/zerolog v1.31.0 h1:FcTR3NnLWW+NnTwwhFWiJSZr4ECLpqCm6QsEnyvbV4A=
-github.com/rs/zerolog v1.31.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
-github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
-github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
-github.com/tidwall/gjson v1.17.0 h1:/Jocvlh98kcTfpN2+JzGQWQcqrPQwDrVEMApx/M5ZwM=
-github.com/tidwall/gjson v1.17.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
-github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
-github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
-github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
-github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
-github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
-github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
-go.mau.fi/util v0.2.2-0.20240107143939-48dfc4dc3894 h1:CuR5LDSxBQLETorfwJ9vRtySeLHjMvJ7//lnCMw7Dy8=
-go.mau.fi/util v0.2.2-0.20240107143939-48dfc4dc3894/go.mod h1:9dGsBCCbZJstx16YgnVMVi3O2bOizELoKpugLD4FoGs=
-golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
-golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
-golang.org/x/exp v0.0.0-20231226003508-02704c960a9b h1:kLiC65FbiHWFAOu+lxwNPujcsl8VYyTYYEZnsOO1WK4=
-golang.org/x/exp v0.0.0-20231226003508-02704c960a9b/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI=
-golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
-golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
-golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
-golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
-gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
-maunium.net/go/maulogger/v2 v2.4.1 h1:N7zSdd0mZkB2m2JtFUsiGTQQAdP0YeFWT7YMc80yAL8=
-maunium.net/go/maulogger/v2 v2.4.1/go.mod h1:omPuYwYBILeVQobz8uO3XC8DIRuEb5rXYlQSuqrbCho=
-maunium.net/go/mautrix v0.16.3-0.20240113165612-308e3583b06f h1:6uzyAxrjqGv2SbTAnIK3LI6mo1fILWOga6uNyId+6yM=
-maunium.net/go/mautrix v0.16.3-0.20240113165612-308e3583b06f/go.mod h1:eRQu5ED1ODsP+xq1K9l1AOD+O9FMkAhodd/RVc3Bkqg=
diff --git a/example/main.go b/example/main.go
index f799409c..2bf4bef3 100644
--- a/example/main.go
+++ b/example/main.go
@@ -20,6 +20,7 @@ import (
"github.com/chzyer/readline"
_ "github.com/mattn/go-sqlite3"
"github.com/rs/zerolog"
+ "go.mau.fi/util/exzerolog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto/cryptohelper"
@@ -57,6 +58,7 @@ func main() {
if !*debug {
log = log.Level(zerolog.InfoLevel)
}
+ exzerolog.SetupDefaults(&log)
client.Log = log
var lastRoomID id.RoomID
@@ -141,7 +143,7 @@ func main() {
if err != nil {
log.Error().Err(err).Msg("Failed to send event")
} else {
- log.Info().Str("event_id", resp.EventID.String()).Msg("Event sent")
+ log.Info().Stringer("event_id", resp.EventID).Msg("Event sent")
}
}
cancelSync()
diff --git a/federation/cache.go b/federation/cache.go
new file mode 100644
index 00000000..24154974
--- /dev/null
+++ b/federation/cache.go
@@ -0,0 +1,153 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package federation
+
+import (
+ "errors"
+ "fmt"
+ "math"
+ "sync"
+ "time"
+)
+
+// ResolutionCache is an interface for caching resolved server names.
+type ResolutionCache interface {
+ StoreResolution(*ResolvedServerName)
+ // LoadResolution loads a resolved server name from the cache.
+ // Expired entries MUST NOT be returned.
+ LoadResolution(serverName string) (*ResolvedServerName, error)
+}
+
+type KeyCache interface {
+ StoreKeys(*ServerKeyResponse)
+ StoreFetchError(serverName string, err error)
+ ShouldReQuery(serverName string) bool
+ LoadKeys(serverName string) (*ServerKeyResponse, error)
+}
+
+type InMemoryCache struct {
+ MinKeyRefetchDelay time.Duration
+
+ resolutions map[string]*ResolvedServerName
+ resolutionsLock sync.RWMutex
+ keys map[string]*ServerKeyResponse
+ lastReQueryAt map[string]time.Time
+ lastError map[string]*resolutionErrorCache
+ keysLock sync.RWMutex
+}
+
+var (
+ _ ResolutionCache = (*InMemoryCache)(nil)
+ _ KeyCache = (*InMemoryCache)(nil)
+)
+
+func NewInMemoryCache() *InMemoryCache {
+ return &InMemoryCache{
+ resolutions: make(map[string]*ResolvedServerName),
+ keys: make(map[string]*ServerKeyResponse),
+ lastReQueryAt: make(map[string]time.Time),
+ lastError: make(map[string]*resolutionErrorCache),
+ MinKeyRefetchDelay: 1 * time.Hour,
+ }
+}
+
+func (c *InMemoryCache) StoreResolution(resolution *ResolvedServerName) {
+ c.resolutionsLock.Lock()
+ defer c.resolutionsLock.Unlock()
+ c.resolutions[resolution.ServerName] = resolution
+}
+
+func (c *InMemoryCache) LoadResolution(serverName string) (*ResolvedServerName, error) {
+ c.resolutionsLock.RLock()
+ defer c.resolutionsLock.RUnlock()
+ resolution, ok := c.resolutions[serverName]
+ if !ok || time.Until(resolution.Expires) < 0 {
+ return nil, nil
+ }
+ return resolution, nil
+}
+
+func (c *InMemoryCache) StoreKeys(keys *ServerKeyResponse) {
+ c.keysLock.Lock()
+ defer c.keysLock.Unlock()
+ c.keys[keys.ServerName] = keys
+ delete(c.lastError, keys.ServerName)
+}
+
+type resolutionErrorCache struct {
+ Error error
+ Time time.Time
+ Count int
+}
+
+const MaxBackoff = 7 * 24 * time.Hour
+
+func (rec *resolutionErrorCache) ShouldRetry() bool {
+ backoff := time.Duration(math.Exp(float64(rec.Count))) * time.Second
+ return time.Since(rec.Time) > backoff
+}
+
+var ErrRecentKeyQueryFailed = errors.New("last retry was too recent")
+
+func (c *InMemoryCache) LoadKeys(serverName string) (*ServerKeyResponse, error) {
+ c.keysLock.RLock()
+ defer c.keysLock.RUnlock()
+ keys, ok := c.keys[serverName]
+ if !ok || time.Until(keys.ValidUntilTS.Time) < 0 {
+ err, ok := c.lastError[serverName]
+ if ok && !err.ShouldRetry() {
+ return nil, fmt.Errorf(
+ "%w (%s ago) and failed with %w",
+ ErrRecentKeyQueryFailed,
+ time.Since(err.Time).String(),
+ err.Error,
+ )
+ }
+ return nil, nil
+ }
+ return keys, nil
+}
+
+func (c *InMemoryCache) StoreFetchError(serverName string, err error) {
+ c.keysLock.Lock()
+ defer c.keysLock.Unlock()
+ errorCache, ok := c.lastError[serverName]
+ if ok {
+ errorCache.Time = time.Now()
+ errorCache.Error = err
+ errorCache.Count++
+ } else {
+ c.lastError[serverName] = &resolutionErrorCache{Error: err, Time: time.Now(), Count: 1}
+ }
+}
+
+func (c *InMemoryCache) ShouldReQuery(serverName string) bool {
+ c.keysLock.Lock()
+ defer c.keysLock.Unlock()
+ lastQuery, ok := c.lastReQueryAt[serverName]
+ if ok && time.Since(lastQuery) < c.MinKeyRefetchDelay {
+ return false
+ }
+ c.lastReQueryAt[serverName] = time.Now()
+ return true
+}
+
+type noopCache struct{}
+
+func (*noopCache) StoreKeys(_ *ServerKeyResponse) {}
+func (*noopCache) LoadKeys(_ string) (*ServerKeyResponse, error) { return nil, nil }
+func (*noopCache) StoreFetchError(_ string, _ error) {}
+func (*noopCache) ShouldReQuery(_ string) bool { return true }
+func (*noopCache) StoreResolution(_ *ResolvedServerName) {}
+func (*noopCache) LoadResolution(_ string) (*ResolvedServerName, error) { return nil, nil }
+
+var (
+ _ ResolutionCache = (*noopCache)(nil)
+ _ KeyCache = (*noopCache)(nil)
+)
+
+var NoopCache *noopCache
diff --git a/federation/client.go b/federation/client.go
new file mode 100644
index 00000000..183fb5d1
--- /dev/null
+++ b/federation/client.go
@@ -0,0 +1,605 @@
+// Copyright (c) 2024 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package federation
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strconv"
+ "time"
+
+ "go.mau.fi/util/exslices"
+ "go.mau.fi/util/jsontime"
+
+ "maunium.net/go/mautrix"
+ "maunium.net/go/mautrix/federation/signutil"
+ "maunium.net/go/mautrix/id"
+)
+
+type Client struct {
+ HTTP *http.Client
+ ServerName string
+ UserAgent string
+ Key *SigningKey
+
+ ResponseSizeLimit int64
+}
+
+func NewClient(serverName string, key *SigningKey, cache ResolutionCache) *Client {
+ return &Client{
+ HTTP: &http.Client{
+ Transport: NewServerResolvingTransport(cache),
+ Timeout: 120 * time.Second,
+ CheckRedirect: func(req *http.Request, via []*http.Request) error {
+ // Federation requests do not allow redirects.
+ return http.ErrUseLastResponse
+ },
+ },
+ UserAgent: mautrix.DefaultUserAgent,
+ ServerName: serverName,
+ Key: key,
+
+ ResponseSizeLimit: mautrix.DefaultResponseSizeLimit,
+ }
+}
+
+func (c *Client) Version(ctx context.Context, serverName string) (resp *RespServerVersion, err error) {
+ err = c.MakeRequest(ctx, serverName, false, http.MethodGet, URLPath{"v1", "version"}, nil, &resp)
+ return
+}
+
+func (c *Client) ServerKeys(ctx context.Context, serverName string) (resp *ServerKeyResponse, err error) {
+ err = c.MakeRequest(ctx, serverName, false, http.MethodGet, KeyURLPath{"v2", "server"}, nil, &resp)
+ return
+}
+
+func (c *Client) QueryKeys(ctx context.Context, serverName string, req *ReqQueryKeys) (resp *QueryKeysResponse, err error) {
+ err = c.MakeRequest(ctx, serverName, false, http.MethodPost, KeyURLPath{"v2", "query"}, req, &resp)
+ return
+}
+
+type PDU = json.RawMessage
+type EDU = json.RawMessage
+
+type ReqSendTransaction struct {
+ Destination string `json:"destination"`
+ TxnID string `json:"-"`
+
+ Origin string `json:"origin"`
+ OriginServerTS jsontime.UnixMilli `json:"origin_server_ts"`
+ PDUs []PDU `json:"pdus"`
+ EDUs []EDU `json:"edus,omitempty"`
+}
+
+type PDUProcessingResult struct {
+ Error string `json:"error,omitempty"`
+}
+
+type RespSendTransaction struct {
+ PDUs map[id.EventID]PDUProcessingResult `json:"pdus"`
+}
+
+func (c *Client) SendTransaction(ctx context.Context, req *ReqSendTransaction) (resp *RespSendTransaction, err error) {
+ err = c.MakeRequest(ctx, req.Destination, true, http.MethodPut, URLPath{"v1", "send", req.TxnID}, req, &resp)
+ return
+}
+
+type RespGetEventAuthChain struct {
+ AuthChain []PDU `json:"auth_chain"`
+}
+
+func (c *Client) GetEventAuthChain(ctx context.Context, serverName string, roomID id.RoomID, eventID id.EventID) (resp *RespGetEventAuthChain, err error) {
+ err = c.MakeRequest(ctx, serverName, true, http.MethodGet, URLPath{"v1", "event_auth", roomID, eventID}, nil, &resp)
+ return
+}
+
+type ReqBackfill struct {
+ ServerName string
+ RoomID id.RoomID
+ Limit int
+ BackfillFrom []id.EventID
+}
+
+type RespBackfill struct {
+ Origin string `json:"origin"`
+ OriginServerTS jsontime.UnixMilli `json:"origin_server_ts"`
+ PDUs []PDU `json:"pdus"`
+}
+
+func (c *Client) Backfill(ctx context.Context, req *ReqBackfill) (resp *RespBackfill, err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.ServerName,
+ Method: http.MethodGet,
+ Path: URLPath{"v1", "backfill", req.RoomID},
+ Query: url.Values{
+ "limit": {strconv.Itoa(req.Limit)},
+ "v": exslices.CastToString[string](req.BackfillFrom),
+ },
+ Authenticate: true,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+type ReqGetMissingEvents struct {
+ ServerName string `json:"-"`
+ RoomID id.RoomID `json:"-"`
+ EarliestEvents []id.EventID `json:"earliest_events"`
+ LatestEvents []id.EventID `json:"latest_events"`
+ Limit int `json:"limit,omitempty"`
+ MinDepth int `json:"min_depth,omitempty"`
+}
+
+type RespGetMissingEvents struct {
+ Events []PDU `json:"events"`
+}
+
+func (c *Client) GetMissingEvents(ctx context.Context, req *ReqGetMissingEvents) (resp *RespGetMissingEvents, err error) {
+ err = c.MakeRequest(ctx, req.ServerName, true, http.MethodPost, URLPath{"v1", "get_missing_events", req.RoomID}, req, &resp)
+ return
+}
+
+func (c *Client) GetEvent(ctx context.Context, serverName string, eventID id.EventID) (resp *RespBackfill, err error) {
+ err = c.MakeRequest(ctx, serverName, true, http.MethodGet, URLPath{"v1", "event", eventID}, nil, &resp)
+ return
+}
+
+type RespGetState struct {
+ AuthChain []PDU `json:"auth_chain"`
+ PDUs []PDU `json:"pdus"`
+}
+
+func (c *Client) GetState(ctx context.Context, serverName string, roomID id.RoomID, eventID id.EventID) (resp *RespGetState, err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: serverName,
+ Method: http.MethodGet,
+ Path: URLPath{"v1", "state", roomID},
+ Query: url.Values{
+ "event_id": {string(eventID)},
+ },
+ Authenticate: true,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+type RespGetStateIDs struct {
+ AuthChain []id.EventID `json:"auth_chain_ids"`
+ PDUs []id.EventID `json:"pdu_ids"`
+}
+
+func (c *Client) GetStateIDs(ctx context.Context, serverName string, roomID id.RoomID, eventID id.EventID) (resp *RespGetStateIDs, err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: serverName,
+ Method: http.MethodGet,
+ Path: URLPath{"v1", "state_ids", roomID},
+ Query: url.Values{
+ "event_id": {string(eventID)},
+ },
+ Authenticate: true,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+func (c *Client) TimestampToEvent(ctx context.Context, serverName string, roomID id.RoomID, timestamp time.Time, dir mautrix.Direction) (resp *mautrix.RespTimestampToEvent, err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: serverName,
+ Method: http.MethodGet,
+ Path: URLPath{"v1", "timestamp_to_event", roomID},
+ Query: url.Values{
+ "dir": {string(dir)},
+ "ts": {strconv.FormatInt(timestamp.UnixMilli(), 10)},
+ },
+ Authenticate: true,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+func (c *Client) QueryProfile(ctx context.Context, serverName string, userID id.UserID) (resp *mautrix.RespUserProfile, err error) {
+ err = c.Query(ctx, serverName, "profile", url.Values{"user_id": {userID.String()}}, &resp)
+ return
+}
+
+func (c *Client) QueryDirectory(ctx context.Context, serverName string, roomAlias id.RoomAlias) (resp *mautrix.RespAliasResolve, err error) {
+ err = c.Query(ctx, serverName, "directory", url.Values{"room_alias": {roomAlias.String()}}, &resp)
+ return
+}
+
+func (c *Client) Query(ctx context.Context, serverName, queryType string, queryParams url.Values, respStruct any) (err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: serverName,
+ Method: http.MethodGet,
+ Path: URLPath{"v1", "query", queryType},
+ Query: queryParams,
+ Authenticate: true,
+ ResponseJSON: respStruct,
+ })
+ return
+}
+
+func queryToValues(query map[string]string) url.Values {
+ values := make(url.Values, len(query))
+ for k, v := range query {
+ values[k] = []string{v}
+ }
+ return values
+}
+
+func (c *Client) PublicRooms(ctx context.Context, serverName string, req *mautrix.ReqPublicRooms) (resp *mautrix.RespPublicRooms, err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: serverName,
+ Method: http.MethodGet,
+ Path: URLPath{"v1", "publicRooms"},
+ Query: queryToValues(req.Query()),
+ Authenticate: true,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+type RespOpenIDUserInfo struct {
+ Sub id.UserID `json:"sub"`
+}
+
+func (c *Client) GetOpenIDUserInfo(ctx context.Context, serverName, accessToken string) (resp *RespOpenIDUserInfo, err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: serverName,
+ Method: http.MethodGet,
+ Path: URLPath{"v1", "openid", "userinfo"},
+ Query: url.Values{"access_token": {accessToken}},
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+type ReqMakeJoin struct {
+ RoomID id.RoomID
+ UserID id.UserID
+ Via string
+ SupportedVersions []id.RoomVersion
+}
+
+type RespMakeJoin struct {
+ RoomVersion id.RoomVersion `json:"room_version"`
+ Event PDU `json:"event"`
+}
+
+type ReqSendJoin struct {
+ RoomID id.RoomID
+ EventID id.EventID
+ OmitMembers bool
+ Event PDU
+ Via string
+}
+
+type ReqSendKnock struct {
+ RoomID id.RoomID
+ EventID id.EventID
+ Event PDU
+ Via string
+}
+
+type RespSendJoin struct {
+ AuthChain []PDU `json:"auth_chain"`
+ Event PDU `json:"event"`
+ MembersOmitted bool `json:"members_omitted"`
+ ServersInRoom []string `json:"servers_in_room"`
+ State []PDU `json:"state"`
+}
+
+type RespSendKnock struct {
+ KnockRoomState []PDU `json:"knock_room_state"`
+}
+
+type ReqSendInvite struct {
+ RoomID id.RoomID `json:"-"`
+ UserID id.UserID `json:"-"`
+ Event PDU `json:"event"`
+ InviteRoomState []PDU `json:"invite_room_state"`
+ RoomVersion id.RoomVersion `json:"room_version"`
+}
+
+type RespSendInvite struct {
+ Event PDU `json:"event"`
+}
+
+type ReqMakeLeave struct {
+ RoomID id.RoomID
+ UserID id.UserID
+ Via string
+}
+
+type ReqSendLeave struct {
+ RoomID id.RoomID
+ EventID id.EventID
+ Event PDU
+ Via string
+}
+
+type (
+ ReqMakeKnock = ReqMakeJoin
+ RespMakeKnock = RespMakeJoin
+ RespMakeLeave = RespMakeJoin
+)
+
+func (c *Client) MakeJoin(ctx context.Context, req *ReqMakeJoin) (resp *RespMakeJoin, err error) {
+ versions := make([]string, len(req.SupportedVersions))
+ for i, v := range req.SupportedVersions {
+ versions[i] = string(v)
+ }
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.Via,
+ Method: http.MethodGet,
+ Path: URLPath{"v1", "make_join", req.RoomID, req.UserID},
+ Query: url.Values{"ver": versions},
+ Authenticate: true,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+func (c *Client) MakeKnock(ctx context.Context, req *ReqMakeKnock) (resp *RespMakeKnock, err error) {
+ versions := make([]string, len(req.SupportedVersions))
+ for i, v := range req.SupportedVersions {
+ versions[i] = string(v)
+ }
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.Via,
+ Method: http.MethodGet,
+ Path: URLPath{"v1", "make_knock", req.RoomID, req.UserID},
+ Query: url.Values{"ver": versions},
+ Authenticate: true,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+func (c *Client) SendJoin(ctx context.Context, req *ReqSendJoin) (resp *RespSendJoin, err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.Via,
+ Method: http.MethodPut,
+ Path: URLPath{"v2", "send_join", req.RoomID, req.EventID},
+ Query: url.Values{
+ "omit_members": {strconv.FormatBool(req.OmitMembers)},
+ },
+ Authenticate: true,
+ RequestJSON: req.Event,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+func (c *Client) SendKnock(ctx context.Context, req *ReqSendKnock) (resp *RespSendKnock, err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.Via,
+ Method: http.MethodPut,
+ Path: URLPath{"v1", "send_knock", req.RoomID, req.EventID},
+ Authenticate: true,
+ RequestJSON: req.Event,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+func (c *Client) SendInvite(ctx context.Context, req *ReqSendInvite) (resp *RespSendInvite, err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.UserID.Homeserver(),
+ Method: http.MethodPut,
+ Path: URLPath{"v2", "invite", req.RoomID, req.UserID},
+ Authenticate: true,
+ RequestJSON: req,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+func (c *Client) MakeLeave(ctx context.Context, req *ReqMakeLeave) (resp *RespMakeLeave, err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.Via,
+ Method: http.MethodGet,
+ Path: URLPath{"v1", "make_leave", req.RoomID, req.UserID},
+ Authenticate: true,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+func (c *Client) SendLeave(ctx context.Context, req *ReqSendLeave) (err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.Via,
+ Method: http.MethodPut,
+ Path: URLPath{"v2", "send_leave", req.RoomID, req.EventID},
+ Authenticate: true,
+ RequestJSON: req.Event,
+ })
+ return
+}
+
+type URLPath []any
+
+func (fup URLPath) FullPath() []any {
+ return append([]any{"_matrix", "federation"}, []any(fup)...)
+}
+
+type KeyURLPath []any
+
+func (fkup KeyURLPath) FullPath() []any {
+ return append([]any{"_matrix", "key"}, []any(fkup)...)
+}
+
+type RequestParams struct {
+ ServerName string
+ Method string
+ Path mautrix.PrefixableURLPath
+ Query url.Values
+ Authenticate bool
+ RequestJSON any
+
+ ResponseJSON any
+ DontReadBody bool
+}
+
+func (c *Client) MakeRequest(ctx context.Context, serverName string, authenticate bool, method string, path mautrix.PrefixableURLPath, reqJSON, respJSON any) error {
+ _, _, err := c.MakeFullRequest(ctx, RequestParams{
+ ServerName: serverName,
+ Method: method,
+ Path: path,
+ Authenticate: authenticate,
+ RequestJSON: reqJSON,
+ ResponseJSON: respJSON,
+ })
+ return err
+}
+
+func (c *Client) MakeFullRequest(ctx context.Context, params RequestParams) ([]byte, *http.Response, error) {
+ req, err := c.compileRequest(ctx, params)
+ if err != nil {
+ return nil, nil, err
+ }
+ resp, err := c.HTTP.Do(req)
+ if err != nil {
+ return nil, nil, mautrix.HTTPError{
+ Request: req,
+ Response: resp,
+
+ Message: "request error",
+ WrappedError: err,
+ }
+ }
+ if !params.DontReadBody {
+ defer resp.Body.Close()
+ }
+ var body []byte
+ if resp.StatusCode >= 300 {
+ body, err = mautrix.ParseErrorResponse(req, resp)
+ return body, resp, err
+ } else if params.ResponseJSON != nil || !params.DontReadBody {
+ if resp.ContentLength > c.ResponseSizeLimit {
+ return body, resp, mautrix.HTTPError{
+ Request: req,
+ Response: resp,
+
+ Message: "not reading response",
+ WrappedError: fmt.Errorf("%w (%.2f MiB)", mautrix.ErrResponseTooLong, float64(resp.ContentLength)/1024/1024),
+ }
+ }
+ body, err = io.ReadAll(io.LimitReader(resp.Body, c.ResponseSizeLimit+1))
+ if err == nil && len(body) > int(c.ResponseSizeLimit) {
+ err = mautrix.ErrBodyReadReachedLimit
+ }
+ if err != nil {
+ return body, resp, mautrix.HTTPError{
+ Request: req,
+ Response: resp,
+
+ Message: "failed to read response body",
+ WrappedError: err,
+ }
+ }
+ if params.ResponseJSON != nil {
+ err = json.Unmarshal(body, params.ResponseJSON)
+ if err != nil {
+ return body, resp, mautrix.HTTPError{
+ Request: req,
+ Response: resp,
+
+ Message: "failed to unmarshal response JSON",
+ ResponseBody: string(body),
+ WrappedError: err,
+ }
+ }
+ }
+ }
+ return body, resp, nil
+}
+
+func (c *Client) compileRequest(ctx context.Context, params RequestParams) (*http.Request, error) {
+ reqURL := mautrix.BuildURL(&url.URL{
+ Scheme: "matrix-federation",
+ Host: params.ServerName,
+ }, params.Path.FullPath()...)
+ reqURL.RawQuery = params.Query.Encode()
+ var reqJSON json.RawMessage
+ var reqBody io.Reader
+ if params.RequestJSON != nil {
+ var err error
+ reqJSON, err = json.Marshal(params.RequestJSON)
+ if err != nil {
+ return nil, mautrix.HTTPError{
+ Message: "failed to marshal JSON",
+ WrappedError: err,
+ }
+ }
+ reqBody = bytes.NewReader(reqJSON)
+ }
+ req, err := http.NewRequestWithContext(ctx, params.Method, reqURL.String(), reqBody)
+ if err != nil {
+ return nil, mautrix.HTTPError{
+ Message: "failed to create request",
+ WrappedError: err,
+ }
+ }
+ req.Header.Set("User-Agent", c.UserAgent)
+ if params.Authenticate {
+ if c.ServerName == "" || c.Key == nil {
+ return nil, mautrix.HTTPError{
+ Message: "client not configured for authentication",
+ }
+ }
+ auth, err := (&signableRequest{
+ Method: req.Method,
+ URI: reqURL.RequestURI(),
+ Origin: c.ServerName,
+ Destination: params.ServerName,
+ Content: reqJSON,
+ }).Sign(c.Key)
+ if err != nil {
+ return nil, mautrix.HTTPError{
+ Message: "failed to sign request",
+ WrappedError: err,
+ }
+ }
+ req.Header.Set("Authorization", auth)
+ }
+ return req, nil
+}
+
+type signableRequest struct {
+ Method string `json:"method"`
+ URI string `json:"uri"`
+ Origin string `json:"origin"`
+ Destination string `json:"destination"`
+ Content json.RawMessage `json:"content,omitempty"`
+}
+
+func (r *signableRequest) Verify(key id.SigningKey, sig string) error {
+ message, err := json.Marshal(r)
+ if err != nil {
+ return fmt.Errorf("failed to marshal data: %w", err)
+ }
+ return signutil.VerifyJSONRaw(key, sig, message)
+}
+
+func (r *signableRequest) Sign(key *SigningKey) (string, error) {
+ sig, err := key.SignJSON(r)
+ if err != nil {
+ return "", err
+ }
+ return XMatrixAuth{
+ Origin: r.Origin,
+ Destination: r.Destination,
+ KeyID: key.ID,
+ Signature: sig,
+ }.String(), nil
+}
diff --git a/federation/client_test.go b/federation/client_test.go
new file mode 100644
index 00000000..ece399ea
--- /dev/null
+++ b/federation/client_test.go
@@ -0,0 +1,23 @@
+// Copyright (c) 2024 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package federation_test
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "maunium.net/go/mautrix/federation"
+)
+
+func TestClient_Version(t *testing.T) {
+ cli := federation.NewClient("", nil, nil)
+ resp, err := cli.Version(context.TODO(), "maunium.net")
+ require.NoError(t, err)
+ require.Equal(t, "Synapse", resp.Server.Name)
+}
diff --git a/federation/context.go b/federation/context.go
new file mode 100644
index 00000000..eedb2dc1
--- /dev/null
+++ b/federation/context.go
@@ -0,0 +1,42 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package federation
+
+import (
+ "context"
+ "net/http"
+)
+
+type contextKey int
+
+const (
+ contextKeyIPPort contextKey = iota
+ contextKeyDestinationServer
+ contextKeyOriginServer
+)
+
+func DestinationServerNameFromRequest(r *http.Request) string {
+ return DestinationServerName(r.Context())
+}
+
+func DestinationServerName(ctx context.Context) string {
+ if dest, ok := ctx.Value(contextKeyDestinationServer).(string); ok {
+ return dest
+ }
+ return ""
+}
+
+func OriginServerNameFromRequest(r *http.Request) string {
+ return OriginServerName(r.Context())
+}
+
+func OriginServerName(ctx context.Context) string {
+ if origin, ok := ctx.Value(contextKeyOriginServer).(string); ok {
+ return origin
+ }
+ return ""
+}
diff --git a/federation/eventauth/eventauth.go b/federation/eventauth/eventauth.go
new file mode 100644
index 00000000..c72933c2
--- /dev/null
+++ b/federation/eventauth/eventauth.go
@@ -0,0 +1,851 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package eventauth
+
+import (
+ "encoding/json"
+ "encoding/json/jsontext"
+ "errors"
+ "fmt"
+ "slices"
+ "strconv"
+ "strings"
+
+ "github.com/tidwall/gjson"
+ "go.mau.fi/util/exgjson"
+ "go.mau.fi/util/exstrings"
+ "go.mau.fi/util/ptr"
+
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/federation/pdu"
+ "maunium.net/go/mautrix/federation/signutil"
+ "maunium.net/go/mautrix/id"
+)
+
+type AuthFailError struct {
+ Index string
+ Message string
+ Wrapped error
+}
+
+func (afe AuthFailError) Error() string {
+ if afe.Message != "" {
+ return fmt.Sprintf("fail %s: %s", afe.Index, afe.Message)
+ } else if afe.Wrapped != nil {
+ return fmt.Sprintf("fail %s: %s", afe.Index, afe.Wrapped.Error())
+ }
+ return fmt.Sprintf("fail %s", afe.Index)
+}
+
+func (afe AuthFailError) Unwrap() error {
+ return afe.Wrapped
+}
+
+var mFederatePath = exgjson.Path("m.federate")
+
+var (
+ ErrCreateHasPrevEvents = AuthFailError{Index: "1.1", Message: "m.room.create event has prev_events"}
+ ErrCreateHasRoomID = AuthFailError{Index: "1.2", Message: "m.room.create event has room_id set"}
+ ErrRoomIDDoesntMatchSender = AuthFailError{Index: "1.2", Message: "room ID server doesn't match sender server"}
+ ErrUnknownRoomVersion = AuthFailError{Index: "1.3", Wrapped: id.ErrUnknownRoomVersion}
+ ErrInvalidAdditionalCreators = AuthFailError{Index: "1.4", Message: "m.room.create event has invalid additional_creators"}
+ ErrMissingCreator = AuthFailError{Index: "1.4", Message: "m.room.create event is missing creator field"}
+
+ ErrInvalidRoomIDLength = AuthFailError{Index: "2", Message: "room ID length is invalid"}
+ ErrFailedToGetCreateEvent = AuthFailError{Index: "2", Message: "failed to get m.room.create event"}
+ ErrCreateEventNotFound = AuthFailError{Index: "2", Message: "m.room.create event not found using room ID as event ID"}
+ ErrRejectedCreateEvent = AuthFailError{Index: "2", Message: "m.room.create event was rejected"}
+
+ ErrFailedToGetAuthEvents = AuthFailError{Index: "3", Message: "failed to get auth events"}
+ ErrFailedToParsePowerLevels = AuthFailError{Index: "?", Message: "failed to parse power levels"}
+ ErrDuplicateAuthEvent = AuthFailError{Index: "3.1", Message: "duplicate type/state key pair in auth events"}
+ ErrNonStateAuthEvent = AuthFailError{Index: "3.2", Message: "non-state event in auth events"}
+ ErrMissingAuthEvent = AuthFailError{Index: "3.2", Message: "missing auth event"}
+ ErrUnexpectedAuthEvent = AuthFailError{Index: "3.2", Message: "unexpected type/state key pair in auth events"}
+ ErrNoCreateEvent = AuthFailError{Index: "3.2", Message: "no m.room.create event found in auth events"}
+ ErrRejectedAuthEvent = AuthFailError{Index: "3.3", Message: "auth event was rejected"}
+ ErrMismatchingRoomIDInAuthEvent = AuthFailError{Index: "3.4", Message: "auth event room ID does not match event room ID"}
+
+ ErrFederationDisabled = AuthFailError{Index: "4", Message: "federation is disabled for this room"}
+
+ ErrMemberNotState = AuthFailError{Index: "5.1", Message: "m.room.member event is not a state event"}
+ ErrNotSignedByAuthoriser = AuthFailError{Index: "5.2", Message: "m.room.member event is not signed by server of join_authorised_via_users_server"}
+ ErrCantJoinOtherUser = AuthFailError{Index: "5.3.2", Message: "can't send join event with different state key"}
+ ErrCantJoinBanned = AuthFailError{Index: "5.3.3", Message: "user is banned from the room"}
+ ErrAuthoriserCantInvite = AuthFailError{Index: "5.3.5.2", Message: "authoriser doesn't have sufficient power level to invite"}
+ ErrAuthoriserNotInRoom = AuthFailError{Index: "5.3.5.2", Message: "authoriser isn't a member of the room"}
+ ErrCantJoinWithoutInvite = AuthFailError{Index: "5.3.7", Message: "can't join invite-only room without invite"}
+ ErrInvalidJoinRule = AuthFailError{Index: "5.3.7", Message: "invalid join rule in room"}
+ ErrThirdPartyInviteBanned = AuthFailError{Index: "5.4.1.1", Message: "third party invite target user is banned"}
+ ErrThirdPartyInviteMissingFields = AuthFailError{Index: "5.4.1.3", Message: "third party invite is missing mxid or token fields"}
+ ErrThirdPartyInviteMXIDMismatch = AuthFailError{Index: "5.4.1.4", Message: "mxid in signed third party invite doesn't match event state key"}
+ ErrThirdPartyInviteNotFound = AuthFailError{Index: "5.4.1.5", Message: "matching m.room.third_party_invite event not found in auth events"}
+ ErrThirdPartyInviteSenderMismatch = AuthFailError{Index: "5.4.1.6", Message: "sender of third party invite doesn't match sender of member event"}
+ ErrThirdPartyInviteNotSigned = AuthFailError{Index: "5.4.1.8", Message: "no valid signatures found for third party invite"}
+ ErrInviterNotInRoom = AuthFailError{Index: "5.4.2", Message: "inviter's membership is not join"}
+ ErrInviteTargetAlreadyInRoom = AuthFailError{Index: "5.4.3", Message: "invite target user is already in the room"}
+ ErrInviteTargetBanned = AuthFailError{Index: "5.4.3", Message: "invite target user is banned"}
+ ErrInsufficientPermissionForInvite = AuthFailError{Index: "5.4.5", Message: "inviter does not have sufficient permission to send invites"}
+ ErrCantLeaveWithoutBeingInRoom = AuthFailError{Index: "5.5.1", Message: "can't leave room without being in it"}
+ ErrCantKickWithoutBeingInRoom = AuthFailError{Index: "5.5.2", Message: "can't kick another user without being in the room"}
+ ErrInsufficientPermissionForUnban = AuthFailError{Index: "5.5.3", Message: "sender does not have sufficient permission to unban users"}
+ ErrInsufficientPermissionForKick = AuthFailError{Index: "5.5.5", Message: "sender does not have sufficient permission to kick the user"}
+ ErrCantBanWithoutBeingInRoom = AuthFailError{Index: "5.6.1", Message: "can't ban another user without being in the room"}
+ ErrInsufficientPermissionForBan = AuthFailError{Index: "5.6.3", Message: "sender does not have sufficient permission to ban the user"}
+ ErrNotKnockableRoom = AuthFailError{Index: "5.7.1", Message: "join rule doesn't allow knocking"}
+ ErrCantKnockOtherUser = AuthFailError{Index: "5.7.1", Message: "can't send knock event with different state key"}
+ ErrCantKnockWhileInRoom = AuthFailError{Index: "5.7.2", Message: "can't knock while joined, invited or banned"}
+ ErrUnknownMembership = AuthFailError{Index: "5.8", Message: "unknown membership in m.room.member event"}
+
+ ErrNotInRoom = AuthFailError{Index: "6", Message: "sender is not a member of the room"}
+
+ ErrInsufficientPowerForThirdPartyInvite = AuthFailError{Index: "7.1", Message: "sender does not have sufficient power level to send third party invite"}
+
+ ErrInsufficientPowerLevel = AuthFailError{Index: "8", Message: "sender does not have sufficient power level to send event"}
+
+ ErrMismatchingPrivateStateKey = AuthFailError{Index: "9", Message: "state keys starting with @ must match sender user ID"}
+
+ ErrTopLevelPLNotInteger = AuthFailError{Index: "10.1", Message: "invalid type for top-level power level field"}
+ ErrPLNotInteger = AuthFailError{Index: "10.2", Message: "invalid type for power level"}
+ ErrInvalidUserIDInPL = AuthFailError{Index: "10.3", Message: "invalid user ID in power levels"}
+ ErrUserPLNotInteger = AuthFailError{Index: "10.3", Message: "invalid type for user power level"}
+ ErrCreatorInPowerLevels = AuthFailError{Index: "10.4", Message: "room creators must not be specified in power levels"}
+ ErrInvalidPowerChange = AuthFailError{Index: "10.x", Message: "illegal power level change"}
+ ErrInvalidUserPowerChange = AuthFailError{Index: "10.9", Message: "illegal power level change"}
+)
+
+func isRejected(evt *pdu.PDU) bool {
+ return evt.InternalMeta.Rejected
+}
+
+type GetEventsFunc = func(ids []id.EventID) ([]*pdu.PDU, error)
+
+func Authorize(roomVersion id.RoomVersion, evt *pdu.PDU, getEvents GetEventsFunc, getKey pdu.GetKeyFunc) error {
+ if evt.Type == event.StateCreate.Type {
+ // 1. If type is m.room.create:
+ return authorizeCreate(roomVersion, evt)
+ }
+ var createEvt *pdu.PDU
+ if roomVersion.RoomIDIsCreateEventID() {
+ // 2. If the event’s room_id is not an event ID for an accepted (not rejected) m.room.create event,
+ // with the sigil ! instead of $, reject.
+ if len(evt.RoomID) != 44 {
+ return fmt.Errorf("%w (%d)", ErrInvalidRoomIDLength, len(evt.RoomID))
+ } else if createEvts, err := getEvents([]id.EventID{id.EventID("$" + evt.RoomID[1:])}); err != nil {
+ return fmt.Errorf("%w: %w", ErrFailedToGetCreateEvent, err)
+ } else if len(createEvts) != 1 {
+ return fmt.Errorf("%w (%s)", ErrCreateEventNotFound, evt.RoomID)
+ } else if isRejected(createEvts[0]) {
+ return ErrRejectedCreateEvent
+ } else {
+ createEvt = createEvts[0]
+ }
+ }
+ authEvents, err := getEvents(evt.AuthEvents)
+ if err != nil {
+ return fmt.Errorf("%w: %w", ErrFailedToGetAuthEvents, err)
+ }
+ expectedAuthEvents := evt.AuthEventSelection(roomVersion)
+ deduplicator := make(map[pdu.StateKey]id.EventID, len(expectedAuthEvents))
+ // 3. Considering the event’s auth_events:
+ for i, ae := range authEvents {
+ authEvtID := evt.AuthEvents[i]
+ if ae == nil {
+ return fmt.Errorf("%w (%s)", ErrMissingAuthEvent, authEvtID)
+ } else if ae.StateKey == nil {
+ // This approximately falls under rule 3.2.
+ return fmt.Errorf("%w (%s)", ErrNonStateAuthEvent, authEvtID)
+ }
+ key := pdu.StateKey{Type: ae.Type, StateKey: *ae.StateKey}
+ if prevEvtID, alreadyFound := deduplicator[key]; alreadyFound {
+ // 3.1. If there are duplicate entries for a given type and state_key pair, reject.
+ return fmt.Errorf("%w for %s/%s: found %s and %s", ErrDuplicateAuthEvent, ae.Type, *ae.StateKey, prevEvtID, authEvtID)
+ } else if !expectedAuthEvents.Has(key) {
+ // 3.2. If there are entries whose type and state_key don’t match those specified by
+ // the auth events selection algorithm described in the server specification, reject.
+ return fmt.Errorf("%w: found %s with key %s/%s", ErrUnexpectedAuthEvent, authEvtID, ae.Type, *ae.StateKey)
+ } else if isRejected(ae) {
+ // 3.3. If there are entries which were themselves rejected under the checks performed on receipt of a PDU, reject.
+ return fmt.Errorf("%w (%s)", ErrRejectedAuthEvent, authEvtID)
+ } else if ae.RoomID != evt.RoomID {
+ // 3.4. If any event in auth_events has a room_id which does not match that of the event being authorised, reject.
+ return fmt.Errorf("%w (%s)", ErrMismatchingRoomIDInAuthEvent, authEvtID)
+ } else {
+ deduplicator[key] = authEvtID
+ }
+ if ae.Type == event.StateCreate.Type {
+ if createEvt == nil {
+ createEvt = ae
+ } else {
+ // Duplicates are prevented by deduplicator, AuthEventSelection also won't allow a create event at all for v12+
+ panic(fmt.Errorf("impossible case: multiple create events found in auth events"))
+ }
+ }
+ }
+ if createEvt == nil {
+ // This comes either from auth_events or room_id depending on the room version.
+ // The checks above make sure it's from the right source.
+ return ErrNoCreateEvent
+ }
+ if federateVal := gjson.GetBytes(createEvt.Content, mFederatePath); federateVal.Type == gjson.False && createEvt.Sender.Homeserver() != evt.Sender.Homeserver() {
+ // 4. If the content of the m.room.create event in the room state has the property m.federate set to false,
+ // and the sender domain of the event does not match the sender domain of the create event, reject.
+ return ErrFederationDisabled
+ }
+ if evt.Type == event.StateMember.Type {
+ // 5. If type is m.room.member:
+ return authorizeMember(roomVersion, evt, createEvt, authEvents, getKey)
+ }
+ senderMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, evt.Sender.String(), "membership", "leave"))
+ if senderMembership != event.MembershipJoin {
+ // 6. If the sender’s current membership state is not join, reject.
+ return ErrNotInRoom
+ }
+ powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt)
+ if err != nil {
+ return err
+ }
+ senderPL := powerLevels.GetUserLevel(evt.Sender)
+ if evt.Type == event.StateThirdPartyInvite.Type {
+ // 7.1. Allow if and only if sender’s current power level is greater than or equal to the invite level.
+ if senderPL >= powerLevels.Invite() {
+ return nil
+ }
+ return ErrInsufficientPowerForThirdPartyInvite
+ }
+ typeClass := event.MessageEventType
+ if evt.StateKey != nil {
+ typeClass = event.StateEventType
+ }
+ evtLevel := powerLevels.GetEventLevel(event.Type{Type: evt.Type, Class: typeClass})
+ if evtLevel > senderPL {
+ // 8. If the event type’s required power level is greater than the sender’s power level, reject.
+ return fmt.Errorf("%w (%d > %d)", ErrInsufficientPowerLevel, evtLevel, senderPL)
+ }
+
+ if evt.StateKey != nil && strings.HasPrefix(*evt.StateKey, "@") && *evt.StateKey != evt.Sender.String() {
+ // 9. If the event has a state_key that starts with an @ and does not match the sender, reject.
+ return ErrMismatchingPrivateStateKey
+ }
+
+ if evt.Type == event.StatePowerLevels.Type {
+ // 10. If type is m.room.power_levels:
+ return authorizePowerLevels(roomVersion, evt, createEvt, authEvents)
+ }
+
+ // 11. Otherwise, allow.
+ return nil
+}
+
+var ErrUserIDNotAString = errors.New("not a string")
+var ErrUserIDNotValid = errors.New("not a valid user ID")
+
+func isValidUserID(roomVersion id.RoomVersion, userID gjson.Result) error {
+ if userID.Type != gjson.String {
+ return ErrUserIDNotAString
+ }
+ // In a future room version, user IDs will have stricter validation
+ _, _, err := id.UserID(userID.Str).Parse()
+ if err != nil {
+ return ErrUserIDNotValid
+ }
+ return nil
+}
+
+func authorizeCreate(roomVersion id.RoomVersion, evt *pdu.PDU) error {
+ if len(evt.PrevEvents) > 0 {
+ // 1.1. If it has any prev_events, reject.
+ return ErrCreateHasPrevEvents
+ }
+ if roomVersion.RoomIDIsCreateEventID() {
+ if evt.RoomID != "" {
+ // 1.2. If the event has a room_id, reject.
+ return ErrCreateHasRoomID
+ }
+ } else {
+ _, _, server := id.ParseCommonIdentifier(evt.RoomID)
+ if server == "" || server != evt.Sender.Homeserver() {
+ // 1.2. (v11 and below) If the domain of the room_id does not match the domain of the sender, reject.
+ return ErrRoomIDDoesntMatchSender
+ }
+ }
+ if !roomVersion.IsKnown() {
+ // 1.3. If content.room_version is present and is not a recognised version, reject.
+ return fmt.Errorf("%w %s", ErrUnknownRoomVersion, roomVersion)
+ }
+ if roomVersion.PrivilegedRoomCreators() {
+ additionalCreators := gjson.GetBytes(evt.Content, "additional_creators")
+ if additionalCreators.Exists() {
+ if !additionalCreators.IsArray() {
+ return fmt.Errorf("%w: not an array", ErrInvalidAdditionalCreators)
+ }
+ for i, item := range additionalCreators.Array() {
+ // 1.4. If additional_creators is present in content and is not an array of strings
+ // where each string passes the same user ID validation applied to sender, reject.
+ if err := isValidUserID(roomVersion, item); err != nil {
+ return fmt.Errorf("%w: item #%d %w", ErrInvalidAdditionalCreators, i+1, err)
+ }
+ }
+ }
+ }
+ if roomVersion.CreatorInContent() {
+ // 1.4. (v10 and below) If content has no creator property, reject.
+ if !gjson.GetBytes(evt.Content, "creator").Exists() {
+ return ErrMissingCreator
+ }
+ }
+ // 1.5. Otherwise, allow.
+ return nil
+}
+
+func authorizeMember(roomVersion id.RoomVersion, evt, createEvt *pdu.PDU, authEvents []*pdu.PDU, getKey pdu.GetKeyFunc) error {
+ membership := event.Membership(gjson.GetBytes(evt.Content, "membership").Str)
+ if evt.StateKey == nil {
+ // 5.1. If there is no state_key property, or no membership property in content, reject.
+ return ErrMemberNotState
+ }
+ authorizedVia := id.UserID(gjson.GetBytes(evt.Content, "authorised_via_users_server").Str)
+ if authorizedVia != "" {
+ homeserver := authorizedVia.Homeserver()
+ err := evt.VerifySignature(roomVersion, homeserver, getKey)
+ if err != nil {
+ // 5.2. If content has a join_authorised_via_users_server key:
+ // 5.2.1. If the event is not validly signed by the homeserver of the user ID denoted by the key, reject.
+ return fmt.Errorf("%w: %w", ErrNotSignedByAuthoriser, err)
+ }
+ }
+ targetPrevMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, *evt.StateKey, "membership", "leave"))
+ senderMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, evt.Sender.String(), "membership", "leave"))
+ switch membership {
+ case event.MembershipJoin:
+ createEvtID, err := createEvt.GetEventID(roomVersion)
+ if err != nil {
+ return fmt.Errorf("failed to get create event ID: %w", err)
+ }
+ creator := createEvt.Sender.String()
+ if roomVersion.CreatorInContent() {
+ creator = gjson.GetBytes(evt.Content, "creator").Str
+ }
+ if len(evt.PrevEvents) == 1 &&
+ len(evt.AuthEvents) <= 1 &&
+ evt.PrevEvents[0] == createEvtID &&
+ *evt.StateKey == creator {
+ // 5.3.1. If the only previous event is an m.room.create and the state_key is the sender of the m.room.create, allow.
+ return nil
+ }
+ // Spec wart: this would make more sense before the check above.
+ // Now you can set anyone as the sender of the first join.
+ if evt.Sender.String() != *evt.StateKey {
+ // 5.3.2. If the sender does not match state_key, reject.
+ return ErrCantJoinOtherUser
+ }
+
+ if senderMembership == event.MembershipBan {
+ // 5.3.3. If the sender is banned, reject.
+ return ErrCantJoinBanned
+ }
+
+ joinRule := event.JoinRule(findEventAndReadString(authEvents, event.StateJoinRules.Type, "", "join_rule", "invite"))
+ switch joinRule {
+ case event.JoinRuleKnock:
+ if !roomVersion.Knocks() {
+ return ErrInvalidJoinRule
+ }
+ fallthrough
+ case event.JoinRuleInvite:
+ // 5.3.4. If the join_rule is invite or knock then allow if membership state is invite or join.
+ if targetPrevMembership == event.MembershipJoin || targetPrevMembership == event.MembershipInvite {
+ return nil
+ }
+ return ErrCantJoinWithoutInvite
+ case event.JoinRuleKnockRestricted:
+ if !roomVersion.KnockRestricted() {
+ return ErrInvalidJoinRule
+ }
+ fallthrough
+ case event.JoinRuleRestricted:
+ if joinRule == event.JoinRuleRestricted && !roomVersion.RestrictedJoins() {
+ return ErrInvalidJoinRule
+ }
+ if targetPrevMembership == event.MembershipJoin || targetPrevMembership == event.MembershipInvite {
+ // 5.3.5.1. If membership state is join or invite, allow.
+ return nil
+ }
+ powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt)
+ if err != nil {
+ return err
+ }
+ if powerLevels.GetUserLevel(authorizedVia) < powerLevels.Invite() {
+ // 5.3.5.2. If the join_authorised_via_users_server key in content is not a user with sufficient permission to invite other users, reject.
+ return ErrAuthoriserCantInvite
+ }
+ authorizerMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, authorizedVia.String(), "membership", string(event.MembershipLeave)))
+ if authorizerMembership != event.MembershipJoin {
+ return ErrAuthoriserNotInRoom
+ }
+ // 5.3.5.3. Otherwise, allow.
+ return nil
+ case event.JoinRulePublic:
+ // 5.3.6. If the join_rule is public, allow.
+ return nil
+ default:
+ // 5.3.7. Otherwise, reject.
+ return ErrInvalidJoinRule
+ }
+ case event.MembershipInvite:
+ tpiVal := gjson.GetBytes(evt.Content, "third_party_invite")
+ if tpiVal.Exists() {
+ if targetPrevMembership == event.MembershipBan {
+ return ErrThirdPartyInviteBanned
+ }
+ signed := tpiVal.Get("signed")
+ mxid := signed.Get("mxid").Str
+ token := signed.Get("token").Str
+ if mxid == "" || token == "" {
+ // 5.4.1.2. If content.third_party_invite does not have a signed property, reject.
+ // 5.4.1.3. If signed does not have mxid and token properties, reject.
+ return ErrThirdPartyInviteMissingFields
+ }
+ if mxid != *evt.StateKey {
+ // 5.4.1.4. If mxid does not match state_key, reject.
+ return ErrThirdPartyInviteMXIDMismatch
+ }
+ tpiEvt := findEvent(authEvents, event.StateThirdPartyInvite.Type, token)
+ if tpiEvt == nil {
+ // 5.4.1.5. If there is no m.room.third_party_invite event in the current room state with state_key matching token, reject.
+ return ErrThirdPartyInviteNotFound
+ }
+ if tpiEvt.Sender != evt.Sender {
+ // 5.4.1.6. If sender does not match sender of the m.room.third_party_invite, reject.
+ return ErrThirdPartyInviteSenderMismatch
+ }
+ var keys []id.Ed25519
+ const ed25519Base64Len = 43
+ oldPubKey := gjson.GetBytes(evt.Content, "public_key.token")
+ if oldPubKey.Type == gjson.String && len(oldPubKey.Str) == ed25519Base64Len {
+ keys = append(keys, id.Ed25519(oldPubKey.Str))
+ }
+ gjson.GetBytes(evt.Content, "public_keys").ForEach(func(key, value gjson.Result) bool {
+ if key.Type != gjson.Number {
+ return false
+ }
+ if value.Type == gjson.String && len(value.Str) == ed25519Base64Len {
+ keys = append(keys, id.Ed25519(value.Str))
+ }
+ return true
+ })
+ rawSigned := jsontext.Value(exstrings.UnsafeBytes(signed.Str))
+ var validated bool
+ for _, key := range keys {
+ if signutil.VerifyJSONAny(key, rawSigned) == nil {
+ validated = true
+ }
+ }
+ if validated {
+ // 4.4.1.7. If any signature in signed matches any public key in the m.room.third_party_invite event, allow.
+ return nil
+ }
+ // 4.4.1.8. Otherwise, reject.
+ return ErrThirdPartyInviteNotSigned
+ }
+ if senderMembership != event.MembershipJoin {
+ // 5.4.2. If the sender’s current membership state is not join, reject.
+ return ErrInviterNotInRoom
+ }
+ // 5.4.3. If target user’s current membership state is join or ban, reject.
+ if targetPrevMembership == event.MembershipJoin {
+ return ErrInviteTargetAlreadyInRoom
+ } else if targetPrevMembership == event.MembershipBan {
+ return ErrInviteTargetBanned
+ }
+ powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt)
+ if err != nil {
+ return err
+ }
+ if powerLevels.GetUserLevel(evt.Sender) >= powerLevels.Invite() {
+ // 5.4.4. If the sender’s power level is greater than or equal to the invite level, allow.
+ return nil
+ }
+ // 5.4.5. Otherwise, reject.
+ return ErrInsufficientPermissionForInvite
+ case event.MembershipLeave:
+ if evt.Sender.String() == *evt.StateKey {
+ // 5.5.1. If the sender matches state_key, allow if and only if that user’s current membership state is invite, join, or knock.
+ if senderMembership == event.MembershipInvite ||
+ senderMembership == event.MembershipJoin ||
+ (senderMembership == event.MembershipKnock && roomVersion.Knocks()) {
+ return nil
+ }
+ return ErrCantLeaveWithoutBeingInRoom
+ }
+ if senderMembership != event.MembershipJoin {
+ // 5.5.2. If the sender’s current membership state is not join, reject.
+ return ErrCantKickWithoutBeingInRoom
+ }
+ powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt)
+ if err != nil {
+ return err
+ }
+ senderLevel := powerLevels.GetUserLevel(evt.Sender)
+ if targetPrevMembership == event.MembershipBan && senderLevel < powerLevels.Ban() {
+ // 5.5.3. If the target user’s current membership state is ban, and the sender’s power level is less than the ban level, reject.
+ return ErrInsufficientPermissionForUnban
+ }
+ if senderLevel >= powerLevels.Kick() && powerLevels.GetUserLevel(id.UserID(*evt.StateKey)) < senderLevel {
+ // 5.5.4. If the sender’s power level is greater than or equal to the kick level, and the target user’s power level is less than the sender’s power level, allow.
+ return nil
+ }
+ // TODO separate errors for < kick and < target user level?
+ // 5.5.5. Otherwise, reject.
+ return ErrInsufficientPermissionForKick
+ case event.MembershipBan:
+ if senderMembership != event.MembershipJoin {
+ // 5.6.1. If the sender’s current membership state is not join, reject.
+ return ErrCantBanWithoutBeingInRoom
+ }
+ powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt)
+ if err != nil {
+ return err
+ }
+ senderLevel := powerLevels.GetUserLevel(evt.Sender)
+ if senderLevel >= powerLevels.Ban() && powerLevels.GetUserLevel(id.UserID(*evt.StateKey)) < senderLevel {
+ // 5.6.2. If the sender’s power level is greater than or equal to the ban level, and the target user’s power level is less than the sender’s power level, allow.
+ return nil
+ }
+ // 5.6.3. Otherwise, reject.
+ return ErrInsufficientPermissionForBan
+ case event.MembershipKnock:
+ joinRule := event.JoinRule(findEventAndReadString(authEvents, event.StateJoinRules.Type, "", "join_rule", "invite"))
+ validKnockRule := roomVersion.Knocks() && joinRule == event.JoinRuleKnock
+ validKnockRestrictedRule := roomVersion.KnockRestricted() && joinRule == event.JoinRuleKnockRestricted
+ if !validKnockRule && !validKnockRestrictedRule {
+ // 5.7.1. If the join_rule is anything other than knock or knock_restricted, reject.
+ return ErrNotKnockableRoom
+ }
+ if evt.Sender.String() != *evt.StateKey {
+ // 5.7.2. If the sender does not match state_key, reject.
+ return ErrCantKnockOtherUser
+ }
+ if senderMembership != event.MembershipBan && senderMembership != event.MembershipInvite && senderMembership != event.MembershipJoin {
+ // 5.7.3. If the sender’s current membership is not ban, invite, or join, allow.
+ return nil
+ }
+ // 5.7.4. Otherwise, reject.
+ return ErrCantKnockWhileInRoom
+ default:
+ // 5.8. Otherwise, the membership is unknown. Reject.
+ return ErrUnknownMembership
+ }
+}
+
+func authorizePowerLevels(roomVersion id.RoomVersion, evt, createEvt *pdu.PDU, authEvents []*pdu.PDU) error {
+ if roomVersion.ValidatePowerLevelInts() {
+ for _, key := range []string{"users_default", "events_default", "state_default", "ban", "redact", "kick", "invite"} {
+ res := gjson.GetBytes(evt.Content, key)
+ if !res.Exists() {
+ continue
+ }
+ if parseIntWithVersion(roomVersion, res) == nil {
+ // 10.1. If any of the properties users_default, events_default, state_default, ban, redact, kick, or invite in content are present and not an integer, reject.
+ return fmt.Errorf("%w %s", ErrTopLevelPLNotInteger, key)
+ }
+ }
+ for _, key := range []string{"events", "notifications"} {
+ obj := gjson.GetBytes(evt.Content, key)
+ if !obj.Exists() {
+ continue
+ }
+ // 10.2. If either of the properties events or notifications in content are present and not an object [...], reject.
+ if !obj.IsObject() {
+ return fmt.Errorf("%w %s", ErrTopLevelPLNotInteger, key)
+ }
+ var err error
+ // 10.2. [...] are not an object with values that are integers, reject.
+ obj.ForEach(func(innerKey, value gjson.Result) bool {
+ if parseIntWithVersion(roomVersion, value) == nil {
+ err = fmt.Errorf("%w %s.%s", ErrPLNotInteger, key, innerKey.Str)
+ return false
+ }
+ return true
+ })
+ if err != nil {
+ return err
+ }
+ }
+ }
+ var creators []id.UserID
+ if roomVersion.PrivilegedRoomCreators() {
+ creators = append(creators, createEvt.Sender)
+ gjson.GetBytes(createEvt.Content, "additional_creators").ForEach(func(key, value gjson.Result) bool {
+ creators = append(creators, id.UserID(value.Str))
+ return true
+ })
+ }
+ users := gjson.GetBytes(evt.Content, "users")
+ if users.Exists() {
+ if !users.IsObject() {
+ // 10.3. If the users property in content is not an object [...], reject.
+ return fmt.Errorf("%w users", ErrTopLevelPLNotInteger)
+ }
+ var err error
+ users.ForEach(func(key, value gjson.Result) bool {
+ if validatorErr := isValidUserID(roomVersion, key); validatorErr != nil {
+ // 10.3. [...] is not an object with keys that are valid user IDs [...], reject.
+ err = fmt.Errorf("%w: %q %w", ErrInvalidUserIDInPL, key.Str, validatorErr)
+ return false
+ }
+ if parseIntWithVersion(roomVersion, value) == nil {
+ // 10.3. [...] is not an object [...] with values that are integers, reject.
+ err = fmt.Errorf("%w %q", ErrUserPLNotInteger, key.Str)
+ return false
+ }
+ // creators is only filled if the room version has privileged room creators
+ if slices.Contains(creators, id.UserID(key.Str)) {
+ // 10.4. If the users property in content contains the sender of the m.room.create event or any of
+ // the additional_creators array (if present) from the content of the m.room.create event, reject.
+ err = fmt.Errorf("%w: %q", ErrCreatorInPowerLevels, key.Str)
+ return false
+ }
+ return true
+ })
+ if err != nil {
+ return err
+ }
+ }
+ oldPL := findEvent(authEvents, event.StatePowerLevels.Type, "")
+ if oldPL == nil {
+ // 10.5. If there is no previous m.room.power_levels event in the room, allow.
+ return nil
+ }
+ if slices.Contains(creators, evt.Sender) {
+ // Skip remaining checks for creators
+ return nil
+ }
+ senderPLPtr := parsePythonInt(gjson.GetBytes(oldPL.Content, exgjson.Path("users", evt.Sender.String())))
+ if senderPLPtr == nil {
+ senderPLPtr = parsePythonInt(gjson.GetBytes(oldPL.Content, "users_default"))
+ if senderPLPtr == nil {
+ senderPLPtr = ptr.Ptr(0)
+ }
+ }
+ for _, key := range []string{"users_default", "events_default", "state_default", "ban", "redact", "kick", "invite"} {
+ oldVal := gjson.GetBytes(oldPL.Content, key)
+ newVal := gjson.GetBytes(evt.Content, key)
+ if err := allowPowerChange(roomVersion, *senderPLPtr, key, oldVal, newVal); err != nil {
+ return err
+ }
+ }
+ if err := allowPowerChangeMap(
+ roomVersion, *senderPLPtr, "events", "",
+ gjson.GetBytes(oldPL.Content, "events"),
+ gjson.GetBytes(evt.Content, "events"),
+ ); err != nil {
+ return err
+ }
+ if err := allowPowerChangeMap(
+ roomVersion, *senderPLPtr, "notifications", "",
+ gjson.GetBytes(oldPL.Content, "notifications"),
+ gjson.GetBytes(evt.Content, "notifications"),
+ ); err != nil {
+ return err
+ }
+ if err := allowPowerChangeMap(
+ roomVersion, *senderPLPtr, "users", evt.Sender.String(),
+ gjson.GetBytes(oldPL.Content, "users"),
+ gjson.GetBytes(evt.Content, "users"),
+ ); err != nil {
+ return err
+ }
+ return nil
+}
+
+func allowPowerChangeMap(roomVersion id.RoomVersion, maxVal int, path, ownID string, old, new gjson.Result) (err error) {
+ old.ForEach(func(key, value gjson.Result) bool {
+ newVal := new.Get(exgjson.Path(key.Str))
+ err = allowPowerChange(roomVersion, maxVal, path+"."+key.Str, value, newVal)
+ if err == nil && ownID != "" && key.Str != ownID {
+ parsedOldVal := parseIntWithVersion(roomVersion, value)
+ parsedNewVal := parseIntWithVersion(roomVersion, newVal)
+ if *parsedOldVal >= maxVal && *parsedOldVal != *parsedNewVal {
+ err = fmt.Errorf("%w: can't change users.%s from %s to %s with sender level %d", ErrInvalidUserPowerChange, key.Str, stringifyForError(value), stringifyForError(newVal), maxVal)
+ }
+ }
+ return err == nil
+ })
+ if err != nil {
+ return
+ }
+ new.ForEach(func(key, value gjson.Result) bool {
+ err = allowPowerChange(roomVersion, maxVal, path+"."+key.Str, old.Get(exgjson.Path(key.Str)), value)
+ return err == nil
+ })
+ return
+}
+
+func allowPowerChange(roomVersion id.RoomVersion, maxVal int, path string, old, new gjson.Result) error {
+ oldVal := parseIntWithVersion(roomVersion, old)
+ newVal := parseIntWithVersion(roomVersion, new)
+ if oldVal == nil {
+ if newVal == nil || *newVal <= maxVal {
+ return nil
+ }
+ } else if newVal == nil {
+ if *oldVal <= maxVal {
+ return nil
+ }
+ } else if *oldVal == *newVal || (*oldVal <= maxVal && *newVal <= maxVal) {
+ return nil
+ }
+ return fmt.Errorf("%w can't change %s from %s to %s with sender level %d", ErrInvalidPowerChange, path, stringifyForError(old), stringifyForError(new), maxVal)
+}
+
+func stringifyForError(val gjson.Result) string {
+ if !val.Exists() {
+ return "null"
+ }
+ return val.Raw
+}
+
+func findEvent(events []*pdu.PDU, evtType, stateKey string) *pdu.PDU {
+ for _, evt := range events {
+ if evt.Type == evtType && *evt.StateKey == stateKey {
+ return evt
+ }
+ }
+ return nil
+}
+
+func findEventAndReadData[T any](events []*pdu.PDU, evtType, stateKey string, reader func(evt *pdu.PDU) T) T {
+ return reader(findEvent(events, evtType, stateKey))
+}
+
+func findEventAndReadString(events []*pdu.PDU, evtType, stateKey, fieldPath, defVal string) string {
+ return findEventAndReadData(events, evtType, stateKey, func(evt *pdu.PDU) string {
+ if evt == nil {
+ return defVal
+ }
+ res := gjson.GetBytes(evt.Content, fieldPath)
+ if res.Type != gjson.String {
+ return defVal
+ }
+ return res.Str
+ })
+}
+
+func getPowerLevels(roomVersion id.RoomVersion, authEvents []*pdu.PDU, createEvt *pdu.PDU) (*event.PowerLevelsEventContent, error) {
+ var err error
+ powerLevels := findEventAndReadData(authEvents, event.StatePowerLevels.Type, "", func(evt *pdu.PDU) *event.PowerLevelsEventContent {
+ if evt == nil {
+ return nil
+ }
+ content := evt.Content
+ out := &event.PowerLevelsEventContent{}
+ if !roomVersion.ValidatePowerLevelInts() {
+ safeParsePowerLevels(content, out)
+ } else {
+ err = json.Unmarshal(content, out)
+ }
+ return out
+ })
+ if err != nil {
+ // This should never happen thanks to safeParsePowerLevels for v1-9 and strict validation in v10+
+ return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err)
+ }
+ if roomVersion.PrivilegedRoomCreators() {
+ if powerLevels == nil {
+ powerLevels = &event.PowerLevelsEventContent{}
+ }
+ powerLevels.CreateEvent, err = createEvt.ToClientEvent(roomVersion)
+ if err != nil {
+ return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err)
+ }
+ err = powerLevels.CreateEvent.Content.ParseRaw(powerLevels.CreateEvent.Type)
+ if err != nil {
+ return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err)
+ }
+ } else if powerLevels == nil {
+ powerLevels = &event.PowerLevelsEventContent{
+ Users: map[id.UserID]int{
+ createEvt.Sender: 100,
+ },
+ }
+ }
+ return powerLevels, nil
+}
+
+func parseIntWithVersion(roomVersion id.RoomVersion, val gjson.Result) *int {
+ if roomVersion.ValidatePowerLevelInts() {
+ if val.Type != gjson.Number {
+ return nil
+ }
+ return ptr.Ptr(int(val.Int()))
+ }
+ return parsePythonInt(val)
+}
+
+func parsePythonInt(val gjson.Result) *int {
+ switch val.Type {
+ case gjson.True:
+ return ptr.Ptr(1)
+ case gjson.False:
+ return ptr.Ptr(0)
+ case gjson.Number:
+ return ptr.Ptr(int(val.Int()))
+ case gjson.String:
+ // strconv.Atoi accepts signs as well as leading zeroes, so we just need to trim spaces beforehand
+ num, err := strconv.Atoi(strings.TrimSpace(val.Str))
+ if err != nil {
+ return nil
+ }
+ return &num
+ default:
+ // Python int() doesn't accept nulls, arrays or dicts
+ return nil
+ }
+}
+
+func safeParsePowerLevels(content jsontext.Value, into *event.PowerLevelsEventContent) {
+ *into = event.PowerLevelsEventContent{
+ Users: make(map[id.UserID]int),
+ UsersDefault: ptr.Val(parsePythonInt(gjson.GetBytes(content, "users_default"))),
+ Events: make(map[string]int),
+ EventsDefault: ptr.Val(parsePythonInt(gjson.GetBytes(content, "events_default"))),
+ Notifications: nil, // irrelevant for event auth
+ StateDefaultPtr: parsePythonInt(gjson.GetBytes(content, "state_default")),
+ InvitePtr: parsePythonInt(gjson.GetBytes(content, "invite")),
+ KickPtr: parsePythonInt(gjson.GetBytes(content, "kick")),
+ BanPtr: parsePythonInt(gjson.GetBytes(content, "ban")),
+ RedactPtr: parsePythonInt(gjson.GetBytes(content, "redact")),
+ }
+ gjson.GetBytes(content, "events").ForEach(func(key, value gjson.Result) bool {
+ if key.Type != gjson.String {
+ return false
+ }
+ val := parsePythonInt(value)
+ if val != nil {
+ into.Events[key.Str] = *val
+ }
+ return true
+ })
+ gjson.GetBytes(content, "users").ForEach(func(key, value gjson.Result) bool {
+ if key.Type != gjson.String {
+ return false
+ }
+ val := parsePythonInt(value)
+ if val == nil {
+ return false
+ }
+ userID := id.UserID(key.Str)
+ if _, _, err := userID.Parse(); err != nil {
+ return false
+ }
+ into.Users[userID] = *val
+ return true
+ })
+}
diff --git a/federation/eventauth/eventauth_internal_test.go b/federation/eventauth/eventauth_internal_test.go
new file mode 100644
index 00000000..d316f3c8
--- /dev/null
+++ b/federation/eventauth/eventauth_internal_test.go
@@ -0,0 +1,66 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package eventauth
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+type pythonIntTest struct {
+ Name string
+ Input string
+ Expected int64
+}
+
+var pythonIntTests = []pythonIntTest{
+ {"True", `true`, 1},
+ {"False", `false`, 0},
+ {"SmallFloat", `3.1415`, 3},
+ {"SmallFloatRoundDown", `10.999999999999999`, 10},
+ {"SmallFloatRoundUp", `10.9999999999999999`, 11},
+ {"BigFloatRoundDown", `1000000.9999999999`, 1000000},
+ {"BigFloatRoundUp", `1000000.99999999999`, 1000001},
+ {"BigFloatPrecisionError", `9007199254740993.0`, 9007199254740992},
+ {"BigFloatPrecisionError2", `9007199254740993.123`, 9007199254740994},
+ {"Int64", `9223372036854775807`, 9223372036854775807},
+ {"Int64String", `"9223372036854775807"`, 9223372036854775807},
+ {"String", `"123"`, 123},
+ {"InvalidFloatInString", `"123.456"`, 0},
+ {"StringWithPlusSign", `"+123"`, 123},
+ {"StringWithMinusSign", `"-123"`, -123},
+ {"StringWithSpaces", `" 123 "`, 123},
+ {"StringWithSpacesAndSign", `" -123 "`, -123},
+ //{"StringWithUnderscores", `"123_456"`, 123456},
+ //{"StringWithUnderscores", `"123_456"`, 123456},
+ {"InvalidStringWithTrailingUnderscore", `"123_456_"`, 0},
+ {"InvalidStringWithMultipleUnderscores", `"123__456"`, 0},
+ {"InvalidStringWithLeadingUnderscore", `"_123_456"`, 0},
+ {"InvalidStringWithUnderscoreAfterSign", `"+_123_456"`, 0},
+ {"InvalidStringWithUnderscoreAfterSpace", `" _123_456"`, 0},
+ //{"StringWithUnderscoresAndSpaces", `" +1_2_3_4_5_6 "`, 123456},
+}
+
+func TestParsePythonInt(t *testing.T) {
+ for _, test := range pythonIntTests {
+ t.Run(test.Name, func(t *testing.T) {
+ output := parsePythonInt(gjson.Parse(test.Input))
+ if strings.HasPrefix(test.Name, "Invalid") {
+ assert.Nil(t, output)
+ } else {
+ require.NotNil(t, output)
+ assert.Equal(t, int(test.Expected), *output)
+ }
+ })
+ }
+}
diff --git a/federation/eventauth/eventauth_test.go b/federation/eventauth/eventauth_test.go
new file mode 100644
index 00000000..e3c5cd76
--- /dev/null
+++ b/federation/eventauth/eventauth_test.go
@@ -0,0 +1,85 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package eventauth_test
+
+import (
+ "embed"
+ "encoding/json/jsontext"
+ "encoding/json/v2"
+ "errors"
+ "io"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+ "go.mau.fi/util/exerrors"
+ "go.mau.fi/util/ptr"
+
+ "maunium.net/go/mautrix/federation/eventauth"
+ "maunium.net/go/mautrix/federation/pdu"
+ "maunium.net/go/mautrix/id"
+)
+
+//go:embed *.jsonl
+var data embed.FS
+
+type eventMap map[id.EventID]*pdu.PDU
+
+func (em eventMap) Get(ids []id.EventID) ([]*pdu.PDU, error) {
+ output := make([]*pdu.PDU, len(ids))
+ for i, evtID := range ids {
+ output[i] = em[evtID]
+ }
+ return output, nil
+}
+
+func GetKey(serverName string, keyID id.KeyID, validUntilTS time.Time) (id.SigningKey, time.Time, error) {
+ return "", time.Time{}, nil
+}
+
+func TestAuthorize(t *testing.T) {
+ files := exerrors.Must(data.ReadDir("."))
+ for _, file := range files {
+ t.Run(file.Name(), func(t *testing.T) {
+ decoder := jsontext.NewDecoder(exerrors.Must(data.Open(file.Name())))
+ events := make(eventMap)
+ var roomVersion *id.RoomVersion
+ for i := 1; ; i++ {
+ var evt *pdu.PDU
+ err := json.UnmarshalDecode(decoder, &evt)
+ if errors.Is(err, io.EOF) {
+ break
+ }
+ require.NoError(t, err)
+ if roomVersion == nil {
+ require.Equal(t, evt.Type, "m.room.create")
+ roomVersion = ptr.Ptr(id.RoomVersion(gjson.GetBytes(evt.Content, "room_version").Str))
+ }
+ expectedEventID := gjson.GetBytes(evt.Unsigned, "event_id").Str
+ evtID, err := evt.GetEventID(*roomVersion)
+ require.NoError(t, err)
+ require.Equalf(t, id.EventID(expectedEventID), evtID, "Event ID mismatch for event #%d", i)
+
+ // TODO allow redacted events
+ assert.True(t, evt.VerifyContentHash(), i)
+
+ events[evtID] = evt
+ err = eventauth.Authorize(*roomVersion, evt, events.Get, GetKey)
+ if err != nil {
+ evt.InternalMeta.Rejected = true
+ }
+ // TODO allow testing intentionally rejected events
+ assert.NoErrorf(t, err, "Failed to authorize event #%d / %s of type %s", i, evtID, evt.Type)
+ }
+ })
+ }
+
+}
diff --git a/federation/eventauth/testroom-v12-success.jsonl b/federation/eventauth/testroom-v12-success.jsonl
new file mode 100644
index 00000000..2b751de3
--- /dev/null
+++ b/federation/eventauth/testroom-v12-success.jsonl
@@ -0,0 +1,21 @@
+{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age_ts":1756071567186,"event_id":"$lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54"}}
+{"auth_events":[],"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"depth":2,"hashes":{"sha256":"MXmgq0e4J9CdIP0IVKVvueFhOb+ndlsXpeyI+6l/2FI"},"origin_server_ts":1756071567259,"prev_events":["$lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"xMgRzyRg9VM9XCKpfFJA+MrYoI68b8PIddKpMTcxz/fDzmGSHEy6Ta2b59VxiX3NoJe2CigkDZ3+jVsQoZYIBA"}},"state_key":"@tulir:maunium.net","type":"m.room.member","unsigned":{"age_ts":1756071567259,"event_id":"$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"}}
+{"auth_events":["$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":150},"events_default":0,"historical":100,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001},"users_default":0},"depth":3,"hashes":{"sha256":"/JzQNBNqJ/i8vwj6xESDaD5EDdOqB4l/LmKlvAVl5jY"},"origin_server_ts":1756071567319,"prev_events":["$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"W3N3X/enja+lumXw3uz66/wT9oczoxrmHbAD5/RF069cX4wkCtqtDd61VWPkSGmKxdV1jurgbCqSX6+Q9/t3AA"}},"state_key":"","type":"m.room.power_levels","unsigned":{"age_ts":1756071567319,"event_id":"$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"}}
+{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"join_rule":"invite"},"depth":4,"hashes":{"sha256":"GBu5AySj75ZXlOLd65mB03KueFKOHNgvtg2o/LUnLyI"},"origin_server_ts":1756071567320,"prev_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"XqWEnFREo2PhRnaebGjNzdHdtD691BtCQKkLnpKd8P3lVDewDt8OkCbDSk/Uzh9rDtzwWEsbsIoKSYuOm+G6CA"}},"state_key":"","type":"m.room.join_rules","unsigned":{"age_ts":1756071567320,"event_id":"$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"}}
+{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"history_visibility":"shared"},"depth":5,"hashes":{"sha256":"niDi5vG2akQm0f5pm0aoCYXqmWjXRfmP1ulr/ZEPm/k"},"origin_server_ts":1756071567320,"prev_events":["$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"PTIrNke/fc9+ObKAl/K0PGZfmpe8dwREyoA5rXffOXWdRHSaBifn9UIiJUqd68Bzvrv4RcADTR/ci7lUquFBBw"}},"state_key":"","type":"m.room.history_visibility","unsigned":{"age_ts":1756071567320,"event_id":"$Wmy3G9yxl9ArVg5ZsdeIDPxBsNAdgseuvHoqHTZ2vug"}}
+{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"guest_access":"can_join"},"depth":6,"hashes":{"sha256":"sZ9QqsId4oarFF724esTohXuRxDNnaXPl+QmTDG60dw"},"origin_server_ts":1756071567321,"prev_events":["$Wmy3G9yxl9ArVg5ZsdeIDPxBsNAdgseuvHoqHTZ2vug"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"Eh2P9/hl38wfZx2AQbeS5VCD4wldXPfeP2sQsJsLtfmdwFV74jrlGVBaKIkaYcXY4eA08iDp8HW5jqttZqKKDg"}},"state_key":"","type":"m.room.guest_access","unsigned":{"age_ts":1756071567321,"event_id":"$hYVRH7F4P5mB5IqvBDDU5aXY7pYGG0ApstrryiVPKmQ"}}
+{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"name":"event auth test v12"},"depth":7,"hashes":{"sha256":"tjwPo38yR+23Was6SbxLvPMhNx44DaXLhF3rKgngepU"},"origin_server_ts":1756071567321,"prev_events":["$hYVRH7F4P5mB5IqvBDDU5aXY7pYGG0ApstrryiVPKmQ"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"q1rk0c5m8TJYE9tePsMaLeaigatNNbvaLRom0X8KiZY0EH+itujfA+/UnksvmPmMmThfAXWlFLx5u8tcuSVyCQ"}},"state_key":"","type":"m.room.name","unsigned":{"age_ts":1756071567321,"event_id":"$fFDwIavLTEIfcnggWuryB6JwfS-L2KT6vP1ap3P6ctE"}}
+{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"],"content":{"avatar_url":"mxc://envs.net/000cf1510b7c61018f9c72ca4cc63668370782c81725865933316030464","displayname":"tulir[e]","membership":"invite"},"depth":8,"hashes":{"sha256":"r5EBUZN/4LbVcMYwuffDcVV9G4OMHzAQuNbnjigL+OE"},"origin_server_ts":1756071567548,"prev_events":["$fFDwIavLTEIfcnggWuryB6JwfS-L2KT6vP1ap3P6ctE"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"envs.net":{"ed25519:wuJyKT":"svB+uW4Tsj8/I+SYbLl+LPPjBlqxGNXE4wGyAxlP7vfyJtFf7Kn/19jx65wT9ebeCq5sTGlEDV4Fabwma9LhDA"},"maunium.net":{"ed25519:a_xxeS":"LBYMcdJVSNsLd6SmOgx5oOU/0xOeCl03o4g83VwJfHWlRuTT5l9+qlpNED28wY07uxoU9MgLgXXICJ0EezMBCg"}},"state_key":"@tulir:envs.net","type":"m.room.member","unsigned":{"age_ts":1756071567548,"event_id":"$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok","invite_room_state":[{"content":{"join_rule":"invite"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.join_rules"},{"content":{"name":"event auth test v12"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.name"},{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age_ts":1756071567186}},{"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"sender":"@tulir:maunium.net","state_key":"@tulir:maunium.net","type":"m.room.member"}]}}
+{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":9,"hashes":{"sha256":"23rgMf7EGJcYt3Aj0qAFnmBWCxuU9Uk+ReidqtIJDKQ"},"origin_server_ts":1756071575986,"prev_events":["$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"p+Fm/uWO8VXJdCYvN/dVb8HF8W3t1sssNCBiOWbzAeuS3QqYjoMKHyixLuN1mOdnCyATv7SsHHmA4+cELRGdAA"}},"type":"m.room.message","unsigned":{"age_ts":1756071576002,"event_id":"$eZDCydRWSRnR5od0c7ahz2qSZQDHbl5g5PITT0OMC3E"}}
+{"auth_events":["$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"],"content":{"avatar_url":"mxc://envs.net/000cf1510b7c61018f9c72ca4cc63668370782c81725865933316030464","displayname":"tulir[e]","membership":"join"},"depth":10,"hashes":{"sha256":"2kJPx2UsysNzTH8QGYHUKTO/05yetxKRlI0nKFeGbts"},"origin_server_ts":1756071578631,"prev_events":["$eZDCydRWSRnR5od0c7ahz2qSZQDHbl5g5PITT0OMC3E"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"Wuzxkh8nEEX6mdJzph6Bt5ku+odFkEg2RIpFAAirOqxgcrwRaz42PsJni3YbfzH1qneF+iWQ/neA+up6jLXFBw"}},"state_key":"@tulir:envs.net","type":"m.room.member","unsigned":{"age":6,"event_id":"$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","replaces_state":"$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok"}}
+{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"content":{"avatar_url":"mxc://matrix.org/BDYVQFSLvZHMaKHDGiRkvhVg","displayname":"tulir[m]","membership":"invite"},"depth":11,"hashes":{"sha256":"dRE11R2hBfFalQ5tIJdyaElUIiSE5aCKMddjek4wR3c"},"origin_server_ts":1756071591449,"prev_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"/Mi4kX40fbR+V3DCJJGI/9L3Uuf8y5Un8LHlCQv1T0O5gnFZGQ3qN6rRNaZ1Kdh3QJBU6H4NTfnd+SVj3wt3CQ"},"matrix.org":{"ed25519:a_RXGa":"ZeLm/oxP3/Cds/uCL2FaZpgjUp0vTDBlGG6YVFNl76yIVlyIKKQKR6BSVw2u5KC5Mu9M1f+0lDmLGQujR5NkBg"}},"state_key":"@tulir:matrix.org","type":"m.room.member","unsigned":{"event_id":"$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4","invite_room_state":[{"content":{"join_rule":"invite"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.join_rules"},{"content":{"name":"event auth test v12"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.name"},{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age":11553}},{"content":{"avatar_url":"mxc://envs.net/000cf1510b7c61018f9c72ca4cc63668370782c81725865933316030464","displayname":"tulir[e]","membership":"join"},"sender":"@tulir:envs.net","state_key":"@tulir:envs.net","type":"m.room.member"}]}}
+{"auth_events":["$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"content":{"avatar_url":"mxc://matrix.org/BDYVQFSLvZHMaKHDGiRkvhVg","displayname":"tulir[m]","membership":"join"},"depth":12,"hashes":{"sha256":"hR/fRIyFkxKnA1XNxIB+NKC0VR0vHs82EDgydhmmZXU"},"origin_server_ts":1756071609205,"prev_events":["$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"matrix.org":{"ed25519:a_RXGa":"keWbZHm+LPW22XWxb14Att4Ae4GVc6XAKAnxFRr3hxhrgEhsnMcxUx7fjqlA1dk3As6kjLKdekcyCef+AQCXCA"}},"state_key":"@tulir:matrix.org","type":"m.room.member","unsigned":{"age":19,"event_id":"$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw","replaces_state":"$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4"}}
+{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":150},"events_default":0,"historical":100,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":13,"hashes":{"sha256":"30Wuw3xIbA8+eXQBa4nFDKcyHtMbKPBYhLW1zft9/fE"},"origin_server_ts":1756071643928,"prev_events":["$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"x6Y4uViq4nK8LVPqtMLdCuvNET2bnjxYTgiKuEe1JYfwB4jPBnPuqvrt1O9oaanMpcRWbnuiZjckq4bUlRZ7Cw"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg","replaces_state":"$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"}}
+{"auth_events":["$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg","$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw"],"content":{"name":"event auth test v12!"},"depth":14,"hashes":{"sha256":"WT0gz7KYXvbdNruRavqIi9Hhul3rxCdZ+YY9yMGN+Fw"},"origin_server_ts":1756071656988,"prev_events":["$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"matrix.org":{"ed25519:a_RXGa":"bSplmqtXVhO2Z3hJ8JMQ/u7G2Wmg6yt7SwhYXObRQJfthekddJN152ME4YJIwy7YD8WFq7EkyB/NMyQoliYyCg"}},"state_key":"","type":"m.room.name","unsigned":{"event_id":"$p4xvOczrhzQMtRW3-Tf86LYUb5aqpGFIgjwHBuxWIcI","replaces_state":"$fFDwIavLTEIfcnggWuryB6JwfS-L2KT6vP1ap3P6ctE"}}
+{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":9001},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":15,"hashes":{"sha256":"FnGzbcXc8YOiB1TY33QunGA17Axoyuu3sdVOj5Z408o"},"origin_server_ts":1756071804931,"prev_events":["$p4xvOczrhzQMtRW3-Tf86LYUb5aqpGFIgjwHBuxWIcI"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"uyTUsPR+CzCtlevzB5+sNXvmfbPSp6u7RZC4E4TLVsj45+pjmMRswAvuHP9PT2+Tkl6Hu8ZPigsXgbKZtR35Aw"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw","replaces_state":"$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg"}}
+{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":100},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":16,"hashes":{"sha256":"KcivsiLesdnUnKX23Akk3OJEJFGRSY0g4H+p7XIThnw"},"origin_server_ts":1756071812688,"prev_events":["$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"cAK8dO2AVZklY9te5aVKbF1jR/eB5rzeNOXfYPjBLf+aSAS4Z6R2aMKW6hJB9PqRS4S+UZc24DTrjUjnvMzeBA"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU","replaces_state":"$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"}}
+{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"],"content":{"body":"meow #2","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":17,"hashes":{"sha256":"SgH9fOXGdbdqpRfYmoz1t29+gX8Ze4ThSoj6klZs3Og"},"origin_server_ts":1756247476706,"prev_events":["$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"SMYK7zP3SaQOKhzZUKUBVCKwffYqi3PFAlPM34kRJtmfGU3KZXNBT0zi+veXDMmxkMunqhF2RTHBD6joa0kBAQ"}},"type":"m.room.message","unsigned":{"event_id":"$KFHLO0-ENYOGQXogp84C-ISSu1xtKUzIMaZ6LiBcR_w"}}
+{"auth_events":["$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw","$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":100},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:beeper.com":8999,"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":18,"hashes":{"sha256":"l8Mw3VKn/Bvntg7bZ8uh5J8M2IBZM93Xg7hsdaSci8s"},"origin_server_ts":1758918656341,"prev_events":["$KFHLO0-ENYOGQXogp84C-ISSu1xtKUzIMaZ6LiBcR_w"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"matrix.org":{"ed25519:a_RXGa":"cg5LP0WuTnVB5jFhNERLLU5b+EhmyACiOq6cp3gKJnZsTAb1yajcgJybLWKrc8QQqxPa7hPnskRBgt4OBTFNAA"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0","replaces_state":"$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"}}
+{"auth_events":["$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw","$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"],"content":{"avatar_url":"mxc://beeper.com/eBdwbHbllONoAySQkXLjbfFM","displayname":"tulir[b]","membership":"invite"},"depth":19,"hashes":{"sha256":"KpmaRUQnJju8TIDMPzakitUIKOWJxTvULpFB3a1CGgc"},"origin_server_ts":1758918665952,"prev_events":["$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"beeper.com":{"ed25519:a_zgvp":"mzI9rPkQ1xHl2/G5Yrn0qmIRt5OyjPNqRwilPfH4jmr1tP+vv3vC0m4mph/MCOq8S1c/DQaCWSpdOX1uWfchBQ"},"matrix.org":{"ed25519:a_RXGa":"kEdfr8DjxC/bdvGYxnniFI/pxDWeyG73OjG/Gu1uoHLhjdtAT/vEQ6lotJJs214/KX5eAaQWobE9qtMvtPwMDw"}},"state_key":"@tulir:beeper.com","type":"m.room.member","unsigned":{"event_id":"$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro","invite_room_state":[{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age":11553}},{"content":{"avatar_url":"mxc://matrix.org/BDYVQFSLvZHMaKHDGiRkvhVg","displayname":"tulir[m]","membership":"join"},"sender":"@tulir:matrix.org","state_key":"@tulir:matrix.org","type":"m.room.member"},{"content":{"name":"event auth test v12!"},"sender":"@tulir:matrix.org","state_key":"","type":"m.room.name"},{"content":{"join_rule":"invite"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.join_rules"}]}}
+{"auth_events":["$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro","$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"],"content":{"avatar_url":"mxc://beeper.com/eBdwbHbllONoAySQkXLjbfFM","displayname":"tulir[b]","membership":"join"},"depth":20,"hashes":{"sha256":"bmaHSm4mYPNBNlUfFsauSTxLrUH4CUSAKYvr1v76qkk"},"origin_server_ts":1758918670276,"prev_events":["$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:beeper.com","signatures":{"beeper.com":{"ed25519:a_zgvp":"D3cz3m15m89a3G4c5yWOBCjhtSeI5IxBfQKt5XOr9a44QHyc3nwjjvIJaRrKNcS5tLUJwZ2IpVzjlrpbPHpxDA"}},"state_key":"@tulir:beeper.com","type":"m.room.member","unsigned":{"age":6,"event_id":"$_hayW1Y0HRWp3VEGZZbsMf0Ncg9x6n0ikveD0lbCwMw","replaces_state":"$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro"}}
+{"auth_events":["$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0","$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":100},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:beeper.com":9000,"@tulir:envs.net":9001,"@tulir:matrix.org":8999},"users_default":0},"depth":21,"hashes":{"sha256":"xCj9vszChHiXba9DaPzhtF79Tphek3pRViMp36DOurU"},"origin_server_ts":1758918689485,"prev_events":["$_hayW1Y0HRWp3VEGZZbsMf0Ncg9x6n0ikveD0lbCwMw"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"odkrWD30+ObeYtagULtECB/QmGae7qNy66nmJMWYXiQMYUJw/GMzSmgAiLAWfVYlfD3aEvMb/CBdrhL07tfSBw"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$di6cI89-GxX8-Wbx-0T69l4wg6TUWITRkjWXzG7EBqo","replaces_state":"$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"}}
diff --git a/federation/httpclient.go b/federation/httpclient.go
new file mode 100644
index 00000000..2f8dbb4f
--- /dev/null
+++ b/federation/httpclient.go
@@ -0,0 +1,92 @@
+// Copyright (c) 2024 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package federation
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "net/http"
+ "sync"
+)
+
+// ServerResolvingTransport is an http.RoundTripper that resolves Matrix server names before sending requests.
+// It only allows requests using the "matrix-federation" scheme.
+type ServerResolvingTransport struct {
+ ResolveOpts *ResolveServerNameOpts
+ Transport *http.Transport
+ Dialer *net.Dialer
+
+ cache ResolutionCache
+
+ resolveLocks map[string]*sync.Mutex
+ resolveLocksLock sync.Mutex
+}
+
+func NewServerResolvingTransport(cache ResolutionCache) *ServerResolvingTransport {
+ if cache == nil {
+ cache = NewInMemoryCache()
+ }
+ srt := &ServerResolvingTransport{
+ resolveLocks: make(map[string]*sync.Mutex),
+ cache: cache,
+ Dialer: &net.Dialer{},
+ }
+ srt.Transport = &http.Transport{
+ DialContext: srt.DialContext,
+ }
+ return srt
+}
+
+var _ http.RoundTripper = (*ServerResolvingTransport)(nil)
+
+func (srt *ServerResolvingTransport) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
+ addrs, ok := ctx.Value(contextKeyIPPort).([]string)
+ if !ok {
+ return nil, fmt.Errorf("no IP:port in context")
+ }
+ return srt.Dialer.DialContext(ctx, network, addrs[0])
+}
+
+func (srt *ServerResolvingTransport) RoundTrip(request *http.Request) (*http.Response, error) {
+ if request.URL.Scheme != "matrix-federation" {
+ return nil, fmt.Errorf("unsupported scheme: %s", request.URL.Scheme)
+ }
+ resolved, err := srt.resolve(request.Context(), request.URL.Host)
+ if err != nil {
+ return nil, fmt.Errorf("failed to resolve server name: %w", err)
+ }
+ request = request.WithContext(context.WithValue(request.Context(), contextKeyIPPort, resolved.IPPort))
+ request.URL.Scheme = "https"
+ request.URL.Host = resolved.HostHeader
+ request.Host = resolved.HostHeader
+ return srt.Transport.RoundTrip(request)
+}
+
+func (srt *ServerResolvingTransport) resolve(ctx context.Context, serverName string) (*ResolvedServerName, error) {
+ srt.resolveLocksLock.Lock()
+ lock, ok := srt.resolveLocks[serverName]
+ if !ok {
+ lock = &sync.Mutex{}
+ srt.resolveLocks[serverName] = lock
+ }
+ srt.resolveLocksLock.Unlock()
+
+ lock.Lock()
+ defer lock.Unlock()
+ res, err := srt.cache.LoadResolution(serverName)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read cache: %w", err)
+ } else if res != nil {
+ return res, nil
+ } else if res, err = ResolveServerName(ctx, serverName, srt.ResolveOpts); err != nil {
+ return nil, err
+ } else {
+ srt.cache.StoreResolution(res)
+ return res, nil
+ }
+}
diff --git a/federation/keyserver.go b/federation/keyserver.go
new file mode 100644
index 00000000..d32ba5cf
--- /dev/null
+++ b/federation/keyserver.go
@@ -0,0 +1,205 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package federation
+
+import (
+ "encoding/json"
+ "net/http"
+ "strconv"
+ "time"
+
+ "github.com/rs/zerolog"
+ "github.com/rs/zerolog/hlog"
+ "go.mau.fi/util/exerrors"
+ "go.mau.fi/util/exhttp"
+ "go.mau.fi/util/jsontime"
+ "go.mau.fi/util/ptr"
+ "go.mau.fi/util/requestlog"
+
+ "maunium.net/go/mautrix"
+ "maunium.net/go/mautrix/id"
+)
+
+type ServerVersion struct {
+ Name string `json:"name"`
+ Version string `json:"version"`
+}
+
+// ServerKeyProvider is an interface that returns private server keys for server key requests.
+type ServerKeyProvider interface {
+ Get(r *http.Request) (serverName string, key *SigningKey)
+}
+
+// StaticServerKey is an implementation of [ServerKeyProvider] that always returns the same server name and key.
+type StaticServerKey struct {
+ ServerName string
+ Key *SigningKey
+}
+
+func (ssk *StaticServerKey) Get(r *http.Request) (serverName string, key *SigningKey) {
+ return ssk.ServerName, ssk.Key
+}
+
+// KeyServer implements a basic Matrix key server that can serve its own keys, plus the federation version endpoint.
+//
+// It does not implement querying keys of other servers, nor any other federation endpoints.
+type KeyServer struct {
+ KeyProvider ServerKeyProvider
+ Version ServerVersion
+ WellKnownTarget string
+ OtherKeys KeyCache
+}
+
+// Register registers the key server endpoints to the given router.
+func (ks *KeyServer) Register(r *http.ServeMux, log zerolog.Logger) {
+ r.HandleFunc("GET /.well-known/matrix/server", ks.GetWellKnown)
+ r.HandleFunc("GET /_matrix/federation/v1/version", ks.GetServerVersion)
+ keyRouter := http.NewServeMux()
+ keyRouter.HandleFunc("GET /v2/server", ks.GetServerKey)
+ keyRouter.HandleFunc("GET /v2/query/{serverName}", ks.GetQueryKeys)
+ keyRouter.HandleFunc("POST /v2/query", ks.PostQueryKeys)
+ errorBodies := exhttp.ErrorBodies{
+ NotFound: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint")).MarshalJSON()),
+ MethodNotAllowed: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint")).MarshalJSON()),
+ }
+ r.Handle("/_matrix/key/", exhttp.ApplyMiddleware(
+ keyRouter,
+ exhttp.StripPrefix("/_matrix/key"),
+ hlog.NewHandler(log),
+ hlog.RequestIDHandler("request_id", "Request-Id"),
+ requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}),
+ exhttp.HandleErrors(errorBodies),
+ ))
+}
+
+// RespWellKnown is the response body for the `GET /.well-known/matrix/server` endpoint.
+type RespWellKnown struct {
+ Server string `json:"m.server"`
+}
+
+// GetWellKnown implements the `GET /.well-known/matrix/server` endpoint
+//
+// https://spec.matrix.org/v1.9/server-server-api/#get_well-knownmatrixserver
+func (ks *KeyServer) GetWellKnown(w http.ResponseWriter, r *http.Request) {
+ if ks.WellKnownTarget == "" {
+ mautrix.MNotFound.WithMessage("No well-known target set").Write(w)
+ } else {
+ exhttp.WriteJSONResponse(w, http.StatusOK, &RespWellKnown{Server: ks.WellKnownTarget})
+ }
+}
+
+// RespServerVersion is the response body for the `GET /_matrix/federation/v1/version` endpoint
+type RespServerVersion struct {
+ Server ServerVersion `json:"server"`
+}
+
+// GetServerVersion implements the `GET /_matrix/federation/v1/version` endpoint
+//
+// https://spec.matrix.org/v1.9/server-server-api/#get_matrixfederationv1version
+func (ks *KeyServer) GetServerVersion(w http.ResponseWriter, r *http.Request) {
+ exhttp.WriteJSONResponse(w, http.StatusOK, &RespServerVersion{Server: ks.Version})
+}
+
+// GetServerKey implements the `GET /_matrix/key/v2/server` endpoint.
+//
+// https://spec.matrix.org/v1.9/server-server-api/#get_matrixkeyv2server
+func (ks *KeyServer) GetServerKey(w http.ResponseWriter, r *http.Request) {
+ domain, key := ks.KeyProvider.Get(r)
+ if key == nil {
+ mautrix.MNotFound.WithMessage("No signing key found for %q", r.Host).Write(w)
+ } else {
+ exhttp.WriteJSONResponse(w, http.StatusOK, key.GenerateKeyResponse(domain, nil))
+ }
+}
+
+// ReqQueryKeys is the request body for the `POST /_matrix/key/v2/query` endpoint
+type ReqQueryKeys struct {
+ ServerKeys map[string]map[id.KeyID]QueryKeysCriteria `json:"server_keys"`
+}
+
+type QueryKeysCriteria struct {
+ MinimumValidUntilTS jsontime.UnixMilli `json:"minimum_valid_until_ts"`
+}
+
+// PostQueryKeysResponse is the response body for the `POST /_matrix/key/v2/query` endpoint
+type PostQueryKeysResponse struct {
+ ServerKeys map[string]*ServerKeyResponse `json:"server_keys"`
+}
+
+// PostQueryKeys implements the `POST /_matrix/key/v2/query` endpoint
+//
+// https://spec.matrix.org/v1.9/server-server-api/#post_matrixkeyv2query
+func (ks *KeyServer) PostQueryKeys(w http.ResponseWriter, r *http.Request) {
+ var req ReqQueryKeys
+ err := json.NewDecoder(r.Body).Decode(&req)
+ if err != nil {
+ mautrix.MBadJSON.WithMessage("failed to parse request: %v", err).Write(w)
+ return
+ }
+
+ resp := &PostQueryKeysResponse{
+ ServerKeys: make(map[string]*ServerKeyResponse),
+ }
+ for serverName, keys := range req.ServerKeys {
+ domain, key := ks.KeyProvider.Get(r)
+ if domain != serverName {
+ continue
+ }
+ for keyID, criteria := range keys {
+ if key.ID == keyID && criteria.MinimumValidUntilTS.Before(time.Now().Add(24*time.Hour)) {
+ resp.ServerKeys[serverName] = key.GenerateKeyResponse(serverName, nil)
+ }
+ }
+ }
+ exhttp.WriteJSONResponse(w, http.StatusOK, resp)
+}
+
+// GetQueryKeysResponse is the response body for the `GET /_matrix/key/v2/query/{serverName}` endpoint
+type GetQueryKeysResponse struct {
+ ServerKeys []*ServerKeyResponse `json:"server_keys"`
+}
+
+// GetQueryKeys implements the `GET /_matrix/key/v2/query/{serverName}` endpoint
+//
+// https://spec.matrix.org/v1.9/server-server-api/#get_matrixkeyv2queryservername
+func (ks *KeyServer) GetQueryKeys(w http.ResponseWriter, r *http.Request) {
+ serverName := r.PathValue("serverName")
+ minimumValidUntilTSString := r.URL.Query().Get("minimum_valid_until_ts")
+ minimumValidUntilTS, err := strconv.ParseInt(minimumValidUntilTSString, 10, 64)
+ if err != nil && minimumValidUntilTSString != "" {
+ mautrix.MInvalidParam.WithMessage("failed to parse ?minimum_valid_until_ts: %v", err).Write(w)
+ return
+ } else if time.UnixMilli(minimumValidUntilTS).After(time.Now().Add(24 * time.Hour)) {
+ mautrix.MInvalidParam.WithMessage("minimum_valid_until_ts may not be more than 24 hours in the future").Write(w)
+ return
+ }
+ resp := &GetQueryKeysResponse{
+ ServerKeys: []*ServerKeyResponse{},
+ }
+ domain, key := ks.KeyProvider.Get(r)
+ if domain == serverName {
+ if key != nil {
+ resp.ServerKeys = append(resp.ServerKeys, key.GenerateKeyResponse(serverName, nil))
+ }
+ } else if ks.OtherKeys != nil {
+ otherKey, err := ks.OtherKeys.LoadKeys(serverName)
+ if err != nil {
+ mautrix.MUnknown.WithMessage("Failed to load keys from cache").Write(w)
+ return
+ }
+ if key != nil && domain != "" {
+ signature, err := key.SignJSON(otherKey)
+ if err == nil {
+ otherKey.Signatures[domain] = map[id.KeyID]string{
+ key.ID: signature,
+ }
+ }
+ }
+ resp.ServerKeys = append(resp.ServerKeys, otherKey)
+ }
+ exhttp.WriteJSONResponse(w, http.StatusOK, resp)
+}
diff --git a/federation/pdu/auth.go b/federation/pdu/auth.go
new file mode 100644
index 00000000..16706fe5
--- /dev/null
+++ b/federation/pdu/auth.go
@@ -0,0 +1,71 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu
+
+import (
+ "slices"
+
+ "github.com/tidwall/gjson"
+ "go.mau.fi/util/exgjson"
+
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+)
+
+type StateKey struct {
+ Type string
+ StateKey string
+}
+
+var thirdPartyInviteTokenPath = exgjson.Path("third_party_invite", "signed", "token")
+
+type AuthEventSelection []StateKey
+
+func (aes *AuthEventSelection) Add(evtType, stateKey string) {
+ key := StateKey{Type: evtType, StateKey: stateKey}
+ if !aes.Has(key) {
+ *aes = append(*aes, key)
+ }
+}
+
+func (aes *AuthEventSelection) Has(key StateKey) bool {
+ return slices.Contains(*aes, key)
+}
+
+func (pdu *PDU) AuthEventSelection(roomVersion id.RoomVersion) (keys AuthEventSelection) {
+ if pdu.Type == event.StateCreate.Type && pdu.StateKey != nil {
+ return AuthEventSelection{}
+ }
+ keys = make(AuthEventSelection, 0, 3)
+ if !roomVersion.RoomIDIsCreateEventID() {
+ keys.Add(event.StateCreate.Type, "")
+ }
+ keys.Add(event.StatePowerLevels.Type, "")
+ keys.Add(event.StateMember.Type, pdu.Sender.String())
+ if pdu.Type == event.StateMember.Type && pdu.StateKey != nil {
+ keys.Add(event.StateMember.Type, *pdu.StateKey)
+ membership := event.Membership(gjson.GetBytes(pdu.Content, "membership").Str)
+ if membership == event.MembershipJoin || membership == event.MembershipInvite || membership == event.MembershipKnock {
+ keys.Add(event.StateJoinRules.Type, "")
+ }
+ if membership == event.MembershipInvite {
+ thirdPartyInviteToken := gjson.GetBytes(pdu.Content, thirdPartyInviteTokenPath).Str
+ if thirdPartyInviteToken != "" {
+ keys.Add(event.StateThirdPartyInvite.Type, thirdPartyInviteToken)
+ }
+ }
+ if membership == event.MembershipJoin && roomVersion.RestrictedJoins() {
+ authorizedVia := gjson.GetBytes(pdu.Content, "authorised_via_users_server").Str
+ if authorizedVia != "" {
+ keys.Add(event.StateMember.Type, authorizedVia)
+ }
+ }
+ }
+ return
+}
diff --git a/federation/pdu/hash.go b/federation/pdu/hash.go
new file mode 100644
index 00000000..38ef83e9
--- /dev/null
+++ b/federation/pdu/hash.go
@@ -0,0 +1,118 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu
+
+import (
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/base64"
+ "fmt"
+
+ "github.com/tidwall/gjson"
+
+ "maunium.net/go/mautrix/id"
+)
+
+func (pdu *PDU) CalculateContentHash() ([32]byte, error) {
+ if pdu == nil {
+ return [32]byte{}, ErrPDUIsNil
+ }
+ pduClone := pdu.Clone()
+ pduClone.Signatures = nil
+ pduClone.Unsigned = nil
+ pduClone.Hashes = nil
+ rawJSON, err := marshalCanonical(pduClone)
+ if err != nil {
+ return [32]byte{}, fmt.Errorf("failed to marshal PDU to calculate content hash: %w", err)
+ }
+ return sha256.Sum256(rawJSON), nil
+}
+
+func (pdu *PDU) FillContentHash() error {
+ if pdu == nil {
+ return ErrPDUIsNil
+ } else if pdu.Hashes != nil {
+ return nil
+ } else if hash, err := pdu.CalculateContentHash(); err != nil {
+ return err
+ } else {
+ pdu.Hashes = &Hashes{SHA256: hash[:]}
+ return nil
+ }
+}
+
+func (pdu *PDU) VerifyContentHash() bool {
+ if pdu == nil || pdu.Hashes == nil {
+ return false
+ }
+ calculatedHash, err := pdu.CalculateContentHash()
+ if err != nil {
+ return false
+ }
+ return hmac.Equal(calculatedHash[:], pdu.Hashes.SHA256)
+}
+
+func (pdu *PDU) GetRoomID() (id.RoomID, error) {
+ if pdu == nil {
+ return "", ErrPDUIsNil
+ } else if pdu.Type != "m.room.create" {
+ return "", fmt.Errorf("room ID can only be calculated for m.room.create events")
+ } else if roomVersion := id.RoomVersion(gjson.GetBytes(pdu.Content, "room_version").Str); !roomVersion.RoomIDIsCreateEventID() {
+ return "", fmt.Errorf("room version %s does not use m.room.create event ID as room ID", roomVersion)
+ } else if evtID, err := pdu.calculateEventID(roomVersion, '!'); err != nil {
+ return "", fmt.Errorf("failed to calculate event ID: %w", err)
+ } else {
+ return id.RoomID(evtID), nil
+ }
+}
+
+var UseInternalMetaForGetEventID = false
+
+func (pdu *PDU) GetEventID(roomVersion id.RoomVersion) (id.EventID, error) {
+ if UseInternalMetaForGetEventID && pdu.InternalMeta.EventID != "" {
+ return pdu.InternalMeta.EventID, nil
+ }
+ return pdu.calculateEventID(roomVersion, '$')
+}
+
+func (pdu *PDU) GetReferenceHash(roomVersion id.RoomVersion) ([32]byte, error) {
+ if pdu == nil {
+ return [32]byte{}, ErrPDUIsNil
+ }
+ if pdu.Hashes == nil || pdu.Hashes.SHA256 == nil {
+ if err := pdu.FillContentHash(); err != nil {
+ return [32]byte{}, err
+ }
+ }
+ rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion))
+ if err != nil {
+ return [32]byte{}, fmt.Errorf("failed to marshal redacted PDU to calculate event ID: %w", err)
+ }
+ return sha256.Sum256(rawJSON), nil
+}
+
+func (pdu *PDU) calculateEventID(roomVersion id.RoomVersion, prefix byte) (id.EventID, error) {
+ referenceHash, err := pdu.GetReferenceHash(roomVersion)
+ if err != nil {
+ return "", err
+ }
+ eventID := make([]byte, 44)
+ eventID[0] = prefix
+ switch roomVersion.EventIDFormat() {
+ case id.EventIDFormatCustom:
+ return "", fmt.Errorf("*pdu.PDU can only be used for room v3+")
+ case id.EventIDFormatBase64:
+ base64.RawStdEncoding.Encode(eventID[1:], referenceHash[:])
+ case id.EventIDFormatURLSafeBase64:
+ base64.RawURLEncoding.Encode(eventID[1:], referenceHash[:])
+ default:
+ return "", fmt.Errorf("unknown event ID format %v", roomVersion.EventIDFormat())
+ }
+ return id.EventID(eventID), nil
+}
diff --git a/federation/pdu/hash_test.go b/federation/pdu/hash_test.go
new file mode 100644
index 00000000..17417e12
--- /dev/null
+++ b/federation/pdu/hash_test.go
@@ -0,0 +1,55 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu_test
+
+import (
+ "encoding/base64"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "go.mau.fi/util/exerrors"
+)
+
+func TestPDU_CalculateContentHash(t *testing.T) {
+ for _, test := range testPDUs {
+ if test.redacted {
+ continue
+ }
+ t.Run(test.name, func(t *testing.T) {
+ parsed := parsePDU(test.pdu)
+ contentHash := exerrors.Must(parsed.CalculateContentHash())
+ assert.Equal(
+ t,
+ base64.RawStdEncoding.EncodeToString(parsed.Hashes.SHA256),
+ base64.RawStdEncoding.EncodeToString(contentHash[:]),
+ )
+ })
+ }
+}
+
+func TestPDU_VerifyContentHash(t *testing.T) {
+ for _, test := range testPDUs {
+ if test.redacted {
+ continue
+ }
+ t.Run(test.name, func(t *testing.T) {
+ parsed := parsePDU(test.pdu)
+ assert.True(t, parsed.VerifyContentHash())
+ })
+ }
+}
+
+func TestPDU_GetEventID(t *testing.T) {
+ for _, test := range testPDUs {
+ t.Run(test.name, func(t *testing.T) {
+ gotEventID := exerrors.Must(parsePDU(test.pdu).GetEventID(test.roomVersion))
+ assert.Equal(t, test.eventID, gotEventID)
+ })
+ }
+}
diff --git a/federation/pdu/pdu.go b/federation/pdu/pdu.go
new file mode 100644
index 00000000..17db6995
--- /dev/null
+++ b/federation/pdu/pdu.go
@@ -0,0 +1,156 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu
+
+import (
+ "bytes"
+ "crypto/ed25519"
+ "encoding/json/jsontext"
+ "encoding/json/v2"
+ "errors"
+ "fmt"
+ "strings"
+ "time"
+
+ "github.com/tidwall/gjson"
+ "go.mau.fi/util/jsonbytes"
+ "go.mau.fi/util/ptr"
+
+ "maunium.net/go/mautrix/crypto/canonicaljson"
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+)
+
+// GetKeyFunc is a callback for retrieving the key corresponding to a given key ID when verifying the signature of a PDU.
+//
+// The input time is the timestamp of the event. The function should attempt to fetch a key that is
+// valid at or after this time, but if that is not possible, the latest available key should be
+// returned without an error. The verify function will do its own validity checking based on the
+// returned valid until timestamp.
+type GetKeyFunc = func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error)
+
+type AnyPDU interface {
+ GetRoomID() (id.RoomID, error)
+ GetEventID(roomVersion id.RoomVersion) (id.EventID, error)
+ GetReferenceHash(roomVersion id.RoomVersion) ([32]byte, error)
+ CalculateContentHash() ([32]byte, error)
+ FillContentHash() error
+ VerifyContentHash() bool
+ Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error
+ VerifySignature(roomVersion id.RoomVersion, serverName string, getKey GetKeyFunc) error
+ ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error)
+ AuthEventSelection(roomVersion id.RoomVersion) (keys AuthEventSelection)
+}
+
+var (
+ _ AnyPDU = (*PDU)(nil)
+ _ AnyPDU = (*RoomV1PDU)(nil)
+)
+
+type InternalMeta struct {
+ EventID id.EventID `json:"event_id,omitempty"`
+ Rejected bool `json:"rejected,omitempty"`
+ Extra map[string]any `json:",unknown"`
+}
+
+type PDU struct {
+ AuthEvents []id.EventID `json:"auth_events"`
+ Content jsontext.Value `json:"content"`
+ Depth int64 `json:"depth"`
+ Hashes *Hashes `json:"hashes,omitzero"`
+ OriginServerTS int64 `json:"origin_server_ts"`
+ PrevEvents []id.EventID `json:"prev_events"`
+ Redacts *id.EventID `json:"redacts,omitzero"`
+ RoomID id.RoomID `json:"room_id,omitzero"` // not present for room v12+ create events
+ Sender id.UserID `json:"sender"`
+ Signatures map[string]map[id.KeyID]string `json:"signatures,omitzero"`
+ StateKey *string `json:"state_key,omitzero"`
+ Type string `json:"type"`
+ Unsigned jsontext.Value `json:"unsigned,omitzero"`
+ InternalMeta InternalMeta `json:"-"`
+
+ Unknown jsontext.Value `json:",unknown"`
+
+ // Deprecated legacy fields
+ DeprecatedPrevState jsontext.Value `json:"prev_state,omitzero"`
+ DeprecatedOrigin jsontext.Value `json:"origin,omitzero"`
+ DeprecatedMembership jsontext.Value `json:"membership,omitzero"`
+}
+
+var ErrPDUIsNil = errors.New("PDU is nil")
+
+type Hashes struct {
+ SHA256 jsonbytes.UnpaddedBytes `json:"sha256"`
+
+ Unknown jsontext.Value `json:",unknown"`
+}
+
+func (pdu *PDU) ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error) {
+ if pdu.Type == "m.room.create" && roomVersion == "" {
+ roomVersion = id.RoomVersion(gjson.GetBytes(pdu.Content, "room_version").Str)
+ }
+ evtType := event.Type{Type: pdu.Type, Class: event.MessageEventType}
+ if pdu.StateKey != nil {
+ evtType.Class = event.StateEventType
+ }
+ eventID, err := pdu.GetEventID(roomVersion)
+ if err != nil {
+ return nil, err
+ }
+ roomID := pdu.RoomID
+ if pdu.Type == "m.room.create" && roomVersion.RoomIDIsCreateEventID() {
+ roomID = id.RoomID(strings.Replace(string(eventID), "$", "!", 1))
+ }
+ evt := &event.Event{
+ StateKey: pdu.StateKey,
+ Sender: pdu.Sender,
+ Type: evtType,
+ Timestamp: pdu.OriginServerTS,
+ ID: eventID,
+ RoomID: roomID,
+ Redacts: ptr.Val(pdu.Redacts),
+ }
+ err = json.Unmarshal(pdu.Content, &evt.Content)
+ if err != nil {
+ return nil, fmt.Errorf("failed to unmarshal content: %w", err)
+ }
+ return evt, nil
+}
+
+func (pdu *PDU) AddSignature(serverName string, keyID id.KeyID, signature string) {
+ if signature == "" {
+ return
+ }
+ if pdu.Signatures == nil {
+ pdu.Signatures = make(map[string]map[id.KeyID]string)
+ }
+ if _, ok := pdu.Signatures[serverName]; !ok {
+ pdu.Signatures[serverName] = make(map[id.KeyID]string)
+ }
+ pdu.Signatures[serverName][keyID] = signature
+}
+
+func marshalCanonical(data any) (jsontext.Value, error) {
+ marshaledBytes, err := json.Marshal(data)
+ if err != nil {
+ return nil, err
+ }
+ marshaled := jsontext.Value(marshaledBytes)
+ err = marshaled.Canonicalize()
+ if err != nil {
+ return nil, err
+ }
+ check := canonicaljson.CanonicalJSONAssumeValid(marshaled)
+ if !bytes.Equal(marshaled, check) {
+ fmt.Println(string(marshaled))
+ fmt.Println(string(check))
+ return nil, fmt.Errorf("canonical JSON mismatch for %s", string(marshaled))
+ }
+ return marshaled, nil
+}
diff --git a/federation/pdu/pdu_test.go b/federation/pdu/pdu_test.go
new file mode 100644
index 00000000..59d7c3a6
--- /dev/null
+++ b/federation/pdu/pdu_test.go
@@ -0,0 +1,193 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu_test
+
+import (
+ "encoding/json/v2"
+ "time"
+
+ "go.mau.fi/util/exerrors"
+
+ "maunium.net/go/mautrix/federation/pdu"
+ "maunium.net/go/mautrix/id"
+)
+
+type serverKey struct {
+ key id.SigningKey
+ validUntilTS time.Time
+}
+
+type serverDetails struct {
+ serverName string
+ keys map[id.KeyID]serverKey
+}
+
+func (sd serverDetails) getKey(serverName string, keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) {
+ if serverName != sd.serverName {
+ return "", time.Time{}, nil
+ }
+ key, ok := sd.keys[keyID]
+ if ok {
+ return key.key, key.validUntilTS, nil
+ }
+ return "", time.Time{}, nil
+}
+
+var mauniumNet = serverDetails{
+ serverName: "maunium.net",
+ keys: map[id.KeyID]serverKey{
+ "ed25519:a_xxeS": {
+ key: "lVt/CC3tv74OH6xTph2JrUmeRj/j+1q0HVa0Xf4QlCg",
+ validUntilTS: time.Now(),
+ },
+ },
+}
+var envsNet = serverDetails{
+ serverName: "envs.net",
+ keys: map[id.KeyID]serverKey{
+ "ed25519:a_zIqy": {
+ key: "vCUcZpt9hUn0aabfh/9GP/6sZvXcydww8DUstPHdJm0",
+ validUntilTS: time.UnixMilli(1722360538068),
+ },
+ "ed25519:wuJyKT": {
+ key: "xbE1QssgomL4wCSlyMYF5/7KxVyM4HPwAbNa+nFFnx0",
+ validUntilTS: time.Now(),
+ },
+ },
+}
+var matrixOrg = serverDetails{
+ serverName: "matrix.org",
+ keys: map[id.KeyID]serverKey{
+ "ed25519:auto": {
+ key: "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw",
+ validUntilTS: time.UnixMilli(1576767829750),
+ },
+ "ed25519:a_RXGa": {
+ key: "l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ",
+ validUntilTS: time.Now(),
+ },
+ },
+}
+var continuwuityOrg = serverDetails{
+ serverName: "continuwuity.org",
+ keys: map[id.KeyID]serverKey{
+ "ed25519:PwHlNsFu": {
+ key: "8eNx2s0zWW+heKAmOH5zKv/nCPkEpraDJfGHxDu6hFI",
+ validUntilTS: time.Now(),
+ },
+ },
+}
+var novaAstraltechOrg = serverDetails{
+ serverName: "nova.astraltech.org",
+ keys: map[id.KeyID]serverKey{
+ "ed25519:a_afpo": {
+ key: "O1Y9GWuKo9xkuzuQef6gROxtTgxxAbS3WPNghPYXF3o",
+ validUntilTS: time.Now(),
+ },
+ },
+}
+
+type testPDU struct {
+ name string
+ pdu string
+ eventID id.EventID
+ roomVersion id.RoomVersion
+ redacted bool
+ serverDetails
+}
+
+var roomV4MessageTestPDU = testPDU{
+ name: "m.room.message in v4 room",
+ pdu: `{"auth_events":["$OB87jNemaIVDHAfu0-pa_cP7OPFXUXCbFpjYVi8gll4","$RaWbTF9wQfGQgUpe1S13wzICtGTB2PNKRHUNHu9IO1c","$ZmEWOXw6cC4Rd1wTdY5OzeLJVzjhrkxFPwwKE4gguGk"],"content":{"body":"the last one is saying it shouldn't have effects","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":13103,"hashes":{"sha256":"c2wb8qMlvzIPCP1Wd+eYZ4BRgnGYxS97dR1UlJjVMeg"},"origin_server_ts":1752875275263,"prev_events":["$-7_BMI3BXwj3ayoxiJvraJxYWTKwjiQ6sh7CW_Brvj0"],"room_id":"!JiiOHXrIUCtcOJsZCa:matrix.org","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"99TAqHpBkUEtgCraXsVXogmf/hnijPbgbG9eACtA+mbix3Y6gURI4QGQgcX/NhcE3pJQZ/YDjmbuvCnKvEccAA"}},"unsigned":{"age_ts":1752875275281}}`,
+ eventID: "$Jo_lmFR-e6lzrimzCA7DevIn2OwhuQYmd9xkcJBoqAA",
+ roomVersion: id.RoomV4,
+ serverDetails: mauniumNet,
+}
+
+var roomV12MessageTestPDU = testPDU{
+ name: "m.room.message in v12 room",
+ pdu: `{"auth_events":["$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA","$hyeL_nU_L3tsZ2dtZZpAHk0Skv-PqFQIipuII_By584"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":122,"hashes":{"sha256":"IQ0zlc+PXeEs6R3JvRkW3xTPV3zlGKSSd3x07KXGjzs"},"origin_server_ts":1755384351627,"prev_events":["$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA"],"room_id":"!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ","sender":"@tulir_test:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"0GDMddL2k7gF4V1VU8sL3wTfhAIzAu5iVH5jeavZ2VEg3J9/tHLWXAOn2tzkLaMRWl0/XpINT2YlH/rd2U21Ag"}},"unsigned":{"age_ts":1755384351627}}`,
+ eventID: "$xmP-wZfpannuHG-Akogi6c4YvqxChMtdyYbUMGOrMWc",
+ roomVersion: id.RoomV12,
+ serverDetails: mauniumNet,
+}
+
+var testPDUs = []testPDU{roomV4MessageTestPDU, {
+ name: "m.room.message in v5 room",
+ pdu: `{"auth_events":["$hp0ImHqYgHTRbLeWKPeTeFmxdb5SdMJN9cfmTrTk7d0","$KAj7X7tnJbR9qYYMWJSw-1g414_KlPptbbkZm7_kUtg","$V-2ShOwZYhA_nxMijaf3lqFgIJgzE2UMeFPtOLnoBYM"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":2248,"hashes":{"sha256":"kV+JuLbWXJ2r6PjHT3wt8bFc/TfI1nTaSN3Lamg/xHs"},"origin_server_ts":1755422945654,"prev_events":["$49lFLem2Nk4dxHk9RDXxTdaq9InIJpmkHpzVnjKcYwg"],"room_id":"!vzBgJsjNzgHSdWsmki:mozilla.org","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"JIl60uVgfCLBZLPoSiE7wVkJ9U5cNEPVPuv1sCCYUOq5yOW56WD1adgpBUdX2UFpYkCHvkRnyQGxU0+6HBp5BA"}},"unsigned":{"age_ts":1755422945673}}`,
+ eventID: "$Qn4tHfuAe6PlnKXPZnygAU9wd6RXqMKtt_ZzstHTSgA",
+ roomVersion: id.RoomV5,
+ serverDetails: mauniumNet,
+}, {
+ name: "m.room.message in v10 room",
+ pdu: `{"auth_events":["$--ilpwnsHaEdHrwiMrZNu5xHP6TthWG0FIXMHnlHCcs","$tn1FZUI_YUpfTr_a3Y_r8kC3inliIZZratzg0UsNdCQ","$Z-qMWmiMvm-aIEffcfSO6lN7TyjyTOsIcHIymfzoo20"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":100885,"hashes":{"sha256":"jc9272JPpPIVreJC3UEAm3BNVnLX8sm3U/TZs23wsHo"},"origin_server_ts":1755422792518,"prev_events":["$HDtbzpSys36Hk-F2NsiXfp9slsGXBH0b58qyddj_q5E"],"room_id":"!UzZHbJYcgggctGnlzr:envs.net","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"sAMLo9jPtNB0Jq67IQm06siEBx82qZa2edu56IDQ4tDylEV4Mq7iFO23gCghqXA7B/MqBsjXotGBxv6AvlJ2Dw"}},"unsigned":{"age_ts":1755422792540}}`,
+ eventID: "$4ZFr_ypfp4DyZQP4zyxM_cvuOMFkl07doJmwi106YFY",
+ roomVersion: id.RoomV10,
+ serverDetails: mauniumNet,
+}, {
+ name: "m.room.message in v11 room",
+ pdu: `{"auth_events":["$L8Ak6A939llTRIsZrytMlLDXQhI4uLEjx-wb1zSg-Bw","$QJmr7mmGeXGD4Tof0ZYSPW2oRGklseyHTKtZXnF-YNM","$7bkKK_Z-cGQ6Ae4HXWGBwXyZi3YjC6rIcQzGfVyl3Eo"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":3212,"hashes":{"sha256":"K549YdTnv62Jn84Y7sS5ZN3+AdmhleZHbenbhUpR2R8"},"origin_server_ts":1754242687127,"prev_events":["$DAhJg4jVsqk5FRatE2hbT1dSA8D2ASy5DbjEHIMSHwY"],"room_id":"!offtopic-2:continuwuity.org","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"SkzZdZ+rH22kzCBBIAErTdB0Vg6vkFmzvwjlOarGul72EnufgtE/tJcd3a8szAdK7f1ZovRyQxDgVm/Ib2u0Aw"}},"unsigned":{"age_ts":1754242687146}}`,
+ eventID: `$qkWfTL7_l3oRZO2CItW8-Q0yAmi_l_1ua629ZDqponE`,
+ roomVersion: id.RoomV11,
+ serverDetails: mauniumNet,
+}, roomV12MessageTestPDU, {
+ name: "m.room.create in v4 room",
+ pdu: `{"auth_events": [], "prev_events": [], "type": "m.room.create", "room_id": "!jxlRxnrZCsjpjDubDX:matrix.org", "sender": "@neilj:matrix.org", "content": {"room_version": "4", "predecessor": {"room_id": "!DYgXKezaHgMbiPMzjX:matrix.org", "event_id": "$156171636353XwPJT:matrix.org"}, "creator": "@neilj:matrix.org"}, "depth": 1, "prev_state": [], "state_key": "", "origin": "matrix.org", "origin_server_ts": 1561716363993, "hashes": {"sha256": "9tj8GpXjTAJvdNAbnuKLemZZk+Tjv2LAbGodSX6nJAo"}, "signatures": {"matrix.org": {"ed25519:auto": "2+sNt8uJUhzU4GPxnFVYtU2ZRgFdtVLT1vEZGUdJYN40zBpwYEGJy+kyb5matA+8/yLeYD9gu1O98lhleH0aCA"}}, "unsigned": {"age": 104769}}`,
+ eventID: "$ay_9_nPilrTpb3UxIwHHBBfFjTJb6hBAE_JzQwSjqeY",
+ roomVersion: id.RoomV4,
+ serverDetails: matrixOrg,
+}, {
+ name: "m.room.create in v10 room",
+ pdu: `{"auth_events":[],"content":{"creator":"@creme:envs.net","predecessor":{"event_id":"$BxYNisKcyBDhPLiVC06t18qhv7wsT72MzMCqn5vRhfY","room_id":"!tEyFYiMHhwJlDXTxwf:envs.net"},"room_version":"10"},"depth":1,"hashes":{"sha256":"us3TrsIjBWpwbm+k3F9fUVnz9GIuhnb+LcaY47fWwUI"},"origin":"envs.net","origin_server_ts":1664394769527,"prev_events":[],"room_id":"!UzZHbJYcgggctGnlzr:envs.net","sender":"@creme:envs.net","state_key":"","type":"m.room.create","signatures":{"envs.net":{"ed25519:a_zIqy":"0g3FDaD1e5BekJYW2sR7dgxuKoZshrf8P067c9+jmH6frsWr2Ua86Ax08CFa/n46L8uvV2SGofP8iiVYgXCRBg"}},"unsigned":{"age":2060}}`,
+ eventID: "$tn1FZUI_YUpfTr_a3Y_r8kC3inliIZZratzg0UsNdCQ",
+ roomVersion: id.RoomV10,
+ serverDetails: envsNet,
+}, {
+ name: "m.room.create in v12 room",
+ pdu: `{"auth_events":[],"content":{"fi.mau.randomness":"AAXZ6aIc","predecessor":{"room_id":"!#test/room\nversion 11, with @\ud83d\udc08\ufe0f:maunium.net"},"room_version":"12"},"depth":1,"hashes":{"sha256":"d3L1M3KUdyIKWcShyW6grUoJ8GOjCdSIEvQrDVHSpE8"},"origin_server_ts":1754940000000,"prev_events":[],"sender":"@tulir:maunium.net","state_key":"","type":"m.room.create","signatures":{"maunium.net":{"ed25519:a_xxeS":"ebjIRpzToc82cjb/RGY+VUzZic0yeRZrjctgx0SUTJxkprXn3/i1KdiYULfl/aD0cUJ5eL8gLakOSk2glm+sBw"}},"unsigned":{"age_ts":1754939139045}}`,
+ eventID: "$mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ",
+ roomVersion: id.RoomV12,
+ serverDetails: mauniumNet,
+}, {
+ name: "m.room.member in v4 room",
+ pdu: `{"auth_events":["$ay_9_nPilrTpb3UxIwHHBBfFjTJb6hBAE_JzQwSjqeY","$jg2AgCfnwnjR-osoyM0lVYS21QrtfmZxhGO90PRkmO4","$wMGMP4Ucij2_d4h_fVDgIT2xooLZAgMcBruT9oo3Jio","$yyDgV8w0_e8qslmn0nh9OeSq_fO0zjpjTjSEdKFxDso"],"prev_events":["$zSjNuTXhUe3Rq6NpKD3sNyl8a_asMnBhGC5IbacHlJ4"],"type":"m.room.member","room_id":"!jxlRxnrZCsjpjDubDX:matrix.org","sender":"@tulir:maunium.net","content":{"membership":"join","displayname":"tulir","avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","clicked \"send membership event with no changes\"":true},"depth":14370,"prev_state":[],"state_key":"@tulir:maunium.net","origin":"maunium.net","origin_server_ts":1600871136259,"hashes":{"sha256":"Ga6bG9Mk0887ruzM9TAAfa1O3DbNssb+qSFtE9oeRL4"},"signatures":{"maunium.net":{"ed25519:a_xxeS":"fzOyDG3G3pEzixtWPttkRA1DfnHETiKbiG8SEBQe2qycQbZWPky7xX8WujSrUJH/+bxTABpQwEH49d+RakxtBw"}},"unsigned":{"age_ts":1600871136259,"replaces_state":"$jg2AgCfnwnjR-osoyM0lVYS21QrtfmZxhGO90PRkmO4"}}`,
+ eventID: "$VtuCNOfAWGow-cxy0ajeK3fvONcC8QzF2yWa43g0Gwo",
+ roomVersion: id.RoomV4,
+ serverDetails: mauniumNet,
+}, {
+ name: "m.room.member in v10 room",
+ pdu: `{"auth_events":["$HQC4hWaioLKVbMH94qKbfb3UnL4ocql2vi-VdUYI48I","$R9FUDgNAp9ms7b6ASunZOIkpqmsIRq_ROrNEznu62fs","$kEPF8Aj87EzRmFPriu2zdyEY0rY15XSqywTYVLUUlCA","$tn1FZUI_YUpfTr_a3Y_r8kC3inliIZZratzg0UsNdCQ"],"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"depth":182,"hashes":{"sha256":"0HscBc921QV2dxK2qY7qrnyoAgfxBM7kKvqAXlEk+GE"},"origin":"maunium.net","origin_server_ts":1665402609039,"prev_events":["$R9FUDgNAp9ms7b6ASunZOIkpqmsIRq_ROrNEznu62fs"],"room_id":"!UzZHbJYcgggctGnlzr:envs.net","sender":"@tulir:maunium.net","state_key":"@tulir:maunium.net","type":"m.room.member","signatures":{"maunium.net":{"ed25519:a_xxeS":"lkOW0FSJ8MJ0wZpdwLH1Uf6FSl2q9/u6KthRIlM0CwHDJG4sIZ9DrMA8BdU8L/PWoDS/CoDUlLanDh99SplgBw"}},"unsigned":{"age_ts":1665402609039,"replaces_state":"$R9FUDgNAp9ms7b6ASunZOIkpqmsIRq_ROrNEznu62fs"}}`,
+ eventID: "$--ilpwnsHaEdHrwiMrZNu5xHP6TthWG0FIXMHnlHCcs",
+ roomVersion: id.RoomV10,
+ serverDetails: mauniumNet,
+}, {
+ name: "m.room.member of creator in v12 room",
+ pdu: `{"auth_events":[],"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"depth":2,"hashes":{"sha256":"IebdOBYaaWYIx2zq/lkVCnjWIXTLk1g+vgFpJMgd2/E"},"origin_server_ts":1754939139117,"prev_events":["$mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ"],"room_id":"!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ","sender":"@tulir:maunium.net","state_key":"@tulir:maunium.net","type":"m.room.member","signatures":{"maunium.net":{"ed25519:a_xxeS":"rFCgF2hmavdm6+P6/f7rmuOdoSOmELFaH3JdWjgBLZXS2z51Ma7fa2v2+BkAH1FvBo9FLhvEoFVM4WbNQLXtAA"}},"unsigned":{"age_ts":1754939139117}}`,
+ eventID: "$accqGxfvhBvMP4Sf6P7t3WgnaJK6UbonO2ZmwqSE5Sg",
+ roomVersion: id.RoomV12,
+ serverDetails: mauniumNet,
+}, {
+ name: "custom message event in v4 room",
+ pdu: `{"auth_events":["$VtuCNOfAWGow-cxy0ajeK3fvONcC8QzF2yWa43g0Gwo","$ay_9_nPilrTpb3UxIwHHBBfFjTJb6hBAE_JzQwSjqeY","$Gau_XwziYsr-rt3SouhbKN14twgmbKjcZZc_hz-nOgU"],"content":{"\ud83d\udc08\ufe0f":true,"\ud83d\udc15\ufe0f":false},"depth":69645,"hashes":{"sha256":"VHtWyCt+15ZesNnStU3FOkxrjzHJYZfd3JUgO9JWe0s"},"origin_server_ts":1755423939146,"prev_events":["$exmp4cj0OKOFSxuqBYiOYwQi5j_0XRc78d6EavAkhy0"],"room_id":"!jxlRxnrZCsjpjDubDX:matrix.org","sender":"@tulir:maunium.net","type":"\ud83d\udc08\ufe0f","signatures":{"maunium.net":{"ed25519:a_xxeS":"wfmP1XN4JBkKVkqrQnwysyEUslXt8hQRFwN9NC9vJaIeDMd0OJ6uqCas75808DuG71p23fzqbzhRnHckst6FCQ"}},"unsigned":{"age_ts":1755423939164}}`,
+ eventID: "$kAagtZAIEeZaLVCUSl74tAxQbdKbE22GU7FM-iAJBc0",
+ roomVersion: id.RoomV4,
+ serverDetails: mauniumNet,
+}, {
+ name: "redacted m.room.member event in v11 room with 2 signatures",
+ pdu: `{"auth_events":["$9f12-_stoY07BOTmyguE1QlqvghLBh9Rk6PWRLoZn_M","$IP8hyjBkIDREVadyv0fPCGAW9IXGNllaZyxqQwiY_tA","$7dN5J8EveliaPkX6_QSejl4GQtem4oieavgALMeWZyE"],"content":{"membership":"join"},"depth":96978,"hashes":{"sha256":"APYA/aj3u+P0EwNaEofuSIlfqY3cK3lBz6RkwHX+Zak"},"origin_server_ts":1755664164485,"prev_events":["$XBN9W5Ll8VEH3eYqJaemxCBTDdy0hZB0sWpmyoUp93c"],"room_id":"!main-1:continuwuity.org","sender":"@6a19abdd4766:nova.astraltech.org","state_key":"@6a19abdd4766:nova.astraltech.org","type":"m.room.member","signatures":{"continuwuity.org":{"ed25519:PwHlNsFu":"+b/Fp2vWnC+Z2lI3GnCu7ZHdo3iWNDZ2AJqMoU9owMtLBPMxs4dVIsJXvaFq0ryawsgwDwKZ7f4xaFUNARJSDg"},"nova.astraltech.org":{"ed25519:a_afpo":"pXIngyxKukCPR7WOIIy8FTZxQ5L2dLiou5Oc8XS4WyY4YzJuckQzOaToigLLZxamfbN/jXbO+XUizpRpYccDAA"}},"unsigned":{}}`,
+ eventID: "$r6d9m125YWG28-Tln47bWtm6Jlv4mcSUWJTHijBlXLQ",
+ roomVersion: id.RoomV11,
+ serverDetails: novaAstraltechOrg,
+ redacted: true,
+}}
+
+func parsePDU(pdu string) (out *pdu.PDU) {
+ exerrors.PanicIfNotNil(json.Unmarshal([]byte(pdu), &out))
+ return
+}
diff --git a/federation/pdu/redact.go b/federation/pdu/redact.go
new file mode 100644
index 00000000..d7ee0c15
--- /dev/null
+++ b/federation/pdu/redact.go
@@ -0,0 +1,111 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu
+
+import (
+ "encoding/json/jsontext"
+
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
+ "go.mau.fi/util/exgjson"
+ "go.mau.fi/util/ptr"
+
+ "maunium.net/go/mautrix/id"
+)
+
+func filteredObject(object jsontext.Value, allowedPaths ...string) jsontext.Value {
+ filtered := jsontext.Value("{}")
+ var err error
+ for _, path := range allowedPaths {
+ res := gjson.GetBytes(object, path)
+ if res.Exists() {
+ var raw jsontext.Value
+ if res.Index > 0 {
+ raw = object[res.Index : res.Index+len(res.Raw)]
+ } else {
+ raw = jsontext.Value(res.Raw)
+ }
+ filtered, err = sjson.SetRawBytes(filtered, path, raw)
+ if err != nil {
+ panic(err)
+ }
+ }
+ }
+ return filtered
+}
+
+func (pdu *PDU) Clone() *PDU {
+ return ptr.Clone(pdu)
+}
+
+func (pdu *PDU) RedactForSignature(roomVersion id.RoomVersion) *PDU {
+ pdu.Signatures = nil
+ return pdu.Redact(roomVersion)
+}
+
+var emptyObject = jsontext.Value("{}")
+
+func RedactContent(eventType string, content jsontext.Value, roomVersion id.RoomVersion) jsontext.Value {
+ switch eventType {
+ case "m.room.member":
+ allowedPaths := []string{"membership"}
+ if roomVersion.RestrictedJoinsFix() {
+ allowedPaths = append(allowedPaths, "join_authorised_via_users_server")
+ }
+ if roomVersion.UpdatedRedactionRules() {
+ allowedPaths = append(allowedPaths, exgjson.Path("third_party_invite", "signed"))
+ }
+ return filteredObject(content, allowedPaths...)
+ case "m.room.create":
+ if !roomVersion.UpdatedRedactionRules() {
+ return filteredObject(content, "creator")
+ }
+ return content
+ case "m.room.join_rules":
+ if roomVersion.RestrictedJoins() {
+ return filteredObject(content, "join_rule", "allow")
+ }
+ return filteredObject(content, "join_rule")
+ case "m.room.power_levels":
+ allowedKeys := []string{"ban", "events", "events_default", "kick", "redact", "state_default", "users", "users_default"}
+ if roomVersion.UpdatedRedactionRules() {
+ allowedKeys = append(allowedKeys, "invite")
+ }
+ return filteredObject(content, allowedKeys...)
+ case "m.room.history_visibility":
+ return filteredObject(content, "history_visibility")
+ case "m.room.redaction":
+ if roomVersion.RedactsInContent() {
+ return filteredObject(content, "redacts")
+ }
+ return emptyObject
+ case "m.room.aliases":
+ if roomVersion.SpecialCasedAliasesAuth() {
+ return filteredObject(content, "aliases")
+ }
+ return emptyObject
+ default:
+ return emptyObject
+ }
+}
+
+func (pdu *PDU) Redact(roomVersion id.RoomVersion) *PDU {
+ pdu.Unknown = nil
+ pdu.Unsigned = nil
+ if roomVersion.UpdatedRedactionRules() {
+ pdu.DeprecatedPrevState = nil
+ pdu.DeprecatedOrigin = nil
+ pdu.DeprecatedMembership = nil
+ }
+ if pdu.Type != "m.room.redaction" || roomVersion.RedactsInContent() {
+ pdu.Redacts = nil
+ }
+ pdu.Content = RedactContent(pdu.Type, pdu.Content, roomVersion)
+ return pdu
+}
diff --git a/federation/pdu/signature.go b/federation/pdu/signature.go
new file mode 100644
index 00000000..04e7c5ef
--- /dev/null
+++ b/federation/pdu/signature.go
@@ -0,0 +1,60 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu
+
+import (
+ "crypto/ed25519"
+ "encoding/base64"
+ "fmt"
+ "time"
+
+ "maunium.net/go/mautrix/federation/signutil"
+ "maunium.net/go/mautrix/id"
+)
+
+func (pdu *PDU) Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error {
+ err := pdu.FillContentHash()
+ if err != nil {
+ return err
+ }
+ rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion))
+ if err != nil {
+ return fmt.Errorf("failed to marshal redacted PDU to sign: %w", err)
+ }
+ signature := ed25519.Sign(privateKey, rawJSON)
+ pdu.AddSignature(serverName, keyID, base64.RawStdEncoding.EncodeToString(signature))
+ return nil
+}
+
+func (pdu *PDU) VerifySignature(roomVersion id.RoomVersion, serverName string, getKey GetKeyFunc) error {
+ rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion))
+ if err != nil {
+ return fmt.Errorf("failed to marshal redacted PDU to verify signature: %w", err)
+ }
+ verified := false
+ for keyID, sig := range pdu.Signatures[serverName] {
+ originServerTS := time.UnixMilli(pdu.OriginServerTS)
+ key, validUntil, err := getKey(serverName, keyID, originServerTS)
+ if err != nil {
+ return fmt.Errorf("failed to get key %s for %s: %w", keyID, serverName, err)
+ } else if key == "" {
+ return fmt.Errorf("key %s not found for %s", keyID, serverName)
+ } else if validUntil.Before(originServerTS) && roomVersion.EnforceSigningKeyValidity() {
+ return fmt.Errorf("key %s for %s is only valid until %s, but event is from %s", keyID, serverName, validUntil, originServerTS)
+ } else if err = signutil.VerifyJSONRaw(key, sig, rawJSON); err != nil {
+ return fmt.Errorf("failed to verify signature from key %s: %w", keyID, err)
+ } else {
+ verified = true
+ }
+ }
+ if !verified {
+ return fmt.Errorf("no verifiable signatures found for server %s", serverName)
+ }
+ return nil
+}
diff --git a/federation/pdu/signature_test.go b/federation/pdu/signature_test.go
new file mode 100644
index 00000000..01df5076
--- /dev/null
+++ b/federation/pdu/signature_test.go
@@ -0,0 +1,102 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu_test
+
+import (
+ "crypto/ed25519"
+ "encoding/base64"
+ "encoding/json/jsontext"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.mau.fi/util/exerrors"
+
+ "maunium.net/go/mautrix/federation/pdu"
+ "maunium.net/go/mautrix/id"
+)
+
+func TestPDU_VerifySignature(t *testing.T) {
+ for _, test := range testPDUs {
+ t.Run(test.name, func(t *testing.T) {
+ parsed := parsePDU(test.pdu)
+ err := parsed.VerifySignature(test.roomVersion, test.serverName, test.getKey)
+ assert.NoError(t, err)
+ })
+ }
+}
+
+func TestPDU_VerifySignature_Fail_NoKey(t *testing.T) {
+ test := roomV12MessageTestPDU
+ parsed := parsePDU(test.pdu)
+ err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
+ return
+ })
+ assert.Error(t, err)
+}
+
+func TestPDU_VerifySignature_V4ExpiredKey(t *testing.T) {
+ test := roomV4MessageTestPDU
+ parsed := parsePDU(test.pdu)
+ err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
+ key = test.keys[keyID].key
+ validUntil = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
+ return
+ })
+ assert.NoError(t, err)
+}
+
+func TestPDU_VerifySignature_V12ExpiredKey(t *testing.T) {
+ test := roomV12MessageTestPDU
+ parsed := parsePDU(test.pdu)
+ err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
+ key = test.keys[keyID].key
+ validUntil = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
+ return
+ })
+ assert.Error(t, err)
+}
+
+func TestPDU_VerifySignature_V12InvalidSignature(t *testing.T) {
+ test := roomV12MessageTestPDU
+ parsed := parsePDU(test.pdu)
+ for _, sigs := range parsed.Signatures {
+ for key := range sigs {
+ sigs[key] = sigs[key][:len(sigs[key])-3] + "ABC"
+ }
+ }
+ err := parsed.VerifySignature(test.roomVersion, test.serverName, test.getKey)
+ assert.Error(t, err)
+}
+
+func TestPDU_Sign(t *testing.T) {
+ pubKey, privKey := exerrors.Must2(ed25519.GenerateKey(nil))
+ evt := &pdu.PDU{
+ AuthEvents: []id.EventID{"$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA", "$hyeL_nU_L3tsZ2dtZZpAHk0Skv-PqFQIipuII_By584"},
+ Content: jsontext.Value(`{"msgtype":"m.text","body":"Hello, world!"}`),
+ Depth: 123,
+ OriginServerTS: 1755384351627,
+ PrevEvents: []id.EventID{"$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA"},
+ RoomID: "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ",
+ Sender: "@tulir:example.com",
+ Type: "m.room.message",
+ }
+ err := evt.Sign(id.RoomV12, "example.com", "ed25519:rand", privKey)
+ require.NoError(t, err)
+ err = evt.VerifySignature(id.RoomV11, "example.com", func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
+ if serverName == "example.com" && keyID == "ed25519:rand" {
+ key = id.SigningKey(base64.RawStdEncoding.EncodeToString(pubKey))
+ validUntil = time.Now()
+ }
+ return
+ })
+ require.NoError(t, err)
+
+}
diff --git a/federation/pdu/v1.go b/federation/pdu/v1.go
new file mode 100644
index 00000000..9557f8ab
--- /dev/null
+++ b/federation/pdu/v1.go
@@ -0,0 +1,277 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu
+
+import (
+ "crypto/ed25519"
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/json/jsontext"
+ "encoding/json/v2"
+ "fmt"
+ "time"
+
+ "github.com/tidwall/gjson"
+ "go.mau.fi/util/ptr"
+
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/federation/signutil"
+ "maunium.net/go/mautrix/id"
+)
+
+type V1EventReference struct {
+ ID id.EventID
+ Hashes Hashes
+}
+
+var (
+ _ json.UnmarshalerFrom = (*V1EventReference)(nil)
+ _ json.MarshalerTo = (*V1EventReference)(nil)
+)
+
+func (er *V1EventReference) MarshalJSONTo(enc *jsontext.Encoder) error {
+ return json.MarshalEncode(enc, []any{er.ID, er.Hashes})
+}
+
+func (er *V1EventReference) UnmarshalJSONFrom(dec *jsontext.Decoder) error {
+ var ref V1EventReference
+ var data []jsontext.Value
+ if err := json.UnmarshalDecode(dec, &data); err != nil {
+ return err
+ } else if len(data) != 2 {
+ return fmt.Errorf("V1EventReference.UnmarshalJSONFrom: expected array with 2 elements, got %d", len(data))
+ } else if err = json.Unmarshal(data[0], &ref.ID); err != nil {
+ return fmt.Errorf("V1EventReference.UnmarshalJSONFrom: failed to unmarshal event ID: %w", err)
+ } else if err = json.Unmarshal(data[1], &ref.Hashes); err != nil {
+ return fmt.Errorf("V1EventReference.UnmarshalJSONFrom: failed to unmarshal hashes: %w", err)
+ }
+ *er = ref
+ return nil
+}
+
+type RoomV1PDU struct {
+ AuthEvents []V1EventReference `json:"auth_events"`
+ Content jsontext.Value `json:"content"`
+ Depth int64 `json:"depth"`
+ EventID id.EventID `json:"event_id"`
+ Hashes *Hashes `json:"hashes,omitzero"`
+ OriginServerTS int64 `json:"origin_server_ts"`
+ PrevEvents []V1EventReference `json:"prev_events"`
+ Redacts *id.EventID `json:"redacts,omitzero"`
+ RoomID id.RoomID `json:"room_id"`
+ Sender id.UserID `json:"sender"`
+ Signatures map[string]map[id.KeyID]string `json:"signatures,omitzero"`
+ StateKey *string `json:"state_key,omitzero"`
+ Type string `json:"type"`
+ Unsigned jsontext.Value `json:"unsigned,omitzero"`
+
+ Unknown jsontext.Value `json:",unknown"`
+
+ // Deprecated legacy fields
+ DeprecatedPrevState jsontext.Value `json:"prev_state,omitzero"`
+ DeprecatedOrigin jsontext.Value `json:"origin,omitzero"`
+ DeprecatedMembership jsontext.Value `json:"membership,omitzero"`
+}
+
+func (pdu *RoomV1PDU) GetRoomID() (id.RoomID, error) {
+ return pdu.RoomID, nil
+}
+
+func (pdu *RoomV1PDU) GetEventID(roomVersion id.RoomVersion) (id.EventID, error) {
+ if !pdu.SupportsRoomVersion(roomVersion) {
+ return "", fmt.Errorf("RoomV1PDU.GetEventID: unsupported room version %s", roomVersion)
+ }
+ return pdu.EventID, nil
+}
+
+func (pdu *RoomV1PDU) RedactForSignature(roomVersion id.RoomVersion) *RoomV1PDU {
+ pdu.Signatures = nil
+ return pdu.Redact(roomVersion)
+}
+
+func (pdu *RoomV1PDU) Redact(roomVersion id.RoomVersion) *RoomV1PDU {
+ pdu.Unknown = nil
+ pdu.Unsigned = nil
+ if pdu.Type != "m.room.redaction" {
+ pdu.Redacts = nil
+ }
+ pdu.Content = RedactContent(pdu.Type, pdu.Content, roomVersion)
+ return pdu
+}
+
+func (pdu *RoomV1PDU) GetReferenceHash(roomVersion id.RoomVersion) ([32]byte, error) {
+ if !pdu.SupportsRoomVersion(roomVersion) {
+ return [32]byte{}, fmt.Errorf("RoomV1PDU.GetReferenceHash: unsupported room version %s", roomVersion)
+ }
+ if pdu == nil {
+ return [32]byte{}, ErrPDUIsNil
+ }
+ if pdu.Hashes == nil || pdu.Hashes.SHA256 == nil {
+ if err := pdu.FillContentHash(); err != nil {
+ return [32]byte{}, err
+ }
+ }
+ rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion))
+ if err != nil {
+ return [32]byte{}, fmt.Errorf("failed to marshal redacted PDU to calculate event ID: %w", err)
+ }
+ return sha256.Sum256(rawJSON), nil
+}
+
+func (pdu *RoomV1PDU) CalculateContentHash() ([32]byte, error) {
+ if pdu == nil {
+ return [32]byte{}, ErrPDUIsNil
+ }
+ pduClone := pdu.Clone()
+ pduClone.Signatures = nil
+ pduClone.Unsigned = nil
+ pduClone.Hashes = nil
+ rawJSON, err := marshalCanonical(pduClone)
+ if err != nil {
+ return [32]byte{}, fmt.Errorf("failed to marshal PDU to calculate content hash: %w", err)
+ }
+ return sha256.Sum256(rawJSON), nil
+}
+
+func (pdu *RoomV1PDU) FillContentHash() error {
+ if pdu == nil {
+ return ErrPDUIsNil
+ } else if pdu.Hashes != nil {
+ return nil
+ } else if hash, err := pdu.CalculateContentHash(); err != nil {
+ return err
+ } else {
+ pdu.Hashes = &Hashes{SHA256: hash[:]}
+ return nil
+ }
+}
+
+func (pdu *RoomV1PDU) VerifyContentHash() bool {
+ if pdu == nil || pdu.Hashes == nil {
+ return false
+ }
+ calculatedHash, err := pdu.CalculateContentHash()
+ if err != nil {
+ return false
+ }
+ return hmac.Equal(calculatedHash[:], pdu.Hashes.SHA256)
+}
+
+func (pdu *RoomV1PDU) Clone() *RoomV1PDU {
+ return ptr.Clone(pdu)
+}
+
+func (pdu *RoomV1PDU) Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error {
+ if !pdu.SupportsRoomVersion(roomVersion) {
+ return fmt.Errorf("RoomV1PDU.Sign: unsupported room version %s", roomVersion)
+ }
+ err := pdu.FillContentHash()
+ if err != nil {
+ return err
+ }
+ rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion))
+ if err != nil {
+ return fmt.Errorf("failed to marshal redacted PDU to sign: %w", err)
+ }
+ signature := ed25519.Sign(privateKey, rawJSON)
+ if pdu.Signatures == nil {
+ pdu.Signatures = make(map[string]map[id.KeyID]string)
+ }
+ if _, ok := pdu.Signatures[serverName]; !ok {
+ pdu.Signatures[serverName] = make(map[id.KeyID]string)
+ }
+ pdu.Signatures[serverName][keyID] = base64.RawStdEncoding.EncodeToString(signature)
+ return nil
+}
+
+func (pdu *RoomV1PDU) VerifySignature(roomVersion id.RoomVersion, serverName string, getKey GetKeyFunc) error {
+ if !pdu.SupportsRoomVersion(roomVersion) {
+ return fmt.Errorf("RoomV1PDU.VerifySignature: unsupported room version %s", roomVersion)
+ }
+ rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion))
+ if err != nil {
+ return fmt.Errorf("failed to marshal redacted PDU to verify signature: %w", err)
+ }
+ verified := false
+ for keyID, sig := range pdu.Signatures[serverName] {
+ originServerTS := time.UnixMilli(pdu.OriginServerTS)
+ key, _, err := getKey(serverName, keyID, originServerTS)
+ if err != nil {
+ return fmt.Errorf("failed to get key %s for %s: %w", keyID, serverName, err)
+ } else if key == "" {
+ return fmt.Errorf("key %s not found for %s", keyID, serverName)
+ } else if err = signutil.VerifyJSONRaw(key, sig, rawJSON); err != nil {
+ return fmt.Errorf("failed to verify signature from key %s: %w", keyID, err)
+ } else {
+ verified = true
+ }
+ }
+ if !verified {
+ return fmt.Errorf("no verifiable signatures found for server %s", serverName)
+ }
+ return nil
+}
+
+func (pdu *RoomV1PDU) SupportsRoomVersion(roomVersion id.RoomVersion) bool {
+ switch roomVersion {
+ case id.RoomV0, id.RoomV1, id.RoomV2:
+ return true
+ default:
+ return false
+ }
+}
+
+func (pdu *RoomV1PDU) ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error) {
+ if !pdu.SupportsRoomVersion(roomVersion) {
+ return nil, fmt.Errorf("RoomV1PDU.ToClientEvent: unsupported room version %s", roomVersion)
+ }
+ evtType := event.Type{Type: pdu.Type, Class: event.MessageEventType}
+ if pdu.StateKey != nil {
+ evtType.Class = event.StateEventType
+ }
+ evt := &event.Event{
+ StateKey: pdu.StateKey,
+ Sender: pdu.Sender,
+ Type: evtType,
+ Timestamp: pdu.OriginServerTS,
+ ID: pdu.EventID,
+ RoomID: pdu.RoomID,
+ Redacts: ptr.Val(pdu.Redacts),
+ }
+ err := json.Unmarshal(pdu.Content, &evt.Content)
+ if err != nil {
+ return nil, fmt.Errorf("failed to unmarshal content: %w", err)
+ }
+ return evt, nil
+}
+
+func (pdu *RoomV1PDU) AuthEventSelection(_ id.RoomVersion) (keys AuthEventSelection) {
+ if pdu.Type == event.StateCreate.Type && pdu.StateKey != nil {
+ return AuthEventSelection{}
+ }
+ keys = make(AuthEventSelection, 0, 3)
+ keys.Add(event.StateCreate.Type, "")
+ keys.Add(event.StatePowerLevels.Type, "")
+ keys.Add(event.StateMember.Type, pdu.Sender.String())
+ if pdu.Type == event.StateMember.Type && pdu.StateKey != nil {
+ keys.Add(event.StateMember.Type, *pdu.StateKey)
+ membership := event.Membership(gjson.GetBytes(pdu.Content, "membership").Str)
+ if membership == event.MembershipJoin || membership == event.MembershipInvite || membership == event.MembershipKnock {
+ keys.Add(event.StateJoinRules.Type, "")
+ }
+ if membership == event.MembershipInvite {
+ thirdPartyInviteToken := gjson.GetBytes(pdu.Content, thirdPartyInviteTokenPath).Str
+ if thirdPartyInviteToken != "" {
+ keys.Add(event.StateThirdPartyInvite.Type, thirdPartyInviteToken)
+ }
+ }
+ }
+ return
+}
diff --git a/federation/pdu/v1_test.go b/federation/pdu/v1_test.go
new file mode 100644
index 00000000..ecf2dbd2
--- /dev/null
+++ b/federation/pdu/v1_test.go
@@ -0,0 +1,86 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu_test
+
+import (
+ "encoding/base64"
+ "encoding/json/v2"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "go.mau.fi/util/exerrors"
+
+ "maunium.net/go/mautrix/federation/pdu"
+ "maunium.net/go/mautrix/id"
+)
+
+var testV1PDUs = []testPDU{{
+ name: "m.room.message in v1 room",
+ pdu: `{"auth_events":[["$159234730483190eXavq:matrix.org",{"sha256":"VprZrhMqOQyKbfF3UE26JXE8D27ih4R/FGGc8GZ0Whs"}],["$143454825711DhCxH:matrix.org",{"sha256":"3sJh/5GOB094OKuhbjL634Gt69YIcge9GD55ciJa9ok"}],["$156837651426789wiPdh:maunium.net",{"sha256":"FGyR3sxJ/VxYabDkO/5qtwrPR3hLwGknJ0KX0w3GUHE"}]],"content":{"body":"photo-1526336024174-e58f5cdd8e13.jpg","info":{"h":1620,"mimetype":"image/jpeg","size":208053,"w":1080},"msgtype":"m.image","url":"mxc://maunium.net/aEqEghIjFPAerIhCxJCYpQeC"},"depth":16669,"event_id":"$16738169022163bokdi:maunium.net","hashes":{"sha256":"XYB47Gf2vAci3BTguIJaC75ZYGMuVY65jcvoUVgpcLA"},"origin":"maunium.net","origin_server_ts":1673816902100,"prev_events":[["$1673816901121325UMCjA:matrix.org",{"sha256":"t7e0IYHLI3ydIPoIU8a8E/pIWXH9cNLlQBEtGyGtHwc"}]],"room_id":"!jhpZBTbckszblMYjMK:matrix.org","sender":"@cat:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"uRZbEm+P+Y1ZVgwBn5I6SlaUZdzlH1bB4nv81yt5EIQ0b1fZ8YgM4UWMijrrXp3+NmqRFl0cakSM3MneJOtFCw"}},"unsigned":{"age_ts":1673816902100}}`,
+ eventID: "$16738169022163bokdi:maunium.net",
+ roomVersion: id.RoomV1,
+ serverDetails: mauniumNet,
+}, {
+ name: "m.room.create in v1 room",
+ pdu: `{"origin": "matrix.org", "signatures": {"matrix.org": {"ed25519:auto": "XTejpXn5REoHrZWgCpJglGX7MfOWS2zUjYwJRLrwW2PQPbFdqtL+JnprBXwIP2C1NmgWSKG+am1QdApu0KoHCQ"}}, "origin_server_ts": 1434548257426, "sender": "@appservice-irc:matrix.org", "event_id": "$143454825711DhCxH:matrix.org", "prev_events": [], "unsigned": {"age": 12872287834}, "state_key": "", "content": {"creator": "@appservice-irc:matrix.org"}, "depth": 1, "prev_state": [], "room_id": "!jhpZBTbckszblMYjMK:matrix.org", "auth_events": [], "hashes": {"sha256": "+SSdmeeoKI/6yK6sY4XAFljWFiugSlCiXQf0QMCZjTs"}, "type": "m.room.create"}`,
+ eventID: "$143454825711DhCxH:matrix.org",
+ roomVersion: id.RoomV1,
+ serverDetails: matrixOrg,
+}, {
+ name: "m.room.member in v1 room",
+ pdu: `{"auth_events": [["$1536447669931522zlyWe:matrix.org", {"sha256": "UkzPGd7cPAGvC0FVx3Yy2/Q0GZhA2kcgj8MGp5pjYV8"}], ["$143454825711DhCxH:matrix.org", {"sha256": "3sJh/5GOB094OKuhbjL634Gt69YIcge9GD55ciJa9ok"}], ["$143454825714nUEqZ:matrix.org", {"sha256": "NjuZXu8EDMfIfejPcNlC/IdnKQAGpPIcQjHaf0BZaHk"}]], "prev_events": [["$15660585503271JRRMm:maunium.net", {"sha256": "/Sm7uSLkYMHapp6I3NuEVJlk2JucW2HqjsQy9vzhciA"}]], "type": "m.room.member", "room_id": "!jhpZBTbckszblMYjMK:matrix.org", "sender": "@tulir:maunium.net", "content": {"membership": "join", "avatar_url": "mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO", "displayname": "tulir"}, "depth": 10485, "prev_state": [], "state_key": "@tulir:maunium.net", "event_id": "$15660585693272iEryv:maunium.net", "origin": "maunium.net", "origin_server_ts": 1566058569201, "hashes": {"sha256": "1D6fdDzKsMGCxSqlXPA7I9wGQNTutVuJke1enGHoWK8"}, "signatures": {"maunium.net": {"ed25519:a_xxeS": "Lj/zDK6ozr4vgsxyL8jY56wTGWoA4jnlvkTs5paCX1w3nNKHnQnSMi+wuaqI6yv5vYh9usGWco2LLMuMzYXcBg"}}, "unsigned": {"age_ts": 1566058569201, "replaces_state": "$15660585383268liyBc:maunium.net"}}`,
+ eventID: "$15660585693272iEryv:maunium.net",
+ roomVersion: id.RoomV1,
+ serverDetails: mauniumNet,
+}}
+
+func parseV1PDU(pdu string) (out *pdu.RoomV1PDU) {
+ exerrors.PanicIfNotNil(json.Unmarshal([]byte(pdu), &out))
+ return
+}
+
+func TestRoomV1PDU_CalculateContentHash(t *testing.T) {
+ for _, test := range testV1PDUs {
+ t.Run(test.name, func(t *testing.T) {
+ parsed := parseV1PDU(test.pdu)
+ contentHash := exerrors.Must(parsed.CalculateContentHash())
+ assert.Equal(
+ t,
+ base64.RawStdEncoding.EncodeToString(parsed.Hashes.SHA256),
+ base64.RawStdEncoding.EncodeToString(contentHash[:]),
+ )
+ })
+ }
+}
+
+func TestRoomV1PDU_VerifyContentHash(t *testing.T) {
+ for _, test := range testV1PDUs {
+ t.Run(test.name, func(t *testing.T) {
+ parsed := parseV1PDU(test.pdu)
+ assert.True(t, parsed.VerifyContentHash())
+ })
+ }
+}
+
+func TestRoomV1PDU_VerifySignature(t *testing.T) {
+ for _, test := range testV1PDUs {
+ t.Run(test.name, func(t *testing.T) {
+ parsed := parseV1PDU(test.pdu)
+ err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) {
+ key, ok := test.keys[keyID]
+ if ok {
+ return key.key, key.validUntilTS, nil
+ }
+ return "", time.Time{}, nil
+ })
+ assert.NoError(t, err)
+ })
+ }
+}
diff --git a/federation/resolution.go b/federation/resolution.go
new file mode 100644
index 00000000..a3188266
--- /dev/null
+++ b/federation/resolution.go
@@ -0,0 +1,198 @@
+// Copyright (c) 2024 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package federation
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "net/url"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/rs/zerolog"
+
+ "maunium.net/go/mautrix"
+)
+
+type ResolvedServerName struct {
+ ServerName string `json:"server_name"`
+ HostHeader string `json:"host_header"`
+ IPPort []string `json:"ip_port"`
+ Expires time.Time `json:"expires"`
+}
+
+type ResolveServerNameOpts struct {
+ HTTPClient *http.Client
+ DNSClient *net.Resolver
+}
+
+var (
+ ErrInvalidServerName = errors.New("invalid server name")
+)
+
+// ResolveServerName implements the full server discovery algorithm as specified in https://spec.matrix.org/v1.11/server-server-api/#resolving-server-names
+func ResolveServerName(ctx context.Context, serverName string, opts ...*ResolveServerNameOpts) (*ResolvedServerName, error) {
+ var opt ResolveServerNameOpts
+ if len(opts) > 0 && opts[0] != nil {
+ opt = *opts[0]
+ }
+ if opt.HTTPClient == nil {
+ opt.HTTPClient = http.DefaultClient
+ }
+ if opt.DNSClient == nil {
+ opt.DNSClient = net.DefaultResolver
+ }
+ output := ResolvedServerName{
+ ServerName: serverName,
+ HostHeader: serverName,
+ IPPort: []string{serverName},
+ Expires: time.Now().Add(24 * time.Hour),
+ }
+ hostname, port, ok := ParseServerName(serverName)
+ if !ok {
+ return nil, ErrInvalidServerName
+ }
+ // Steps 1 and 2: handle IP literals and hostnames with port
+ if net.ParseIP(hostname) != nil || port != 0 {
+ if port == 0 {
+ port = 8448
+ }
+ output.IPPort = []string{net.JoinHostPort(hostname, strconv.Itoa(int(port)))}
+ return &output, nil
+ }
+ // Step 3: resolve .well-known
+ wellKnown, expiry, err := RequestWellKnown(ctx, opt.HTTPClient, hostname)
+ if err != nil {
+ zerolog.Ctx(ctx).Trace().
+ Str("server_name", serverName).
+ Err(err).
+ Msg("Failed to get well-known data")
+ } else if wellKnown != nil {
+ output.Expires = expiry
+ output.HostHeader = wellKnown.Server
+ wkHost, wkPort, ok := ParseServerName(wellKnown.Server)
+ if ok {
+ hostname, port = wkHost, wkPort
+ }
+ // Step 3.1 and 3.2: IP literals and hostnames with port inside .well-known
+ if net.ParseIP(hostname) != nil || port != 0 {
+ if port == 0 {
+ port = 8448
+ }
+ output.IPPort = []string{net.JoinHostPort(hostname, strconv.Itoa(int(port)))}
+ return &output, nil
+ }
+ }
+ // Step 3.3, 3.4, 4 and 5: resolve SRV records
+ srv, err := RequestSRV(ctx, opt.DNSClient, hostname)
+ if err != nil {
+ // TODO log more noisily for abnormal errors?
+ zerolog.Ctx(ctx).Trace().
+ Str("server_name", serverName).
+ Str("hostname", hostname).
+ Err(err).
+ Msg("Failed to get SRV record")
+ } else if len(srv) > 0 {
+ output.IPPort = make([]string, len(srv))
+ for i, record := range srv {
+ output.IPPort[i] = net.JoinHostPort(strings.TrimRight(record.Target, "."), strconv.Itoa(int(record.Port)))
+ }
+ return &output, nil
+ }
+ // Step 6 or 3.5: no SRV records were found, so default to port 8448
+ output.IPPort = []string{net.JoinHostPort(hostname, "8448")}
+ return &output, nil
+}
+
+// RequestSRV resolves the `_matrix-fed._tcp` SRV record for the given hostname.
+// If the new matrix-fed record is not found, it falls back to the old `_matrix._tcp` record.
+func RequestSRV(ctx context.Context, cli *net.Resolver, hostname string) ([]*net.SRV, error) {
+ _, target, err := cli.LookupSRV(ctx, "matrix-fed", "tcp", hostname)
+ var dnsErr *net.DNSError
+ if err != nil && errors.As(err, &dnsErr) && dnsErr.IsNotFound {
+ _, target, err = cli.LookupSRV(ctx, "matrix", "tcp", hostname)
+ }
+ return target, err
+}
+
+func parseCacheControl(resp *http.Response) time.Duration {
+ cc := resp.Header.Get("Cache-Control")
+ if cc == "" {
+ return 0
+ }
+ parts := strings.Split(cc, ",")
+ for _, part := range parts {
+ kv := strings.SplitN(strings.TrimSpace(part), "=", 1)
+ switch kv[0] {
+ case "no-cache", "no-store":
+ return 0
+ case "max-age":
+ if len(kv) < 2 {
+ continue
+ }
+ maxAge, err := strconv.Atoi(kv[1])
+ if err != nil || maxAge < 0 {
+ continue
+ }
+ age, _ := strconv.Atoi(resp.Header.Get("Age"))
+ return time.Duration(maxAge-age) * time.Second
+ }
+ }
+ return 0
+}
+
+const (
+ MinCacheDuration = 1 * time.Hour
+ MaxCacheDuration = 72 * time.Hour
+ DefaultCacheDuration = 24 * time.Hour
+)
+
+// RequestWellKnown sends a request to the well-known endpoint of a server and returns the response,
+// plus the time when the cache should expire.
+func RequestWellKnown(ctx context.Context, cli *http.Client, hostname string) (*RespWellKnown, time.Time, error) {
+ wellKnownURL := url.URL{
+ Scheme: "https",
+ Host: hostname,
+ Path: "/.well-known/matrix/server",
+ }
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnownURL.String(), nil)
+ if err != nil {
+ return nil, time.Time{}, fmt.Errorf("failed to prepare request: %w", err)
+ }
+ resp, err := cli.Do(req)
+ if err != nil {
+ return nil, time.Time{}, fmt.Errorf("failed to send request: %w", err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ return nil, time.Time{}, fmt.Errorf("unexpected status code %d", resp.StatusCode)
+ } else if resp.ContentLength > mautrix.WellKnownMaxSize {
+ return nil, time.Time{}, fmt.Errorf("response too large: %d bytes", resp.ContentLength)
+ }
+ var respData RespWellKnown
+ err = json.NewDecoder(io.LimitReader(resp.Body, mautrix.WellKnownMaxSize)).Decode(&respData)
+ if err != nil {
+ return nil, time.Time{}, fmt.Errorf("failed to decode response: %w", err)
+ } else if respData.Server == "" {
+ return nil, time.Time{}, errors.New("server name not found in response")
+ }
+ cacheDuration := parseCacheControl(resp)
+ if cacheDuration <= 0 {
+ cacheDuration = DefaultCacheDuration
+ } else if cacheDuration < MinCacheDuration {
+ cacheDuration = MinCacheDuration
+ } else if cacheDuration > MaxCacheDuration {
+ cacheDuration = MaxCacheDuration
+ }
+ return &respData, time.Now().Add(24 * time.Hour), nil
+}
diff --git a/federation/resolution_test.go b/federation/resolution_test.go
new file mode 100644
index 00000000..62200454
--- /dev/null
+++ b/federation/resolution_test.go
@@ -0,0 +1,115 @@
+// Copyright (c) 2024 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package federation_test
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "maunium.net/go/mautrix/federation"
+)
+
+type resolveTestCase struct {
+ name string
+ serverName string
+ expected federation.ResolvedServerName
+}
+
+func TestResolveServerName(t *testing.T) {
+ // See https://t2bot.io/docs/resolvematrix/ for more info on the RM test cases
+ testCases := []resolveTestCase{{
+ "maunium",
+ "maunium.net",
+ federation.ResolvedServerName{
+ HostHeader: "federation.mau.chat",
+ IPPort: []string{"meow.host.mau.fi:443"},
+ },
+ }, {
+ "IP literal",
+ "135.181.208.158",
+ federation.ResolvedServerName{
+ HostHeader: "135.181.208.158",
+ IPPort: []string{"135.181.208.158:8448"},
+ },
+ }, {
+ "IP literal with port",
+ "135.181.208.158:8447",
+ federation.ResolvedServerName{
+ HostHeader: "135.181.208.158:8447",
+ IPPort: []string{"135.181.208.158:8447"},
+ },
+ }, {
+ "RM Step 2",
+ "2.s.resolvematrix.dev:7652",
+ federation.ResolvedServerName{
+ HostHeader: "2.s.resolvematrix.dev:7652",
+ IPPort: []string{"2.s.resolvematrix.dev:7652"},
+ },
+ }, {
+ "RM Step 3B",
+ "3b.s.resolvematrix.dev",
+ federation.ResolvedServerName{
+ HostHeader: "wk.3b.s.resolvematrix.dev:7753",
+ IPPort: []string{"wk.3b.s.resolvematrix.dev:7753"},
+ },
+ }, {
+ "RM Step 3C",
+ "3c.s.resolvematrix.dev",
+ federation.ResolvedServerName{
+ HostHeader: "wk.3c.s.resolvematrix.dev",
+ IPPort: []string{"srv.wk.3c.s.resolvematrix.dev:7754"},
+ },
+ }, {
+ "RM Step 3C MSC4040",
+ "3c.msc4040.s.resolvematrix.dev",
+ federation.ResolvedServerName{
+ HostHeader: "wk.3c.msc4040.s.resolvematrix.dev",
+ IPPort: []string{"srv.wk.3c.msc4040.s.resolvematrix.dev:7053"},
+ },
+ }, {
+ "RM Step 3D",
+ "3d.s.resolvematrix.dev",
+ federation.ResolvedServerName{
+ HostHeader: "wk.3d.s.resolvematrix.dev",
+ IPPort: []string{"wk.3d.s.resolvematrix.dev:8448"},
+ },
+ }, {
+ "RM Step 4",
+ "4.s.resolvematrix.dev",
+ federation.ResolvedServerName{
+ HostHeader: "4.s.resolvematrix.dev",
+ IPPort: []string{"srv.4.s.resolvematrix.dev:7855"},
+ },
+ }, {
+ "RM Step 4 MSC4040",
+ "4.msc4040.s.resolvematrix.dev",
+ federation.ResolvedServerName{
+ HostHeader: "4.msc4040.s.resolvematrix.dev",
+ IPPort: []string{"srv.4.msc4040.s.resolvematrix.dev:7054"},
+ },
+ }, {
+ "RM Step 5",
+ "5.s.resolvematrix.dev",
+ federation.ResolvedServerName{
+ HostHeader: "5.s.resolvematrix.dev",
+ IPPort: []string{"5.s.resolvematrix.dev:8448"},
+ },
+ }}
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ tc.expected.ServerName = tc.serverName
+ resp, err := federation.ResolveServerName(context.TODO(), tc.serverName)
+ require.NoError(t, err)
+ resp.Expires = time.Time{}
+ assert.Equal(t, tc.expected, *resp)
+ })
+ }
+}
diff --git a/federation/serverauth.go b/federation/serverauth.go
new file mode 100644
index 00000000..cd300341
--- /dev/null
+++ b/federation/serverauth.go
@@ -0,0 +1,264 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package federation
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "maps"
+ "net/http"
+ "slices"
+ "strings"
+ "sync"
+
+ "github.com/rs/zerolog"
+ "go.mau.fi/util/ptr"
+
+ "maunium.net/go/mautrix"
+ "maunium.net/go/mautrix/id"
+)
+
+type ServerAuth struct {
+ Keys KeyCache
+ Client *Client
+ GetDestination func(XMatrixAuth) string
+ MaxBodySize int64
+
+ keyFetchLocks map[string]*sync.Mutex
+ keyFetchLocksLock sync.Mutex
+}
+
+func NewServerAuth(client *Client, keyCache KeyCache, getDestination func(auth XMatrixAuth) string) *ServerAuth {
+ return &ServerAuth{
+ Keys: keyCache,
+ Client: client,
+ GetDestination: getDestination,
+ MaxBodySize: 50 * 1024 * 1024,
+ keyFetchLocks: make(map[string]*sync.Mutex),
+ }
+}
+
+var MUnauthorized = mautrix.RespError{ErrCode: "M_UNAUTHORIZED", StatusCode: http.StatusUnauthorized}
+
+var (
+ errMissingAuthHeader = MUnauthorized.WithMessage("Missing Authorization header")
+ errInvalidAuthHeader = MUnauthorized.WithMessage("Authorization header does not start with X-Matrix")
+ errMalformedAuthHeader = MUnauthorized.WithMessage("X-Matrix value is missing required components")
+ errInvalidDestination = MUnauthorized.WithMessage("Invalid destination in X-Matrix header")
+ errFailedToQueryKeys = MUnauthorized.WithMessage("Failed to query server keys")
+ errInvalidSelfSignatures = MUnauthorized.WithMessage("Server keys don't have valid self-signatures")
+ errRequestBodyTooLarge = mautrix.MTooLarge.WithMessage("Request body too large")
+ errInvalidJSONBody = mautrix.MBadJSON.WithMessage("Request body is not valid JSON")
+ errBodyReadFailed = mautrix.MUnknown.WithMessage("Failed to read request body")
+ errInvalidRequestSignature = MUnauthorized.WithMessage("Failed to verify request signature")
+)
+
+type XMatrixAuth struct {
+ Origin string
+ Destination string
+ KeyID id.KeyID
+ Signature string
+}
+
+func (xma XMatrixAuth) String() string {
+ return fmt.Sprintf(
+ `X-Matrix origin="%s",destination="%s",key="%s",sig="%s"`,
+ xma.Origin,
+ xma.Destination,
+ xma.KeyID,
+ xma.Signature,
+ )
+}
+
+func ParseXMatrixAuth(auth string) (xma XMatrixAuth) {
+ auth = strings.TrimPrefix(auth, "X-Matrix ")
+ // TODO upgrade to strings.SplitSeq after Go 1.24 is the minimum
+ for _, part := range strings.Split(auth, ",") {
+ part = strings.TrimSpace(part)
+ eqIdx := strings.Index(part, "=")
+ if eqIdx == -1 || strings.Count(part, "=") > 1 {
+ continue
+ }
+ val := strings.Trim(part[eqIdx+1:], "\"")
+ switch strings.ToLower(part[:eqIdx]) {
+ case "origin":
+ xma.Origin = val
+ case "destination":
+ xma.Destination = val
+ case "key":
+ xma.KeyID = id.KeyID(val)
+ case "sig":
+ xma.Signature = val
+ }
+ }
+ return
+}
+
+func (sa *ServerAuth) GetKeysWithCache(ctx context.Context, serverName string, keyID id.KeyID) (*ServerKeyResponse, error) {
+ res, err := sa.Keys.LoadKeys(serverName)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read cache: %w", err)
+ } else if res.HasKey(keyID) {
+ return res, nil
+ }
+
+ sa.keyFetchLocksLock.Lock()
+ lock, ok := sa.keyFetchLocks[serverName]
+ if !ok {
+ lock = &sync.Mutex{}
+ sa.keyFetchLocks[serverName] = lock
+ }
+ sa.keyFetchLocksLock.Unlock()
+
+ lock.Lock()
+ defer lock.Unlock()
+ res, err = sa.Keys.LoadKeys(serverName)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read cache: %w", err)
+ } else if res != nil {
+ if res.HasKey(keyID) {
+ return res, nil
+ } else if !sa.Keys.ShouldReQuery(serverName) {
+ zerolog.Ctx(ctx).Trace().
+ Str("server_name", serverName).
+ Stringer("key_id", keyID).
+ Msg("Not sending key request for missing key ID, last query was too recent")
+ return res, nil
+ }
+ }
+ res, err = sa.Client.ServerKeys(ctx, serverName)
+ if err != nil {
+ sa.Keys.StoreFetchError(serverName, err)
+ return nil, err
+ }
+ sa.Keys.StoreKeys(res)
+ return res, nil
+}
+
+type fixedLimitedReader struct {
+ R io.Reader
+ N int64
+ Err error
+}
+
+func (l *fixedLimitedReader) Read(p []byte) (n int, err error) {
+ if l.N <= 0 {
+ return 0, l.Err
+ }
+ if int64(len(p)) > l.N {
+ p = p[0:l.N]
+ }
+ n, err = l.R.Read(p)
+ l.N -= int64(n)
+ return
+}
+
+func (sa *ServerAuth) Authenticate(r *http.Request) (*http.Request, *mautrix.RespError) {
+ defer func() {
+ _ = r.Body.Close()
+ }()
+ log := zerolog.Ctx(r.Context())
+ if r.ContentLength > sa.MaxBodySize {
+ return nil, &errRequestBodyTooLarge
+ }
+ auth := r.Header.Get("Authorization")
+ if auth == "" {
+ return nil, &errMissingAuthHeader
+ } else if !strings.HasPrefix(auth, "X-Matrix ") {
+ return nil, &errInvalidAuthHeader
+ }
+ parsed := ParseXMatrixAuth(auth)
+ if parsed.Origin == "" || parsed.KeyID == "" || parsed.Signature == "" {
+ log.Trace().Str("auth_header", auth).Msg("Malformed X-Matrix header")
+ return nil, &errMalformedAuthHeader
+ }
+ destination := sa.GetDestination(parsed)
+ if destination == "" || (parsed.Destination != "" && parsed.Destination != destination) {
+ log.Trace().
+ Str("got_destination", parsed.Destination).
+ Str("expected_destination", destination).
+ Msg("Invalid destination in X-Matrix header")
+ return nil, &errInvalidDestination
+ }
+ resp, err := sa.GetKeysWithCache(r.Context(), parsed.Origin, parsed.KeyID)
+ if err != nil {
+ if !errors.Is(err, ErrRecentKeyQueryFailed) {
+ log.Err(err).
+ Str("server_name", parsed.Origin).
+ Msg("Failed to query keys to authenticate request")
+ } else {
+ log.Trace().Err(err).
+ Str("server_name", parsed.Origin).
+ Msg("Failed to query keys to authenticate request (cached error)")
+ }
+ return nil, &errFailedToQueryKeys
+ } else if err := resp.VerifySelfSignature(); err != nil {
+ log.Trace().Err(err).
+ Str("server_name", parsed.Origin).
+ Msg("Failed to validate self-signatures of server keys")
+ return nil, &errInvalidSelfSignatures
+ }
+ key, ok := resp.VerifyKeys[parsed.KeyID]
+ if !ok {
+ keys := slices.Collect(maps.Keys(resp.VerifyKeys))
+ log.Trace().
+ Stringer("expected_key_id", parsed.KeyID).
+ Any("found_key_ids", keys).
+ Msg("Didn't find expected key ID to verify request")
+ return nil, ptr.Ptr(MUnauthorized.WithMessage("Key ID %q not found (got %v)", parsed.KeyID, keys))
+ }
+ var reqBody []byte
+ if r.ContentLength != 0 && r.Method != http.MethodGet && r.Method != http.MethodHead {
+ reqBody, err = io.ReadAll(&fixedLimitedReader{R: r.Body, N: sa.MaxBodySize, Err: errRequestBodyTooLarge})
+ if errors.Is(err, errRequestBodyTooLarge) {
+ return nil, &errRequestBodyTooLarge
+ } else if err != nil {
+ log.Err(err).
+ Str("server_name", parsed.Origin).
+ Msg("Failed to read request body to authenticate")
+ return nil, &errBodyReadFailed
+ } else if !json.Valid(reqBody) {
+ return nil, &errInvalidJSONBody
+ }
+ }
+ err = (&signableRequest{
+ Method: r.Method,
+ URI: r.URL.RequestURI(),
+ Origin: parsed.Origin,
+ Destination: destination,
+ Content: reqBody,
+ }).Verify(key.Key, parsed.Signature)
+ if err != nil {
+ log.Trace().Err(err).Msg("Request has invalid signature")
+ return nil, &errInvalidRequestSignature
+ }
+ ctx := context.WithValue(r.Context(), contextKeyDestinationServer, destination)
+ ctx = context.WithValue(ctx, contextKeyOriginServer, parsed.Origin)
+ ctx = log.With().
+ Str("origin_server_name", parsed.Origin).
+ Str("destination_server_name", destination).
+ Logger().WithContext(ctx)
+ modifiedReq := r.WithContext(ctx)
+ if reqBody != nil {
+ modifiedReq.Body = io.NopCloser(bytes.NewReader(reqBody))
+ }
+ return modifiedReq, nil
+}
+
+func (sa *ServerAuth) AuthenticateMiddleware(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if modifiedReq, err := sa.Authenticate(r); err != nil {
+ err.Write(w)
+ } else {
+ next.ServeHTTP(w, modifiedReq)
+ }
+ })
+}
diff --git a/federation/serverauth_test.go b/federation/serverauth_test.go
new file mode 100644
index 00000000..f99fc6cf
--- /dev/null
+++ b/federation/serverauth_test.go
@@ -0,0 +1,29 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package federation_test
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "maunium.net/go/mautrix/federation"
+)
+
+func TestServerKeyResponse_VerifySelfSignature(t *testing.T) {
+ cli := federation.NewClient("", nil, nil)
+ ctx := context.Background()
+ for _, name := range []string{"matrix.org", "maunium.net", "cd.mau.dev", "uwu.mau.dev"} {
+ t.Run(name, func(t *testing.T) {
+ resp, err := cli.ServerKeys(ctx, name)
+ require.NoError(t, err)
+ assert.NoError(t, resp.VerifySelfSignature())
+ })
+ }
+}
diff --git a/federation/servername.go b/federation/servername.go
new file mode 100644
index 00000000..33590712
--- /dev/null
+++ b/federation/servername.go
@@ -0,0 +1,95 @@
+// Copyright (c) 2024 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package federation
+
+import (
+ "net"
+ "strconv"
+ "strings"
+)
+
+func isSpecCompliantIPv6(host string) bool {
+ // IPv6address = 2*45IPv6char
+ // IPv6char = DIGIT / %x41-46 / %x61-66 / ":" / "."
+ // ; 0-9, A-F, a-f, :, .
+ if len(host) < 2 || len(host) > 45 {
+ return false
+ }
+ for _, ch := range host {
+ if (ch < '0' || ch > '9') && (ch < 'a' || ch > 'f') && (ch < 'A' || ch > 'F') && ch != ':' && ch != '.' {
+ return false
+ }
+ }
+ return true
+}
+
+func isValidIPv4Chunk(str string) bool {
+ if len(str) == 0 || len(str) > 3 {
+ return false
+ }
+ for _, ch := range str {
+ if ch < '0' || ch > '9' {
+ return false
+ }
+ }
+ return true
+
+}
+
+func isSpecCompliantIPv4(host string) bool {
+ // IPv4address = 1*3DIGIT "." 1*3DIGIT "." 1*3DIGIT "." 1*3DIGIT
+ if len(host) < 7 || len(host) > 15 {
+ return false
+ }
+ parts := strings.Split(host, ".")
+ return len(parts) == 4 &&
+ isValidIPv4Chunk(parts[0]) &&
+ isValidIPv4Chunk(parts[1]) &&
+ isValidIPv4Chunk(parts[2]) &&
+ isValidIPv4Chunk(parts[3])
+}
+
+func isSpecCompliantDNSName(host string) bool {
+ // dns-name = 1*255dns-char
+ // dns-char = DIGIT / ALPHA / "-" / "."
+ if len(host) == 0 || len(host) > 255 {
+ return false
+ }
+ for _, ch := range host {
+ if (ch < '0' || ch > '9') && (ch < 'a' || ch > 'z') && (ch < 'A' || ch > 'Z') && ch != '-' && ch != '.' {
+ return false
+ }
+ }
+ return true
+}
+
+// ParseServerName parses the port and hostname from a Matrix server name and validates that
+// it matches the grammar specified in https://spec.matrix.org/v1.11/appendices/#server-name
+func ParseServerName(serverName string) (host string, port uint16, ok bool) {
+ if len(serverName) == 0 || len(serverName) > 255 {
+ return
+ }
+ colonIdx := strings.LastIndexByte(serverName, ':')
+ if colonIdx > 0 {
+ u64Port, err := strconv.ParseUint(serverName[colonIdx+1:], 10, 16)
+ if err == nil {
+ port = uint16(u64Port)
+ serverName = serverName[:colonIdx]
+ }
+ }
+ if serverName[0] == '[' {
+ if serverName[len(serverName)-1] != ']' {
+ return
+ }
+ host = serverName[1 : len(serverName)-1]
+ ok = isSpecCompliantIPv6(host) && net.ParseIP(host) != nil
+ } else {
+ host = serverName
+ ok = isSpecCompliantDNSName(host) || isSpecCompliantIPv4(host)
+ }
+ return
+}
diff --git a/federation/servername_test.go b/federation/servername_test.go
new file mode 100644
index 00000000..156d692f
--- /dev/null
+++ b/federation/servername_test.go
@@ -0,0 +1,64 @@
+// Copyright (c) 2024 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package federation_test
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+
+ "maunium.net/go/mautrix/federation"
+)
+
+type parseTestCase struct {
+ name string
+ serverName string
+ hostname string
+ port uint16
+}
+
+func TestParseServerName(t *testing.T) {
+ testCases := []parseTestCase{{
+ "Domain",
+ "matrix.org",
+ "matrix.org",
+ 0,
+ }, {
+ "Domain with port",
+ "matrix.org:8448",
+ "matrix.org",
+ 8448,
+ }, {
+ "IPv4 literal",
+ "1.2.3.4",
+ "1.2.3.4",
+ 0,
+ }, {
+ "IPv4 literal with port",
+ "1.2.3.4:8448",
+ "1.2.3.4",
+ 8448,
+ }, {
+ "IPv6 literal",
+ "[1234:5678::abcd]",
+ "1234:5678::abcd",
+ 0,
+ }, {
+ "IPv6 literal with port",
+ "[1234:5678::abcd]:8448",
+ "1234:5678::abcd",
+ 8448,
+ }}
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ hostname, port, ok := federation.ParseServerName(tc.serverName)
+ assert.True(t, ok)
+ assert.Equal(t, tc.hostname, hostname)
+ assert.Equal(t, tc.port, port)
+ })
+ }
+}
diff --git a/federation/signingkey.go b/federation/signingkey.go
new file mode 100644
index 00000000..a4ad9679
--- /dev/null
+++ b/federation/signingkey.go
@@ -0,0 +1,164 @@
+// Copyright (c) 2024 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package federation
+
+import (
+ "crypto/ed25519"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "github.com/tidwall/sjson"
+ "go.mau.fi/util/jsontime"
+
+ "maunium.net/go/mautrix/crypto/canonicaljson"
+ "maunium.net/go/mautrix/federation/signutil"
+ "maunium.net/go/mautrix/id"
+)
+
+// SigningKey is a Matrix federation signing key pair.
+type SigningKey struct {
+ ID id.KeyID
+ Pub id.SigningKey
+ Priv ed25519.PrivateKey
+}
+
+// SynapseString returns a string representation of the private key compatible with Synapse's .signing.key file format.
+//
+// The output of this function can be parsed back into a [SigningKey] using the [ParseSynapseKey] function.
+func (sk *SigningKey) SynapseString() string {
+ alg, keyID := sk.ID.Parse()
+ return fmt.Sprintf("%s %s %s", alg, keyID, base64.RawStdEncoding.EncodeToString(sk.Priv.Seed()))
+}
+
+// ParseSynapseKey parses a Synapse-compatible private key string into a SigningKey.
+func ParseSynapseKey(key string) (*SigningKey, error) {
+ parts := strings.Split(key, " ")
+ if len(parts) != 3 {
+ return nil, fmt.Errorf("invalid key format (expected 3 space-separated parts, got %d)", len(parts))
+ } else if parts[0] != string(id.KeyAlgorithmEd25519) {
+ return nil, fmt.Errorf("unsupported key algorithm %s (only ed25519 is supported)", parts[0])
+ }
+ seed, err := base64.RawStdEncoding.DecodeString(parts[2])
+ if err != nil {
+ return nil, fmt.Errorf("invalid private key: %w", err)
+ }
+ priv := ed25519.NewKeyFromSeed(seed)
+ pub := base64.RawStdEncoding.EncodeToString(priv.Public().(ed25519.PublicKey))
+ return &SigningKey{
+ ID: id.NewKeyID(id.KeyAlgorithmEd25519, parts[1]),
+ Pub: id.SigningKey(pub),
+ Priv: priv,
+ }, nil
+}
+
+// GenerateSigningKey generates a new random signing key.
+func GenerateSigningKey() *SigningKey {
+ pub, priv, err := ed25519.GenerateKey(nil)
+ if err != nil {
+ panic(err)
+ }
+ return &SigningKey{
+ ID: id.NewKeyID(id.KeyAlgorithmEd25519, base64.RawURLEncoding.EncodeToString(pub[:4])),
+ Pub: id.SigningKey(base64.RawStdEncoding.EncodeToString(pub)),
+ Priv: priv,
+ }
+}
+
+// ServerKeyResponse is the response body for the `GET /_matrix/key/v2/server` endpoint.
+// It's also used inside the query endpoint response structs.
+type ServerKeyResponse struct {
+ ServerName string `json:"server_name"`
+ VerifyKeys map[id.KeyID]ServerVerifyKey `json:"verify_keys"`
+ OldVerifyKeys map[id.KeyID]OldVerifyKey `json:"old_verify_keys,omitempty"`
+ Signatures map[string]map[id.KeyID]string `json:"signatures,omitempty"`
+ ValidUntilTS jsontime.UnixMilli `json:"valid_until_ts"`
+
+ Raw json.RawMessage `json:"-"`
+}
+
+type QueryKeysResponse struct {
+ ServerKeys []*ServerKeyResponse `json:"server_keys"`
+}
+
+func (skr *ServerKeyResponse) HasKey(keyID id.KeyID) bool {
+ if skr == nil {
+ return false
+ } else if _, ok := skr.VerifyKeys[keyID]; ok {
+ return true
+ }
+ return false
+}
+
+func (skr *ServerKeyResponse) VerifySelfSignature() error {
+ for keyID, key := range skr.VerifyKeys {
+ if err := signutil.VerifyJSON(skr.ServerName, keyID, key.Key, skr.Raw); err != nil {
+ return fmt.Errorf("failed to verify self signature for key %s: %w", keyID, err)
+ }
+ }
+ return nil
+}
+
+type marshalableSKR ServerKeyResponse
+
+func (skr *ServerKeyResponse) UnmarshalJSON(data []byte) error {
+ skr.Raw = data
+ return json.Unmarshal(data, (*marshalableSKR)(skr))
+}
+
+type ServerVerifyKey struct {
+ Key id.SigningKey `json:"key"`
+}
+
+func (svk *ServerVerifyKey) Decode() (ed25519.PublicKey, error) {
+ return base64.RawStdEncoding.DecodeString(string(svk.Key))
+}
+
+type OldVerifyKey struct {
+ Key id.SigningKey `json:"key"`
+ ExpiredTS jsontime.UnixMilli `json:"expired_ts"`
+}
+
+func (sk *SigningKey) SignJSON(data any) (string, error) {
+ marshaled, err := json.Marshal(data)
+ if err != nil {
+ return "", err
+ }
+ marshaled, err = sjson.DeleteBytes(marshaled, "signatures")
+ if err != nil {
+ return "", err
+ }
+ return base64.RawStdEncoding.EncodeToString(sk.SignRawJSON(marshaled)), nil
+}
+
+func (sk *SigningKey) SignRawJSON(data json.RawMessage) []byte {
+ return ed25519.Sign(sk.Priv, canonicaljson.CanonicalJSONAssumeValid(data))
+}
+
+// GenerateKeyResponse generates a key response signed by this key with the given server name and optionally some old verify keys.
+func (sk *SigningKey) GenerateKeyResponse(serverName string, oldVerifyKeys map[id.KeyID]OldVerifyKey) *ServerKeyResponse {
+ skr := &ServerKeyResponse{
+ ServerName: serverName,
+ OldVerifyKeys: oldVerifyKeys,
+ ValidUntilTS: jsontime.UM(time.Now().Add(24 * time.Hour)),
+ VerifyKeys: map[id.KeyID]ServerVerifyKey{
+ sk.ID: {Key: sk.Pub},
+ },
+ }
+ signature, err := sk.SignJSON(skr)
+ if err != nil {
+ panic(err)
+ }
+ skr.Signatures = map[string]map[id.KeyID]string{
+ serverName: {
+ sk.ID: signature,
+ },
+ }
+ return skr
+}
diff --git a/federation/signutil/verify.go b/federation/signutil/verify.go
new file mode 100644
index 00000000..ea0e7886
--- /dev/null
+++ b/federation/signutil/verify.go
@@ -0,0 +1,106 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package signutil
+
+import (
+ "crypto/ed25519"
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
+ "go.mau.fi/util/exgjson"
+
+ "maunium.net/go/mautrix/crypto/canonicaljson"
+ "maunium.net/go/mautrix/id"
+)
+
+var ErrSignatureNotFound = errors.New("signature not found")
+var ErrInvalidSignature = errors.New("invalid signature")
+
+func VerifyJSON(serverName string, keyID id.KeyID, key id.SigningKey, data any) error {
+ var err error
+ message, ok := data.(json.RawMessage)
+ if !ok {
+ message, err = json.Marshal(data)
+ if err != nil {
+ return fmt.Errorf("failed to marshal data: %w", err)
+ }
+ }
+ sigVal := gjson.GetBytes(message, exgjson.Path("signatures", serverName, string(keyID)))
+ if sigVal.Type != gjson.String {
+ return ErrSignatureNotFound
+ }
+ message, err = sjson.DeleteBytes(message, "signatures")
+ if err != nil {
+ return fmt.Errorf("failed to delete signatures: %w", err)
+ }
+ message, err = sjson.DeleteBytes(message, "unsigned")
+ if err != nil {
+ return fmt.Errorf("failed to delete unsigned: %w", err)
+ }
+ return VerifyJSONRaw(key, sigVal.Str, message)
+}
+
+func VerifyJSONAny(key id.SigningKey, data any) error {
+ var err error
+ message, ok := data.(json.RawMessage)
+ if !ok {
+ message, err = json.Marshal(data)
+ if err != nil {
+ return fmt.Errorf("failed to marshal data: %w", err)
+ }
+ }
+ sigs := gjson.GetBytes(message, "signatures")
+ if !sigs.IsObject() {
+ return ErrSignatureNotFound
+ }
+ message, err = sjson.DeleteBytes(message, "signatures")
+ if err != nil {
+ return fmt.Errorf("failed to delete signatures: %w", err)
+ }
+ message, err = sjson.DeleteBytes(message, "unsigned")
+ if err != nil {
+ return fmt.Errorf("failed to delete unsigned: %w", err)
+ }
+ var validated bool
+ sigs.ForEach(func(_, value gjson.Result) bool {
+ if !value.IsObject() {
+ return true
+ }
+ value.ForEach(func(_, value gjson.Result) bool {
+ if value.Type != gjson.String {
+ return true
+ }
+ validated = VerifyJSONRaw(key, value.Str, message) == nil
+ return !validated
+ })
+ return !validated
+ })
+ if !validated {
+ return ErrInvalidSignature
+ }
+ return nil
+}
+
+func VerifyJSONRaw(key id.SigningKey, sig string, message json.RawMessage) error {
+ sigBytes, err := base64.RawStdEncoding.DecodeString(sig)
+ if err != nil {
+ return fmt.Errorf("failed to decode signature: %w", err)
+ }
+ keyBytes, err := base64.RawStdEncoding.DecodeString(string(key))
+ if err != nil {
+ return fmt.Errorf("failed to decode key: %w", err)
+ }
+ message = canonicaljson.CanonicalJSONAssumeValid(message)
+ if !ed25519.Verify(keyBytes, message, sigBytes) {
+ return ErrInvalidSignature
+ }
+ return nil
+}
diff --git a/filter.go b/filter.go
index fd6de7a0..54973dab 100644
--- a/filter.go
+++ b/filter.go
@@ -19,43 +19,45 @@ const (
// Filter is used by clients to specify how the server should filter responses to e.g. sync requests
// Specified by: https://spec.matrix.org/v1.2/client-server-api/#filtering
type Filter struct {
- AccountData FilterPart `json:"account_data,omitempty"`
+ AccountData *FilterPart `json:"account_data,omitempty"`
EventFields []string `json:"event_fields,omitempty"`
EventFormat EventFormat `json:"event_format,omitempty"`
- Presence FilterPart `json:"presence,omitempty"`
- Room RoomFilter `json:"room,omitempty"`
+ Presence *FilterPart `json:"presence,omitempty"`
+ Room *RoomFilter `json:"room,omitempty"`
+
+ BeeperToDevice *FilterPart `json:"com.beeper.to_device,omitempty"`
}
// RoomFilter is used to define filtering rules for room events
type RoomFilter struct {
- AccountData FilterPart `json:"account_data,omitempty"`
- Ephemeral FilterPart `json:"ephemeral,omitempty"`
+ AccountData *FilterPart `json:"account_data,omitempty"`
+ Ephemeral *FilterPart `json:"ephemeral,omitempty"`
IncludeLeave bool `json:"include_leave,omitempty"`
NotRooms []id.RoomID `json:"not_rooms,omitempty"`
Rooms []id.RoomID `json:"rooms,omitempty"`
- State FilterPart `json:"state,omitempty"`
- Timeline FilterPart `json:"timeline,omitempty"`
+ State *FilterPart `json:"state,omitempty"`
+ Timeline *FilterPart `json:"timeline,omitempty"`
}
// FilterPart is used to define filtering rules for specific categories of events
type FilterPart struct {
- NotRooms []id.RoomID `json:"not_rooms,omitempty"`
- Rooms []id.RoomID `json:"rooms,omitempty"`
- Limit int `json:"limit,omitempty"`
- NotSenders []id.UserID `json:"not_senders,omitempty"`
- NotTypes []event.Type `json:"not_types,omitempty"`
- Senders []id.UserID `json:"senders,omitempty"`
- Types []event.Type `json:"types,omitempty"`
- ContainsURL *bool `json:"contains_url,omitempty"`
-
- LazyLoadMembers bool `json:"lazy_load_members,omitempty"`
- IncludeRedundantMembers bool `json:"include_redundant_members,omitempty"`
+ NotRooms []id.RoomID `json:"not_rooms,omitempty"`
+ Rooms []id.RoomID `json:"rooms,omitempty"`
+ Limit int `json:"limit,omitempty"`
+ NotSenders []id.UserID `json:"not_senders,omitempty"`
+ NotTypes []event.Type `json:"not_types,omitempty"`
+ Senders []id.UserID `json:"senders,omitempty"`
+ Types []event.Type `json:"types,omitempty"`
+ ContainsURL *bool `json:"contains_url,omitempty"`
+ LazyLoadMembers bool `json:"lazy_load_members,omitempty"`
+ IncludeRedundantMembers bool `json:"include_redundant_members,omitempty"`
+ UnreadThreadNotifications bool `json:"unread_thread_notifications,omitempty"`
}
// Validate checks if the filter contains valid property values
func (filter *Filter) Validate() error {
if filter.EventFormat != EventFormatClient && filter.EventFormat != EventFormatFederation {
- return errors.New("Bad event_format value. Must be one of [\"client\", \"federation\"]")
+ return errors.New("bad event_format value")
}
return nil
}
@@ -67,7 +69,7 @@ func DefaultFilter() Filter {
EventFields: nil,
EventFormat: "client",
Presence: DefaultFilterPart(),
- Room: RoomFilter{
+ Room: &RoomFilter{
AccountData: DefaultFilterPart(),
Ephemeral: DefaultFilterPart(),
IncludeLeave: false,
@@ -80,8 +82,8 @@ func DefaultFilter() Filter {
}
// DefaultFilterPart returns the default filter part used by the Matrix server if no filter is provided in the request
-func DefaultFilterPart() FilterPart {
- return FilterPart{
+func DefaultFilterPart() *FilterPart {
+ return &FilterPart{
NotRooms: nil,
Rooms: nil,
Limit: 20,
diff --git a/format/htmlparser.go b/format/htmlparser.go
index eb2a662b..e0507d93 100644
--- a/format/htmlparser.go
+++ b/format/htmlparser.go
@@ -7,13 +7,16 @@
package format
import (
+ "context"
"fmt"
"math"
"strconv"
"strings"
+ "go.mau.fi/util/exstrings"
"golang.org/x/net/html"
+ "maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
@@ -33,14 +36,16 @@ func (ts TagStack) Has(tag string) bool {
}
type Context struct {
+ Ctx context.Context
ReturnData map[string]any
TagStack TagStack
PreserveWhitespace bool
}
-func NewContext() Context {
+func NewContext(ctx context.Context) Context {
return Context{
+ Ctx: ctx,
ReturnData: map[string]any{},
TagStack: make(TagStack, 0, 4),
}
@@ -62,10 +67,15 @@ type LinkConverter func(text, href string, ctx Context) string
type ColorConverter func(text, fg, bg string, ctx Context) string
type CodeBlockConverter func(code, language string, ctx Context) string
type PillConverter func(displayname, mxid, eventID string, ctx Context) string
+type ImageConverter func(src, alt, title, width, height string, isEmoji bool) string
-func DefaultPillConverter(displayname, mxid, eventID string, _ Context) string {
+const ContextKeyMentions = "_mentions"
+
+func DefaultPillConverter(displayname, mxid, eventID string, ctx Context) string {
switch {
case len(mxid) == 0, mxid[0] == '@':
+ existingMentions, _ := ctx.ReturnData[ContextKeyMentions].([]id.UserID)
+ ctx.ReturnData[ContextKeyMentions] = append(existingMentions, id.UserID(mxid))
// User link, always just show the displayname
return displayname
case len(eventID) > 0:
@@ -83,6 +93,30 @@ func DefaultPillConverter(displayname, mxid, eventID string, _ Context) string {
}
}
+func onlyBacktickCount(line string) (count int) {
+ for i := 0; i < len(line); i++ {
+ if line[i] != '`' {
+ return -1
+ }
+ count++
+ }
+ return
+}
+
+func DefaultMonospaceBlockConverter(code, language string, ctx Context) string {
+ if len(code) == 0 || code[len(code)-1] != '\n' {
+ code += "\n"
+ }
+ fence := "```"
+ for line := range strings.SplitSeq(code, "\n") {
+ count := onlyBacktickCount(strings.TrimSpace(line))
+ if count >= len(fence) {
+ fence = strings.Repeat("`", count+1)
+ }
+ }
+ return fmt.Sprintf("%s%s\n%s%s", fence, language, code, fence)
+}
+
// HTMLParser is a somewhat customizable Matrix HTML parser.
type HTMLParser struct {
PillConverter PillConverter
@@ -93,12 +127,15 @@ type HTMLParser struct {
ItalicConverter TextConverter
StrikethroughConverter TextConverter
UnderlineConverter TextConverter
+ MathConverter TextConverter
+ MathBlockConverter TextConverter
LinkConverter LinkConverter
SpoilerConverter SpoilerConverter
ColorConverter ColorConverter
MonospaceBlockConverter CodeBlockConverter
MonospaceConverter TextConverter
TextConverter TextConverter
+ ImageConverter ImageConverter
}
// TaggedString is a string that also contains a HTML tag.
@@ -175,25 +212,6 @@ func (parser *HTMLParser) listToString(node *html.Node, ctx Context) string {
return strings.Join(children, "\n")
}
-func LongestSequence(in string, of rune) int {
- currentSeq := 0
- maxSeq := 0
- for _, chr := range in {
- if chr == of {
- currentSeq++
- } else {
- if currentSeq > maxSeq {
- maxSeq = currentSeq
- }
- currentSeq = 0
- }
- }
- if currentSeq > maxSeq {
- maxSeq = currentSeq
- }
- return maxSeq
-}
-
func (parser *HTMLParser) basicFormatToString(node *html.Node, ctx Context) string {
str := parser.nodeToTagAwareString(node.FirstChild, ctx)
switch node.Data {
@@ -220,14 +238,23 @@ func (parser *HTMLParser) basicFormatToString(node *html.Node, ctx Context) stri
if parser.MonospaceConverter != nil {
return parser.MonospaceConverter(str, ctx)
}
- surround := strings.Repeat("`", LongestSequence(str, '`')+1)
- return fmt.Sprintf("%s%s%s", surround, str, surround)
+ return SafeMarkdownCode(str)
}
return str
}
func (parser *HTMLParser) spanToString(node *html.Node, ctx Context) string {
str := parser.nodeToTagAwareString(node.FirstChild, ctx)
+ if node.Data == "span" || node.Data == "div" {
+ math, _ := parser.maybeGetAttribute(node, "data-mx-maths")
+ if math != "" && parser.MathConverter != nil {
+ if node.Data == "div" && parser.MathBlockConverter != nil {
+ str = parser.MathBlockConverter(math, ctx)
+ } else {
+ str = parser.MathConverter(math, ctx)
+ }
+ }
+ }
if node.Data == "span" {
reason, isSpoiler := parser.maybeGetAttribute(node, "data-mx-spoiler")
if isSpoiler {
@@ -284,12 +311,28 @@ func (parser *HTMLParser) linkToString(node *html.Node, ctx Context) string {
}
if parser.LinkConverter != nil {
return parser.LinkConverter(str, href, ctx)
- } else if str == href {
+ } else if str == href ||
+ str == strings.TrimPrefix(href, "mailto:") ||
+ str == strings.TrimPrefix(href, "http://") ||
+ str == strings.TrimPrefix(href, "https://") {
return str
}
return fmt.Sprintf("%s (%s)", str, href)
}
+func (parser *HTMLParser) imgToString(node *html.Node, ctx Context) string {
+ src := parser.getAttribute(node, "src")
+ alt := parser.getAttribute(node, "alt")
+ title := parser.getAttribute(node, "title")
+ width := parser.getAttribute(node, "width")
+ height := parser.getAttribute(node, "height")
+ _, isEmoji := parser.maybeGetAttribute(node, "data-mx-emoticon")
+ if parser.ImageConverter != nil {
+ return parser.ImageConverter(src, alt, title, width, height, isEmoji)
+ }
+ return alt
+}
+
func (parser *HTMLParser) tagToString(node *html.Node, ctx Context) string {
ctx = ctx.WithTag(node.Data)
switch node.Data {
@@ -309,8 +352,12 @@ func (parser *HTMLParser) tagToString(node *html.Node, ctx Context) string {
return parser.linkToString(node, ctx)
case "p":
return parser.nodeToTagAwareString(node.FirstChild, ctx)
+ case "img":
+ return parser.imgToString(node, ctx)
case "hr":
return parser.HorizontalLine
+ case "input":
+ return parser.inputToString(node, ctx)
case "pre":
var preStr, language string
if node.FirstChild != nil && node.FirstChild.Type == html.ElementNode && node.FirstChild.Data == "code" {
@@ -325,20 +372,28 @@ func (parser *HTMLParser) tagToString(node *html.Node, ctx Context) string {
if parser.MonospaceBlockConverter != nil {
return parser.MonospaceBlockConverter(preStr, language, ctx)
}
- if len(preStr) == 0 || preStr[len(preStr)-1] != '\n' {
- preStr += "\n"
- }
- return fmt.Sprintf("```%s\n%s```", language, preStr)
+ return DefaultMonospaceBlockConverter(preStr, language, ctx)
default:
return parser.nodeToTagAwareString(node.FirstChild, ctx)
}
}
+func (parser *HTMLParser) inputToString(node *html.Node, ctx Context) string {
+ if len(ctx.TagStack) > 1 && ctx.TagStack[len(ctx.TagStack)-2] == "li" {
+ _, checked := parser.maybeGetAttribute(node, "checked")
+ if checked {
+ return "[x]"
+ }
+ return "[ ]"
+ }
+ return parser.nodeToTagAwareString(node.FirstChild, ctx)
+}
+
func (parser *HTMLParser) singleNodeToString(node *html.Node, ctx Context) TaggedString {
switch node.Type {
case html.TextNode:
if !ctx.PreserveWhitespace {
- node.Data = strings.Replace(node.Data, "\n", "", -1)
+ node.Data = exstrings.CollapseSpaces(strings.ReplaceAll(node.Data, "\n", ""))
}
if parser.TextConverter != nil {
node.Data = parser.TextConverter(node.Data, ctx)
@@ -404,6 +459,35 @@ func (parser *HTMLParser) Parse(htmlData string, ctx Context) string {
return parser.nodeToTagAwareString(node, ctx)
}
+var TextHTMLParser = &HTMLParser{
+ TabsToSpaces: 4,
+ Newline: "\n",
+ HorizontalLine: "\n---\n",
+ PillConverter: DefaultPillConverter,
+}
+
+var MarkdownHTMLParser = &HTMLParser{
+ TabsToSpaces: 4,
+ Newline: "\n",
+ HorizontalLine: "\n---\n",
+ PillConverter: DefaultPillConverter,
+ LinkConverter: func(text, href string, ctx Context) string {
+ if text == href {
+ return fmt.Sprintf("<%s>", href)
+ }
+ return fmt.Sprintf("[%s](%s)", text, href)
+ },
+ MathConverter: func(s string, c Context) string {
+ return fmt.Sprintf("$%s$", s)
+ },
+ MathBlockConverter: func(s string, c Context) string {
+ return fmt.Sprintf("$$\n%s\n$$", s)
+ },
+ UnderlineConverter: func(s string, c Context) string {
+ return fmt.Sprintf("%s", s)
+ },
+}
+
// HTMLToText converts Matrix HTML into text with the default settings.
func HTMLToText(html string) string {
return (&HTMLParser{
@@ -411,23 +495,26 @@ func HTMLToText(html string) string {
Newline: "\n",
HorizontalLine: "\n---\n",
PillConverter: DefaultPillConverter,
- }).Parse(html, NewContext())
+ }).Parse(html, NewContext(context.TODO()))
+}
+
+func HTMLToMarkdownFull(parser *HTMLParser, html string) (parsed string, mentions *event.Mentions) {
+ if parser == nil {
+ parser = MarkdownHTMLParser
+ }
+ ctx := NewContext(context.TODO())
+ parsed = parser.Parse(html, ctx)
+ mentionList, _ := ctx.ReturnData[ContextKeyMentions].([]id.UserID)
+ mentions = &event.Mentions{
+ UserIDs: mentionList,
+ }
+ return
}
// HTMLToMarkdown converts Matrix HTML into markdown with the default settings.
//
// Currently, the only difference to HTMLToText is how links are formatted.
func HTMLToMarkdown(html string) string {
- return (&HTMLParser{
- TabsToSpaces: 4,
- Newline: "\n",
- HorizontalLine: "\n---\n",
- PillConverter: DefaultPillConverter,
- LinkConverter: func(text, href string, ctx Context) string {
- if text == href {
- return text
- }
- return fmt.Sprintf("[%s](%s)", text, href)
- },
- }).Parse(html, NewContext())
+ parsed, _ := HTMLToMarkdownFull(nil, html)
+ return parsed
}
diff --git a/format/markdown.go b/format/markdown.go
index fa2a8e8a..77ced0dc 100644
--- a/format/markdown.go
+++ b/format/markdown.go
@@ -8,14 +8,17 @@ package format
import (
"fmt"
+ "regexp"
"strings"
"github.com/yuin/goldmark"
"github.com/yuin/goldmark/extension"
"github.com/yuin/goldmark/renderer/html"
+ "go.mau.fi/util/exstrings"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format/mdext"
+ "maunium.net/go/mautrix/id"
)
const paragraphStart = ""
@@ -39,6 +42,55 @@ func UnwrapSingleParagraph(html string) string {
return html
}
+var mdEscapeRegex = regexp.MustCompile("([\\\\`*_[\\]()])")
+
+func EscapeMarkdown(text string) string {
+ text = mdEscapeRegex.ReplaceAllString(text, "\\$1")
+ text = strings.ReplaceAll(text, ">", ">")
+ text = strings.ReplaceAll(text, "<", "<")
+ return text
+}
+
+type uriAble interface {
+ String() string
+ URI() *id.MatrixURI
+}
+
+func MarkdownMention(id uriAble) string {
+ return MarkdownMentionWithName(id.String(), id)
+}
+
+func MarkdownMentionWithName(name string, id uriAble) string {
+ return MarkdownLink(name, id.URI().MatrixToURL())
+}
+
+func MarkdownMentionRoomID(name string, id id.RoomID, via ...string) string {
+ if name == "" {
+ name = id.String()
+ }
+ return MarkdownLink(name, id.URI(via...).MatrixToURL())
+}
+
+func MarkdownLink(name string, url string) string {
+ return fmt.Sprintf("[%s](%s)", EscapeMarkdown(name), EscapeMarkdown(url))
+}
+
+func SafeMarkdownCode[T ~string](textInput T) string {
+ if textInput == "" {
+ return "` `"
+ }
+ text := strings.ReplaceAll(string(textInput), "\n", " ")
+ backtickCount := exstrings.LongestSequenceOf(text, '`')
+ if backtickCount == 0 {
+ return fmt.Sprintf("`%s`", text)
+ }
+ quotes := strings.Repeat("`", backtickCount+1)
+ if text[0] == '`' || text[len(text)-1] == '`' {
+ return fmt.Sprintf("%s %s %s", quotes, text, quotes)
+ }
+ return fmt.Sprintf("%s%s%s", quotes, text, quotes)
+}
+
func RenderMarkdownCustom(text string, renderer goldmark.Markdown) event.MessageEventContent {
var buf strings.Builder
err := renderer.Convert([]byte(text), &buf)
@@ -49,20 +101,30 @@ func RenderMarkdownCustom(text string, renderer goldmark.Markdown) event.Message
return HTMLToContent(htmlBody)
}
-func HTMLToContent(html string) event.MessageEventContent {
- text := HTMLToMarkdown(html)
+func TextToContent(text string) event.MessageEventContent {
+ return event.MessageEventContent{
+ MsgType: event.MsgText,
+ Body: text,
+ Mentions: &event.Mentions{},
+ }
+}
+
+func HTMLToContentFull(renderer *HTMLParser, html string) event.MessageEventContent {
+ text, mentions := HTMLToMarkdownFull(renderer, html)
if html != text {
return event.MessageEventContent{
FormattedBody: html,
Format: event.FormatHTML,
MsgType: event.MsgText,
Body: text,
+ Mentions: mentions,
}
}
- return event.MessageEventContent{
- MsgType: event.MsgText,
- Body: text,
- }
+ return TextToContent(text)
+}
+
+func HTMLToContent(html string) event.MessageEventContent {
+ return HTMLToContentFull(nil, html)
}
func RenderMarkdown(text string, allowMarkdown, allowHTML bool) event.MessageEventContent {
@@ -78,9 +140,6 @@ func RenderMarkdown(text string, allowMarkdown, allowHTML bool) event.MessageEve
htmlBody = strings.Replace(text, "\n", "
", -1)
return HTMLToContent(htmlBody)
} else {
- return event.MessageEventContent{
- MsgType: event.MsgText,
- Body: text,
- }
+ return TextToContent(text)
}
}
diff --git a/format/markdown_test.go b/format/markdown_test.go
index 179de6b6..46ea4886 100644
--- a/format/markdown_test.go
+++ b/format/markdown_test.go
@@ -17,17 +17,20 @@ import (
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format"
"maunium.net/go/mautrix/format/mdext"
+ "maunium.net/go/mautrix/id"
)
func TestRenderMarkdown_PlainText(t *testing.T) {
content := format.RenderMarkdown("hello world", true, true)
- assert.Equal(t, event.MessageEventContent{MsgType: event.MsgText, Body: "hello world"}, content)
+ assert.Equal(t, event.MessageEventContent{MsgType: event.MsgText, Body: "hello world", Mentions: &event.Mentions{}}, content)
content = format.RenderMarkdown("hello world", true, false)
- assert.Equal(t, event.MessageEventContent{MsgType: event.MsgText, Body: "hello world"}, content)
+ assert.Equal(t, event.MessageEventContent{MsgType: event.MsgText, Body: "hello world", Mentions: &event.Mentions{}}, content)
content = format.RenderMarkdown("hello world", false, true)
- assert.Equal(t, event.MessageEventContent{MsgType: event.MsgText, Body: "hello world"}, content)
+ assert.Equal(t, event.MessageEventContent{MsgType: event.MsgText, Body: "hello world", Mentions: &event.Mentions{}}, content)
content = format.RenderMarkdown("hello world", false, false)
- assert.Equal(t, event.MessageEventContent{MsgType: event.MsgText, Body: "hello world"}, content)
+ assert.Equal(t, event.MessageEventContent{MsgType: event.MsgText, Body: "hello world", Mentions: &event.Mentions{}}, content)
+ content = format.RenderMarkdown(`mention`, false, false)
+ assert.Equal(t, event.MessageEventContent{MsgType: event.MsgText, Body: "mention", Mentions: &event.Mentions{}}, content)
}
func TestRenderMarkdown_EscapeHTML(t *testing.T) {
@@ -37,6 +40,7 @@ func TestRenderMarkdown_EscapeHTML(t *testing.T) {
Body: "hello world",
Format: event.FormatHTML,
FormattedBody: "<b>hello world</b>",
+ Mentions: &event.Mentions{},
}, content)
}
@@ -47,6 +51,7 @@ func TestRenderMarkdown_HTML(t *testing.T) {
Body: "**hello world**",
Format: event.FormatHTML,
FormattedBody: "hello world",
+ Mentions: &event.Mentions{},
}, content)
content = format.RenderMarkdown("hello world", true, true)
@@ -55,6 +60,18 @@ func TestRenderMarkdown_HTML(t *testing.T) {
Body: "**hello world**",
Format: event.FormatHTML,
FormattedBody: "hello world",
+ Mentions: &event.Mentions{},
+ }, content)
+
+ content = format.RenderMarkdown(`[mention](https://matrix.to/#/@user:example.com)`, true, false)
+ assert.Equal(t, event.MessageEventContent{
+ MsgType: event.MsgText,
+ Body: "mention",
+ Format: event.FormatHTML,
+ FormattedBody: `mention`,
+ Mentions: &event.Mentions{
+ UserIDs: []id.UserID{"@user:example.com"},
+ },
}, content)
}
@@ -141,3 +158,56 @@ func TestRenderMarkdown_DiscordUnderline(t *testing.T) {
assert.Equal(t, html, strings.ReplaceAll(rendered, "\n", ""))
}
}
+
+var mathTests = map[string]string{
+ "$foo$": `foo`,
+ "hello $foo$ world": `hello foo world`,
+ "$$\nfoo\nbar\n$$": `
foo
bar`,
+ "`$foo$`": `$foo$`,
+ "```\n$foo$\n```": `$foo$\n
`,
+ "~~meow $foo$ asd~~": `meow foo asd`,
+ "$5 or $10": `$5 or $10`,
+ "5$ or 10$": `5$ or 10$`,
+ "$5 or 10$": `5 or 10`,
+ "$*500*$": `*500*`,
+ "$$\n*500*\n$$": `*500*`,
+
+ // TODO: This doesn't work :(
+ // Maybe same reason as the spoiler wrapping not working?
+ //"~~$foo$~~": `foo`,
+}
+
+func TestRenderMarkdown_Math(t *testing.T) {
+ renderer := goldmark.New(goldmark.WithExtensions(extension.Strikethrough, mdext.Math, mdext.EscapeHTML), format.HTMLOptions)
+ for markdown, html := range mathTests {
+ rendered := format.UnwrapSingleParagraph(render(renderer, markdown))
+ assert.Equal(t, html, strings.ReplaceAll(rendered, "\n", "\\n"), "with input %q", markdown)
+ }
+}
+
+var customEmojiTests = map[string]string{
+ ``: `
`,
+}
+
+func TestRenderMarkdown_CustomEmoji(t *testing.T) {
+ renderer := goldmark.New(goldmark.WithExtensions(mdext.CustomEmoji), format.HTMLOptions)
+ for markdown, html := range customEmojiTests {
+ rendered := format.UnwrapSingleParagraph(render(renderer, markdown))
+ assert.Equal(t, html, rendered, "with input %q", markdown)
+ }
+}
+
+var codeTests = map[string]string{
+ "meow": "`meow`",
+ "me`ow": "``me`ow``",
+ "`me`ow": "`` `me`ow ``",
+ "me`ow`": "`` me`ow` ``",
+ "`meow`": "`` `meow` ``",
+ "`````````": "`````````` ````````` ``````````",
+}
+
+func TestSafeMarkdownCode(t *testing.T) {
+ for input, expected := range codeTests {
+ assert.Equal(t, expected, format.SafeMarkdownCode(input), "with input %q", input)
+ }
+}
diff --git a/format/mdext/customemoji.go b/format/mdext/customemoji.go
new file mode 100644
index 00000000..2884a5ea
--- /dev/null
+++ b/format/mdext/customemoji.go
@@ -0,0 +1,73 @@
+// Copyright (c) 2024 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mdext
+
+import (
+ "bytes"
+
+ "github.com/yuin/goldmark"
+ "github.com/yuin/goldmark/ast"
+ "github.com/yuin/goldmark/renderer"
+ "github.com/yuin/goldmark/util"
+)
+
+type extCustomEmoji struct{}
+type customEmojiRenderer struct {
+ funcs functionCapturer
+}
+
+// CustomEmoji is an extension that converts certain markdown images into Matrix custom emojis.
+var CustomEmoji = &extCustomEmoji{}
+
+type functionCapturer struct {
+ renderImage renderer.NodeRendererFunc
+ renderText renderer.NodeRendererFunc
+ renderString renderer.NodeRendererFunc
+}
+
+func (fc *functionCapturer) Register(kind ast.NodeKind, rendererFunc renderer.NodeRendererFunc) {
+ switch kind {
+ case ast.KindImage:
+ fc.renderImage = rendererFunc
+ case ast.KindText:
+ fc.renderText = rendererFunc
+ case ast.KindString:
+ fc.renderString = rendererFunc
+ }
+}
+
+var (
+ _ renderer.NodeRendererFuncRegisterer = (*functionCapturer)(nil)
+ _ renderer.Option = (*functionCapturer)(nil)
+)
+
+func (fc *functionCapturer) SetConfig(cfg *renderer.Config) {
+ cfg.NodeRenderers[0].Value.(renderer.NodeRenderer).RegisterFuncs(fc)
+}
+
+func (eeh *extCustomEmoji) Extend(m goldmark.Markdown) {
+ var fc functionCapturer
+ m.Renderer().AddOptions(&fc)
+ m.Renderer().AddOptions(renderer.WithNodeRenderers(util.Prioritized(&customEmojiRenderer{fc}, 0)))
+}
+
+func (cer *customEmojiRenderer) RegisterFuncs(reg renderer.NodeRendererFuncRegisterer) {
+ reg.Register(ast.KindImage, cer.renderImage)
+}
+
+var emojiPrefix = []byte("Emoji: ")
+var mxcPrefix = []byte("mxc://")
+
+func (cer *customEmojiRenderer) renderImage(w util.BufWriter, source []byte, node ast.Node, entering bool) (ast.WalkStatus, error) {
+ n, ok := node.(*ast.Image)
+ if ok && entering && bytes.HasPrefix(n.Title, emojiPrefix) && bytes.HasPrefix(n.Destination, mxcPrefix) {
+ n.Title = bytes.TrimPrefix(n.Title, emojiPrefix)
+ n.SetAttributeString("data-mx-emoticon", nil)
+ n.SetAttributeString("height", "32")
+ }
+ return cer.funcs.renderImage(w, source, node, entering)
+}
diff --git a/format/mdext/indentableparagraph.go b/format/mdext/indentableparagraph.go
new file mode 100644
index 00000000..a6ebd6c0
--- /dev/null
+++ b/format/mdext/indentableparagraph.go
@@ -0,0 +1,28 @@
+// Copyright (c) 2024 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mdext
+
+import (
+ "github.com/yuin/goldmark"
+ "github.com/yuin/goldmark/parser"
+ "github.com/yuin/goldmark/util"
+)
+
+// indentableParagraphParser is the default paragraph parser with CanAcceptIndentedLine.
+// Used when disabling CodeBlockParser (as disabling it without a replacement will make indented blocks disappear).
+type indentableParagraphParser struct {
+ parser.BlockParser
+}
+
+var defaultIndentableParagraphParser = &indentableParagraphParser{BlockParser: parser.NewParagraphParser()}
+
+func (b *indentableParagraphParser) CanAcceptIndentedLine() bool {
+ return true
+}
+
+// FixIndentedParagraphs is a goldmark option which fixes indented paragraphs when disabling CodeBlockParser.
+var FixIndentedParagraphs = goldmark.WithParserOptions(parser.WithBlockParsers(util.Prioritized(defaultIndentableParagraphParser, 500)))
diff --git a/format/mdext/math.go b/format/mdext/math.go
new file mode 100644
index 00000000..e6a6ecc5
--- /dev/null
+++ b/format/mdext/math.go
@@ -0,0 +1,240 @@
+// Copyright (c) 2024 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mdext
+
+import (
+ "bytes"
+ "fmt"
+ stdhtml "html"
+ "regexp"
+ "strings"
+ "unicode"
+
+ "github.com/yuin/goldmark"
+ "github.com/yuin/goldmark/ast"
+ "github.com/yuin/goldmark/parser"
+ "github.com/yuin/goldmark/renderer"
+ "github.com/yuin/goldmark/renderer/html"
+ "github.com/yuin/goldmark/text"
+ "github.com/yuin/goldmark/util"
+)
+
+var astKindMath = ast.NewNodeKind("Math")
+
+type astMath struct {
+ ast.BaseInline
+ value []byte
+}
+
+func (n *astMath) Dump(source []byte, level int) {
+ ast.DumpHelper(n, source, level, nil, nil)
+}
+
+func (n *astMath) Kind() ast.NodeKind {
+ return astKindMath
+}
+
+type astMathBlock struct {
+ ast.BaseBlock
+}
+
+func (n *astMathBlock) Dump(source []byte, level int) {
+ ast.DumpHelper(n, source, level, nil, nil)
+}
+
+func (n *astMathBlock) Kind() ast.NodeKind {
+ return astKindMath
+}
+
+type inlineMathParser struct{}
+
+var defaultInlineMathParser = &inlineMathParser{}
+
+func NewInlineMathParser() parser.InlineParser {
+ return defaultInlineMathParser
+}
+
+const mathDelimiter = '$'
+
+func (s *inlineMathParser) Trigger() []byte {
+ return []byte{mathDelimiter}
+}
+
+// This ignores lines where there's no space after the closing $ to avoid false positives
+var latexInlineRegexp = regexp.MustCompile(`^(\$[^$]*\$)(?:$|\s)`)
+
+func (s *inlineMathParser) Parse(parent ast.Node, block text.Reader, pc parser.Context) ast.Node {
+ before := block.PrecendingCharacter()
+ // Ignore lines where the opening $ comes after a letter or number to avoid false positives
+ if unicode.IsLetter(before) || unicode.IsNumber(before) {
+ return nil
+ }
+ line, segment := block.PeekLine()
+ idx := latexInlineRegexp.FindSubmatchIndex(line)
+ if idx == nil {
+ return nil
+ }
+ block.Advance(idx[3])
+ return &astMath{
+ value: block.Value(text.NewSegment(segment.Start+1, segment.Start+idx[3]-1)),
+ }
+}
+
+func (s *inlineMathParser) CloseBlock(parent ast.Node, pc parser.Context) {
+ // nothing to do
+}
+
+type blockMathParser struct{}
+
+var defaultBlockMathParser = &blockMathParser{}
+
+func NewBlockMathParser() parser.BlockParser {
+ return defaultBlockMathParser
+}
+
+var mathBlockInfoKey = parser.NewContextKey()
+
+type mathBlockData struct {
+ indent int
+ length int
+ node ast.Node
+}
+
+func (b *blockMathParser) Trigger() []byte {
+ return []byte{'$'}
+}
+
+func (b *blockMathParser) Open(parent ast.Node, reader text.Reader, pc parser.Context) (ast.Node, parser.State) {
+ line, _ := reader.PeekLine()
+ pos := pc.BlockOffset()
+ if pos < 0 || (line[pos] != mathDelimiter) {
+ return nil, parser.NoChildren
+ }
+ findent := pos
+ i := pos
+ for ; i < len(line) && line[i] == mathDelimiter; i++ {
+ }
+ oFenceLength := i - pos
+ if oFenceLength < 2 {
+ return nil, parser.NoChildren
+ }
+ if i < len(line)-1 {
+ rest := line[i:]
+ left := util.TrimLeftSpaceLength(rest)
+ right := util.TrimRightSpaceLength(rest)
+ if left < len(rest)-right {
+ value := rest[left : len(rest)-right]
+ if bytes.IndexByte(value, mathDelimiter) > -1 {
+ return nil, parser.NoChildren
+ }
+ }
+ }
+ node := &astMathBlock{}
+ pc.Set(mathBlockInfoKey, &mathBlockData{findent, oFenceLength, node})
+ return node, parser.NoChildren
+
+}
+
+func (b *blockMathParser) Continue(node ast.Node, reader text.Reader, pc parser.Context) parser.State {
+ line, segment := reader.PeekLine()
+ fdata := pc.Get(mathBlockInfoKey).(*mathBlockData)
+
+ w, pos := util.IndentWidth(line, reader.LineOffset())
+ if w < 4 {
+ i := pos
+ for ; i < len(line) && line[i] == mathDelimiter; i++ {
+ }
+ length := i - pos
+ if length >= fdata.length && util.IsBlank(line[i:]) {
+ newline := 1
+ if line[len(line)-1] != '\n' {
+ newline = 0
+ }
+ reader.Advance(segment.Stop - segment.Start - newline + segment.Padding)
+ return parser.Close
+ }
+ }
+ pos, padding := util.IndentPositionPadding(line, reader.LineOffset(), segment.Padding, fdata.indent)
+ if pos < 0 {
+ pos = util.FirstNonSpacePosition(line)
+ if pos < 0 {
+ pos = 0
+ }
+ padding = 0
+ }
+ seg := text.NewSegmentPadding(segment.Start+pos, segment.Stop, padding)
+ seg.ForceNewline = true // EOF as newline
+ node.Lines().Append(seg)
+ reader.AdvanceAndSetPadding(segment.Stop-segment.Start-pos-1, padding)
+ return parser.Continue | parser.NoChildren
+}
+
+func (b *blockMathParser) Close(node ast.Node, reader text.Reader, pc parser.Context) {
+ fdata := pc.Get(mathBlockInfoKey).(*mathBlockData)
+ if fdata.node == node {
+ pc.Set(mathBlockInfoKey, nil)
+ }
+}
+
+func (b *blockMathParser) CanInterruptParagraph() bool {
+ return true
+}
+
+func (b *blockMathParser) CanAcceptIndentedLine() bool {
+ return false
+}
+
+type mathHTMLRenderer struct {
+ html.Config
+}
+
+func NewMathHTMLRenderer(opts ...html.Option) renderer.NodeRenderer {
+ r := &mathHTMLRenderer{
+ Config: html.NewConfig(),
+ }
+ for _, opt := range opts {
+ opt.SetHTMLOption(&r.Config)
+ }
+ return r
+}
+
+func (r *mathHTMLRenderer) RegisterFuncs(reg renderer.NodeRendererFuncRegisterer) {
+ reg.Register(astKindMath, r.renderMath)
+}
+
+func (r *mathHTMLRenderer) renderMath(w util.BufWriter, source []byte, n ast.Node, entering bool) (ast.WalkStatus, error) {
+ if entering {
+ tag := "span"
+ var tex string
+ switch typed := n.(type) {
+ case *astMathBlock:
+ tag = "div"
+ tex = string(n.Lines().Value(source))
+ case *astMath:
+ tex = string(typed.value)
+ }
+ tex = stdhtml.EscapeString(strings.TrimSpace(tex))
+ _, _ = fmt.Fprintf(w, `<%s data-mx-maths="%s">%s%s>`, tag, tex, strings.ReplaceAll(tex, "\n", "
"), tag)
+ }
+ return ast.WalkSkipChildren, nil
+}
+
+type math struct{}
+
+// Math is an extension that allow you to use math like '$$text$$'.
+var Math = &math{}
+
+func (e *math) Extend(m goldmark.Markdown) {
+ m.Parser().AddOptions(parser.WithInlineParsers(
+ util.Prioritized(NewInlineMathParser(), 500),
+ ), parser.WithBlockParsers(
+ util.Prioritized(NewBlockMathParser(), 850),
+ ))
+ m.Renderer().AddOptions(renderer.WithNodeRenderers(
+ util.Prioritized(NewMathHTMLRenderer(), 500),
+ ))
+}
diff --git a/format/mdext/shortemphasis.go b/format/mdext/shortemphasis.go
new file mode 100644
index 00000000..62190326
--- /dev/null
+++ b/format/mdext/shortemphasis.go
@@ -0,0 +1,96 @@
+// Copyright (c) 2024 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mdext
+
+import (
+ "github.com/yuin/goldmark"
+ "github.com/yuin/goldmark/ast"
+ "github.com/yuin/goldmark/parser"
+ "github.com/yuin/goldmark/text"
+ "github.com/yuin/goldmark/util"
+)
+
+var ShortEmphasis goldmark.Extender = &shortEmphasisExtender{}
+
+type shortEmphasisExtender struct{}
+
+func (s *shortEmphasisExtender) Extend(m goldmark.Markdown) {
+ m.Parser().AddOptions(parser.WithInlineParsers(
+ util.Prioritized(&italicsParser{}, 500),
+ util.Prioritized(&boldParser{}, 500),
+ ))
+}
+
+type italicsDelimiterProcessor struct{}
+
+func (p *italicsDelimiterProcessor) IsDelimiter(b byte) bool {
+ return b == '_'
+}
+
+func (p *italicsDelimiterProcessor) CanOpenCloser(opener, closer *parser.Delimiter) bool {
+ return opener.Char == closer.Char
+}
+
+func (p *italicsDelimiterProcessor) OnMatch(consumes int) ast.Node {
+ return ast.NewEmphasis(1)
+}
+
+var defaultItalicsDelimiterProcessor = &italicsDelimiterProcessor{}
+
+type italicsParser struct{}
+
+func (s *italicsParser) Trigger() []byte {
+ return []byte{'_'}
+}
+
+func (s *italicsParser) Parse(parent ast.Node, block text.Reader, pc parser.Context) ast.Node {
+ before := block.PrecendingCharacter()
+ line, segment := block.PeekLine()
+ node := parser.ScanDelimiter(line, before, 1, defaultItalicsDelimiterProcessor)
+ if node == nil || node.OriginalLength > 1 || before == '_' {
+ return nil
+ }
+ node.Segment = segment.WithStop(segment.Start + node.OriginalLength)
+ block.Advance(node.OriginalLength)
+ pc.PushDelimiter(node)
+ return node
+}
+
+type boldDelimiterProcessor struct{}
+
+func (p *boldDelimiterProcessor) IsDelimiter(b byte) bool {
+ return b == '*'
+}
+
+func (p *boldDelimiterProcessor) CanOpenCloser(opener, closer *parser.Delimiter) bool {
+ return opener.Char == closer.Char
+}
+
+func (p *boldDelimiterProcessor) OnMatch(consumes int) ast.Node {
+ return ast.NewEmphasis(2)
+}
+
+var defaultBoldDelimiterProcessor = &boldDelimiterProcessor{}
+
+type boldParser struct{}
+
+func (s *boldParser) Trigger() []byte {
+ return []byte{'*'}
+}
+
+func (s *boldParser) Parse(parent ast.Node, block text.Reader, pc parser.Context) ast.Node {
+ before := block.PrecendingCharacter()
+ line, segment := block.PeekLine()
+ node := parser.ScanDelimiter(line, before, 1, defaultBoldDelimiterProcessor)
+ if node == nil || node.OriginalLength > 1 || before == '*' {
+ return nil
+ }
+ node.Segment = segment.WithStop(segment.Start + node.OriginalLength)
+ block.Advance(node.OriginalLength)
+ pc.PushDelimiter(node)
+ return node
+}
diff --git a/format/mdext/shortstrike.go b/format/mdext/shortstrike.go
new file mode 100644
index 00000000..00328f22
--- /dev/null
+++ b/format/mdext/shortstrike.go
@@ -0,0 +1,76 @@
+// Copyright (c) 2024 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mdext
+
+import (
+ "github.com/yuin/goldmark"
+ gast "github.com/yuin/goldmark/ast"
+ "github.com/yuin/goldmark/extension"
+ "github.com/yuin/goldmark/extension/ast"
+ "github.com/yuin/goldmark/parser"
+ "github.com/yuin/goldmark/renderer"
+ "github.com/yuin/goldmark/text"
+ "github.com/yuin/goldmark/util"
+)
+
+var ShortStrike goldmark.Extender = &shortStrikeExtender{length: 1}
+var LongStrike goldmark.Extender = &shortStrikeExtender{length: 2}
+
+type shortStrikeExtender struct {
+ length int
+}
+
+func (s *shortStrikeExtender) Extend(m goldmark.Markdown) {
+ m.Parser().AddOptions(parser.WithInlineParsers(
+ util.Prioritized(&strikethroughParser{length: s.length}, 500),
+ ))
+ m.Renderer().AddOptions(renderer.WithNodeRenderers(
+ util.Prioritized(extension.NewStrikethroughHTMLRenderer(), 500),
+ ))
+}
+
+type strikethroughDelimiterProcessor struct{}
+
+func (p *strikethroughDelimiterProcessor) IsDelimiter(b byte) bool {
+ return b == '~'
+}
+
+func (p *strikethroughDelimiterProcessor) CanOpenCloser(opener, closer *parser.Delimiter) bool {
+ return opener.Char == closer.Char
+}
+
+func (p *strikethroughDelimiterProcessor) OnMatch(consumes int) gast.Node {
+ return ast.NewStrikethrough()
+}
+
+var defaultStrikethroughDelimiterProcessor = &strikethroughDelimiterProcessor{}
+
+type strikethroughParser struct {
+ length int
+}
+
+func (s *strikethroughParser) Trigger() []byte {
+ return []byte{'~'}
+}
+
+func (s *strikethroughParser) Parse(parent gast.Node, block text.Reader, pc parser.Context) gast.Node {
+ before := block.PrecendingCharacter()
+ line, segment := block.PeekLine()
+ node := parser.ScanDelimiter(line, before, 1, defaultStrikethroughDelimiterProcessor)
+ if node == nil || node.OriginalLength != s.length || before == '~' {
+ return nil
+ }
+
+ node.Segment = segment.WithStop(segment.Start + node.OriginalLength)
+ block.Advance(node.OriginalLength)
+ pc.PushDelimiter(node)
+ return node
+}
+
+func (s *strikethroughParser) CloseBlock(parent gast.Node, pc parser.Context) {
+ // nothing to do
+}
diff --git a/go.mod b/go.mod
index 48ff59e0..49a1d4e4 100644
--- a/go.mod
+++ b/go.mod
@@ -1,35 +1,42 @@
module maunium.net/go/mautrix
-go 1.20
+go 1.25.0
+
+toolchain go1.26.0
require (
- github.com/gorilla/mux v1.8.0
- github.com/gorilla/websocket v1.5.0
- github.com/lib/pq v1.10.9
- github.com/mattn/go-sqlite3 v1.14.19
- github.com/rs/zerolog v1.31.0
- github.com/stretchr/testify v1.8.4
- github.com/tidwall/gjson v1.17.0
+ filippo.io/edwards25519 v1.2.0
+ github.com/chzyer/readline v1.5.1
+ github.com/coder/websocket v1.8.14
+ github.com/lib/pq v1.11.2
+ github.com/mattn/go-sqlite3 v1.14.34
+ github.com/rs/xid v1.6.0
+ github.com/rs/zerolog v1.34.0
+ github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
+ github.com/stretchr/testify v1.11.1
+ github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
- github.com/yuin/goldmark v1.6.0
- go.mau.fi/util v0.3.0
- go.mau.fi/zeroconfig v0.1.2
- golang.org/x/crypto v0.18.0
- golang.org/x/exp v0.0.0-20240112132812-db7319d0e0e3
- golang.org/x/net v0.20.0
+ github.com/yuin/goldmark v1.7.16
+ go.mau.fi/util v0.9.6
+ go.mau.fi/zeroconfig v0.2.0
+ golang.org/x/crypto v0.48.0
+ golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa
+ golang.org/x/net v0.50.0
+ golang.org/x/sync v0.19.0
gopkg.in/yaml.v3 v3.0.1
maunium.net/go/mauflag v1.0.0
- maunium.net/go/maulogger/v2 v2.4.1
)
require (
- github.com/coreos/go-systemd/v22 v22.5.0 // indirect
+ github.com/coreos/go-systemd/v22 v22.6.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
- github.com/mattn/go-colorable v0.1.13 // indirect
- github.com/mattn/go-isatty v0.0.19 // indirect
+ github.com/mattn/go-colorable v0.1.14 // indirect
+ github.com/mattn/go-isatty v0.0.20 // indirect
+ github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
- github.com/tidwall/pretty v1.2.0 // indirect
- golang.org/x/sys v0.16.0 // indirect
+ github.com/tidwall/pretty v1.2.1 // indirect
+ golang.org/x/sys v0.41.0 // indirect
+ golang.org/x/text v0.34.0 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
)
diff --git a/go.sum b/go.sum
index 9061a651..871a5156 100644
--- a/go.sum
+++ b/go.sum
@@ -1,56 +1,77 @@
-github.com/DATA-DOG/go-sqlmock v1.5.1 h1:FK6RCIUSfmbnI/imIICmboyQBkOckutaa6R5YYlLZyo=
-github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs=
+filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo=
+filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc=
+github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
+github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
+github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM=
+github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ=
+github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI=
+github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk=
+github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04=
+github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
+github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
+github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
+github.com/coreos/go-systemd/v22 v22.6.0 h1:aGVa/v8B7hpb0TKl0MWoAavPDmHvobFe5R5zn0bCJWo=
+github.com/coreos/go-systemd/v22 v22.6.0/go.mod h1:iG+pp635Fo7ZmV/j14KUcmEyWF+0X7Lua8rrTWzYgWU=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
-github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
-github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
-github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
-github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
-github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
-github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
-github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
+github.com/lib/pq v1.11.2 h1:x6gxUeu39V0BHZiugWe8LXZYZ+Utk7hSJGThs8sdzfs=
+github.com/lib/pq v1.11.2/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
+github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
+github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
-github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
-github.com/mattn/go-sqlite3 v1.14.19 h1:fhGleo2h1p8tVChob4I9HpmVFIAkKGpiukdrgQbWfGI=
-github.com/mattn/go-sqlite3 v1.14.19/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
+github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
+github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
+github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp6Zk=
+github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
+github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 h1:KPpdlQLZcHfTMQRi6bFQ7ogNO0ltFT4PmtwTLW4W+14=
+github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
-github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
-github.com/rs/zerolog v1.31.0 h1:FcTR3NnLWW+NnTwwhFWiJSZr4ECLpqCm6QsEnyvbV4A=
-github.com/rs/zerolog v1.31.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
-github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
-github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
+github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU=
+github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
+github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
+github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
+github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
+github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
+github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
+github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
-github.com/tidwall/gjson v1.17.0 h1:/Jocvlh98kcTfpN2+JzGQWQcqrPQwDrVEMApx/M5ZwM=
-github.com/tidwall/gjson v1.17.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
+github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
+github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
-github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
+github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
+github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
-github.com/yuin/goldmark v1.6.0 h1:boZcn2GTjpsynOsC0iJHnBWa4Bi0qzfJjthwauItG68=
-github.com/yuin/goldmark v1.6.0/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
-go.mau.fi/util v0.3.0 h1:Lt3lbRXP6ZBqTINK0EieRWor3zEwwwrDT14Z5N8RUCs=
-go.mau.fi/util v0.3.0/go.mod h1:9dGsBCCbZJstx16YgnVMVi3O2bOizELoKpugLD4FoGs=
-go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto=
-go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70=
-golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc=
-golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
-golang.org/x/exp v0.0.0-20240112132812-db7319d0e0e3 h1:hNQpMuAJe5CtcUqCXaWga3FHu+kQvCqcsoVaQgSV60o=
-golang.org/x/exp v0.0.0-20240112132812-db7319d0e0e3/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08=
-golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo=
-golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
+github.com/yuin/goldmark v1.7.16 h1:n+CJdUxaFMiDUNnWC3dMWCIQJSkxH4uz3ZwQBkAlVNE=
+github.com/yuin/goldmark v1.7.16/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
+go.mau.fi/util v0.9.6 h1:2nsvxm49KhI3wrFltr0+wSUBlnQ4CMtykuELjpIU+ts=
+go.mau.fi/util v0.9.6/go.mod h1:sIJpRH7Iy5Ad1SBuxQoatxtIeErgzxCtjd/2hCMkYMI=
+go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU=
+go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w=
+golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
+golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
+golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa h1:Zt3DZoOFFYkKhDT3v7Lm9FDMEV06GpzjG2jrqW+QTE0=
+golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa/go.mod h1:K79w1Vqn7PoiZn+TkNpx3BUWUQksGO3JcVX6qIjytmA=
+golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60=
+golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM=
+golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
+golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
+golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU=
-golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
+golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
+golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
+golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
+golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
@@ -59,5 +80,3 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M=
maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA=
-maunium.net/go/maulogger/v2 v2.4.1 h1:N7zSdd0mZkB2m2JtFUsiGTQQAdP0YeFWT7YMc80yAL8=
-maunium.net/go/maulogger/v2 v2.4.1/go.mod h1:omPuYwYBILeVQobz8uO3XC8DIRuEb5rXYlQSuqrbCho=
diff --git a/id/contenturi.go b/id/contenturi.go
index cfd00c3e..67127b6c 100644
--- a/id/contenturi.go
+++ b/id/contenturi.go
@@ -12,12 +12,19 @@ import (
"encoding/json"
"errors"
"fmt"
+ "regexp"
"strings"
)
var (
- InvalidContentURI = errors.New("invalid Matrix content URI")
- InputNotJSONString = errors.New("input doesn't look like a JSON string")
+ ErrInvalidContentURI = errors.New("invalid Matrix content URI")
+ ErrInputNotJSONString = errors.New("input doesn't look like a JSON string")
+)
+
+// Deprecated: use variables prefixed with Err
+var (
+ InvalidContentURI = ErrInvalidContentURI
+ InputNotJSONString = ErrInputNotJSONString
)
// ContentURIString is a string that's expected to be a Matrix content URI.
@@ -54,9 +61,9 @@ func ParseContentURI(uri string) (parsed ContentURI, err error) {
if len(uri) == 0 {
return
} else if !strings.HasPrefix(uri, "mxc://") {
- err = InvalidContentURI
+ err = ErrInvalidContentURI
} else if index := strings.IndexRune(uri[6:], '/'); index == -1 || index == len(uri)-7 {
- err = InvalidContentURI
+ err = ErrInvalidContentURI
} else {
parsed.Homeserver = uri[6 : 6+index]
parsed.FileID = uri[6+index+1:]
@@ -70,9 +77,9 @@ func ParseContentURIBytes(uri []byte) (parsed ContentURI, err error) {
if len(uri) == 0 {
return
} else if !bytes.HasPrefix(uri, mxcBytes) {
- err = InvalidContentURI
+ err = ErrInvalidContentURI
} else if index := bytes.IndexRune(uri[6:], '/'); index == -1 || index == len(uri)-7 {
- err = InvalidContentURI
+ err = ErrInvalidContentURI
} else {
parsed.Homeserver = string(uri[6 : 6+index])
parsed.FileID = string(uri[6+index+1:])
@@ -85,7 +92,7 @@ func (uri *ContentURI) UnmarshalJSON(raw []byte) (err error) {
*uri = ContentURI{}
return nil
} else if len(raw) < 2 || raw[0] != '"' || raw[len(raw)-1] != '"' {
- return InputNotJSONString
+ return fmt.Errorf("ContentURI: %w", ErrInputNotJSONString)
}
parsed, err := ParseContentURIBytes(raw[1 : len(raw)-1])
if err != nil {
@@ -156,3 +163,21 @@ func (uri ContentURI) CUString() ContentURIString {
func (uri ContentURI) IsEmpty() bool {
return len(uri.Homeserver) == 0 || len(uri.FileID) == 0
}
+
+var simpleHomeserverRegex = regexp.MustCompile(`^[a-zA-Z0-9.:-]+$`)
+
+func (uri ContentURI) IsValid() bool {
+ return IsValidMediaID(uri.FileID) && uri.Homeserver != "" && simpleHomeserverRegex.MatchString(uri.Homeserver)
+}
+
+func IsValidMediaID(mediaID string) bool {
+ if len(mediaID) == 0 {
+ return false
+ }
+ for _, char := range mediaID {
+ if (char < 'A' || char > 'Z') && (char < 'a' || char > 'z') && (char < '0' || char > '9') && char != '-' && char != '_' {
+ return false
+ }
+ }
+ return true
+}
diff --git a/id/crypto.go b/id/crypto.go
index 84fcd67f..ee857f78 100644
--- a/id/crypto.go
+++ b/id/crypto.go
@@ -7,8 +7,11 @@
package id
import (
+ "encoding/base64"
"fmt"
"strings"
+
+ "go.mau.fi/util/random"
)
// OlmMsgType is an Olm message type
@@ -44,6 +47,47 @@ const (
XSUsageUserSigning CrossSigningUsage = "user_signing"
)
+type KeyBackupAlgorithm string
+
+const (
+ KeyBackupAlgorithmMegolmBackupV1 KeyBackupAlgorithm = "m.megolm_backup.v1.curve25519-aes-sha2"
+)
+
+type KeySource string
+
+func (source KeySource) String() string {
+ return string(source)
+}
+
+func (source KeySource) Int() int {
+ switch source {
+ case KeySourceDirect:
+ return 100
+ case KeySourceBackup:
+ return 90
+ case KeySourceImport:
+ return 80
+ case KeySourceForward:
+ return 50
+ default:
+ return 0
+ }
+}
+
+const (
+ KeySourceDirect KeySource = "direct"
+ KeySourceBackup KeySource = "backup"
+ KeySourceImport KeySource = "import"
+ KeySourceForward KeySource = "forward"
+)
+
+// BackupVersion is an arbitrary string that identifies a server side key backup.
+type KeyBackupVersion string
+
+func (version KeyBackupVersion) String() string {
+ return string(version)
+}
+
// A SessionID is an arbitrary string that identifies an Olm or Megolm session.
type SessionID string
@@ -59,6 +103,12 @@ func (ed25519 Ed25519) String() string {
return string(ed25519)
}
+func (ed25519 Ed25519) Bytes() []byte {
+ val, _ := base64.RawStdEncoding.DecodeString(string(ed25519))
+ // TODO handle errors
+ return val
+}
+
func (ed25519 Ed25519) Fingerprint() string {
spacedSigningKey := make([]byte, len(ed25519)+(len(ed25519)-1)/4)
var ptr = 0
@@ -82,6 +132,12 @@ func (curve25519 Curve25519) String() string {
return string(curve25519)
}
+func (curve25519 Curve25519) Bytes() []byte {
+ val, _ := base64.RawStdEncoding.DecodeString(string(curve25519))
+ // TODO handle errors
+ return val
+}
+
// A DeviceID is an arbitrary string that references a specific device.
type DeviceID string
@@ -147,3 +203,29 @@ type CrossSigningKey struct {
Key Ed25519
First Ed25519
}
+
+// Secret storage keys
+type Secret string
+
+func (s Secret) String() string {
+ return string(s)
+}
+
+const (
+ SecretXSMaster Secret = "m.cross_signing.master"
+ SecretXSSelfSigning Secret = "m.cross_signing.self_signing"
+ SecretXSUserSigning Secret = "m.cross_signing.user_signing"
+ SecretMegolmBackupV1 Secret = "m.megolm_backup.v1"
+)
+
+// VerificationTransactionID is a unique identifier for a verification
+// transaction.
+type VerificationTransactionID string
+
+func NewVerificationTransactionID() VerificationTransactionID {
+ return VerificationTransactionID(random.String(32))
+}
+
+func (t VerificationTransactionID) String() string {
+ return string(t)
+}
diff --git a/id/matrixuri.go b/id/matrixuri.go
index 5ec403e9..d5c78bc7 100644
--- a/id/matrixuri.go
+++ b/id/matrixuri.go
@@ -54,7 +54,7 @@ var SigilToPathSegment = map[rune]string{
func (uri *MatrixURI) getQuery() url.Values {
q := make(url.Values)
- if uri.Via != nil && len(uri.Via) > 0 {
+ if len(uri.Via) > 0 {
q["via"] = uri.Via
}
if len(uri.Action) > 0 {
@@ -65,6 +65,9 @@ func (uri *MatrixURI) getQuery() url.Values {
// String converts the parsed matrix: URI back into the string representation.
func (uri *MatrixURI) String() string {
+ if uri == nil {
+ return ""
+ }
parts := []string{
SigilToPathSegment[uri.Sigil1],
url.PathEscape(uri.MXID1),
@@ -81,6 +84,9 @@ func (uri *MatrixURI) String() string {
// MatrixToURL converts to parsed matrix: URI into a matrix.to URL
func (uri *MatrixURI) MatrixToURL() string {
+ if uri == nil {
+ return ""
+ }
fragment := fmt.Sprintf("#/%s", url.PathEscape(uri.PrimaryIdentifier()))
if uri.Sigil2 != 0 {
fragment = fmt.Sprintf("%s/%s", fragment, url.PathEscape(uri.SecondaryIdentifier()))
@@ -96,13 +102,16 @@ func (uri *MatrixURI) MatrixToURL() string {
// PrimaryIdentifier returns the first Matrix identifier in the URI.
// Currently room IDs, room aliases and user IDs can be in the primary identifier slot.
func (uri *MatrixURI) PrimaryIdentifier() string {
+ if uri == nil {
+ return ""
+ }
return fmt.Sprintf("%c%s", uri.Sigil1, uri.MXID1)
}
// SecondaryIdentifier returns the second Matrix identifier in the URI.
// Currently only event IDs can be in the secondary identifier slot.
func (uri *MatrixURI) SecondaryIdentifier() string {
- if uri.Sigil2 == 0 {
+ if uri == nil || uri.Sigil2 == 0 {
return ""
}
return fmt.Sprintf("%c%s", uri.Sigil2, uri.MXID2)
@@ -110,7 +119,7 @@ func (uri *MatrixURI) SecondaryIdentifier() string {
// UserID returns the user ID from the URI if the primary identifier is a user ID.
func (uri *MatrixURI) UserID() UserID {
- if uri.Sigil1 == '@' {
+ if uri != nil && uri.Sigil1 == '@' {
return UserID(uri.PrimaryIdentifier())
}
return ""
@@ -118,7 +127,7 @@ func (uri *MatrixURI) UserID() UserID {
// RoomID returns the room ID from the URI if the primary identifier is a room ID.
func (uri *MatrixURI) RoomID() RoomID {
- if uri.Sigil1 == '!' {
+ if uri != nil && uri.Sigil1 == '!' {
return RoomID(uri.PrimaryIdentifier())
}
return ""
@@ -126,7 +135,7 @@ func (uri *MatrixURI) RoomID() RoomID {
// RoomAlias returns the room alias from the URI if the primary identifier is a room alias.
func (uri *MatrixURI) RoomAlias() RoomAlias {
- if uri.Sigil1 == '#' {
+ if uri != nil && uri.Sigil1 == '#' {
return RoomAlias(uri.PrimaryIdentifier())
}
return ""
@@ -134,7 +143,7 @@ func (uri *MatrixURI) RoomAlias() RoomAlias {
// EventID returns the event ID from the URI if the primary identifier is a room ID or alias and the secondary identifier is an event ID.
func (uri *MatrixURI) EventID() EventID {
- if (uri.Sigil1 == '!' || uri.Sigil1 == '#') && uri.Sigil2 == '$' {
+ if uri != nil && (uri.Sigil1 == '!' || uri.Sigil1 == '#') && uri.Sigil2 == '$' {
return EventID(uri.SecondaryIdentifier())
}
return ""
@@ -201,10 +210,14 @@ func ProcessMatrixURI(uri *url.URL) (*MatrixURI, error) {
if len(parts[1]) == 0 {
return nil, ErrEmptySecondSegment
}
- parsed.MXID1 = parts[1]
+ var err error
+ parsed.MXID1, err = url.PathUnescape(parts[1])
+ if err != nil {
+ return nil, fmt.Errorf("failed to url decode second segment %q: %w", parts[1], err)
+ }
// Step 6: if the first part is a room and the URI has 4 segments, construct a second level identifier
- if (parsed.Sigil1 == '!' || parsed.Sigil1 == '#') && len(parts) == 4 {
+ if parsed.Sigil1 == '!' && len(parts) == 4 {
// a: find the sigil from the third segment
switch parts[2] {
case "e", "event":
@@ -217,7 +230,10 @@ func ProcessMatrixURI(uri *url.URL) (*MatrixURI, error) {
if len(parts[3]) == 0 {
return nil, ErrEmptyFourthSegment
}
- parsed.MXID2 = parts[3]
+ parsed.MXID2, err = url.PathUnescape(parts[3])
+ if err != nil {
+ return nil, fmt.Errorf("failed to url decode fourth segment %q: %w", parts[3], err)
+ }
}
// Step 7: parse the query and extract via and action items
diff --git a/id/matrixuri_test.go b/id/matrixuri_test.go
index d26d4bfd..90a0754d 100644
--- a/id/matrixuri_test.go
+++ b/id/matrixuri_test.go
@@ -16,12 +16,11 @@ import (
)
var (
- roomIDLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org"}
- roomIDViaLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org", Via: []string{"maunium.net", "matrix.org"}}
- roomAliasLink = id.MatrixURI{Sigil1: '#', MXID1: "someroom:example.org"}
- roomIDEventLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org", Sigil2: '$', MXID2: "uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s"}
- roomAliasEventLink = id.MatrixURI{Sigil1: '#', MXID1: "someroom:example.org", Sigil2: '$', MXID2: "uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s"}
- userLink = id.MatrixURI{Sigil1: '@', MXID1: "user:example.org"}
+ roomIDLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org"}
+ roomIDViaLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org", Via: []string{"maunium.net", "matrix.org"}}
+ roomAliasLink = id.MatrixURI{Sigil1: '#', MXID1: "someroom:example.org"}
+ roomIDEventLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org", Sigil2: '$', MXID2: "uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s"}
+ userLink = id.MatrixURI{Sigil1: '@', MXID1: "user:example.org"}
escapeRoomIDEventLink = id.MatrixURI{Sigil1: '!', MXID1: "meow & 🐈️:example.org", Sigil2: '$', MXID2: "uOH4C9cK4HhMeFWkUXMbdF/dtndJ0j9je+kIK3XpV1s"}
)
@@ -31,7 +30,6 @@ func TestMatrixURI_MatrixToURL(t *testing.T) {
assert.Equal(t, "https://matrix.to/#/%217NdBVvkd4aLSbgKt9RXl:example.org?via=maunium.net&via=matrix.org", roomIDViaLink.MatrixToURL())
assert.Equal(t, "https://matrix.to/#/%23someroom:example.org", roomAliasLink.MatrixToURL())
assert.Equal(t, "https://matrix.to/#/%217NdBVvkd4aLSbgKt9RXl:example.org/$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s", roomIDEventLink.MatrixToURL())
- assert.Equal(t, "https://matrix.to/#/%23someroom:example.org/$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s", roomAliasEventLink.MatrixToURL())
assert.Equal(t, "https://matrix.to/#/@user:example.org", userLink.MatrixToURL())
assert.Equal(t, "https://matrix.to/#/%21meow%20&%20%F0%9F%90%88%EF%B8%8F:example.org/$uOH4C9cK4HhMeFWkUXMbdF%2FdtndJ0j9je+kIK3XpV1s", escapeRoomIDEventLink.MatrixToURL())
}
@@ -41,7 +39,6 @@ func TestMatrixURI_String(t *testing.T) {
assert.Equal(t, "matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org?via=maunium.net&via=matrix.org", roomIDViaLink.String())
assert.Equal(t, "matrix:r/someroom:example.org", roomAliasLink.String())
assert.Equal(t, "matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s", roomIDEventLink.String())
- assert.Equal(t, "matrix:r/someroom:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s", roomAliasEventLink.String())
assert.Equal(t, "matrix:u/user:example.org", userLink.String())
assert.Equal(t, "matrix:roomid/meow%20&%20%F0%9F%90%88%EF%B8%8F:example.org/e/uOH4C9cK4HhMeFWkUXMbdF%2FdtndJ0j9je+kIK3XpV1s", escapeRoomIDEventLink.String())
}
@@ -80,8 +77,12 @@ func TestParseMatrixURI_RoomID(t *testing.T) {
parsedVia, err := id.ParseMatrixURI("matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org?via=maunium.net&via=matrix.org")
require.NoError(t, err)
require.NotNil(t, parsedVia)
+ parsedEncoded, err := id.ParseMatrixURI("matrix:roomid/7NdBVvkd4aLSbgKt9RXl%3Aexample.org")
+ require.NoError(t, err)
+ require.NotNil(t, parsedEncoded)
assert.Equal(t, roomIDLink, *parsed)
+ assert.Equal(t, roomIDLink, *parsedEncoded)
assert.Equal(t, roomIDViaLink, *parsedVia)
}
@@ -98,19 +99,11 @@ func TestParseMatrixURI_UserID(t *testing.T) {
}
func TestParseMatrixURI_EventID(t *testing.T) {
- parsed1, err := id.ParseMatrixURI("matrix:r/someroom:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s")
+ parsed, err := id.ParseMatrixURI("matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s")
require.NoError(t, err)
- require.NotNil(t, parsed1)
- parsed2, err := id.ParseMatrixURI("matrix:room/someroom:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s")
- require.NoError(t, err)
- require.NotNil(t, parsed2)
- parsed3, err := id.ParseMatrixURI("matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s")
- require.NoError(t, err)
- require.NotNil(t, parsed3)
+ require.NotNil(t, parsed)
- assert.Equal(t, roomAliasEventLink, *parsed1)
- assert.Equal(t, roomAliasEventLink, *parsed2)
- assert.Equal(t, roomIDEventLink, *parsed3)
+ assert.Equal(t, roomIDEventLink, *parsed)
}
func TestParseMatrixToURL_RoomAlias(t *testing.T) {
@@ -158,21 +151,13 @@ func TestParseMatrixToURL_UserID(t *testing.T) {
}
func TestParseMatrixToURL_EventID(t *testing.T) {
- parsed1, err := id.ParseMatrixToURL("https://matrix.to/#/#someroom:example.org/$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s")
+ parsed, err := id.ParseMatrixToURL("https://matrix.to/#/!7NdBVvkd4aLSbgKt9RXl:example.org/$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s")
require.NoError(t, err)
- require.NotNil(t, parsed1)
- parsed2, err := id.ParseMatrixToURL("https://matrix.to/#/!7NdBVvkd4aLSbgKt9RXl:example.org/$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s")
+ require.NotNil(t, parsed)
+ parsedEncoded, err := id.ParseMatrixToURL("https://matrix.to/#/%217NdBVvkd4aLSbgKt9RXl:example.org/%24uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s")
require.NoError(t, err)
- require.NotNil(t, parsed2)
- parsed1Encoded, err := id.ParseMatrixToURL("https://matrix.to/#/%23someroom:example.org/%24uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s")
- require.NoError(t, err)
- require.NotNil(t, parsed1)
- parsed2Encoded, err := id.ParseMatrixToURL("https://matrix.to/#/%217NdBVvkd4aLSbgKt9RXl:example.org/%24uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s")
- require.NoError(t, err)
- require.NotNil(t, parsed2)
+ require.NotNil(t, parsedEncoded)
- assert.Equal(t, roomAliasEventLink, *parsed1)
- assert.Equal(t, roomAliasEventLink, *parsed1Encoded)
- assert.Equal(t, roomIDEventLink, *parsed2)
- assert.Equal(t, roomIDEventLink, *parsed2Encoded)
+ assert.Equal(t, roomIDEventLink, *parsed)
+ assert.Equal(t, roomIDEventLink, *parsedEncoded)
}
diff --git a/id/opaque.go b/id/opaque.go
index 16863b95..c1ad4988 100644
--- a/id/opaque.go
+++ b/id/opaque.go
@@ -32,11 +32,17 @@ type EventID string
// https://github.com/matrix-org/matrix-doc/pull/2716
type BatchID string
+// A DelayID is a string identifying a delayed event.
+type DelayID string
+
func (roomID RoomID) String() string {
return string(roomID)
}
func (roomID RoomID) URI(via ...string) *MatrixURI {
+ if roomID == "" {
+ return nil
+ }
return &MatrixURI{
Sigil1: '!',
MXID1: string(roomID)[1:],
@@ -45,6 +51,11 @@ func (roomID RoomID) URI(via ...string) *MatrixURI {
}
func (roomID RoomID) EventURI(eventID EventID, via ...string) *MatrixURI {
+ if roomID == "" {
+ return nil
+ } else if eventID == "" {
+ return roomID.URI(via...)
+ }
return &MatrixURI{
Sigil1: '!',
MXID1: string(roomID)[1:],
@@ -59,13 +70,20 @@ func (roomAlias RoomAlias) String() string {
}
func (roomAlias RoomAlias) URI() *MatrixURI {
+ if roomAlias == "" {
+ return nil
+ }
return &MatrixURI{
Sigil1: '#',
MXID1: string(roomAlias)[1:],
}
}
+// Deprecated: room alias event links should not be used. Use room IDs instead.
func (roomAlias RoomAlias) EventURI(eventID EventID) *MatrixURI {
+ if roomAlias == "" {
+ return nil
+ }
return &MatrixURI{
Sigil1: '#',
MXID1: string(roomAlias)[1:],
diff --git a/id/roomversion.go b/id/roomversion.go
new file mode 100644
index 00000000..578c10bd
--- /dev/null
+++ b/id/roomversion.go
@@ -0,0 +1,265 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package id
+
+import (
+ "errors"
+ "fmt"
+ "slices"
+)
+
+type RoomVersion string
+
+const (
+ RoomV0 RoomVersion = "" // No room version, used for rooms created before room versions were introduced, equivalent to v1
+ RoomV1 RoomVersion = "1"
+ RoomV2 RoomVersion = "2"
+ RoomV3 RoomVersion = "3"
+ RoomV4 RoomVersion = "4"
+ RoomV5 RoomVersion = "5"
+ RoomV6 RoomVersion = "6"
+ RoomV7 RoomVersion = "7"
+ RoomV8 RoomVersion = "8"
+ RoomV9 RoomVersion = "9"
+ RoomV10 RoomVersion = "10"
+ RoomV11 RoomVersion = "11"
+ RoomV12 RoomVersion = "12"
+)
+
+func (rv RoomVersion) Equals(versions ...RoomVersion) bool {
+ return slices.Contains(versions, rv)
+}
+
+func (rv RoomVersion) NotEquals(versions ...RoomVersion) bool {
+ return !rv.Equals(versions...)
+}
+
+var ErrUnknownRoomVersion = errors.New("unknown room version")
+
+func (rv RoomVersion) unknownVersionError() error {
+ return fmt.Errorf("%w %s", ErrUnknownRoomVersion, rv)
+}
+
+func (rv RoomVersion) IsKnown() bool {
+ switch rv {
+ case RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11, RoomV12:
+ return true
+ default:
+ return false
+ }
+}
+
+type StateResVersion int
+
+const (
+ // StateResV1 is the original state resolution algorithm.
+ StateResV1 StateResVersion = 0
+ // StateResV2 is state resolution v2 introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/1759
+ StateResV2 StateResVersion = 1
+ // StateResV2_1 is state resolution v2.1 introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/4297
+ StateResV2_1 StateResVersion = 2
+)
+
+// StateResVersion returns the version of the state resolution algorithm used by this room version.
+func (rv RoomVersion) StateResVersion() StateResVersion {
+ switch rv {
+ case RoomV0, RoomV1:
+ return StateResV1
+ case RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11:
+ return StateResV2
+ case RoomV12:
+ return StateResV2_1
+ default:
+ panic(rv.unknownVersionError())
+ }
+}
+
+type EventIDFormat int
+
+const (
+ // EventIDFormatCustom is the original format used by room v1 and v2.
+ // Event IDs in this format are an arbitrary string followed by a colon and the server name.
+ EventIDFormatCustom EventIDFormat = 0
+ // EventIDFormatBase64 is the format used by room v3 introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/1659.
+ // Event IDs in this format are the standard unpadded base64-encoded SHA256 reference hash of the event.
+ EventIDFormatBase64 EventIDFormat = 1
+ // EventIDFormatURLSafeBase64 is the format used by room v4 and later introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/2002.
+ // Event IDs in this format are the url-safe unpadded base64-encoded SHA256 reference hash of the event.
+ EventIDFormatURLSafeBase64 EventIDFormat = 2
+)
+
+// EventIDFormat returns the format of event IDs used by this room version.
+func (rv RoomVersion) EventIDFormat() EventIDFormat {
+ switch rv {
+ case RoomV0, RoomV1, RoomV2:
+ return EventIDFormatCustom
+ case RoomV3:
+ return EventIDFormatBase64
+ default:
+ return EventIDFormatURLSafeBase64
+ }
+}
+
+/////////////////////
+// Room v5 changes //
+/////////////////////
+// https://github.com/matrix-org/matrix-spec-proposals/pull/2077
+
+// EnforceSigningKeyValidity returns true if the `valid_until_ts` field of federation signing keys
+// must be enforced on received events.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/2076
+func (rv RoomVersion) EnforceSigningKeyValidity() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4)
+}
+
+/////////////////////
+// Room v6 changes //
+/////////////////////
+// https://github.com/matrix-org/matrix-spec-proposals/pull/2240
+
+// SpecialCasedAliasesAuth returns true if the `m.room.aliases` event authorization is special cased
+// to only always allow servers to modify the state event with their own server name as state key.
+// This also implies that the `aliases` field is protected from redactions.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/2432
+func (rv RoomVersion) SpecialCasedAliasesAuth() bool {
+ return rv.Equals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5)
+}
+
+// ForbidFloatsAndBigInts returns true if floats and integers greater than 2^53-1 or lower than -2^53+1 are forbidden everywhere.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/2540
+func (rv RoomVersion) ForbidFloatsAndBigInts() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5)
+}
+
+// NotificationsPowerLevels returns true if the `notifications` field in `m.room.power_levels` is validated in event auth.
+// However, the field is not protected from redactions.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/2209
+func (rv RoomVersion) NotificationsPowerLevels() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5)
+}
+
+/////////////////////
+// Room v7 changes //
+/////////////////////
+// https://github.com/matrix-org/matrix-spec-proposals/pull/2998
+
+// Knocks returns true if the `knock` join rule is supported.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/2403
+func (rv RoomVersion) Knocks() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6)
+}
+
+/////////////////////
+// Room v8 changes //
+/////////////////////
+// https://github.com/matrix-org/matrix-spec-proposals/pull/3289
+
+// RestrictedJoins returns true if the `restricted` join rule is supported.
+// This also implies that the `allow` field in the `m.room.join_rules` event is supported and protected from redactions.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/3083
+func (rv RoomVersion) RestrictedJoins() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7)
+}
+
+/////////////////////
+// Room v9 changes //
+/////////////////////
+// https://github.com/matrix-org/matrix-spec-proposals/pull/3375
+
+// RestrictedJoinsFix returns true if the `join_authorised_via_users_server` field in `m.room.member` events is protected from redactions.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/3375
+func (rv RoomVersion) RestrictedJoinsFix() bool {
+ return rv.RestrictedJoins() && rv != RoomV8
+}
+
+//////////////////////
+// Room v10 changes //
+//////////////////////
+// https://github.com/matrix-org/matrix-spec-proposals/pull/3604
+
+// ValidatePowerLevelInts returns true if the known values in `m.room.power_levels` must be integers (and not strings).
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/3667
+func (rv RoomVersion) ValidatePowerLevelInts() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9)
+}
+
+// KnockRestricted returns true if the `knock_restricted` join rule is supported.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/3787
+func (rv RoomVersion) KnockRestricted() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9)
+}
+
+//////////////////////
+// Room v11 changes //
+//////////////////////
+// https://github.com/matrix-org/matrix-spec-proposals/pull/3820
+
+// CreatorInContent returns true if the `m.room.create` event has a `creator` field in content.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/2175
+func (rv RoomVersion) CreatorInContent() bool {
+ return rv.Equals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10)
+}
+
+// RedactsInContent returns true if the `m.room.redaction` event has the `redacts` field in content instead of at the top level.
+// The redaction protection is also moved from the top level to the content field.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/2174
+// (and https://github.com/matrix-org/matrix-spec-proposals/pull/2176 for the redaction protection).
+func (rv RoomVersion) RedactsInContent() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10)
+}
+
+// UpdatedRedactionRules returns true if various updates to the redaction algorithm are applied.
+//
+// Specifically:
+//
+// * the `membership`, `origin`, and `prev_state` fields at the top level of all events are no longer protected.
+// * the entire content of `m.room.create` is protected.
+// * the `redacts` field in `m.room.redaction` content is protected instead of the top-level field.
+// * the `m.room.power_levels` event protects the `invite` field in content.
+// * the `signed` field inside the `third_party_invite` field in content of `m.room.member` events is protected.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/2176,
+// https://github.com/matrix-org/matrix-spec-proposals/pull/3821, and
+// https://github.com/matrix-org/matrix-spec-proposals/pull/3989
+func (rv RoomVersion) UpdatedRedactionRules() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10)
+}
+
+//////////////////////
+// Room v12 changes //
+//////////////////////
+// https://github.com/matrix-org/matrix-spec-proposals/pull/4304
+
+// Return value of StateResVersion was changed to StateResV2_1
+
+// PrivilegedRoomCreators returns true if the creator(s) of a room always have infinite power level.
+// This also implies that the `m.room.create` event has an `additional_creators` field,
+// and that the creators can't be present in the `m.room.power_levels` event.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/4289
+func (rv RoomVersion) PrivilegedRoomCreators() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11)
+}
+
+// RoomIDIsCreateEventID returns true if the ID of rooms is the same as the ID of the `m.room.create` event.
+// This also implies that `m.room.create` events do not have a `room_id` field.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/4291
+func (rv RoomVersion) RoomIDIsCreateEventID() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11)
+}
diff --git a/id/servername.go b/id/servername.go
new file mode 100644
index 00000000..923705b6
--- /dev/null
+++ b/id/servername.go
@@ -0,0 +1,58 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package id
+
+import (
+ "regexp"
+ "strconv"
+)
+
+type ParsedServerNameType int
+
+const (
+ ServerNameDNS ParsedServerNameType = iota
+ ServerNameIPv4
+ ServerNameIPv6
+)
+
+type ParsedServerName struct {
+ Type ParsedServerNameType
+ Host string
+ Port int
+}
+
+var ServerNameRegex = regexp.MustCompile(`^(?:\[([0-9A-Fa-f:.]{2,45})]|(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})|([0-9A-Za-z.-]{1,255}))(?::(\d{1,5}))?$`)
+
+func ValidateServerName(serverName string) bool {
+ return len(serverName) <= 255 && len(serverName) > 0 && ServerNameRegex.MatchString(serverName)
+}
+
+func ParseServerName(serverName string) *ParsedServerName {
+ if len(serverName) > 255 || len(serverName) < 1 {
+ return nil
+ }
+ match := ServerNameRegex.FindStringSubmatch(serverName)
+ if len(match) != 5 {
+ return nil
+ }
+ port, _ := strconv.Atoi(match[4])
+ parsed := &ParsedServerName{
+ Port: port,
+ }
+ switch {
+ case match[1] != "":
+ parsed.Type = ServerNameIPv6
+ parsed.Host = match[1]
+ case match[2] != "":
+ parsed.Type = ServerNameIPv4
+ parsed.Host = match[2]
+ case match[3] != "":
+ parsed.Type = ServerNameDNS
+ parsed.Host = match[3]
+ }
+ return parsed
+}
diff --git a/id/trust.go b/id/trust.go
index 04f6e36b..6255093e 100644
--- a/id/trust.go
+++ b/id/trust.go
@@ -16,6 +16,7 @@ type TrustState int
const (
TrustStateBlacklisted TrustState = -100
+ TrustStateDeviceKeyMismatch TrustState = -5
TrustStateUnset TrustState = 0
TrustStateUnknownDevice TrustState = 10
TrustStateForwarded TrustState = 20
@@ -23,7 +24,7 @@ const (
TrustStateCrossSignedTOFU TrustState = 100
TrustStateCrossSignedVerified TrustState = 200
TrustStateVerified TrustState = 300
- TrustStateInvalid TrustState = (1 << 31) - 1
+ TrustStateInvalid TrustState = -2147483647
)
func (ts *TrustState) UnmarshalText(data []byte) error {
@@ -44,6 +45,8 @@ func ParseTrustState(val string) TrustState {
switch strings.ToLower(val) {
case "blacklisted":
return TrustStateBlacklisted
+ case "device-key-mismatch":
+ return TrustStateDeviceKeyMismatch
case "unverified":
return TrustStateUnset
case "cross-signed-untrusted":
@@ -67,6 +70,8 @@ func (ts TrustState) String() string {
switch ts {
case TrustStateBlacklisted:
return "blacklisted"
+ case TrustStateDeviceKeyMismatch:
+ return "device-key-mismatch"
case TrustStateUnset:
return "unverified"
case TrustStateCrossSignedUntrusted:
diff --git a/id/userid.go b/id/userid.go
index 3aae3b21..726a0d58 100644
--- a/id/userid.go
+++ b/id/userid.go
@@ -30,25 +30,41 @@ func NewEncodedUserID(localpart, homeserver string) UserID {
}
var (
- ErrInvalidUserID = errors.New("is not a valid user ID")
- ErrNoncompliantLocalpart = errors.New("contains characters that are not allowed")
- ErrUserIDTooLong = errors.New("the given user ID is longer than 255 characters")
- ErrEmptyLocalpart = errors.New("empty localparts are not allowed")
+ ErrInvalidUserID = errors.New("is not a valid user ID")
+ ErrNoncompliantLocalpart = errors.New("contains characters that are not allowed")
+ ErrUserIDTooLong = errors.New("the given user ID is longer than 255 characters")
+ ErrEmptyLocalpart = errors.New("empty localparts are not allowed")
+ ErrNoncompliantServerPart = errors.New("is not a valid server name")
)
+// ParseCommonIdentifier parses a common identifier according to https://spec.matrix.org/v1.9/appendices/#common-identifier-format
+func ParseCommonIdentifier[Stringish ~string](identifier Stringish) (sigil byte, localpart, homeserver string) {
+ if len(identifier) == 0 {
+ return
+ }
+ sigil = identifier[0]
+ strIdentifier := string(identifier)
+ colonIdx := strings.IndexByte(strIdentifier, ':')
+ if colonIdx > 0 {
+ localpart = strIdentifier[1:colonIdx]
+ homeserver = strIdentifier[colonIdx+1:]
+ } else {
+ localpart = strIdentifier[1:]
+ }
+ return
+}
+
// Parse parses the user ID into the localpart and server name.
//
// Note that this only enforces very basic user ID formatting requirements: user IDs start with
// a @, and contain a : after the @. If you want to enforce localpart validity, see the
// ParseAndValidate and ValidateUserLocalpart functions.
func (userID UserID) Parse() (localpart, homeserver string, err error) {
- if len(userID) == 0 || userID[0] != '@' || !strings.ContainsRune(string(userID), ':') {
- // This error wrapping lets you use errors.Is() nicely even though the message contains the user ID
+ var sigil byte
+ sigil, localpart, homeserver = ParseCommonIdentifier(userID)
+ if sigil != '@' || homeserver == "" {
err = fmt.Errorf("'%s' %w", userID, ErrInvalidUserID)
- return
}
- parts := strings.SplitN(string(userID), ":", 2)
- localpart, homeserver = strings.TrimPrefix(parts[0], "@"), parts[1]
return
}
@@ -66,6 +82,9 @@ func (userID UserID) Homeserver() string {
//
// This does not parse or validate the user ID. Use the ParseAndValidate method if you want to ensure the user ID is valid first.
func (userID UserID) URI() *MatrixURI {
+ if userID == "" {
+ return nil
+ }
return &MatrixURI{
Sigil1: '@',
MXID1: string(userID)[1:],
@@ -85,21 +104,32 @@ func ValidateUserLocalpart(localpart string) error {
return nil
}
-// ParseAndValidate parses the user ID into the localpart and server name like Parse,
-// and also validates that the localpart is allowed according to the user identifiers spec.
-func (userID UserID) ParseAndValidate() (localpart, homeserver string, err error) {
- localpart, homeserver, err = userID.Parse()
+// ParseAndValidateStrict is a stricter version of ParseAndValidateRelaxed that checks the localpart to only allow non-historical localparts.
+// This should be used with care: there are real users still using historical localparts.
+func (userID UserID) ParseAndValidateStrict() (localpart, homeserver string, err error) {
+ localpart, homeserver, err = userID.ParseAndValidateRelaxed()
if err == nil {
err = ValidateUserLocalpart(localpart)
}
- if err == nil && len(userID) > UserIDMaxLength {
+ return
+}
+
+// ParseAndValidateRelaxed parses the user ID into the localpart and server name like Parse,
+// and also validates that the user ID is not too long and that the server name is valid.
+func (userID UserID) ParseAndValidateRelaxed() (localpart, homeserver string, err error) {
+ if len(userID) > UserIDMaxLength {
err = ErrUserIDTooLong
+ return
+ }
+ localpart, homeserver, err = userID.Parse()
+ if err == nil && !ValidateServerName(homeserver) {
+ err = fmt.Errorf("%q %q", homeserver, ErrNoncompliantServerPart)
}
return
}
func (userID UserID) ParseAndDecode() (localpart, homeserver string, err error) {
- localpart, homeserver, err = userID.ParseAndValidate()
+ localpart, homeserver, err = userID.ParseAndValidateStrict()
if err == nil {
localpart, err = DecodeUserLocalpart(localpart)
}
@@ -189,15 +219,15 @@ func DecodeUserLocalpart(str string) (string, error) {
for i := 0; i < len(strBytes); i++ {
b := strBytes[i]
if !isValidByte(b) {
- return "", fmt.Errorf("Byte pos %d: Invalid byte", i)
+ return "", fmt.Errorf("invalid encoded byte at position %d: %c", i, b)
}
if b == '_' { // next byte is a-z and should be upper-case or is another _ and should be a literal _
if i+1 >= len(strBytes) {
- return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding but ran out of string", i)
+ return "", fmt.Errorf("unexpected end of string after underscore at %d", i)
}
if !isValidEscapedChar(strBytes[i+1]) { // invalid escaping
- return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding", i)
+ return "", fmt.Errorf("unexpected byte %c after underscore at %d", strBytes[i+1], i)
}
if strBytes[i+1] == '_' {
outputBuffer.WriteByte('_')
@@ -207,7 +237,7 @@ func DecodeUserLocalpart(str string) (string, error) {
i++ // skip next byte since we just handled it
} else if b == '=' { // next 2 bytes are hex and should be buffered ready to be read as utf8
if i+2 >= len(strBytes) {
- return "", fmt.Errorf("Byte pos: %d: expected quote-printable encoding but ran out of string", i)
+ return "", fmt.Errorf("unexpected end of string after equals sign at %d", i)
}
dst := make([]byte, 1)
_, err := hex.Decode(dst, strBytes[i+1:i+3])
diff --git a/id/userid_test.go b/id/userid_test.go
index 359bc687..57a88066 100644
--- a/id/userid_test.go
+++ b/id/userid_test.go
@@ -38,30 +38,30 @@ func TestUserID_Parse_Invalid(t *testing.T) {
assert.True(t, errors.Is(err, id.ErrInvalidUserID))
}
-func TestUserID_ParseAndValidate_Invalid(t *testing.T) {
+func TestUserID_ParseAndValidateStrict_Invalid(t *testing.T) {
const inputUserID = "@s p a c e:maunium.net"
- _, _, err := id.UserID(inputUserID).ParseAndValidate()
+ _, _, err := id.UserID(inputUserID).ParseAndValidateStrict()
assert.Error(t, err)
assert.True(t, errors.Is(err, id.ErrNoncompliantLocalpart))
}
-func TestUserID_ParseAndValidate_Empty(t *testing.T) {
+func TestUserID_ParseAndValidateStrict_Empty(t *testing.T) {
const inputUserID = "@:ponies.im"
- _, _, err := id.UserID(inputUserID).ParseAndValidate()
+ _, _, err := id.UserID(inputUserID).ParseAndValidateStrict()
assert.Error(t, err)
assert.True(t, errors.Is(err, id.ErrEmptyLocalpart))
}
-func TestUserID_ParseAndValidate_Long(t *testing.T) {
+func TestUserID_ParseAndValidateStrict_Long(t *testing.T) {
const inputUserID = "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa:example.com"
- _, _, err := id.UserID(inputUserID).ParseAndValidate()
+ _, _, err := id.UserID(inputUserID).ParseAndValidateStrict()
assert.Error(t, err)
assert.True(t, errors.Is(err, id.ErrUserIDTooLong))
}
-func TestUserID_ParseAndValidate_NotLong(t *testing.T) {
+func TestUserID_ParseAndValidateStrict_NotLong(t *testing.T) {
const inputUserID = "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa:example.com"
- _, _, err := id.UserID(inputUserID).ParseAndValidate()
+ _, _, err := id.UserID(inputUserID).ParseAndValidateStrict()
assert.NoError(t, err)
}
@@ -70,7 +70,7 @@ func TestUserIDEncoding(t *testing.T) {
const encodedLocalpart = "_this=20local+part=20contains=20_il_le_ga_l=20ch=c3=a4racters=20=f0=9f=9a=a8"
const inputServerName = "example.com"
userID := id.NewEncodedUserID(inputLocalpart, inputServerName)
- parsedLocalpart, parsedServerName, err := userID.ParseAndValidate()
+ parsedLocalpart, parsedServerName, err := userID.ParseAndValidateStrict()
assert.NoError(t, err)
assert.Equal(t, encodedLocalpart, parsedLocalpart)
assert.Equal(t, inputServerName, parsedServerName)
diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go
new file mode 100644
index 00000000..4d2bc7cf
--- /dev/null
+++ b/mediaproxy/mediaproxy.go
@@ -0,0 +1,525 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mediaproxy
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "mime"
+ "mime/multipart"
+ "net/http"
+ "net/textproto"
+ "net/url"
+ "os"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/rs/zerolog"
+ "github.com/rs/zerolog/hlog"
+ "go.mau.fi/util/exerrors"
+ "go.mau.fi/util/exhttp"
+ "go.mau.fi/util/ptr"
+ "go.mau.fi/util/requestlog"
+
+ "maunium.net/go/mautrix"
+ "maunium.net/go/mautrix/federation"
+ "maunium.net/go/mautrix/id"
+)
+
+type GetMediaResponse interface {
+ isGetMediaResponse()
+}
+
+func (*GetMediaResponseURL) isGetMediaResponse() {}
+func (*GetMediaResponseData) isGetMediaResponse() {}
+func (*GetMediaResponseCallback) isGetMediaResponse() {}
+func (*GetMediaResponseFile) isGetMediaResponse() {}
+
+type GetMediaResponseURL struct {
+ URL string
+ ExpiresAt time.Time
+}
+
+type GetMediaResponseWriter interface {
+ GetMediaResponse
+ io.WriterTo
+ GetContentType() string
+ GetContentLength() int64
+}
+
+var (
+ _ GetMediaResponseWriter = (*GetMediaResponseCallback)(nil)
+ _ GetMediaResponseWriter = (*GetMediaResponseData)(nil)
+)
+
+type GetMediaResponseData struct {
+ Reader io.ReadCloser
+ ContentType string
+ ContentLength int64
+}
+
+func (d *GetMediaResponseData) WriteTo(w io.Writer) (int64, error) {
+ return io.Copy(w, d.Reader)
+}
+
+func (d *GetMediaResponseData) GetContentType() string {
+ return d.ContentType
+}
+
+func (d *GetMediaResponseData) GetContentLength() int64 {
+ return d.ContentLength
+}
+
+type GetMediaResponseCallback struct {
+ Callback func(w io.Writer) (int64, error)
+ ContentType string
+ ContentLength int64
+}
+
+func (d *GetMediaResponseCallback) WriteTo(w io.Writer) (int64, error) {
+ return d.Callback(w)
+}
+
+func (d *GetMediaResponseCallback) GetContentLength() int64 {
+ return d.ContentLength
+}
+
+func (d *GetMediaResponseCallback) GetContentType() string {
+ return d.ContentType
+}
+
+type FileMeta struct {
+ ContentType string
+ ReplacementFile string
+}
+
+type GetMediaResponseFile struct {
+ Callback func(w *os.File) (*FileMeta, error)
+}
+
+type GetMediaFunc = func(ctx context.Context, mediaID string, params map[string]string) (response GetMediaResponse, err error)
+
+type MediaProxy struct {
+ KeyServer *federation.KeyServer
+ ServerAuth *federation.ServerAuth
+
+ GetMedia GetMediaFunc
+ PrepareProxyRequest func(*http.Request)
+
+ serverName string
+ serverKey *federation.SigningKey
+
+ FederationRouter *http.ServeMux
+ ClientMediaRouter *http.ServeMux
+}
+
+func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProxy, error) {
+ parsed, err := federation.ParseSynapseKey(serverKey)
+ if err != nil {
+ return nil, err
+ }
+ mp := &MediaProxy{
+ serverName: serverName,
+ serverKey: parsed,
+ GetMedia: getMedia,
+ KeyServer: &federation.KeyServer{
+ KeyProvider: &federation.StaticServerKey{
+ ServerName: serverName,
+ Key: parsed,
+ },
+ WellKnownTarget: fmt.Sprintf("%s:443", serverName),
+ Version: federation.ServerVersion{
+ Name: "mautrix-go media proxy",
+ Version: strings.TrimPrefix(mautrix.VersionWithCommit, "v"),
+ },
+ },
+ }
+ mp.FederationRouter = http.NewServeMux()
+ mp.FederationRouter.HandleFunc("GET /v1/media/download/{mediaID}", mp.DownloadMediaFederation)
+ mp.FederationRouter.HandleFunc("GET /v1/media/thumbnail/{mediaID}", mp.DownloadMediaFederation)
+ mp.FederationRouter.HandleFunc("GET /v1/version", mp.KeyServer.GetServerVersion)
+ mp.ClientMediaRouter = http.NewServeMux()
+ mp.ClientMediaRouter.HandleFunc("GET /download/{serverName}/{mediaID}", mp.DownloadMedia)
+ mp.ClientMediaRouter.HandleFunc("GET /download/{serverName}/{mediaID}/{fileName}", mp.DownloadMedia)
+ mp.ClientMediaRouter.HandleFunc("GET /thumbnail/{serverName}/{mediaID}", mp.DownloadMedia)
+ mp.ClientMediaRouter.HandleFunc("PUT /upload/{serverName}/{mediaID}", mp.UploadNotSupported)
+ mp.ClientMediaRouter.HandleFunc("POST /upload", mp.UploadNotSupported)
+ mp.ClientMediaRouter.HandleFunc("POST /create", mp.UploadNotSupported)
+ mp.ClientMediaRouter.HandleFunc("GET /config", mp.UploadNotSupported)
+ mp.ClientMediaRouter.HandleFunc("GET /preview_url", mp.PreviewURLNotSupported)
+ return mp, nil
+}
+
+type BasicConfig struct {
+ ServerName string `yaml:"server_name" json:"server_name"`
+ ServerKey string `yaml:"server_key" json:"server_key"`
+ FederationAuth bool `yaml:"federation_auth" json:"federation_auth"`
+ WellKnownResponse string `yaml:"well_known_response" json:"well_known_response"`
+}
+
+func NewFromConfig(cfg BasicConfig, getMedia GetMediaFunc) (*MediaProxy, error) {
+ mp, err := New(cfg.ServerName, cfg.ServerKey, getMedia)
+ if err != nil {
+ return nil, err
+ }
+ if cfg.WellKnownResponse != "" {
+ mp.KeyServer.WellKnownTarget = cfg.WellKnownResponse
+ }
+ if cfg.FederationAuth {
+ mp.EnableServerAuth(nil, nil)
+ }
+ return mp, nil
+}
+
+type ServerConfig struct {
+ Hostname string `yaml:"hostname" json:"hostname"`
+ Port uint16 `yaml:"port" json:"port"`
+}
+
+func (mp *MediaProxy) Listen(cfg ServerConfig) error {
+ router := http.NewServeMux()
+ mp.RegisterRoutes(router, zerolog.Nop())
+ return http.ListenAndServe(fmt.Sprintf("%s:%d", cfg.Hostname, cfg.Port), router)
+}
+
+func (mp *MediaProxy) GetServerName() string {
+ return mp.serverName
+}
+
+func (mp *MediaProxy) GetServerKey() *federation.SigningKey {
+ return mp.serverKey
+}
+
+func (mp *MediaProxy) EnableServerAuth(client *federation.Client, keyCache federation.KeyCache) {
+ if keyCache == nil {
+ keyCache = federation.NewInMemoryCache()
+ }
+ if client == nil {
+ resCache, _ := keyCache.(federation.ResolutionCache)
+ client = federation.NewClient(mp.serverName, mp.serverKey, resCache)
+ }
+ mp.ServerAuth = federation.NewServerAuth(client, keyCache, func(auth federation.XMatrixAuth) string {
+ return mp.GetServerName()
+ })
+}
+
+func (mp *MediaProxy) RegisterRoutes(router *http.ServeMux, log zerolog.Logger) {
+ errorBodies := exhttp.ErrorBodies{
+ NotFound: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint")).MarshalJSON()),
+ MethodNotAllowed: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint")).MarshalJSON()),
+ }
+ router.Handle("/_matrix/federation/", exhttp.ApplyMiddleware(
+ mp.FederationRouter,
+ exhttp.StripPrefix("/_matrix/federation"),
+ hlog.NewHandler(log),
+ hlog.RequestIDHandler("request_id", "Request-Id"),
+ requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}),
+ exhttp.HandleErrors(errorBodies),
+ ))
+ router.Handle("/_matrix/client/v1/media/", exhttp.ApplyMiddleware(
+ mp.ClientMediaRouter,
+ exhttp.StripPrefix("/_matrix/client/v1/media"),
+ hlog.NewHandler(log),
+ hlog.RequestIDHandler("request_id", "Request-Id"),
+ exhttp.CORSMiddleware,
+ requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}),
+ exhttp.HandleErrors(errorBodies),
+ ))
+ mp.KeyServer.Register(router, log)
+}
+
+var ErrInvalidMediaIDSyntax = errors.New("invalid media ID syntax")
+
+func queryToMap(vals url.Values) map[string]string {
+ m := make(map[string]string, len(vals))
+ for k, v := range vals {
+ m[k] = v[0]
+ }
+ return m
+}
+
+func (mp *MediaProxy) getMedia(w http.ResponseWriter, r *http.Request) GetMediaResponse {
+ mediaID := r.PathValue("mediaID")
+ if !id.IsValidMediaID(mediaID) {
+ mautrix.MNotFound.WithMessage("Media ID %q is not valid", mediaID).Write(w)
+ return nil
+ }
+ resp, err := mp.GetMedia(r.Context(), mediaID, queryToMap(r.URL.Query()))
+ if err != nil {
+ var mautrixRespError mautrix.RespError
+ if errors.Is(err, ErrInvalidMediaIDSyntax) {
+ mautrix.MNotFound.WithMessage("This is a media proxy at %q, other media downloads are not available here", mp.serverName).Write(w)
+ } else if errors.As(err, &mautrixRespError) {
+ mautrixRespError.Write(w)
+ } else {
+ zerolog.Ctx(r.Context()).Err(err).Str("media_id", mediaID).Msg("Failed to get media URL")
+ mautrix.MNotFound.WithMessage("Media not found").Write(w)
+ }
+ return nil
+ }
+ return resp
+}
+
+func startMultipart(ctx context.Context, w http.ResponseWriter) *multipart.Writer {
+ mpw := multipart.NewWriter(w)
+ w.Header().Set("Content-Type", strings.Replace(mpw.FormDataContentType(), "form-data", "mixed", 1))
+ w.WriteHeader(http.StatusOK)
+ metaPart, err := mpw.CreatePart(textproto.MIMEHeader{
+ "Content-Type": {"application/json"},
+ })
+ if err != nil {
+ zerolog.Ctx(ctx).Err(err).Msg("Failed to create multipart metadata field")
+ return nil
+ }
+ _, err = metaPart.Write([]byte(`{}`))
+ if err != nil {
+ zerolog.Ctx(ctx).Err(err).Msg("Failed to write multipart metadata field")
+ return nil
+ }
+ return mpw
+}
+
+func (mp *MediaProxy) DownloadMediaFederation(w http.ResponseWriter, r *http.Request) {
+ if mp.ServerAuth != nil {
+ var err *mautrix.RespError
+ r, err = mp.ServerAuth.Authenticate(r)
+ if err != nil {
+ err.Write(w)
+ return
+ }
+ }
+ ctx := r.Context()
+ log := zerolog.Ctx(ctx)
+
+ resp := mp.getMedia(w, r)
+ if resp == nil {
+ return
+ }
+
+ var mpw *multipart.Writer
+ if urlResp, ok := resp.(*GetMediaResponseURL); ok {
+ mpw = startMultipart(ctx, w)
+ if mpw == nil {
+ return
+ }
+ _, err := mpw.CreatePart(textproto.MIMEHeader{
+ "Location": {urlResp.URL},
+ })
+ if err != nil {
+ log.Err(err).Msg("Failed to create multipart redirect field")
+ return
+ }
+ } else if fileResp, ok := resp.(*GetMediaResponseFile); ok {
+ responseStarted, err := doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error {
+ mpw = startMultipart(ctx, w)
+ if mpw == nil {
+ return fmt.Errorf("failed to start multipart writer")
+ }
+ dataPart, err := mpw.CreatePart(textproto.MIMEHeader{
+ "Content-Type": {mimeType},
+ })
+ if err != nil {
+ return fmt.Errorf("failed to create multipart data field: %w", err)
+ }
+ _, err = wt.WriteTo(dataPart)
+ return err
+ })
+ if err != nil {
+ log.Err(err).Msg("Failed to do media proxy with temp file")
+ if !responseStarted {
+ var mautrixRespError mautrix.RespError
+ if errors.As(err, &mautrixRespError) {
+ mautrixRespError.Write(w)
+ } else {
+ mautrix.MUnknown.WithMessage("Internal error proxying media").Write(w)
+ }
+ }
+ return
+ }
+ } else if dataResp, ok := resp.(GetMediaResponseWriter); ok {
+ mpw = startMultipart(ctx, w)
+ if mpw == nil {
+ return
+ }
+ dataPart, err := mpw.CreatePart(textproto.MIMEHeader{
+ "Content-Type": {dataResp.GetContentType()},
+ })
+ if err != nil {
+ log.Err(err).Msg("Failed to create multipart data field")
+ return
+ }
+ _, err = dataResp.WriteTo(dataPart)
+ if err != nil {
+ log.Err(err).Msg("Failed to write multipart data field")
+ return
+ }
+ } else {
+ panic(fmt.Errorf("unknown GetMediaResponse type %T", resp))
+ }
+ err := mpw.Close()
+ if err != nil {
+ log.Err(err).Msg("Failed to close multipart writer")
+ return
+ }
+}
+
+func (mp *MediaProxy) addHeaders(w http.ResponseWriter, mimeType, fileName string) {
+ w.Header().Set("Cache-Control", "public, max-age=31536000, immutable")
+ contentDisposition := "attachment"
+ switch mimeType {
+ case "text/css", "text/plain", "text/csv", "application/json", "application/ld+json", "image/jpeg", "image/gif",
+ "image/png", "image/apng", "image/webp", "image/avif", "video/mp4", "video/webm", "video/ogg", "video/quicktime",
+ "audio/mp4", "audio/webm", "audio/aac", "audio/mpeg", "audio/ogg", "audio/wave", "audio/wav", "audio/x-wav",
+ "audio/x-pn-wav", "audio/flac", "audio/x-flac", "application/pdf":
+ contentDisposition = "inline"
+ }
+ if fileName != "" {
+ contentDisposition = mime.FormatMediaType(contentDisposition, map[string]string{
+ "filename": fileName,
+ })
+ }
+ w.Header().Set("Content-Disposition", contentDisposition)
+ w.Header().Set("Content-Type", mimeType)
+}
+
+func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) {
+ ctx := r.Context()
+ log := zerolog.Ctx(ctx)
+ if r.PathValue("serverName") != mp.serverName {
+ mautrix.MNotFound.WithMessage("This is a media proxy at %q, other media downloads are not available here", mp.serverName).Write(w)
+ return
+ }
+ resp := mp.getMedia(w, r)
+ if resp == nil {
+ return
+ }
+
+ if urlResp, ok := resp.(*GetMediaResponseURL); ok {
+ w.Header().Set("Location", urlResp.URL)
+ expirySeconds := (time.Until(urlResp.ExpiresAt) - 5*time.Minute).Seconds()
+ if urlResp.ExpiresAt.IsZero() {
+ w.Header().Set("Cache-Control", "public, max-age=31536000, immutable")
+ } else if expirySeconds > 0 {
+ cacheControl := fmt.Sprintf("public, max-age=%d, immutable", int(expirySeconds))
+ w.Header().Set("Cache-Control", cacheControl)
+ } else {
+ w.Header().Set("Cache-Control", "no-store")
+ }
+ w.WriteHeader(http.StatusTemporaryRedirect)
+ } else if fileResp, ok := resp.(*GetMediaResponseFile); ok {
+ responseStarted, err := doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error {
+ mp.addHeaders(w, mimeType, r.PathValue("fileName"))
+ w.Header().Set("Content-Length", strconv.FormatInt(size, 10))
+ w.WriteHeader(http.StatusOK)
+ _, err := wt.WriteTo(w)
+ return err
+ })
+ if err != nil {
+ log.Err(err).Msg("Failed to do media proxy with temp file")
+ if !responseStarted {
+ var mautrixRespError mautrix.RespError
+ if errors.As(err, &mautrixRespError) {
+ mautrixRespError.Write(w)
+ } else {
+ mautrix.MUnknown.WithMessage("Internal error proxying media").Write(w)
+ }
+ }
+ }
+ } else if writerResp, ok := resp.(GetMediaResponseWriter); ok {
+ if dataResp, ok := writerResp.(*GetMediaResponseData); ok {
+ defer dataResp.Reader.Close()
+ }
+ mp.addHeaders(w, writerResp.GetContentType(), r.PathValue("fileName"))
+ if writerResp.GetContentLength() != 0 {
+ w.Header().Set("Content-Length", strconv.FormatInt(writerResp.GetContentLength(), 10))
+ }
+ w.WriteHeader(http.StatusOK)
+ _, err := writerResp.WriteTo(w)
+ if err != nil {
+ log.Err(err).Msg("Failed to write media data")
+ }
+ } else {
+ panic(fmt.Errorf("unknown GetMediaResponse type %T", resp))
+ }
+}
+
+func doTempFileDownload(
+ data *GetMediaResponseFile,
+ respond func(w io.WriterTo, size int64, mimeType string) error,
+) (bool, error) {
+ tempFile, err := os.CreateTemp("", "mautrix-mediaproxy-*")
+ if err != nil {
+ return false, fmt.Errorf("failed to create temp file: %w", err)
+ }
+ origTempFile := tempFile
+ defer func() {
+ _ = origTempFile.Close()
+ _ = os.Remove(origTempFile.Name())
+ }()
+ meta, err := data.Callback(tempFile)
+ if err != nil {
+ return false, err
+ }
+ if meta.ReplacementFile != "" {
+ tempFile, err = os.Open(meta.ReplacementFile)
+ if err != nil {
+ return false, fmt.Errorf("failed to open replacement file: %w", err)
+ }
+ defer func() {
+ _ = tempFile.Close()
+ _ = os.Remove(origTempFile.Name())
+ }()
+ } else {
+ _, err = tempFile.Seek(0, io.SeekStart)
+ if err != nil {
+ return false, fmt.Errorf("failed to seek to start of temp file: %w", err)
+ }
+ }
+ fileInfo, err := tempFile.Stat()
+ if err != nil {
+ return false, fmt.Errorf("failed to stat temp file: %w", err)
+ }
+ mimeType := meta.ContentType
+ if mimeType == "" {
+ buf := make([]byte, 512)
+ n, err := tempFile.Read(buf)
+ if err != nil {
+ return false, fmt.Errorf("failed to read temp file to detect mime: %w", err)
+ }
+ buf = buf[:n]
+ _, err = tempFile.Seek(0, io.SeekStart)
+ if err != nil {
+ return false, fmt.Errorf("failed to seek to start of temp file: %w", err)
+ }
+ mimeType = http.DetectContentType(buf)
+ }
+ err = respond(tempFile, fileInfo.Size(), mimeType)
+ if err != nil {
+ return true, err
+ }
+ return true, nil
+}
+
+var (
+ ErrUploadNotSupported = mautrix.MUnrecognized.
+ WithMessage("This is a media proxy and does not support media uploads.").
+ WithStatus(http.StatusNotImplemented)
+ ErrPreviewURLNotSupported = mautrix.MUnrecognized.
+ WithMessage("This is a media proxy and does not support URL previews.").
+ WithStatus(http.StatusNotImplemented)
+)
+
+func (mp *MediaProxy) UploadNotSupported(w http.ResponseWriter, r *http.Request) {
+ ErrUploadNotSupported.Write(w)
+}
+
+func (mp *MediaProxy) PreviewURLNotSupported(w http.ResponseWriter, r *http.Request) {
+ ErrPreviewURLNotSupported.Write(w)
+}
diff --git a/mockserver/mockserver.go b/mockserver/mockserver.go
new file mode 100644
index 00000000..507c24a5
--- /dev/null
+++ b/mockserver/mockserver.go
@@ -0,0 +1,307 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mockserver
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "maps"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ globallog "github.com/rs/zerolog/log" // zerolog-allow-global-log
+ "github.com/stretchr/testify/require"
+ "go.mau.fi/util/dbutil"
+ "go.mau.fi/util/exerrors"
+ "go.mau.fi/util/exhttp"
+ "go.mau.fi/util/random"
+
+ "maunium.net/go/mautrix"
+ "maunium.net/go/mautrix/crypto"
+ "maunium.net/go/mautrix/crypto/cryptohelper"
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+)
+
+func mustDecode(r *http.Request, data any) {
+ exerrors.PanicIfNotNil(json.NewDecoder(r.Body).Decode(data))
+}
+
+type userAndDeviceID struct {
+ UserID id.UserID
+ DeviceID id.DeviceID
+}
+
+type MockServer struct {
+ Router *http.ServeMux
+ Server *httptest.Server
+
+ AccessTokenToUserID map[string]userAndDeviceID
+ DeviceInbox map[id.UserID]map[id.DeviceID][]event.Event
+ AccountData map[id.UserID]map[event.Type]json.RawMessage
+ DeviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys
+ OneTimeKeys map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey
+ MasterKeys map[id.UserID]mautrix.CrossSigningKeys
+ SelfSigningKeys map[id.UserID]mautrix.CrossSigningKeys
+ UserSigningKeys map[id.UserID]mautrix.CrossSigningKeys
+
+ PopOTKs bool
+ MemoryStore bool
+}
+
+func Create(t testing.TB) *MockServer {
+ t.Helper()
+
+ server := MockServer{
+ AccessTokenToUserID: map[string]userAndDeviceID{},
+ DeviceInbox: map[id.UserID]map[id.DeviceID][]event.Event{},
+ AccountData: map[id.UserID]map[event.Type]json.RawMessage{},
+ DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{},
+ OneTimeKeys: map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{},
+ MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{},
+ SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
+ UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
+ PopOTKs: true,
+ MemoryStore: true,
+ }
+
+ router := http.NewServeMux()
+ router.HandleFunc("POST /_matrix/client/v3/login", server.postLogin)
+ router.HandleFunc("POST /_matrix/client/v3/keys/query", server.postKeysQuery)
+ router.HandleFunc("POST /_matrix/client/v3/keys/claim", server.postKeysClaim)
+ router.HandleFunc("PUT /_matrix/client/v3/sendToDevice/{type}/{txn}", server.putSendToDevice)
+ router.HandleFunc("PUT /_matrix/client/v3/user/{userID}/account_data/{type}", server.putAccountData)
+ router.HandleFunc("POST /_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload)
+ router.HandleFunc("POST /_matrix/client/v3/keys/signatures/upload", server.emptyResp)
+ router.HandleFunc("POST /_matrix/client/v3/keys/upload", server.postKeysUpload)
+ server.Router = router
+ server.Server = httptest.NewServer(router)
+ t.Cleanup(server.Server.Close)
+ return &server
+}
+
+func (ms *MockServer) getUserID(r *http.Request) userAndDeviceID {
+ authHeader := r.Header.Get("Authorization")
+ authHeader = strings.TrimPrefix(authHeader, "Bearer ")
+ userID, ok := ms.AccessTokenToUserID[authHeader]
+ if !ok {
+ panic("no user ID found for access token " + authHeader)
+ }
+ return userID
+}
+
+func (ms *MockServer) emptyResp(w http.ResponseWriter, _ *http.Request) {
+ exhttp.WriteEmptyJSONResponse(w, http.StatusOK)
+}
+
+func (ms *MockServer) postLogin(w http.ResponseWriter, r *http.Request) {
+ var loginReq mautrix.ReqLogin
+ mustDecode(r, &loginReq)
+
+ deviceID := loginReq.DeviceID
+ if deviceID == "" {
+ deviceID = id.DeviceID(random.String(10))
+ }
+
+ accessToken := random.String(30)
+ userID := id.UserID(loginReq.Identifier.User)
+ ms.AccessTokenToUserID[accessToken] = userAndDeviceID{
+ UserID: userID,
+ DeviceID: deviceID,
+ }
+
+ exhttp.WriteJSONResponse(w, http.StatusOK, &mautrix.RespLogin{
+ AccessToken: accessToken,
+ DeviceID: deviceID,
+ UserID: userID,
+ })
+}
+
+func (ms *MockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) {
+ var req mautrix.ReqSendToDevice
+ mustDecode(r, &req)
+ evtType := event.Type{Type: r.PathValue("type"), Class: event.ToDeviceEventType}
+
+ for user, devices := range req.Messages {
+ for device, content := range devices {
+ if _, ok := ms.DeviceInbox[user]; !ok {
+ ms.DeviceInbox[user] = map[id.DeviceID][]event.Event{}
+ }
+ content.ParseRaw(evtType)
+ ms.DeviceInbox[user][device] = append(ms.DeviceInbox[user][device], event.Event{
+ Sender: ms.getUserID(r).UserID,
+ Type: evtType,
+ Content: *content,
+ })
+ }
+ }
+ ms.emptyResp(w, r)
+}
+
+func (ms *MockServer) putAccountData(w http.ResponseWriter, r *http.Request) {
+ userID := id.UserID(r.PathValue("userID"))
+ eventType := event.Type{Type: r.PathValue("type"), Class: event.AccountDataEventType}
+
+ jsonData, _ := io.ReadAll(r.Body)
+ if _, ok := ms.AccountData[userID]; !ok {
+ ms.AccountData[userID] = map[event.Type]json.RawMessage{}
+ }
+ ms.AccountData[userID][eventType] = json.RawMessage(jsonData)
+ ms.emptyResp(w, r)
+}
+
+func (ms *MockServer) postKeysQuery(w http.ResponseWriter, r *http.Request) {
+ var req mautrix.ReqQueryKeys
+ mustDecode(r, &req)
+ resp := mautrix.RespQueryKeys{
+ MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{},
+ UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
+ SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
+ DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{},
+ }
+ for user := range req.DeviceKeys {
+ resp.MasterKeys[user] = ms.MasterKeys[user]
+ resp.UserSigningKeys[user] = ms.UserSigningKeys[user]
+ resp.SelfSigningKeys[user] = ms.SelfSigningKeys[user]
+ resp.DeviceKeys[user] = ms.DeviceKeys[user]
+ }
+ exhttp.WriteJSONResponse(w, http.StatusOK, &resp)
+}
+
+func (ms *MockServer) postKeysClaim(w http.ResponseWriter, r *http.Request) {
+ var req mautrix.ReqClaimKeys
+ mustDecode(r, &req)
+ resp := mautrix.RespClaimKeys{
+ OneTimeKeys: map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{},
+ }
+ for user, devices := range req.OneTimeKeys {
+ resp.OneTimeKeys[user] = map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{}
+ for device := range devices {
+ keys := ms.OneTimeKeys[user][device]
+ for keyID, key := range keys {
+ if ms.PopOTKs {
+ delete(keys, keyID)
+ }
+ resp.OneTimeKeys[user][device] = map[id.KeyID]mautrix.OneTimeKey{
+ keyID: key,
+ }
+ break
+ }
+ }
+ }
+ exhttp.WriteJSONResponse(w, http.StatusOK, &resp)
+}
+
+func (ms *MockServer) postKeysUpload(w http.ResponseWriter, r *http.Request) {
+ var req mautrix.ReqUploadKeys
+ mustDecode(r, &req)
+
+ uid := ms.getUserID(r)
+ userID := uid.UserID
+ if _, ok := ms.DeviceKeys[userID]; !ok {
+ ms.DeviceKeys[userID] = map[id.DeviceID]mautrix.DeviceKeys{}
+ }
+ if _, ok := ms.OneTimeKeys[userID]; !ok {
+ ms.OneTimeKeys[userID] = map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{}
+ }
+
+ if req.DeviceKeys != nil {
+ ms.DeviceKeys[userID][uid.DeviceID] = *req.DeviceKeys
+ }
+ otks, ok := ms.OneTimeKeys[userID][uid.DeviceID]
+ if !ok {
+ otks = map[id.KeyID]mautrix.OneTimeKey{}
+ ms.OneTimeKeys[userID][uid.DeviceID] = otks
+ }
+ if req.OneTimeKeys != nil {
+ maps.Copy(otks, req.OneTimeKeys)
+ }
+
+ exhttp.WriteJSONResponse(w, http.StatusOK, &mautrix.RespUploadKeys{
+ OneTimeKeyCounts: mautrix.OTKCount{SignedCurve25519: len(otks)},
+ })
+}
+
+func (ms *MockServer) postDeviceSigningUpload(w http.ResponseWriter, r *http.Request) {
+ var req mautrix.UploadCrossSigningKeysReq[any]
+ mustDecode(r, &req)
+
+ userID := ms.getUserID(r).UserID
+ ms.MasterKeys[userID] = req.Master
+ ms.SelfSigningKeys[userID] = req.SelfSigning
+ ms.UserSigningKeys[userID] = req.UserSigning
+
+ ms.emptyResp(w, r)
+}
+
+func (ms *MockServer) Login(t testing.TB, ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*mautrix.Client, crypto.Store) {
+ t.Helper()
+ if ctx == nil {
+ ctx = context.TODO()
+ }
+ client, err := mautrix.NewClient(ms.Server.URL, "", "")
+ require.NoError(t, err)
+ client.Client = ms.Server.Client()
+
+ _, err = client.Login(ctx, &mautrix.ReqLogin{
+ Type: mautrix.AuthTypePassword,
+ Identifier: mautrix.UserIdentifier{
+ Type: mautrix.IdentifierTypeUser,
+ User: userID.String(),
+ },
+ DeviceID: deviceID,
+ Password: "password",
+ StoreCredentials: true,
+ })
+ require.NoError(t, err)
+
+ var store any
+ if ms.MemoryStore {
+ store = crypto.NewMemoryStore(nil)
+ client.StateStore = mautrix.NewMemoryStateStore()
+ } else {
+ store, err = dbutil.NewFromConfig("", dbutil.Config{
+ PoolConfig: dbutil.PoolConfig{
+ Type: "sqlite3-fk-wal",
+ URI: fmt.Sprintf("file:%s?mode=memory&cache=shared&_txlock=immediate", random.String(10)),
+ MaxOpenConns: 5,
+ MaxIdleConns: 1,
+ },
+ }, nil)
+ require.NoError(t, err)
+ }
+ cryptoHelper, err := cryptohelper.NewCryptoHelper(client, []byte("test"), store)
+ require.NoError(t, err)
+ client.Crypto = cryptoHelper
+
+ err = cryptoHelper.Init(ctx)
+ require.NoError(t, err)
+
+ machineLog := globallog.Logger.With().
+ Stringer("my_user_id", userID).
+ Stringer("my_device_id", deviceID).
+ Logger()
+ cryptoHelper.Machine().Log = &machineLog
+
+ err = cryptoHelper.Machine().ShareKeys(ctx, 50)
+ require.NoError(t, err)
+
+ return client, cryptoHelper.Machine().CryptoStore
+}
+
+func (ms *MockServer) DispatchToDevice(t testing.TB, ctx context.Context, client *mautrix.Client) {
+ t.Helper()
+
+ for _, evt := range ms.DeviceInbox[client.UserID][client.DeviceID] {
+ client.Syncer.(*mautrix.DefaultSyncer).Dispatch(ctx, &evt)
+ ms.DeviceInbox[client.UserID][client.DeviceID] = ms.DeviceInbox[client.UserID][client.DeviceID][1:]
+ }
+}
diff --git a/pushrules/action.go b/pushrules/action.go
index 9838e88b..b5a884b2 100644
--- a/pushrules/action.go
+++ b/pushrules/action.go
@@ -105,7 +105,7 @@ func (action *PushAction) UnmarshalJSON(raw []byte) error {
if ok {
action.Action = ActionSetTweak
action.Tweak = PushActionTweak(tweak)
- action.Value, _ = val["value"]
+ action.Value = val["value"]
}
}
return nil
diff --git a/pushrules/action_test.go b/pushrules/action_test.go
index a8f68415..3c0aa168 100644
--- a/pushrules/action_test.go
+++ b/pushrules/action_test.go
@@ -139,9 +139,9 @@ func TestPushAction_UnmarshalJSON_InvalidTypeDoesNothing(t *testing.T) {
}
err := pa.UnmarshalJSON([]byte(`{"foo": "bar"}`))
- assert.Nil(t, err)
+ assert.NoError(t, err)
err = pa.UnmarshalJSON([]byte(`9001`))
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, pushrules.PushActionType("unchanged"), pa.Action)
assert.Equal(t, pushrules.PushActionTweak("unchanged"), pa.Tweak)
@@ -156,7 +156,7 @@ func TestPushAction_UnmarshalJSON_StringChangesActionType(t *testing.T) {
}
err := pa.UnmarshalJSON([]byte(`"foo"`))
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, pushrules.PushActionType("foo"), pa.Action)
assert.Equal(t, pushrules.PushActionTweak("unchanged"), pa.Tweak)
@@ -171,7 +171,7 @@ func TestPushAction_UnmarshalJSON_SetTweakChangesTweak(t *testing.T) {
}
err := pa.UnmarshalJSON([]byte(`{"set_tweak": "foo", "value": 123.0}`))
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, pushrules.ActionSetTweak, pa.Action)
assert.Equal(t, pushrules.PushActionTweak("foo"), pa.Tweak)
@@ -185,7 +185,7 @@ func TestPushAction_MarshalJSON_TweakOutputWorks(t *testing.T) {
Value: "bar",
}
data, err := pa.MarshalJSON()
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, []byte(`{"set_tweak":"foo","value":"bar"}`), data)
}
@@ -196,6 +196,6 @@ func TestPushAction_MarshalJSON_OtherOutputWorks(t *testing.T) {
Value: "bar",
}
data, err := pa.MarshalJSON()
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, []byte(`"something else"`), data)
}
diff --git a/pushrules/condition.go b/pushrules/condition.go
index 435178fb..caa717de 100644
--- a/pushrules/condition.go
+++ b/pushrules/condition.go
@@ -15,10 +15,10 @@ import (
"unicode"
"github.com/tidwall/gjson"
+ "go.mau.fi/util/glob"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
- "maunium.net/go/mautrix/pushrules/glob"
)
// Room is an interface with the functions that are needed for processing room-specific push conditions
@@ -27,6 +27,11 @@ type Room interface {
GetMemberCount() int
}
+type PowerLevelfulRoom interface {
+ Room
+ GetPowerLevels() *event.PowerLevelsEventContent
+}
+
// EventfulRoom is an extension of Room to support MSC3664.
type EventfulRoom interface {
Room
@@ -38,11 +43,12 @@ type PushCondKind string
// The allowed push condition kinds as specified in https://spec.matrix.org/v1.2/client-server-api/#conditions-1
const (
- KindEventMatch PushCondKind = "event_match"
- KindContainsDisplayName PushCondKind = "contains_display_name"
- KindRoomMemberCount PushCondKind = "room_member_count"
- KindEventPropertyIs PushCondKind = "event_property_is"
- KindEventPropertyContains PushCondKind = "event_property_contains"
+ KindEventMatch PushCondKind = "event_match"
+ KindContainsDisplayName PushCondKind = "contains_display_name"
+ KindRoomMemberCount PushCondKind = "room_member_count"
+ KindEventPropertyIs PushCondKind = "event_property_is"
+ KindEventPropertyContains PushCondKind = "event_property_contains"
+ KindSenderNotificationPermission PushCondKind = "sender_notification_permission"
// MSC3664: https://github.com/matrix-org/matrix-spec-proposals/pull/3664
@@ -82,6 +88,8 @@ func (cond *PushCondition) Match(room Room, evt *event.Event) bool {
return cond.matchDisplayName(room, evt)
case KindRoomMemberCount:
return cond.matchMemberCount(room)
+ case KindSenderNotificationPermission:
+ return cond.matchSenderNotificationPermission(room, evt.Sender, cond.Key)
default:
return false
}
@@ -219,11 +227,11 @@ func (cond *PushCondition) matchValue(evt *event.Event) bool {
switch cond.Kind {
case KindEventMatch, KindRelatedEventMatch, KindUnstableRelatedEventMatch:
- pattern, err := glob.Compile(cond.Pattern)
- if err != nil {
+ pattern := glob.CompileWithImplicitContains(cond.Pattern)
+ if pattern == nil {
return false
}
- return pattern.MatchString(stringifyForPushCondition(val))
+ return pattern.Match(stringifyForPushCondition(val))
case KindEventPropertyIs:
return valueEquals(val, cond.Value)
case KindEventPropertyContains:
@@ -334,3 +342,18 @@ func (cond *PushCondition) matchMemberCount(room Room) bool {
return false
}
}
+
+func (cond *PushCondition) matchSenderNotificationPermission(room Room, sender id.UserID, key string) bool {
+ if key != "room" {
+ return false
+ }
+ plRoom, ok := room.(PowerLevelfulRoom)
+ if !ok {
+ return false
+ }
+ pls := plRoom.GetPowerLevels()
+ if pls == nil {
+ return false
+ }
+ return pls.GetUserLevel(sender) >= pls.Notifications.Room()
+}
diff --git a/pushrules/condition_test.go b/pushrules/condition_test.go
index 0d3eaf7a..37af3e34 100644
--- a/pushrules/condition_test.go
+++ b/pushrules/condition_test.go
@@ -102,14 +102,6 @@ func newEventPropertyIsPushCondition(key string, value any) *pushrules.PushCondi
}
}
-func newEventPropertyContainsPushCondition(key string, value any) *pushrules.PushCondition {
- return &pushrules.PushCondition{
- Kind: pushrules.KindEventPropertyContains,
- Key: key,
- Value: value,
- }
-}
-
func TestPushCondition_Match_InvalidKind(t *testing.T) {
condition := &pushrules.PushCondition{
Kind: pushrules.PushCondKind("invalid"),
diff --git a/pushrules/glob/LICENSE b/pushrules/glob/LICENSE
deleted file mode 100644
index cb00d952..00000000
--- a/pushrules/glob/LICENSE
+++ /dev/null
@@ -1,22 +0,0 @@
-Glob is licensed under the MIT "Expat" License:
-
-Copyright (c) 2016: Zachary Yedidia.
-
-Permission is hereby granted, free of charge, to any person obtaining
-a copy of this software and associated documentation files (the
-"Software"), to deal in the Software without restriction, including
-without limitation the rights to use, copy, modify, merge, publish,
-distribute, sublicense, and/or sell copies of the Software, and to
-permit persons to whom the Software is furnished to do so, subject to
-the following conditions:
-
-The above copyright notice and this permission notice shall be
-included in all copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
-EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
-MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
-IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
-CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
-TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
-SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
diff --git a/pushrules/glob/README.md b/pushrules/glob/README.md
deleted file mode 100644
index e2e6c649..00000000
--- a/pushrules/glob/README.md
+++ /dev/null
@@ -1,28 +0,0 @@
-# String globbing in Go
-
-[](http://godoc.org/github.com/zyedidia/glob)
-
-This package adds support for globs in Go.
-
-It simply converts glob expressions to regexps. I try to follow the standard defined [here](http://pubs.opengroup.org/onlinepubs/009695399/utilities/xcu_chap02.html#tag_02_13).
-
-# Example
-
-```go
-package main
-
-import "github.com/zyedidia/glob"
-
-func main() {
- glob, err := glob.Compile("{*.go,*.c}")
- if err != nil {
- // Error
- }
-
- glob.Match([]byte("test.c")) // true
- glob.Match([]byte("hello.go")) // true
- glob.Match([]byte("test.d")) // false
-}
-```
-
-You can call all the same functions on a glob that you can call on a regexp.
diff --git a/pushrules/glob/glob.go b/pushrules/glob/glob.go
deleted file mode 100644
index c270dbc5..00000000
--- a/pushrules/glob/glob.go
+++ /dev/null
@@ -1,108 +0,0 @@
-// Package glob provides objects for matching strings with globs
-package glob
-
-import "regexp"
-
-// Glob is a wrapper of *regexp.Regexp.
-// It should contain a glob expression compiled into a regular expression.
-type Glob struct {
- *regexp.Regexp
-}
-
-// Compile a takes a glob expression as a string and transforms it
-// into a *Glob object (which is really just a regular expression)
-// Compile also returns a possible error.
-func Compile(pattern string) (*Glob, error) {
- r, err := globToRegex(pattern)
- return &Glob{r}, err
-}
-
-func globToRegex(glob string) (*regexp.Regexp, error) {
- regex := ""
- inGroup := 0
- inClass := 0
- firstIndexInClass := -1
- arr := []byte(glob)
-
- hasGlobCharacters := false
-
- for i := 0; i < len(arr); i++ {
- ch := arr[i]
-
- switch ch {
- case '\\':
- i++
- if i >= len(arr) {
- regex += "\\"
- } else {
- next := arr[i]
- switch next {
- case ',':
- // Nothing
- case 'Q', 'E':
- regex += "\\\\"
- default:
- regex += "\\"
- }
- regex += string(next)
- }
- case '*':
- if inClass == 0 {
- regex += ".*"
- } else {
- regex += "*"
- }
- hasGlobCharacters = true
- case '?':
- if inClass == 0 {
- regex += "."
- } else {
- regex += "?"
- }
- hasGlobCharacters = true
- case '[':
- inClass++
- firstIndexInClass = i + 1
- regex += "["
- hasGlobCharacters = true
- case ']':
- inClass--
- regex += "]"
- case '.', '(', ')', '+', '|', '^', '$', '@', '%':
- if inClass == 0 || (firstIndexInClass == i && ch == '^') {
- regex += "\\"
- }
- regex += string(ch)
- hasGlobCharacters = true
- case '!':
- if firstIndexInClass == i {
- regex += "^"
- } else {
- regex += "!"
- }
- hasGlobCharacters = true
- case '{':
- inGroup++
- regex += "("
- hasGlobCharacters = true
- case '}':
- inGroup--
- regex += ")"
- case ',':
- if inGroup > 0 {
- regex += "|"
- hasGlobCharacters = true
- } else {
- regex += ","
- }
- default:
- regex += string(ch)
- }
- }
-
- if hasGlobCharacters {
- return regexp.Compile("^" + regex + "$")
- } else {
- return regexp.Compile(regex)
- }
-}
diff --git a/pushrules/pushrules_test.go b/pushrules/pushrules_test.go
index a531ca28..a5a0f5e7 100644
--- a/pushrules/pushrules_test.go
+++ b/pushrules/pushrules_test.go
@@ -25,7 +25,7 @@ func TestEventToPushRules(t *testing.T) {
},
}
pushRuleset, err := pushrules.EventToPushRules(evt)
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.NotNil(t, pushRuleset)
assert.IsType(t, pushRuleset.Override, pushrules.PushRuleArray{})
diff --git a/pushrules/rule.go b/pushrules/rule.go
index 0f7436f3..cf659695 100644
--- a/pushrules/rule.go
+++ b/pushrules/rule.go
@@ -8,10 +8,14 @@ package pushrules
import (
"encoding/gob"
+ "regexp"
+ "strings"
+
+ "go.mau.fi/util/exerrors"
+ "go.mau.fi/util/glob"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
- "maunium.net/go/mautrix/pushrules/glob"
)
func init() {
@@ -164,13 +168,20 @@ func (rule *PushRule) matchConditions(room Room, evt *event.Event) bool {
}
func (rule *PushRule) matchPattern(room Room, evt *event.Event) bool {
- pattern, err := glob.Compile(rule.Pattern)
- if err != nil {
- return false
- }
msg, ok := evt.Content.Raw["body"].(string)
if !ok {
return false
}
+ var buf strings.Builder
+ // As per https://spec.matrix.org/unstable/client-server-api/#push-rules, content rules are case-insensitive
+ // and must match whole words, so wrap the converted glob in (?i) and \b.
+ buf.WriteString(`(?i)\b`)
+ // strings.Builder will never return errors
+ exerrors.PanicIfNotNil(glob.ToRegexPattern(rule.Pattern, &buf))
+ buf.WriteString(`\b`)
+ pattern, err := regexp.Compile(buf.String())
+ if err != nil {
+ return false
+ }
return pattern.MatchString(msg)
}
diff --git a/pushrules/rule_test.go b/pushrules/rule_test.go
index 803c721e..7ff839a7 100644
--- a/pushrules/rule_test.go
+++ b/pushrules/rule_test.go
@@ -186,6 +186,34 @@ func TestPushRule_Match_Content(t *testing.T) {
assert.True(t, rule.Match(blankTestRoom, evt))
}
+func TestPushRule_Match_WordBoundary(t *testing.T) {
+ rule := &pushrules.PushRule{
+ Type: pushrules.ContentRule,
+ Enabled: true,
+ Pattern: "test",
+ }
+
+ evt := newFakeEvent(event.EventMessage, &event.MessageEventContent{
+ MsgType: event.MsgEmote,
+ Body: "is testing pushrules",
+ })
+ assert.False(t, rule.Match(blankTestRoom, evt))
+}
+
+func TestPushRule_Match_CaseInsensitive(t *testing.T) {
+ rule := &pushrules.PushRule{
+ Type: pushrules.ContentRule,
+ Enabled: true,
+ Pattern: "test",
+ }
+
+ evt := newFakeEvent(event.EventMessage, &event.MessageEventContent{
+ MsgType: event.MsgEmote,
+ Body: "is TeSt-InG pushrules",
+ })
+ assert.True(t, rule.Match(blankTestRoom, evt))
+}
+
func TestPushRule_Match_Content_Fail(t *testing.T) {
rule := &pushrules.PushRule{
Type: pushrules.ContentRule,
diff --git a/pushrules/ruleset.go b/pushrules/ruleset.go
index 609997b4..c42d4799 100644
--- a/pushrules/ruleset.go
+++ b/pushrules/ruleset.go
@@ -68,6 +68,9 @@ func (rs *PushRuleset) MarshalJSON() ([]byte, error) {
var DefaultPushActions = PushActionArray{&PushAction{Action: ActionDontNotify}}
func (rs *PushRuleset) GetMatchingRule(room Room, evt *event.Event) (rule *PushRule) {
+ if rs == nil {
+ return nil
+ }
// Add push rule collections to array in priority order
arrays := []PushRuleCollection{rs.Override, rs.Content, rs.Room, rs.Sender, rs.Underride}
// Loop until one of the push rule collections matches the room/event combo.
diff --git a/requests.go b/requests.go
index 6e00346a..cc8b7266 100644
--- a/requests.go
+++ b/requests.go
@@ -2,8 +2,11 @@ package mautrix
import (
"encoding/json"
+ "fmt"
"strconv"
+ "time"
+ "maunium.net/go/mautrix/crypto/signatures"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/pushrules"
@@ -37,20 +40,40 @@ const (
type Direction rune
+func (d Direction) MarshalJSON() ([]byte, error) {
+ return json.Marshal(string(d))
+}
+
+func (d *Direction) UnmarshalJSON(data []byte) error {
+ var str string
+ if err := json.Unmarshal(data, &str); err != nil {
+ return err
+ }
+ switch str {
+ case "f":
+ *d = DirectionForward
+ case "b":
+ *d = DirectionBackward
+ default:
+ return fmt.Errorf("invalid direction %q, must be 'f' or 'b'", str)
+ }
+ return nil
+}
+
const (
DirectionForward Direction = 'f'
DirectionBackward Direction = 'b'
)
// ReqRegister is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3register
-type ReqRegister struct {
+type ReqRegister[UIAType any] struct {
Username string `json:"username,omitempty"`
Password string `json:"password,omitempty"`
DeviceID id.DeviceID `json:"device_id,omitempty"`
InitialDeviceDisplayName string `json:"initial_device_display_name,omitempty"`
InhibitLogin bool `json:"inhibit_login,omitempty"`
RefreshToken bool `json:"refresh_token,omitempty"`
- Auth interface{} `json:"auth,omitempty"`
+ Auth UIAType `json:"auth,omitempty"`
// Type for registration, only used for appservice user registrations
// https://spec.matrix.org/v1.2/application-service-api/#server-admin-style-permissions
@@ -82,6 +105,7 @@ type ReqLogin struct {
Token string `json:"token,omitempty"`
DeviceID id.DeviceID `json:"device_id,omitempty"`
InitialDeviceDisplayName string `json:"initial_device_display_name,omitempty"`
+ RefreshToken bool `json:"refresh_token,omitempty"`
// Whether or not the returned credentials should be stored in the Client
StoreCredentials bool `json:"-"`
@@ -89,6 +113,10 @@ type ReqLogin struct {
StoreHomeserverURL bool `json:"-"`
}
+type ReqPutDevice struct {
+ DisplayName string `json:"display_name,omitempty"`
+}
+
type ReqUIAuthFallback struct {
Session string `json:"session"`
User string `json:"user"`
@@ -96,8 +124,9 @@ type ReqUIAuthFallback struct {
type ReqUIAuthLogin struct {
BaseAuthData
- User string `json:"user"`
- Password string `json:"password"`
+ User string `json:"user,omitempty"`
+ Password string `json:"password,omitempty"`
+ Token string `json:"token,omitempty"`
}
// ReqCreateRoom is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3createroom
@@ -112,12 +141,17 @@ type ReqCreateRoom struct {
InitialState []*event.Event `json:"initial_state,omitempty"`
Preset string `json:"preset,omitempty"`
IsDirect bool `json:"is_direct,omitempty"`
- RoomVersion string `json:"room_version,omitempty"`
+ RoomVersion id.RoomVersion `json:"room_version,omitempty"`
PowerLevelOverride *event.PowerLevelsEventContent `json:"power_level_content_override,omitempty"`
- MeowRoomID id.RoomID `json:"fi.mau.room_id,omitempty"`
- BeeperAutoJoinInvites bool `json:"com.beeper.auto_join_invites,omitempty"`
+ MeowRoomID id.RoomID `json:"fi.mau.room_id,omitempty"`
+ MeowCreateTS int64 `json:"fi.mau.origin_server_ts,omitempty"`
+ BeeperInitialMembers []id.UserID `json:"com.beeper.initial_members,omitempty"`
+ BeeperAutoJoinInvites bool `json:"com.beeper.auto_join_invites,omitempty"`
+ BeeperLocalRoomID id.RoomID `json:"com.beeper.local_room_id,omitempty"`
+ BeeperBridgeName string `json:"com.beeper.bridge_name,omitempty"`
+ BeeperBridgeAccountID string `json:"com.beeper.bridge_account_id,omitempty"`
}
// ReqRedact is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidredacteventidtxnid
@@ -127,12 +161,37 @@ type ReqRedact struct {
Extra map[string]interface{}
}
+type ReqRedactUser struct {
+ Reason string `json:"reason"`
+ Limit int `json:"-"`
+}
+
type ReqMembers struct {
At string `json:"at"`
Membership event.Membership `json:"membership,omitempty"`
NotMembership event.Membership `json:"not_membership,omitempty"`
}
+type ReqJoinRoom struct {
+ Via []string `json:"-"`
+ Reason string `json:"reason,omitempty"`
+ ThirdPartySigned any `json:"third_party_signed,omitempty"`
+}
+
+type ReqKnockRoom struct {
+ Via []string `json:"-"`
+ Reason string `json:"reason,omitempty"`
+}
+
+type ReqSearchUserDirectory struct {
+ SearchTerm string `json:"search_term"`
+ Limit int `json:"limit,omitempty"`
+}
+
+type ReqMutualRooms struct {
+ From string `json:"-"`
+}
+
// ReqInvite3PID is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidinvite-1
// It is also a JSON object used in https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3createroom
type ReqInvite3PID struct {
@@ -161,6 +220,8 @@ type ReqKickUser struct {
type ReqBanUser struct {
Reason string `json:"reason,omitempty"`
UserID id.UserID `json:"user_id"`
+
+ MSC4293RedactEvents bool `json:"org.matrix.msc4293.redact_events,omitempty"`
}
// ReqUnbanUser is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidunban
@@ -176,7 +237,8 @@ type ReqTyping struct {
}
type ReqPresence struct {
- Presence event.Presence `json:"presence"`
+ Presence event.Presence `json:"presence"`
+ StatusMsg string `json:"status_msg,omitempty"`
}
type ReqAliasCreate struct {
@@ -184,11 +246,11 @@ type ReqAliasCreate struct {
}
type OneTimeKey struct {
- Key id.Curve25519 `json:"key"`
- Fallback bool `json:"fallback,omitempty"`
- Signatures Signatures `json:"signatures,omitempty"`
- Unsigned map[string]any `json:"unsigned,omitempty"`
- IsSigned bool `json:"-"`
+ Key id.Curve25519 `json:"key"`
+ Fallback bool `json:"fallback,omitempty"`
+ Signatures signatures.Signatures `json:"signatures,omitempty"`
+ Unsigned map[string]any `json:"unsigned,omitempty"`
+ IsSigned bool `json:"-"`
// Raw data in the one-time key. This must be used for signature verification to ensure unrecognized fields
// aren't thrown away (because that would invalidate the signature).
@@ -221,7 +283,7 @@ func (otk *OneTimeKey) MarshalJSON() ([]byte, error) {
type ReqUploadKeys struct {
DeviceKeys *DeviceKeys `json:"device_keys,omitempty"`
- OneTimeKeys map[id.KeyID]OneTimeKey `json:"one_time_keys"`
+ OneTimeKeys map[id.KeyID]OneTimeKey `json:"one_time_keys,omitempty"`
}
type ReqKeysSignatures struct {
@@ -230,7 +292,7 @@ type ReqKeysSignatures struct {
Algorithms []id.Algorithm `json:"algorithms,omitempty"`
Usage []id.CrossSigningUsage `json:"usage,omitempty"`
Keys map[id.KeyID]string `json:"keys"`
- Signatures Signatures `json:"signatures"`
+ Signatures signatures.Signatures `json:"signatures"`
}
type ReqUploadSignatures map[id.UserID]map[string]ReqKeysSignatures
@@ -240,15 +302,15 @@ type DeviceKeys struct {
DeviceID id.DeviceID `json:"device_id"`
Algorithms []id.Algorithm `json:"algorithms"`
Keys KeyMap `json:"keys"`
- Signatures Signatures `json:"signatures"`
+ Signatures signatures.Signatures `json:"signatures"`
Unsigned map[string]interface{} `json:"unsigned,omitempty"`
}
type CrossSigningKeys struct {
- UserID id.UserID `json:"user_id"`
- Usage []id.CrossSigningUsage `json:"usage"`
- Keys map[id.KeyID]id.Ed25519 `json:"keys"`
- Signatures map[id.UserID]map[id.KeyID]string `json:"signatures,omitempty"`
+ UserID id.UserID `json:"user_id"`
+ Usage []id.CrossSigningUsage `json:"usage"`
+ Keys map[id.KeyID]id.Ed25519 `json:"keys"`
+ Signatures signatures.Signatures `json:"signatures,omitempty"`
}
func (csk *CrossSigningKeys) FirstKey() id.Ed25519 {
@@ -258,11 +320,11 @@ func (csk *CrossSigningKeys) FirstKey() id.Ed25519 {
return ""
}
-type UploadCrossSigningKeysReq struct {
+type UploadCrossSigningKeysReq[UIAType any] struct {
Master CrossSigningKeys `json:"master_key"`
SelfSigning CrossSigningKeys `json:"self_signing_key"`
UserSigning CrossSigningKeys `json:"user_signing_key"`
- Auth interface{} `json:"auth,omitempty"`
+ Auth UIAType `json:"auth,omitempty"`
}
type KeyMap map[id.DeviceKeyID]string
@@ -283,8 +345,6 @@ func (km KeyMap) GetCurve25519(deviceID id.DeviceID) id.Curve25519 {
return id.Curve25519(val)
}
-type Signatures map[id.UserID]map[id.KeyID]string
-
type ReqQueryKeys struct {
DeviceKeys DeviceKeysRequest `json:"device_keys"`
Timeout int64 `json:"timeout,omitempty"`
@@ -306,20 +366,40 @@ type ReqSendToDevice struct {
Messages map[id.UserID]map[id.DeviceID]*event.Content `json:"messages"`
}
+type ReqSendEvent struct {
+ Timestamp int64
+ TransactionID string
+ UnstableDelay time.Duration
+ UnstableStickyDuration time.Duration
+ DontEncrypt bool
+ MeowEventID id.EventID
+}
+
+type ReqDelayedEvents struct {
+ DelayID id.DelayID `json:"-"`
+ Status event.DelayStatus `json:"-"`
+ NextBatch string `json:"-"`
+}
+
+type ReqUpdateDelayedEvent struct {
+ DelayID id.DelayID `json:"-"`
+ Action event.DelayAction `json:"action"`
+}
+
// ReqDeviceInfo is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3devicesdeviceid
type ReqDeviceInfo struct {
DisplayName string `json:"display_name,omitempty"`
}
// ReqDeleteDevice is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#delete_matrixclientv3devicesdeviceid
-type ReqDeleteDevice struct {
- Auth interface{} `json:"auth,omitempty"`
+type ReqDeleteDevice[UIAType any] struct {
+ Auth UIAType `json:"auth,omitempty"`
}
// ReqDeleteDevices is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3delete_devices
-type ReqDeleteDevices struct {
+type ReqDeleteDevices[UIAType any] struct {
Devices []id.DeviceID `json:"devices"`
- Auth interface{} `json:"auth,omitempty"`
+ Auth UIAType `json:"auth,omitempty"`
}
type ReqPutPushRule struct {
@@ -331,18 +411,6 @@ type ReqPutPushRule struct {
Pattern string `json:"pattern"`
}
-// Deprecated: MSC2716 was abandoned
-type ReqBatchSend struct {
- PrevEventID id.EventID `json:"-"`
- BatchID id.BatchID `json:"-"`
-
- BeeperNewMessages bool `json:"-"`
- BeeperMarkReadBy id.UserID `json:"-"`
-
- StateEventsAtStart []*event.Event `json:"state_events_at_start"`
- Events []*event.Event `json:"events"`
-}
-
type ReqBeeperBatchSend struct {
// ForwardIfNoMessages should be set to true if the batch should be forward
// backfilled if there are no messages currently in the room.
@@ -363,10 +431,48 @@ type ReqSetReadMarkers struct {
BeeperFullyReadExtra interface{} `json:"com.beeper.fully_read.extra,omitempty"`
}
+type BeeperInboxDone struct {
+ Delta int64 `json:"at_delta"`
+ AtOrder int64 `json:"at_order"`
+}
+
+type ReqSetBeeperInboxState struct {
+ MarkedUnread *bool `json:"marked_unread,omitempty"`
+ Done *BeeperInboxDone `json:"done,omitempty"`
+ ReadMarkers *ReqSetReadMarkers `json:"read_markers,omitempty"`
+}
+
type ReqSendReceipt struct {
ThreadID string `json:"thread_id,omitempty"`
}
+type ReqPublicRooms struct {
+ IncludeAllNetworks bool
+ Limit int
+ Since string
+ ThirdPartyInstanceID string
+}
+
+func (req *ReqPublicRooms) Query() map[string]string {
+ query := map[string]string{}
+ if req == nil {
+ return query
+ }
+ if req.IncludeAllNetworks {
+ query["include_all_networks"] = "true"
+ }
+ if req.Limit > 0 {
+ query["limit"] = strconv.Itoa(req.Limit)
+ }
+ if req.Since != "" {
+ query["since"] = req.Since
+ }
+ if req.ThirdPartyInstanceID != "" {
+ query["third_party_instance_id"] = req.ThirdPartyInstanceID
+ }
+ return query
+}
+
// ReqHierarchy contains the parameters for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv1roomsroomidhierarchy
//
// As it's a GET method, there is no JSON body, so this is only query parameters.
@@ -429,22 +535,84 @@ type ReqBeeperSplitRoom struct {
Parts []BeeperSplitRoomPart `json:"parts"`
}
-type ReqRoomKeysVersionCreate struct {
- Algorithm string `json:"algorithm"`
- AuthData json.RawMessage `json:"auth_data"`
+type ReqRoomKeysVersionCreate[A any] struct {
+ Algorithm id.KeyBackupAlgorithm `json:"algorithm"`
+ AuthData A `json:"auth_data"`
}
-type ReqRoomKeysUpdate struct {
- Rooms map[id.RoomID]ReqRoomKeysRoomUpdate `json:"rooms"`
+type ReqRoomKeysVersionUpdate[A any] struct {
+ Algorithm id.KeyBackupAlgorithm `json:"algorithm"`
+ AuthData A `json:"auth_data"`
+ Version id.KeyBackupVersion `json:"version,omitempty"`
}
-type ReqRoomKeysRoomUpdate struct {
- Sessions map[id.SessionID]ReqRoomKeysSessionUpdate `json:"sessions"`
+type ReqKeyBackup struct {
+ Rooms map[id.RoomID]ReqRoomKeyBackup `json:"rooms"`
}
-type ReqRoomKeysSessionUpdate struct {
+type ReqRoomKeyBackup struct {
+ Sessions map[id.SessionID]ReqKeyBackupData `json:"sessions"`
+}
+
+type ReqKeyBackupData struct {
FirstMessageIndex int `json:"first_message_index"`
ForwardedCount int `json:"forwarded_count"`
IsVerified bool `json:"is_verified"`
SessionData json.RawMessage `json:"session_data"`
}
+
+type ReqReport struct {
+ Reason string `json:"reason,omitempty"`
+ Score int `json:"score,omitempty"`
+}
+
+type ReqGetRelations struct {
+ RelationType event.RelationType
+ EventType event.Type
+
+ Dir Direction
+ From string
+ To string
+ Limit int
+ Recurse bool
+}
+
+func (rgr *ReqGetRelations) PathSuffix() ClientURLPath {
+ if rgr.RelationType != "" {
+ if rgr.EventType.Type != "" {
+ return ClientURLPath{rgr.RelationType, rgr.EventType.Type}
+ }
+ return ClientURLPath{rgr.RelationType}
+ }
+ return ClientURLPath{}
+}
+
+func (rgr *ReqGetRelations) Query() map[string]string {
+ query := map[string]string{}
+ if rgr.Dir != 0 {
+ query["dir"] = string(rgr.Dir)
+ }
+ if rgr.From != "" {
+ query["from"] = rgr.From
+ }
+ if rgr.To != "" {
+ query["to"] = rgr.To
+ }
+ if rgr.Limit > 0 {
+ query["limit"] = strconv.Itoa(rgr.Limit)
+ }
+ if rgr.Recurse {
+ query["recurse"] = "true"
+ }
+ return query
+}
+
+// ReqSuspend is the request body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323
+type ReqSuspend struct {
+ Suspended bool `json:"suspended"`
+}
+
+// ReqLocked is the request body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323
+type ReqLocked struct {
+ Locked bool `json:"locked"`
+}
diff --git a/responses.go b/responses.go
index 69eb4b8f..4fbe1fbc 100644
--- a/responses.go
+++ b/responses.go
@@ -4,13 +4,16 @@ import (
"bytes"
"encoding/json"
"fmt"
+ "maps"
"reflect"
+ "slices"
"strconv"
"strings"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.mau.fi/util/jsontime"
+ "go.mau.fi/util/ptr"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
@@ -32,6 +35,11 @@ type RespJoinRoom struct {
RoomID id.RoomID `json:"room_id"`
}
+// RespKnockRoom is the JSON response for https://spec.matrix.org/v1.13/client-server-api/#post_matrixclientv3knockroomidoralias
+type RespKnockRoom struct {
+ RoomID id.RoomID `json:"room_id"`
+}
+
// RespLeaveRoom is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidleave
type RespLeaveRoom struct{}
@@ -97,6 +105,29 @@ type RespContext struct {
// RespSendEvent is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidsendeventtypetxnid
type RespSendEvent struct {
EventID id.EventID `json:"event_id"`
+
+ UnstableDelayID id.DelayID `json:"delay_id,omitempty"`
+}
+
+type RespUpdateDelayedEvent struct{}
+
+type RespDelayedEvents struct {
+ Scheduled []*event.ScheduledDelayedEvent `json:"scheduled,omitempty"`
+ Finalised []*event.FinalisedDelayedEvent `json:"finalised,omitempty"`
+ NextBatch string `json:"next_batch,omitempty"`
+
+ // Deprecated: Synapse implementation still returns this
+ DelayedEvents []*event.ScheduledDelayedEvent `json:"delayed_events,omitempty"`
+ // Deprecated: Synapse implementation still returns this
+ FinalisedEvents []*event.FinalisedDelayedEvent `json:"finalised_events,omitempty"`
+}
+
+type RespRedactUserEvents struct {
+ IsMoreEvents bool `json:"is_more_events"`
+ RedactedEvents struct {
+ Total int `json:"total"`
+ SoftFailed int `json:"soft_failed"`
+ } `json:"redacted_events"`
}
// RespMediaConfig is the JSON response for https://spec.matrix.org/v1.4/client-server-api/#get_matrixmediav3config
@@ -111,26 +142,18 @@ type RespMediaUpload struct {
// RespCreateMXC is the JSON response for https://spec.matrix.org/v1.7/client-server-api/#post_matrixmediav1create
type RespCreateMXC struct {
- ContentURI id.ContentURI `json:"content_uri"`
- UnusedExpiresAt int `json:"unused_expires_at,omitempty"`
+ ContentURI id.ContentURI `json:"content_uri"`
+ UnusedExpiresAt jsontime.UnixMilli `json:"unused_expires_at,omitempty"`
UnstableUploadURL string `json:"com.beeper.msc3870.upload_url,omitempty"`
+
+ // Beeper extensions for uploading unique media only once
+ BeeperUniqueID string `json:"com.beeper.unique_id,omitempty"`
+ BeeperCompletedAt jsontime.UnixMilli `json:"com.beeper.completed_at,omitempty"`
}
// RespPreviewURL is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#get_matrixmediav3preview_url
-type RespPreviewURL struct {
- CanonicalURL string `json:"og:url,omitempty"`
- Title string `json:"og:title,omitempty"`
- Type string `json:"og:type,omitempty"`
- Description string `json:"og:description,omitempty"`
-
- ImageURL id.ContentURIString `json:"og:image,omitempty"`
-
- ImageSize int `json:"matrix:image:size,omitempty"`
- ImageWidth int `json:"og:image:width,omitempty"`
- ImageHeight int `json:"og:image:height,omitempty"`
- ImageType string `json:"og:image:type,omitempty"`
-}
+type RespPreviewURL = event.LinkPreview
// RespUserInteractive is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#user-interactive-authentication-api
type RespUserInteractive struct {
@@ -163,8 +186,89 @@ type RespUserDisplayName struct {
}
type RespUserProfile struct {
- DisplayName string `json:"displayname"`
- AvatarURL id.ContentURI `json:"avatar_url"`
+ DisplayName string `json:"displayname,omitempty"`
+ AvatarURL id.ContentURI `json:"avatar_url,omitempty"`
+ Extra map[string]any `json:"-"`
+}
+
+type marshalableUserProfile RespUserProfile
+
+func (r *RespUserProfile) UnmarshalJSON(data []byte) error {
+ err := json.Unmarshal(data, &r.Extra)
+ if err != nil {
+ return err
+ }
+ r.DisplayName, _ = r.Extra["displayname"].(string)
+ avatarURL, _ := r.Extra["avatar_url"].(string)
+ if avatarURL != "" {
+ r.AvatarURL, _ = id.ParseContentURI(avatarURL)
+ }
+ delete(r.Extra, "displayname")
+ delete(r.Extra, "avatar_url")
+ return nil
+}
+
+func (r *RespUserProfile) MarshalJSON() ([]byte, error) {
+ if len(r.Extra) == 0 {
+ return json.Marshal((*marshalableUserProfile)(r))
+ }
+ marshalMap := maps.Clone(r.Extra)
+ if r.DisplayName != "" {
+ marshalMap["displayname"] = r.DisplayName
+ } else {
+ delete(marshalMap, "displayname")
+ }
+ if !r.AvatarURL.IsEmpty() {
+ marshalMap["avatar_url"] = r.AvatarURL.String()
+ } else {
+ delete(marshalMap, "avatar_url")
+ }
+ return json.Marshal(marshalMap)
+}
+
+type RespSearchUserDirectory struct {
+ Limited bool `json:"limited"`
+ Results []*UserDirectoryEntry `json:"results"`
+}
+
+type UserDirectoryEntry struct {
+ RespUserProfile
+ UserID id.UserID `json:"user_id"`
+}
+
+func (r *UserDirectoryEntry) UnmarshalJSON(data []byte) error {
+ err := r.RespUserProfile.UnmarshalJSON(data)
+ if err != nil {
+ return err
+ }
+ userIDStr, _ := r.Extra["user_id"].(string)
+ r.UserID = id.UserID(userIDStr)
+ delete(r.Extra, "user_id")
+ return nil
+}
+
+func (r *UserDirectoryEntry) MarshalJSON() ([]byte, error) {
+ if r.Extra == nil {
+ r.Extra = make(map[string]any)
+ }
+ r.Extra["user_id"] = r.UserID.String()
+ return r.RespUserProfile.MarshalJSON()
+}
+
+type RespMutualRooms struct {
+ Joined []id.RoomID `json:"joined"`
+ NextBatch string `json:"next_batch,omitempty"`
+ Count int `json:"count,omitempty"`
+}
+
+type RespRoomSummary struct {
+ PublicRoomInfo
+
+ Membership event.Membership `json:"membership,omitempty"`
+
+ UnstableRoomVersion id.RoomVersion `json:"im.nheko.summary.room_version,omitempty"`
+ UnstableRoomVersionOld id.RoomVersion `json:"im.nheko.summary.version,omitempty"`
+ UnstableEncryption id.Algorithm `json:"im.nheko.summary.encryption,omitempty"`
}
// RespRegisterAvailable is the JSON response for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv3registeravailable
@@ -215,6 +319,9 @@ type RespLogin struct {
DeviceID id.DeviceID `json:"device_id"`
UserID id.UserID `json:"user_id"`
WellKnown *ClientWellKnown `json:"well_known,omitempty"`
+
+ RefreshToken string `json:"refresh_token,omitempty"`
+ ExpiresInMS int64 `json:"expires_in_ms,omitempty"`
}
// RespLogout is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3logout
@@ -235,6 +342,24 @@ type LazyLoadSummary struct {
InvitedMemberCount *int `json:"m.invited_member_count,omitempty"`
}
+func (lls *LazyLoadSummary) MemberCount() int {
+ if lls == nil {
+ return 0
+ }
+ return ptr.Val(lls.JoinedMemberCount) + ptr.Val(lls.InvitedMemberCount)
+}
+
+func (lls *LazyLoadSummary) Equal(other *LazyLoadSummary) bool {
+ if lls == other {
+ return true
+ } else if lls == nil || other == nil {
+ return false
+ }
+ return ptr.Val(lls.JoinedMemberCount) == ptr.Val(other.JoinedMemberCount) &&
+ ptr.Val(lls.InvitedMemberCount) == ptr.Val(other.InvitedMemberCount) &&
+ slices.Equal(lls.Heroes, other.Heroes)
+}
+
type SyncEventsList struct {
Events []*event.Event `json:"events,omitempty"`
}
@@ -321,9 +446,16 @@ func (slr SyncLeftRoom) MarshalJSON() ([]byte, error) {
return marshalAndDeleteEmpty((marshalableSyncLeftRoom)(slr), syncLeftRoomPathsToDelete)
}
+type BeeperInboxPreviewEvent struct {
+ EventID id.EventID `json:"event_id"`
+ Timestamp jsontime.UnixMilli `json:"origin_server_ts"`
+ Event *event.Event `json:"event,omitempty"`
+}
+
type SyncJoinedRoom struct {
Summary LazyLoadSummary `json:"summary"`
State SyncEventsList `json:"state"`
+ StateAfter *SyncEventsList `json:"state_after,omitempty"`
Timeline SyncTimeline `json:"timeline"`
Ephemeral SyncEventsList `json:"ephemeral"`
AccountData SyncEventsList `json:"account_data"`
@@ -331,6 +463,8 @@ type SyncJoinedRoom struct {
UnreadNotifications *UnreadNotificationCounts `json:"unread_notifications,omitempty"`
// https://github.com/matrix-org/matrix-spec-proposals/pull/2654
MSC2654UnreadCount *int `json:"org.matrix.msc2654.unread_count,omitempty"`
+ // Beeper extension
+ BeeperInboxPreview *BeeperInboxPreviewEvent `json:"com.beeper.inbox.preview,omitempty"`
}
type UnreadNotificationCounts struct {
@@ -347,16 +481,7 @@ func (sjr SyncJoinedRoom) MarshalJSON() ([]byte, error) {
}
type SyncInvitedRoom struct {
- Summary LazyLoadSummary `json:"summary"`
- State SyncEventsList `json:"invite_state"`
-}
-
-type marshalableSyncInvitedRoom SyncInvitedRoom
-
-var syncInvitedRoomPathsToDelete = []string{"summary"}
-
-func (sir SyncInvitedRoom) MarshalJSON() ([]byte, error) {
- return marshalAndDeleteEmpty((marshalableSyncInvitedRoom)(sir), syncInvitedRoomPathsToDelete)
+ State SyncEventsList `json:"invite_state"`
}
type SyncKnockedRoom struct {
@@ -421,29 +546,19 @@ type RespDeviceInfo struct {
LastSeenTS int64 `json:"last_seen_ts"`
}
-// Deprecated: MSC2716 was abandoned
-type RespBatchSend struct {
- StateEventIDs []id.EventID `json:"state_event_ids"`
- EventIDs []id.EventID `json:"event_ids"`
-
- InsertionEventID id.EventID `json:"insertion_event_id"`
- BatchEventID id.EventID `json:"batch_event_id"`
- BaseInsertionEventID id.EventID `json:"base_insertion_event_id"`
-
- NextBatchID id.BatchID `json:"next_batch_id"`
-}
-
type RespBeeperBatchSend struct {
EventIDs []id.EventID `json:"event_ids"`
}
// RespCapabilities is the JSON response for https://spec.matrix.org/v1.3/client-server-api/#get_matrixclientv3capabilities
type RespCapabilities struct {
- RoomVersions *CapRoomVersions `json:"m.room_versions,omitempty"`
- ChangePassword *CapBooleanTrue `json:"m.change_password,omitempty"`
- SetDisplayname *CapBooleanTrue `json:"m.set_displayname,omitempty"`
- SetAvatarURL *CapBooleanTrue `json:"m.set_avatar_url,omitempty"`
- ThreePIDChanges *CapBooleanTrue `json:"m.3pid_changes,omitempty"`
+ RoomVersions *CapRoomVersions `json:"m.room_versions,omitempty"`
+ ChangePassword *CapBooleanTrue `json:"m.change_password,omitempty"`
+ SetDisplayname *CapBooleanTrue `json:"m.set_displayname,omitempty"`
+ SetAvatarURL *CapBooleanTrue `json:"m.set_avatar_url,omitempty"`
+ ThreePIDChanges *CapBooleanTrue `json:"m.3pid_changes,omitempty"`
+ GetLoginToken *CapBooleanTrue `json:"m.get_login_token,omitempty"`
+ UnstableAccountModeration *CapUnstableAccountModeration `json:"uk.timedout.msc4323,omitempty"`
Custom map[string]interface{} `json:"-"`
}
@@ -552,29 +667,44 @@ func (vers *CapRoomVersions) IsAvailable(version string) bool {
return available
}
+type CapUnstableAccountModeration struct {
+ Suspend bool `json:"suspend"`
+ Lock bool `json:"lock"`
+}
+
+type RespPublicRooms struct {
+ Chunk []*PublicRoomInfo `json:"chunk"`
+ NextBatch string `json:"next_batch,omitempty"`
+ PrevBatch string `json:"prev_batch,omitempty"`
+ TotalRoomCountEstimate int `json:"total_room_count_estimate"`
+}
+
+type PublicRoomInfo struct {
+ RoomID id.RoomID `json:"room_id"`
+ AvatarURL id.ContentURIString `json:"avatar_url,omitempty"`
+ CanonicalAlias id.RoomAlias `json:"canonical_alias,omitempty"`
+ GuestCanJoin bool `json:"guest_can_join"`
+ JoinRule event.JoinRule `json:"join_rule,omitempty"`
+ Name string `json:"name,omitempty"`
+ NumJoinedMembers int `json:"num_joined_members"`
+ RoomType event.RoomType `json:"room_type"`
+ Topic string `json:"topic,omitempty"`
+ WorldReadable bool `json:"world_readable"`
+
+ RoomVersion id.RoomVersion `json:"room_version,omitempty"`
+ Encryption id.Algorithm `json:"encryption,omitempty"`
+ AllowedRoomIDs []id.RoomID `json:"allowed_room_ids,omitempty"`
+}
+
// RespHierarchy is the JSON response for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv1roomsroomidhierarchy
type RespHierarchy struct {
- NextBatch string `json:"next_batch,omitempty"`
- Rooms []ChildRoomsChunk `json:"rooms"`
+ NextBatch string `json:"next_batch,omitempty"`
+ Rooms []*ChildRoomsChunk `json:"rooms"`
}
type ChildRoomsChunk struct {
- AvatarURL id.ContentURI `json:"avatar_url,omitempty"`
- CanonicalAlias id.RoomAlias `json:"canonical_alias,omitempty"`
- ChildrenState []StrippedStateWithTime `json:"children_state"`
- GuestCanJoin bool `json:"guest_can_join"`
- JoinRule event.JoinRule `json:"join_rule,omitempty"`
- Name string `json:"name,omitempty"`
- NumJoinedMembers int `json:"num_joined_members"`
- RoomID id.RoomID `json:"room_id"`
- RoomType event.RoomType `json:"room_type"`
- Topic string `json:"topic,omitempty"`
- WorldReadble bool `json:"world_readable"`
-}
-
-type StrippedStateWithTime struct {
- event.StrippedState
- Timestamp jsontime.UnixMilli `json:"origin_server_ts"`
+ PublicRoomInfo
+ ChildrenState []*event.Event `json:"children_state"`
}
type RespAppservicePing struct {
@@ -593,33 +723,77 @@ type RespTimestampToEvent struct {
}
type RespRoomKeysVersionCreate struct {
- Version string `json:"version"`
+ Version id.KeyBackupVersion `json:"version"`
}
-type RespRoomKeysVersion struct {
- Algorithm string `json:"algorithm"`
- AuthData json.RawMessage `json:"auth_data"`
- Count int `json:"count"`
- ETag string `json:"etag"`
- Version string `json:"version"`
+type RespRoomKeysVersion[A any] struct {
+ Algorithm id.KeyBackupAlgorithm `json:"algorithm"`
+ AuthData A `json:"auth_data"`
+ Count int `json:"count"`
+ ETag string `json:"etag"`
+ Version id.KeyBackupVersion `json:"version"`
}
-type RespRoomKeys struct {
- Rooms map[id.RoomID]RespRoomKeysRoom `json:"rooms"`
+type RespRoomKeys[S any] struct {
+ Rooms map[id.RoomID]RespRoomKeyBackup[S] `json:"rooms"`
}
-type RespRoomKeysRoom struct {
- Sessions map[id.SessionID]RespRoomKeysSession `json:"sessions"`
+type RespRoomKeyBackup[S any] struct {
+ Sessions map[id.SessionID]RespKeyBackupData[S] `json:"sessions"`
}
-type RespRoomKeysSession struct {
- FirstMessageIndex int `json:"first_message_index"`
- ForwardedCount int `json:"forwarded_count"`
- IsVerified bool `json:"is_verified"`
- SessionData json.RawMessage `json:"session_data"`
+type RespKeyBackupData[S any] struct {
+ FirstMessageIndex int `json:"first_message_index"`
+ ForwardedCount int `json:"forwarded_count"`
+ IsVerified bool `json:"is_verified"`
+ SessionData S `json:"session_data"`
}
type RespRoomKeysUpdate struct {
Count int `json:"count"`
ETag string `json:"etag"`
}
+
+type RespOpenIDToken struct {
+ AccessToken string `json:"access_token"`
+ ExpiresIn int `json:"expires_in"`
+ MatrixServerName string `json:"matrix_server_name"`
+ TokenType string `json:"token_type"` // Always "Bearer"
+}
+
+type RespGetRelations struct {
+ Chunk []*event.Event `json:"chunk"`
+ NextBatch string `json:"next_batch,omitempty"`
+ PrevBatch string `json:"prev_batch,omitempty"`
+ RecursionDepth int `json:"recursion_depth,omitempty"`
+}
+
+// RespSuspended is the response body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323
+type RespSuspended struct {
+ Suspended bool `json:"suspended"`
+}
+
+// RespLocked is the response body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323
+type RespLocked struct {
+ Locked bool `json:"locked"`
+}
+
+type ConnectionInfo struct {
+ IP string `json:"ip,omitempty"`
+ LastSeen jsontime.UnixMilli `json:"last_seen,omitempty"`
+ UserAgent string `json:"user_agent,omitempty"`
+}
+
+type SessionInfo struct {
+ Connections []ConnectionInfo `json:"connections,omitempty"`
+}
+
+type DeviceInfo struct {
+ Sessions []SessionInfo `json:"sessions,omitempty"`
+}
+
+// RespWhoIs is the response body for https://spec.matrix.org/v1.15/client-server-api/#get_matrixclientv3adminwhoisuserid
+type RespWhoIs struct {
+ UserID id.UserID `json:"user_id,omitempty"`
+ Devices map[id.DeviceID]DeviceInfo `json:"devices,omitempty"`
+}
diff --git a/responses_test.go b/responses_test.go
index b23d85ad..73d82635 100644
--- a/responses_test.go
+++ b/responses_test.go
@@ -8,7 +8,6 @@ package mautrix_test
import (
"encoding/json"
- "fmt"
"testing"
"github.com/stretchr/testify/assert"
@@ -86,7 +85,6 @@ func TestRespCapabilities_UnmarshalJSON(t *testing.T) {
var caps mautrix.RespCapabilities
err := json.Unmarshal([]byte(sampleData), &caps)
require.NoError(t, err)
- fmt.Println(caps)
require.NotNil(t, caps.RoomVersions)
assert.Equal(t, "9", caps.RoomVersions.Default)
diff --git a/room.go b/room.go
index c3ddb7e6..4292bff5 100644
--- a/room.go
+++ b/room.go
@@ -5,8 +5,6 @@ import (
"maunium.net/go/mautrix/id"
)
-type RoomStateMap = map[event.Type]map[string]*event.Event
-
// Room represents a single Matrix room.
type Room struct {
ID id.RoomID
@@ -25,8 +23,8 @@ func (room Room) UpdateState(evt *event.Event) {
// GetStateEvent returns the state event for the given type/state_key combo, or nil.
func (room Room) GetStateEvent(eventType event.Type, stateKey string) *event.Event {
- stateEventMap, _ := room.State[eventType]
- evt, _ := stateEventMap[stateKey]
+ stateEventMap := room.State[eventType]
+ evt := stateEventMap[stateKey]
return evt
}
diff --git a/sqlstatestore/statestore.go b/sqlstatestore/statestore.go
index cd94215d..11957dfa 100644
--- a/sqlstatestore/statestore.go
+++ b/sqlstatestore/statestore.go
@@ -17,7 +17,9 @@ import (
"strings"
"github.com/rs/zerolog"
+ "go.mau.fi/util/confusable"
"go.mau.fi/util/dbutil"
+ "go.mau.fi/util/exslices"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
@@ -37,6 +39,8 @@ const VersionTableName = "mx_version"
type SQLStateStore struct {
*dbutil.Database
IsBridge bool
+
+ DisableNameDisambiguation bool
}
func NewSQLStateStore(db *dbutil.Database, log dbutil.DatabaseLogger, isBridge bool) *SQLStateStore {
@@ -58,6 +62,9 @@ func (store *SQLStateStore) IsRegistered(ctx context.Context, userID id.UserID)
}
func (store *SQLStateStore) MarkRegistered(ctx context.Context, userID id.UserID) error {
+ if userID == "" {
+ return fmt.Errorf("user ID is empty")
+ }
_, err := store.Exec(ctx, "INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
return err
}
@@ -65,6 +72,7 @@ func (store *SQLStateStore) MarkRegistered(ctx context.Context, userID id.UserID
type Member struct {
id.UserID
event.MemberEventContent
+ NameSkeleton [32]byte
}
func (store *SQLStateStore) GetRoomMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) (map[id.UserID]*event.MemberEventContent, error) {
@@ -80,14 +88,11 @@ func (store *SQLStateStore) GetRoomMembers(ctx context.Context, roomID id.RoomID
query = fmt.Sprintf("%s AND membership IN (%s)", query, strings.Join(placeholders, ","))
}
rows, err := store.Query(ctx, query, args...)
- if err != nil {
- return nil, err
- }
members := make(map[id.UserID]*event.MemberEventContent)
- return members, dbutil.NewRowIter(rows, func(row dbutil.Scannable) (ret Member, err error) {
+ return members, dbutil.NewRowIterWithError(rows, func(row dbutil.Scannable) (ret Member, err error) {
err = row.Scan(&ret.UserID, &ret.Membership, &ret.Displayname, &ret.AvatarURL)
return
- }).Iter(func(m Member) (bool, error) {
+ }, err).Iter(func(m Member) (bool, error) {
members[m.UserID] = &m.MemberEventContent
return true, nil
})
@@ -154,10 +159,7 @@ func (store *SQLStateStore) FindSharedRooms(ctx context.Context, userID id.UserI
`
}
rows, err := store.Query(ctx, query, userID)
- if err != nil {
- return nil, err
- }
- return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.RoomID]).AsList()
+ return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.RoomID], err).AsList()
}
func (store *SQLStateStore) IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool {
@@ -183,6 +185,11 @@ func (store *SQLStateStore) IsMembership(ctx context.Context, roomID id.RoomID,
}
func (store *SQLStateStore) SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error {
+ if roomID == "" {
+ return fmt.Errorf("room ID is empty")
+ } else if userID == "" {
+ return fmt.Errorf("user ID is empty")
+ }
_, err := store.Exec(ctx, `
INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, '', '')
ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership
@@ -190,14 +197,101 @@ func (store *SQLStateStore) SetMembership(ctx context.Context, roomID id.RoomID,
return err
}
+const insertUserProfileQuery = `
+ INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url, name_skeleton)
+ VALUES ($1, $2, $3, $4, $5, $6)
+ ON CONFLICT (room_id, user_id) DO UPDATE
+ SET membership=excluded.membership,
+ displayname=excluded.displayname,
+ avatar_url=excluded.avatar_url,
+ name_skeleton=excluded.name_skeleton
+`
+
+type userProfileRow struct {
+ UserID id.UserID
+ Membership event.Membership
+ Displayname string
+ AvatarURL id.ContentURIString
+ NameSkeleton []byte
+}
+
+func (u *userProfileRow) GetMassInsertValues() [5]any {
+ return [5]any{u.UserID, u.Membership, u.Displayname, u.AvatarURL, u.NameSkeleton}
+}
+
+var userProfileMassInserter = dbutil.NewMassInsertBuilder[*userProfileRow, [1]any](insertUserProfileQuery, "($1, $%d, $%d, $%d, $%d, $%d)")
+
func (store *SQLStateStore) SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error {
- _, err := store.Exec(ctx, `
- INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, $4, $5)
- ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership, displayname=excluded.displayname, avatar_url=excluded.avatar_url
- `, roomID, userID, member.Membership, member.Displayname, member.AvatarURL)
+ if roomID == "" {
+ return fmt.Errorf("room ID is empty")
+ } else if userID == "" {
+ return fmt.Errorf("user ID is empty")
+ }
+ var nameSkeleton []byte
+ if !store.DisableNameDisambiguation && len(member.Displayname) > 0 {
+ nameSkeletonArr := confusable.SkeletonHash(member.Displayname)
+ nameSkeleton = nameSkeletonArr[:]
+ }
+ _, err := store.Exec(ctx, insertUserProfileQuery, roomID, userID, member.Membership, member.Displayname, member.AvatarURL, nameSkeleton)
return err
}
+func (store *SQLStateStore) IsConfusableName(ctx context.Context, roomID id.RoomID, currentUser id.UserID, name string) ([]id.UserID, error) {
+ if store.DisableNameDisambiguation {
+ return nil, nil
+ }
+ skeleton := confusable.SkeletonHash(name)
+ rows, err := store.Query(ctx, "SELECT user_id FROM mx_user_profile WHERE room_id=$1 AND name_skeleton=$2 AND user_id<>$3", roomID, skeleton[:], currentUser)
+ return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList()
+}
+
+const userProfileMassInsertBatchSize = 500
+
+func (store *SQLStateStore) ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error {
+ if roomID == "" {
+ return fmt.Errorf("room ID is empty")
+ }
+ return store.DoTxn(ctx, nil, func(ctx context.Context) error {
+ err := store.ClearCachedMembers(ctx, roomID, onlyMemberships...)
+ if err != nil {
+ return fmt.Errorf("failed to clear cached members: %w", err)
+ }
+ rows := make([]*userProfileRow, min(len(evts), userProfileMassInsertBatchSize))
+ for _, evtsChunk := range exslices.Chunk(evts, userProfileMassInsertBatchSize) {
+ rows = rows[:0]
+ for _, evt := range evtsChunk {
+ content, ok := evt.Content.Parsed.(*event.MemberEventContent)
+ if !ok {
+ continue
+ }
+ row := &userProfileRow{
+ UserID: id.UserID(*evt.StateKey),
+ Membership: content.Membership,
+ Displayname: content.Displayname,
+ AvatarURL: content.AvatarURL,
+ }
+ if !store.DisableNameDisambiguation && len(content.Displayname) > 0 {
+ nameSkeletonArr := confusable.SkeletonHash(content.Displayname)
+ row.NameSkeleton = nameSkeletonArr[:]
+ }
+ rows = append(rows, row)
+ }
+ query, args := userProfileMassInserter.Build([1]any{roomID}, rows)
+ _, err = store.Exec(ctx, query, args...)
+ if err != nil {
+ return fmt.Errorf("failed to insert members: %w", err)
+ }
+ }
+ if len(onlyMemberships) == 0 {
+ err = store.MarkMembersFetched(ctx, roomID)
+ if err != nil {
+ return fmt.Errorf("failed to mark members as fetched: %w", err)
+ }
+ }
+ return nil
+ })
+}
+
func (store *SQLStateStore) ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error {
query := "DELETE FROM mx_user_profile WHERE room_id=$1"
params := make([]any, len(memberships)+1)
@@ -211,10 +305,57 @@ func (store *SQLStateStore) ClearCachedMembers(ctx context.Context, roomID id.Ro
query += fmt.Sprintf(" AND membership IN (%s)", strings.Join(placeholders, ","))
}
_, err := store.Exec(ctx, query, params...)
+ if err != nil {
+ return err
+ }
+ _, err = store.Exec(ctx, "UPDATE mx_room_state SET members_fetched=false WHERE room_id=$1", roomID)
return err
}
+func (store *SQLStateStore) HasFetchedMembers(ctx context.Context, roomID id.RoomID) (fetched bool, err error) {
+ err = store.QueryRow(ctx, "SELECT COALESCE(members_fetched, false) FROM mx_room_state WHERE room_id=$1", roomID).Scan(&fetched)
+ if errors.Is(err, sql.ErrNoRows) {
+ err = nil
+ }
+ return
+}
+
+func (store *SQLStateStore) MarkMembersFetched(ctx context.Context, roomID id.RoomID) error {
+ if roomID == "" {
+ return fmt.Errorf("room ID is empty")
+ }
+ _, err := store.Exec(ctx, `
+ INSERT INTO mx_room_state (room_id, members_fetched) VALUES ($1, true)
+ ON CONFLICT (room_id) DO UPDATE SET members_fetched=true
+ `, roomID)
+ return err
+}
+
+type userAndMembership struct {
+ UserID id.UserID
+ event.MemberEventContent
+}
+
+func (store *SQLStateStore) GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) {
+ rows, err := store.Query(ctx, "SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1", roomID)
+ if err != nil {
+ return nil, err
+ }
+ output := make(map[id.UserID]*event.MemberEventContent)
+ err = dbutil.NewRowIterWithError(rows, func(row dbutil.Scannable) (res userAndMembership, err error) {
+ err = row.Scan(&res.UserID, &res.Membership, &res.Displayname, &res.AvatarURL)
+ return
+ }, err).Iter(func(member userAndMembership) (bool, error) {
+ output[member.UserID] = &member.MemberEventContent
+ return true, nil
+ })
+ return output, err
+}
+
func (store *SQLStateStore) SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error {
+ if roomID == "" {
+ return fmt.Errorf("room ID is empty")
+ }
contentBytes, err := json.Marshal(content)
if err != nil {
return fmt.Errorf("failed to marshal content JSON: %w", err)
@@ -229,7 +370,7 @@ func (store *SQLStateStore) SetEncryptionEvent(ctx context.Context, roomID id.Ro
func (store *SQLStateStore) GetEncryptionEvent(ctx context.Context, roomID id.RoomID) (*event.EncryptionEventContent, error) {
var data []byte
err := store.
- QueryRow(ctx, "SELECT encryption FROM mx_room_state WHERE room_id=$1", roomID).
+ QueryRow(ctx, "SELECT encryption FROM mx_room_state WHERE room_id=$1 AND encryption IS NOT NULL", roomID).
Scan(&data)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
@@ -252,6 +393,9 @@ func (store *SQLStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (
}
func (store *SQLStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error {
+ if roomID == "" {
+ return fmt.Errorf("room ID is empty")
+ }
_, err := store.Exec(ctx, `
INSERT INTO mx_room_state (room_id, power_levels) VALUES ($1, $2)
ON CONFLICT (room_id) DO UPDATE SET power_levels=excluded.power_levels
@@ -260,89 +404,92 @@ func (store *SQLStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID
}
func (store *SQLStateStore) GetPowerLevels(ctx context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) {
+ levels = &event.PowerLevelsEventContent{}
err = store.
- QueryRow(ctx, "SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID).
- Scan(&dbutil.JSON{Data: &levels})
+ QueryRow(ctx, "SELECT power_levels, create_event FROM mx_room_state WHERE room_id=$1 AND power_levels IS NOT NULL", roomID).
+ Scan(&dbutil.JSON{Data: &levels}, &dbutil.JSON{Data: &levels.CreateEvent})
if errors.Is(err, sql.ErrNoRows) {
- err = nil
+ return nil, nil
+ } else if err != nil {
+ return nil, err
+ }
+ if levels.CreateEvent != nil {
+ err = levels.CreateEvent.Content.ParseRaw(event.StateCreate)
}
return
}
func (store *SQLStateStore) GetPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID) (int, error) {
- if store.Dialect == dbutil.Postgres {
- var powerLevel int
- err := store.
- QueryRow(ctx, `
- SELECT COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
- FROM mx_room_state WHERE room_id=$1
- `, roomID, userID).
- Scan(&powerLevel)
- return powerLevel, err
- } else {
- levels, err := store.GetPowerLevels(ctx, roomID)
- if err != nil {
- return 0, err
- }
- return levels.GetUserLevel(userID), nil
+ levels, err := store.GetPowerLevels(ctx, roomID)
+ if err != nil {
+ return 0, err
}
+ return levels.GetUserLevel(userID), nil
}
func (store *SQLStateStore) GetPowerLevelRequirement(ctx context.Context, roomID id.RoomID, eventType event.Type) (int, error) {
- if store.Dialect == dbutil.Postgres {
- defaultType := "events_default"
- defaultValue := 0
- if eventType.IsState() {
- defaultType = "state_default"
- defaultValue = 50
- }
- var powerLevel int
- err := store.
- QueryRow(ctx, `
- SELECT COALESCE((power_levels->'events'->$2)::int, (power_levels->'$3')::int, $4)
- FROM mx_room_state WHERE room_id=$1
- `, roomID, eventType.Type, defaultType, defaultValue).
- Scan(&powerLevel)
- if errors.Is(err, sql.ErrNoRows) {
- err = nil
- powerLevel = defaultValue
- }
- return powerLevel, err
- } else {
- levels, err := store.GetPowerLevels(ctx, roomID)
- if err != nil {
- return 0, err
- }
- return levels.GetEventLevel(eventType), nil
+ levels, err := store.GetPowerLevels(ctx, roomID)
+ if err != nil {
+ return 0, err
}
+ return levels.GetEventLevel(eventType), nil
}
func (store *SQLStateStore) HasPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID, eventType event.Type) (bool, error) {
- if store.Dialect == dbutil.Postgres {
- defaultType := "events_default"
- defaultValue := 0
- if eventType.IsState() {
- defaultType = "state_default"
- defaultValue = 50
- }
- var hasPower bool
- err := store.
- QueryRow(ctx, `SELECT
- COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
- >=
- COALESCE((power_levels->'events'->$3)::int, (power_levels->'$4')::int, $5)
- FROM mx_room_state WHERE room_id=$1`, roomID, userID, eventType.Type, defaultType, defaultValue).
- Scan(&hasPower)
- if errors.Is(err, sql.ErrNoRows) {
- err = nil
- hasPower = defaultValue == 0
- }
- return hasPower, err
- } else {
- levels, err := store.GetPowerLevels(ctx, roomID)
- if err != nil {
- return false, err
- }
- return levels.GetUserLevel(userID) >= levels.GetEventLevel(eventType), nil
+ levels, err := store.GetPowerLevels(ctx, roomID)
+ if err != nil {
+ return false, err
}
+ return levels.GetUserLevel(userID) >= levels.GetEventLevel(eventType), nil
+}
+
+func (store *SQLStateStore) SetCreate(ctx context.Context, evt *event.Event) error {
+ if evt.Type != event.StateCreate {
+ return fmt.Errorf("invalid event type for create event: %s", evt.Type)
+ } else if evt.RoomID == "" {
+ return fmt.Errorf("room ID is empty")
+ }
+ _, err := store.Exec(ctx, `
+ INSERT INTO mx_room_state (room_id, create_event) VALUES ($1, $2)
+ ON CONFLICT (room_id) DO UPDATE SET create_event=excluded.create_event
+ `, evt.RoomID, dbutil.JSON{Data: evt})
+ return err
+}
+
+func (store *SQLStateStore) GetCreate(ctx context.Context, roomID id.RoomID) (evt *event.Event, err error) {
+ err = store.
+ QueryRow(ctx, "SELECT create_event FROM mx_room_state WHERE room_id=$1 AND create_event IS NOT NULL", roomID).
+ Scan(&dbutil.JSON{Data: &evt})
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, nil
+ } else if err != nil {
+ return nil, err
+ }
+ if evt != nil {
+ err = evt.Content.ParseRaw(event.StateCreate)
+ }
+ return
+}
+
+func (store *SQLStateStore) SetJoinRules(ctx context.Context, roomID id.RoomID, rules *event.JoinRulesEventContent) error {
+ if roomID == "" {
+ return fmt.Errorf("room ID is empty")
+ }
+ _, err := store.Exec(ctx, `
+ INSERT INTO mx_room_state (room_id, join_rules) VALUES ($1, $2)
+ ON CONFLICT (room_id) DO UPDATE SET join_rules=excluded.join_rules
+ `, roomID, dbutil.JSON{Data: rules})
+ return err
+}
+
+func (store *SQLStateStore) GetJoinRules(ctx context.Context, roomID id.RoomID) (levels *event.JoinRulesEventContent, err error) {
+ levels = &event.JoinRulesEventContent{}
+ err = store.
+ QueryRow(ctx, "SELECT join_rules FROM mx_room_state WHERE room_id=$1 AND join_rules IS NOT NULL", roomID).
+ Scan(&dbutil.JSON{Data: &levels})
+ if errors.Is(err, sql.ErrNoRows) {
+ levels = nil
+ err = nil
+ }
+ return
}
diff --git a/sqlstatestore/v00-latest-revision.sql b/sqlstatestore/v00-latest-revision.sql
index 41c2b9a1..4679f1c6 100644
--- a/sqlstatestore/v00-latest-revision.sql
+++ b/sqlstatestore/v00-latest-revision.sql
@@ -1,4 +1,4 @@
--- v0 -> v5: Latest revision
+-- v0 -> v10 (compatible with v3+): Latest revision
CREATE TABLE mx_registrations (
user_id TEXT PRIMARY KEY
@@ -8,16 +8,25 @@ CREATE TABLE mx_registrations (
CREATE TYPE membership AS ENUM ('join', 'leave', 'invite', 'ban', 'knock');
CREATE TABLE mx_user_profile (
- room_id TEXT,
- user_id TEXT,
- membership membership NOT NULL,
- displayname TEXT NOT NULL DEFAULT '',
- avatar_url TEXT NOT NULL DEFAULT '',
+ room_id TEXT,
+ user_id TEXT,
+ membership membership NOT NULL,
+ displayname TEXT NOT NULL DEFAULT '',
+ avatar_url TEXT NOT NULL DEFAULT '',
+
+ name_skeleton bytea,
+
PRIMARY KEY (room_id, user_id)
);
+CREATE INDEX mx_user_profile_membership_idx ON mx_user_profile (room_id, membership);
+CREATE INDEX mx_user_profile_name_skeleton_idx ON mx_user_profile (room_id, name_skeleton);
+
CREATE TABLE mx_room_state (
- room_id TEXT PRIMARY KEY,
- power_levels jsonb,
- encryption jsonb
+ room_id TEXT PRIMARY KEY,
+ power_levels jsonb,
+ encryption jsonb,
+ create_event jsonb,
+ join_rules jsonb,
+ members_fetched BOOLEAN NOT NULL DEFAULT false
);
diff --git a/sqlstatestore/v05-mark-encryption-state-resync.go b/sqlstatestore/v05-mark-encryption-state-resync.go
index bf44d308..b7f2b1c2 100644
--- a/sqlstatestore/v05-mark-encryption-state-resync.go
+++ b/sqlstatestore/v05-mark-encryption-state-resync.go
@@ -8,7 +8,7 @@ import (
)
func init() {
- UpgradeTable.Register(-1, 5, 0, "Mark rooms that need crypto state event resynced", true, func(ctx context.Context, db *dbutil.Database) error {
+ UpgradeTable.Register(-1, 5, 0, "Mark rooms that need crypto state event resynced", dbutil.TxnModeOn, func(ctx context.Context, db *dbutil.Database) error {
portalExists, err := db.TableExists(ctx, "portal")
if err != nil {
return fmt.Errorf("failed to check if portal table exists")
diff --git a/sqlstatestore/v06-displayname-disambiguation.go b/sqlstatestore/v06-displayname-disambiguation.go
new file mode 100644
index 00000000..d0d1d502
--- /dev/null
+++ b/sqlstatestore/v06-displayname-disambiguation.go
@@ -0,0 +1,55 @@
+// Copyright (c) 2024 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package sqlstatestore
+
+import (
+ "context"
+
+ "go.mau.fi/util/confusable"
+ "go.mau.fi/util/dbutil"
+
+ "maunium.net/go/mautrix/id"
+)
+
+type roomUserName struct {
+ RoomID id.RoomID
+ UserID id.UserID
+ Name string
+}
+
+func init() {
+ UpgradeTable.Register(-1, 6, 3, "Add disambiguation column for user profiles", dbutil.TxnModeOn, func(ctx context.Context, db *dbutil.Database) error {
+ _, err := db.Exec(ctx, `
+ ALTER TABLE mx_user_profile ADD COLUMN name_skeleton bytea;
+ CREATE INDEX mx_user_profile_membership_idx ON mx_user_profile (room_id, membership);
+ CREATE INDEX mx_user_profile_name_skeleton_idx ON mx_user_profile (room_id, name_skeleton);
+ `)
+ if err != nil {
+ return err
+ }
+ const ChunkSize = 1000
+ const GetEntriesChunkQuery = "SELECT room_id, user_id, displayname FROM mx_user_profile WHERE displayname<>'' LIMIT $1 OFFSET $2"
+ const SetSkeletonHashQuery = `UPDATE mx_user_profile SET name_skeleton = $3 WHERE room_id = $1 AND user_id = $2`
+ for offset := 0; ; offset += ChunkSize {
+ entries, err := dbutil.NewSimpleReflectRowIter[roomUserName](db.Query(ctx, GetEntriesChunkQuery, ChunkSize, offset)).AsList()
+ if err != nil {
+ return err
+ }
+ for _, entry := range entries {
+ skel := confusable.SkeletonHash(entry.Name)
+ _, err = db.Exec(ctx, SetSkeletonHashQuery, entry.RoomID, entry.UserID, skel[:])
+ if err != nil {
+ return err
+ }
+ }
+ if len(entries) < ChunkSize {
+ break
+ }
+ }
+ return nil
+ })
+}
diff --git a/sqlstatestore/v07-full-member-flag.sql b/sqlstatestore/v07-full-member-flag.sql
new file mode 100644
index 00000000..32f2ef6c
--- /dev/null
+++ b/sqlstatestore/v07-full-member-flag.sql
@@ -0,0 +1,2 @@
+-- v7 (compatible with v3+): Add flag for whether the full member list has been fetched
+ALTER TABLE mx_room_state ADD COLUMN members_fetched BOOLEAN NOT NULL DEFAULT false;
diff --git a/sqlstatestore/v08-create-event.sql b/sqlstatestore/v08-create-event.sql
new file mode 100644
index 00000000..9f1b55c9
--- /dev/null
+++ b/sqlstatestore/v08-create-event.sql
@@ -0,0 +1,2 @@
+-- v8 (compatible with v3+): Add create event to room state table
+ALTER TABLE mx_room_state ADD COLUMN create_event jsonb;
diff --git a/sqlstatestore/v09-clear-empty-room-ids.sql b/sqlstatestore/v09-clear-empty-room-ids.sql
new file mode 100644
index 00000000..ca951068
--- /dev/null
+++ b/sqlstatestore/v09-clear-empty-room-ids.sql
@@ -0,0 +1,3 @@
+-- v9 (compatible with v3+): Clear invalid rows
+DELETE FROM mx_room_state WHERE room_id='';
+DELETE FROM mx_user_profile WHERE room_id='' OR user_id='';
diff --git a/sqlstatestore/v10-join-rules.sql b/sqlstatestore/v10-join-rules.sql
new file mode 100644
index 00000000..3074c46a
--- /dev/null
+++ b/sqlstatestore/v10-join-rules.sql
@@ -0,0 +1,2 @@
+-- v10 (compatible with v3+): Add join rules to room state table
+ALTER TABLE mx_room_state ADD COLUMN join_rules jsonb;
diff --git a/statestore.go b/statestore.go
index 8fe5f8b3..2bd498dd 100644
--- a/statestore.go
+++ b/statestore.go
@@ -8,6 +8,7 @@ package mautrix
import (
"context"
+ "maps"
"sync"
"github.com/rs/zerolog"
@@ -26,21 +27,41 @@ type StateStore interface {
TryGetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error)
SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error
SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error
+ IsConfusableName(ctx context.Context, roomID id.RoomID, currentUser id.UserID, name string) ([]id.UserID, error)
ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error
+ ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error
SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error
GetPowerLevels(ctx context.Context, roomID id.RoomID) (*event.PowerLevelsEventContent, error)
+ SetCreate(ctx context.Context, evt *event.Event) error
+ GetCreate(ctx context.Context, roomID id.RoomID) (*event.Event, error)
+
+ GetJoinRules(ctx context.Context, roomID id.RoomID) (*event.JoinRulesEventContent, error)
+ SetJoinRules(ctx context.Context, roomID id.RoomID, content *event.JoinRulesEventContent) error
+
+ HasFetchedMembers(ctx context.Context, roomID id.RoomID) (bool, error)
+ MarkMembersFetched(ctx context.Context, roomID id.RoomID) error
+ GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error)
+
SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error
IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error)
GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error)
}
+type StateStoreUpdater interface {
+ UpdateState(ctx context.Context, evt *event.Event)
+}
+
func UpdateStateStore(ctx context.Context, store StateStore, evt *event.Event) {
if store == nil || evt == nil || evt.StateKey == nil {
return
}
+ if directUpdater, ok := store.(StateStoreUpdater); ok {
+ directUpdater.UpdateState(ctx, evt)
+ return
+ }
// We only care about events without a state key (power levels, encryption) or member events with state key
if evt.Type != event.StateMember && evt.GetStateKey() != "" {
return
@@ -53,6 +74,19 @@ func UpdateStateStore(ctx context.Context, store StateStore, evt *event.Event) {
err = store.SetPowerLevels(ctx, evt.RoomID, content)
case *event.EncryptionEventContent:
err = store.SetEncryptionEvent(ctx, evt.RoomID, content)
+ case *event.CreateEventContent:
+ err = store.SetCreate(ctx, evt)
+ case *event.JoinRulesEventContent:
+ err = store.SetJoinRules(ctx, evt.RoomID, content)
+ default:
+ switch evt.Type {
+ case event.StateMember, event.StatePowerLevels, event.StateEncryption, event.StateCreate:
+ zerolog.Ctx(ctx).Warn().
+ Stringer("event_id", evt.ID).
+ Str("event_type", evt.Type.Type).
+ Type("content_type", evt.Content.Parsed).
+ Msg("Got known event type with unknown content type in UpdateStateStore")
+ }
}
if err != nil {
zerolog.Ctx(ctx).Warn().Err(err).
@@ -72,23 +106,30 @@ func (cli *Client) StateStoreSyncHandler(ctx context.Context, evt *event.Event)
}
type MemoryStateStore struct {
- Registrations map[id.UserID]bool `json:"registrations"`
- Members map[id.RoomID]map[id.UserID]*event.MemberEventContent `json:"memberships"`
- PowerLevels map[id.RoomID]*event.PowerLevelsEventContent `json:"power_levels"`
- Encryption map[id.RoomID]*event.EncryptionEventContent `json:"encryption"`
+ Registrations map[id.UserID]bool `json:"registrations"`
+ Members map[id.RoomID]map[id.UserID]*event.MemberEventContent `json:"memberships"`
+ MembersFetched map[id.RoomID]bool `json:"members_fetched"`
+ PowerLevels map[id.RoomID]*event.PowerLevelsEventContent `json:"power_levels"`
+ Encryption map[id.RoomID]*event.EncryptionEventContent `json:"encryption"`
+ Create map[id.RoomID]*event.Event `json:"create"`
+ JoinRules map[id.RoomID]*event.JoinRulesEventContent `json:"join_rules"`
registrationsLock sync.RWMutex
membersLock sync.RWMutex
powerLevelsLock sync.RWMutex
encryptionLock sync.RWMutex
+ joinRulesLock sync.RWMutex
}
func NewMemoryStateStore() StateStore {
return &MemoryStateStore{
- Registrations: make(map[id.UserID]bool),
- Members: make(map[id.RoomID]map[id.UserID]*event.MemberEventContent),
- PowerLevels: make(map[id.RoomID]*event.PowerLevelsEventContent),
- Encryption: make(map[id.RoomID]*event.EncryptionEventContent),
+ Registrations: make(map[id.UserID]bool),
+ Members: make(map[id.RoomID]map[id.UserID]*event.MemberEventContent),
+ MembersFetched: make(map[id.RoomID]bool),
+ PowerLevels: make(map[id.RoomID]*event.PowerLevelsEventContent),
+ Encryption: make(map[id.RoomID]*event.EncryptionEventContent),
+ Create: make(map[id.RoomID]*event.Event),
+ JoinRules: make(map[id.RoomID]*event.JoinRulesEventContent),
}
}
@@ -143,6 +184,11 @@ func (store *MemoryStateStore) GetMember(ctx context.Context, roomID id.RoomID,
return member, err
}
+func (store *MemoryStateStore) IsConfusableName(ctx context.Context, roomID id.RoomID, currentUser id.UserID, name string) ([]id.UserID, error) {
+ // TODO implement?
+ return nil, nil
+}
+
func (store *MemoryStateStore) TryGetMember(_ context.Context, roomID id.RoomID, userID id.UserID) (member *event.MemberEventContent, err error) {
store.membersLock.RLock()
defer store.membersLock.RUnlock()
@@ -223,9 +269,40 @@ func (store *MemoryStateStore) ClearCachedMembers(_ context.Context, roomID id.R
}
}
}
+ store.MembersFetched[roomID] = false
return nil
}
+func (store *MemoryStateStore) HasFetchedMembers(ctx context.Context, roomID id.RoomID) (bool, error) {
+ store.membersLock.RLock()
+ defer store.membersLock.RUnlock()
+ return store.MembersFetched[roomID], nil
+}
+
+func (store *MemoryStateStore) MarkMembersFetched(ctx context.Context, roomID id.RoomID) error {
+ store.membersLock.Lock()
+ defer store.membersLock.Unlock()
+ store.MembersFetched[roomID] = true
+ return nil
+}
+
+func (store *MemoryStateStore) ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error {
+ _ = store.ClearCachedMembers(ctx, roomID, onlyMemberships...)
+ for _, evt := range evts {
+ UpdateStateStore(ctx, store, evt)
+ }
+ if len(onlyMemberships) == 0 {
+ _ = store.MarkMembersFetched(ctx, roomID)
+ }
+ return nil
+}
+
+func (store *MemoryStateStore) GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) {
+ store.membersLock.RLock()
+ defer store.membersLock.RUnlock()
+ return maps.Clone(store.Members[roomID]), nil
+}
+
func (store *MemoryStateStore) SetPowerLevels(_ context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error {
store.powerLevelsLock.Lock()
store.PowerLevels[roomID] = levels
@@ -236,6 +313,9 @@ func (store *MemoryStateStore) SetPowerLevels(_ context.Context, roomID id.RoomI
func (store *MemoryStateStore) GetPowerLevels(_ context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) {
store.powerLevelsLock.RLock()
levels = store.PowerLevels[roomID]
+ if levels != nil && levels.CreateEvent == nil {
+ levels.CreateEvent = store.Create[roomID]
+ }
store.powerLevelsLock.RUnlock()
return
}
@@ -252,6 +332,23 @@ func (store *MemoryStateStore) HasPowerLevel(ctx context.Context, roomID id.Room
return exerrors.Must(store.GetPowerLevel(ctx, roomID, userID)) >= exerrors.Must(store.GetPowerLevelRequirement(ctx, roomID, eventType)), nil
}
+func (store *MemoryStateStore) SetCreate(ctx context.Context, evt *event.Event) error {
+ store.powerLevelsLock.Lock()
+ store.Create[evt.RoomID] = evt
+ if pls, ok := store.PowerLevels[evt.RoomID]; ok && pls.CreateEvent == nil {
+ pls.CreateEvent = evt
+ }
+ store.powerLevelsLock.Unlock()
+ return nil
+}
+
+func (store *MemoryStateStore) GetCreate(ctx context.Context, roomID id.RoomID) (*event.Event, error) {
+ store.powerLevelsLock.RLock()
+ evt := store.Create[roomID]
+ store.powerLevelsLock.RUnlock()
+ return evt, nil
+}
+
func (store *MemoryStateStore) SetEncryptionEvent(_ context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error {
store.encryptionLock.Lock()
store.Encryption[roomID] = content
@@ -265,7 +362,31 @@ func (store *MemoryStateStore) GetEncryptionEvent(_ context.Context, roomID id.R
return store.Encryption[roomID], nil
}
+func (store *MemoryStateStore) SetJoinRules(ctx context.Context, roomID id.RoomID, content *event.JoinRulesEventContent) error {
+ store.joinRulesLock.Lock()
+ store.JoinRules[roomID] = content
+ store.joinRulesLock.Unlock()
+ return nil
+}
+
+func (store *MemoryStateStore) GetJoinRules(ctx context.Context, roomID id.RoomID) (*event.JoinRulesEventContent, error) {
+ store.joinRulesLock.RLock()
+ defer store.joinRulesLock.RUnlock()
+ return store.JoinRules[roomID], nil
+}
+
func (store *MemoryStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error) {
cfg, err := store.GetEncryptionEvent(ctx, roomID)
return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1, err
}
+
+func (store *MemoryStateStore) FindSharedRooms(ctx context.Context, userID id.UserID) (rooms []id.RoomID, err error) {
+ store.membersLock.RLock()
+ defer store.membersLock.RUnlock()
+ for roomID, members := range store.Members {
+ if _, ok := members[userID]; ok {
+ rooms = append(rooms, roomID)
+ }
+ }
+ return rooms, nil
+}
diff --git a/synapseadmin/client.go b/synapseadmin/client.go
index 775b4b13..6925ca7d 100644
--- a/synapseadmin/client.go
+++ b/synapseadmin/client.go
@@ -14,9 +14,9 @@ import (
//
// https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/index.html
type Client struct {
- *mautrix.Client
+ Client *mautrix.Client
}
func (cli *Client) BuildAdminURL(path ...any) string {
- return cli.BuildURL(mautrix.SynapseAdminURLPath(path))
+ return cli.Client.BuildURL(mautrix.SynapseAdminURLPath(path))
}
diff --git a/synapseadmin/register.go b/synapseadmin/register.go
index d7a94f6f..05e0729a 100644
--- a/synapseadmin/register.go
+++ b/synapseadmin/register.go
@@ -73,11 +73,7 @@ func (req *ReqSharedSecretRegister) Sign(secret string) string {
// This does not need to be called manually as SharedSecretRegister will automatically call this if no nonce is provided.
func (cli *Client) GetRegisterNonce(ctx context.Context) (string, error) {
var resp respGetRegisterNonce
- _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{
- Method: http.MethodGet,
- URL: cli.BuildURL(mautrix.SynapseAdminURLPath{"v1", "register"}),
- ResponseJSON: &resp,
- })
+ _, err := cli.Client.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v1", "register"), nil, &resp)
if err != nil {
return "", err
}
@@ -97,12 +93,7 @@ func (cli *Client) SharedSecretRegister(ctx context.Context, sharedSecret string
}
req.SHA1Checksum = req.Sign(sharedSecret)
var resp mautrix.RespRegister
- _, err = cli.MakeFullRequest(ctx, mautrix.FullRequest{
- Method: http.MethodPost,
- URL: cli.BuildURL(mautrix.SynapseAdminURLPath{"v1", "register"}),
- RequestJSON: req,
- ResponseJSON: &resp,
- })
+ _, err = cli.Client.MakeRequest(ctx, http.MethodPost, cli.BuildAdminURL("v1", "register"), &req, &resp)
if err != nil {
return nil, err
}
diff --git a/synapseadmin/roomapi.go b/synapseadmin/roomapi.go
new file mode 100644
index 00000000..0925b748
--- /dev/null
+++ b/synapseadmin/roomapi.go
@@ -0,0 +1,250 @@
+// Copyright (c) 2023 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package synapseadmin
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "strconv"
+
+ "maunium.net/go/mautrix"
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+)
+
+type ReqListRoom struct {
+ SearchTerm string
+ OrderBy string
+ Direction mautrix.Direction
+ From int
+ Limit int
+}
+
+func (req *ReqListRoom) BuildQuery() map[string]string {
+ query := map[string]string{
+ "from": strconv.Itoa(req.From),
+ }
+ if req.SearchTerm != "" {
+ query["search_term"] = req.SearchTerm
+ }
+ if req.OrderBy != "" {
+ query["order_by"] = req.OrderBy
+ }
+ if req.Direction != 0 {
+ query["dir"] = string(req.Direction)
+ }
+ if req.Limit != 0 {
+ query["limit"] = strconv.Itoa(req.Limit)
+ }
+ return query
+}
+
+type RoomInfo struct {
+ RoomID id.RoomID `json:"room_id"`
+ Name string `json:"name"`
+ CanonicalAlias id.RoomAlias `json:"canonical_alias"`
+ JoinedMembers int `json:"joined_members"`
+ JoinedLocalMembers int `json:"joined_local_members"`
+ Version string `json:"version"`
+ Creator id.UserID `json:"creator"`
+ Encryption id.Algorithm `json:"encryption"`
+ Federatable bool `json:"federatable"`
+ Public bool `json:"public"`
+ JoinRules event.JoinRule `json:"join_rules"`
+ GuestAccess event.GuestAccess `json:"guest_access"`
+ HistoryVisibility event.HistoryVisibility `json:"history_visibility"`
+ StateEvents int `json:"state_events"`
+ RoomType event.RoomType `json:"room_type"`
+}
+
+type RespListRooms struct {
+ Rooms []RoomInfo `json:"rooms"`
+ Offset int `json:"offset"`
+ TotalRooms int `json:"total_rooms"`
+ NextBatch int `json:"next_batch"`
+ PrevBatch int `json:"prev_batch"`
+}
+
+// ListRooms returns a list of rooms on the server.
+//
+// https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#list-room-api
+func (cli *Client) ListRooms(ctx context.Context, req ReqListRoom) (RespListRooms, error) {
+ var resp RespListRooms
+ reqURL := cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery())
+ _, err := cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp)
+ return resp, err
+}
+
+func (cli *Client) RoomInfo(ctx context.Context, roomID id.RoomID) (resp *RoomInfo, err error) {
+ reqURL := cli.BuildAdminURL("v1", "rooms", roomID)
+ _, err = cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp)
+ return
+}
+
+type RespRoomMessages = mautrix.RespMessages
+
+// RoomMessages returns a list of messages in a room.
+//
+// https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#room-messages-api
+func (cli *Client) RoomMessages(ctx context.Context, roomID id.RoomID, from, to string, dir mautrix.Direction, filter *mautrix.FilterPart, limit int) (resp *RespRoomMessages, err error) {
+ query := map[string]string{
+ "from": from,
+ "dir": string(dir),
+ }
+ if filter != nil {
+ filterJSON, err := json.Marshal(filter)
+ if err != nil {
+ return nil, err
+ }
+ query["filter"] = string(filterJSON)
+ }
+ if to != "" {
+ query["to"] = to
+ }
+ if limit != 0 {
+ query["limit"] = strconv.Itoa(limit)
+ }
+ urlPath := cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms", roomID, "messages"}, query)
+ _, err = cli.Client.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
+ return resp, err
+}
+
+type ReqDeleteRoom struct {
+ Purge bool `json:"purge,omitempty"`
+ ForcePurge bool `json:"force_purge,omitempty"`
+ Block bool `json:"block,omitempty"`
+ Message string `json:"message,omitempty"`
+ RoomName string `json:"room_name,omitempty"`
+ NewRoomUserID id.UserID `json:"new_room_user_id,omitempty"`
+}
+
+type RespDeleteRoom struct {
+ DeleteID string `json:"delete_id"`
+}
+
+type RespDeleteRoomResult struct {
+ KickedUsers []id.UserID `json:"kicked_users,omitempty"`
+ FailedToKickUsers []id.UserID `json:"failed_to_kick_users,omitempty"`
+ LocalAliases []id.RoomAlias `json:"local_aliases,omitempty"`
+ NewRoomID id.RoomID `json:"new_room_id,omitempty"`
+}
+
+type RespDeleteRoomStatus struct {
+ Status string `json:"status,omitempty"`
+ Error string `json:"error,omitempty"`
+ ShutdownRoom RespDeleteRoomResult `json:"shutdown_room,omitempty"`
+}
+
+// DeleteRoom deletes a room from the server, optionally blocking it and/or purging all data from the database.
+//
+// This calls the async version of the endpoint, which will return immediately and delete the room in the background.
+//
+// https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#version-2-new-version
+func (cli *Client) DeleteRoom(ctx context.Context, roomID id.RoomID, req ReqDeleteRoom) (RespDeleteRoom, error) {
+ reqURL := cli.BuildAdminURL("v2", "rooms", roomID)
+ var resp RespDeleteRoom
+ _, err := cli.Client.MakeRequest(ctx, http.MethodDelete, reqURL, &req, &resp)
+ return resp, err
+}
+
+func (cli *Client) DeleteRoomStatus(ctx context.Context, deleteID string) (resp RespDeleteRoomStatus, err error) {
+ reqURL := cli.BuildAdminURL("v2", "rooms", "delete_status", deleteID)
+ _, err = cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp)
+ return
+}
+
+// DeleteRoomSync deletes a room from the server, optionally blocking it and/or purging all data from the database.
+//
+// This calls the synchronous version of the endpoint, which will block until the room is deleted.
+//
+// https://element-hq.github.io/synapse/latest/admin_api/rooms.html#version-1-old-version
+func (cli *Client) DeleteRoomSync(ctx context.Context, roomID id.RoomID, req ReqDeleteRoom) (resp RespDeleteRoomResult, err error) {
+ reqURL := cli.BuildAdminURL("v1", "rooms", roomID)
+ httpClient := &http.Client{}
+ _, err = cli.Client.MakeFullRequest(ctx, mautrix.FullRequest{
+ Method: http.MethodDelete,
+ URL: reqURL,
+ RequestJSON: &req,
+ ResponseJSON: &resp,
+ MaxAttempts: 1,
+ // Use a fresh HTTP client without timeouts
+ Client: httpClient,
+ })
+ httpClient.CloseIdleConnections()
+ return
+}
+
+type RespRoomsMembers struct {
+ Members []id.UserID `json:"members"`
+ Total int `json:"total"`
+}
+
+// RoomMembers gets the full list of members in a room.
+//
+// https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#room-members-api
+func (cli *Client) RoomMembers(ctx context.Context, roomID id.RoomID) (RespRoomsMembers, error) {
+ reqURL := cli.BuildAdminURL("v1", "rooms", roomID, "members")
+ var resp RespRoomsMembers
+ _, err := cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp)
+ return resp, err
+}
+
+type ReqMakeRoomAdmin struct {
+ UserID id.UserID `json:"user_id"`
+}
+
+// MakeRoomAdmin promotes a user to admin in a room. This requires that a local user has permission to promote users in the room.
+//
+// https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#make-room-admin-api
+func (cli *Client) MakeRoomAdmin(ctx context.Context, roomIDOrAlias string, req ReqMakeRoomAdmin) error {
+ reqURL := cli.BuildAdminURL("v1", "rooms", roomIDOrAlias, "make_room_admin")
+ _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil)
+ return err
+}
+
+type ReqJoinUserToRoom struct {
+ UserID id.UserID `json:"user_id"`
+}
+
+// JoinUserToRoom makes a local user join the given room.
+//
+// https://matrix-org.github.io/synapse/latest/admin_api/room_membership.html
+func (cli *Client) JoinUserToRoom(ctx context.Context, roomID id.RoomID, req ReqJoinUserToRoom) error {
+ reqURL := cli.BuildAdminURL("v1", "join", roomID)
+ _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil)
+ return err
+}
+
+type ReqBlockRoom struct {
+ Block bool `json:"block"`
+}
+
+// BlockRoom blocks or unblocks a room.
+//
+// https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#block-room-api
+func (cli *Client) BlockRoom(ctx context.Context, roomID id.RoomID, req ReqBlockRoom) error {
+ reqURL := cli.BuildAdminURL("v1", "rooms", roomID, "block")
+ _, err := cli.Client.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil)
+ return err
+}
+
+// RoomsBlockResponse represents the response containing wether a room is blocked or not
+type RoomsBlockResponse struct {
+ Block bool `json:"block"`
+ UserID id.UserID `json:"user_id"`
+}
+
+// GetRoomBlockStatus gets whether a room is currently blocked.
+//
+// https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#get-block-status
+func (cli *Client) GetRoomBlockStatus(ctx context.Context, roomID id.RoomID) (RoomsBlockResponse, error) {
+ var resp RoomsBlockResponse
+ reqURL := cli.BuildAdminURL("v1", "rooms", roomID, "block")
+ _, err := cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp)
+ return resp, err
+}
diff --git a/synapseadmin/userapi.go b/synapseadmin/userapi.go
index aa1ce2a7..b1de55b6 100644
--- a/synapseadmin/userapi.go
+++ b/synapseadmin/userapi.go
@@ -21,7 +21,6 @@ import (
type ReqResetPassword struct {
// The user whose password to reset.
UserID id.UserID `json:"-"`
-
// The new password for the user. Required.
NewPassword string `json:"new_password"`
// Whether all the user's existing devices should be logged out after the password change.
@@ -33,11 +32,7 @@ type ReqResetPassword struct {
// https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#reset-password
func (cli *Client) ResetPassword(ctx context.Context, req ReqResetPassword) error {
reqURL := cli.BuildAdminURL("v1", "reset_password", req.UserID)
- _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{
- Method: http.MethodPost,
- URL: reqURL,
- RequestJSON: &req,
- })
+ _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil)
return err
}
@@ -48,12 +43,8 @@ func (cli *Client) ResetPassword(ctx context.Context, req ReqResetPassword) erro
//
// https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#check-username-availability
func (cli *Client) UsernameAvailable(ctx context.Context, username string) (resp *mautrix.RespRegisterAvailable, err error) {
- u := cli.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "username_available"}, map[string]string{"username": username})
- _, err = cli.MakeFullRequest(ctx, mautrix.FullRequest{
- Method: http.MethodGet,
- URL: u,
- ResponseJSON: &resp,
- })
+ u := cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "username_available"}, map[string]string{"username": username})
+ _, err = cli.Client.MakeRequest(ctx, http.MethodGet, u, nil, &resp)
if err == nil && !resp.Available {
err = fmt.Errorf(`request returned OK status without "available": true`)
}
@@ -74,11 +65,7 @@ type RespListDevices struct {
//
// https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#list-all-devices
func (cli *Client) ListDevices(ctx context.Context, userID id.UserID) (resp *RespListDevices, err error) {
- _, err = cli.MakeFullRequest(ctx, mautrix.FullRequest{
- Method: http.MethodGet,
- URL: cli.BuildAdminURL("v2", "users", userID, "devices"),
- ResponseJSON: &resp,
- })
+ _, err = cli.Client.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v2", "users", userID, "devices"), nil, &resp)
return
}
@@ -86,7 +73,7 @@ type RespUserInfo struct {
UserID id.UserID `json:"name"`
DisplayName string `json:"displayname"`
AvatarURL id.ContentURIString `json:"avatar_url"`
- Guest int `json:"is_guest"`
+ Guest bool `json:"is_guest"`
Admin bool `json:"admin"`
Deactivated bool `json:"deactivated"`
Erased bool `json:"erased"`
@@ -102,10 +89,88 @@ type RespUserInfo struct {
//
// https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#query-user-account
func (cli *Client) GetUserInfo(ctx context.Context, userID id.UserID) (resp *RespUserInfo, err error) {
- _, err = cli.MakeFullRequest(ctx, mautrix.FullRequest{
- Method: http.MethodGet,
- URL: cli.BuildAdminURL("v2", "users", userID),
- ResponseJSON: &resp,
- })
+ _, err = cli.Client.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v2", "users", userID), nil, &resp)
+ return
+}
+
+type ReqDeleteUser struct {
+ Erase bool `json:"erase"`
+}
+
+// DeactivateAccount deactivates a specific local user account.
+//
+// https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#deactivate-account
+func (cli *Client) DeactivateAccount(ctx context.Context, userID id.UserID, req ReqDeleteUser) error {
+ reqURL := cli.BuildAdminURL("v1", "deactivate", userID)
+ _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil)
+ return err
+}
+
+type ReqSuspendUser struct {
+ Suspend bool `json:"suspend"`
+}
+
+// SuspendAccount suspends or unsuspends a specific local user account.
+//
+// https://element-hq.github.io/synapse/latest/admin_api/user_admin_api.html#suspendunsuspend-account
+func (cli *Client) SuspendAccount(ctx context.Context, userID id.UserID, req ReqSuspendUser) error {
+ reqURL := cli.BuildAdminURL("v1", "suspend", userID)
+ _, err := cli.Client.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil)
+ return err
+}
+
+type ReqCreateOrModifyAccount struct {
+ Password string `json:"password,omitempty"`
+ LogoutDevices *bool `json:"logout_devices,omitempty"`
+
+ Deactivated *bool `json:"deactivated,omitempty"`
+ Admin *bool `json:"admin,omitempty"`
+ Locked *bool `json:"locked,omitempty"`
+
+ Displayname string `json:"displayname,omitempty"`
+ AvatarURL id.ContentURIString `json:"avatar_url,omitempty"`
+ UserType string `json:"user_type,omitempty"`
+}
+
+// CreateOrModifyAccount creates or modifies an account on the server.
+//
+// https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#create-or-modify-account
+func (cli *Client) CreateOrModifyAccount(ctx context.Context, userID id.UserID, req ReqCreateOrModifyAccount) error {
+ reqURL := cli.BuildAdminURL("v2", "users", userID)
+ _, err := cli.Client.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil)
+ return err
+}
+
+type RatelimitOverride struct {
+ MessagesPerSecond int `json:"messages_per_second"`
+ BurstCount int `json:"burst_count"`
+}
+
+type ReqSetRatelimit = RatelimitOverride
+
+// SetUserRatelimit overrides the message sending ratelimit for a specific user.
+//
+// https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#set-ratelimit
+func (cli *Client) SetUserRatelimit(ctx context.Context, userID id.UserID, req ReqSetRatelimit) error {
+ reqURL := cli.BuildAdminURL("v1", "users", userID, "override_ratelimit")
+ _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil)
+ return err
+}
+
+type RespUserRatelimit = RatelimitOverride
+
+// GetUserRatelimit gets the ratelimit override for the given user.
+//
+// https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#get-status-of-ratelimit
+func (cli *Client) GetUserRatelimit(ctx context.Context, userID id.UserID) (resp RespUserRatelimit, err error) {
+ _, err = cli.Client.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v1", "users", userID, "override_ratelimit"), nil, &resp)
+ return
+}
+
+// DeleteUserRatelimit deletes the ratelimit override for the given user, returning them to the default ratelimits.
+//
+// https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#delete-ratelimit
+func (cli *Client) DeleteUserRatelimit(ctx context.Context, userID id.UserID) (err error) {
+ _, err = cli.Client.MakeRequest(ctx, http.MethodDelete, cli.BuildAdminURL("v1", "users", userID, "override_ratelimit"), nil, nil)
return
}
diff --git a/sync.go b/sync.go
index d4208404..598df8e0 100644
--- a/sync.go
+++ b/sync.go
@@ -90,6 +90,7 @@ func (s *DefaultSyncer) ProcessResponse(ctx context.Context, res *RespSync, sinc
err = fmt.Errorf("ProcessResponse panicked! since=%s panic=%s\n%s", since, r, debug.Stack())
}
}()
+ ctx = context.WithValue(ctx, SyncTokenContextKey, since)
for _, listener := range s.syncListeners {
if !listener(ctx, res, since) {
@@ -97,33 +98,38 @@ func (s *DefaultSyncer) ProcessResponse(ctx context.Context, res *RespSync, sinc
}
}
- s.processSyncEvents(ctx, "", res.ToDevice.Events, event.SourceToDevice)
- s.processSyncEvents(ctx, "", res.Presence.Events, event.SourcePresence)
- s.processSyncEvents(ctx, "", res.AccountData.Events, event.SourceAccountData)
+ s.processSyncEvents(ctx, "", res.ToDevice.Events, event.SourceToDevice, false)
+ s.processSyncEvents(ctx, "", res.Presence.Events, event.SourcePresence, false)
+ s.processSyncEvents(ctx, "", res.AccountData.Events, event.SourceAccountData, false)
for roomID, roomData := range res.Rooms.Join {
- s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceJoin|event.SourceState)
- s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceJoin|event.SourceTimeline)
- s.processSyncEvents(ctx, roomID, roomData.Ephemeral.Events, event.SourceJoin|event.SourceEphemeral)
- s.processSyncEvents(ctx, roomID, roomData.AccountData.Events, event.SourceJoin|event.SourceAccountData)
+ if roomData.StateAfter == nil {
+ s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceJoin|event.SourceState, false)
+ s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceJoin|event.SourceTimeline, false)
+ } else {
+ s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceJoin|event.SourceTimeline, true)
+ s.processSyncEvents(ctx, roomID, roomData.StateAfter.Events, event.SourceJoin|event.SourceState, false)
+ }
+ s.processSyncEvents(ctx, roomID, roomData.Ephemeral.Events, event.SourceJoin|event.SourceEphemeral, false)
+ s.processSyncEvents(ctx, roomID, roomData.AccountData.Events, event.SourceJoin|event.SourceAccountData, false)
}
for roomID, roomData := range res.Rooms.Invite {
- s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceInvite|event.SourceState)
+ s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceInvite|event.SourceState, false)
}
for roomID, roomData := range res.Rooms.Leave {
- s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceLeave|event.SourceState)
- s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceLeave|event.SourceTimeline)
+ s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceLeave|event.SourceState, false)
+ s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceLeave|event.SourceTimeline, false)
}
return
}
-func (s *DefaultSyncer) processSyncEvents(ctx context.Context, roomID id.RoomID, events []*event.Event, source event.Source) {
+func (s *DefaultSyncer) processSyncEvents(ctx context.Context, roomID id.RoomID, events []*event.Event, source event.Source, ignoreState bool) {
for _, evt := range events {
- s.processSyncEvent(ctx, roomID, evt, source)
+ s.processSyncEvent(ctx, roomID, evt, source, ignoreState)
}
}
-func (s *DefaultSyncer) processSyncEvent(ctx context.Context, roomID id.RoomID, evt *event.Event, source event.Source) {
+func (s *DefaultSyncer) processSyncEvent(ctx context.Context, roomID id.RoomID, evt *event.Event, source event.Source, ignoreState bool) {
evt.RoomID = roomID
// Ensure the type class is correct. It's safe to mutate the class since the event type is not a pointer.
@@ -149,6 +155,7 @@ func (s *DefaultSyncer) processSyncEvent(ctx context.Context, roomID id.RoomID,
}
evt.Mautrix.EventSource = source
+ evt.Mautrix.IgnoreState = ignoreState
s.Dispatch(ctx, evt)
}
@@ -191,8 +198,8 @@ func (s *DefaultSyncer) OnFailedSync(res *RespSync, err error) (time.Duration, e
}
var defaultFilter = Filter{
- Room: RoomFilter{
- Timeline: FilterPart{
+ Room: &RoomFilter{
+ Timeline: &FilterPart{
Limit: 50,
},
},
@@ -257,7 +264,7 @@ func dontProcessOldEvents(userID id.UserID, resp *RespSync, since string) bool {
// cli.Syncer.(mautrix.ExtensibleSyncer).OnSync(cli.MoveInviteState)
func (cli *Client) MoveInviteState(ctx context.Context, resp *RespSync, _ string) bool {
for _, meta := range resp.Rooms.Invite {
- var inviteState []event.StrippedState
+ var inviteState []*event.Event
var inviteEvt *event.Event
for _, evt := range meta.State.Events {
if evt.Type == event.StateMember && evt.GetStateKey() == cli.UserID.String() {
@@ -265,12 +272,7 @@ func (cli *Client) MoveInviteState(ctx context.Context, resp *RespSync, _ string
} else {
evt.Type.Class = event.StateEventType
_ = evt.Content.ParseRaw(evt.Type)
- inviteState = append(inviteState, event.StrippedState{
- Content: evt.Content,
- Type: evt.Type,
- StateKey: evt.GetStateKey(),
- Sender: evt.Sender,
- })
+ inviteState = append(inviteState, evt)
}
}
if inviteEvt != nil {
diff --git a/url.go b/url.go
index 4646b442..91b3d49d 100644
--- a/url.go
+++ b/url.go
@@ -57,13 +57,13 @@ func BuildURL(baseURL *url.URL, path ...any) *url.URL {
// BuildURL builds a URL with the Client's homeserver and appservice user ID set already.
func (cli *Client) BuildURL(urlPath PrefixableURLPath) string {
- return cli.BuildURLWithQuery(urlPath, nil)
+ return cli.BuildURLWithFullQuery(urlPath, nil)
}
// BuildClientURL builds a URL with the Client's homeserver and appservice user ID set already.
// This method also automatically prepends the client API prefix (/_matrix/client).
func (cli *Client) BuildClientURL(urlPath ...any) string {
- return cli.BuildURLWithQuery(ClientURLPath(urlPath), nil)
+ return cli.BuildURLWithFullQuery(ClientURLPath(urlPath), nil)
}
type PrefixableURLPath interface {
@@ -97,15 +97,30 @@ func (saup SynapseAdminURLPath) FullPath() []any {
// BuildURLWithQuery builds a URL with query parameters in addition to the Client's homeserver
// and appservice user ID set already.
func (cli *Client) BuildURLWithQuery(urlPath PrefixableURLPath, urlQuery map[string]string) string {
+ return cli.BuildURLWithFullQuery(urlPath, func(q url.Values) {
+ for k, v := range urlQuery {
+ q.Set(k, v)
+ }
+ })
+}
+
+// BuildURLWithQuery builds a URL with query parameters in addition to the Client's homeserver
+// and appservice user ID set already.
+func (cli *Client) BuildURLWithFullQuery(urlPath PrefixableURLPath, fn func(q url.Values)) string {
+ if cli == nil {
+ return "client is nil"
+ }
hsURL := *BuildURL(cli.HomeserverURL, urlPath.FullPath()...)
query := hsURL.Query()
if cli.SetAppServiceUserID {
query.Set("user_id", string(cli.UserID))
}
- if urlQuery != nil {
- for k, v := range urlQuery {
- query.Set(k, v)
- }
+ if cli.SetAppServiceDeviceID && cli.DeviceID != "" {
+ query.Set("device_id", string(cli.DeviceID))
+ query.Set("org.matrix.msc3202.device_id", string(cli.DeviceID))
+ }
+ if fn != nil {
+ fn(query)
}
hsURL.RawQuery = query.Encode()
return hsURL.String()
diff --git a/version.go b/version.go
index d92a7977..f00bbf39 100644
--- a/version.go
+++ b/version.go
@@ -4,10 +4,11 @@ import (
"fmt"
"regexp"
"runtime"
+ "runtime/debug"
"strings"
)
-const Version = "v0.17.0"
+const Version = "v0.26.3"
var GoModVersion = ""
var Commit = ""
@@ -15,11 +16,20 @@ var VersionWithCommit = Version
var DefaultUserAgent = "mautrix-go/" + Version + " go/" + strings.TrimPrefix(runtime.Version(), "go")
-var goModVersionRegex = regexp.MustCompile(`v.+\d{14}-([0-9a-f]{12})`)
-
func init() {
+ if GoModVersion == "" {
+ info, _ := debug.ReadBuildInfo()
+ if info != nil {
+ for _, mod := range info.Deps {
+ if mod.Path == "maunium.net/go/mautrix" {
+ GoModVersion = mod.Version
+ break
+ }
+ }
+ }
+ }
if GoModVersion != "" {
- match := goModVersionRegex.FindStringSubmatch(GoModVersion)
+ match := regexp.MustCompile(`v.+\d{14}-([0-9a-f]{12})`).FindStringSubmatch(GoModVersion)
if match != nil {
Commit = match[1]
}
diff --git a/versions.go b/versions.go
index d3dd3c67..61b2e4ea 100644
--- a/versions.go
+++ b/versions.go
@@ -19,6 +19,9 @@ type RespVersions struct {
}
func (versions *RespVersions) ContainsFunc(match func(found SpecVersion) bool) bool {
+ if versions == nil {
+ return false
+ }
for _, found := range versions.Versions {
if match(found) {
return true
@@ -40,6 +43,9 @@ func (versions *RespVersions) ContainsGreaterOrEqual(version SpecVersion) bool {
}
func (versions *RespVersions) GetLatest() (latest SpecVersion) {
+ if versions == nil {
+ return
+ }
for _, ver := range versions.Versions {
if ver.GreaterThan(latest) {
latest = ver
@@ -54,16 +60,34 @@ type UnstableFeature struct {
}
var (
- FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17}
+ FeatureAsyncUploads = UnstableFeature{UnstableFlag: "fi.mau.msc2246.stable", SpecVersion: SpecV17}
+ FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17}
+ FeatureAuthenticatedMedia = UnstableFeature{UnstableFlag: "org.matrix.msc3916.stable", SpecVersion: SpecV111}
+ FeatureUnstableMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms"}
+ FeatureStableMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms.stable" /*, SpecVersion: SpecV118*/}
+ FeatureUserRedaction = UnstableFeature{UnstableFlag: "org.matrix.msc4194"}
+ FeatureViewRedactedContent = UnstableFeature{UnstableFlag: "fi.mau.msc2815"}
+ FeatureUnstableAccountModeration = UnstableFeature{UnstableFlag: "uk.timedout.msc4323"}
+ FeatureStableAccountModeration = UnstableFeature{UnstableFlag: "uk.timedout.msc4323.stable" /*, SpecVersion: SpecV118*/}
+ FeatureUnstableProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133"}
+ FeatureArbitraryProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133.stable", SpecVersion: SpecV116}
+ FeatureRedactSendAsEvent = UnstableFeature{UnstableFlag: "com.beeper.msc4169"}
- BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"}
- BeeperFeatureBatchSending = UnstableFeature{UnstableFlag: "com.beeper.batch_sending"}
- BeeperFeatureRoomYeeting = UnstableFeature{UnstableFlag: "com.beeper.room_yeeting"}
- BeeperFeatureAutojoinInvites = UnstableFeature{UnstableFlag: "com.beeper.room_create_autojoin_invites"}
- BeeperFeatureArbitraryProfileMeta = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_profile_meta"}
+ BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"}
+ BeeperFeatureBatchSending = UnstableFeature{UnstableFlag: "com.beeper.batch_sending"}
+ BeeperFeatureRoomYeeting = UnstableFeature{UnstableFlag: "com.beeper.room_yeeting"}
+ BeeperFeatureAutojoinInvites = UnstableFeature{UnstableFlag: "com.beeper.room_create_autojoin_invites"}
+ BeeperFeatureArbitraryProfileMeta = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_profile_meta"}
+ BeeperFeatureAccountDataMute = UnstableFeature{UnstableFlag: "com.beeper.account_data_mute"}
+ BeeperFeatureInboxState = UnstableFeature{UnstableFlag: "com.beeper.inbox_state"}
+ BeeperFeatureArbitraryMemberChange = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_member_change"}
+ BeeperFeatureEphemeralEvents = UnstableFeature{UnstableFlag: "com.beeper.ephemeral"}
)
func (versions *RespVersions) Supports(feature UnstableFeature) bool {
+ if versions == nil {
+ return false
+ }
return versions.UnstableFeatures[feature.UnstableFlag] ||
(!feature.SpecVersion.IsEmpty() && versions.ContainsGreaterOrEqual(feature.SpecVersion))
}
@@ -95,6 +119,14 @@ var (
SpecV17 = MustParseSpecVersion("v1.7")
SpecV18 = MustParseSpecVersion("v1.8")
SpecV19 = MustParseSpecVersion("v1.9")
+ SpecV110 = MustParseSpecVersion("v1.10")
+ SpecV111 = MustParseSpecVersion("v1.11")
+ SpecV112 = MustParseSpecVersion("v1.12")
+ SpecV113 = MustParseSpecVersion("v1.13")
+ SpecV114 = MustParseSpecVersion("v1.14")
+ SpecV115 = MustParseSpecVersion("v1.15")
+ SpecV116 = MustParseSpecVersion("v1.16")
+ SpecV117 = MustParseSpecVersion("v1.17")
)
func (svf SpecVersionFormat) String() string {