From 39bddeb7d3f410d41bab3b3fba45ec33bc85a5bb Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 6 Nov 2024 11:01:49 +0100 Subject: [PATCH] mediaproxy: add support for temp files --- CHANGELOG.md | 5 ++- mediaproxy/mediaproxy.go | 89 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a971597c..26df4cbf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,8 +13,9 @@ * *(pushrules)* Added support for `sender_notification_permission` condition kind (used for `@room` mentions). * *(crypto)* Added support for `json.RawMessage` in `EncryptMegolmEvent`. -* *(mediaproxy)* Added `GetMediaResponseCallback` to write proxied response - directly instead of having to use an `io.Reader`. +* *(mediaproxy)* Added `GetMediaResponseCallback` and `GetMediaResponseFile` + to write proxied data directly to http response or temp file instead of + having to use an `io.Reader`. [MSC2781]: https://github.com/matrix-org/matrix-spec-proposals/pull/2781 diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go index d1ab0815..ce8dd99b 100644 --- a/mediaproxy/mediaproxy.go +++ b/mediaproxy/mediaproxy.go @@ -17,6 +17,7 @@ import ( "net" "net/http" "net/textproto" + "os" "strconv" "strings" "time" @@ -35,6 +36,7 @@ type GetMediaResponse interface { func (*GetMediaResponseURL) isGetMediaResponse() {} func (*GetMediaResponseData) isGetMediaResponse() {} func (*GetMediaResponseCallback) isGetMediaResponse() {} +func (*GetMediaResponseFile) isGetMediaResponse() {} type GetMediaResponseURL struct { URL string @@ -48,6 +50,11 @@ type GetMediaResponseWriter interface { GetContentLength() int64 } +var ( + _ GetMediaResponseWriter = (*GetMediaResponseCallback)(nil) + _ GetMediaResponseWriter = (*GetMediaResponseData)(nil) +) + type GetMediaResponseData struct { Reader io.ReadCloser ContentType string @@ -80,6 +87,15 @@ func (d *GetMediaResponseCallback) GetContentLength() int64 { return d.ContentLength } +func (d *GetMediaResponseCallback) GetContentType() string { + return d.ContentType +} + +type GetMediaResponseFile struct { + Callback func(w *os.File) error + ContentType string +} + type GetMediaFunc = func(ctx context.Context, mediaID string) (response GetMediaResponse, err error) type MediaProxy struct { @@ -350,6 +366,20 @@ func (mp *MediaProxy) DownloadMediaFederation(w http.ResponseWriter, r *http.Req log.Err(err).Msg("Failed to create multipart redirect field") return } + } else if fileResp, ok := resp.(*GetMediaResponseFile); ok { + _, err = doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error { + dataPart, err := mpw.CreatePart(textproto.MIMEHeader{ + "Content-Type": {mimeType}, + }) + if err != nil { + return fmt.Errorf("failed to create multipart data field: %w", err) + } + _, err = wt.WriteTo(dataPart) + return err + }) + if err != nil { + log.Err(err).Msg("Failed to do media proxy with temp file") + } } else if dataResp, ok := resp.(GetMediaResponseWriter); ok { dataPart, err := mpw.CreatePart(textproto.MIMEHeader{ "Content-Type": {dataResp.GetContentType()}, @@ -408,6 +438,20 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) { w.Header().Set("Cache-Control", "no-store") } w.WriteHeader(http.StatusTemporaryRedirect) + } else if fileResp, ok := resp.(*GetMediaResponseFile); ok { + responseStarted, err := doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error { + w.Header().Set("Content-Type", mimeType) + w.Header().Set("Content-Length", strconv.FormatInt(size, 10)) + w.WriteHeader(http.StatusOK) + _, err := wt.WriteTo(w) + return err + }) + if err != nil { + log.Err(err).Msg("Failed to do media proxy with temp file") + if !responseStarted { + mautrix.MUnknown.WithMessage("Internal error proxying media").Write(w) + } + } } else if dataResp, ok := resp.(GetMediaResponseWriter); ok { w.Header().Set("Content-Type", dataResp.GetContentType()) if dataResp.GetContentLength() != 0 { @@ -423,6 +467,51 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) { } } +func doTempFileDownload( + data *GetMediaResponseFile, + respond func(w io.WriterTo, size int64, mimeType string) error, +) (bool, error) { + tempFile, err := os.CreateTemp("", "mautrix-mediaproxy-*") + if err != nil { + return false, fmt.Errorf("failed to create temp file: %w", err) + } + defer func() { + _ = tempFile.Close() + _ = os.Remove(tempFile.Name()) + }() + err = data.Callback(tempFile) + if err != nil { + return false, err + } + _, err = tempFile.Seek(0, io.SeekStart) + if err != nil { + return false, fmt.Errorf("failed to seek to start of temp file: %w", err) + } + fileInfo, err := tempFile.Stat() + if err != nil { + return false, fmt.Errorf("failed to stat temp file: %w", err) + } + mimeType := data.ContentType + if mimeType == "" { + buf := make([]byte, 512) + n, err := tempFile.Read(buf) + if err != nil { + return false, fmt.Errorf("failed to read temp file to detect mime: %w", err) + } + buf = buf[:n] + _, err = tempFile.Seek(0, io.SeekStart) + if err != nil { + return false, fmt.Errorf("failed to seek to start of temp file: %w", err) + } + mimeType = http.DetectContentType(buf) + } + err = respond(tempFile, fileInfo.Size(), mimeType) + if err != nil { + return true, err + } + return true, nil +} + func jsonResponse(w http.ResponseWriter, status int, response interface{}) { w.Header().Add("Content-Type", "application/json") w.WriteHeader(status)