From f7db8a38e1046047b2f48659e9f879ccb818f253 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Thu, 7 Jul 2022 09:57:10 +0200 Subject: [PATCH] Send initial "welcome" message when clients connect. This can be used to detect server features before performing the actual "hello" handshake. --- api_proxy.go | 4 +-- api_signaling.go | 77 +++++++++++++++++++++++++++++++++++++------ api_signaling_test.go | 41 +++++++++++++++++++++++ hub.go | 76 +++++++++++++++++------------------------- hub_test.go | 23 +++++++++++++ proxy/proxy_server.go | 2 +- testclient_test.go | 20 +++++++++-- 7 files changed, 183 insertions(+), 60 deletions(-) diff --git a/api_proxy.go b/api_proxy.go index 8f511bc..093f227 100644 --- a/api_proxy.go +++ b/api_proxy.go @@ -156,8 +156,8 @@ func (m *HelloProxyClientMessage) CheckValid() error { type HelloProxyServerMessage struct { Version string `json:"version"` - SessionId string `json:"sessionid"` - Server *HelloServerMessageServer `json:"server,omitempty"` + SessionId string `json:"sessionid"` + Server *WelcomeServerMessage `json:"server,omitempty"` } // Type "bye" diff --git a/api_signaling.go b/api_signaling.go index cef25c8..930663c 100644 --- a/api_signaling.go +++ b/api_signaling.go @@ -25,6 +25,7 @@ import ( "encoding/json" "fmt" "net/url" + "sort" "strings" ) @@ -141,6 +142,8 @@ type ServerMessage struct { Error *Error `json:"error,omitempty"` + Welcome *WelcomeServerMessage `json:"welcome,omitempty"` + Hello *HelloServerMessage `json:"hello,omitempty"` Bye *ByeServerMessage `json:"bye,omitempty"` @@ -233,6 +236,54 @@ func (e *Error) Error() string { return e.Message } +type WelcomeServerMessage struct { + Version string `json:"version"` + Features []string `json:"features,omitempty"` + Country string `json:"country,omitempty"` +} + +func NewWelcomeServerMessage(version string, feature ...string) *WelcomeServerMessage { + message := &WelcomeServerMessage{ + Version: version, + Features: feature, + } + if len(feature) > 0 { + sort.Strings(message.Features) + } + return message +} + +func (m *WelcomeServerMessage) AddFeature(feature ...string) { + newFeatures := make([]string, len(m.Features)) + copy(newFeatures, m.Features) + for _, feat := range feature { + found := false + for _, f := range newFeatures { + if f == feat { + found = true + break + } + } + if !found { + newFeatures = append(newFeatures, feat) + } + } + sort.Strings(newFeatures) + m.Features = newFeatures +} + +func (m *WelcomeServerMessage) RemoveFeature(feature ...string) { + newFeatures := make([]string, len(m.Features)) + copy(newFeatures, m.Features) + for _, feat := range feature { + idx := sort.SearchStrings(newFeatures, feat) + if idx < len(newFeatures) && newFeatures[idx] == feat { + newFeatures = append(newFeatures[:idx], newFeatures[idx+1:]...) + } + } + m.Features = newFeatures +} + const ( HelloClientTypeClient = "client" HelloClientTypeInternal = "internal" @@ -345,6 +396,7 @@ const ( ServerFeatureAudioVideoPermissions = "audio-video-permissions" ServerFeatureTransientData = "transient-data" ServerFeatureInCallAll = "incall-all" + ServerFeatureWelcome = "welcome" // Features for internal clients only. ServerFeatureInternalVirtualSessions = "virtual-sessions" @@ -355,27 +407,32 @@ var ( ServerFeatureAudioVideoPermissions, ServerFeatureTransientData, ServerFeatureInCallAll, + ServerFeatureWelcome, } DefaultFeaturesInternal = []string{ ServerFeatureInternalVirtualSessions, ServerFeatureTransientData, ServerFeatureInCallAll, + ServerFeatureWelcome, + } + DefaultWelcomeFeatures = []string{ + ServerFeatureAudioVideoPermissions, + ServerFeatureInternalVirtualSessions, + ServerFeatureTransientData, + ServerFeatureInCallAll, + ServerFeatureWelcome, } ) -type HelloServerMessageServer struct { - Version string `json:"version"` - Features []string `json:"features,omitempty"` - Country string `json:"country,omitempty"` -} - type HelloServerMessage struct { Version string `json:"version"` - SessionId string `json:"sessionid"` - ResumeId string `json:"resumeid"` - UserId string `json:"userid"` - Server *HelloServerMessageServer `json:"server,omitempty"` + SessionId string `json:"sessionid"` + ResumeId string `json:"resumeid"` + UserId string `json:"userid"` + + // TODO: Remove once all clients have switched to the "welcome" message. + Server *WelcomeServerMessage `json:"server,omitempty"` } // Type "bye" diff --git a/api_signaling_test.go b/api_signaling_test.go index 2ddbd18..6e9bc7a 100644 --- a/api_signaling_test.go +++ b/api_signaling_test.go @@ -24,6 +24,8 @@ package signaling import ( "encoding/json" "fmt" + "reflect" + "sort" "testing" ) @@ -346,3 +348,42 @@ func TestIsChatRefresh(t *testing.T) { t.Error("message should not be detected as chat refresh") } } + +func assertEqualStrings(t *testing.T, expected, result []string) { + t.Helper() + + if expected == nil { + expected = make([]string, 0) + } else { + sort.Strings(expected) + } + if result == nil { + result = make([]string, 0) + } else { + sort.Strings(result) + } + + if !reflect.DeepEqual(expected, result) { + t.Errorf("Expected %+v, got %+v", expected, result) + } +} + +func Test_Welcome_AddRemoveFeature(t *testing.T) { + var msg WelcomeServerMessage + assertEqualStrings(t, []string{}, msg.Features) + + msg.AddFeature("one", "two", "one") + assertEqualStrings(t, []string{"one", "two"}, msg.Features) + if !sort.StringsAreSorted(msg.Features) { + t.Errorf("features should be sorted, got %+v", msg.Features) + } + + msg.AddFeature("three") + assertEqualStrings(t, []string{"one", "two", "three"}, msg.Features) + if !sort.StringsAreSorted(msg.Features) { + t.Errorf("features should be sorted, got %+v", msg.Features) + } + + msg.RemoveFeature("three", "one") + assertEqualStrings(t, []string{"two"}, msg.Features) +} diff --git a/hub.go b/hub.go index c7e131c..81ca1b1 100644 --- a/hub.go +++ b/hub.go @@ -106,8 +106,9 @@ type Hub struct { nats NatsClient upgrader websocket.Upgrader cookie *securecookie.SecureCookie - info *HelloServerMessageServer - infoInternal *HelloServerMessageServer + info *WelcomeServerMessage + infoInternal *WelcomeServerMessage + welcome atomic.Value // *ServerMessage stopped int32 stopChan chan bool @@ -297,15 +298,9 @@ func NewHub(config *goconf.ConfigFile, nats NatsClient, r *mux.Router, version s ReadBufferSize: websocketReadBufferSize, WriteBufferSize: websocketWriteBufferSize, }, - cookie: securecookie.New([]byte(hashKey), blockBytes).MaxAge(0), - info: &HelloServerMessageServer{ - Version: version, - Features: DefaultFeatures, - }, - infoInternal: &HelloServerMessageServer{ - Version: version, - Features: DefaultFeaturesInternal, - }, + cookie: securecookie.New([]byte(hashKey), blockBytes).MaxAge(0), + info: NewWelcomeServerMessage(version, DefaultFeatures...), + infoInternal: NewWelcomeServerMessage(version, DefaultFeaturesInternal...), stopChan: make(chan bool), @@ -339,6 +334,10 @@ func NewHub(config *goconf.ConfigFile, nats NatsClient, r *mux.Router, version s geoip: geoip, geoipOverrides: geoipOverrides, } + hub.setWelcomeMessage(&ServerMessage{ + Type: "welcome", + Welcome: NewWelcomeServerMessage(version, DefaultWelcomeFeatures...), + }) backend.hub = hub hub.upgrader.CheckOrigin = hub.checkOrigin r.HandleFunc("/spreed", func(w http.ResponseWriter, r *http.Request) { @@ -348,49 +347,31 @@ func NewHub(config *goconf.ConfigFile, nats NatsClient, r *mux.Router, version s return hub, nil } -func addFeature(msg *HelloServerMessageServer, feature string) { - var newFeatures []string - added := false - for _, f := range msg.Features { - newFeatures = append(newFeatures, f) - if f == feature { - added = true - } - } - if !added { - newFeatures = append(newFeatures, feature) - } - msg.Features = newFeatures +func (h *Hub) setWelcomeMessage(msg *ServerMessage) { + h.welcome.Store(msg) } -func removeFeature(msg *HelloServerMessageServer, feature string) { - var newFeatures []string - for _, f := range msg.Features { - if f != feature { - newFeatures = append(newFeatures, f) - } - } - msg.Features = newFeatures +func (h *Hub) getWelcomeMessage() *ServerMessage { + return h.welcome.Load().(*ServerMessage) } func (h *Hub) SetMcu(mcu Mcu) { h.mcu = mcu + // Create copy of message so it can be updated concurrently. + welcome := *h.getWelcomeMessage() if mcu == nil { - removeFeature(h.info, ServerFeatureMcu) - removeFeature(h.info, ServerFeatureSimulcast) - removeFeature(h.info, ServerFeatureUpdateSdp) - removeFeature(h.infoInternal, ServerFeatureMcu) - removeFeature(h.infoInternal, ServerFeatureSimulcast) - removeFeature(h.infoInternal, ServerFeatureUpdateSdp) + h.info.RemoveFeature(ServerFeatureMcu, ServerFeatureSimulcast, ServerFeatureUpdateSdp) + h.infoInternal.RemoveFeature(ServerFeatureMcu, ServerFeatureSimulcast, ServerFeatureUpdateSdp) + + welcome.Welcome.RemoveFeature(ServerFeatureMcu, ServerFeatureSimulcast, ServerFeatureUpdateSdp) } else { log.Printf("Using a timeout of %s for MCU requests", h.mcuTimeout) - addFeature(h.info, ServerFeatureMcu) - addFeature(h.info, ServerFeatureSimulcast) - addFeature(h.info, ServerFeatureUpdateSdp) - addFeature(h.infoInternal, ServerFeatureMcu) - addFeature(h.infoInternal, ServerFeatureSimulcast) - addFeature(h.infoInternal, ServerFeatureUpdateSdp) + h.info.AddFeature(ServerFeatureMcu, ServerFeatureSimulcast, ServerFeatureUpdateSdp) + h.infoInternal.AddFeature(ServerFeatureMcu, ServerFeatureSimulcast, ServerFeatureUpdateSdp) + + welcome.Welcome.AddFeature(ServerFeatureMcu, ServerFeatureSimulcast, ServerFeatureUpdateSdp) } + h.setWelcomeMessage(&welcome) } func (h *Hub) checkOrigin(r *http.Request) bool { @@ -398,7 +379,7 @@ func (h *Hub) checkOrigin(r *http.Request) bool { return true } -func (h *Hub) GetServerInfo(session Session) *HelloServerMessageServer { +func (h *Hub) GetServerInfo(session Session) *WelcomeServerMessage { if session.ClientType() == HelloClientTypeInternal { return h.infoInternal } @@ -685,6 +666,11 @@ func (h *Hub) startExpectHello(client *Client) { func (h *Hub) processNewClient(client *Client) { h.startExpectHello(client) + h.sendWelcome(client) +} + +func (h *Hub) sendWelcome(client *Client) { + client.SendMessage(h.getWelcomeMessage()) } func (h *Hub) newSessionIdData(backend *Backend) *SessionIdData { diff --git a/hub_test.go b/hub_test.go index c45ebe9..2583c71 100644 --- a/hub_test.go +++ b/hub_test.go @@ -473,6 +473,29 @@ func performHousekeeping(hub *Hub, now time.Time) *sync.WaitGroup { return &wg } +func TestInitialWelcome(t *testing.T) { + hub, _, _, server := CreateHubForTest(t) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client := NewTestClientContext(ctx, t, server, hub) + defer client.CloseWithBye() + + msg, err := client.RunUntilMessage(ctx) + if err != nil { + t.Fatal(err) + } + + if msg.Type != "welcome" { + t.Errorf("Expected \"welcome\" message, got %+v", msg) + } else if msg.Welcome.Version == "" { + t.Errorf("Expected welcome version, got %+v", msg) + } else if len(msg.Welcome.Features) == 0 { + t.Errorf("Expected welcome features, got %+v", msg) + } +} + func TestExpectClientHello(t *testing.T) { hub, _, _, server := CreateHubForTest(t) diff --git a/proxy/proxy_server.go b/proxy/proxy_server.go index fc919cc..72a2feb 100644 --- a/proxy/proxy_server.go +++ b/proxy/proxy_server.go @@ -591,7 +591,7 @@ func (s *ProxyServer) processMessage(client *ProxyClient, data []byte) { Hello: &signaling.HelloProxyServerMessage{ Version: signaling.HelloVersion, SessionId: session.PublicId(), - Server: &signaling.HelloServerMessageServer{ + Server: &signaling.WelcomeServerMessage{ Version: s.version, Country: s.country, }, diff --git a/testclient_test.go b/testclient_test.go index 264d818..4cf40f4 100644 --- a/testclient_test.go +++ b/testclient_test.go @@ -190,9 +190,9 @@ type TestClient struct { publicId string } -func NewTestClient(t *testing.T, server *httptest.Server, hub *Hub) *TestClient { +func NewTestClientContext(ctx context.Context, t *testing.T, server *httptest.Server, hub *Hub) *TestClient { // Reference "hub" to prevent compiler error. - conn, _, err := websocket.DefaultDialer.Dial(getWebsocketUrl(server.URL), nil) + conn, _, err := websocket.DefaultDialer.DialContext(ctx, getWebsocketUrl(server.URL), nil) if err != nil { t.Fatal(err) } @@ -228,6 +228,22 @@ func NewTestClient(t *testing.T, server *httptest.Server, hub *Hub) *TestClient } } +func NewTestClient(t *testing.T, server *httptest.Server, hub *Hub) *TestClient { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + client := NewTestClientContext(ctx, t, server, hub) + msg, err := client.RunUntilMessage(ctx) + if err != nil { + t.Fatal(err) + } + + if msg.Type != "welcome" { + t.Errorf("Expected welcome message, got %+v", msg) + } + return client +} + func (c *TestClient) CloseWithBye() { c.SendBye() // nolint c.Close()