mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
mediaproxy: add support for temp files
This commit is contained in:
parent
56aadb232f
commit
39bddeb7d3
2 changed files with 92 additions and 2 deletions
|
|
@ -17,6 +17,7 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
|
@ -35,6 +36,7 @@ type GetMediaResponse interface {
|
|||
func (*GetMediaResponseURL) isGetMediaResponse() {}
|
||||
func (*GetMediaResponseData) isGetMediaResponse() {}
|
||||
func (*GetMediaResponseCallback) isGetMediaResponse() {}
|
||||
func (*GetMediaResponseFile) isGetMediaResponse() {}
|
||||
|
||||
type GetMediaResponseURL struct {
|
||||
URL string
|
||||
|
|
@ -48,6 +50,11 @@ type GetMediaResponseWriter interface {
|
|||
GetContentLength() int64
|
||||
}
|
||||
|
||||
var (
|
||||
_ GetMediaResponseWriter = (*GetMediaResponseCallback)(nil)
|
||||
_ GetMediaResponseWriter = (*GetMediaResponseData)(nil)
|
||||
)
|
||||
|
||||
type GetMediaResponseData struct {
|
||||
Reader io.ReadCloser
|
||||
ContentType string
|
||||
|
|
@ -80,6 +87,15 @@ func (d *GetMediaResponseCallback) GetContentLength() int64 {
|
|||
return d.ContentLength
|
||||
}
|
||||
|
||||
func (d *GetMediaResponseCallback) GetContentType() string {
|
||||
return d.ContentType
|
||||
}
|
||||
|
||||
type GetMediaResponseFile struct {
|
||||
Callback func(w *os.File) error
|
||||
ContentType string
|
||||
}
|
||||
|
||||
type GetMediaFunc = func(ctx context.Context, mediaID string) (response GetMediaResponse, err error)
|
||||
|
||||
type MediaProxy struct {
|
||||
|
|
@ -350,6 +366,20 @@ func (mp *MediaProxy) DownloadMediaFederation(w http.ResponseWriter, r *http.Req
|
|||
log.Err(err).Msg("Failed to create multipart redirect field")
|
||||
return
|
||||
}
|
||||
} else if fileResp, ok := resp.(*GetMediaResponseFile); ok {
|
||||
_, err = doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error {
|
||||
dataPart, err := mpw.CreatePart(textproto.MIMEHeader{
|
||||
"Content-Type": {mimeType},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create multipart data field: %w", err)
|
||||
}
|
||||
_, err = wt.WriteTo(dataPart)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to do media proxy with temp file")
|
||||
}
|
||||
} else if dataResp, ok := resp.(GetMediaResponseWriter); ok {
|
||||
dataPart, err := mpw.CreatePart(textproto.MIMEHeader{
|
||||
"Content-Type": {dataResp.GetContentType()},
|
||||
|
|
@ -408,6 +438,20 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) {
|
|||
w.Header().Set("Cache-Control", "no-store")
|
||||
}
|
||||
w.WriteHeader(http.StatusTemporaryRedirect)
|
||||
} else if fileResp, ok := resp.(*GetMediaResponseFile); ok {
|
||||
responseStarted, err := doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error {
|
||||
w.Header().Set("Content-Type", mimeType)
|
||||
w.Header().Set("Content-Length", strconv.FormatInt(size, 10))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, err := wt.WriteTo(w)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to do media proxy with temp file")
|
||||
if !responseStarted {
|
||||
mautrix.MUnknown.WithMessage("Internal error proxying media").Write(w)
|
||||
}
|
||||
}
|
||||
} else if dataResp, ok := resp.(GetMediaResponseWriter); ok {
|
||||
w.Header().Set("Content-Type", dataResp.GetContentType())
|
||||
if dataResp.GetContentLength() != 0 {
|
||||
|
|
@ -423,6 +467,51 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
func doTempFileDownload(
|
||||
data *GetMediaResponseFile,
|
||||
respond func(w io.WriterTo, size int64, mimeType string) error,
|
||||
) (bool, error) {
|
||||
tempFile, err := os.CreateTemp("", "mautrix-mediaproxy-*")
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to create temp file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = tempFile.Close()
|
||||
_ = os.Remove(tempFile.Name())
|
||||
}()
|
||||
err = data.Callback(tempFile)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
_, err = tempFile.Seek(0, io.SeekStart)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to seek to start of temp file: %w", err)
|
||||
}
|
||||
fileInfo, err := tempFile.Stat()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to stat temp file: %w", err)
|
||||
}
|
||||
mimeType := data.ContentType
|
||||
if mimeType == "" {
|
||||
buf := make([]byte, 512)
|
||||
n, err := tempFile.Read(buf)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to read temp file to detect mime: %w", err)
|
||||
}
|
||||
buf = buf[:n]
|
||||
_, err = tempFile.Seek(0, io.SeekStart)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to seek to start of temp file: %w", err)
|
||||
}
|
||||
mimeType = http.DetectContentType(buf)
|
||||
}
|
||||
err = respond(tempFile, fileInfo.Size(), mimeType)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func jsonResponse(w http.ResponseWriter, status int, response interface{}) {
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue