diff --git a/internal/common/common.go b/internal/common/common.go index 3f0901c1..99b5e8d1 100644 --- a/internal/common/common.go +++ b/internal/common/common.go @@ -463,6 +463,7 @@ type ActiveTransfer interface { GetDownloadedSize() int64 GetUploadedSize() int64 GetVirtualPath() string + GetFsPath() string GetStartTime() time.Time SignalClose(err error) Truncate(fsPath string, size int64) (int64, error) diff --git a/internal/common/connection.go b/internal/common/connection.go index 53272405..4bbc89f0 100644 --- a/internal/common/connection.go +++ b/internal/common/connection.go @@ -296,6 +296,20 @@ func (c *BaseConnection) setTimes(fsPath string, atime time.Time, mtime time.Tim return false } +// getInfoForOngoingUpload returns upload statistics for an upload currently in +// progress on this connection. +func (c *BaseConnection) getInfoForOngoingUpload(fsPath string) (os.FileInfo, error) { + c.RLock() + defer c.RUnlock() + + for _, t := range c.activeTransfers { + if t.GetType() == TransferUpload && t.GetFsPath() == fsPath { + return vfs.NewFileInfo(t.GetVirtualPath(), false, t.GetSize(), t.GetStartTime(), false), nil + } + } + return nil, os.ErrNotExist +} + func (c *BaseConnection) truncateOpenHandle(fsPath string, size int64) (int64, error) { c.RLock() defer c.RUnlock() @@ -952,7 +966,19 @@ func (c *BaseConnection) doStatInternal(virtualPath string, mode int, checkFileP info, err = fs.Stat(c.getRealFsPath(fsPath)) } if err != nil { - if !fs.IsNotExist(err) { + isNotExist := fs.IsNotExist(err) + if isNotExist { + // This is primarily useful for atomic storage backends, where files + // become visible only after they are closed. However, since we may + // be proxying (for example) an SFTP server backed by atomic + // storage, and this search only inspects transfers active on the + // current connection (typically just one), the check is inexpensive + // and safe to perform unconditionally. + if info, err := c.getInfoForOngoingUpload(fsPath); err == nil { + return info, nil + } + } + if !isNotExist { c.Log(logger.LevelWarn, "stat error for path %q: %+v", virtualPath, err) } return nil, c.GetFsError(fs, err) diff --git a/internal/common/connection_test.go b/internal/common/connection_test.go index d2a13d2f..78622800 100644 --- a/internal/common/connection_test.go +++ b/internal/common/connection_test.go @@ -1046,6 +1046,37 @@ func TestFilePatterns(t *testing.T) { require.Len(t, filtered, 1) } +func TestStatForOngoingTransfers(t *testing.T) { + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: xid.New().String(), + Password: xid.New().String(), + HomeDir: filepath.Clean(os.TempDir()), + Status: 1, + Permissions: map[string][]string{ + "/": {"*"}, + }, + }, + } + fileName := "file.txt" + conn := NewBaseConnection(xid.New().String(), ProtocolSFTP, "", "", user) + fs := vfs.NewOsFs("", os.TempDir(), "", nil) + tr := NewBaseTransfer(nil, conn, nil, filepath.Join(os.TempDir(), fileName), filepath.Join(os.TempDir(), fileName), + fileName, TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) + _, err := conn.DoStat("/file.txt", 0, false) + assert.NoError(t, err) + err = tr.Close() + assert.NoError(t, err) + tr = NewBaseTransfer(nil, conn, nil, filepath.Join(os.TempDir(), fileName), filepath.Join(os.TempDir(), fileName), + fileName, TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) + _, err = conn.DoStat("/file.txt", 0, false) + assert.Error(t, err) + err = tr.Close() + assert.NoError(t, err) + err = conn.CloseFS() + assert.NoError(t, err) +} + func TestListerAt(t *testing.T) { dir := t.TempDir() user := dataprovider.User{ diff --git a/internal/httpd/handler.go b/internal/httpd/handler.go index 77364b64..15b085fe 100644 --- a/internal/httpd/handler.go +++ b/internal/httpd/handler.go @@ -343,6 +343,10 @@ func (t *throttledReader) GetRealFsPath(_ string) string { return "" } +func (t *throttledReader) GetFsPath() string { + return "" +} + func (t *throttledReader) SetTimes(_ string, _ time.Time, _ time.Time) bool { return false }