diff --git a/CHANGELOG.md b/CHANGELOG.md index 56a297a8..a971597c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ * *(pushrules)* Added support for `sender_notification_permission` condition kind (used for `@room` mentions). * *(crypto)* Added support for `json.RawMessage` in `EncryptMegolmEvent`. +* *(mediaproxy)* Added `GetMediaResponseCallback` to write proxied response + directly instead of having to use an `io.Reader`. [MSC2781]: https://github.com/matrix-org/matrix-spec-proposals/pull/2781 diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go index f2591428..d1ab0815 100644 --- a/mediaproxy/mediaproxy.go +++ b/mediaproxy/mediaproxy.go @@ -32,20 +32,54 @@ type GetMediaResponse interface { isGetMediaResponse() } -func (*GetMediaResponseURL) isGetMediaResponse() {} -func (*GetMediaResponseData) isGetMediaResponse() {} +func (*GetMediaResponseURL) isGetMediaResponse() {} +func (*GetMediaResponseData) isGetMediaResponse() {} +func (*GetMediaResponseCallback) isGetMediaResponse() {} type GetMediaResponseURL struct { URL string ExpiresAt time.Time } +type GetMediaResponseWriter interface { + GetMediaResponse + io.WriterTo + GetContentType() string + GetContentLength() int64 +} + type GetMediaResponseData struct { Reader io.ReadCloser ContentType string ContentLength int64 } +func (d *GetMediaResponseData) WriteTo(w io.Writer) (int64, error) { + return io.Copy(w, d.Reader) +} + +func (d *GetMediaResponseData) GetContentType() string { + return d.ContentType +} + +func (d *GetMediaResponseData) GetContentLength() int64 { + return d.ContentLength +} + +type GetMediaResponseCallback struct { + Callback func(w io.Writer) (int64, error) + ContentType string + ContentLength int64 +} + +func (d *GetMediaResponseCallback) WriteTo(w io.Writer) (int64, error) { + return d.Callback(w) +} + +func (d *GetMediaResponseCallback) GetContentLength() int64 { + return d.ContentLength +} + type GetMediaFunc = func(ctx context.Context, mediaID string) (response GetMediaResponse, err error) type MediaProxy struct { @@ -316,21 +350,21 @@ func (mp *MediaProxy) DownloadMediaFederation(w http.ResponseWriter, r *http.Req log.Err(err).Msg("Failed to create multipart redirect field") return } - } else if dataResp, ok := resp.(*GetMediaResponseData); ok { + } else if dataResp, ok := resp.(GetMediaResponseWriter); ok { dataPart, err := mpw.CreatePart(textproto.MIMEHeader{ - "Content-Type": {dataResp.ContentType}, + "Content-Type": {dataResp.GetContentType()}, }) if err != nil { log.Err(err).Msg("Failed to create multipart data field") return } - _, err = io.Copy(dataPart, dataResp.Reader) + _, err = dataResp.WriteTo(dataPart) if err != nil { log.Err(err).Msg("Failed to write multipart data field") return } } else { - panic("unknown GetMediaResponse type") + panic(fmt.Errorf("unknown GetMediaResponse type %T", resp)) } err = mpw.Close() if err != nil { @@ -374,18 +408,18 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) { w.Header().Set("Cache-Control", "no-store") } w.WriteHeader(http.StatusTemporaryRedirect) - } else if dataResp, ok := resp.(*GetMediaResponseData); ok { - w.Header().Set("Content-Type", dataResp.ContentType) - if dataResp.ContentLength != 0 { - w.Header().Set("Content-Length", strconv.FormatInt(dataResp.ContentLength, 10)) + } else if dataResp, ok := resp.(GetMediaResponseWriter); ok { + w.Header().Set("Content-Type", dataResp.GetContentType()) + if dataResp.GetContentLength() != 0 { + w.Header().Set("Content-Length", strconv.FormatInt(dataResp.GetContentLength(), 10)) } w.WriteHeader(http.StatusOK) - _, err := io.Copy(w, dataResp.Reader) + _, err := dataResp.WriteTo(w) if err != nil { log.Err(err).Msg("Failed to write media data") } } else { - panic("unknown GetMediaResponse type") + panic(fmt.Errorf("unknown GetMediaResponse type %T", resp)) } }