bridgev2/matrix: add new stream upload that uses a writer instead of a reader (#269)

This commit is contained in:
Tulir Asokan 2024-08-20 00:52:54 +03:00 committed by GitHub
commit ce2ffd8232
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 128 additions and 41 deletions

View file

@ -69,11 +69,12 @@ type BridgeConfig struct {
}
type MatrixConfig struct {
MessageStatusEvents bool `yaml:"message_status_events"`
DeliveryReceipts bool `yaml:"delivery_receipts"`
MessageErrorNotices bool `yaml:"message_error_notices"`
SyncDirectChatList bool `yaml:"sync_direct_chat_list"`
FederateRooms bool `yaml:"federate_rooms"`
MessageStatusEvents bool `yaml:"message_status_events"`
DeliveryReceipts bool `yaml:"delivery_receipts"`
MessageErrorNotices bool `yaml:"message_error_notices"`
SyncDirectChatList bool `yaml:"sync_direct_chat_list"`
FederateRooms bool `yaml:"federate_rooms"`
UploadFileThreshold int64 `yaml:"upload_file_threshold"`
}
type ProvisioningConfig struct {

View file

@ -84,6 +84,7 @@ func doUpgrade(helper up.Helper) {
helper.Copy(up.Bool, "matrix", "message_error_notices")
helper.Copy(up.Bool, "matrix", "sync_direct_chat_list")
helper.Copy(up.Bool, "matrix", "federate_rooms")
helper.Copy(up.Int, "matrix", "upload_file_threshold")
helper.Copy(up.Str, "provisioning", "prefix")
if secret, ok := helper.Get(up.Str, "provisioning", "shared_secret"); !ok || secret == "generate" {

View file

@ -238,13 +238,11 @@ func (as *ASIntent) UploadMedia(ctx context.Context, roomID id.RoomID, data []by
return
}
const inMemoryUploadThreshold = 5 * 1024 * 1024
type writeToCapturer struct {
type simpleBuffer struct {
data []byte
}
func (w *writeToCapturer) Write(p []byte) (n int, err error) {
func (w *simpleBuffer) Write(p []byte) (n int, err error) {
if w.data == nil {
w.data = p
} else {
@ -253,36 +251,50 @@ func (w *writeToCapturer) Write(p []byte) (n int, err error) {
return len(p), nil
}
func (as *ASIntent) UploadMediaStream(ctx context.Context, roomID id.RoomID, data io.Reader, size int64, fileName, mimeType string) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) {
func (w *simpleBuffer) Seek(offset int64, whence int) (int64, error) {
if whence == io.SeekStart {
if offset == 0 {
w.data = nil
} else {
w.data = w.data[:offset]
}
return offset, nil
}
return 0, fmt.Errorf("unsupported whence value %d", whence)
}
func (as *ASIntent) UploadMediaStream(
ctx context.Context,
roomID id.RoomID,
size int64,
requireFile bool,
fileName,
mimeType string,
cb bridgev2.FileStreamCallback,
) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) {
if size > as.Connector.MediaConfig.UploadSize {
return "", nil, fmt.Errorf("file too large (%.2f MB > %.2f MB)", float64(size)/1000/1000, float64(as.Connector.MediaConfig.UploadSize)/1000/1000)
} else if 0 < size && size < inMemoryUploadThreshold {
var dataBytes []byte
wt, ok := data.(io.WriterTo)
if ok {
capturer := &writeToCapturer{}
_, err = wt.WriteTo(capturer)
if err != nil {
return "", nil, err
}
dataBytes = capturer.data
} else {
dataBytes, err = io.ReadAll(data)
if err != nil {
return "", nil, err
}
}
return as.UploadMedia(ctx, roomID, dataBytes, fileName, mimeType)
}
tempFile, err := os.CreateTemp("", "mautrix-upload-*")
if !requireFile && 0 < size && size < as.Connector.Config.Matrix.UploadFileThreshold {
var buf simpleBuffer
replPath, err := cb(&buf)
if err != nil {
return "", nil, err
} else if replPath != "" {
panic(fmt.Errorf("logic error: replacement path must only be returned if requireFile is true"))
}
return as.UploadMedia(ctx, roomID, buf.data, fileName, mimeType)
}
var tempFile *os.File
tempFile, err = os.CreateTemp("", "mautrix-upload-*")
if err != nil {
return "", nil, fmt.Errorf("failed to create temp file: %w", err)
err = fmt.Errorf("failed to create temp file: %w", err)
return
}
defer func() {
_ = tempFile.Close()
_ = os.Remove(tempFile.Name())
}()
var realSize int64
if roomID != "" {
var encrypted bool
if encrypted, err = as.Matrix.StateStore.IsEncrypted(ctx, roomID); err != nil {
@ -292,27 +304,50 @@ func (as *ASIntent) UploadMediaStream(ctx context.Context, roomID id.RoomID, dat
file = &event.EncryptedFileInfo{
EncryptedFile: *attachment.NewEncryptedFile(),
}
encryptStream := file.EncryptStream(data)
realSize, err = io.Copy(tempFile, encryptStream)
if err != nil {
return "", nil, fmt.Errorf("failed to write to temp file: %w", err)
}
err = encryptStream.Close()
if err != nil {
return "", nil, fmt.Errorf("failed to finalize encryption: %w", err)
}
mimeType = "application/octet-stream"
fileName = ""
}
} else {
realSize, err = io.Copy(tempFile, data)
}
var replPath string
replPath, err = cb(tempFile)
if err != nil {
err = fmt.Errorf("failed to write to temp file: %w", err)
}
var replFile *os.File
if replPath != "" {
replFile, err = os.OpenFile(replPath, os.O_RDWR, 0)
if err != nil {
return "", nil, fmt.Errorf("failed to write to temp file: %w", err)
err = fmt.Errorf("failed to open replacement file: %w", err)
return
}
} else {
replFile = tempFile
_, err = replFile.Seek(0, io.SeekStart)
if err != nil {
err = fmt.Errorf("failed to seek to start of temp file: %w", err)
return
}
}
if file != nil {
err = file.EncryptFile(replFile)
if err != nil {
err = fmt.Errorf("failed to encrypt file: %w", err)
return
}
_, err = replFile.Seek(0, io.SeekStart)
if err != nil {
err = fmt.Errorf("failed to seek to start of temp file after encrypting: %w", err)
return
}
}
info, err := replFile.Stat()
if err != nil {
err = fmt.Errorf("failed to get temp file info: %w", err)
return
}
url, err = as.doUploadReq(ctx, file, mautrix.ReqUploadMedia{
Content: tempFile,
ContentLength: realSize,
ContentLength: info.Size(),
ContentType: mimeType,
FileName: fileName,
})

View file

@ -198,6 +198,9 @@ matrix:
# Whether created rooms should have federation enabled. If false, created portal rooms
# will never be federated. Changing this option requires recreating rooms.
federate_rooms: true
# The threshold as bytes after which the bridge should roundtrip uploads via the disk
# rather than keeping the whole file in memory.
upload_file_threshold: 5242880
# Settings for provisioning API
provisioning:

View file

@ -8,6 +8,7 @@ package bridgev2
import (
"context"
"io"
"time"
"github.com/gorilla/mux"
@ -78,6 +79,14 @@ type MatrixSendExtra struct {
PartIndex int
}
// FileStreamCallback is a callback function for file uploads that roundtrip via disk.
//
// The parameter is either a file or an in-memory buffer depending on the size of the file and whether the requireFile flag was set.
//
// The first return value can specify a file path to use instead of the original temp file.
// Returning a replacement path is only valid if the parameter is a file.
type FileStreamCallback func(file io.WriteSeeker) (string, error)
type MatrixAPI interface {
GetMXID() id.UserID
@ -88,6 +97,7 @@ type MatrixAPI interface {
MarkTyping(ctx context.Context, roomID id.RoomID, typingType TypingType, timeout time.Duration) error
DownloadMedia(ctx context.Context, uri id.ContentURIString, file *event.EncryptedFileInfo) ([]byte, error)
UploadMedia(ctx context.Context, roomID id.RoomID, data []byte, fileName, mimeType string) (url id.ContentURIString, file *event.EncryptedFileInfo, err error)
UploadMediaStream(ctx context.Context, roomID id.RoomID, size int64, requireFile bool, fileName, mimeType string, cb FileStreamCallback) (url id.ContentURIString, file *event.EncryptedFileInfo, err error)
SetDisplayName(ctx context.Context, name string) error
SetAvatarURL(ctx context.Context, avatarURL id.ContentURIString) error

View file

@ -127,6 +127,43 @@ func (ef *EncryptedFile) EncryptInPlace(data []byte) {
ef.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString(checksum[:])
}
type ReadWriterAt interface {
io.WriterAt
io.Reader
}
// EncryptFile encrypts the given file in-place and updates the SHA256 hash in the EncryptedFile struct.
func (ef *EncryptedFile) EncryptFile(file ReadWriterAt) error {
err := ef.decodeKeys(false)
if err != nil {
return err
}
block, _ := aes.NewCipher(ef.decoded.key[:])
stream := cipher.NewCTR(block, ef.decoded.iv[:])
hasher := sha256.New()
buf := make([]byte, 32*1024)
var writePtr int64
var n int
for {
n, err = file.Read(buf)
if err != nil && !errors.Is(err, io.EOF) {
return err
}
if n == 0 {
break
}
stream.XORKeyStream(buf[:n], buf[:n])
_, err = file.WriteAt(buf[:n], writePtr)
if err != nil {
return err
}
writePtr += int64(n)
hasher.Write(buf[:n])
}
ef.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString(hasher.Sum(nil))
return nil
}
type encryptingReader struct {
stream cipher.Stream
hash hash.Hash