diff --git a/internal/common/connection.go b/internal/common/connection.go index 9ca18ec4..53272405 100644 --- a/internal/common/connection.go +++ b/internal/common/connection.go @@ -780,9 +780,8 @@ func (c *BaseConnection) Copy(virtualSourcePath, virtualTargetPath string) error if err := c.CheckParentDirs(path.Dir(destPath)); err != nil { return err } - done := make(chan bool) - defer close(done) - go keepConnectionAlive(c, done, 2*time.Minute) + stopKeepAlive := keepConnectionAlive(c, 2*time.Minute) + defer stopKeepAlive() return c.doRecursiveCopy(virtualSourcePath, destPath, srcInfo, createTargetDir, 0) } @@ -848,9 +847,8 @@ func (c *BaseConnection) renameInternal(virtualSourcePath, virtualTargetPath str if checkParentDestination { c.CheckParentDirs(path.Dir(virtualTargetPath)) //nolint:errcheck } - done := make(chan bool) - defer close(done) - go keepConnectionAlive(c, done, 2*time.Minute) + stopKeepAlive := keepConnectionAlive(c, 2*time.Minute) + defer stopKeepAlive() files, size, err := fsDst.Rename(fsSourcePath, fsTargetPath, checks) if err != nil { @@ -1869,18 +1867,22 @@ func getPermissionDeniedError(protocol string) error { } } -func keepConnectionAlive(c *BaseConnection, done chan bool, interval time.Duration) { - ticker := time.NewTicker(interval) - defer func() { - ticker.Stop() - }() +func keepConnectionAlive(c *BaseConnection, interval time.Duration) func() { + var timer *time.Timer + var closed atomic.Bool - for { - select { - case <-done: - return - case <-ticker.C: - c.UpdateLastActivity() + task := func() { + c.UpdateLastActivity() + + if !closed.Load() { + timer.Reset(interval) } } + + timer = time.AfterFunc(interval, task) + + return func() { + closed.Store(true) + timer.Stop() + } } diff --git a/internal/common/connection_test.go b/internal/common/connection_test.go index b1f2a506..d2a13d2f 100644 --- a/internal/common/connection_test.go +++ b/internal/common/connection_test.go @@ -627,12 +627,11 @@ func TestErrorResolvePath(t *testing.T) { func TestConnectionKeepAlive(t *testing.T) { conn := NewBaseConnection("", ProtocolWebDAV, "", "", dataprovider.User{}) lastActivity := conn.GetLastActivity() - done := make(chan bool) - go func() { - time.Sleep(200 * time.Millisecond) - close(done) - }() - keepConnectionAlive(conn, done, 50*time.Millisecond) + + stop := keepConnectionAlive(conn, 50*time.Millisecond) + defer stop() + + time.Sleep(200 * time.Millisecond) assert.Greater(t, conn.GetLastActivity(), lastActivity) }