mediaproxy: pass through query parameters

This commit is contained in:
Tulir Asokan 2024-11-06 13:10:29 +01:00
commit f588c35d8b
6 changed files with 21 additions and 10 deletions

View file

@ -16,6 +16,7 @@ import (
"mime/multipart"
"net/http"
"net/textproto"
"net/url"
"os"
"strconv"
"strings"
@ -95,7 +96,7 @@ type GetMediaResponseFile struct {
ContentType string
}
type GetMediaFunc = func(ctx context.Context, mediaID string) (response GetMediaResponse, err error)
type GetMediaFunc = func(ctx context.Context, mediaID string, params map[string]string) (response GetMediaResponse, err error)
type MediaProxy struct {
KeyServer *federation.KeyServer
@ -218,9 +219,17 @@ func (err *ResponseError) Error() string {
var ErrInvalidMediaIDSyntax = errors.New("invalid media ID syntax")
func queryToMap(vals url.Values) map[string]string {
m := make(map[string]string, len(vals))
for k, v := range vals {
m[k] = v[0]
}
return m
}
func (mp *MediaProxy) getMedia(w http.ResponseWriter, r *http.Request) GetMediaResponse {
mediaID := mux.Vars(r)["mediaID"]
resp, err := mp.GetMedia(r.Context(), mediaID)
resp, err := mp.GetMedia(r.Context(), mediaID, queryToMap(r.URL.Query()))
if err != nil {
//lint:ignore SA1019 deprecated types need to be supported until they're removed
var respError *ResponseError
@ -384,7 +393,7 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTemporaryRedirect)
} else if fileResp, ok := resp.(*GetMediaResponseFile); ok {
responseStarted, err := doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error {
mp.addHeaders(w, mimeType, r.PathValue("fileName"))
mp.addHeaders(w, mimeType, vars["fileName"])
w.Header().Set("Content-Length", strconv.FormatInt(size, 10))
w.WriteHeader(http.StatusOK)
_, err := wt.WriteTo(w)
@ -402,7 +411,7 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) {
}
}
} else if dataResp, ok := resp.(GetMediaResponseWriter); ok {
mp.addHeaders(w, dataResp.GetContentType(), r.PathValue("fileName"))
mp.addHeaders(w, dataResp.GetContentType(), vars["fileName"])
if dataResp.GetContentLength() != 0 {
w.Header().Set("Content-Length", strconv.FormatInt(dataResp.GetContentLength(), 10))
}