diff --git a/appservice/appservice.go b/appservice/appservice.go index 518e1073..5dd067c0 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -19,7 +19,6 @@ import ( "syscall" "time" - "github.com/gorilla/mux" "github.com/gorilla/websocket" "github.com/rs/zerolog" "golang.org/x/net/publicsuffix" @@ -43,7 +42,7 @@ func Create() *AppService { intents: make(map[id.UserID]*IntentAPI), HTTPClient: &http.Client{Timeout: 180 * time.Second, Jar: jar}, StateStore: mautrix.NewMemoryStateStore().(StateStore), - Router: mux.NewRouter(), + Router: http.NewServeMux(), UserAgent: mautrix.DefaultUserAgent, txnIDC: NewTransactionIDCache(128), Live: true, @@ -61,12 +60,12 @@ func Create() *AppService { DefaultHTTPRetries: 4, } - as.Router.HandleFunc("/_matrix/app/v1/transactions/{txnID}", as.PutTransaction).Methods(http.MethodPut) - as.Router.HandleFunc("/_matrix/app/v1/rooms/{roomAlias}", as.GetRoom).Methods(http.MethodGet) - as.Router.HandleFunc("/_matrix/app/v1/users/{userID}", as.GetUser).Methods(http.MethodGet) - as.Router.HandleFunc("/_matrix/app/v1/ping", as.PostPing).Methods(http.MethodPost) - as.Router.HandleFunc("/_matrix/mau/live", as.GetLive).Methods(http.MethodGet) - as.Router.HandleFunc("/_matrix/mau/ready", as.GetReady).Methods(http.MethodGet) + as.Router.HandleFunc("PUT /_matrix/app/v1/transactions/{txnID}", as.PutTransaction) + as.Router.HandleFunc("GET /_matrix/app/v1/rooms/{roomAlias}", as.GetRoom) + as.Router.HandleFunc("GET /_matrix/app/v1/users/{userID}", as.GetUser) + as.Router.HandleFunc("POST /_matrix/app/v1/ping", as.PostPing) + as.Router.HandleFunc("GET /_matrix/mau/live", as.GetLive) + as.Router.HandleFunc("GET /_matrix/mau/ready", as.GetReady) return as } @@ -114,13 +113,13 @@ var _ StateStore = (*mautrix.MemoryStateStore)(nil) // QueryHandler handles room alias and user ID queries from the homeserver. type QueryHandler interface { - QueryAlias(alias string) bool + QueryAlias(alias id.RoomAlias) bool QueryUser(userID id.UserID) bool } type QueryHandlerStub struct{} -func (qh *QueryHandlerStub) QueryAlias(alias string) bool { +func (qh *QueryHandlerStub) QueryAlias(alias id.RoomAlias) bool { return false } @@ -128,7 +127,7 @@ func (qh *QueryHandlerStub) QueryUser(userID id.UserID) bool { return false } -type WebsocketHandler func(WebsocketCommand) (ok bool, data interface{}) +type WebsocketHandler func(WebsocketCommand) (ok bool, data any) type StateStore interface { mautrix.StateStore @@ -160,7 +159,7 @@ type AppService struct { QueryHandler QueryHandler StateStore StateStore - Router *mux.Router + Router *http.ServeMux UserAgent string server *http.Server HTTPClient *http.Client diff --git a/appservice/http.go b/appservice/http.go index 1ebe6e56..862de7fd 100644 --- a/appservice/http.go +++ b/appservice/http.go @@ -17,7 +17,6 @@ import ( "syscall" "time" - "github.com/gorilla/mux" "github.com/rs/zerolog" "go.mau.fi/util/exhttp" "go.mau.fi/util/exstrings" @@ -95,8 +94,7 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) { return } - vars := mux.Vars(r) - txnID := vars["txnID"] + txnID := r.PathValue("txnID") if len(txnID) == 0 { mautrix.MInvalidParam.WithMessage("Missing transaction ID").Write(w) return @@ -240,8 +238,7 @@ func (as *AppService) GetRoom(w http.ResponseWriter, r *http.Request) { return } - vars := mux.Vars(r) - roomAlias := vars["roomAlias"] + roomAlias := id.RoomAlias(r.PathValue("roomAlias")) ok := as.QueryHandler.QueryAlias(roomAlias) if ok { exhttp.WriteEmptyJSONResponse(w, http.StatusOK) @@ -256,8 +253,7 @@ func (as *AppService) GetUser(w http.ResponseWriter, r *http.Request) { return } - vars := mux.Vars(r) - userID := id.UserID(vars["userID"]) + userID := id.UserID(r.PathValue("userID")) ok := as.QueryHandler.QueryUser(userID) if ok { exhttp.WriteEmptyJSONResponse(w, http.StatusOK) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index c168ae3d..af9931b0 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -12,6 +12,7 @@ import ( "encoding/base64" "errors" "fmt" + "net/http" "net/url" "os" "regexp" @@ -20,7 +21,6 @@ import ( "time" "unsafe" - "github.com/gorilla/mux" _ "github.com/lib/pq" "github.com/rs/zerolog" "go.mau.fi/util/dbutil" @@ -223,7 +223,7 @@ func (br *Connector) GetPublicAddress() string { return br.Config.AppService.PublicAddress } -func (br *Connector) GetRouter() *mux.Router { +func (br *Connector) GetRouter() *http.ServeMux { if br.GetPublicAddress() != "" { return br.AS.Router } diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index f865a19e..7eec37c1 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -17,7 +17,6 @@ import ( "sync" "time" - "github.com/gorilla/mux" "github.com/rs/xid" "github.com/rs/zerolog" "github.com/rs/zerolog/hlog" @@ -40,7 +39,7 @@ type matrixAuthCacheEntry struct { } type ProvisioningAPI struct { - Router *mux.Router + Router *http.ServeMux br *Connector log zerolog.Logger @@ -91,12 +90,12 @@ func (prov *ProvisioningAPI) GetUser(r *http.Request) *bridgev2.User { return r.Context().Value(provisioningUserKey).(*bridgev2.User) } -func (prov *ProvisioningAPI) GetRouter() *mux.Router { +func (prov *ProvisioningAPI) GetRouter() *http.ServeMux { return prov.Router } type IProvisioningAPI interface { - GetRouter() *mux.Router + GetRouter() *http.ServeMux GetUser(r *http.Request) *bridgev2.User } @@ -116,41 +115,49 @@ func (prov *ProvisioningAPI) Init() { tp.Dialer.Timeout = 10 * time.Second tp.Transport.ResponseHeaderTimeout = 10 * time.Second tp.Transport.TLSHandshakeTimeout = 10 * time.Second - prov.Router = prov.br.AS.Router.PathPrefix(prov.br.Config.Provisioning.Prefix).Subrouter() - prov.Router.Use(hlog.NewHandler(prov.log)) - prov.Router.Use(hlog.RequestIDHandler("request_id", "Request-Id")) - prov.Router.Use(exhttp.CORSMiddleware) - prov.Router.Use(requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true})) - prov.Router.Use(prov.AuthMiddleware) - prov.Router.Path("/v3/whoami").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetWhoami) - prov.Router.Path("/v3/login/flows").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetLoginFlows) - prov.Router.Path("/v3/login/start/{flowID}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginStart) - prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:user_input|cookies}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginSubmitInput) - prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:display_and_wait}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginWait) - prov.Router.Path("/v3/logout/{loginID}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLogout) - prov.Router.Path("/v3/logins").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetLogins) - prov.Router.Path("/v3/contacts").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetContactList) - prov.Router.Path("/v3/search_users").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostSearchUsers) - prov.Router.Path("/v3/resolve_identifier/{identifier}").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetResolveIdentifier) - prov.Router.Path("/v3/create_dm/{identifier}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostCreateDM) - prov.Router.Path("/v3/create_group").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostCreateGroup) + prov.Router = http.NewServeMux() + prov.Router.HandleFunc("GET /v3/whoami", prov.GetWhoami) + prov.Router.HandleFunc("GET /v3/login/flows", prov.GetLoginFlows) + prov.Router.HandleFunc("POST /v3/login/start/{flowID}", prov.PostLoginStart) + prov.Router.HandleFunc("POST /v3/login/step/{loginProcessID}/{stepID}/{stepType:user_input|cookies}", prov.PostLoginSubmitInput) + prov.Router.HandleFunc("POST /v3/login/step/{loginProcessID}/{stepID}/{stepType:display_and_wait}", prov.PostLoginWait) + prov.Router.HandleFunc("POST /v3/logout/{loginID}", prov.PostLogout) + prov.Router.HandleFunc("GET /v3/logins", prov.GetLogins) + prov.Router.HandleFunc("GET /v3/contacts", prov.GetContactList) + prov.Router.HandleFunc("POST /v3/search_users", prov.PostSearchUsers) + prov.Router.HandleFunc("GET /v3/resolve_identifier/{identifier}", prov.GetResolveIdentifier) + prov.Router.HandleFunc("POST /v3/create_dm/{identifier}", prov.PostCreateDM) + prov.Router.HandleFunc("POST /v3/create_group", prov.PostCreateGroup) if prov.br.Config.Provisioning.EnableSessionTransfers { prov.log.Debug().Msg("Enabling session transfer API") - prov.Router.Path("/v3/session_transfer/init").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostInitSessionTransfer) - prov.Router.Path("/v3/session_transfer/finish").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostFinishSessionTransfer) + prov.Router.HandleFunc("POST /v3/session_transfer/init", prov.PostInitSessionTransfer) + prov.Router.HandleFunc("POST /v3/session_transfer/finish", prov.PostFinishSessionTransfer) } if prov.br.Config.Provisioning.DebugEndpoints { prov.log.Debug().Msg("Enabling debug API at /debug") - r := prov.br.AS.Router.PathPrefix("/debug").Subrouter() - r.Use(prov.DebugAuthMiddleware) - r.HandleFunc("/pprof/cmdline", pprof.Cmdline).Methods(http.MethodGet) - r.HandleFunc("/pprof/profile", pprof.Profile).Methods(http.MethodGet) - r.HandleFunc("/pprof/symbol", pprof.Symbol).Methods(http.MethodGet) - r.HandleFunc("/pprof/trace", pprof.Trace).Methods(http.MethodGet) - r.PathPrefix("/pprof/").HandlerFunc(pprof.Index) + debugRouter := http.NewServeMux() + debugRouter.HandleFunc("GET /pprof/cmdline", pprof.Cmdline) + debugRouter.HandleFunc("GET /pprof/profile", pprof.Profile) + debugRouter.HandleFunc("GET /pprof/symbol", pprof.Symbol) + debugRouter.HandleFunc("GET /pprof/trace", pprof.Trace) + debugRouter.HandleFunc("/pprof/", pprof.Index) + prov.br.AS.Router.Handle("/debug", exhttp.ApplyMiddleware( + debugRouter, + hlog.NewHandler(prov.br.Log.With().Str("component", "debug api").Logger()), + prov.DebugAuthMiddleware, + )) } + + prov.br.AS.Router.Handle("/_matrix/provision", exhttp.ApplyMiddleware( + prov.Router, + hlog.NewHandler(prov.log), + hlog.RequestIDHandler("request_id", "Request-Id"), + exhttp.CORSMiddleware, + requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}), + prov.AuthMiddleware, + )) } func (prov *ProvisioningAPI) checkMatrixAuth(ctx context.Context, userID id.UserID, token string) error { @@ -250,7 +257,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { ctx := context.WithValue(r.Context(), ProvisioningKeyRequest, r) ctx = context.WithValue(ctx, provisioningUserKey, user) - if loginID, ok := mux.Vars(r)["loginProcessID"]; ok { + if loginID := r.PathValue("loginProcessID"); loginID != "" { prov.loginsLock.RLock() login, ok := prov.logins[loginID] prov.loginsLock.RUnlock() @@ -262,7 +269,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { login.Lock.Lock() // This will only unlock after the handler runs defer login.Lock.Unlock() - stepID := mux.Vars(r)["stepID"] + stepID := r.PathValue("stepID") if login.NextStep.StepID != stepID { zerolog.Ctx(r.Context()).Warn(). Str("request_step_id", stepID). @@ -271,7 +278,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { mautrix.MBadState.WithMessage("Step ID does not match").Write(w) return } - stepType := mux.Vars(r)["stepType"] + stepType := r.PathValue("stepType") if login.NextStep.Type != bridgev2.LoginStepType(stepType) { zerolog.Ctx(r.Context()).Warn(). Str("request_step_type", stepType). @@ -374,7 +381,7 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque login, err := prov.net.CreateLogin( r.Context(), prov.GetUser(r), - mux.Vars(r)["flowID"], + r.PathValue("flowID"), ) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create login process") @@ -475,7 +482,7 @@ func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Reques func (prov *ProvisioningAPI) PostLogout(w http.ResponseWriter, r *http.Request) { user := prov.GetUser(r) - userLoginID := networkid.UserLoginID(mux.Vars(r)["loginID"]) + userLoginID := networkid.UserLoginID(r.PathValue("loginID")) if userLoginID == "all" { for { login := user.GetDefaultLogin() @@ -571,7 +578,7 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http. mautrix.MUnrecognized.WithMessage("This bridge does not support resolving identifiers").Write(w) return } - resp, err := api.ResolveIdentifier(r.Context(), mux.Vars(r)["identifier"], createChat) + resp, err := api.ResolveIdentifier(r.Context(), r.PathValue("identifier"), createChat) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to resolve identifier") RespondWithError(w, err, "Internal error resolving identifier") diff --git a/bridgev2/matrix/publicmedia.go b/bridgev2/matrix/publicmedia.go index 9db5f442..95e37262 100644 --- a/bridgev2/matrix/publicmedia.go +++ b/bridgev2/matrix/publicmedia.go @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -16,8 +16,6 @@ import ( "net/http" "time" - "github.com/gorilla/mux" - "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/id" ) @@ -35,7 +33,7 @@ func (br *Connector) initPublicMedia() error { return fmt.Errorf("public media hash length is negative") } br.pubMediaSigKey = []byte(br.Config.PublicMedia.SigningKey) - br.AS.Router.HandleFunc("/_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia).Methods(http.MethodGet) + br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia) return nil } @@ -76,16 +74,15 @@ var proxyHeadersToCopy = []string{ } func (br *Connector) servePublicMedia(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) contentURI := id.ContentURI{ - Homeserver: vars["server"], - FileID: vars["mediaID"], + Homeserver: r.PathValue("server"), + FileID: r.PathValue("mediaID"), } if !contentURI.IsValid() { http.Error(w, "invalid content URI", http.StatusBadRequest) return } - checksum, err := base64.RawURLEncoding.DecodeString(vars["checksum"]) + checksum, err := base64.RawURLEncoding.DecodeString(r.PathValue("checksum")) if err != nil || !hmac.Equal(checksum, br.makePublicMediaChecksum(contentURI)) { http.Error(w, "invalid base64 in checksum", http.StatusBadRequest) return diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index b5a575ba..b30e274a 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -10,11 +10,10 @@ import ( "context" "fmt" "io" + "net/http" "os" "time" - "github.com/gorilla/mux" - "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" @@ -64,7 +63,7 @@ type MatrixConnectorWithArbitraryRoomState interface { type MatrixConnectorWithServer interface { GetPublicAddress() string - GetRouter() *mux.Router + GetRouter() *http.ServeMux } type MatrixConnectorWithPublicMedia interface { diff --git a/crypto/verificationhelper/mockserver_test.go b/crypto/verificationhelper/mockserver_test.go index b6bf3d2c..45ca7781 100644 --- a/crypto/verificationhelper/mockserver_test.go +++ b/crypto/verificationhelper/mockserver_test.go @@ -12,11 +12,9 @@ import ( "io" "net/http" "net/http/httptest" - "net/url" "strings" "testing" - "github.com/gorilla/mux" "github.com/rs/zerolog/log" // zerolog-allow-global-log "github.com/stretchr/testify/require" "go.mau.fi/util/random" @@ -42,20 +40,6 @@ type mockServer struct { UserSigningKeys map[id.UserID]mautrix.CrossSigningKeys } -func DecodeVarsMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - var err error - for k, v := range vars { - vars[k], err = url.PathUnescape(v) - if err != nil { - panic(err) - } - } - next.ServeHTTP(w, r) - }) -} - func createMockServer(t *testing.T) *mockServer { t.Helper() @@ -69,15 +53,14 @@ func createMockServer(t *testing.T) *mockServer { UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, } - router := mux.NewRouter().SkipClean(true).StrictSlash(false).UseEncodedPath() - router.Use(DecodeVarsMiddleware) - router.HandleFunc("/_matrix/client/v3/login", server.postLogin).Methods(http.MethodPost) - router.HandleFunc("/_matrix/client/v3/keys/query", server.postKeysQuery).Methods(http.MethodPost) - router.HandleFunc("/_matrix/client/v3/sendToDevice/{type}/{txn}", server.putSendToDevice).Methods(http.MethodPut) - router.HandleFunc("/_matrix/client/v3/user/{userID}/account_data/{type}", server.putAccountData).Methods(http.MethodPut) - router.HandleFunc("/_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload).Methods(http.MethodPost) - router.HandleFunc("/_matrix/client/v3/keys/signatures/upload", server.emptyResp).Methods(http.MethodPost) - router.HandleFunc("/_matrix/client/v3/keys/upload", server.postKeysUpload).Methods(http.MethodPost) + router := http.NewServeMux() + router.HandleFunc("POST /_matrix/client/v3/login", server.postLogin) + router.HandleFunc("POST /_matrix/client/v3/keys/query", server.postKeysQuery) + router.HandleFunc("PUT /_matrix/client/v3/sendToDevice/{type}/{txn}", server.putSendToDevice) + router.HandleFunc("PUT /_matrix/client/v3/user/{userID}/account_data/{type}", server.putAccountData) + router.HandleFunc("POST /_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload) + router.HandleFunc("POST /_matrix/client/v3/keys/signatures/upload", server.emptyResp) + router.HandleFunc("POST /_matrix/client/v3/keys/upload", server.postKeysUpload) server.Server = httptest.NewServer(router) return &server @@ -118,10 +101,9 @@ func (s *mockServer) postLogin(w http.ResponseWriter, r *http.Request) { } func (s *mockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) var req mautrix.ReqSendToDevice json.NewDecoder(r.Body).Decode(&req) - evtType := event.Type{Type: vars["type"], Class: event.ToDeviceEventType} + evtType := event.Type{Type: r.PathValue("type"), Class: event.ToDeviceEventType} for user, devices := range req.Messages { for device, content := range devices { @@ -140,9 +122,8 @@ func (s *mockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) { } func (s *mockServer) putAccountData(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - userID := id.UserID(vars["userID"]) - eventType := event.Type{Type: vars["type"], Class: event.AccountDataEventType} + userID := id.UserID(r.PathValue("userID")) + eventType := event.Type{Type: r.PathValue("type"), Class: event.AccountDataEventType} jsonData, _ := io.ReadAll(r.Body) if _, ok := s.AccountData[userID]; !ok { diff --git a/federation/keyserver.go b/federation/keyserver.go index b0faf8fb..37998786 100644 --- a/federation/keyserver.go +++ b/federation/keyserver.go @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -12,7 +12,7 @@ import ( "strconv" "time" - "github.com/gorilla/mux" + "go.mau.fi/util/exerrors" "go.mau.fi/util/exhttp" "go.mau.fi/util/jsontime" @@ -51,19 +51,21 @@ type KeyServer struct { } // Register registers the key server endpoints to the given router. -func (ks *KeyServer) Register(r *mux.Router) { - r.HandleFunc("/.well-known/matrix/server", ks.GetWellKnown).Methods(http.MethodGet) - r.HandleFunc("/_matrix/federation/v1/version", ks.GetServerVersion).Methods(http.MethodGet) - keyRouter := r.PathPrefix("/_matrix/key").Subrouter() - keyRouter.HandleFunc("/v2/server", ks.GetServerKey).Methods(http.MethodGet) - keyRouter.HandleFunc("/v2/query/{serverName}", ks.GetQueryKeys).Methods(http.MethodGet) - keyRouter.HandleFunc("/v2/query", ks.PostQueryKeys).Methods(http.MethodPost) - keyRouter.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - mautrix.MUnrecognized.WithStatus(http.StatusNotFound).WithMessage("Unrecognized endpoint").Write(w) - }) - keyRouter.MethodNotAllowedHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - mautrix.MUnrecognized.WithStatus(http.StatusMethodNotAllowed).WithMessage("Invalid method for endpoint").Write(w) - }) +func (ks *KeyServer) Register(r *http.ServeMux) { + r.HandleFunc("GET /.well-known/matrix/server", ks.GetWellKnown) + r.HandleFunc("GET /_matrix/federation/v1/version", ks.GetServerVersion) + keyRouter := http.NewServeMux() + keyRouter.HandleFunc("GET /v2/server", ks.GetServerKey) + keyRouter.HandleFunc("GET /v2/query/{serverName}", ks.GetQueryKeys) + keyRouter.HandleFunc("POST /v2/query", ks.PostQueryKeys) + errorBodies := exhttp.ErrorBodies{ + NotFound: exerrors.Must(json.Marshal(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint"))), + MethodNotAllowed: exerrors.Must(json.Marshal(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint"))), + } + r.Handle("/_matrix/key", exhttp.ApplyMiddleware( + keyRouter, + exhttp.HandleErrors(errorBodies), + )) } // RespWellKnown is the response body for the `GET /.well-known/matrix/server` endpoint. @@ -157,7 +159,7 @@ type GetQueryKeysResponse struct { // // https://spec.matrix.org/v1.9/server-server-api/#get_matrixkeyv2queryservername func (ks *KeyServer) GetQueryKeys(w http.ResponseWriter, r *http.Request) { - serverName := mux.Vars(r)["serverName"] + serverName := r.PathValue("serverName") minimumValidUntilTSString := r.URL.Query().Get("minimum_valid_until_ts") minimumValidUntilTS, err := strconv.ParseInt(minimumValidUntilTSString, 10, 64) if err != nil && minimumValidUntilTSString != "" { diff --git a/go.mod b/go.mod index 59f29c0c..d71e86ab 100644 --- a/go.mod +++ b/go.mod @@ -2,12 +2,11 @@ module maunium.net/go/mautrix go 1.23.0 -toolchain go1.24.4 +toolchain go1.24.5 require ( filippo.io/edwards25519 v1.1.0 github.com/chzyer/readline v1.5.1 - github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.5.0 github.com/lib/pq v1.10.9 github.com/mattn/go-sqlite3 v1.14.28 @@ -18,10 +17,10 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.12 - go.mau.fi/util v0.8.8 + go.mau.fi/util v0.8.9-0.20250723171559-474867266038 go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.40.0 - golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc + golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 golang.org/x/net v0.42.0 golang.org/x/sync v0.16.0 gopkg.in/yaml.v3 v3.0.1 @@ -33,7 +32,7 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/petermattis/goid v0.0.0-20250508124226-395b08cebbdb // indirect + github.com/petermattis/goid v0.0.0-20250721140440-ea1c0173183e // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect diff --git a/go.sum b/go.sum index 9f48386e..eaa97cc8 100644 --- a/go.sum +++ b/go.sum @@ -13,8 +13,6 @@ github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSV github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= -github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= @@ -28,8 +26,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A= github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/petermattis/goid v0.0.0-20250508124226-395b08cebbdb h1:3PrKuO92dUTMrQ9dx0YNejC6U/Si6jqKmyQ9vWjwqR4= -github.com/petermattis/goid v0.0.0-20250508124226-395b08cebbdb/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= +github.com/petermattis/goid v0.0.0-20250721140440-ea1c0173183e h1:D0bJD+4O3G4izvrQUmzCL80zazlN7EwJ0PPDhpJWC/I= +github.com/petermattis/goid v0.0.0-20250721140440-ea1c0173183e/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -53,14 +51,14 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.12 h1:YwGP/rrea2/CnCtUHgjuolG/PnMxdQtPMO5PvaE2/nY= github.com/yuin/goldmark v1.7.12/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.8.8 h1:OnuEEc/sIJFhnq4kFggiImUpcmnmL/xpvQMRu5Fiy5c= -go.mau.fi/util v0.8.8/go.mod h1:Y/kS3loxTEhy8Vill513EtPXr+CRDdae+Xj2BXXMy/c= +go.mau.fi/util v0.8.9-0.20250723171559-474867266038 h1:RVL8TVaYc3LTBBopfjCNDtD+6eZks0O+qgXN/9hsz7k= +go.mau.fi/util v0.8.9-0.20250723171559-474867266038/go.mod h1:GZZp5f9r2MgEu4GDvtB0XxCF7i6Z7Z8fM0w9a5oZH3Y= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY= -golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc h1:TS73t7x3KarrNd5qAipmspBDS1rkMcgVG/fS1aRb4Rc= -golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc= +golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 h1:R9PFI6EUdfVKgwKjZef7QIwGcBKu86OEFpJ9nUEP2l4= +golang.org/x/exp v0.0.0-20250718183923-645b1fa84792/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc= golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go index 4be799d3..6fbcdbad 100644 --- a/mediaproxy/mediaproxy.go +++ b/mediaproxy/mediaproxy.go @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -8,6 +8,7 @@ package mediaproxy import ( "context" + "encoding/json" "errors" "fmt" "io" @@ -21,8 +22,9 @@ import ( "strings" "time" - "github.com/gorilla/mux" "github.com/rs/zerolog" + "go.mau.fi/util/exerrors" + "go.mau.fi/util/exhttp" "maunium.net/go/mautrix" "maunium.net/go/mautrix/federation" @@ -108,8 +110,8 @@ type MediaProxy struct { serverName string serverKey *federation.SigningKey - FederationRouter *mux.Router - ClientMediaRouter *mux.Router + FederationRouter *http.ServeMux + ClientMediaRouter *http.ServeMux } func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProxy, error) { @@ -117,7 +119,7 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx if err != nil { return nil, err } - return &MediaProxy{ + mp := &MediaProxy{ serverName: serverName, serverKey: parsed, GetMedia: getMedia, @@ -132,7 +134,20 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx Version: strings.TrimPrefix(mautrix.VersionWithCommit, "v"), }, }, - }, nil + } + mp.FederationRouter = http.NewServeMux() + mp.FederationRouter.HandleFunc("GET /v1/media/download/{mediaID}", mp.DownloadMediaFederation) + mp.FederationRouter.HandleFunc("GET /v1/version", mp.KeyServer.GetServerVersion) + mp.ClientMediaRouter = http.NewServeMux() + mp.ClientMediaRouter.HandleFunc("GET /download/{serverName}/{mediaID}", mp.DownloadMedia) + mp.ClientMediaRouter.HandleFunc("GET /download/{serverName}/{mediaID}/{fileName}", mp.DownloadMedia) + mp.ClientMediaRouter.HandleFunc("GET /thumbnail/{serverName}/{mediaID}", mp.DownloadMedia) + mp.ClientMediaRouter.HandleFunc("PUT /upload/{serverName}/{mediaID}", mp.UploadNotSupported) + mp.ClientMediaRouter.HandleFunc("POST /upload", mp.UploadNotSupported) + mp.ClientMediaRouter.HandleFunc("POST /create", mp.UploadNotSupported) + mp.ClientMediaRouter.HandleFunc("GET /config", mp.UploadNotSupported) + mp.ClientMediaRouter.HandleFunc("GET /preview_url", mp.PreviewURLNotSupported) + return mp, nil } type BasicConfig struct { @@ -162,7 +177,7 @@ type ServerConfig struct { } func (mp *MediaProxy) Listen(cfg ServerConfig) error { - router := mux.NewRouter() + router := http.NewServeMux() mp.RegisterRoutes(router) return http.ListenAndServe(fmt.Sprintf("%s:%d", cfg.Hostname, cfg.Port), router) } @@ -188,38 +203,20 @@ func (mp *MediaProxy) EnableServerAuth(client *federation.Client, keyCache feder }) } -func (mp *MediaProxy) RegisterRoutes(router *mux.Router) { - if mp.FederationRouter == nil { - mp.FederationRouter = router.PathPrefix("/_matrix/federation").Subrouter() +func (mp *MediaProxy) RegisterRoutes(router *http.ServeMux) { + errorBodies := exhttp.ErrorBodies{ + NotFound: exerrors.Must(json.Marshal(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint"))), + MethodNotAllowed: exerrors.Must(json.Marshal(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint"))), } - if mp.ClientMediaRouter == nil { - mp.ClientMediaRouter = router.PathPrefix("/_matrix/client/v1/media").Subrouter() - } - - mp.FederationRouter.HandleFunc("/v1/media/download/{mediaID}", mp.DownloadMediaFederation).Methods(http.MethodGet) - mp.FederationRouter.HandleFunc("/v1/version", mp.KeyServer.GetServerVersion).Methods(http.MethodGet) - mp.ClientMediaRouter.HandleFunc("/download/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet) - mp.ClientMediaRouter.HandleFunc("/download/{serverName}/{mediaID}/{fileName}", mp.DownloadMedia).Methods(http.MethodGet) - mp.ClientMediaRouter.HandleFunc("/thumbnail/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet) - mp.ClientMediaRouter.HandleFunc("/upload/{serverName}/{mediaID}", mp.UploadNotSupported).Methods(http.MethodPut) - mp.ClientMediaRouter.HandleFunc("/upload", mp.UploadNotSupported).Methods(http.MethodPost) - mp.ClientMediaRouter.HandleFunc("/create", mp.UploadNotSupported).Methods(http.MethodPost) - mp.ClientMediaRouter.HandleFunc("/config", mp.UploadNotSupported).Methods(http.MethodGet) - mp.ClientMediaRouter.HandleFunc("/preview_url", mp.PreviewURLNotSupported).Methods(http.MethodGet) - mp.FederationRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint) - mp.FederationRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod) - mp.ClientMediaRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint) - mp.ClientMediaRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod) - corsMiddleware := func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With, Content-Type, Authorization") - w.Header().Set("Content-Security-Policy", "sandbox; default-src 'none'; script-src 'none'; plugin-types application/pdf; style-src 'unsafe-inline'; object-src 'self';") - next.ServeHTTP(w, r) - }) - } - mp.ClientMediaRouter.Use(corsMiddleware) + router.Handle("/_matrix/federation", exhttp.ApplyMiddleware( + mp.FederationRouter, + exhttp.HandleErrors(errorBodies), + )) + router.Handle("/_matrix/client/v1/media", exhttp.ApplyMiddleware( + mp.ClientMediaRouter, + exhttp.CORSMiddleware, + exhttp.HandleErrors(errorBodies), + )) mp.KeyServer.Register(router) } @@ -234,7 +231,7 @@ func queryToMap(vals url.Values) map[string]string { } func (mp *MediaProxy) getMedia(w http.ResponseWriter, r *http.Request) GetMediaResponse { - mediaID := mux.Vars(r)["mediaID"] + mediaID := r.PathValue("mediaID") if !id.IsValidMediaID(mediaID) { mautrix.MNotFound.WithMessage("Media ID %q is not valid", mediaID).Write(w) return nil @@ -380,8 +377,7 @@ func (mp *MediaProxy) addHeaders(w http.ResponseWriter, mimeType, fileName strin func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) { ctx := r.Context() log := zerolog.Ctx(ctx) - vars := mux.Vars(r) - if vars["serverName"] != mp.serverName { + if r.PathValue("serverName") != mp.serverName { mautrix.MNotFound.WithMessage("This is a media proxy at %q, other media downloads are not available here", mp.serverName).Write(w) return } @@ -404,7 +400,7 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusTemporaryRedirect) } else if fileResp, ok := resp.(*GetMediaResponseFile); ok { responseStarted, err := doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error { - mp.addHeaders(w, mimeType, vars["fileName"]) + mp.addHeaders(w, mimeType, r.PathValue("fileName")) w.Header().Set("Content-Length", strconv.FormatInt(size, 10)) w.WriteHeader(http.StatusOK) _, err := wt.WriteTo(w) @@ -425,7 +421,7 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) { if dataResp, ok := writerResp.(*GetMediaResponseData); ok { defer dataResp.Reader.Close() } - mp.addHeaders(w, writerResp.GetContentType(), vars["fileName"]) + mp.addHeaders(w, writerResp.GetContentType(), r.PathValue("fileName")) if writerResp.GetContentLength() != 0 { w.Header().Set("Content-Length", strconv.FormatInt(writerResp.GetContentLength(), 10)) } @@ -491,11 +487,6 @@ var ( ErrPreviewURLNotSupported = mautrix.MUnrecognized. WithMessage("This is a media proxy and does not support URL previews."). WithStatus(http.StatusNotImplemented) - ErrUnknownEndpoint = mautrix.MUnrecognized. - WithMessage("Unrecognized endpoint") - ErrUnsupportedMethod = mautrix.MUnrecognized. - WithMessage("Invalid method for endpoint"). - WithStatus(http.StatusMethodNotAllowed) ) func (mp *MediaProxy) UploadNotSupported(w http.ResponseWriter, r *http.Request) { @@ -505,11 +496,3 @@ func (mp *MediaProxy) UploadNotSupported(w http.ResponseWriter, r *http.Request) func (mp *MediaProxy) PreviewURLNotSupported(w http.ResponseWriter, r *http.Request) { ErrPreviewURLNotSupported.Write(w) } - -func (mp *MediaProxy) UnknownEndpoint(w http.ResponseWriter, r *http.Request) { - ErrUnknownEndpoint.Write(w) -} - -func (mp *MediaProxy) UnsupportedMethod(w http.ResponseWriter, r *http.Request) { - ErrUnsupportedMethod.Write(w) -}