diff --git a/api_signaling.go b/api_signaling.go index 1c48c78..db4036b 100644 --- a/api_signaling.go +++ b/api_signaling.go @@ -28,6 +28,7 @@ import ( "net/url" "sort" "strings" + "time" "github.com/golang-jwt/jwt/v4" ) @@ -847,6 +848,7 @@ type TransientDataClientMessage struct { Key string `json:"key,omitempty"` Value *json.RawMessage `json:"value,omitempty"` + TTL time.Duration `json:"ttl,omitempty"` } func (m *TransientDataClientMessage) CheckValid() error { diff --git a/backend_server_test.go b/backend_server_test.go index 82f3ee8..e9d25bd 100644 --- a/backend_server_test.go +++ b/backend_server_test.go @@ -626,7 +626,8 @@ func TestBackendServer_RoomDisinviteDifferentRooms(t *testing.T) { if err != nil { t.Fatal(err) } - if _, err := client2.RunUntilHello(ctx); err != nil { + hello2, err := client2.RunUntilHello(ctx) + if err != nil { t.Fatal(err) } @@ -635,16 +636,14 @@ func TestBackendServer_RoomDisinviteDifferentRooms(t *testing.T) { if _, err := client1.JoinRoom(ctx, roomId1); err != nil { t.Fatal(err) } + if err := client1.RunUntilJoined(ctx, hello1.Hello); err != nil { + t.Error(err) + } roomId2 := "test-room2" if _, err := client2.JoinRoom(ctx, roomId2); err != nil { t.Fatal(err) } - - // Ignore "join" events. - if err := client1.DrainMessages(ctx); err != nil { - t.Error(err) - } - if err := client2.DrainMessages(ctx); err != nil { + if err := client2.RunUntilJoined(ctx, hello2.Hello); err != nil { t.Error(err) } @@ -702,6 +701,7 @@ func TestBackendServer_RoomDisinviteDifferentRooms(t *testing.T) { UserIds: []string{ testDefaultUserId, }, + Properties: (*json.RawMessage)(&testRoomProperties), }, } diff --git a/docs/standalone-signaling-api-v1.md b/docs/standalone-signaling-api-v1.md index a53dd26..914cf7d 100644 --- a/docs/standalone-signaling-api-v1.md +++ b/docs/standalone-signaling-api-v1.md @@ -817,14 +817,17 @@ Message format (Client -> Server): "transient": { "type": "set", "key": "sample-key", - "value": "any-json-object" + "value": "any-json-object", + "ttl": "optional-ttl" } } - The `key` must be a string. - The `value` can be of any type (i.e. string, number, array, object, etc.). +- The `ttl` is the time to live in nanoseconds. The value will be removed after + that time (if it is still present). - Requests to set a value that is already present for the key are silently - ignored. + ignored. Any TTL value will be updated / removed. Message format (Server -> Client): diff --git a/hub.go b/hub.go index ee6b1eb..f7ef8fb 100644 --- a/hub.go +++ b/hub.go @@ -1965,9 +1965,9 @@ func (h *Hub) processTransientMsg(client *Client, message *ClientMessage) { } if msg.Value == nil { - room.SetTransientData(msg.Key, nil) + room.SetTransientDataTTL(msg.Key, nil, msg.TTL) } else { - room.SetTransientData(msg.Key, *msg.Value) + room.SetTransientDataTTL(msg.Key, *msg.Value, msg.TTL) } case "remove": if !isAllowedToUpdateTransientData(session) { diff --git a/room.go b/room.go index a4d0e09..e0a533a 100644 --- a/room.go +++ b/room.go @@ -1059,6 +1059,10 @@ func (r *Room) SetTransientData(key string, value interface{}) { r.transientData.Set(key, value) } +func (r *Room) SetTransientDataTTL(key string, value interface{}, ttl time.Duration) { + r.transientData.SetTTL(key, value, ttl) +} + func (r *Room) RemoveTransientData(key string) { r.transientData.Remove(key) } diff --git a/testclient_test.go b/testclient_test.go index 2f2fde7..acb64e6 100644 --- a/testclient_test.go +++ b/testclient_test.go @@ -578,7 +578,7 @@ func (c *TestClient) SendInternalRemoveSession(msg *RemoveSessionInternalClientM return c.WriteJSON(message) } -func (c *TestClient) SetTransientData(key string, value interface{}) error { +func (c *TestClient) SetTransientData(key string, value interface{}, ttl time.Duration) error { payload, err := json.Marshal(value) if err != nil { c.t.Fatal(err) @@ -591,6 +591,7 @@ func (c *TestClient) SetTransientData(key string, value interface{}) error { Type: "set", Key: key, Value: (*json.RawMessage)(&payload), + TTL: ttl, }, } return c.WriteJSON(message) diff --git a/transient_data.go b/transient_data.go index 392af02..120a454 100644 --- a/transient_data.go +++ b/transient_data.go @@ -24,6 +24,7 @@ package signaling import ( "reflect" "sync" + "time" ) type TransientListener interface { @@ -34,6 +35,8 @@ type TransientData struct { mu sync.Mutex data map[string]interface{} listeners map[TransientListener]bool + timers map[string]*time.Timer + ttlCh chan<- struct{} } // NewTransientData creates a new transient data container. @@ -99,9 +102,59 @@ func (t *TransientData) RemoveListener(listener TransientListener) { delete(t.listeners, listener) } +func (t *TransientData) updateTTL(key string, value interface{}, ttl time.Duration) { + if ttl <= 0 { + delete(t.timers, key) + } else { + t.removeAfterTTL(key, value, ttl) + } +} + +func (t *TransientData) removeAfterTTL(key string, value interface{}, ttl time.Duration) { + if ttl <= 0 { + return + } + + if old, found := t.timers[key]; found { + old.Stop() + } + + timer := time.AfterFunc(ttl, func() { + t.mu.Lock() + defer t.mu.Unlock() + + t.compareAndRemove(key, value) + if t.ttlCh != nil { + select { + case t.ttlCh <- struct{}{}: + default: + } + } + }) + if t.timers == nil { + t.timers = make(map[string]*time.Timer) + } + t.timers[key] = timer +} + +func (t *TransientData) doSet(key string, value interface{}, prev interface{}, ttl time.Duration) { + if t.data == nil { + t.data = make(map[string]interface{}) + } + t.data[key] = value + t.notifySet(key, prev, value) + t.removeAfterTTL(key, value, ttl) +} + // Set sets a new value for the given key and notifies listeners // if the value has been changed. func (t *TransientData) Set(key string, value interface{}) bool { + return t.SetTTL(key, value, 0) +} + +// SetTTL sets a new value for the given key with a time-to-live and notifies +// listeners if the value has been changed. +func (t *TransientData) SetTTL(key string, value interface{}, ttl time.Duration) bool { if value == nil { return t.Remove(key) } @@ -111,20 +164,24 @@ func (t *TransientData) Set(key string, value interface{}) bool { prev, found := t.data[key] if found && reflect.DeepEqual(prev, value) { + t.updateTTL(key, value, ttl) return false } - if t.data == nil { - t.data = make(map[string]interface{}) - } - t.data[key] = value - t.notifySet(key, prev, value) + t.doSet(key, value, prev, ttl) return true } // CompareAndSet sets a new value for the given key only for a given old value // and notifies listeners if the value has been changed. func (t *TransientData) CompareAndSet(key string, old, value interface{}) bool { + return t.CompareAndSetTTL(key, old, value, 0) +} + +// CompareAndSetTTL sets a new value for the given key with a time-to-live, +// only for a given old value and notifies listeners if the value has been +// changed. +func (t *TransientData) CompareAndSetTTL(key string, old, value interface{}, ttl time.Duration) bool { if value == nil { return t.CompareAndRemove(key, old) } @@ -139,11 +196,19 @@ func (t *TransientData) CompareAndSet(key string, old, value interface{}) bool { return false } - t.data[key] = value - t.notifySet(key, prev, value) + t.doSet(key, value, prev, ttl) return true } +func (t *TransientData) doRemove(key string, prev interface{}) { + delete(t.data, key) + if old, found := t.timers[key]; found { + old.Stop() + delete(t.timers, key) + } + t.notifyDeleted(key, prev) +} + // Remove deletes the value with the given key and notifies listeners // if the key was removed. func (t *TransientData) Remove(key string) bool { @@ -155,8 +220,7 @@ func (t *TransientData) Remove(key string) bool { return false } - delete(t.data, key) - t.notifyDeleted(key, prev) + t.doRemove(key, prev) return true } @@ -166,13 +230,16 @@ func (t *TransientData) CompareAndRemove(key string, old interface{}) bool { t.mu.Lock() defer t.mu.Unlock() + return t.compareAndRemove(key, old) +} + +func (t *TransientData) compareAndRemove(key string, old interface{}) bool { prev, found := t.data[key] if !found || !reflect.DeepEqual(prev, old) { return false } - delete(t.data, key) - t.notifyDeleted(key, prev) + t.doRemove(key, prev) return true } diff --git a/transient_data_test.go b/transient_data_test.go index 598fac2..41ca971 100644 --- a/transient_data_test.go +++ b/transient_data_test.go @@ -27,6 +27,13 @@ import ( "time" ) +func (t *TransientData) SetTTLChannel(ch chan<- struct{}) { + t.mu.Lock() + defer t.mu.Unlock() + + t.ttlCh = ch +} + func Test_TransientData(t *testing.T) { data := NewTransientData() if data.Set("foo", nil) { @@ -53,6 +60,9 @@ func Test_TransientData(t *testing.T) { if !data.CompareAndSet("test", nil, "123") { t.Errorf("should have set value") } + if data.CompareAndSet("test", nil, "456") { + t.Errorf("should not have set value") + } if data.CompareAndRemove("test", "1234") { t.Errorf("should not have removed value") } @@ -65,6 +75,61 @@ func Test_TransientData(t *testing.T) { if !data.Remove("foo") { t.Errorf("should have removed value") } + + ttlCh := make(chan struct{}) + data.SetTTLChannel(ttlCh) + if !data.SetTTL("test", "1234", time.Millisecond) { + t.Errorf("should have set value") + } + if value := data.GetData()["test"]; value != "1234" { + t.Errorf("expected 1234, got %v", value) + } + // Data is removed after the TTL + <-ttlCh + if value := data.GetData()["test"]; value != nil { + t.Errorf("expected no value, got %v", value) + } + + if !data.SetTTL("test", "1234", time.Millisecond) { + t.Errorf("should have set value") + } + if value := data.GetData()["test"]; value != "1234" { + t.Errorf("expected 1234, got %v", value) + } + if !data.SetTTL("test", "2345", 3*time.Millisecond) { + t.Errorf("should have set value") + } + if value := data.GetData()["test"]; value != "2345" { + t.Errorf("expected 2345, got %v", value) + } + // Data is removed after the TTL only if the value still matches + time.Sleep(2 * time.Millisecond) + if value := data.GetData()["test"]; value != "2345" { + t.Errorf("expected 2345, got %v", value) + } + // Data is removed after the (second) TTL + <-ttlCh + if value := data.GetData()["test"]; value != nil { + t.Errorf("expected no value, got %v", value) + } + + // Setting existing key will update the TTL + if !data.SetTTL("test", "1234", time.Millisecond) { + t.Errorf("should have set value") + } + if data.SetTTL("test", "1234", 3*time.Millisecond) { + t.Errorf("should not have set value") + } + // Data still exists after the first TTL + time.Sleep(2 * time.Millisecond) + if value := data.GetData()["test"]; value != "1234" { + t.Errorf("expected 1234, got %v", value) + } + // Data is removed after the (updated) TTL + <-ttlCh + if value := data.GetData()["test"]; value != nil { + t.Errorf("expected no value, got %v", value) + } } func Test_TransientMessages(t *testing.T) { @@ -83,7 +148,7 @@ func Test_TransientMessages(t *testing.T) { t.Fatal(err) } - if err := client1.SetTransientData("foo", "bar"); err != nil { + if err := client1.SetTransientData("foo", "bar", 0); err != nil { t.Fatal(err) } if msg, err := client1.RunUntilMessage(ctx); err != nil { @@ -137,7 +202,7 @@ func Test_TransientMessages(t *testing.T) { // Client 2 may not modify transient data. session2.SetPermissions([]Permission{}) - if err := client2.SetTransientData("foo", "bar"); err != nil { + if err := client2.SetTransientData("foo", "bar", 0); err != nil { t.Fatal(err) } if msg, err := client2.RunUntilMessage(ctx); err != nil { @@ -148,7 +213,7 @@ func Test_TransientMessages(t *testing.T) { } } - if err := client1.SetTransientData("foo", "bar"); err != nil { + if err := client1.SetTransientData("foo", "bar", 0); err != nil { t.Fatal(err) } @@ -179,7 +244,7 @@ func Test_TransientMessages(t *testing.T) { } // Setting the same value is ignored by the server. - if err := client1.SetTransientData("foo", "bar"); err != nil { + if err := client1.SetTransientData("foo", "bar", 0); err != nil { t.Fatal(err) } ctx2, cancel2 := context.WithTimeout(context.Background(), 100*time.Millisecond) @@ -196,7 +261,7 @@ func Test_TransientMessages(t *testing.T) { data := map[string]interface{}{ "hello": "world", } - if err := client1.SetTransientData("foo", data); err != nil { + if err := client1.SetTransientData("foo", data, 0); err != nil { t.Fatal(err) } @@ -249,7 +314,7 @@ func Test_TransientMessages(t *testing.T) { t.Errorf("Expected no payload, got %+v", msg) } - if err := client1.SetTransientData("abc", data); err != nil { + if err := client1.SetTransientData("abc", data, 10*time.Millisecond); err != nil { t.Fatal(err) } @@ -290,4 +355,11 @@ func Test_TransientMessages(t *testing.T) { }); err != nil { t.Fatal(err) } + + time.Sleep(10 * time.Millisecond) + if msg, err = client3.RunUntilMessage(ctx); err != nil { + t.Fatal(err) + } else if err := checkMessageTransientRemove(msg, "abc", data); err != nil { + t.Fatal(err) + } }