mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
all: replace gorilla/mux with standard library
This commit is contained in:
parent
5b55330b85
commit
3d85625644
11 changed files with 149 additions and 188 deletions
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2023 Tulir Asokan
|
||||
// Copyright (c) 2025 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
|
@ -19,7 +19,6 @@ import (
|
|||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/rs/zerolog"
|
||||
"golang.org/x/net/publicsuffix"
|
||||
|
|
@ -43,7 +42,7 @@ func Create() *AppService {
|
|||
intents: make(map[id.UserID]*IntentAPI),
|
||||
HTTPClient: &http.Client{Timeout: 180 * time.Second, Jar: jar},
|
||||
StateStore: mautrix.NewMemoryStateStore().(StateStore),
|
||||
Router: mux.NewRouter(),
|
||||
Router: http.NewServeMux(),
|
||||
UserAgent: mautrix.DefaultUserAgent,
|
||||
txnIDC: NewTransactionIDCache(128),
|
||||
Live: true,
|
||||
|
|
@ -61,12 +60,12 @@ func Create() *AppService {
|
|||
DefaultHTTPRetries: 4,
|
||||
}
|
||||
|
||||
as.Router.HandleFunc("/_matrix/app/v1/transactions/{txnID}", as.PutTransaction).Methods(http.MethodPut)
|
||||
as.Router.HandleFunc("/_matrix/app/v1/rooms/{roomAlias}", as.GetRoom).Methods(http.MethodGet)
|
||||
as.Router.HandleFunc("/_matrix/app/v1/users/{userID}", as.GetUser).Methods(http.MethodGet)
|
||||
as.Router.HandleFunc("/_matrix/app/v1/ping", as.PostPing).Methods(http.MethodPost)
|
||||
as.Router.HandleFunc("/_matrix/mau/live", as.GetLive).Methods(http.MethodGet)
|
||||
as.Router.HandleFunc("/_matrix/mau/ready", as.GetReady).Methods(http.MethodGet)
|
||||
as.Router.HandleFunc("PUT /_matrix/app/v1/transactions/{txnID}", as.PutTransaction)
|
||||
as.Router.HandleFunc("GET /_matrix/app/v1/rooms/{roomAlias}", as.GetRoom)
|
||||
as.Router.HandleFunc("GET /_matrix/app/v1/users/{userID}", as.GetUser)
|
||||
as.Router.HandleFunc("POST /_matrix/app/v1/ping", as.PostPing)
|
||||
as.Router.HandleFunc("GET /_matrix/mau/live", as.GetLive)
|
||||
as.Router.HandleFunc("GET /_matrix/mau/ready", as.GetReady)
|
||||
|
||||
return as
|
||||
}
|
||||
|
|
@ -114,13 +113,13 @@ var _ StateStore = (*mautrix.MemoryStateStore)(nil)
|
|||
|
||||
// QueryHandler handles room alias and user ID queries from the homeserver.
|
||||
type QueryHandler interface {
|
||||
QueryAlias(alias string) bool
|
||||
QueryAlias(alias id.RoomAlias) bool
|
||||
QueryUser(userID id.UserID) bool
|
||||
}
|
||||
|
||||
type QueryHandlerStub struct{}
|
||||
|
||||
func (qh *QueryHandlerStub) QueryAlias(alias string) bool {
|
||||
func (qh *QueryHandlerStub) QueryAlias(alias id.RoomAlias) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
|
|
@ -128,7 +127,7 @@ func (qh *QueryHandlerStub) QueryUser(userID id.UserID) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
type WebsocketHandler func(WebsocketCommand) (ok bool, data interface{})
|
||||
type WebsocketHandler func(WebsocketCommand) (ok bool, data any)
|
||||
|
||||
type StateStore interface {
|
||||
mautrix.StateStore
|
||||
|
|
@ -160,7 +159,7 @@ type AppService struct {
|
|||
QueryHandler QueryHandler
|
||||
StateStore StateStore
|
||||
|
||||
Router *mux.Router
|
||||
Router *http.ServeMux
|
||||
UserAgent string
|
||||
server *http.Server
|
||||
HTTPClient *http.Client
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ import (
|
|||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/rs/zerolog"
|
||||
"go.mau.fi/util/exhttp"
|
||||
"go.mau.fi/util/exstrings"
|
||||
|
|
@ -95,8 +94,7 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
txnID := vars["txnID"]
|
||||
txnID := r.PathValue("txnID")
|
||||
if len(txnID) == 0 {
|
||||
mautrix.MInvalidParam.WithMessage("Missing transaction ID").Write(w)
|
||||
return
|
||||
|
|
@ -240,8 +238,7 @@ func (as *AppService) GetRoom(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
roomAlias := vars["roomAlias"]
|
||||
roomAlias := id.RoomAlias(r.PathValue("roomAlias"))
|
||||
ok := as.QueryHandler.QueryAlias(roomAlias)
|
||||
if ok {
|
||||
exhttp.WriteEmptyJSONResponse(w, http.StatusOK)
|
||||
|
|
@ -256,8 +253,7 @@ func (as *AppService) GetUser(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
userID := id.UserID(vars["userID"])
|
||||
userID := id.UserID(r.PathValue("userID"))
|
||||
ok := as.QueryHandler.QueryUser(userID)
|
||||
if ok {
|
||||
exhttp.WriteEmptyJSONResponse(w, http.StatusOK)
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import (
|
|||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
|
|
@ -20,7 +21,6 @@ import (
|
|||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/rs/zerolog"
|
||||
"go.mau.fi/util/dbutil"
|
||||
|
|
@ -223,7 +223,7 @@ func (br *Connector) GetPublicAddress() string {
|
|||
return br.Config.AppService.PublicAddress
|
||||
}
|
||||
|
||||
func (br *Connector) GetRouter() *mux.Router {
|
||||
func (br *Connector) GetRouter() *http.ServeMux {
|
||||
if br.GetPublicAddress() != "" {
|
||||
return br.AS.Router
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
// Copyright (c) 2025 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
|
@ -17,7 +17,6 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/rs/xid"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/hlog"
|
||||
|
|
@ -40,7 +39,7 @@ type matrixAuthCacheEntry struct {
|
|||
}
|
||||
|
||||
type ProvisioningAPI struct {
|
||||
Router *mux.Router
|
||||
Router *http.ServeMux
|
||||
|
||||
br *Connector
|
||||
log zerolog.Logger
|
||||
|
|
@ -91,12 +90,12 @@ func (prov *ProvisioningAPI) GetUser(r *http.Request) *bridgev2.User {
|
|||
return r.Context().Value(provisioningUserKey).(*bridgev2.User)
|
||||
}
|
||||
|
||||
func (prov *ProvisioningAPI) GetRouter() *mux.Router {
|
||||
func (prov *ProvisioningAPI) GetRouter() *http.ServeMux {
|
||||
return prov.Router
|
||||
}
|
||||
|
||||
type IProvisioningAPI interface {
|
||||
GetRouter() *mux.Router
|
||||
GetRouter() *http.ServeMux
|
||||
GetUser(r *http.Request) *bridgev2.User
|
||||
}
|
||||
|
||||
|
|
@ -116,41 +115,49 @@ func (prov *ProvisioningAPI) Init() {
|
|||
tp.Dialer.Timeout = 10 * time.Second
|
||||
tp.Transport.ResponseHeaderTimeout = 10 * time.Second
|
||||
tp.Transport.TLSHandshakeTimeout = 10 * time.Second
|
||||
prov.Router = prov.br.AS.Router.PathPrefix(prov.br.Config.Provisioning.Prefix).Subrouter()
|
||||
prov.Router.Use(hlog.NewHandler(prov.log))
|
||||
prov.Router.Use(hlog.RequestIDHandler("request_id", "Request-Id"))
|
||||
prov.Router.Use(exhttp.CORSMiddleware)
|
||||
prov.Router.Use(requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}))
|
||||
prov.Router.Use(prov.AuthMiddleware)
|
||||
prov.Router.Path("/v3/whoami").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetWhoami)
|
||||
prov.Router.Path("/v3/login/flows").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetLoginFlows)
|
||||
prov.Router.Path("/v3/login/start/{flowID}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginStart)
|
||||
prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:user_input|cookies}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginSubmitInput)
|
||||
prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:display_and_wait}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginWait)
|
||||
prov.Router.Path("/v3/logout/{loginID}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLogout)
|
||||
prov.Router.Path("/v3/logins").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetLogins)
|
||||
prov.Router.Path("/v3/contacts").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetContactList)
|
||||
prov.Router.Path("/v3/search_users").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostSearchUsers)
|
||||
prov.Router.Path("/v3/resolve_identifier/{identifier}").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetResolveIdentifier)
|
||||
prov.Router.Path("/v3/create_dm/{identifier}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostCreateDM)
|
||||
prov.Router.Path("/v3/create_group").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostCreateGroup)
|
||||
prov.Router = http.NewServeMux()
|
||||
prov.Router.HandleFunc("GET /v3/whoami", prov.GetWhoami)
|
||||
prov.Router.HandleFunc("GET /v3/login/flows", prov.GetLoginFlows)
|
||||
prov.Router.HandleFunc("POST /v3/login/start/{flowID}", prov.PostLoginStart)
|
||||
prov.Router.HandleFunc("POST /v3/login/step/{loginProcessID}/{stepID}/{stepType:user_input|cookies}", prov.PostLoginSubmitInput)
|
||||
prov.Router.HandleFunc("POST /v3/login/step/{loginProcessID}/{stepID}/{stepType:display_and_wait}", prov.PostLoginWait)
|
||||
prov.Router.HandleFunc("POST /v3/logout/{loginID}", prov.PostLogout)
|
||||
prov.Router.HandleFunc("GET /v3/logins", prov.GetLogins)
|
||||
prov.Router.HandleFunc("GET /v3/contacts", prov.GetContactList)
|
||||
prov.Router.HandleFunc("POST /v3/search_users", prov.PostSearchUsers)
|
||||
prov.Router.HandleFunc("GET /v3/resolve_identifier/{identifier}", prov.GetResolveIdentifier)
|
||||
prov.Router.HandleFunc("POST /v3/create_dm/{identifier}", prov.PostCreateDM)
|
||||
prov.Router.HandleFunc("POST /v3/create_group", prov.PostCreateGroup)
|
||||
|
||||
if prov.br.Config.Provisioning.EnableSessionTransfers {
|
||||
prov.log.Debug().Msg("Enabling session transfer API")
|
||||
prov.Router.Path("/v3/session_transfer/init").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostInitSessionTransfer)
|
||||
prov.Router.Path("/v3/session_transfer/finish").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostFinishSessionTransfer)
|
||||
prov.Router.HandleFunc("POST /v3/session_transfer/init", prov.PostInitSessionTransfer)
|
||||
prov.Router.HandleFunc("POST /v3/session_transfer/finish", prov.PostFinishSessionTransfer)
|
||||
}
|
||||
|
||||
if prov.br.Config.Provisioning.DebugEndpoints {
|
||||
prov.log.Debug().Msg("Enabling debug API at /debug")
|
||||
r := prov.br.AS.Router.PathPrefix("/debug").Subrouter()
|
||||
r.Use(prov.DebugAuthMiddleware)
|
||||
r.HandleFunc("/pprof/cmdline", pprof.Cmdline).Methods(http.MethodGet)
|
||||
r.HandleFunc("/pprof/profile", pprof.Profile).Methods(http.MethodGet)
|
||||
r.HandleFunc("/pprof/symbol", pprof.Symbol).Methods(http.MethodGet)
|
||||
r.HandleFunc("/pprof/trace", pprof.Trace).Methods(http.MethodGet)
|
||||
r.PathPrefix("/pprof/").HandlerFunc(pprof.Index)
|
||||
debugRouter := http.NewServeMux()
|
||||
debugRouter.HandleFunc("GET /pprof/cmdline", pprof.Cmdline)
|
||||
debugRouter.HandleFunc("GET /pprof/profile", pprof.Profile)
|
||||
debugRouter.HandleFunc("GET /pprof/symbol", pprof.Symbol)
|
||||
debugRouter.HandleFunc("GET /pprof/trace", pprof.Trace)
|
||||
debugRouter.HandleFunc("/pprof/", pprof.Index)
|
||||
prov.br.AS.Router.Handle("/debug", exhttp.ApplyMiddleware(
|
||||
debugRouter,
|
||||
hlog.NewHandler(prov.br.Log.With().Str("component", "debug api").Logger()),
|
||||
prov.DebugAuthMiddleware,
|
||||
))
|
||||
}
|
||||
|
||||
prov.br.AS.Router.Handle("/_matrix/provision", exhttp.ApplyMiddleware(
|
||||
prov.Router,
|
||||
hlog.NewHandler(prov.log),
|
||||
hlog.RequestIDHandler("request_id", "Request-Id"),
|
||||
exhttp.CORSMiddleware,
|
||||
requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}),
|
||||
prov.AuthMiddleware,
|
||||
))
|
||||
}
|
||||
|
||||
func (prov *ProvisioningAPI) checkMatrixAuth(ctx context.Context, userID id.UserID, token string) error {
|
||||
|
|
@ -250,7 +257,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
|
|||
|
||||
ctx := context.WithValue(r.Context(), ProvisioningKeyRequest, r)
|
||||
ctx = context.WithValue(ctx, provisioningUserKey, user)
|
||||
if loginID, ok := mux.Vars(r)["loginProcessID"]; ok {
|
||||
if loginID := r.PathValue("loginProcessID"); loginID != "" {
|
||||
prov.loginsLock.RLock()
|
||||
login, ok := prov.logins[loginID]
|
||||
prov.loginsLock.RUnlock()
|
||||
|
|
@ -262,7 +269,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
|
|||
login.Lock.Lock()
|
||||
// This will only unlock after the handler runs
|
||||
defer login.Lock.Unlock()
|
||||
stepID := mux.Vars(r)["stepID"]
|
||||
stepID := r.PathValue("stepID")
|
||||
if login.NextStep.StepID != stepID {
|
||||
zerolog.Ctx(r.Context()).Warn().
|
||||
Str("request_step_id", stepID).
|
||||
|
|
@ -271,7 +278,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
|
|||
mautrix.MBadState.WithMessage("Step ID does not match").Write(w)
|
||||
return
|
||||
}
|
||||
stepType := mux.Vars(r)["stepType"]
|
||||
stepType := r.PathValue("stepType")
|
||||
if login.NextStep.Type != bridgev2.LoginStepType(stepType) {
|
||||
zerolog.Ctx(r.Context()).Warn().
|
||||
Str("request_step_type", stepType).
|
||||
|
|
@ -374,7 +381,7 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque
|
|||
login, err := prov.net.CreateLogin(
|
||||
r.Context(),
|
||||
prov.GetUser(r),
|
||||
mux.Vars(r)["flowID"],
|
||||
r.PathValue("flowID"),
|
||||
)
|
||||
if err != nil {
|
||||
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create login process")
|
||||
|
|
@ -475,7 +482,7 @@ func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Reques
|
|||
|
||||
func (prov *ProvisioningAPI) PostLogout(w http.ResponseWriter, r *http.Request) {
|
||||
user := prov.GetUser(r)
|
||||
userLoginID := networkid.UserLoginID(mux.Vars(r)["loginID"])
|
||||
userLoginID := networkid.UserLoginID(r.PathValue("loginID"))
|
||||
if userLoginID == "all" {
|
||||
for {
|
||||
login := user.GetDefaultLogin()
|
||||
|
|
@ -571,7 +578,7 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http.
|
|||
mautrix.MUnrecognized.WithMessage("This bridge does not support resolving identifiers").Write(w)
|
||||
return
|
||||
}
|
||||
resp, err := api.ResolveIdentifier(r.Context(), mux.Vars(r)["identifier"], createChat)
|
||||
resp, err := api.ResolveIdentifier(r.Context(), r.PathValue("identifier"), createChat)
|
||||
if err != nil {
|
||||
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to resolve identifier")
|
||||
RespondWithError(w, err, "Internal error resolving identifier")
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
// Copyright (c) 2025 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
|
@ -16,8 +16,6 @@ import (
|
|||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"maunium.net/go/mautrix/bridgev2"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
|
@ -35,7 +33,7 @@ func (br *Connector) initPublicMedia() error {
|
|||
return fmt.Errorf("public media hash length is negative")
|
||||
}
|
||||
br.pubMediaSigKey = []byte(br.Config.PublicMedia.SigningKey)
|
||||
br.AS.Router.HandleFunc("/_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia).Methods(http.MethodGet)
|
||||
br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -76,16 +74,15 @@ var proxyHeadersToCopy = []string{
|
|||
}
|
||||
|
||||
func (br *Connector) servePublicMedia(w http.ResponseWriter, r *http.Request) {
|
||||
vars := mux.Vars(r)
|
||||
contentURI := id.ContentURI{
|
||||
Homeserver: vars["server"],
|
||||
FileID: vars["mediaID"],
|
||||
Homeserver: r.PathValue("server"),
|
||||
FileID: r.PathValue("mediaID"),
|
||||
}
|
||||
if !contentURI.IsValid() {
|
||||
http.Error(w, "invalid content URI", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
checksum, err := base64.RawURLEncoding.DecodeString(vars["checksum"])
|
||||
checksum, err := base64.RawURLEncoding.DecodeString(r.PathValue("checksum"))
|
||||
if err != nil || !hmac.Equal(checksum, br.makePublicMediaChecksum(contentURI)) {
|
||||
http.Error(w, "invalid base64 in checksum", http.StatusBadRequest)
|
||||
return
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
// Copyright (c) 2025 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
|
@ -10,11 +10,10 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/bridgev2/database"
|
||||
"maunium.net/go/mautrix/bridgev2/networkid"
|
||||
|
|
@ -64,7 +63,7 @@ type MatrixConnectorWithArbitraryRoomState interface {
|
|||
|
||||
type MatrixConnectorWithServer interface {
|
||||
GetPublicAddress() string
|
||||
GetRouter() *mux.Router
|
||||
GetRouter() *http.ServeMux
|
||||
}
|
||||
|
||||
type MatrixConnectorWithPublicMedia interface {
|
||||
|
|
|
|||
|
|
@ -12,11 +12,9 @@ import (
|
|||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/rs/zerolog/log" // zerolog-allow-global-log
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.mau.fi/util/random"
|
||||
|
|
@ -42,20 +40,6 @@ type mockServer struct {
|
|||
UserSigningKeys map[id.UserID]mautrix.CrossSigningKeys
|
||||
}
|
||||
|
||||
func DecodeVarsMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
vars := mux.Vars(r)
|
||||
var err error
|
||||
for k, v := range vars {
|
||||
vars[k], err = url.PathUnescape(v)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func createMockServer(t *testing.T) *mockServer {
|
||||
t.Helper()
|
||||
|
||||
|
|
@ -69,15 +53,14 @@ func createMockServer(t *testing.T) *mockServer {
|
|||
UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
|
||||
}
|
||||
|
||||
router := mux.NewRouter().SkipClean(true).StrictSlash(false).UseEncodedPath()
|
||||
router.Use(DecodeVarsMiddleware)
|
||||
router.HandleFunc("/_matrix/client/v3/login", server.postLogin).Methods(http.MethodPost)
|
||||
router.HandleFunc("/_matrix/client/v3/keys/query", server.postKeysQuery).Methods(http.MethodPost)
|
||||
router.HandleFunc("/_matrix/client/v3/sendToDevice/{type}/{txn}", server.putSendToDevice).Methods(http.MethodPut)
|
||||
router.HandleFunc("/_matrix/client/v3/user/{userID}/account_data/{type}", server.putAccountData).Methods(http.MethodPut)
|
||||
router.HandleFunc("/_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload).Methods(http.MethodPost)
|
||||
router.HandleFunc("/_matrix/client/v3/keys/signatures/upload", server.emptyResp).Methods(http.MethodPost)
|
||||
router.HandleFunc("/_matrix/client/v3/keys/upload", server.postKeysUpload).Methods(http.MethodPost)
|
||||
router := http.NewServeMux()
|
||||
router.HandleFunc("POST /_matrix/client/v3/login", server.postLogin)
|
||||
router.HandleFunc("POST /_matrix/client/v3/keys/query", server.postKeysQuery)
|
||||
router.HandleFunc("PUT /_matrix/client/v3/sendToDevice/{type}/{txn}", server.putSendToDevice)
|
||||
router.HandleFunc("PUT /_matrix/client/v3/user/{userID}/account_data/{type}", server.putAccountData)
|
||||
router.HandleFunc("POST /_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload)
|
||||
router.HandleFunc("POST /_matrix/client/v3/keys/signatures/upload", server.emptyResp)
|
||||
router.HandleFunc("POST /_matrix/client/v3/keys/upload", server.postKeysUpload)
|
||||
|
||||
server.Server = httptest.NewServer(router)
|
||||
return &server
|
||||
|
|
@ -118,10 +101,9 @@ func (s *mockServer) postLogin(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func (s *mockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) {
|
||||
vars := mux.Vars(r)
|
||||
var req mautrix.ReqSendToDevice
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
evtType := event.Type{Type: vars["type"], Class: event.ToDeviceEventType}
|
||||
evtType := event.Type{Type: r.PathValue("type"), Class: event.ToDeviceEventType}
|
||||
|
||||
for user, devices := range req.Messages {
|
||||
for device, content := range devices {
|
||||
|
|
@ -140,9 +122,8 @@ func (s *mockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func (s *mockServer) putAccountData(w http.ResponseWriter, r *http.Request) {
|
||||
vars := mux.Vars(r)
|
||||
userID := id.UserID(vars["userID"])
|
||||
eventType := event.Type{Type: vars["type"], Class: event.AccountDataEventType}
|
||||
userID := id.UserID(r.PathValue("userID"))
|
||||
eventType := event.Type{Type: r.PathValue("type"), Class: event.AccountDataEventType}
|
||||
|
||||
jsonData, _ := io.ReadAll(r.Body)
|
||||
if _, ok := s.AccountData[userID]; !ok {
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
// Copyright (c) 2025 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
|
@ -12,7 +12,7 @@ import (
|
|||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"go.mau.fi/util/exerrors"
|
||||
"go.mau.fi/util/exhttp"
|
||||
"go.mau.fi/util/jsontime"
|
||||
|
||||
|
|
@ -51,19 +51,21 @@ type KeyServer struct {
|
|||
}
|
||||
|
||||
// Register registers the key server endpoints to the given router.
|
||||
func (ks *KeyServer) Register(r *mux.Router) {
|
||||
r.HandleFunc("/.well-known/matrix/server", ks.GetWellKnown).Methods(http.MethodGet)
|
||||
r.HandleFunc("/_matrix/federation/v1/version", ks.GetServerVersion).Methods(http.MethodGet)
|
||||
keyRouter := r.PathPrefix("/_matrix/key").Subrouter()
|
||||
keyRouter.HandleFunc("/v2/server", ks.GetServerKey).Methods(http.MethodGet)
|
||||
keyRouter.HandleFunc("/v2/query/{serverName}", ks.GetQueryKeys).Methods(http.MethodGet)
|
||||
keyRouter.HandleFunc("/v2/query", ks.PostQueryKeys).Methods(http.MethodPost)
|
||||
keyRouter.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mautrix.MUnrecognized.WithStatus(http.StatusNotFound).WithMessage("Unrecognized endpoint").Write(w)
|
||||
})
|
||||
keyRouter.MethodNotAllowedHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mautrix.MUnrecognized.WithStatus(http.StatusMethodNotAllowed).WithMessage("Invalid method for endpoint").Write(w)
|
||||
})
|
||||
func (ks *KeyServer) Register(r *http.ServeMux) {
|
||||
r.HandleFunc("GET /.well-known/matrix/server", ks.GetWellKnown)
|
||||
r.HandleFunc("GET /_matrix/federation/v1/version", ks.GetServerVersion)
|
||||
keyRouter := http.NewServeMux()
|
||||
keyRouter.HandleFunc("GET /v2/server", ks.GetServerKey)
|
||||
keyRouter.HandleFunc("GET /v2/query/{serverName}", ks.GetQueryKeys)
|
||||
keyRouter.HandleFunc("POST /v2/query", ks.PostQueryKeys)
|
||||
errorBodies := exhttp.ErrorBodies{
|
||||
NotFound: exerrors.Must(json.Marshal(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint"))),
|
||||
MethodNotAllowed: exerrors.Must(json.Marshal(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint"))),
|
||||
}
|
||||
r.Handle("/_matrix/key", exhttp.ApplyMiddleware(
|
||||
keyRouter,
|
||||
exhttp.HandleErrors(errorBodies),
|
||||
))
|
||||
}
|
||||
|
||||
// RespWellKnown is the response body for the `GET /.well-known/matrix/server` endpoint.
|
||||
|
|
@ -157,7 +159,7 @@ type GetQueryKeysResponse struct {
|
|||
//
|
||||
// https://spec.matrix.org/v1.9/server-server-api/#get_matrixkeyv2queryservername
|
||||
func (ks *KeyServer) GetQueryKeys(w http.ResponseWriter, r *http.Request) {
|
||||
serverName := mux.Vars(r)["serverName"]
|
||||
serverName := r.PathValue("serverName")
|
||||
minimumValidUntilTSString := r.URL.Query().Get("minimum_valid_until_ts")
|
||||
minimumValidUntilTS, err := strconv.ParseInt(minimumValidUntilTSString, 10, 64)
|
||||
if err != nil && minimumValidUntilTSString != "" {
|
||||
|
|
|
|||
9
go.mod
9
go.mod
|
|
@ -2,12 +2,11 @@ module maunium.net/go/mautrix
|
|||
|
||||
go 1.23.0
|
||||
|
||||
toolchain go1.24.4
|
||||
toolchain go1.24.5
|
||||
|
||||
require (
|
||||
filippo.io/edwards25519 v1.1.0
|
||||
github.com/chzyer/readline v1.5.1
|
||||
github.com/gorilla/mux v1.8.0
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/mattn/go-sqlite3 v1.14.28
|
||||
|
|
@ -18,10 +17,10 @@ require (
|
|||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
github.com/yuin/goldmark v1.7.12
|
||||
go.mau.fi/util v0.8.8
|
||||
go.mau.fi/util v0.8.9-0.20250723171559-474867266038
|
||||
go.mau.fi/zeroconfig v0.1.3
|
||||
golang.org/x/crypto v0.40.0
|
||||
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc
|
||||
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792
|
||||
golang.org/x/net v0.42.0
|
||||
golang.org/x/sync v0.16.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
|
|
@ -33,7 +32,7 @@ require (
|
|||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/petermattis/goid v0.0.0-20250508124226-395b08cebbdb // indirect
|
||||
github.com/petermattis/goid v0.0.0-20250721140440-ea1c0173183e // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
|
|
|
|||
14
go.sum
14
go.sum
|
|
@ -13,8 +13,6 @@ github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSV
|
|||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
|
||||
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
|
||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
|
|
@ -28,8 +26,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE
|
|||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A=
|
||||
github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/petermattis/goid v0.0.0-20250508124226-395b08cebbdb h1:3PrKuO92dUTMrQ9dx0YNejC6U/Si6jqKmyQ9vWjwqR4=
|
||||
github.com/petermattis/goid v0.0.0-20250508124226-395b08cebbdb/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4=
|
||||
github.com/petermattis/goid v0.0.0-20250721140440-ea1c0173183e h1:D0bJD+4O3G4izvrQUmzCL80zazlN7EwJ0PPDhpJWC/I=
|
||||
github.com/petermattis/goid v0.0.0-20250721140440-ea1c0173183e/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
|
|
@ -53,14 +51,14 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
|||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/yuin/goldmark v1.7.12 h1:YwGP/rrea2/CnCtUHgjuolG/PnMxdQtPMO5PvaE2/nY=
|
||||
github.com/yuin/goldmark v1.7.12/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
|
||||
go.mau.fi/util v0.8.8 h1:OnuEEc/sIJFhnq4kFggiImUpcmnmL/xpvQMRu5Fiy5c=
|
||||
go.mau.fi/util v0.8.8/go.mod h1:Y/kS3loxTEhy8Vill513EtPXr+CRDdae+Xj2BXXMy/c=
|
||||
go.mau.fi/util v0.8.9-0.20250723171559-474867266038 h1:RVL8TVaYc3LTBBopfjCNDtD+6eZks0O+qgXN/9hsz7k=
|
||||
go.mau.fi/util v0.8.9-0.20250723171559-474867266038/go.mod h1:GZZp5f9r2MgEu4GDvtB0XxCF7i6Z7Z8fM0w9a5oZH3Y=
|
||||
go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM=
|
||||
go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70=
|
||||
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
|
||||
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
|
||||
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc h1:TS73t7x3KarrNd5qAipmspBDS1rkMcgVG/fS1aRb4Rc=
|
||||
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
|
||||
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 h1:R9PFI6EUdfVKgwKjZef7QIwGcBKu86OEFpJ9nUEP2l4=
|
||||
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
|
||||
golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs=
|
||||
golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8=
|
||||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
// Copyright (c) 2025 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
|
@ -8,6 +8,7 @@ package mediaproxy
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
|
@ -21,8 +22,9 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/rs/zerolog"
|
||||
"go.mau.fi/util/exerrors"
|
||||
"go.mau.fi/util/exhttp"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/federation"
|
||||
|
|
@ -108,8 +110,8 @@ type MediaProxy struct {
|
|||
serverName string
|
||||
serverKey *federation.SigningKey
|
||||
|
||||
FederationRouter *mux.Router
|
||||
ClientMediaRouter *mux.Router
|
||||
FederationRouter *http.ServeMux
|
||||
ClientMediaRouter *http.ServeMux
|
||||
}
|
||||
|
||||
func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProxy, error) {
|
||||
|
|
@ -117,7 +119,7 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &MediaProxy{
|
||||
mp := &MediaProxy{
|
||||
serverName: serverName,
|
||||
serverKey: parsed,
|
||||
GetMedia: getMedia,
|
||||
|
|
@ -132,7 +134,20 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx
|
|||
Version: strings.TrimPrefix(mautrix.VersionWithCommit, "v"),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
mp.FederationRouter = http.NewServeMux()
|
||||
mp.FederationRouter.HandleFunc("GET /v1/media/download/{mediaID}", mp.DownloadMediaFederation)
|
||||
mp.FederationRouter.HandleFunc("GET /v1/version", mp.KeyServer.GetServerVersion)
|
||||
mp.ClientMediaRouter = http.NewServeMux()
|
||||
mp.ClientMediaRouter.HandleFunc("GET /download/{serverName}/{mediaID}", mp.DownloadMedia)
|
||||
mp.ClientMediaRouter.HandleFunc("GET /download/{serverName}/{mediaID}/{fileName}", mp.DownloadMedia)
|
||||
mp.ClientMediaRouter.HandleFunc("GET /thumbnail/{serverName}/{mediaID}", mp.DownloadMedia)
|
||||
mp.ClientMediaRouter.HandleFunc("PUT /upload/{serverName}/{mediaID}", mp.UploadNotSupported)
|
||||
mp.ClientMediaRouter.HandleFunc("POST /upload", mp.UploadNotSupported)
|
||||
mp.ClientMediaRouter.HandleFunc("POST /create", mp.UploadNotSupported)
|
||||
mp.ClientMediaRouter.HandleFunc("GET /config", mp.UploadNotSupported)
|
||||
mp.ClientMediaRouter.HandleFunc("GET /preview_url", mp.PreviewURLNotSupported)
|
||||
return mp, nil
|
||||
}
|
||||
|
||||
type BasicConfig struct {
|
||||
|
|
@ -162,7 +177,7 @@ type ServerConfig struct {
|
|||
}
|
||||
|
||||
func (mp *MediaProxy) Listen(cfg ServerConfig) error {
|
||||
router := mux.NewRouter()
|
||||
router := http.NewServeMux()
|
||||
mp.RegisterRoutes(router)
|
||||
return http.ListenAndServe(fmt.Sprintf("%s:%d", cfg.Hostname, cfg.Port), router)
|
||||
}
|
||||
|
|
@ -188,38 +203,20 @@ func (mp *MediaProxy) EnableServerAuth(client *federation.Client, keyCache feder
|
|||
})
|
||||
}
|
||||
|
||||
func (mp *MediaProxy) RegisterRoutes(router *mux.Router) {
|
||||
if mp.FederationRouter == nil {
|
||||
mp.FederationRouter = router.PathPrefix("/_matrix/federation").Subrouter()
|
||||
func (mp *MediaProxy) RegisterRoutes(router *http.ServeMux) {
|
||||
errorBodies := exhttp.ErrorBodies{
|
||||
NotFound: exerrors.Must(json.Marshal(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint"))),
|
||||
MethodNotAllowed: exerrors.Must(json.Marshal(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint"))),
|
||||
}
|
||||
if mp.ClientMediaRouter == nil {
|
||||
mp.ClientMediaRouter = router.PathPrefix("/_matrix/client/v1/media").Subrouter()
|
||||
}
|
||||
|
||||
mp.FederationRouter.HandleFunc("/v1/media/download/{mediaID}", mp.DownloadMediaFederation).Methods(http.MethodGet)
|
||||
mp.FederationRouter.HandleFunc("/v1/version", mp.KeyServer.GetServerVersion).Methods(http.MethodGet)
|
||||
mp.ClientMediaRouter.HandleFunc("/download/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet)
|
||||
mp.ClientMediaRouter.HandleFunc("/download/{serverName}/{mediaID}/{fileName}", mp.DownloadMedia).Methods(http.MethodGet)
|
||||
mp.ClientMediaRouter.HandleFunc("/thumbnail/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet)
|
||||
mp.ClientMediaRouter.HandleFunc("/upload/{serverName}/{mediaID}", mp.UploadNotSupported).Methods(http.MethodPut)
|
||||
mp.ClientMediaRouter.HandleFunc("/upload", mp.UploadNotSupported).Methods(http.MethodPost)
|
||||
mp.ClientMediaRouter.HandleFunc("/create", mp.UploadNotSupported).Methods(http.MethodPost)
|
||||
mp.ClientMediaRouter.HandleFunc("/config", mp.UploadNotSupported).Methods(http.MethodGet)
|
||||
mp.ClientMediaRouter.HandleFunc("/preview_url", mp.PreviewURLNotSupported).Methods(http.MethodGet)
|
||||
mp.FederationRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint)
|
||||
mp.FederationRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod)
|
||||
mp.ClientMediaRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint)
|
||||
mp.ClientMediaRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod)
|
||||
corsMiddleware := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With, Content-Type, Authorization")
|
||||
w.Header().Set("Content-Security-Policy", "sandbox; default-src 'none'; script-src 'none'; plugin-types application/pdf; style-src 'unsafe-inline'; object-src 'self';")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
mp.ClientMediaRouter.Use(corsMiddleware)
|
||||
router.Handle("/_matrix/federation", exhttp.ApplyMiddleware(
|
||||
mp.FederationRouter,
|
||||
exhttp.HandleErrors(errorBodies),
|
||||
))
|
||||
router.Handle("/_matrix/client/v1/media", exhttp.ApplyMiddleware(
|
||||
mp.ClientMediaRouter,
|
||||
exhttp.CORSMiddleware,
|
||||
exhttp.HandleErrors(errorBodies),
|
||||
))
|
||||
mp.KeyServer.Register(router)
|
||||
}
|
||||
|
||||
|
|
@ -234,7 +231,7 @@ func queryToMap(vals url.Values) map[string]string {
|
|||
}
|
||||
|
||||
func (mp *MediaProxy) getMedia(w http.ResponseWriter, r *http.Request) GetMediaResponse {
|
||||
mediaID := mux.Vars(r)["mediaID"]
|
||||
mediaID := r.PathValue("mediaID")
|
||||
if !id.IsValidMediaID(mediaID) {
|
||||
mautrix.MNotFound.WithMessage("Media ID %q is not valid", mediaID).Write(w)
|
||||
return nil
|
||||
|
|
@ -380,8 +377,7 @@ func (mp *MediaProxy) addHeaders(w http.ResponseWriter, mimeType, fileName strin
|
|||
func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
log := zerolog.Ctx(ctx)
|
||||
vars := mux.Vars(r)
|
||||
if vars["serverName"] != mp.serverName {
|
||||
if r.PathValue("serverName") != mp.serverName {
|
||||
mautrix.MNotFound.WithMessage("This is a media proxy at %q, other media downloads are not available here", mp.serverName).Write(w)
|
||||
return
|
||||
}
|
||||
|
|
@ -404,7 +400,7 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) {
|
|||
w.WriteHeader(http.StatusTemporaryRedirect)
|
||||
} else if fileResp, ok := resp.(*GetMediaResponseFile); ok {
|
||||
responseStarted, err := doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error {
|
||||
mp.addHeaders(w, mimeType, vars["fileName"])
|
||||
mp.addHeaders(w, mimeType, r.PathValue("fileName"))
|
||||
w.Header().Set("Content-Length", strconv.FormatInt(size, 10))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, err := wt.WriteTo(w)
|
||||
|
|
@ -425,7 +421,7 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) {
|
|||
if dataResp, ok := writerResp.(*GetMediaResponseData); ok {
|
||||
defer dataResp.Reader.Close()
|
||||
}
|
||||
mp.addHeaders(w, writerResp.GetContentType(), vars["fileName"])
|
||||
mp.addHeaders(w, writerResp.GetContentType(), r.PathValue("fileName"))
|
||||
if writerResp.GetContentLength() != 0 {
|
||||
w.Header().Set("Content-Length", strconv.FormatInt(writerResp.GetContentLength(), 10))
|
||||
}
|
||||
|
|
@ -491,11 +487,6 @@ var (
|
|||
ErrPreviewURLNotSupported = mautrix.MUnrecognized.
|
||||
WithMessage("This is a media proxy and does not support URL previews.").
|
||||
WithStatus(http.StatusNotImplemented)
|
||||
ErrUnknownEndpoint = mautrix.MUnrecognized.
|
||||
WithMessage("Unrecognized endpoint")
|
||||
ErrUnsupportedMethod = mautrix.MUnrecognized.
|
||||
WithMessage("Invalid method for endpoint").
|
||||
WithStatus(http.StatusMethodNotAllowed)
|
||||
)
|
||||
|
||||
func (mp *MediaProxy) UploadNotSupported(w http.ResponseWriter, r *http.Request) {
|
||||
|
|
@ -505,11 +496,3 @@ func (mp *MediaProxy) UploadNotSupported(w http.ResponseWriter, r *http.Request)
|
|||
func (mp *MediaProxy) PreviewURLNotSupported(w http.ResponseWriter, r *http.Request) {
|
||||
ErrPreviewURLNotSupported.Write(w)
|
||||
}
|
||||
|
||||
func (mp *MediaProxy) UnknownEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
ErrUnknownEndpoint.Write(w)
|
||||
}
|
||||
|
||||
func (mp *MediaProxy) UnsupportedMethod(w http.ResponseWriter, r *http.Request) {
|
||||
ErrUnsupportedMethod.Write(w)
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue