From ce2ffd8232a5e525fe62c67d1a1659fd65426158 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 20 Aug 2024 00:52:54 +0300 Subject: [PATCH] bridgev2/matrix: add new stream upload that uses a writer instead of a reader (#269) --- bridgev2/bridgeconfig/config.go | 11 +- bridgev2/bridgeconfig/upgrade.go | 1 + bridgev2/matrix/intent.go | 111 ++++++++++++++------- bridgev2/matrix/mxmain/example-config.yaml | 3 + bridgev2/matrixinterface.go | 10 ++ crypto/attachment/attachments.go | 37 +++++++ 6 files changed, 130 insertions(+), 43 deletions(-) diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index ab97c891..40a17622 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -69,11 +69,12 @@ type BridgeConfig struct { } type MatrixConfig struct { - MessageStatusEvents bool `yaml:"message_status_events"` - DeliveryReceipts bool `yaml:"delivery_receipts"` - MessageErrorNotices bool `yaml:"message_error_notices"` - SyncDirectChatList bool `yaml:"sync_direct_chat_list"` - FederateRooms bool `yaml:"federate_rooms"` + MessageStatusEvents bool `yaml:"message_status_events"` + DeliveryReceipts bool `yaml:"delivery_receipts"` + MessageErrorNotices bool `yaml:"message_error_notices"` + SyncDirectChatList bool `yaml:"sync_direct_chat_list"` + FederateRooms bool `yaml:"federate_rooms"` + UploadFileThreshold int64 `yaml:"upload_file_threshold"` } type ProvisioningConfig struct { diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 4eff205d..9597fa4f 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -84,6 +84,7 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Bool, "matrix", "message_error_notices") helper.Copy(up.Bool, "matrix", "sync_direct_chat_list") helper.Copy(up.Bool, "matrix", "federate_rooms") + helper.Copy(up.Int, "matrix", "upload_file_threshold") helper.Copy(up.Str, "provisioning", "prefix") if secret, ok := helper.Get(up.Str, "provisioning", "shared_secret"); !ok || secret == "generate" { diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 8846a30b..115d7393 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -238,13 +238,11 @@ func (as *ASIntent) UploadMedia(ctx context.Context, roomID id.RoomID, data []by return } -const inMemoryUploadThreshold = 5 * 1024 * 1024 - -type writeToCapturer struct { +type simpleBuffer struct { data []byte } -func (w *writeToCapturer) Write(p []byte) (n int, err error) { +func (w *simpleBuffer) Write(p []byte) (n int, err error) { if w.data == nil { w.data = p } else { @@ -253,36 +251,50 @@ func (w *writeToCapturer) Write(p []byte) (n int, err error) { return len(p), nil } -func (as *ASIntent) UploadMediaStream(ctx context.Context, roomID id.RoomID, data io.Reader, size int64, fileName, mimeType string) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) { +func (w *simpleBuffer) Seek(offset int64, whence int) (int64, error) { + if whence == io.SeekStart { + if offset == 0 { + w.data = nil + } else { + w.data = w.data[:offset] + } + return offset, nil + } + return 0, fmt.Errorf("unsupported whence value %d", whence) +} + +func (as *ASIntent) UploadMediaStream( + ctx context.Context, + roomID id.RoomID, + size int64, + requireFile bool, + fileName, + mimeType string, + cb bridgev2.FileStreamCallback, +) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) { if size > as.Connector.MediaConfig.UploadSize { return "", nil, fmt.Errorf("file too large (%.2f MB > %.2f MB)", float64(size)/1000/1000, float64(as.Connector.MediaConfig.UploadSize)/1000/1000) - } else if 0 < size && size < inMemoryUploadThreshold { - var dataBytes []byte - wt, ok := data.(io.WriterTo) - if ok { - capturer := &writeToCapturer{} - _, err = wt.WriteTo(capturer) - if err != nil { - return "", nil, err - } - dataBytes = capturer.data - } else { - dataBytes, err = io.ReadAll(data) - if err != nil { - return "", nil, err - } - } - return as.UploadMedia(ctx, roomID, dataBytes, fileName, mimeType) } - tempFile, err := os.CreateTemp("", "mautrix-upload-*") + if !requireFile && 0 < size && size < as.Connector.Config.Matrix.UploadFileThreshold { + var buf simpleBuffer + replPath, err := cb(&buf) + if err != nil { + return "", nil, err + } else if replPath != "" { + panic(fmt.Errorf("logic error: replacement path must only be returned if requireFile is true")) + } + return as.UploadMedia(ctx, roomID, buf.data, fileName, mimeType) + } + var tempFile *os.File + tempFile, err = os.CreateTemp("", "mautrix-upload-*") if err != nil { - return "", nil, fmt.Errorf("failed to create temp file: %w", err) + err = fmt.Errorf("failed to create temp file: %w", err) + return } defer func() { _ = tempFile.Close() _ = os.Remove(tempFile.Name()) }() - var realSize int64 if roomID != "" { var encrypted bool if encrypted, err = as.Matrix.StateStore.IsEncrypted(ctx, roomID); err != nil { @@ -292,27 +304,50 @@ func (as *ASIntent) UploadMediaStream(ctx context.Context, roomID id.RoomID, dat file = &event.EncryptedFileInfo{ EncryptedFile: *attachment.NewEncryptedFile(), } - encryptStream := file.EncryptStream(data) - realSize, err = io.Copy(tempFile, encryptStream) - if err != nil { - return "", nil, fmt.Errorf("failed to write to temp file: %w", err) - } - err = encryptStream.Close() - if err != nil { - return "", nil, fmt.Errorf("failed to finalize encryption: %w", err) - } mimeType = "application/octet-stream" fileName = "" } - } else { - realSize, err = io.Copy(tempFile, data) + } + var replPath string + replPath, err = cb(tempFile) + if err != nil { + err = fmt.Errorf("failed to write to temp file: %w", err) + } + var replFile *os.File + if replPath != "" { + replFile, err = os.OpenFile(replPath, os.O_RDWR, 0) if err != nil { - return "", nil, fmt.Errorf("failed to write to temp file: %w", err) + err = fmt.Errorf("failed to open replacement file: %w", err) + return } + } else { + replFile = tempFile + _, err = replFile.Seek(0, io.SeekStart) + if err != nil { + err = fmt.Errorf("failed to seek to start of temp file: %w", err) + return + } + } + if file != nil { + err = file.EncryptFile(replFile) + if err != nil { + err = fmt.Errorf("failed to encrypt file: %w", err) + return + } + _, err = replFile.Seek(0, io.SeekStart) + if err != nil { + err = fmt.Errorf("failed to seek to start of temp file after encrypting: %w", err) + return + } + } + info, err := replFile.Stat() + if err != nil { + err = fmt.Errorf("failed to get temp file info: %w", err) + return } url, err = as.doUploadReq(ctx, file, mautrix.ReqUploadMedia{ Content: tempFile, - ContentLength: realSize, + ContentLength: info.Size(), ContentType: mimeType, FileName: fileName, }) diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 06ed010f..e0a5ed87 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -198,6 +198,9 @@ matrix: # Whether created rooms should have federation enabled. If false, created portal rooms # will never be federated. Changing this option requires recreating rooms. federate_rooms: true + # The threshold as bytes after which the bridge should roundtrip uploads via the disk + # rather than keeping the whole file in memory. + upload_file_threshold: 5242880 # Settings for provisioning API provisioning: diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 0628f16d..02528fde 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -8,6 +8,7 @@ package bridgev2 import ( "context" + "io" "time" "github.com/gorilla/mux" @@ -78,6 +79,14 @@ type MatrixSendExtra struct { PartIndex int } +// FileStreamCallback is a callback function for file uploads that roundtrip via disk. +// +// The parameter is either a file or an in-memory buffer depending on the size of the file and whether the requireFile flag was set. +// +// The first return value can specify a file path to use instead of the original temp file. +// Returning a replacement path is only valid if the parameter is a file. +type FileStreamCallback func(file io.WriteSeeker) (string, error) + type MatrixAPI interface { GetMXID() id.UserID @@ -88,6 +97,7 @@ type MatrixAPI interface { MarkTyping(ctx context.Context, roomID id.RoomID, typingType TypingType, timeout time.Duration) error DownloadMedia(ctx context.Context, uri id.ContentURIString, file *event.EncryptedFileInfo) ([]byte, error) UploadMedia(ctx context.Context, roomID id.RoomID, data []byte, fileName, mimeType string) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) + UploadMediaStream(ctx context.Context, roomID id.RoomID, size int64, requireFile bool, fileName, mimeType string, cb FileStreamCallback) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) SetDisplayName(ctx context.Context, name string) error SetAvatarURL(ctx context.Context, avatarURL id.ContentURIString) error diff --git a/crypto/attachment/attachments.go b/crypto/attachment/attachments.go index 344db4f0..cfa1c3e5 100644 --- a/crypto/attachment/attachments.go +++ b/crypto/attachment/attachments.go @@ -127,6 +127,43 @@ func (ef *EncryptedFile) EncryptInPlace(data []byte) { ef.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString(checksum[:]) } +type ReadWriterAt interface { + io.WriterAt + io.Reader +} + +// EncryptFile encrypts the given file in-place and updates the SHA256 hash in the EncryptedFile struct. +func (ef *EncryptedFile) EncryptFile(file ReadWriterAt) error { + err := ef.decodeKeys(false) + if err != nil { + return err + } + block, _ := aes.NewCipher(ef.decoded.key[:]) + stream := cipher.NewCTR(block, ef.decoded.iv[:]) + hasher := sha256.New() + buf := make([]byte, 32*1024) + var writePtr int64 + var n int + for { + n, err = file.Read(buf) + if err != nil && !errors.Is(err, io.EOF) { + return err + } + if n == 0 { + break + } + stream.XORKeyStream(buf[:n], buf[:n]) + _, err = file.WriteAt(buf[:n], writePtr) + if err != nil { + return err + } + writePtr += int64(n) + hasher.Write(buf[:n]) + } + ef.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString(hasher.Sum(nil)) + return nil +} + type encryptingReader struct { stream cipher.Stream hash hash.Hash