Send initial "welcome" message when clients connect.

This can be used to detect server features before performing the
actual "hello" handshake.
This commit is contained in:
Joachim Bauch 2022-07-07 09:57:10 +02:00
parent 32a2f822e0
commit f7db8a38e1
No known key found for this signature in database
GPG key ID: 77C1D22D53E15F02
7 changed files with 183 additions and 60 deletions

View file

@ -156,8 +156,8 @@ func (m *HelloProxyClientMessage) CheckValid() error {
type HelloProxyServerMessage struct {
Version string `json:"version"`
SessionId string `json:"sessionid"`
Server *HelloServerMessageServer `json:"server,omitempty"`
SessionId string `json:"sessionid"`
Server *WelcomeServerMessage `json:"server,omitempty"`
}
// Type "bye"

View file

@ -25,6 +25,7 @@ import (
"encoding/json"
"fmt"
"net/url"
"sort"
"strings"
)
@ -141,6 +142,8 @@ type ServerMessage struct {
Error *Error `json:"error,omitempty"`
Welcome *WelcomeServerMessage `json:"welcome,omitempty"`
Hello *HelloServerMessage `json:"hello,omitempty"`
Bye *ByeServerMessage `json:"bye,omitempty"`
@ -233,6 +236,54 @@ func (e *Error) Error() string {
return e.Message
}
type WelcomeServerMessage struct {
Version string `json:"version"`
Features []string `json:"features,omitempty"`
Country string `json:"country,omitempty"`
}
func NewWelcomeServerMessage(version string, feature ...string) *WelcomeServerMessage {
message := &WelcomeServerMessage{
Version: version,
Features: feature,
}
if len(feature) > 0 {
sort.Strings(message.Features)
}
return message
}
func (m *WelcomeServerMessage) AddFeature(feature ...string) {
newFeatures := make([]string, len(m.Features))
copy(newFeatures, m.Features)
for _, feat := range feature {
found := false
for _, f := range newFeatures {
if f == feat {
found = true
break
}
}
if !found {
newFeatures = append(newFeatures, feat)
}
}
sort.Strings(newFeatures)
m.Features = newFeatures
}
func (m *WelcomeServerMessage) RemoveFeature(feature ...string) {
newFeatures := make([]string, len(m.Features))
copy(newFeatures, m.Features)
for _, feat := range feature {
idx := sort.SearchStrings(newFeatures, feat)
if idx < len(newFeatures) && newFeatures[idx] == feat {
newFeatures = append(newFeatures[:idx], newFeatures[idx+1:]...)
}
}
m.Features = newFeatures
}
const (
HelloClientTypeClient = "client"
HelloClientTypeInternal = "internal"
@ -345,6 +396,7 @@ const (
ServerFeatureAudioVideoPermissions = "audio-video-permissions"
ServerFeatureTransientData = "transient-data"
ServerFeatureInCallAll = "incall-all"
ServerFeatureWelcome = "welcome"
// Features for internal clients only.
ServerFeatureInternalVirtualSessions = "virtual-sessions"
@ -355,27 +407,32 @@ var (
ServerFeatureAudioVideoPermissions,
ServerFeatureTransientData,
ServerFeatureInCallAll,
ServerFeatureWelcome,
}
DefaultFeaturesInternal = []string{
ServerFeatureInternalVirtualSessions,
ServerFeatureTransientData,
ServerFeatureInCallAll,
ServerFeatureWelcome,
}
DefaultWelcomeFeatures = []string{
ServerFeatureAudioVideoPermissions,
ServerFeatureInternalVirtualSessions,
ServerFeatureTransientData,
ServerFeatureInCallAll,
ServerFeatureWelcome,
}
)
type HelloServerMessageServer struct {
Version string `json:"version"`
Features []string `json:"features,omitempty"`
Country string `json:"country,omitempty"`
}
type HelloServerMessage struct {
Version string `json:"version"`
SessionId string `json:"sessionid"`
ResumeId string `json:"resumeid"`
UserId string `json:"userid"`
Server *HelloServerMessageServer `json:"server,omitempty"`
SessionId string `json:"sessionid"`
ResumeId string `json:"resumeid"`
UserId string `json:"userid"`
// TODO: Remove once all clients have switched to the "welcome" message.
Server *WelcomeServerMessage `json:"server,omitempty"`
}
// Type "bye"

View file

@ -24,6 +24,8 @@ package signaling
import (
"encoding/json"
"fmt"
"reflect"
"sort"
"testing"
)
@ -346,3 +348,42 @@ func TestIsChatRefresh(t *testing.T) {
t.Error("message should not be detected as chat refresh")
}
}
func assertEqualStrings(t *testing.T, expected, result []string) {
t.Helper()
if expected == nil {
expected = make([]string, 0)
} else {
sort.Strings(expected)
}
if result == nil {
result = make([]string, 0)
} else {
sort.Strings(result)
}
if !reflect.DeepEqual(expected, result) {
t.Errorf("Expected %+v, got %+v", expected, result)
}
}
func Test_Welcome_AddRemoveFeature(t *testing.T) {
var msg WelcomeServerMessage
assertEqualStrings(t, []string{}, msg.Features)
msg.AddFeature("one", "two", "one")
assertEqualStrings(t, []string{"one", "two"}, msg.Features)
if !sort.StringsAreSorted(msg.Features) {
t.Errorf("features should be sorted, got %+v", msg.Features)
}
msg.AddFeature("three")
assertEqualStrings(t, []string{"one", "two", "three"}, msg.Features)
if !sort.StringsAreSorted(msg.Features) {
t.Errorf("features should be sorted, got %+v", msg.Features)
}
msg.RemoveFeature("three", "one")
assertEqualStrings(t, []string{"two"}, msg.Features)
}

76
hub.go
View file

@ -106,8 +106,9 @@ type Hub struct {
nats NatsClient
upgrader websocket.Upgrader
cookie *securecookie.SecureCookie
info *HelloServerMessageServer
infoInternal *HelloServerMessageServer
info *WelcomeServerMessage
infoInternal *WelcomeServerMessage
welcome atomic.Value // *ServerMessage
stopped int32
stopChan chan bool
@ -297,15 +298,9 @@ func NewHub(config *goconf.ConfigFile, nats NatsClient, r *mux.Router, version s
ReadBufferSize: websocketReadBufferSize,
WriteBufferSize: websocketWriteBufferSize,
},
cookie: securecookie.New([]byte(hashKey), blockBytes).MaxAge(0),
info: &HelloServerMessageServer{
Version: version,
Features: DefaultFeatures,
},
infoInternal: &HelloServerMessageServer{
Version: version,
Features: DefaultFeaturesInternal,
},
cookie: securecookie.New([]byte(hashKey), blockBytes).MaxAge(0),
info: NewWelcomeServerMessage(version, DefaultFeatures...),
infoInternal: NewWelcomeServerMessage(version, DefaultFeaturesInternal...),
stopChan: make(chan bool),
@ -339,6 +334,10 @@ func NewHub(config *goconf.ConfigFile, nats NatsClient, r *mux.Router, version s
geoip: geoip,
geoipOverrides: geoipOverrides,
}
hub.setWelcomeMessage(&ServerMessage{
Type: "welcome",
Welcome: NewWelcomeServerMessage(version, DefaultWelcomeFeatures...),
})
backend.hub = hub
hub.upgrader.CheckOrigin = hub.checkOrigin
r.HandleFunc("/spreed", func(w http.ResponseWriter, r *http.Request) {
@ -348,49 +347,31 @@ func NewHub(config *goconf.ConfigFile, nats NatsClient, r *mux.Router, version s
return hub, nil
}
func addFeature(msg *HelloServerMessageServer, feature string) {
var newFeatures []string
added := false
for _, f := range msg.Features {
newFeatures = append(newFeatures, f)
if f == feature {
added = true
}
}
if !added {
newFeatures = append(newFeatures, feature)
}
msg.Features = newFeatures
func (h *Hub) setWelcomeMessage(msg *ServerMessage) {
h.welcome.Store(msg)
}
func removeFeature(msg *HelloServerMessageServer, feature string) {
var newFeatures []string
for _, f := range msg.Features {
if f != feature {
newFeatures = append(newFeatures, f)
}
}
msg.Features = newFeatures
func (h *Hub) getWelcomeMessage() *ServerMessage {
return h.welcome.Load().(*ServerMessage)
}
func (h *Hub) SetMcu(mcu Mcu) {
h.mcu = mcu
// Create copy of message so it can be updated concurrently.
welcome := *h.getWelcomeMessage()
if mcu == nil {
removeFeature(h.info, ServerFeatureMcu)
removeFeature(h.info, ServerFeatureSimulcast)
removeFeature(h.info, ServerFeatureUpdateSdp)
removeFeature(h.infoInternal, ServerFeatureMcu)
removeFeature(h.infoInternal, ServerFeatureSimulcast)
removeFeature(h.infoInternal, ServerFeatureUpdateSdp)
h.info.RemoveFeature(ServerFeatureMcu, ServerFeatureSimulcast, ServerFeatureUpdateSdp)
h.infoInternal.RemoveFeature(ServerFeatureMcu, ServerFeatureSimulcast, ServerFeatureUpdateSdp)
welcome.Welcome.RemoveFeature(ServerFeatureMcu, ServerFeatureSimulcast, ServerFeatureUpdateSdp)
} else {
log.Printf("Using a timeout of %s for MCU requests", h.mcuTimeout)
addFeature(h.info, ServerFeatureMcu)
addFeature(h.info, ServerFeatureSimulcast)
addFeature(h.info, ServerFeatureUpdateSdp)
addFeature(h.infoInternal, ServerFeatureMcu)
addFeature(h.infoInternal, ServerFeatureSimulcast)
addFeature(h.infoInternal, ServerFeatureUpdateSdp)
h.info.AddFeature(ServerFeatureMcu, ServerFeatureSimulcast, ServerFeatureUpdateSdp)
h.infoInternal.AddFeature(ServerFeatureMcu, ServerFeatureSimulcast, ServerFeatureUpdateSdp)
welcome.Welcome.AddFeature(ServerFeatureMcu, ServerFeatureSimulcast, ServerFeatureUpdateSdp)
}
h.setWelcomeMessage(&welcome)
}
func (h *Hub) checkOrigin(r *http.Request) bool {
@ -398,7 +379,7 @@ func (h *Hub) checkOrigin(r *http.Request) bool {
return true
}
func (h *Hub) GetServerInfo(session Session) *HelloServerMessageServer {
func (h *Hub) GetServerInfo(session Session) *WelcomeServerMessage {
if session.ClientType() == HelloClientTypeInternal {
return h.infoInternal
}
@ -685,6 +666,11 @@ func (h *Hub) startExpectHello(client *Client) {
func (h *Hub) processNewClient(client *Client) {
h.startExpectHello(client)
h.sendWelcome(client)
}
func (h *Hub) sendWelcome(client *Client) {
client.SendMessage(h.getWelcomeMessage())
}
func (h *Hub) newSessionIdData(backend *Backend) *SessionIdData {

View file

@ -473,6 +473,29 @@ func performHousekeeping(hub *Hub, now time.Time) *sync.WaitGroup {
return &wg
}
func TestInitialWelcome(t *testing.T) {
hub, _, _, server := CreateHubForTest(t)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
client := NewTestClientContext(ctx, t, server, hub)
defer client.CloseWithBye()
msg, err := client.RunUntilMessage(ctx)
if err != nil {
t.Fatal(err)
}
if msg.Type != "welcome" {
t.Errorf("Expected \"welcome\" message, got %+v", msg)
} else if msg.Welcome.Version == "" {
t.Errorf("Expected welcome version, got %+v", msg)
} else if len(msg.Welcome.Features) == 0 {
t.Errorf("Expected welcome features, got %+v", msg)
}
}
func TestExpectClientHello(t *testing.T) {
hub, _, _, server := CreateHubForTest(t)

View file

@ -591,7 +591,7 @@ func (s *ProxyServer) processMessage(client *ProxyClient, data []byte) {
Hello: &signaling.HelloProxyServerMessage{
Version: signaling.HelloVersion,
SessionId: session.PublicId(),
Server: &signaling.HelloServerMessageServer{
Server: &signaling.WelcomeServerMessage{
Version: s.version,
Country: s.country,
},

View file

@ -190,9 +190,9 @@ type TestClient struct {
publicId string
}
func NewTestClient(t *testing.T, server *httptest.Server, hub *Hub) *TestClient {
func NewTestClientContext(ctx context.Context, t *testing.T, server *httptest.Server, hub *Hub) *TestClient {
// Reference "hub" to prevent compiler error.
conn, _, err := websocket.DefaultDialer.Dial(getWebsocketUrl(server.URL), nil)
conn, _, err := websocket.DefaultDialer.DialContext(ctx, getWebsocketUrl(server.URL), nil)
if err != nil {
t.Fatal(err)
}
@ -228,6 +228,22 @@ func NewTestClient(t *testing.T, server *httptest.Server, hub *Hub) *TestClient
}
}
func NewTestClient(t *testing.T, server *httptest.Server, hub *Hub) *TestClient {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
client := NewTestClientContext(ctx, t, server, hub)
msg, err := client.RunUntilMessage(ctx)
if err != nil {
t.Fatal(err)
}
if msg.Type != "welcome" {
t.Errorf("Expected welcome message, got %+v", msg)
}
return client
}
func (c *TestClient) CloseWithBye() {
c.SendBye() // nolint
c.Close()