Merge pull request #575 from strukturag/transient-ttl

Support TTL for transient data.
This commit is contained in:
Joachim Bauch 2023-10-12 11:31:38 +02:00 committed by GitHub
commit c17c5fd444
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 178 additions and 29 deletions

View file

@ -28,6 +28,7 @@ import (
"net/url"
"sort"
"strings"
"time"
"github.com/golang-jwt/jwt/v4"
)
@ -847,6 +848,7 @@ type TransientDataClientMessage struct {
Key string `json:"key,omitempty"`
Value *json.RawMessage `json:"value,omitempty"`
TTL time.Duration `json:"ttl,omitempty"`
}
func (m *TransientDataClientMessage) CheckValid() error {

View file

@ -626,7 +626,8 @@ func TestBackendServer_RoomDisinviteDifferentRooms(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if _, err := client2.RunUntilHello(ctx); err != nil {
hello2, err := client2.RunUntilHello(ctx)
if err != nil {
t.Fatal(err)
}
@ -635,16 +636,14 @@ func TestBackendServer_RoomDisinviteDifferentRooms(t *testing.T) {
if _, err := client1.JoinRoom(ctx, roomId1); err != nil {
t.Fatal(err)
}
if err := client1.RunUntilJoined(ctx, hello1.Hello); err != nil {
t.Error(err)
}
roomId2 := "test-room2"
if _, err := client2.JoinRoom(ctx, roomId2); err != nil {
t.Fatal(err)
}
// Ignore "join" events.
if err := client1.DrainMessages(ctx); err != nil {
t.Error(err)
}
if err := client2.DrainMessages(ctx); err != nil {
if err := client2.RunUntilJoined(ctx, hello2.Hello); err != nil {
t.Error(err)
}
@ -702,6 +701,7 @@ func TestBackendServer_RoomDisinviteDifferentRooms(t *testing.T) {
UserIds: []string{
testDefaultUserId,
},
Properties: (*json.RawMessage)(&testRoomProperties),
},
}

View file

@ -817,14 +817,17 @@ Message format (Client -> Server):
"transient": {
"type": "set",
"key": "sample-key",
"value": "any-json-object"
"value": "any-json-object",
"ttl": "optional-ttl"
}
}
- The `key` must be a string.
- The `value` can be of any type (i.e. string, number, array, object, etc.).
- The `ttl` is the time to live in nanoseconds. The value will be removed after
that time (if it is still present).
- Requests to set a value that is already present for the key are silently
ignored.
ignored. Any TTL value will be updated / removed.
Message format (Server -> Client):

4
hub.go
View file

@ -1965,9 +1965,9 @@ func (h *Hub) processTransientMsg(client *Client, message *ClientMessage) {
}
if msg.Value == nil {
room.SetTransientData(msg.Key, nil)
room.SetTransientDataTTL(msg.Key, nil, msg.TTL)
} else {
room.SetTransientData(msg.Key, *msg.Value)
room.SetTransientDataTTL(msg.Key, *msg.Value, msg.TTL)
}
case "remove":
if !isAllowedToUpdateTransientData(session) {

View file

@ -1059,6 +1059,10 @@ func (r *Room) SetTransientData(key string, value interface{}) {
r.transientData.Set(key, value)
}
func (r *Room) SetTransientDataTTL(key string, value interface{}, ttl time.Duration) {
r.transientData.SetTTL(key, value, ttl)
}
func (r *Room) RemoveTransientData(key string) {
r.transientData.Remove(key)
}

View file

@ -578,7 +578,7 @@ func (c *TestClient) SendInternalRemoveSession(msg *RemoveSessionInternalClientM
return c.WriteJSON(message)
}
func (c *TestClient) SetTransientData(key string, value interface{}) error {
func (c *TestClient) SetTransientData(key string, value interface{}, ttl time.Duration) error {
payload, err := json.Marshal(value)
if err != nil {
c.t.Fatal(err)
@ -591,6 +591,7 @@ func (c *TestClient) SetTransientData(key string, value interface{}) error {
Type: "set",
Key: key,
Value: (*json.RawMessage)(&payload),
TTL: ttl,
},
}
return c.WriteJSON(message)

View file

@ -24,6 +24,7 @@ package signaling
import (
"reflect"
"sync"
"time"
)
type TransientListener interface {
@ -34,6 +35,8 @@ type TransientData struct {
mu sync.Mutex
data map[string]interface{}
listeners map[TransientListener]bool
timers map[string]*time.Timer
ttlCh chan<- struct{}
}
// NewTransientData creates a new transient data container.
@ -99,9 +102,59 @@ func (t *TransientData) RemoveListener(listener TransientListener) {
delete(t.listeners, listener)
}
func (t *TransientData) updateTTL(key string, value interface{}, ttl time.Duration) {
if ttl <= 0 {
delete(t.timers, key)
} else {
t.removeAfterTTL(key, value, ttl)
}
}
func (t *TransientData) removeAfterTTL(key string, value interface{}, ttl time.Duration) {
if ttl <= 0 {
return
}
if old, found := t.timers[key]; found {
old.Stop()
}
timer := time.AfterFunc(ttl, func() {
t.mu.Lock()
defer t.mu.Unlock()
t.compareAndRemove(key, value)
if t.ttlCh != nil {
select {
case t.ttlCh <- struct{}{}:
default:
}
}
})
if t.timers == nil {
t.timers = make(map[string]*time.Timer)
}
t.timers[key] = timer
}
func (t *TransientData) doSet(key string, value interface{}, prev interface{}, ttl time.Duration) {
if t.data == nil {
t.data = make(map[string]interface{})
}
t.data[key] = value
t.notifySet(key, prev, value)
t.removeAfterTTL(key, value, ttl)
}
// Set sets a new value for the given key and notifies listeners
// if the value has been changed.
func (t *TransientData) Set(key string, value interface{}) bool {
return t.SetTTL(key, value, 0)
}
// SetTTL sets a new value for the given key with a time-to-live and notifies
// listeners if the value has been changed.
func (t *TransientData) SetTTL(key string, value interface{}, ttl time.Duration) bool {
if value == nil {
return t.Remove(key)
}
@ -111,20 +164,24 @@ func (t *TransientData) Set(key string, value interface{}) bool {
prev, found := t.data[key]
if found && reflect.DeepEqual(prev, value) {
t.updateTTL(key, value, ttl)
return false
}
if t.data == nil {
t.data = make(map[string]interface{})
}
t.data[key] = value
t.notifySet(key, prev, value)
t.doSet(key, value, prev, ttl)
return true
}
// CompareAndSet sets a new value for the given key only for a given old value
// and notifies listeners if the value has been changed.
func (t *TransientData) CompareAndSet(key string, old, value interface{}) bool {
return t.CompareAndSetTTL(key, old, value, 0)
}
// CompareAndSetTTL sets a new value for the given key with a time-to-live,
// only for a given old value and notifies listeners if the value has been
// changed.
func (t *TransientData) CompareAndSetTTL(key string, old, value interface{}, ttl time.Duration) bool {
if value == nil {
return t.CompareAndRemove(key, old)
}
@ -139,11 +196,19 @@ func (t *TransientData) CompareAndSet(key string, old, value interface{}) bool {
return false
}
t.data[key] = value
t.notifySet(key, prev, value)
t.doSet(key, value, prev, ttl)
return true
}
func (t *TransientData) doRemove(key string, prev interface{}) {
delete(t.data, key)
if old, found := t.timers[key]; found {
old.Stop()
delete(t.timers, key)
}
t.notifyDeleted(key, prev)
}
// Remove deletes the value with the given key and notifies listeners
// if the key was removed.
func (t *TransientData) Remove(key string) bool {
@ -155,8 +220,7 @@ func (t *TransientData) Remove(key string) bool {
return false
}
delete(t.data, key)
t.notifyDeleted(key, prev)
t.doRemove(key, prev)
return true
}
@ -166,13 +230,16 @@ func (t *TransientData) CompareAndRemove(key string, old interface{}) bool {
t.mu.Lock()
defer t.mu.Unlock()
return t.compareAndRemove(key, old)
}
func (t *TransientData) compareAndRemove(key string, old interface{}) bool {
prev, found := t.data[key]
if !found || !reflect.DeepEqual(prev, old) {
return false
}
delete(t.data, key)
t.notifyDeleted(key, prev)
t.doRemove(key, prev)
return true
}

View file

@ -27,6 +27,13 @@ import (
"time"
)
func (t *TransientData) SetTTLChannel(ch chan<- struct{}) {
t.mu.Lock()
defer t.mu.Unlock()
t.ttlCh = ch
}
func Test_TransientData(t *testing.T) {
data := NewTransientData()
if data.Set("foo", nil) {
@ -53,6 +60,9 @@ func Test_TransientData(t *testing.T) {
if !data.CompareAndSet("test", nil, "123") {
t.Errorf("should have set value")
}
if data.CompareAndSet("test", nil, "456") {
t.Errorf("should not have set value")
}
if data.CompareAndRemove("test", "1234") {
t.Errorf("should not have removed value")
}
@ -65,6 +75,61 @@ func Test_TransientData(t *testing.T) {
if !data.Remove("foo") {
t.Errorf("should have removed value")
}
ttlCh := make(chan struct{})
data.SetTTLChannel(ttlCh)
if !data.SetTTL("test", "1234", time.Millisecond) {
t.Errorf("should have set value")
}
if value := data.GetData()["test"]; value != "1234" {
t.Errorf("expected 1234, got %v", value)
}
// Data is removed after the TTL
<-ttlCh
if value := data.GetData()["test"]; value != nil {
t.Errorf("expected no value, got %v", value)
}
if !data.SetTTL("test", "1234", time.Millisecond) {
t.Errorf("should have set value")
}
if value := data.GetData()["test"]; value != "1234" {
t.Errorf("expected 1234, got %v", value)
}
if !data.SetTTL("test", "2345", 3*time.Millisecond) {
t.Errorf("should have set value")
}
if value := data.GetData()["test"]; value != "2345" {
t.Errorf("expected 2345, got %v", value)
}
// Data is removed after the TTL only if the value still matches
time.Sleep(2 * time.Millisecond)
if value := data.GetData()["test"]; value != "2345" {
t.Errorf("expected 2345, got %v", value)
}
// Data is removed after the (second) TTL
<-ttlCh
if value := data.GetData()["test"]; value != nil {
t.Errorf("expected no value, got %v", value)
}
// Setting existing key will update the TTL
if !data.SetTTL("test", "1234", time.Millisecond) {
t.Errorf("should have set value")
}
if data.SetTTL("test", "1234", 3*time.Millisecond) {
t.Errorf("should not have set value")
}
// Data still exists after the first TTL
time.Sleep(2 * time.Millisecond)
if value := data.GetData()["test"]; value != "1234" {
t.Errorf("expected 1234, got %v", value)
}
// Data is removed after the (updated) TTL
<-ttlCh
if value := data.GetData()["test"]; value != nil {
t.Errorf("expected no value, got %v", value)
}
}
func Test_TransientMessages(t *testing.T) {
@ -83,7 +148,7 @@ func Test_TransientMessages(t *testing.T) {
t.Fatal(err)
}
if err := client1.SetTransientData("foo", "bar"); err != nil {
if err := client1.SetTransientData("foo", "bar", 0); err != nil {
t.Fatal(err)
}
if msg, err := client1.RunUntilMessage(ctx); err != nil {
@ -137,7 +202,7 @@ func Test_TransientMessages(t *testing.T) {
// Client 2 may not modify transient data.
session2.SetPermissions([]Permission{})
if err := client2.SetTransientData("foo", "bar"); err != nil {
if err := client2.SetTransientData("foo", "bar", 0); err != nil {
t.Fatal(err)
}
if msg, err := client2.RunUntilMessage(ctx); err != nil {
@ -148,7 +213,7 @@ func Test_TransientMessages(t *testing.T) {
}
}
if err := client1.SetTransientData("foo", "bar"); err != nil {
if err := client1.SetTransientData("foo", "bar", 0); err != nil {
t.Fatal(err)
}
@ -179,7 +244,7 @@ func Test_TransientMessages(t *testing.T) {
}
// Setting the same value is ignored by the server.
if err := client1.SetTransientData("foo", "bar"); err != nil {
if err := client1.SetTransientData("foo", "bar", 0); err != nil {
t.Fatal(err)
}
ctx2, cancel2 := context.WithTimeout(context.Background(), 100*time.Millisecond)
@ -196,7 +261,7 @@ func Test_TransientMessages(t *testing.T) {
data := map[string]interface{}{
"hello": "world",
}
if err := client1.SetTransientData("foo", data); err != nil {
if err := client1.SetTransientData("foo", data, 0); err != nil {
t.Fatal(err)
}
@ -249,7 +314,7 @@ func Test_TransientMessages(t *testing.T) {
t.Errorf("Expected no payload, got %+v", msg)
}
if err := client1.SetTransientData("abc", data); err != nil {
if err := client1.SetTransientData("abc", data, 10*time.Millisecond); err != nil {
t.Fatal(err)
}
@ -290,4 +355,11 @@ func Test_TransientMessages(t *testing.T) {
}); err != nil {
t.Fatal(err)
}
time.Sleep(10 * time.Millisecond)
if msg, err = client3.RunUntilMessage(ctx); err != nil {
t.Fatal(err)
} else if err := checkMessageTransientRemove(msg, "abc", data); err != nil {
t.Fatal(err)
}
}