mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
bridgev2/matrix: add new stream upload that uses a writer instead of a reader (#269)
This commit is contained in:
parent
79527df26e
commit
ce2ffd8232
6 changed files with 128 additions and 41 deletions
|
|
@ -69,11 +69,12 @@ type BridgeConfig struct {
|
|||
}
|
||||
|
||||
type MatrixConfig struct {
|
||||
MessageStatusEvents bool `yaml:"message_status_events"`
|
||||
DeliveryReceipts bool `yaml:"delivery_receipts"`
|
||||
MessageErrorNotices bool `yaml:"message_error_notices"`
|
||||
SyncDirectChatList bool `yaml:"sync_direct_chat_list"`
|
||||
FederateRooms bool `yaml:"federate_rooms"`
|
||||
MessageStatusEvents bool `yaml:"message_status_events"`
|
||||
DeliveryReceipts bool `yaml:"delivery_receipts"`
|
||||
MessageErrorNotices bool `yaml:"message_error_notices"`
|
||||
SyncDirectChatList bool `yaml:"sync_direct_chat_list"`
|
||||
FederateRooms bool `yaml:"federate_rooms"`
|
||||
UploadFileThreshold int64 `yaml:"upload_file_threshold"`
|
||||
}
|
||||
|
||||
type ProvisioningConfig struct {
|
||||
|
|
|
|||
|
|
@ -84,6 +84,7 @@ func doUpgrade(helper up.Helper) {
|
|||
helper.Copy(up.Bool, "matrix", "message_error_notices")
|
||||
helper.Copy(up.Bool, "matrix", "sync_direct_chat_list")
|
||||
helper.Copy(up.Bool, "matrix", "federate_rooms")
|
||||
helper.Copy(up.Int, "matrix", "upload_file_threshold")
|
||||
|
||||
helper.Copy(up.Str, "provisioning", "prefix")
|
||||
if secret, ok := helper.Get(up.Str, "provisioning", "shared_secret"); !ok || secret == "generate" {
|
||||
|
|
|
|||
|
|
@ -238,13 +238,11 @@ func (as *ASIntent) UploadMedia(ctx context.Context, roomID id.RoomID, data []by
|
|||
return
|
||||
}
|
||||
|
||||
const inMemoryUploadThreshold = 5 * 1024 * 1024
|
||||
|
||||
type writeToCapturer struct {
|
||||
type simpleBuffer struct {
|
||||
data []byte
|
||||
}
|
||||
|
||||
func (w *writeToCapturer) Write(p []byte) (n int, err error) {
|
||||
func (w *simpleBuffer) Write(p []byte) (n int, err error) {
|
||||
if w.data == nil {
|
||||
w.data = p
|
||||
} else {
|
||||
|
|
@ -253,36 +251,50 @@ func (w *writeToCapturer) Write(p []byte) (n int, err error) {
|
|||
return len(p), nil
|
||||
}
|
||||
|
||||
func (as *ASIntent) UploadMediaStream(ctx context.Context, roomID id.RoomID, data io.Reader, size int64, fileName, mimeType string) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) {
|
||||
func (w *simpleBuffer) Seek(offset int64, whence int) (int64, error) {
|
||||
if whence == io.SeekStart {
|
||||
if offset == 0 {
|
||||
w.data = nil
|
||||
} else {
|
||||
w.data = w.data[:offset]
|
||||
}
|
||||
return offset, nil
|
||||
}
|
||||
return 0, fmt.Errorf("unsupported whence value %d", whence)
|
||||
}
|
||||
|
||||
func (as *ASIntent) UploadMediaStream(
|
||||
ctx context.Context,
|
||||
roomID id.RoomID,
|
||||
size int64,
|
||||
requireFile bool,
|
||||
fileName,
|
||||
mimeType string,
|
||||
cb bridgev2.FileStreamCallback,
|
||||
) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) {
|
||||
if size > as.Connector.MediaConfig.UploadSize {
|
||||
return "", nil, fmt.Errorf("file too large (%.2f MB > %.2f MB)", float64(size)/1000/1000, float64(as.Connector.MediaConfig.UploadSize)/1000/1000)
|
||||
} else if 0 < size && size < inMemoryUploadThreshold {
|
||||
var dataBytes []byte
|
||||
wt, ok := data.(io.WriterTo)
|
||||
if ok {
|
||||
capturer := &writeToCapturer{}
|
||||
_, err = wt.WriteTo(capturer)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
dataBytes = capturer.data
|
||||
} else {
|
||||
dataBytes, err = io.ReadAll(data)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
return as.UploadMedia(ctx, roomID, dataBytes, fileName, mimeType)
|
||||
}
|
||||
tempFile, err := os.CreateTemp("", "mautrix-upload-*")
|
||||
if !requireFile && 0 < size && size < as.Connector.Config.Matrix.UploadFileThreshold {
|
||||
var buf simpleBuffer
|
||||
replPath, err := cb(&buf)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
} else if replPath != "" {
|
||||
panic(fmt.Errorf("logic error: replacement path must only be returned if requireFile is true"))
|
||||
}
|
||||
return as.UploadMedia(ctx, roomID, buf.data, fileName, mimeType)
|
||||
}
|
||||
var tempFile *os.File
|
||||
tempFile, err = os.CreateTemp("", "mautrix-upload-*")
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to create temp file: %w", err)
|
||||
err = fmt.Errorf("failed to create temp file: %w", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = tempFile.Close()
|
||||
_ = os.Remove(tempFile.Name())
|
||||
}()
|
||||
var realSize int64
|
||||
if roomID != "" {
|
||||
var encrypted bool
|
||||
if encrypted, err = as.Matrix.StateStore.IsEncrypted(ctx, roomID); err != nil {
|
||||
|
|
@ -292,27 +304,50 @@ func (as *ASIntent) UploadMediaStream(ctx context.Context, roomID id.RoomID, dat
|
|||
file = &event.EncryptedFileInfo{
|
||||
EncryptedFile: *attachment.NewEncryptedFile(),
|
||||
}
|
||||
encryptStream := file.EncryptStream(data)
|
||||
realSize, err = io.Copy(tempFile, encryptStream)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to write to temp file: %w", err)
|
||||
}
|
||||
err = encryptStream.Close()
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to finalize encryption: %w", err)
|
||||
}
|
||||
mimeType = "application/octet-stream"
|
||||
fileName = ""
|
||||
}
|
||||
} else {
|
||||
realSize, err = io.Copy(tempFile, data)
|
||||
}
|
||||
var replPath string
|
||||
replPath, err = cb(tempFile)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to write to temp file: %w", err)
|
||||
}
|
||||
var replFile *os.File
|
||||
if replPath != "" {
|
||||
replFile, err = os.OpenFile(replPath, os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to write to temp file: %w", err)
|
||||
err = fmt.Errorf("failed to open replacement file: %w", err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
replFile = tempFile
|
||||
_, err = replFile.Seek(0, io.SeekStart)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to seek to start of temp file: %w", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
if file != nil {
|
||||
err = file.EncryptFile(replFile)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to encrypt file: %w", err)
|
||||
return
|
||||
}
|
||||
_, err = replFile.Seek(0, io.SeekStart)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to seek to start of temp file after encrypting: %w", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
info, err := replFile.Stat()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to get temp file info: %w", err)
|
||||
return
|
||||
}
|
||||
url, err = as.doUploadReq(ctx, file, mautrix.ReqUploadMedia{
|
||||
Content: tempFile,
|
||||
ContentLength: realSize,
|
||||
ContentLength: info.Size(),
|
||||
ContentType: mimeType,
|
||||
FileName: fileName,
|
||||
})
|
||||
|
|
|
|||
|
|
@ -198,6 +198,9 @@ matrix:
|
|||
# Whether created rooms should have federation enabled. If false, created portal rooms
|
||||
# will never be federated. Changing this option requires recreating rooms.
|
||||
federate_rooms: true
|
||||
# The threshold as bytes after which the bridge should roundtrip uploads via the disk
|
||||
# rather than keeping the whole file in memory.
|
||||
upload_file_threshold: 5242880
|
||||
|
||||
# Settings for provisioning API
|
||||
provisioning:
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ package bridgev2
|
|||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
|
@ -78,6 +79,14 @@ type MatrixSendExtra struct {
|
|||
PartIndex int
|
||||
}
|
||||
|
||||
// FileStreamCallback is a callback function for file uploads that roundtrip via disk.
|
||||
//
|
||||
// The parameter is either a file or an in-memory buffer depending on the size of the file and whether the requireFile flag was set.
|
||||
//
|
||||
// The first return value can specify a file path to use instead of the original temp file.
|
||||
// Returning a replacement path is only valid if the parameter is a file.
|
||||
type FileStreamCallback func(file io.WriteSeeker) (string, error)
|
||||
|
||||
type MatrixAPI interface {
|
||||
GetMXID() id.UserID
|
||||
|
||||
|
|
@ -88,6 +97,7 @@ type MatrixAPI interface {
|
|||
MarkTyping(ctx context.Context, roomID id.RoomID, typingType TypingType, timeout time.Duration) error
|
||||
DownloadMedia(ctx context.Context, uri id.ContentURIString, file *event.EncryptedFileInfo) ([]byte, error)
|
||||
UploadMedia(ctx context.Context, roomID id.RoomID, data []byte, fileName, mimeType string) (url id.ContentURIString, file *event.EncryptedFileInfo, err error)
|
||||
UploadMediaStream(ctx context.Context, roomID id.RoomID, size int64, requireFile bool, fileName, mimeType string, cb FileStreamCallback) (url id.ContentURIString, file *event.EncryptedFileInfo, err error)
|
||||
|
||||
SetDisplayName(ctx context.Context, name string) error
|
||||
SetAvatarURL(ctx context.Context, avatarURL id.ContentURIString) error
|
||||
|
|
|
|||
|
|
@ -127,6 +127,43 @@ func (ef *EncryptedFile) EncryptInPlace(data []byte) {
|
|||
ef.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString(checksum[:])
|
||||
}
|
||||
|
||||
type ReadWriterAt interface {
|
||||
io.WriterAt
|
||||
io.Reader
|
||||
}
|
||||
|
||||
// EncryptFile encrypts the given file in-place and updates the SHA256 hash in the EncryptedFile struct.
|
||||
func (ef *EncryptedFile) EncryptFile(file ReadWriterAt) error {
|
||||
err := ef.decodeKeys(false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
block, _ := aes.NewCipher(ef.decoded.key[:])
|
||||
stream := cipher.NewCTR(block, ef.decoded.iv[:])
|
||||
hasher := sha256.New()
|
||||
buf := make([]byte, 32*1024)
|
||||
var writePtr int64
|
||||
var n int
|
||||
for {
|
||||
n, err = file.Read(buf)
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
return err
|
||||
}
|
||||
if n == 0 {
|
||||
break
|
||||
}
|
||||
stream.XORKeyStream(buf[:n], buf[:n])
|
||||
_, err = file.WriteAt(buf[:n], writePtr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
writePtr += int64(n)
|
||||
hasher.Write(buf[:n])
|
||||
}
|
||||
ef.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString(hasher.Sum(nil))
|
||||
return nil
|
||||
}
|
||||
|
||||
type encryptingReader struct {
|
||||
stream cipher.Stream
|
||||
hash hash.Hash
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue