mediaproxy: pass through query parameters

This commit is contained in:
Tulir Asokan 2024-11-06 13:10:29 +01:00
commit f588c35d8b
6 changed files with 21 additions and 10 deletions

View file

@ -16,6 +16,8 @@
* *(mediaproxy)* Added `GetMediaResponseCallback` and `GetMediaResponseFile`
to write proxied data directly to http response or temp file instead of
having to use an `io.Reader`.
* *(mediaproxy)* Dropped support for legacy media download endpoints.
* *(mediaproxy,bridgev2)* Made interface pass through query parameters.
[MSC2781]: https://github.com/matrix-org/matrix-spec-proposals/pull/2781

View file

@ -71,7 +71,7 @@ func (br *Connector) GenerateContentURI(ctx context.Context, mediaID networkid.M
return mxc, nil
}
func (br *Connector) getDirectMedia(ctx context.Context, mediaIDStr string) (response mediaproxy.GetMediaResponse, err error) {
func (br *Connector) getDirectMedia(ctx context.Context, mediaIDStr string, params map[string]string) (response mediaproxy.GetMediaResponse, err error) {
mediaID, err := base64.RawURLEncoding.DecodeString(strings.TrimPrefix(mediaIDStr, br.Config.DirectMedia.MediaIDPrefix))
if err != nil || !bytes.HasPrefix(mediaID, []byte(MediaIDPrefix)) || len(mediaID) < len(MediaIDPrefix)+MediaIDTruncatedHashLength+1 {
return nil, mediaproxy.ErrInvalidMediaIDSyntax
@ -82,5 +82,5 @@ func (br *Connector) getDirectMedia(ctx context.Context, mediaIDStr string) (res
return nil, mautrix.MNotFound.WithMessage("Invalid checksum in media ID part")
}
remoteMediaID := networkid.MediaID(mediaID[len(MediaIDPrefix) : len(mediaID)-MediaIDTruncatedHashLength])
return br.Bridge.Network.(bridgev2.DirectMediableNetwork).Download(ctx, remoteMediaID)
return br.Bridge.Network.(bridgev2.DirectMediableNetwork).Download(ctx, remoteMediaID, params)
}

View file

@ -243,7 +243,7 @@ type StoppableNetwork interface {
type DirectMediableNetwork interface {
NetworkConnector
SetUseDirectMedia()
Download(ctx context.Context, mediaID networkid.MediaID) (mediaproxy.GetMediaResponse, error)
Download(ctx context.Context, mediaID networkid.MediaID, params map[string]string) (mediaproxy.GetMediaResponse, error)
}
type IdentifierValidatingNetwork interface {

2
go.mod
View file

@ -18,7 +18,7 @@ require (
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
github.com/yuin/goldmark v1.7.8
go.mau.fi/util v0.8.1
go.mau.fi/util v0.8.2-0.20241106111346-576742786fe9
go.mau.fi/zeroconfig v0.1.3
golang.org/x/crypto v0.28.0
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c

4
go.sum
View file

@ -51,8 +51,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic=
github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
go.mau.fi/util v0.8.1 h1:Ga43cz6esQBYqcjZ/onRoVnYWoUwjWbsxVeJg2jOTSo=
go.mau.fi/util v0.8.1/go.mod h1:T1u/rD2rzidVrBLyaUdPpZiJdP/rsyi+aTzn0D+Q6wc=
go.mau.fi/util v0.8.2-0.20241106111346-576742786fe9 h1:zYcb/lTZudowXAjKi6Yc2/2y5xxglPFfy9ZT2pNGsuM=
go.mau.fi/util v0.8.2-0.20241106111346-576742786fe9/go.mod h1:T1u/rD2rzidVrBLyaUdPpZiJdP/rsyi+aTzn0D+Q6wc=
go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM=
go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70=
golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=

View file

@ -16,6 +16,7 @@ import (
"mime/multipart"
"net/http"
"net/textproto"
"net/url"
"os"
"strconv"
"strings"
@ -95,7 +96,7 @@ type GetMediaResponseFile struct {
ContentType string
}
type GetMediaFunc = func(ctx context.Context, mediaID string) (response GetMediaResponse, err error)
type GetMediaFunc = func(ctx context.Context, mediaID string, params map[string]string) (response GetMediaResponse, err error)
type MediaProxy struct {
KeyServer *federation.KeyServer
@ -218,9 +219,17 @@ func (err *ResponseError) Error() string {
var ErrInvalidMediaIDSyntax = errors.New("invalid media ID syntax")
func queryToMap(vals url.Values) map[string]string {
m := make(map[string]string, len(vals))
for k, v := range vals {
m[k] = v[0]
}
return m
}
func (mp *MediaProxy) getMedia(w http.ResponseWriter, r *http.Request) GetMediaResponse {
mediaID := mux.Vars(r)["mediaID"]
resp, err := mp.GetMedia(r.Context(), mediaID)
resp, err := mp.GetMedia(r.Context(), mediaID, queryToMap(r.URL.Query()))
if err != nil {
//lint:ignore SA1019 deprecated types need to be supported until they're removed
var respError *ResponseError
@ -384,7 +393,7 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTemporaryRedirect)
} else if fileResp, ok := resp.(*GetMediaResponseFile); ok {
responseStarted, err := doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error {
mp.addHeaders(w, mimeType, r.PathValue("fileName"))
mp.addHeaders(w, mimeType, vars["fileName"])
w.Header().Set("Content-Length", strconv.FormatInt(size, 10))
w.WriteHeader(http.StatusOK)
_, err := wt.WriteTo(w)
@ -402,7 +411,7 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) {
}
}
} else if dataResp, ok := resp.(GetMediaResponseWriter); ok {
mp.addHeaders(w, dataResp.GetContentType(), r.PathValue("fileName"))
mp.addHeaders(w, dataResp.GetContentType(), vars["fileName"])
if dataResp.GetContentLength() != 0 {
w.Header().Set("Content-Length", strconv.FormatInt(dataResp.GetContentLength(), 10))
}