From fca30af201507e65aa212d8997275d3f970ca5f7 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Thu, 10 Feb 2022 13:58:39 +0100 Subject: [PATCH] Add API for transient room data. --- api_signaling.go | 46 +++++++ hub.go | 55 ++++++++ room.go | 18 +++ session.go | 1 + testclient_test.go | 77 +++++++++++ transient_data.go | 189 ++++++++++++++++++++++++++ transient_data_test.go | 294 +++++++++++++++++++++++++++++++++++++++++ 7 files changed, 680 insertions(+) create mode 100644 transient_data.go create mode 100644 transient_data_test.go diff --git a/api_signaling.go b/api_signaling.go index 340fbdb..caae520 100644 --- a/api_signaling.go +++ b/api_signaling.go @@ -53,6 +53,8 @@ type ClientMessage struct { Control *ControlClientMessage `json:"control,omitempty"` Internal *InternalClientMessage `json:"internal,omitempty"` + + TransientData *TransientDataClientMessage `json:"transient,omitempty"` } func (m *ClientMessage) CheckValid() error { @@ -91,6 +93,12 @@ func (m *ClientMessage) CheckValid() error { } else if err := m.Internal.CheckValid(); err != nil { return err } + case "transient": + if m.TransientData == nil { + return fmt.Errorf("transient missing") + } else if err := m.TransientData.CheckValid(); err != nil { + return err + } } return nil } @@ -138,6 +146,8 @@ type ServerMessage struct { Control *ControlServerMessage `json:"control,omitempty"` Event *EventServerMessage `json:"event,omitempty"` + + TransientData *TransientDataServerMessage `json:"transient,omitempty"` } func (r *ServerMessage) CloseAfterSend(session Session) bool { @@ -326,6 +336,7 @@ const ( ServerFeatureMcu = "mcu" ServerFeatureSimulcast = "simulcast" ServerFeatureAudioVideoPermissions = "audio-video-permissions" + ServerFeatureTransientData = "transient-data" // Features for internal clients only. ServerFeatureInternalVirtualSessions = "virtual-sessions" @@ -334,9 +345,11 @@ const ( var ( DefaultFeatures = []string{ ServerFeatureAudioVideoPermissions, + ServerFeatureTransientData, } DefaultFeaturesInternal = []string{ ServerFeatureInternalVirtualSessions, + ServerFeatureTransientData, } ) @@ -636,3 +649,36 @@ type AnswerOfferMessage struct { RoomType string `json:"roomType"` Payload map[string]interface{} `json:"payload"` } + +// Type "transient" + +type TransientDataClientMessage struct { + Type string `json:"type"` + + Key string `json:"key,omitempty"` + Value *json.RawMessage `json:"value,omitempty"` +} + +func (m *TransientDataClientMessage) CheckValid() error { + switch m.Type { + case "set": + if m.Key == "" { + return fmt.Errorf("key missing") + } + // A "nil" value is allowed and will remove the key. + case "remove": + if m.Key == "" { + return fmt.Errorf("key missing") + } + } + return nil +} + +type TransientDataServerMessage struct { + Type string `json:"type"` + + Key string `json:"key,omitempty"` + OldValue interface{} `json:"oldvalue,omitempty"` + Value interface{} `json:"value,omitempty"` + Data map[string]interface{} `json:"data,omitempty"` +} diff --git a/hub.go b/hub.go index 39585ab..80e4b95 100644 --- a/hub.go +++ b/hub.go @@ -832,6 +832,8 @@ func (h *Hub) processMessage(client *Client, data []byte) { h.processControlMsg(client, &message) case "internal": h.processInternalMsg(client, &message) + case "transient": + h.processTransientMsg(client, &message) case "bye": h.processByeMsg(client, &message) case "hello": @@ -1656,6 +1658,59 @@ func (h *Hub) processInternalMsg(client *Client, message *ClientMessage) { } } +func isAllowedToUpdateTransientData(session Session) bool { + if session.ClientType() == HelloClientTypeInternal { + // Internal clients are always allowed. + return true + } + + if session.HasPermission(PERMISSION_TRANSIENT_DATA) { + return true + } + + return false +} + +func (h *Hub) processTransientMsg(client *Client, message *ClientMessage) { + msg := message.TransientData + session := client.GetSession() + if session == nil { + // Client is not connected yet. + return + } + + room := session.GetRoom() + if room == nil { + response := message.NewErrorServerMessage(NewError("not_in_room", "No room joined yet.")) + session.SendMessage(response) + return + } + + switch msg.Type { + case "set": + if !isAllowedToUpdateTransientData(session) { + sendNotAllowed(session, message, "Not allowed to update transient data.") + return + } + + if msg.Value == nil { + room.SetTransientData(msg.Key, nil) + } else { + room.SetTransientData(msg.Key, *msg.Value) + } + case "remove": + if !isAllowedToUpdateTransientData(session) { + sendNotAllowed(session, message, "Not allowed to update transient data.") + return + } + + room.RemoveTransientData(msg.Key) + default: + response := message.NewErrorServerMessage(NewError("ignored", "Unsupported message type.")) + session.SendMessage(response) + } +} + func sendNotAllowed(session *ClientSession, message *ClientMessage, reason string) { response := message.NewErrorServerMessage(NewError("not_allowed", reason)) session.SendMessage(response) diff --git a/room.go b/room.go index ea05207..57ff6e6 100644 --- a/room.go +++ b/room.go @@ -79,6 +79,8 @@ type Room struct { // Timestamps of last NATS backend requests for the different types. lastNatsRoomRequests map[string]int64 + + transientData *TransientData } func GetSubjectForRoomId(roomId string, backend *Backend) string { @@ -139,6 +141,8 @@ func NewRoom(roomId string, properties *json.RawMessage, hub *Hub, n NatsClient, backendSubscription: backendSubscription, lastNatsRoomRequests: make(map[string]int64), + + transientData: NewTransientData(), } go room.run() @@ -331,6 +335,9 @@ func (r *Room) AddSession(session Session, sessionData *json.RawMessage) []Sessi r.publishSessionFlagsChanged(session) } } + if clientSession, ok := session.(*ClientSession); ok { + r.transientData.AddListener(clientSession) + } } return result } @@ -364,6 +371,9 @@ func (r *Room) RemoveSession(session Session) bool { if virtualSession, ok := session.(*VirtualSession); ok { delete(r.virtualSessions, virtualSession) } + if clientSession, ok := session.(*ClientSession); ok { + r.transientData.RemoveListener(clientSession) + } delete(r.inCallSessions, session) delete(r.roomSessionData, sid) if len(r.sessions) > 0 { @@ -792,3 +802,11 @@ func (r *Room) notifyInternalRoomDeleted() { s.(*ClientSession).SendMessage(msg) } } + +func (r *Room) SetTransientData(key string, value interface{}) { + r.transientData.Set(key, value) +} + +func (r *Room) RemoveTransientData(key string) { + r.transientData.Remove(key) +} diff --git a/session.go b/session.go index 41434ac..9a8da88 100644 --- a/session.go +++ b/session.go @@ -35,6 +35,7 @@ var ( PERMISSION_MAY_PUBLISH_VIDEO Permission = "publish-video" PERMISSION_MAY_PUBLISH_SCREEN Permission = "publish-screen" PERMISSION_MAY_CONTROL Permission = "control" + PERMISSION_TRANSIENT_DATA Permission = "transient-data" ) type SessionIdData struct { diff --git a/testclient_test.go b/testclient_test.go index dbb6b1c..ff19fd6 100644 --- a/testclient_test.go +++ b/testclient_test.go @@ -30,6 +30,7 @@ import ( "fmt" "net" "net/http/httptest" + "reflect" "strings" "testing" "time" @@ -112,6 +113,10 @@ func checkMessageType(message *ServerMessage, expectedType string) error { if message.Event == nil { return fmt.Errorf("Expected \"%s\" message, got %+v (%s)", expectedType, message, toJsonString(message)) } + case "transient": + if message.TransientData == nil { + return fmt.Errorf("Expected \"%s\" message, got %+v (%s)", expectedType, message, toJsonString(message)) + } } return nil @@ -407,6 +412,36 @@ func (c *TestClient) SendMessage(recipient MessageClientMessageRecipient, data i return c.WriteJSON(message) } +func (c *TestClient) SetTransientData(key string, value interface{}) error { + payload, err := json.Marshal(value) + if err != nil { + c.t.Fatal(err) + } + + message := &ClientMessage{ + Id: "efgh", + Type: "transient", + TransientData: &TransientDataClientMessage{ + Type: "set", + Key: key, + Value: (*json.RawMessage)(&payload), + }, + } + return c.WriteJSON(message) +} + +func (c *TestClient) RemoveTransientData(key string) error { + message := &ClientMessage{ + Id: "ijkl", + Type: "transient", + TransientData: &TransientDataClientMessage{ + Type: "remove", + Key: key, + }, + } + return c.WriteJSON(message) +} + func (c *TestClient) DrainMessages(ctx context.Context) error { select { case err := <-c.readErrorChan: @@ -764,3 +799,45 @@ func (c *TestClient) RunUntilAnswer(ctx context.Context, answer string) error { return nil } + +func checkMessageTransientSet(message *ServerMessage, key string, value interface{}, oldValue interface{}) error { + if err := checkMessageType(message, "transient"); err != nil { + return err + } else if message.TransientData.Type != "set" { + return fmt.Errorf("Expected transient set, got %+v", message.TransientData) + } else if message.TransientData.Key != key { + return fmt.Errorf("Expected transient set key %s, got %+v", key, message.TransientData) + } else if !reflect.DeepEqual(message.TransientData.Value, value) { + return fmt.Errorf("Expected transient set value %+v, got %+v", value, message.TransientData.Value) + } else if !reflect.DeepEqual(message.TransientData.OldValue, oldValue) { + return fmt.Errorf("Expected transient set old value %+v, got %+v", oldValue, message.TransientData.OldValue) + } + + return nil +} + +func checkMessageTransientRemove(message *ServerMessage, key string, oldValue interface{}) error { + if err := checkMessageType(message, "transient"); err != nil { + return err + } else if message.TransientData.Type != "remove" { + return fmt.Errorf("Expected transient remove, got %+v", message.TransientData) + } else if message.TransientData.Key != key { + return fmt.Errorf("Expected transient remove key %s, got %+v", key, message.TransientData) + } else if !reflect.DeepEqual(message.TransientData.OldValue, oldValue) { + return fmt.Errorf("Expected transient remove old value %+v, got %+v", oldValue, message.TransientData.OldValue) + } + + return nil +} + +func checkMessageTransientInitial(message *ServerMessage, data map[string]interface{}) error { + if err := checkMessageType(message, "transient"); err != nil { + return err + } else if message.TransientData.Type != "initial" { + return fmt.Errorf("Expected transient initial, got %+v", message.TransientData) + } else if !reflect.DeepEqual(message.TransientData.Data, data) { + return fmt.Errorf("Expected transient initial data %+v, got %+v", data, message.TransientData.Data) + } + + return nil +} diff --git a/transient_data.go b/transient_data.go new file mode 100644 index 0000000..392af02 --- /dev/null +++ b/transient_data.go @@ -0,0 +1,189 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2021 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "reflect" + "sync" +) + +type TransientListener interface { + SendMessage(message *ServerMessage) bool +} + +type TransientData struct { + mu sync.Mutex + data map[string]interface{} + listeners map[TransientListener]bool +} + +// NewTransientData creates a new transient data container. +func NewTransientData() *TransientData { + return &TransientData{} +} + +func (t *TransientData) notifySet(key string, prev, value interface{}) { + msg := &ServerMessage{ + Type: "transient", + TransientData: &TransientDataServerMessage{ + Type: "set", + Key: key, + OldValue: prev, + Value: value, + }, + } + for listener := range t.listeners { + listener.SendMessage(msg) + } +} + +func (t *TransientData) notifyDeleted(key string, prev interface{}) { + msg := &ServerMessage{ + Type: "transient", + TransientData: &TransientDataServerMessage{ + Type: "remove", + Key: key, + OldValue: prev, + }, + } + for listener := range t.listeners { + listener.SendMessage(msg) + } +} + +// AddListener adds a new listener to be notified about changes. +func (t *TransientData) AddListener(listener TransientListener) { + t.mu.Lock() + defer t.mu.Unlock() + + if t.listeners == nil { + t.listeners = make(map[TransientListener]bool) + } + t.listeners[listener] = true + if len(t.data) > 0 { + msg := &ServerMessage{ + Type: "transient", + TransientData: &TransientDataServerMessage{ + Type: "initial", + Data: t.data, + }, + } + listener.SendMessage(msg) + } +} + +// RemoveListener removes a previously registered listener. +func (t *TransientData) RemoveListener(listener TransientListener) { + t.mu.Lock() + defer t.mu.Unlock() + + delete(t.listeners, listener) +} + +// Set sets a new value for the given key and notifies listeners +// if the value has been changed. +func (t *TransientData) Set(key string, value interface{}) bool { + if value == nil { + return t.Remove(key) + } + + t.mu.Lock() + defer t.mu.Unlock() + + prev, found := t.data[key] + if found && reflect.DeepEqual(prev, value) { + return false + } + + if t.data == nil { + t.data = make(map[string]interface{}) + } + t.data[key] = value + t.notifySet(key, prev, value) + return true +} + +// CompareAndSet sets a new value for the given key only for a given old value +// and notifies listeners if the value has been changed. +func (t *TransientData) CompareAndSet(key string, old, value interface{}) bool { + if value == nil { + return t.CompareAndRemove(key, old) + } + + t.mu.Lock() + defer t.mu.Unlock() + + prev, found := t.data[key] + if old != nil && (!found || !reflect.DeepEqual(prev, old)) { + return false + } else if old == nil && found { + return false + } + + t.data[key] = value + t.notifySet(key, prev, value) + return true +} + +// Remove deletes the value with the given key and notifies listeners +// if the key was removed. +func (t *TransientData) Remove(key string) bool { + t.mu.Lock() + defer t.mu.Unlock() + + prev, found := t.data[key] + if !found { + return false + } + + delete(t.data, key) + t.notifyDeleted(key, prev) + return true +} + +// CompareAndRemove deletes the value with the given key if it has a given value +// and notifies listeners if the key was removed. +func (t *TransientData) CompareAndRemove(key string, old interface{}) bool { + t.mu.Lock() + defer t.mu.Unlock() + + prev, found := t.data[key] + if !found || !reflect.DeepEqual(prev, old) { + return false + } + + delete(t.data, key) + t.notifyDeleted(key, prev) + return true +} + +// GetData returns a copy of the internal data. +func (t *TransientData) GetData() map[string]interface{} { + t.mu.Lock() + defer t.mu.Unlock() + + result := make(map[string]interface{}) + for k, v := range t.data { + result[k] = v + } + return result +} diff --git a/transient_data_test.go b/transient_data_test.go new file mode 100644 index 0000000..43c99c2 --- /dev/null +++ b/transient_data_test.go @@ -0,0 +1,294 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2021 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "context" + "testing" + "time" +) + +func Test_TransientData(t *testing.T) { + data := NewTransientData() + if data.Set("foo", nil) { + t.Errorf("should not have set value") + } + if !data.Set("foo", "bar") { + t.Errorf("should have set value") + } + if data.Set("foo", "bar") { + t.Errorf("should not have set value") + } + if !data.Set("foo", "baz") { + t.Errorf("should have set value") + } + if data.CompareAndSet("foo", "bar", "lala") { + t.Errorf("should not have set value") + } + if !data.CompareAndSet("foo", "baz", "lala") { + t.Errorf("should have set value") + } + if data.CompareAndSet("test", nil, nil) { + t.Errorf("should not have set value") + } + if !data.CompareAndSet("test", nil, "123") { + t.Errorf("should have set value") + } + if data.CompareAndRemove("test", "1234") { + t.Errorf("should not have removed value") + } + if !data.CompareAndRemove("test", "123") { + t.Errorf("should have removed value") + } + if data.Remove("lala") { + t.Errorf("should not have removed value") + } + if !data.Remove("foo") { + t.Errorf("should have removed value") + } +} + +func Test_TransientMessages(t *testing.T) { + hub, _, _, server, shutdown := CreateHubForTest(t) + defer shutdown() + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client1 := NewTestClient(t, server, hub) + defer client1.CloseWithBye() + if err := client1.SendHello(testDefaultUserId + "1"); err != nil { + t.Fatal(err) + } + hello1, err := client1.RunUntilHello(ctx) + if err != nil { + t.Fatal(err) + } + + if err := client1.SetTransientData("foo", "bar"); err != nil { + t.Fatal(err) + } + if msg, err := client1.RunUntilMessage(ctx); err != nil { + t.Fatal(err) + } else { + if err := checkMessageError(msg, "not_in_room"); err != nil { + t.Fatal(err) + } + } + + client2 := NewTestClient(t, server, hub) + defer client2.CloseWithBye() + if err := client2.SendHello(testDefaultUserId + "2"); err != nil { + t.Fatal(err) + } + hello2, err := client2.RunUntilHello(ctx) + if err != nil { + t.Fatal(err) + } + + // Join room by id. + roomId := "test-room" + if room, err := client1.JoinRoom(ctx, roomId); err != nil { + t.Fatal(err) + } else if room.Room.RoomId != roomId { + t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) + } + + // Give message processing some time. + time.Sleep(10 * time.Millisecond) + + if room, err := client2.JoinRoom(ctx, roomId); err != nil { + t.Fatal(err) + } else if room.Room.RoomId != roomId { + t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) + } + + WaitForUsersJoined(ctx, t, client1, hello1, client2, hello2) + + session1 := hub.GetSessionByPublicId(hello1.Hello.SessionId).(*ClientSession) + if session1 == nil { + t.Fatalf("Session %s does not exist", hello1.Hello.SessionId) + } + session2 := hub.GetSessionByPublicId(hello2.Hello.SessionId).(*ClientSession) + if session2 == nil { + t.Fatalf("Session %s does not exist", hello2.Hello.SessionId) + } + + // Client 1 may modify transient data. + session1.SetPermissions([]Permission{PERMISSION_TRANSIENT_DATA}) + // Client 2 may not modify transient data. + session2.SetPermissions([]Permission{}) + + if err := client2.SetTransientData("foo", "bar"); err != nil { + t.Fatal(err) + } + if msg, err := client2.RunUntilMessage(ctx); err != nil { + t.Fatal(err) + } else { + if err := checkMessageError(msg, "not_allowed"); err != nil { + t.Fatal(err) + } + } + + if err := client1.SetTransientData("foo", "bar"); err != nil { + t.Fatal(err) + } + + if msg, err := client1.RunUntilMessage(ctx); err != nil { + t.Fatal(err) + } else { + if err := checkMessageTransientSet(msg, "foo", "bar", nil); err != nil { + t.Fatal(err) + } + } + if msg, err := client2.RunUntilMessage(ctx); err != nil { + t.Fatal(err) + } else { + if err := checkMessageTransientSet(msg, "foo", "bar", nil); err != nil { + t.Fatal(err) + } + } + + if err := client2.RemoveTransientData("foo"); err != nil { + t.Fatal(err) + } + if msg, err := client2.RunUntilMessage(ctx); err != nil { + t.Fatal(err) + } else { + if err := checkMessageError(msg, "not_allowed"); err != nil { + t.Fatal(err) + } + } + + // Setting the same value is ignored by the server. + if err := client1.SetTransientData("foo", "bar"); err != nil { + t.Fatal(err) + } + ctx2, cancel2 := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel2() + + if msg, err := client1.RunUntilMessage(ctx2); err != nil { + if err != context.DeadlineExceeded { + t.Fatal(err) + } + } else { + t.Errorf("Expected no payload, got %+v", msg) + } + + data := map[string]interface{}{ + "hello": "world", + } + if err := client1.SetTransientData("foo", data); err != nil { + t.Fatal(err) + } + + if msg, err := client1.RunUntilMessage(ctx); err != nil { + t.Fatal(err) + } else { + if err := checkMessageTransientSet(msg, "foo", data, "bar"); err != nil { + t.Fatal(err) + } + } + if msg, err := client2.RunUntilMessage(ctx); err != nil { + t.Fatal(err) + } else { + if err := checkMessageTransientSet(msg, "foo", data, "bar"); err != nil { + t.Fatal(err) + } + } + + if err := client1.RemoveTransientData("foo"); err != nil { + t.Fatal(err) + } + + if msg, err := client1.RunUntilMessage(ctx); err != nil { + t.Fatal(err) + } else { + if err := checkMessageTransientRemove(msg, "foo", data); err != nil { + t.Fatal(err) + } + } + if msg, err := client2.RunUntilMessage(ctx); err != nil { + t.Fatal(err) + } else { + if err := checkMessageTransientRemove(msg, "foo", data); err != nil { + t.Fatal(err) + } + } + + // Removing a non-existing key is ignored by the server. + if err := client1.RemoveTransientData("foo"); err != nil { + t.Fatal(err) + } + ctx3, cancel3 := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel3() + + if msg, err := client1.RunUntilMessage(ctx3); err != nil { + if err != context.DeadlineExceeded { + t.Fatal(err) + } + } else { + t.Errorf("Expected no payload, got %+v", msg) + } + + if err := client1.SetTransientData("abc", data); err != nil { + t.Fatal(err) + } + + client3 := NewTestClient(t, server, hub) + defer client3.CloseWithBye() + if err := client3.SendHello(testDefaultUserId + "3"); err != nil { + t.Fatal(err) + } + hello3, err := client3.RunUntilHello(ctx) + if err != nil { + t.Fatal(err) + } + + if room, err := client3.JoinRoom(ctx, roomId); err != nil { + t.Fatal(err) + } else if room.Room.RoomId != roomId { + t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) + } + + ignored, err := client3.RunUntilJoinedAndReturnIgnored(ctx, hello1.Hello, hello2.Hello, hello3.Hello) + if err != nil { + t.Fatal(err) + } + + var msg *ServerMessage + if len(ignored) == 0 { + if msg, err = client3.RunUntilMessage(ctx); err != nil { + t.Fatal(err) + } + } else if len(ignored) == 1 { + msg = ignored[0] + } else { + t.Fatalf("Received too many messages: %+v", ignored) + } + + if err := checkMessageTransientInitial(msg, map[string]interface{}{ + "abc": data, + }); err != nil { + t.Fatal(err) + } +}