diff --git a/transient_data.go b/transient_data.go index a96bb3a..af3c640 100644 --- a/transient_data.go +++ b/transient_data.go @@ -95,8 +95,6 @@ type TransientData struct { listeners map[TransientListener]bool // +checklocks:mu timers map[string]*time.Timer - // +checklocks:mu - ttlCh chan<- struct{} } // NewTransientData creates a new transient data container. @@ -181,7 +179,10 @@ func (t *TransientData) RemoveListener(listener TransientListener) { // +checklocks:t.mu func (t *TransientData) updateTTL(key string, value any, ttl time.Duration) { if ttl <= 0 { - delete(t.timers, key) + if old, found := t.timers[key]; found { + old.Stop() + delete(t.timers, key) + } } else { t.removeAfterTTL(key, value, ttl) } @@ -189,25 +190,20 @@ func (t *TransientData) updateTTL(key string, value any, ttl time.Duration) { // +checklocks:t.mu func (t *TransientData) removeAfterTTL(key string, value any, ttl time.Duration) { - if ttl <= 0 { - return - } - if old, found := t.timers[key]; found { old.Stop() } + if ttl <= 0 { + delete(t.timers, key) + return + } + timer := time.AfterFunc(ttl, func() { t.mu.Lock() defer t.mu.Unlock() t.compareAndRemove(key, value) - if t.ttlCh != nil { - select { - case t.ttlCh <- struct{}{}: - default: - } - } }) if t.timers == nil { t.timers = make(map[string]*time.Timer) diff --git a/transient_data_test.go b/transient_data_test.go index e2bee37..e4aff9a 100644 --- a/transient_data_test.go +++ b/transient_data_test.go @@ -26,6 +26,7 @@ import ( "net/http/httptest" "sync" "testing" + "testing/synctest" "time" "github.com/stretchr/testify/assert" @@ -34,59 +35,63 @@ import ( "github.com/strukturag/nextcloud-spreed-signaling/api" ) -func (t *TransientData) SetTTLChannel(ch chan<- struct{}) { - t.mu.Lock() - defer t.mu.Unlock() - - t.ttlCh = ch -} - func Test_TransientData(t *testing.T) { t.Parallel() - assert := assert.New(t) - data := NewTransientData() - assert.False(data.Set("foo", nil)) - assert.True(data.Set("foo", "bar")) - assert.False(data.Set("foo", "bar")) - assert.True(data.Set("foo", "baz")) - assert.False(data.CompareAndSet("foo", "bar", "lala")) - assert.True(data.CompareAndSet("foo", "baz", "lala")) - assert.False(data.CompareAndSet("test", nil, nil)) - assert.True(data.CompareAndSet("test", nil, "123")) - assert.False(data.CompareAndSet("test", nil, "456")) - assert.False(data.CompareAndRemove("test", "1234")) - assert.True(data.CompareAndRemove("test", "123")) - assert.False(data.Remove("lala")) - assert.True(data.Remove("foo")) + SynctestTest(t, func(t *testing.T) { + assert := assert.New(t) + data := NewTransientData() + assert.False(data.Set("foo", nil)) + assert.True(data.Set("foo", "bar")) + assert.False(data.Set("foo", "bar")) + assert.True(data.Set("foo", "baz")) + assert.False(data.CompareAndSet("foo", "bar", "lala")) + assert.True(data.CompareAndSet("foo", "baz", "lala")) + assert.False(data.CompareAndSet("test", nil, nil)) + assert.True(data.CompareAndSet("test", nil, "123")) + assert.False(data.CompareAndSet("test", nil, "456")) + assert.False(data.CompareAndRemove("test", "1234")) + assert.True(data.CompareAndRemove("test", "123")) + assert.False(data.Remove("lala")) + assert.True(data.Remove("foo")) - ttlCh := make(chan struct{}) - data.SetTTLChannel(ttlCh) - assert.True(data.SetTTL("test", "1234", time.Millisecond)) - assert.Equal("1234", data.GetData()["test"]) - // Data is removed after the TTL - <-ttlCh - assert.Nil(data.GetData()["test"]) + assert.True(data.SetTTL("test", "1234", time.Millisecond)) + assert.Equal("1234", data.GetData()["test"]) + // Data is removed after the TTL + start := time.Now() + time.Sleep(time.Millisecond) + synctest.Wait() + assert.Equal(time.Millisecond, time.Since(start)) + assert.Nil(data.GetData()["test"]) - assert.True(data.SetTTL("test", "1234", time.Millisecond)) - assert.Equal("1234", data.GetData()["test"]) - assert.True(data.SetTTL("test", "2345", 3*time.Millisecond)) - assert.Equal("2345", data.GetData()["test"]) - // Data is removed after the TTL only if the value still matches - time.Sleep(2 * time.Millisecond) - assert.Equal("2345", data.GetData()["test"]) - // Data is removed after the (second) TTL - <-ttlCh - assert.Nil(data.GetData()["test"]) + assert.True(data.SetTTL("test", "1234", time.Millisecond)) + assert.Equal("1234", data.GetData()["test"]) + assert.True(data.SetTTL("test", "2345", 3*time.Millisecond)) + assert.Equal("2345", data.GetData()["test"]) + start = time.Now() + // Data is removed after the TTL only if the value still matches + time.Sleep(2 * time.Millisecond) + synctest.Wait() + assert.Equal("2345", data.GetData()["test"]) + // Data is removed after the (second) TTL + time.Sleep(time.Millisecond) + synctest.Wait() + assert.Equal(3*time.Millisecond, time.Since(start)) + assert.Nil(data.GetData()["test"]) - // Setting existing key will update the TTL - assert.True(data.SetTTL("test", "1234", time.Millisecond)) - assert.False(data.SetTTL("test", "1234", 3*time.Millisecond)) - // Data still exists after the first TTL - time.Sleep(2 * time.Millisecond) - assert.Equal("1234", data.GetData()["test"]) - // Data is removed after the (updated) TTL - <-ttlCh - assert.Nil(data.GetData()["test"]) + // Setting existing key will update the TTL + assert.True(data.SetTTL("test", "1234", time.Millisecond)) + assert.False(data.SetTTL("test", "1234", 3*time.Millisecond)) + start = time.Now() + // Data still exists after the first TTL + time.Sleep(2 * time.Millisecond) + synctest.Wait() + assert.Equal("1234", data.GetData()["test"]) + // Data is removed after the (updated) TTL + time.Sleep(time.Millisecond) + synctest.Wait() + assert.Equal(3*time.Millisecond, time.Since(start)) + assert.Nil(data.GetData()["test"]) + }) } type MockTransientListener struct {