mediaproxy: add support for temp files

This commit is contained in:
Tulir Asokan 2024-11-06 11:01:49 +01:00
commit 39bddeb7d3
2 changed files with 92 additions and 2 deletions

View file

@ -17,6 +17,7 @@ import (
"net"
"net/http"
"net/textproto"
"os"
"strconv"
"strings"
"time"
@ -35,6 +36,7 @@ type GetMediaResponse interface {
func (*GetMediaResponseURL) isGetMediaResponse() {}
func (*GetMediaResponseData) isGetMediaResponse() {}
func (*GetMediaResponseCallback) isGetMediaResponse() {}
func (*GetMediaResponseFile) isGetMediaResponse() {}
type GetMediaResponseURL struct {
URL string
@ -48,6 +50,11 @@ type GetMediaResponseWriter interface {
GetContentLength() int64
}
var (
_ GetMediaResponseWriter = (*GetMediaResponseCallback)(nil)
_ GetMediaResponseWriter = (*GetMediaResponseData)(nil)
)
type GetMediaResponseData struct {
Reader io.ReadCloser
ContentType string
@ -80,6 +87,15 @@ func (d *GetMediaResponseCallback) GetContentLength() int64 {
return d.ContentLength
}
func (d *GetMediaResponseCallback) GetContentType() string {
return d.ContentType
}
type GetMediaResponseFile struct {
Callback func(w *os.File) error
ContentType string
}
type GetMediaFunc = func(ctx context.Context, mediaID string) (response GetMediaResponse, err error)
type MediaProxy struct {
@ -350,6 +366,20 @@ func (mp *MediaProxy) DownloadMediaFederation(w http.ResponseWriter, r *http.Req
log.Err(err).Msg("Failed to create multipart redirect field")
return
}
} else if fileResp, ok := resp.(*GetMediaResponseFile); ok {
_, err = doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error {
dataPart, err := mpw.CreatePart(textproto.MIMEHeader{
"Content-Type": {mimeType},
})
if err != nil {
return fmt.Errorf("failed to create multipart data field: %w", err)
}
_, err = wt.WriteTo(dataPart)
return err
})
if err != nil {
log.Err(err).Msg("Failed to do media proxy with temp file")
}
} else if dataResp, ok := resp.(GetMediaResponseWriter); ok {
dataPart, err := mpw.CreatePart(textproto.MIMEHeader{
"Content-Type": {dataResp.GetContentType()},
@ -408,6 +438,20 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Cache-Control", "no-store")
}
w.WriteHeader(http.StatusTemporaryRedirect)
} else if fileResp, ok := resp.(*GetMediaResponseFile); ok {
responseStarted, err := doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error {
w.Header().Set("Content-Type", mimeType)
w.Header().Set("Content-Length", strconv.FormatInt(size, 10))
w.WriteHeader(http.StatusOK)
_, err := wt.WriteTo(w)
return err
})
if err != nil {
log.Err(err).Msg("Failed to do media proxy with temp file")
if !responseStarted {
mautrix.MUnknown.WithMessage("Internal error proxying media").Write(w)
}
}
} else if dataResp, ok := resp.(GetMediaResponseWriter); ok {
w.Header().Set("Content-Type", dataResp.GetContentType())
if dataResp.GetContentLength() != 0 {
@ -423,6 +467,51 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) {
}
}
func doTempFileDownload(
data *GetMediaResponseFile,
respond func(w io.WriterTo, size int64, mimeType string) error,
) (bool, error) {
tempFile, err := os.CreateTemp("", "mautrix-mediaproxy-*")
if err != nil {
return false, fmt.Errorf("failed to create temp file: %w", err)
}
defer func() {
_ = tempFile.Close()
_ = os.Remove(tempFile.Name())
}()
err = data.Callback(tempFile)
if err != nil {
return false, err
}
_, err = tempFile.Seek(0, io.SeekStart)
if err != nil {
return false, fmt.Errorf("failed to seek to start of temp file: %w", err)
}
fileInfo, err := tempFile.Stat()
if err != nil {
return false, fmt.Errorf("failed to stat temp file: %w", err)
}
mimeType := data.ContentType
if mimeType == "" {
buf := make([]byte, 512)
n, err := tempFile.Read(buf)
if err != nil {
return false, fmt.Errorf("failed to read temp file to detect mime: %w", err)
}
buf = buf[:n]
_, err = tempFile.Seek(0, io.SeekStart)
if err != nil {
return false, fmt.Errorf("failed to seek to start of temp file: %w", err)
}
mimeType = http.DetectContentType(buf)
}
err = respond(tempFile, fileInfo.Size(), mimeType)
if err != nil {
return true, err
}
return true, nil
}
func jsonResponse(w http.ResponseWriter, status int, response interface{}) {
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(status)