all: replace gorilla/mux with standard library

This commit is contained in:
Tulir Asokan 2025-07-23 20:30:43 +03:00
commit 3d85625644
11 changed files with 149 additions and 188 deletions

View file

@ -1,4 +1,4 @@
// Copyright (c) 2023 Tulir Asokan
// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -19,7 +19,6 @@ import (
"syscall"
"time"
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
"github.com/rs/zerolog"
"golang.org/x/net/publicsuffix"
@ -43,7 +42,7 @@ func Create() *AppService {
intents: make(map[id.UserID]*IntentAPI),
HTTPClient: &http.Client{Timeout: 180 * time.Second, Jar: jar},
StateStore: mautrix.NewMemoryStateStore().(StateStore),
Router: mux.NewRouter(),
Router: http.NewServeMux(),
UserAgent: mautrix.DefaultUserAgent,
txnIDC: NewTransactionIDCache(128),
Live: true,
@ -61,12 +60,12 @@ func Create() *AppService {
DefaultHTTPRetries: 4,
}
as.Router.HandleFunc("/_matrix/app/v1/transactions/{txnID}", as.PutTransaction).Methods(http.MethodPut)
as.Router.HandleFunc("/_matrix/app/v1/rooms/{roomAlias}", as.GetRoom).Methods(http.MethodGet)
as.Router.HandleFunc("/_matrix/app/v1/users/{userID}", as.GetUser).Methods(http.MethodGet)
as.Router.HandleFunc("/_matrix/app/v1/ping", as.PostPing).Methods(http.MethodPost)
as.Router.HandleFunc("/_matrix/mau/live", as.GetLive).Methods(http.MethodGet)
as.Router.HandleFunc("/_matrix/mau/ready", as.GetReady).Methods(http.MethodGet)
as.Router.HandleFunc("PUT /_matrix/app/v1/transactions/{txnID}", as.PutTransaction)
as.Router.HandleFunc("GET /_matrix/app/v1/rooms/{roomAlias}", as.GetRoom)
as.Router.HandleFunc("GET /_matrix/app/v1/users/{userID}", as.GetUser)
as.Router.HandleFunc("POST /_matrix/app/v1/ping", as.PostPing)
as.Router.HandleFunc("GET /_matrix/mau/live", as.GetLive)
as.Router.HandleFunc("GET /_matrix/mau/ready", as.GetReady)
return as
}
@ -114,13 +113,13 @@ var _ StateStore = (*mautrix.MemoryStateStore)(nil)
// QueryHandler handles room alias and user ID queries from the homeserver.
type QueryHandler interface {
QueryAlias(alias string) bool
QueryAlias(alias id.RoomAlias) bool
QueryUser(userID id.UserID) bool
}
type QueryHandlerStub struct{}
func (qh *QueryHandlerStub) QueryAlias(alias string) bool {
func (qh *QueryHandlerStub) QueryAlias(alias id.RoomAlias) bool {
return false
}
@ -128,7 +127,7 @@ func (qh *QueryHandlerStub) QueryUser(userID id.UserID) bool {
return false
}
type WebsocketHandler func(WebsocketCommand) (ok bool, data interface{})
type WebsocketHandler func(WebsocketCommand) (ok bool, data any)
type StateStore interface {
mautrix.StateStore
@ -160,7 +159,7 @@ type AppService struct {
QueryHandler QueryHandler
StateStore StateStore
Router *mux.Router
Router *http.ServeMux
UserAgent string
server *http.Server
HTTPClient *http.Client

View file

@ -17,7 +17,6 @@ import (
"syscall"
"time"
"github.com/gorilla/mux"
"github.com/rs/zerolog"
"go.mau.fi/util/exhttp"
"go.mau.fi/util/exstrings"
@ -95,8 +94,7 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) {
return
}
vars := mux.Vars(r)
txnID := vars["txnID"]
txnID := r.PathValue("txnID")
if len(txnID) == 0 {
mautrix.MInvalidParam.WithMessage("Missing transaction ID").Write(w)
return
@ -240,8 +238,7 @@ func (as *AppService) GetRoom(w http.ResponseWriter, r *http.Request) {
return
}
vars := mux.Vars(r)
roomAlias := vars["roomAlias"]
roomAlias := id.RoomAlias(r.PathValue("roomAlias"))
ok := as.QueryHandler.QueryAlias(roomAlias)
if ok {
exhttp.WriteEmptyJSONResponse(w, http.StatusOK)
@ -256,8 +253,7 @@ func (as *AppService) GetUser(w http.ResponseWriter, r *http.Request) {
return
}
vars := mux.Vars(r)
userID := id.UserID(vars["userID"])
userID := id.UserID(r.PathValue("userID"))
ok := as.QueryHandler.QueryUser(userID)
if ok {
exhttp.WriteEmptyJSONResponse(w, http.StatusOK)

View file

@ -12,6 +12,7 @@ import (
"encoding/base64"
"errors"
"fmt"
"net/http"
"net/url"
"os"
"regexp"
@ -20,7 +21,6 @@ import (
"time"
"unsafe"
"github.com/gorilla/mux"
_ "github.com/lib/pq"
"github.com/rs/zerolog"
"go.mau.fi/util/dbutil"
@ -223,7 +223,7 @@ func (br *Connector) GetPublicAddress() string {
return br.Config.AppService.PublicAddress
}
func (br *Connector) GetRouter() *mux.Router {
func (br *Connector) GetRouter() *http.ServeMux {
if br.GetPublicAddress() != "" {
return br.AS.Router
}

View file

@ -1,4 +1,4 @@
// Copyright (c) 2024 Tulir Asokan
// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -17,7 +17,6 @@ import (
"sync"
"time"
"github.com/gorilla/mux"
"github.com/rs/xid"
"github.com/rs/zerolog"
"github.com/rs/zerolog/hlog"
@ -40,7 +39,7 @@ type matrixAuthCacheEntry struct {
}
type ProvisioningAPI struct {
Router *mux.Router
Router *http.ServeMux
br *Connector
log zerolog.Logger
@ -91,12 +90,12 @@ func (prov *ProvisioningAPI) GetUser(r *http.Request) *bridgev2.User {
return r.Context().Value(provisioningUserKey).(*bridgev2.User)
}
func (prov *ProvisioningAPI) GetRouter() *mux.Router {
func (prov *ProvisioningAPI) GetRouter() *http.ServeMux {
return prov.Router
}
type IProvisioningAPI interface {
GetRouter() *mux.Router
GetRouter() *http.ServeMux
GetUser(r *http.Request) *bridgev2.User
}
@ -116,41 +115,49 @@ func (prov *ProvisioningAPI) Init() {
tp.Dialer.Timeout = 10 * time.Second
tp.Transport.ResponseHeaderTimeout = 10 * time.Second
tp.Transport.TLSHandshakeTimeout = 10 * time.Second
prov.Router = prov.br.AS.Router.PathPrefix(prov.br.Config.Provisioning.Prefix).Subrouter()
prov.Router.Use(hlog.NewHandler(prov.log))
prov.Router.Use(hlog.RequestIDHandler("request_id", "Request-Id"))
prov.Router.Use(exhttp.CORSMiddleware)
prov.Router.Use(requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}))
prov.Router.Use(prov.AuthMiddleware)
prov.Router.Path("/v3/whoami").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetWhoami)
prov.Router.Path("/v3/login/flows").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetLoginFlows)
prov.Router.Path("/v3/login/start/{flowID}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginStart)
prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:user_input|cookies}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginSubmitInput)
prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:display_and_wait}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginWait)
prov.Router.Path("/v3/logout/{loginID}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLogout)
prov.Router.Path("/v3/logins").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetLogins)
prov.Router.Path("/v3/contacts").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetContactList)
prov.Router.Path("/v3/search_users").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostSearchUsers)
prov.Router.Path("/v3/resolve_identifier/{identifier}").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetResolveIdentifier)
prov.Router.Path("/v3/create_dm/{identifier}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostCreateDM)
prov.Router.Path("/v3/create_group").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostCreateGroup)
prov.Router = http.NewServeMux()
prov.Router.HandleFunc("GET /v3/whoami", prov.GetWhoami)
prov.Router.HandleFunc("GET /v3/login/flows", prov.GetLoginFlows)
prov.Router.HandleFunc("POST /v3/login/start/{flowID}", prov.PostLoginStart)
prov.Router.HandleFunc("POST /v3/login/step/{loginProcessID}/{stepID}/{stepType:user_input|cookies}", prov.PostLoginSubmitInput)
prov.Router.HandleFunc("POST /v3/login/step/{loginProcessID}/{stepID}/{stepType:display_and_wait}", prov.PostLoginWait)
prov.Router.HandleFunc("POST /v3/logout/{loginID}", prov.PostLogout)
prov.Router.HandleFunc("GET /v3/logins", prov.GetLogins)
prov.Router.HandleFunc("GET /v3/contacts", prov.GetContactList)
prov.Router.HandleFunc("POST /v3/search_users", prov.PostSearchUsers)
prov.Router.HandleFunc("GET /v3/resolve_identifier/{identifier}", prov.GetResolveIdentifier)
prov.Router.HandleFunc("POST /v3/create_dm/{identifier}", prov.PostCreateDM)
prov.Router.HandleFunc("POST /v3/create_group", prov.PostCreateGroup)
if prov.br.Config.Provisioning.EnableSessionTransfers {
prov.log.Debug().Msg("Enabling session transfer API")
prov.Router.Path("/v3/session_transfer/init").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostInitSessionTransfer)
prov.Router.Path("/v3/session_transfer/finish").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostFinishSessionTransfer)
prov.Router.HandleFunc("POST /v3/session_transfer/init", prov.PostInitSessionTransfer)
prov.Router.HandleFunc("POST /v3/session_transfer/finish", prov.PostFinishSessionTransfer)
}
if prov.br.Config.Provisioning.DebugEndpoints {
prov.log.Debug().Msg("Enabling debug API at /debug")
r := prov.br.AS.Router.PathPrefix("/debug").Subrouter()
r.Use(prov.DebugAuthMiddleware)
r.HandleFunc("/pprof/cmdline", pprof.Cmdline).Methods(http.MethodGet)
r.HandleFunc("/pprof/profile", pprof.Profile).Methods(http.MethodGet)
r.HandleFunc("/pprof/symbol", pprof.Symbol).Methods(http.MethodGet)
r.HandleFunc("/pprof/trace", pprof.Trace).Methods(http.MethodGet)
r.PathPrefix("/pprof/").HandlerFunc(pprof.Index)
debugRouter := http.NewServeMux()
debugRouter.HandleFunc("GET /pprof/cmdline", pprof.Cmdline)
debugRouter.HandleFunc("GET /pprof/profile", pprof.Profile)
debugRouter.HandleFunc("GET /pprof/symbol", pprof.Symbol)
debugRouter.HandleFunc("GET /pprof/trace", pprof.Trace)
debugRouter.HandleFunc("/pprof/", pprof.Index)
prov.br.AS.Router.Handle("/debug", exhttp.ApplyMiddleware(
debugRouter,
hlog.NewHandler(prov.br.Log.With().Str("component", "debug api").Logger()),
prov.DebugAuthMiddleware,
))
}
prov.br.AS.Router.Handle("/_matrix/provision", exhttp.ApplyMiddleware(
prov.Router,
hlog.NewHandler(prov.log),
hlog.RequestIDHandler("request_id", "Request-Id"),
exhttp.CORSMiddleware,
requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}),
prov.AuthMiddleware,
))
}
func (prov *ProvisioningAPI) checkMatrixAuth(ctx context.Context, userID id.UserID, token string) error {
@ -250,7 +257,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
ctx := context.WithValue(r.Context(), ProvisioningKeyRequest, r)
ctx = context.WithValue(ctx, provisioningUserKey, user)
if loginID, ok := mux.Vars(r)["loginProcessID"]; ok {
if loginID := r.PathValue("loginProcessID"); loginID != "" {
prov.loginsLock.RLock()
login, ok := prov.logins[loginID]
prov.loginsLock.RUnlock()
@ -262,7 +269,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
login.Lock.Lock()
// This will only unlock after the handler runs
defer login.Lock.Unlock()
stepID := mux.Vars(r)["stepID"]
stepID := r.PathValue("stepID")
if login.NextStep.StepID != stepID {
zerolog.Ctx(r.Context()).Warn().
Str("request_step_id", stepID).
@ -271,7 +278,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
mautrix.MBadState.WithMessage("Step ID does not match").Write(w)
return
}
stepType := mux.Vars(r)["stepType"]
stepType := r.PathValue("stepType")
if login.NextStep.Type != bridgev2.LoginStepType(stepType) {
zerolog.Ctx(r.Context()).Warn().
Str("request_step_type", stepType).
@ -374,7 +381,7 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque
login, err := prov.net.CreateLogin(
r.Context(),
prov.GetUser(r),
mux.Vars(r)["flowID"],
r.PathValue("flowID"),
)
if err != nil {
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create login process")
@ -475,7 +482,7 @@ func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Reques
func (prov *ProvisioningAPI) PostLogout(w http.ResponseWriter, r *http.Request) {
user := prov.GetUser(r)
userLoginID := networkid.UserLoginID(mux.Vars(r)["loginID"])
userLoginID := networkid.UserLoginID(r.PathValue("loginID"))
if userLoginID == "all" {
for {
login := user.GetDefaultLogin()
@ -571,7 +578,7 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http.
mautrix.MUnrecognized.WithMessage("This bridge does not support resolving identifiers").Write(w)
return
}
resp, err := api.ResolveIdentifier(r.Context(), mux.Vars(r)["identifier"], createChat)
resp, err := api.ResolveIdentifier(r.Context(), r.PathValue("identifier"), createChat)
if err != nil {
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to resolve identifier")
RespondWithError(w, err, "Internal error resolving identifier")

View file

@ -1,4 +1,4 @@
// Copyright (c) 2024 Tulir Asokan
// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -16,8 +16,6 @@ import (
"net/http"
"time"
"github.com/gorilla/mux"
"maunium.net/go/mautrix/bridgev2"
"maunium.net/go/mautrix/id"
)
@ -35,7 +33,7 @@ func (br *Connector) initPublicMedia() error {
return fmt.Errorf("public media hash length is negative")
}
br.pubMediaSigKey = []byte(br.Config.PublicMedia.SigningKey)
br.AS.Router.HandleFunc("/_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia).Methods(http.MethodGet)
br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia)
return nil
}
@ -76,16 +74,15 @@ var proxyHeadersToCopy = []string{
}
func (br *Connector) servePublicMedia(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
contentURI := id.ContentURI{
Homeserver: vars["server"],
FileID: vars["mediaID"],
Homeserver: r.PathValue("server"),
FileID: r.PathValue("mediaID"),
}
if !contentURI.IsValid() {
http.Error(w, "invalid content URI", http.StatusBadRequest)
return
}
checksum, err := base64.RawURLEncoding.DecodeString(vars["checksum"])
checksum, err := base64.RawURLEncoding.DecodeString(r.PathValue("checksum"))
if err != nil || !hmac.Equal(checksum, br.makePublicMediaChecksum(contentURI)) {
http.Error(w, "invalid base64 in checksum", http.StatusBadRequest)
return

View file

@ -1,4 +1,4 @@
// Copyright (c) 2024 Tulir Asokan
// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -10,11 +10,10 @@ import (
"context"
"fmt"
"io"
"net/http"
"os"
"time"
"github.com/gorilla/mux"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/bridgev2/database"
"maunium.net/go/mautrix/bridgev2/networkid"
@ -64,7 +63,7 @@ type MatrixConnectorWithArbitraryRoomState interface {
type MatrixConnectorWithServer interface {
GetPublicAddress() string
GetRouter() *mux.Router
GetRouter() *http.ServeMux
}
type MatrixConnectorWithPublicMedia interface {

View file

@ -12,11 +12,9 @@ import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/gorilla/mux"
"github.com/rs/zerolog/log" // zerolog-allow-global-log
"github.com/stretchr/testify/require"
"go.mau.fi/util/random"
@ -42,20 +40,6 @@ type mockServer struct {
UserSigningKeys map[id.UserID]mautrix.CrossSigningKeys
}
func DecodeVarsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
var err error
for k, v := range vars {
vars[k], err = url.PathUnescape(v)
if err != nil {
panic(err)
}
}
next.ServeHTTP(w, r)
})
}
func createMockServer(t *testing.T) *mockServer {
t.Helper()
@ -69,15 +53,14 @@ func createMockServer(t *testing.T) *mockServer {
UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
}
router := mux.NewRouter().SkipClean(true).StrictSlash(false).UseEncodedPath()
router.Use(DecodeVarsMiddleware)
router.HandleFunc("/_matrix/client/v3/login", server.postLogin).Methods(http.MethodPost)
router.HandleFunc("/_matrix/client/v3/keys/query", server.postKeysQuery).Methods(http.MethodPost)
router.HandleFunc("/_matrix/client/v3/sendToDevice/{type}/{txn}", server.putSendToDevice).Methods(http.MethodPut)
router.HandleFunc("/_matrix/client/v3/user/{userID}/account_data/{type}", server.putAccountData).Methods(http.MethodPut)
router.HandleFunc("/_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload).Methods(http.MethodPost)
router.HandleFunc("/_matrix/client/v3/keys/signatures/upload", server.emptyResp).Methods(http.MethodPost)
router.HandleFunc("/_matrix/client/v3/keys/upload", server.postKeysUpload).Methods(http.MethodPost)
router := http.NewServeMux()
router.HandleFunc("POST /_matrix/client/v3/login", server.postLogin)
router.HandleFunc("POST /_matrix/client/v3/keys/query", server.postKeysQuery)
router.HandleFunc("PUT /_matrix/client/v3/sendToDevice/{type}/{txn}", server.putSendToDevice)
router.HandleFunc("PUT /_matrix/client/v3/user/{userID}/account_data/{type}", server.putAccountData)
router.HandleFunc("POST /_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload)
router.HandleFunc("POST /_matrix/client/v3/keys/signatures/upload", server.emptyResp)
router.HandleFunc("POST /_matrix/client/v3/keys/upload", server.postKeysUpload)
server.Server = httptest.NewServer(router)
return &server
@ -118,10 +101,9 @@ func (s *mockServer) postLogin(w http.ResponseWriter, r *http.Request) {
}
func (s *mockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
var req mautrix.ReqSendToDevice
json.NewDecoder(r.Body).Decode(&req)
evtType := event.Type{Type: vars["type"], Class: event.ToDeviceEventType}
evtType := event.Type{Type: r.PathValue("type"), Class: event.ToDeviceEventType}
for user, devices := range req.Messages {
for device, content := range devices {
@ -140,9 +122,8 @@ func (s *mockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) {
}
func (s *mockServer) putAccountData(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
userID := id.UserID(vars["userID"])
eventType := event.Type{Type: vars["type"], Class: event.AccountDataEventType}
userID := id.UserID(r.PathValue("userID"))
eventType := event.Type{Type: r.PathValue("type"), Class: event.AccountDataEventType}
jsonData, _ := io.ReadAll(r.Body)
if _, ok := s.AccountData[userID]; !ok {

View file

@ -1,4 +1,4 @@
// Copyright (c) 2024 Tulir Asokan
// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -12,7 +12,7 @@ import (
"strconv"
"time"
"github.com/gorilla/mux"
"go.mau.fi/util/exerrors"
"go.mau.fi/util/exhttp"
"go.mau.fi/util/jsontime"
@ -51,19 +51,21 @@ type KeyServer struct {
}
// Register registers the key server endpoints to the given router.
func (ks *KeyServer) Register(r *mux.Router) {
r.HandleFunc("/.well-known/matrix/server", ks.GetWellKnown).Methods(http.MethodGet)
r.HandleFunc("/_matrix/federation/v1/version", ks.GetServerVersion).Methods(http.MethodGet)
keyRouter := r.PathPrefix("/_matrix/key").Subrouter()
keyRouter.HandleFunc("/v2/server", ks.GetServerKey).Methods(http.MethodGet)
keyRouter.HandleFunc("/v2/query/{serverName}", ks.GetQueryKeys).Methods(http.MethodGet)
keyRouter.HandleFunc("/v2/query", ks.PostQueryKeys).Methods(http.MethodPost)
keyRouter.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mautrix.MUnrecognized.WithStatus(http.StatusNotFound).WithMessage("Unrecognized endpoint").Write(w)
})
keyRouter.MethodNotAllowedHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mautrix.MUnrecognized.WithStatus(http.StatusMethodNotAllowed).WithMessage("Invalid method for endpoint").Write(w)
})
func (ks *KeyServer) Register(r *http.ServeMux) {
r.HandleFunc("GET /.well-known/matrix/server", ks.GetWellKnown)
r.HandleFunc("GET /_matrix/federation/v1/version", ks.GetServerVersion)
keyRouter := http.NewServeMux()
keyRouter.HandleFunc("GET /v2/server", ks.GetServerKey)
keyRouter.HandleFunc("GET /v2/query/{serverName}", ks.GetQueryKeys)
keyRouter.HandleFunc("POST /v2/query", ks.PostQueryKeys)
errorBodies := exhttp.ErrorBodies{
NotFound: exerrors.Must(json.Marshal(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint"))),
MethodNotAllowed: exerrors.Must(json.Marshal(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint"))),
}
r.Handle("/_matrix/key", exhttp.ApplyMiddleware(
keyRouter,
exhttp.HandleErrors(errorBodies),
))
}
// RespWellKnown is the response body for the `GET /.well-known/matrix/server` endpoint.
@ -157,7 +159,7 @@ type GetQueryKeysResponse struct {
//
// https://spec.matrix.org/v1.9/server-server-api/#get_matrixkeyv2queryservername
func (ks *KeyServer) GetQueryKeys(w http.ResponseWriter, r *http.Request) {
serverName := mux.Vars(r)["serverName"]
serverName := r.PathValue("serverName")
minimumValidUntilTSString := r.URL.Query().Get("minimum_valid_until_ts")
minimumValidUntilTS, err := strconv.ParseInt(minimumValidUntilTSString, 10, 64)
if err != nil && minimumValidUntilTSString != "" {

9
go.mod
View file

@ -2,12 +2,11 @@ module maunium.net/go/mautrix
go 1.23.0
toolchain go1.24.4
toolchain go1.24.5
require (
filippo.io/edwards25519 v1.1.0
github.com/chzyer/readline v1.5.1
github.com/gorilla/mux v1.8.0
github.com/gorilla/websocket v1.5.0
github.com/lib/pq v1.10.9
github.com/mattn/go-sqlite3 v1.14.28
@ -18,10 +17,10 @@ require (
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
github.com/yuin/goldmark v1.7.12
go.mau.fi/util v0.8.8
go.mau.fi/util v0.8.9-0.20250723171559-474867266038
go.mau.fi/zeroconfig v0.1.3
golang.org/x/crypto v0.40.0
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792
golang.org/x/net v0.42.0
golang.org/x/sync v0.16.0
gopkg.in/yaml.v3 v3.0.1
@ -33,7 +32,7 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/petermattis/goid v0.0.0-20250508124226-395b08cebbdb // indirect
github.com/petermattis/goid v0.0.0-20250721140440-ea1c0173183e // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect

14
go.sum
View file

@ -13,8 +13,6 @@ github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSV
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
@ -28,8 +26,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A=
github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/petermattis/goid v0.0.0-20250508124226-395b08cebbdb h1:3PrKuO92dUTMrQ9dx0YNejC6U/Si6jqKmyQ9vWjwqR4=
github.com/petermattis/goid v0.0.0-20250508124226-395b08cebbdb/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4=
github.com/petermattis/goid v0.0.0-20250721140440-ea1c0173183e h1:D0bJD+4O3G4izvrQUmzCL80zazlN7EwJ0PPDhpJWC/I=
github.com/petermattis/goid v0.0.0-20250721140440-ea1c0173183e/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
@ -53,14 +51,14 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/yuin/goldmark v1.7.12 h1:YwGP/rrea2/CnCtUHgjuolG/PnMxdQtPMO5PvaE2/nY=
github.com/yuin/goldmark v1.7.12/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
go.mau.fi/util v0.8.8 h1:OnuEEc/sIJFhnq4kFggiImUpcmnmL/xpvQMRu5Fiy5c=
go.mau.fi/util v0.8.8/go.mod h1:Y/kS3loxTEhy8Vill513EtPXr+CRDdae+Xj2BXXMy/c=
go.mau.fi/util v0.8.9-0.20250723171559-474867266038 h1:RVL8TVaYc3LTBBopfjCNDtD+6eZks0O+qgXN/9hsz7k=
go.mau.fi/util v0.8.9-0.20250723171559-474867266038/go.mod h1:GZZp5f9r2MgEu4GDvtB0XxCF7i6Z7Z8fM0w9a5oZH3Y=
go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM=
go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70=
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc h1:TS73t7x3KarrNd5qAipmspBDS1rkMcgVG/fS1aRb4Rc=
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 h1:R9PFI6EUdfVKgwKjZef7QIwGcBKu86OEFpJ9nUEP2l4=
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs=
golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8=
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=

View file

@ -1,4 +1,4 @@
// Copyright (c) 2024 Tulir Asokan
// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -8,6 +8,7 @@ package mediaproxy
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
@ -21,8 +22,9 @@ import (
"strings"
"time"
"github.com/gorilla/mux"
"github.com/rs/zerolog"
"go.mau.fi/util/exerrors"
"go.mau.fi/util/exhttp"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/federation"
@ -108,8 +110,8 @@ type MediaProxy struct {
serverName string
serverKey *federation.SigningKey
FederationRouter *mux.Router
ClientMediaRouter *mux.Router
FederationRouter *http.ServeMux
ClientMediaRouter *http.ServeMux
}
func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProxy, error) {
@ -117,7 +119,7 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx
if err != nil {
return nil, err
}
return &MediaProxy{
mp := &MediaProxy{
serverName: serverName,
serverKey: parsed,
GetMedia: getMedia,
@ -132,7 +134,20 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx
Version: strings.TrimPrefix(mautrix.VersionWithCommit, "v"),
},
},
}, nil
}
mp.FederationRouter = http.NewServeMux()
mp.FederationRouter.HandleFunc("GET /v1/media/download/{mediaID}", mp.DownloadMediaFederation)
mp.FederationRouter.HandleFunc("GET /v1/version", mp.KeyServer.GetServerVersion)
mp.ClientMediaRouter = http.NewServeMux()
mp.ClientMediaRouter.HandleFunc("GET /download/{serverName}/{mediaID}", mp.DownloadMedia)
mp.ClientMediaRouter.HandleFunc("GET /download/{serverName}/{mediaID}/{fileName}", mp.DownloadMedia)
mp.ClientMediaRouter.HandleFunc("GET /thumbnail/{serverName}/{mediaID}", mp.DownloadMedia)
mp.ClientMediaRouter.HandleFunc("PUT /upload/{serverName}/{mediaID}", mp.UploadNotSupported)
mp.ClientMediaRouter.HandleFunc("POST /upload", mp.UploadNotSupported)
mp.ClientMediaRouter.HandleFunc("POST /create", mp.UploadNotSupported)
mp.ClientMediaRouter.HandleFunc("GET /config", mp.UploadNotSupported)
mp.ClientMediaRouter.HandleFunc("GET /preview_url", mp.PreviewURLNotSupported)
return mp, nil
}
type BasicConfig struct {
@ -162,7 +177,7 @@ type ServerConfig struct {
}
func (mp *MediaProxy) Listen(cfg ServerConfig) error {
router := mux.NewRouter()
router := http.NewServeMux()
mp.RegisterRoutes(router)
return http.ListenAndServe(fmt.Sprintf("%s:%d", cfg.Hostname, cfg.Port), router)
}
@ -188,38 +203,20 @@ func (mp *MediaProxy) EnableServerAuth(client *federation.Client, keyCache feder
})
}
func (mp *MediaProxy) RegisterRoutes(router *mux.Router) {
if mp.FederationRouter == nil {
mp.FederationRouter = router.PathPrefix("/_matrix/federation").Subrouter()
func (mp *MediaProxy) RegisterRoutes(router *http.ServeMux) {
errorBodies := exhttp.ErrorBodies{
NotFound: exerrors.Must(json.Marshal(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint"))),
MethodNotAllowed: exerrors.Must(json.Marshal(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint"))),
}
if mp.ClientMediaRouter == nil {
mp.ClientMediaRouter = router.PathPrefix("/_matrix/client/v1/media").Subrouter()
}
mp.FederationRouter.HandleFunc("/v1/media/download/{mediaID}", mp.DownloadMediaFederation).Methods(http.MethodGet)
mp.FederationRouter.HandleFunc("/v1/version", mp.KeyServer.GetServerVersion).Methods(http.MethodGet)
mp.ClientMediaRouter.HandleFunc("/download/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet)
mp.ClientMediaRouter.HandleFunc("/download/{serverName}/{mediaID}/{fileName}", mp.DownloadMedia).Methods(http.MethodGet)
mp.ClientMediaRouter.HandleFunc("/thumbnail/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet)
mp.ClientMediaRouter.HandleFunc("/upload/{serverName}/{mediaID}", mp.UploadNotSupported).Methods(http.MethodPut)
mp.ClientMediaRouter.HandleFunc("/upload", mp.UploadNotSupported).Methods(http.MethodPost)
mp.ClientMediaRouter.HandleFunc("/create", mp.UploadNotSupported).Methods(http.MethodPost)
mp.ClientMediaRouter.HandleFunc("/config", mp.UploadNotSupported).Methods(http.MethodGet)
mp.ClientMediaRouter.HandleFunc("/preview_url", mp.PreviewURLNotSupported).Methods(http.MethodGet)
mp.FederationRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint)
mp.FederationRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod)
mp.ClientMediaRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint)
mp.ClientMediaRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod)
corsMiddleware := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With, Content-Type, Authorization")
w.Header().Set("Content-Security-Policy", "sandbox; default-src 'none'; script-src 'none'; plugin-types application/pdf; style-src 'unsafe-inline'; object-src 'self';")
next.ServeHTTP(w, r)
})
}
mp.ClientMediaRouter.Use(corsMiddleware)
router.Handle("/_matrix/federation", exhttp.ApplyMiddleware(
mp.FederationRouter,
exhttp.HandleErrors(errorBodies),
))
router.Handle("/_matrix/client/v1/media", exhttp.ApplyMiddleware(
mp.ClientMediaRouter,
exhttp.CORSMiddleware,
exhttp.HandleErrors(errorBodies),
))
mp.KeyServer.Register(router)
}
@ -234,7 +231,7 @@ func queryToMap(vals url.Values) map[string]string {
}
func (mp *MediaProxy) getMedia(w http.ResponseWriter, r *http.Request) GetMediaResponse {
mediaID := mux.Vars(r)["mediaID"]
mediaID := r.PathValue("mediaID")
if !id.IsValidMediaID(mediaID) {
mautrix.MNotFound.WithMessage("Media ID %q is not valid", mediaID).Write(w)
return nil
@ -380,8 +377,7 @@ func (mp *MediaProxy) addHeaders(w http.ResponseWriter, mimeType, fileName strin
func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
log := zerolog.Ctx(ctx)
vars := mux.Vars(r)
if vars["serverName"] != mp.serverName {
if r.PathValue("serverName") != mp.serverName {
mautrix.MNotFound.WithMessage("This is a media proxy at %q, other media downloads are not available here", mp.serverName).Write(w)
return
}
@ -404,7 +400,7 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTemporaryRedirect)
} else if fileResp, ok := resp.(*GetMediaResponseFile); ok {
responseStarted, err := doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error {
mp.addHeaders(w, mimeType, vars["fileName"])
mp.addHeaders(w, mimeType, r.PathValue("fileName"))
w.Header().Set("Content-Length", strconv.FormatInt(size, 10))
w.WriteHeader(http.StatusOK)
_, err := wt.WriteTo(w)
@ -425,7 +421,7 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) {
if dataResp, ok := writerResp.(*GetMediaResponseData); ok {
defer dataResp.Reader.Close()
}
mp.addHeaders(w, writerResp.GetContentType(), vars["fileName"])
mp.addHeaders(w, writerResp.GetContentType(), r.PathValue("fileName"))
if writerResp.GetContentLength() != 0 {
w.Header().Set("Content-Length", strconv.FormatInt(writerResp.GetContentLength(), 10))
}
@ -491,11 +487,6 @@ var (
ErrPreviewURLNotSupported = mautrix.MUnrecognized.
WithMessage("This is a media proxy and does not support URL previews.").
WithStatus(http.StatusNotImplemented)
ErrUnknownEndpoint = mautrix.MUnrecognized.
WithMessage("Unrecognized endpoint")
ErrUnsupportedMethod = mautrix.MUnrecognized.
WithMessage("Invalid method for endpoint").
WithStatus(http.StatusMethodNotAllowed)
)
func (mp *MediaProxy) UploadNotSupported(w http.ResponseWriter, r *http.Request) {
@ -505,11 +496,3 @@ func (mp *MediaProxy) UploadNotSupported(w http.ResponseWriter, r *http.Request)
func (mp *MediaProxy) PreviewURLNotSupported(w http.ResponseWriter, r *http.Request) {
ErrPreviewURLNotSupported.Write(w)
}
func (mp *MediaProxy) UnknownEndpoint(w http.ResponseWriter, r *http.Request) {
ErrUnknownEndpoint.Write(w)
}
func (mp *MediaProxy) UnsupportedMethod(w http.ResponseWriter, r *http.Request) {
ErrUnsupportedMethod.Write(w)
}