treewide: replace gorilla/mux with http.ServeMux
Some checks failed
Go / Lint (latest) (push) Has been cancelled
Go / Build (old, libolm) (push) Has been cancelled
Go / Build (latest, libolm) (push) Has been cancelled
Go / Build (old, goolm) (push) Has been cancelled
Go / Build (latest, goolm) (push) Has been cancelled

Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
This commit is contained in:
Sumner Evans 2024-08-23 09:45:45 -06:00
commit 481f435dfe
No known key found for this signature in database
11 changed files with 175 additions and 188 deletions

View file

@ -19,7 +19,6 @@ import (
"syscall"
"time"
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
"github.com/rs/zerolog"
"golang.org/x/net/publicsuffix"
@ -43,7 +42,7 @@ func Create() *AppService {
intents: make(map[id.UserID]*IntentAPI),
HTTPClient: &http.Client{Timeout: 180 * time.Second, Jar: jar},
StateStore: mautrix.NewMemoryStateStore().(StateStore),
Router: mux.NewRouter(),
Router: http.NewServeMux(),
UserAgent: mautrix.DefaultUserAgent,
txnIDC: NewTransactionIDCache(128),
Live: true,
@ -61,12 +60,12 @@ func Create() *AppService {
DefaultHTTPRetries: 4,
}
as.Router.HandleFunc("/_matrix/app/v1/transactions/{txnID}", as.PutTransaction).Methods(http.MethodPut)
as.Router.HandleFunc("/_matrix/app/v1/rooms/{roomAlias}", as.GetRoom).Methods(http.MethodGet)
as.Router.HandleFunc("/_matrix/app/v1/users/{userID}", as.GetUser).Methods(http.MethodGet)
as.Router.HandleFunc("/_matrix/app/v1/ping", as.PostPing).Methods(http.MethodPost)
as.Router.HandleFunc("/_matrix/mau/live", as.GetLive).Methods(http.MethodGet)
as.Router.HandleFunc("/_matrix/mau/ready", as.GetReady).Methods(http.MethodGet)
as.Router.HandleFunc("PUT /_matrix/app/v1/transactions/{txnID}", as.PutTransaction)
as.Router.HandleFunc("GET /_matrix/app/v1/rooms/{roomAlias}", as.GetRoom)
as.Router.HandleFunc("GET /_matrix/app/v1/users/{userID}", as.GetUser)
as.Router.HandleFunc("POST /_matrix/app/v1/ping", as.PostPing)
as.Router.HandleFunc("GET /_matrix/mau/live", as.GetLive)
as.Router.HandleFunc("GET /_matrix/mau/ready", as.GetReady)
return as
}
@ -160,7 +159,7 @@ type AppService struct {
QueryHandler QueryHandler
StateStore StateStore
Router *mux.Router
Router *http.ServeMux
UserAgent string
server *http.Server
HTTPClient *http.Client

View file

@ -17,7 +17,6 @@ import (
"syscall"
"time"
"github.com/gorilla/mux"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
@ -101,8 +100,7 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) {
return
}
vars := mux.Vars(r)
txnID := vars["txnID"]
txnID := r.PathValue("txnID")
if len(txnID) == 0 {
Error{
ErrorCode: ErrNoTransactionID,
@ -258,9 +256,7 @@ func (as *AppService) GetRoom(w http.ResponseWriter, r *http.Request) {
return
}
vars := mux.Vars(r)
roomAlias := vars["roomAlias"]
ok := as.QueryHandler.QueryAlias(roomAlias)
ok := as.QueryHandler.QueryAlias(r.PathValue("roomAlias"))
if ok {
WriteBlankOK(w)
} else {
@ -277,9 +273,7 @@ func (as *AppService) GetUser(w http.ResponseWriter, r *http.Request) {
return
}
vars := mux.Vars(r)
userID := id.UserID(vars["userID"])
ok := as.QueryHandler.QueryUser(userID)
ok := as.QueryHandler.QueryUser(id.UserID(r.PathValue("userID")))
if ok {
WriteBlankOK(w)
} else {

View file

@ -13,6 +13,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"os"
"regexp"
@ -21,7 +22,6 @@ import (
"time"
"unsafe"
"github.com/gorilla/mux"
_ "github.com/lib/pq"
"github.com/rs/zerolog"
"go.mau.fi/util/dbutil"
@ -222,7 +222,8 @@ func (br *Connector) GetPublicAddress() string {
return br.Config.AppService.PublicAddress
}
func (br *Connector) GetRouter() *mux.Router {
// TODO switch to http.ServeMux
func (br *Connector) GetRouter() *http.ServeMux {
if br.GetPublicAddress() != "" {
return br.AS.Router
}

View file

@ -17,10 +17,10 @@ import (
"sync"
"time"
"github.com/gorilla/mux"
"github.com/rs/xid"
"github.com/rs/zerolog"
"github.com/rs/zerolog/hlog"
"go.mau.fi/util/exhttp"
"go.mau.fi/util/jsontime"
"go.mau.fi/util/requestlog"
@ -38,7 +38,7 @@ type matrixAuthCacheEntry struct {
}
type ProvisioningAPI struct {
Router *mux.Router
Router *http.ServeMux
br *Connector
log zerolog.Logger
@ -82,12 +82,12 @@ func (prov *ProvisioningAPI) GetUser(r *http.Request) *bridgev2.User {
return r.Context().Value(provisioningUserKey).(*bridgev2.User)
}
func (prov *ProvisioningAPI) GetRouter() *mux.Router {
func (prov *ProvisioningAPI) GetRouter() *http.ServeMux {
return prov.Router
}
type IProvisioningAPI interface {
GetRouter() *mux.Router
GetRouter() *http.ServeMux
GetUser(r *http.Request) *bridgev2.User
}
@ -106,50 +106,44 @@ func (prov *ProvisioningAPI) Init() {
tp.Dialer.Timeout = 10 * time.Second
tp.Transport.ResponseHeaderTimeout = 10 * time.Second
tp.Transport.TLSHandshakeTimeout = 10 * time.Second
prov.Router = prov.br.AS.Router.PathPrefix(prov.br.Config.Provisioning.Prefix).Subrouter()
prov.Router.Use(hlog.NewHandler(prov.log))
prov.Router.Use(hlog.RequestIDHandler("request_id", "Request-Id"))
prov.Router.Use(corsMiddleware)
prov.Router.Use(requestlog.AccessLogger(false))
prov.Router.Use(prov.AuthMiddleware)
prov.Router.Path("/v3/whoami").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetWhoami)
prov.Router.Path("/v3/login/flows").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetLoginFlows)
prov.Router.Path("/v3/login/start/{flowID}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginStart)
prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:user_input|cookies}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginSubmitInput)
prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:display_and_wait}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginWait)
prov.Router.Path("/v3/logout/{loginID}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLogout)
prov.Router.Path("/v3/logins").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetLogins)
prov.Router.Path("/v3/contacts").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetContactList)
prov.Router.Path("/v3/search_users").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostSearchUsers)
prov.Router.Path("/v3/resolve_identifier/{identifier}").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetResolveIdentifier)
prov.Router.Path("/v3/create_dm/{identifier}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostCreateDM)
prov.Router.Path("/v3/create_group").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostCreateGroup)
provRouter := http.NewServeMux()
provRouter.HandleFunc("GET /v3/whoami", prov.GetWhoami)
provRouter.HandleFunc("GET /v3/whoami/flows", prov.GetLoginFlows)
provRouter.HandleFunc("POST /v3/login/start/{flowID}", prov.PostLoginStart)
provRouter.HandleFunc("POST /v3/login/step/{loginProcessID}/{stepID}/{stepType}", prov.PostLogin)
provRouter.HandleFunc("POST /v3/logout/{loginID}", prov.PostLogout)
provRouter.HandleFunc("GET /v3/logins", prov.GetLogins)
provRouter.HandleFunc("GET /v3/contacts", prov.GetContactList)
provRouter.HandleFunc("POST /v3/search_users", prov.PostSearchUsers)
provRouter.HandleFunc("GET /v3/resolve_identifier/{identifier}", prov.GetResolveIdentifier)
provRouter.HandleFunc("POST /v3/create_dm/{identifier}", prov.PostCreateDM)
provRouter.HandleFunc("POST /v3/create_group", prov.PostCreateGroup)
var provHandler http.Handler = prov.Router
provHandler = prov.AuthMiddleware(provHandler)
provHandler = requestlog.AccessLogger(false)(provHandler)
provHandler = exhttp.CORSMiddleware(provHandler)
provHandler = hlog.RequestIDHandler("request_id", "Request-Id")(provHandler)
provHandler = hlog.NewHandler(prov.log)(provHandler)
provHandler = http.StripPrefix(prov.br.Config.Provisioning.Prefix, provHandler)
prov.br.AS.Router.Handle(prov.br.Config.Provisioning.Prefix, provHandler)
if prov.br.Config.Provisioning.DebugEndpoints {
prov.log.Debug().Msg("Enabling debug API at /debug")
r := prov.br.AS.Router.PathPrefix("/debug").Subrouter()
r.Use(prov.DebugAuthMiddleware)
r.HandleFunc("/pprof/cmdline", pprof.Cmdline).Methods(http.MethodGet)
r.HandleFunc("/pprof/profile", pprof.Profile).Methods(http.MethodGet)
r.HandleFunc("/pprof/symbol", pprof.Symbol).Methods(http.MethodGet)
r.HandleFunc("/pprof/trace", pprof.Trace).Methods(http.MethodGet)
r.PathPrefix("/pprof/").HandlerFunc(pprof.Index)
debugRouter := http.NewServeMux()
// TODO do we need to strip prefix here?
debugRouter.HandleFunc("/debug/pprof", pprof.Index)
debugRouter.HandleFunc("GET /debug/pprof/trace", pprof.Trace)
debugRouter.HandleFunc("GET /debug/pprof/symbol", pprof.Symbol)
debugRouter.HandleFunc("GET /debug/pprof/profile", pprof.Profile)
debugRouter.HandleFunc("GET /debug/pprof/cmdline", pprof.Cmdline)
prov.br.AS.Router.Handle("/debug", prov.AuthMiddleware(debugRouter))
}
}
func corsMiddleware(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With, Content-Type, Authorization")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusOK)
return
}
handler.ServeHTTP(w, r)
})
}
func jsonResponse(w http.ResponseWriter, status int, response any) {
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(status)
@ -270,7 +264,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
}
ctx := context.WithValue(r.Context(), provisioningUserKey, user)
if loginID, ok := mux.Vars(r)["loginProcessID"]; ok {
if loginID := r.PathValue("loginProcessID"); loginID != "" {
prov.loginsLock.RLock()
login, ok := prov.logins[loginID]
prov.loginsLock.RUnlock()
@ -285,7 +279,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
login.Lock.Lock()
// This will only unlock after the handler runs
defer login.Lock.Unlock()
stepID := mux.Vars(r)["stepID"]
stepID := r.PathValue("stepID")
if login.NextStep.StepID != stepID {
zerolog.Ctx(r.Context()).Warn().
Str("request_step_id", stepID).
@ -297,7 +291,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
})
return
}
stepType := mux.Vars(r)["stepType"]
stepType := r.PathValue("stepType")
if login.NextStep.Type != bridgev2.LoginStepType(stepType) {
zerolog.Ctx(r.Context()).Warn().
Str("request_step_type", stepType).
@ -401,7 +395,7 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque
login, err := prov.net.CreateLogin(
r.Context(),
prov.GetUser(r),
mux.Vars(r)["flowID"],
r.PathValue("flowID"),
)
if err != nil {
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create login process")
@ -440,6 +434,17 @@ func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *Prov
}, bridgev2.DeleteOpts{LogoutRemote: true})
}
func (prov *ProvisioningAPI) PostLogin(w http.ResponseWriter, r *http.Request) {
switch r.PathValue("stepType") {
case "user_input", "cookies":
prov.PostLoginSubmitInput(w, r)
case "display_and_wait":
prov.PostLoginWait(w, r)
default:
panic("Impossible state") // checked by the AuthMiddleware
}
}
func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http.Request) {
var params map[string]string
err := json.NewDecoder(r.Body).Decode(&params)
@ -493,7 +498,7 @@ func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Reques
func (prov *ProvisioningAPI) PostLogout(w http.ResponseWriter, r *http.Request) {
user := prov.GetUser(r)
userLoginID := networkid.UserLoginID(mux.Vars(r)["loginID"])
userLoginID := networkid.UserLoginID(r.PathValue("loginID"))
if userLoginID == "all" {
for {
login := user.GetDefaultLogin()
@ -596,7 +601,7 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http.
})
return
}
resp, err := api.ResolveIdentifier(r.Context(), mux.Vars(r)["identifier"], createChat)
resp, err := api.ResolveIdentifier(r.Context(), r.PathValue("identifier"), createChat)
if err != nil {
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to resolve identifier")
RespondWithError(w, err, "Internal error resolving identifier")

View file

@ -16,8 +16,6 @@ import (
"net/http"
"time"
"github.com/gorilla/mux"
"maunium.net/go/mautrix/bridgev2"
"maunium.net/go/mautrix/id"
)
@ -35,7 +33,7 @@ func (br *Connector) initPublicMedia() error {
return fmt.Errorf("public media hash length is negative")
}
br.pubMediaSigKey = []byte(br.Config.PublicMedia.SigningKey)
br.AS.Router.HandleFunc("/_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia).Methods(http.MethodGet)
br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia)
return nil
}
@ -76,16 +74,15 @@ var proxyHeadersToCopy = []string{
}
func (br *Connector) servePublicMedia(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
contentURI := id.ContentURI{
Homeserver: vars["server"],
FileID: vars["mediaID"],
Homeserver: r.PathValue("server"),
FileID: r.PathValue("mediaID"),
}
if !contentURI.IsValid() {
http.Error(w, "invalid content URI", http.StatusBadRequest)
return
}
checksum, err := base64.RawURLEncoding.DecodeString(vars["checksum"])
checksum, err := base64.RawURLEncoding.DecodeString(r.PathValue("checksum"))
if err != nil || !hmac.Equal(checksum, br.makePublicMediaChecksum(contentURI)) {
http.Error(w, "invalid base64 in checksum", http.StatusBadRequest)
return

View file

@ -10,11 +10,10 @@ import (
"context"
"fmt"
"io"
"net/http"
"os"
"time"
"github.com/gorilla/mux"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/bridge/status"
"maunium.net/go/mautrix/bridgev2/database"
@ -58,7 +57,7 @@ type MatrixConnector interface {
type MatrixConnectorWithServer interface {
GetPublicAddress() string
GetRouter() *mux.Router
GetRouter() *http.ServeMux
}
type MatrixConnectorWithPublicMedia interface {

View file

@ -12,11 +12,9 @@ import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/gorilla/mux"
"github.com/rs/zerolog/log"
"github.com/stretchr/testify/require"
"go.mau.fi/util/random"
@ -42,19 +40,19 @@ type mockServer struct {
UserSigningKeys map[id.UserID]mautrix.CrossSigningKeys
}
func DecodeVarsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
var err error
for k, v := range vars {
vars[k], err = url.PathUnescape(v)
if err != nil {
panic(err)
}
}
next.ServeHTTP(w, r)
})
}
// func DecodeVarsMiddleware(next http.Handler) http.Handler {
// return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// vars := mux.Vars(r)
// var err error
// for k, v := range vars {
// vars[k], err = url.PathUnescape(v)
// if err != nil {
// panic(err)
// }
// }
// next.ServeHTTP(w, r)
// })
// }
func createMockServer(t *testing.T) *mockServer {
t.Helper()
@ -69,15 +67,14 @@ func createMockServer(t *testing.T) *mockServer {
UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
}
router := mux.NewRouter().SkipClean(true).StrictSlash(false).UseEncodedPath()
router.Use(DecodeVarsMiddleware)
router.HandleFunc("/_matrix/client/v3/login", server.postLogin).Methods(http.MethodPost)
router.HandleFunc("/_matrix/client/v3/keys/query", server.postKeysQuery).Methods(http.MethodPost)
router.HandleFunc("/_matrix/client/v3/sendToDevice/{type}/{txn}", server.putSendToDevice).Methods(http.MethodPut)
router.HandleFunc("/_matrix/client/v3/user/{userID}/account_data/{type}", server.putAccountData).Methods(http.MethodPut)
router.HandleFunc("/_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload).Methods(http.MethodPost)
router.HandleFunc("/_matrix/client/v3/keys/signatures/upload", server.emptyResp).Methods(http.MethodPost)
router.HandleFunc("/_matrix/client/v3/keys/upload", server.postKeysUpload).Methods(http.MethodPost)
router := http.NewServeMux()
router.HandleFunc("POST /_matrix/client/v3/login", server.postLogin)
router.HandleFunc("POST /_matrix/client/v3/keys/query", server.postKeysQuery)
router.HandleFunc("PUT /_matrix/client/v3/sendToDevice/{type}/{txn}", server.putSendToDevice)
router.HandleFunc("PUT /_matrix/client/v3/user/{userID}/account_data/{type}", server.putAccountData)
router.HandleFunc("POST /_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload)
router.HandleFunc("POST /_matrix/client/v3/keys/signatures/upload", server.emptyResp)
router.HandleFunc("POST /_matrix/client/v3/keys/upload", server.postKeysUpload)
server.Server = httptest.NewServer(router)
return &server
@ -118,10 +115,9 @@ func (s *mockServer) postLogin(w http.ResponseWriter, r *http.Request) {
}
func (s *mockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
var req mautrix.ReqSendToDevice
json.NewDecoder(r.Body).Decode(&req)
evtType := event.Type{Type: vars["type"], Class: event.ToDeviceEventType}
evtType := event.Type{Type: r.PathValue("type"), Class: event.ToDeviceEventType}
for user, devices := range req.Messages {
for device, content := range devices {
@ -140,9 +136,8 @@ func (s *mockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) {
}
func (s *mockServer) putAccountData(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
userID := id.UserID(vars["userID"])
eventType := event.Type{Type: vars["type"], Class: event.AccountDataEventType}
userID := id.UserID(r.PathValue("userID"))
eventType := event.Type{Type: r.PathValue("type"), Class: event.AccountDataEventType}
jsonData, _ := io.ReadAll(r.Body)
if _, ok := s.AccountData[userID]; !ok {

View file

@ -13,7 +13,7 @@ import (
"strconv"
"time"
"github.com/gorilla/mux"
"go.mau.fi/util/exhttp"
"go.mau.fi/util/jsontime"
"maunium.net/go/mautrix"
@ -50,25 +50,32 @@ type KeyServer struct {
}
// Register registers the key server endpoints to the given router.
func (ks *KeyServer) Register(r *mux.Router) {
r.HandleFunc("/.well-known/matrix/server", ks.GetWellKnown).Methods(http.MethodGet)
r.HandleFunc("/_matrix/federation/v1/version", ks.GetServerVersion).Methods(http.MethodGet)
keyRouter := r.PathPrefix("/_matrix/key").Subrouter()
keyRouter.HandleFunc("/v2/server", ks.GetServerKey).Methods(http.MethodGet)
keyRouter.HandleFunc("/v2/query/{serverName}", ks.GetQueryKeys).Methods(http.MethodGet)
keyRouter.HandleFunc("/v2/query", ks.PostQueryKeys).Methods(http.MethodPost)
keyRouter.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
jsonResponse(w, http.StatusNotFound, &mautrix.RespError{
ErrCode: mautrix.MUnrecognized.ErrCode,
Err: "Unrecognized endpoint",
})
})
keyRouter.MethodNotAllowedHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
jsonResponse(w, http.StatusMethodNotAllowed, &mautrix.RespError{
ErrCode: mautrix.MUnrecognized.ErrCode,
Err: "Invalid method for endpoint",
})
func (ks *KeyServer) Register(r *http.ServeMux) {
r.HandleFunc("GET /.well-known/matrix/server", ks.GetWellKnown)
r.HandleFunc("GET /_matrix/federation/v1/version", ks.GetServerVersion)
keyRouter := http.NewServeMux()
keyRouter.HandleFunc("GET /v2/server", ks.GetServerKey)
keyRouter.HandleFunc("GET /v2/query/{serverName}", ks.GetQueryKeys)
keyRouter.HandleFunc("POST /v2/query", ks.PostQueryKeys)
keyHandler := exhttp.HandleErrors(keyRouter, exhttp.ErrorBodyGenerators{
NotFound: func() (body []byte) {
body, _ = json.Marshal(&mautrix.RespError{
ErrCode: mautrix.MUnrecognized.ErrCode,
Err: "Unrecognized endpoint",
})
return
},
MethodNotAllowed: func() (body []byte) {
body, _ = json.Marshal(&mautrix.RespError{
ErrCode: mautrix.MUnrecognized.ErrCode,
Err: "Invalid method for endpoint",
})
return
},
})
r.Handle("/_matrix/key", http.StripPrefix("/_matrix/key", keyHandler))
}
func jsonResponse(w http.ResponseWriter, code int, data any) {
@ -177,7 +184,7 @@ type GetQueryKeysResponse struct {
//
// https://spec.matrix.org/v1.9/server-server-api/#get_matrixkeyv2queryservername
func (ks *KeyServer) GetQueryKeys(w http.ResponseWriter, r *http.Request) {
serverName := mux.Vars(r)["serverName"]
serverName := r.PathValue("serverName")
minimumValidUntilTSString := r.URL.Query().Get("minimum_valid_until_ts")
minimumValidUntilTS, err := strconv.ParseInt(minimumValidUntilTSString, 10, 64)
if err != nil && minimumValidUntilTSString != "" {

1
go.mod
View file

@ -7,7 +7,6 @@ toolchain go1.23.2
require (
filippo.io/edwards25519 v1.1.0
github.com/chzyer/readline v1.5.1
github.com/gorilla/mux v1.8.0
github.com/gorilla/websocket v1.5.0
github.com/lib/pq v1.10.9
github.com/mattn/go-sqlite3 v1.14.23

4
go.sum
View file

@ -13,8 +13,6 @@ github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSV
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
@ -53,6 +51,8 @@ github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg=
github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
go.mau.fi/util v0.8.1-0.20241003092848-3b49d3e0b9ee h1:/BGpUK7fzVyFgy5KBiyP7ktEDn20vzz/5FTngrXtIEE=
go.mau.fi/util v0.8.1-0.20241003092848-3b49d3e0b9ee/go.mod h1:L9qnqEkhe4KpuYmILrdttKTXL79MwGLyJ4EOskWxO3I=
go.mau.fi/util v0.7.1-0.20240826142731-d642a8a8b6fb h1:5sx2bjPNqkKB/EJsIinnRhXXomMBP2+7nYRIptwDlp4=
go.mau.fi/util v0.7.1-0.20240826142731-d642a8a8b6fb/go.mod h1:WuAOOV0O/otkxGkFUvfv/XE2ztegaoyM15ovS6SYbf4=
go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM=
go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70=
golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A=

View file

@ -21,8 +21,8 @@ import (
"strings"
"time"
"github.com/gorilla/mux"
"github.com/rs/zerolog"
"go.mau.fi/util/exhttp"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/federation"
@ -60,9 +60,9 @@ type MediaProxy struct {
serverName string
serverKey *federation.SigningKey
FederationRouter *mux.Router
LegacyMediaRouter *mux.Router
ClientMediaRouter *mux.Router
FederationRouter *http.ServeMux
LegacyMediaRouter *http.ServeMux
ClientMediaRouter *http.ServeMux
}
func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProxy, error) {
@ -70,7 +70,8 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx
if err != nil {
return nil, err
}
return &MediaProxy{
mp := &MediaProxy{
serverName: serverName,
serverKey: parsed,
GetMedia: getMedia,
@ -93,7 +94,29 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx
Version: strings.TrimPrefix(mautrix.VersionWithCommit, "v"),
},
},
}, nil
FederationRouter: http.NewServeMux(),
LegacyMediaRouter: http.NewServeMux(),
ClientMediaRouter: http.NewServeMux(),
}
mp.FederationRouter.HandleFunc("GET /v1/media/download/{mediaID}", mp.DownloadMediaFederation)
mp.FederationRouter.HandleFunc("GET /v1/version", mp.KeyServer.GetServerVersion)
addClientRoutes := func(router *http.ServeMux, prefix string) {
router.HandleFunc("GET "+prefix+"/download/{serverName}/{mediaID}", mp.DownloadMedia)
router.HandleFunc("GET "+prefix+"/download/{serverName}/{mediaID}/{fileName}", mp.DownloadMedia)
router.HandleFunc("GET "+prefix+"/thumbnail/{serverName}/{mediaID}", mp.DownloadMedia)
router.HandleFunc("PUT "+prefix+"/upload/{serverName}/{mediaID}", mp.UploadNotSupported)
router.HandleFunc("POST "+prefix+"/upload", mp.UploadNotSupported)
router.HandleFunc("POST "+prefix+"/create", mp.UploadNotSupported)
router.HandleFunc("GET "+prefix+"/config", mp.UploadNotSupported)
router.HandleFunc("GET "+prefix+"/preview_url", mp.PreviewURLNotSupported)
}
addClientRoutes(mp.LegacyMediaRouter, "/v3")
addClientRoutes(mp.LegacyMediaRouter, "/r0")
addClientRoutes(mp.LegacyMediaRouter, "/v1")
addClientRoutes(mp.ClientMediaRouter, "")
return mp, nil
}
type BasicConfig struct {
@ -123,7 +146,7 @@ type ServerConfig struct {
}
func (mp *MediaProxy) Listen(cfg ServerConfig) error {
router := mux.NewRouter()
router := http.NewServeMux()
mp.RegisterRoutes(router)
return http.ListenAndServe(fmt.Sprintf("%s:%d", cfg.Hostname, cfg.Port), router)
}
@ -140,50 +163,17 @@ func (mp *MediaProxy) DisallowProxying() {
mp.ProxyClient = nil
}
func (mp *MediaProxy) RegisterRoutes(router *mux.Router) {
if mp.FederationRouter == nil {
mp.FederationRouter = router.PathPrefix("/_matrix/federation").Subrouter()
}
if mp.LegacyMediaRouter == nil {
mp.LegacyMediaRouter = router.PathPrefix("/_matrix/media").Subrouter()
}
if mp.ClientMediaRouter == nil {
mp.ClientMediaRouter = router.PathPrefix("/_matrix/client/v1/media").Subrouter()
}
func (mp *MediaProxy) RegisterRoutes(router *http.ServeMux) {
legacyMediaHandler := exhttp.HandleErrors(mp.LegacyMediaRouter, exhttp.ErrorBodyGenerators{NotFound: mp.UnknownEndpoint, MethodNotAllowed: mp.UnsupportedMethod})
federationHandler := exhttp.HandleErrors(mp.FederationRouter, exhttp.ErrorBodyGenerators{NotFound: mp.UnknownEndpoint, MethodNotAllowed: mp.UnsupportedMethod})
clientMediaHandler := exhttp.HandleErrors(mp.ClientMediaRouter, exhttp.ErrorBodyGenerators{NotFound: mp.UnknownEndpoint, MethodNotAllowed: mp.UnsupportedMethod})
mp.FederationRouter.HandleFunc("/v1/media/download/{mediaID}", mp.DownloadMediaFederation).Methods(http.MethodGet)
mp.FederationRouter.HandleFunc("/v1/version", mp.KeyServer.GetServerVersion).Methods(http.MethodGet)
addClientRoutes := func(router *mux.Router, prefix string) {
router.HandleFunc(prefix+"/download/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet)
router.HandleFunc(prefix+"/download/{serverName}/{mediaID}/{fileName}", mp.DownloadMedia).Methods(http.MethodGet)
router.HandleFunc(prefix+"/thumbnail/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet)
router.HandleFunc(prefix+"/upload/{serverName}/{mediaID}", mp.UploadNotSupported).Methods(http.MethodPut)
router.HandleFunc(prefix+"/upload", mp.UploadNotSupported).Methods(http.MethodPost)
router.HandleFunc(prefix+"/create", mp.UploadNotSupported).Methods(http.MethodPost)
router.HandleFunc(prefix+"/config", mp.UploadNotSupported).Methods(http.MethodGet)
router.HandleFunc(prefix+"/preview_url", mp.PreviewURLNotSupported).Methods(http.MethodGet)
}
addClientRoutes(mp.LegacyMediaRouter, "/v3")
addClientRoutes(mp.LegacyMediaRouter, "/r0")
addClientRoutes(mp.LegacyMediaRouter, "/v1")
addClientRoutes(mp.ClientMediaRouter, "")
mp.LegacyMediaRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint)
mp.LegacyMediaRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod)
mp.FederationRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint)
mp.FederationRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod)
mp.ClientMediaRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint)
mp.ClientMediaRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod)
corsMiddleware := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With, Content-Type, Authorization")
w.Header().Set("Content-Security-Policy", "sandbox; default-src 'none'; script-src 'none'; plugin-types application/pdf; style-src 'unsafe-inline'; object-src 'self';")
next.ServeHTTP(w, r)
})
}
mp.LegacyMediaRouter.Use(corsMiddleware)
mp.ClientMediaRouter.Use(corsMiddleware)
legacyMediaHandler = exhttp.CORSMiddleware(legacyMediaHandler)
clientMediaHandler = exhttp.CORSMiddleware(clientMediaHandler)
router.Handle("/_matrix/federation", http.StripPrefix("/_matrix/federation", federationHandler))
router.Handle("/_matrix/media", http.StripPrefix("/_matrix/media", legacyMediaHandler))
router.Handle("/_matrix/client/v1/media", http.StripPrefix("/_matrix/client/v1/media", clientMediaHandler))
mp.KeyServer.Register(router)
}
@ -260,7 +250,7 @@ func (err *ResponseError) Error() string {
var ErrInvalidMediaIDSyntax = errors.New("invalid media ID syntax")
func (mp *MediaProxy) getMedia(w http.ResponseWriter, r *http.Request) GetMediaResponse {
mediaID := mux.Vars(r)["mediaID"]
mediaID := r.PathValue("mediaID")
resp, err := mp.GetMedia(r.Context(), mediaID)
if err != nil {
var respError *ResponseError
@ -342,8 +332,7 @@ func (mp *MediaProxy) DownloadMediaFederation(w http.ResponseWriter, r *http.Req
func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
log := zerolog.Ctx(ctx)
vars := mux.Vars(r)
if vars["serverName"] != mp.serverName {
if r.PathValue("serverName") != mp.serverName {
jsonResponse(w, http.StatusNotFound, &mautrix.RespError{
ErrCode: mautrix.MNotFound.ErrCode,
Err: fmt.Sprintf("This is a media proxy at %q, other media downloads are not available here", mp.serverName),
@ -360,7 +349,7 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) {
// In any other case, redirect to the URL.
isFederated := strings.HasPrefix(r.Header.Get("Authorization"), "X-Matrix")
if mp.ProxyClient != nil && (r.URL.Query().Get("allow_redirect") != "true" || (mp.ForceProxyLegacyFederation && isFederated)) {
mp.proxyDownload(ctx, w, urlResp.URL, vars["fileName"])
mp.proxyDownload(ctx, w, urlResp.URL, r.PathValue("fileName"))
return
}
w.Header().Set("Location", urlResp.URL)
@ -409,16 +398,18 @@ func (mp *MediaProxy) PreviewURLNotSupported(w http.ResponseWriter, r *http.Requ
})
}
func (mp *MediaProxy) UnknownEndpoint(w http.ResponseWriter, r *http.Request) {
jsonResponse(w, http.StatusNotFound, &mautrix.RespError{
func (mp *MediaProxy) UnknownEndpoint() (body []byte) {
body, _ = json.Marshal(&mautrix.RespError{
ErrCode: mautrix.MUnrecognized.ErrCode,
Err: "Unrecognized endpoint",
})
return
}
func (mp *MediaProxy) UnsupportedMethod(w http.ResponseWriter, r *http.Request) {
jsonResponse(w, http.StatusMethodNotAllowed, &mautrix.RespError{
func (mp *MediaProxy) UnsupportedMethod() (body []byte) {
body, _ = json.Marshal(&mautrix.RespError{
ErrCode: mautrix.MUnrecognized.ErrCode,
Err: "Invalid method for endpoint",
})
return
}