diff --git a/platform/wait/wait.go b/platform/wait/wait.go index 6ad817b26..a1fa80903 100644 --- a/platform/wait/wait.go +++ b/platform/wait/wait.go @@ -25,8 +25,9 @@ func For(msg string, timeout, interval time.Duration, f func() (bool, error)) er stop, err := f() if stop { - return nil + return err } + if err != nil { lastErr = err } diff --git a/platform/wait/wait_test.go b/platform/wait/wait_test.go index 9722e6f2e..9d1a4ac34 100644 --- a/platform/wait/wait_test.go +++ b/platform/wait/wait_test.go @@ -1,11 +1,12 @@ package wait import ( + "errors" "testing" "time" ) -func TestForTimeout(t *testing.T) { +func TestFor_timeout(t *testing.T) { c := make(chan error) go func() { c <- For("", 3*time.Second, 1*time.Second, func() (bool, error) { @@ -24,3 +25,42 @@ func TestForTimeout(t *testing.T) { t.Logf("%v", err) } } + +func TestFor_stop(t *testing.T) { + c := make(chan error) + go func() { + c <- For("", 3*time.Second, 1*time.Second, func() (bool, error) { + return true, nil + }) + }() + + timeout := time.After(6 * time.Second) + select { + case <-timeout: + t.Fatal("timeout exceeded") + case err := <-c: + if err != nil { + t.Errorf("expected no timeout error; got %v", err) + } + } +} + +func TestFor_stop_error(t *testing.T) { + c := make(chan error) + go func() { + c <- For("", 3*time.Second, 1*time.Second, func() (bool, error) { + return true, errors.New("oops") + }) + }() + + timeout := time.After(6 * time.Second) + select { + case <-timeout: + t.Fatal("timeout exceeded") + case err := <-c: + if err == nil { + t.Errorf("expected error; got %v", err) + } + t.Logf("%v", err) + } +}