diff --git a/api/signaling.go b/api/signaling.go index 27ee638..b9e488a 100644 --- a/api/signaling.go +++ b/api/signaling.go @@ -835,6 +835,8 @@ func (m *MessageClientMessageData) CheckValid() error { return fmt.Errorf("invalid room type: %s", m.RoomType) } switch m.Type { + case "": + return errors.New("type missing") case "offer", "answer": sdpText, ok := GetStringMapEntry[string](m.Payload, "sdp") if !ok { @@ -1333,6 +1335,8 @@ type TransientDataClientMessage struct { func (m *TransientDataClientMessage) CheckValid() error { switch m.Type { + case "": + return errors.New("type missing") case "set": if m.Key == "" { return errors.New("key missing") diff --git a/api/signaling_test.go b/api/signaling_test.go index e7103dc..ed8c54c 100644 --- a/api/signaling_test.go +++ b/api/signaling_test.go @@ -36,6 +36,20 @@ import ( "github.com/strukturag/nextcloud-spreed-signaling/mock" ) +func TestRoomSessionIds(t *testing.T) { + t.Parallel() + + assert := assert.New(t) + + var s1 RoomSessionId = "foo" + assert.False(s1.IsFederated()) + assert.EqualValues("foo", s1.WithoutFederation()) + + var s2 RoomSessionId = "federated|bar" + assert.True(s2.IsFederated()) + assert.EqualValues("bar", s2.WithoutFederation()) +} + type testCheckValid interface { CheckValid() error } @@ -53,6 +67,12 @@ func wrapMessage(messageType string, msg testCheckValid) *ClientMessage { wrapped.Bye = msg.(*ByeClientMessage) case "room": wrapped.Room = msg.(*RoomClientMessage) + case "control": + wrapped.Control = msg.(*ControlClientMessage) + case "internal": + wrapped.Internal = msg.(*InternalClientMessage) + case "transient": + wrapped.TransientData = msg.(*TransientDataClientMessage) default: return nil } @@ -65,19 +85,23 @@ func testMessages(t *testing.T, messageType string, valid_messages []testCheckVa for _, msg := range valid_messages { assert.NoError(msg.CheckValid(), "Message %+v should be valid", msg) - // If the inner message is valid, it should also be valid in a wrapped - // ClientMessage. - if wrapped := wrapMessage(messageType, msg); assert.NotNil(wrapped, "Unknown message type: %s", messageType) { - assert.NoError(wrapped.CheckValid(), "Message %+v should be valid", wrapped) + if messageType != "" { + // If the inner message is valid, it should also be valid in a wrapped + // ClientMessage. + if wrapped := wrapMessage(messageType, msg); assert.NotNil(wrapped, "Unknown message type: %s", messageType) { + assert.NoError(wrapped.CheckValid(), "Message %+v should be valid", wrapped) + } } } for _, msg := range invalid_messages { assert.Error(msg.CheckValid(), "Message %+v should not be valid", msg) - // If the inner message is invalid, it should also be invalid in a - // wrapped ClientMessage. - if wrapped := wrapMessage(messageType, msg); assert.NotNil(wrapped, "Unknown message type: %s", messageType) { - assert.Error(wrapped.CheckValid(), "Message %+v should not be valid", wrapped) + if messageType != "" { + // If the inner message is invalid, it should also be invalid in a + // wrapped ClientMessage. + if wrapped := wrapMessage(messageType, msg); assert.NotNil(wrapped, "Unknown message type: %s", messageType) { + assert.Error(wrapped.CheckValid(), "Message %+v should not be valid", wrapped) + } } } } @@ -142,6 +166,14 @@ func TestHelloClientMessage(t *testing.T) { Version: HelloVersionV2, ResumeId: "the-resume-id", }, + &HelloClientMessage{ + Version: HelloVersionV2, + Auth: &HelloClientMessageAuth{ + Type: "federation", + Params: tokenAuthParams, + Url: "https://domain.invalid", + }, + }, } invalid_messages := []testCheckValid{ // Hello version 1 @@ -222,6 +254,28 @@ func TestHelloClientMessage(t *testing.T) { Url: "https://domain.invalid", }, }, + &HelloClientMessage{ + Version: HelloVersionV2, + Auth: &HelloClientMessageAuth{ + Type: HelloClientTypeFederation, + Params: json.RawMessage("xyz"), // Invalid JSON. + }, + }, + &HelloClientMessage{ + Version: HelloVersionV2, + Auth: &HelloClientMessageAuth{ + Type: HelloClientTypeFederation, + Params: json.RawMessage("{}"), + Url: "https://domain.invalid", + }, + }, + &HelloClientMessage{ + Version: HelloVersionV2, + Auth: &HelloClientMessageAuth{ + Type: HelloClientTypeFederation, + Params: tokenAuthParams, + }, + }, } testMessages(t, "hello", valid_messages, invalid_messages) @@ -315,6 +369,224 @@ func TestMessageClientMessage(t *testing.T) { assert.Error(msg.CheckValid()) } +func TestControlClientMessage(t *testing.T) { + t.Parallel() + valid_messages := []testCheckValid{ + &ControlClientMessage{ + MessageClientMessage{ + Recipient: MessageClientMessageRecipient{ + Type: "session", + SessionId: "the-session-id", + }, + Data: json.RawMessage("{}"), + }, + }, + &ControlClientMessage{ + MessageClientMessage{ + Recipient: MessageClientMessageRecipient{ + Type: "user", + UserId: "the-user-id", + }, + Data: json.RawMessage("{}"), + }, + }, + &ControlClientMessage{ + MessageClientMessage{ + Recipient: MessageClientMessageRecipient{ + Type: "room", + }, + Data: json.RawMessage("{}"), + }, + }, + } + invalid_messages := []testCheckValid{ + &ControlClientMessage{ + MessageClientMessage{}, + }, + &ControlClientMessage{ + MessageClientMessage{ + Recipient: MessageClientMessageRecipient{ + Type: "session", + SessionId: "the-session-id", + }, + }, + }, + &ControlClientMessage{ + MessageClientMessage{ + Recipient: MessageClientMessageRecipient{ + Type: "session", + }, + Data: json.RawMessage("{}"), + }, + }, + &ControlClientMessage{ + MessageClientMessage{ + Recipient: MessageClientMessageRecipient{ + Type: "session", + UserId: "the-user-id", + }, + Data: json.RawMessage("{}"), + }, + }, + &ControlClientMessage{ + MessageClientMessage{ + Recipient: MessageClientMessageRecipient{ + Type: "user", + }, + Data: json.RawMessage("{}"), + }, + }, + &ControlClientMessage{ + MessageClientMessage{ + Recipient: MessageClientMessageRecipient{ + Type: "user", + UserId: "the-user-id", + }, + }, + }, + &ControlClientMessage{ + MessageClientMessage{ + Recipient: MessageClientMessageRecipient{ + Type: "user", + SessionId: "the-user-id", + }, + Data: json.RawMessage("{}"), + }, + }, + &ControlClientMessage{ + MessageClientMessage{ + Recipient: MessageClientMessageRecipient{ + Type: "unknown-type", + }, + Data: json.RawMessage("{}"), + }, + }, + } + testMessages(t, "control", valid_messages, invalid_messages) + + // But a "control" message must be present + msg := ClientMessage{ + Type: "control", + } + assert := assert.New(t) + assert.Error(msg.CheckValid()) +} + +func TestMessageClientMessageData(t *testing.T) { + t.Parallel() + valid_messages := []testCheckValid{ + &MessageClientMessageData{ + Type: "invalid", + RoomType: "video", + }, + &MessageClientMessageData{ + Type: "offer", + RoomType: "video", + Payload: StringMap{ + "sdp": mock.MockSdpOfferAudioAndVideo, + }, + }, + &MessageClientMessageData{ + Type: "answer", + RoomType: "video", + Payload: StringMap{ + "sdp": mock.MockSdpAnswerAudioAndVideo, + }, + }, + &MessageClientMessageData{ + Type: "candidate", + RoomType: "video", + Payload: StringMap{ + "candidate": StringMap{ + "candidate": "", + }, + }, + }, + &MessageClientMessageData{ + Type: "candidate", + RoomType: "video", + Payload: StringMap{ + "candidate": StringMap{ + "candidate": "candidate:0 1 UDP 2122194687 192.0.2.4 61665 typ host", + }, + }, + }, + } + invalid_messages := []testCheckValid{ + &MessageClientMessageData{}, + &MessageClientMessageData{ + RoomType: "invalid", + }, + &MessageClientMessageData{ + Type: "offer", + RoomType: "video", + }, + &MessageClientMessageData{ + Type: "offer", + RoomType: "video", + Payload: StringMap{ + "sdp": 1234, + }, + }, + &MessageClientMessageData{ + Type: "offer", + RoomType: "video", + Payload: StringMap{ + "sdp": "invalid-sdp", + }, + }, + &MessageClientMessageData{ + Type: "answer", + RoomType: "video", + }, + &MessageClientMessageData{ + Type: "answer", + RoomType: "video", + Payload: StringMap{ + "sdp": 1234, + }, + }, + &MessageClientMessageData{ + Type: "answer", + RoomType: "video", + Payload: StringMap{ + "sdp": "invalid-sdp", + }, + }, + &MessageClientMessageData{ + Type: "candidate", + RoomType: "video", + }, + &MessageClientMessageData{ + Type: "candidate", + RoomType: "video", + Payload: StringMap{ + "candidate": "invalid-candidate", + }, + }, + &MessageClientMessageData{ + Type: "candidate", + RoomType: "video", + Payload: StringMap{ + "candidate": StringMap{ + "candidate": 12345, + }, + }, + }, + &MessageClientMessageData{ + Type: "candidate", + RoomType: "video", + Payload: StringMap{ + "candidate": StringMap{ + "candidate": ":", + }, + }, + }, + } + + testMessages(t, "", valid_messages, invalid_messages) +} + func TestByeClientMessage(t *testing.T) { t.Parallel() // Any "bye" message is valid. @@ -335,11 +607,44 @@ func TestByeClientMessage(t *testing.T) { func TestRoomClientMessage(t *testing.T) { t.Parallel() - // Any "room" message is valid. + // Any regular "room" message is valid. valid_messages := []testCheckValid{ &RoomClientMessage{}, + &RoomClientMessage{ + Federation: &RoomFederationMessage{ + SignalingUrl: "http://signaling.domain.invalid/", + NextcloudUrl: "http://nextcloud.domain.invalid", + Token: "the token", + }, + }, + } + invalid_messages := []testCheckValid{ + &RoomClientMessage{ + Federation: &RoomFederationMessage{}, + }, + &RoomClientMessage{ + Federation: &RoomFederationMessage{ + SignalingUrl: ":", + }, + }, + &RoomClientMessage{ + Federation: &RoomFederationMessage{ + SignalingUrl: "http://signaling.domain.invalid", + }, + }, + &RoomClientMessage{ + Federation: &RoomFederationMessage{ + SignalingUrl: "http://signaling.domain.invalid/", + NextcloudUrl: ":", + }, + }, + &RoomClientMessage{ + Federation: &RoomFederationMessage{ + SignalingUrl: "http://signaling.domain.invalid/", + NextcloudUrl: "http://nextcloud.domain.invalid", + }, + }, } - invalid_messages := []testCheckValid{} testMessages(t, "room", valid_messages, invalid_messages) @@ -351,6 +656,177 @@ func TestRoomClientMessage(t *testing.T) { assert.Error(msg.CheckValid()) } +func TestInternalClientMessage(t *testing.T) { + t.Parallel() + valid_messages := []testCheckValid{ + &InternalClientMessage{ + Type: "invalid", + }, + &InternalClientMessage{ + Type: "addsession", + AddSession: &AddSessionInternalClientMessage{ + CommonSessionInternalClientMessage: CommonSessionInternalClientMessage{ + SessionId: "session id", + RoomId: "room id", + }, + }, + }, + &InternalClientMessage{ + Type: "updatesession", + UpdateSession: &UpdateSessionInternalClientMessage{ + CommonSessionInternalClientMessage: CommonSessionInternalClientMessage{ + SessionId: "session id", + RoomId: "room id", + }, + }, + }, + &InternalClientMessage{ + Type: "removesession", + RemoveSession: &RemoveSessionInternalClientMessage{ + CommonSessionInternalClientMessage: CommonSessionInternalClientMessage{ + SessionId: "session id", + RoomId: "room id", + }, + }, + }, + &InternalClientMessage{ + Type: "incall", + InCall: &InCallInternalClientMessage{}, + }, + &InternalClientMessage{ + Type: "dialout", + Dialout: &DialoutInternalClientMessage{ + Type: "invalid", + }, + }, + &InternalClientMessage{ + Type: "dialout", + Dialout: &DialoutInternalClientMessage{ + Type: "error", + Error: &Error{}, + }, + }, + &InternalClientMessage{ + Type: "dialout", + Dialout: &DialoutInternalClientMessage{ + Type: "status", + Status: &DialoutStatusInternalClientMessage{}, + }, + }, + } + invalid_messages := []testCheckValid{ + &InternalClientMessage{}, + &InternalClientMessage{ + Type: "addsession", + }, + &InternalClientMessage{ + Type: "addsession", + AddSession: &AddSessionInternalClientMessage{}, + }, + &InternalClientMessage{ + Type: "addsession", + AddSession: &AddSessionInternalClientMessage{ + CommonSessionInternalClientMessage: CommonSessionInternalClientMessage{ + SessionId: "session id", + }, + }, + }, + &InternalClientMessage{ + Type: "updatesession", + }, + &InternalClientMessage{ + Type: "updatesession", + UpdateSession: &UpdateSessionInternalClientMessage{}, + }, + &InternalClientMessage{ + Type: "updatesession", + UpdateSession: &UpdateSessionInternalClientMessage{ + CommonSessionInternalClientMessage: CommonSessionInternalClientMessage{ + SessionId: "session id", + }, + }, + }, + &InternalClientMessage{ + Type: "removesession", + }, + &InternalClientMessage{ + Type: "removesession", + RemoveSession: &RemoveSessionInternalClientMessage{}, + }, + &InternalClientMessage{ + Type: "removesession", + RemoveSession: &RemoveSessionInternalClientMessage{ + CommonSessionInternalClientMessage: CommonSessionInternalClientMessage{ + SessionId: "session id", + }, + }, + }, + &InternalClientMessage{ + Type: "incall", + }, + &InternalClientMessage{ + Type: "dialout", + }, + &InternalClientMessage{ + Type: "dialout", + Dialout: &DialoutInternalClientMessage{}, + }, + &InternalClientMessage{ + Type: "dialout", + Dialout: &DialoutInternalClientMessage{ + Type: "error", + }, + }, + &InternalClientMessage{ + Type: "dialout", + Dialout: &DialoutInternalClientMessage{ + Type: "status", + }, + }, + } + + testMessages(t, "internal", valid_messages, invalid_messages) + + // But a "internal" message must be present + msg := ClientMessage{ + Type: "internal", + } + assert := assert.New(t) + assert.Error(msg.CheckValid()) +} + +func TestTransientDataClientMessage(t *testing.T) { + t.Parallel() + valid_messages := []testCheckValid{ + &TransientDataClientMessage{ + Type: "set", + Key: "foo", + }, + &TransientDataClientMessage{ + Type: "remove", + Key: "foo", + }, + } + invalid_messages := []testCheckValid{ + &TransientDataClientMessage{}, + &TransientDataClientMessage{ + Type: "set", + }, + &TransientDataClientMessage{ + Type: "remove", + }, + } + + testMessages(t, "transient", valid_messages, invalid_messages) + + // But a "transient" message must be present + msg := ClientMessage{ + Type: "transient", + } + assert := assert.New(t) + assert.Error(msg.CheckValid()) +} + func TestErrorMessages(t *testing.T) { t.Parallel() assert := assert.New(t) diff --git a/client/client.go b/client/client.go index 9679526..796c9db 100644 --- a/client/client.go +++ b/client/client.go @@ -462,33 +462,6 @@ close: return false } -func (c *Client) writeError(e error) bool { // nolint - message := &api.ServerMessage{ - Type: "error", - Error: api.NewError("internal_error", e.Error()), - } - c.mu.Lock() - defer c.mu.Unlock() - if c.conn == nil { - return false - } - - if !c.writeMessageLocked(message) { - return false - } - - closeData := websocket.FormatCloseMessage(websocket.CloseInternalServerErr, e.Error()) - c.conn.SetWriteDeadline(time.Now().Add(writeWait)) // nolint - if err := c.conn.WriteMessage(websocket.CloseMessage, closeData); err != nil { - if sessionId := c.GetSessionId(); sessionId != "" { - c.logger.Printf("Could not send close message to client %s: %v", sessionId, err) - } else { - c.logger.Printf("Could not send close message to %s: %v", c.RemoteAddr(), err) - } - } - return false -} - func (c *Client) writeMessage(message WritableClientMessage) bool { c.mu.Lock() defer c.mu.Unlock() diff --git a/client/client_test.go b/client/client_test.go index bac0f70..6ec13a2 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -111,6 +111,7 @@ func (c *serverClient) GetSessionId() api.PublicSessionId { } func (c *serverClient) OnClosed() { + c.Close() c.handler.removeClient(c) } @@ -120,6 +121,9 @@ func (c *serverClient) OnMessageReceived(message []byte) { var s string if err := json.Unmarshal(message, &s); assert.NoError(c.t, err) { assert.Equal(c.t, "Hello world!", s) + c.sendPing() + assert.EqualValues(c.t, "DE", c.Country()) + assert.False(c.t, c.Client.IsInRoom("room-id")) c.SendMessage(&api.ServerMessage{ Type: "welcome", Welcome: &api.WelcomeServerMessage{ @@ -212,8 +216,15 @@ func (h *testHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { id := h.id.Add(1) client := newTestClient(h, r, conn, id) h.addClient(client) + + closed := make(chan struct{}) + context.AfterFunc(client.Context(), func() { + close(closed) + }) + go client.WritePump() client.ReadPump() + <-closed } type localClient struct { @@ -246,6 +257,10 @@ func (c *localClient) WriteJSON(v any) error { return c.conn.WriteJSON(v) } +func (c *localClient) Write(v []byte) error { + return c.conn.WriteMessage(websocket.BinaryMessage, v) +} + func (c *localClient) ReadJSON(v any) error { return c.conn.ReadJSON(v) } @@ -281,6 +296,14 @@ func TestClient(t *testing.T) { assert.EqualValues(1, clients[0].received.Load()) } + require.NoError(client.Write([]byte("Hello world!"))) + if assert.NoError(client.ReadJSON(&msg)) && + assert.Equal("error", msg.Type) && + assert.NotNil(msg.Error) { + assert.Equal("invalid_format", msg.Error.Code) + assert.Equal("Invalid data format.", msg.Error.Message) + } + require.NoError(client.WriteJSON("Send error")) if assert.NoError(client.ReadJSON(&msg)) && assert.Equal("error", msg.Type) && diff --git a/cmd/proxy/proxy_server.go b/cmd/proxy/proxy_server.go index ec11ba7..53e0051 100644 --- a/cmd/proxy/proxy_server.go +++ b/cmd/proxy/proxy_server.go @@ -265,32 +265,6 @@ func NewProxyServer(ctx context.Context, r *mux.Router, version string, config * return nil, err } - statsAllowed, _ := config.GetString("stats", "allowed_ips") - statsAllowedIps, err := container.ParseIPList(statsAllowed) - if err != nil { - return nil, err - } - - if !statsAllowedIps.Empty() { - logger.Printf("Only allowing access to the stats endpoint from %s", statsAllowed) - } else { - statsAllowedIps = container.DefaultAllowedIPs() - logger.Printf("No IPs configured for the stats endpoint, only allowing access from %s", statsAllowedIps) - } - - trustedProxies, _ := config.GetString("app", "trustedproxies") - trustedProxiesIps, err := container.ParseIPList(trustedProxies) - if err != nil { - return nil, err - } - - if !trustedProxiesIps.Empty() { - logger.Printf("Trusted proxies: %s", trustedProxiesIps) - } else { - trustedProxiesIps = client.DefaultTrustedProxies - logger.Printf("No trusted proxies configured, only allowing for %s", trustedProxiesIps) - } - countryString, _ := config.GetString("app", "country") country := geoip.Country(strings.ToUpper(countryString)) if geoip.IsValidCountry(country) { @@ -350,8 +324,6 @@ func NewProxyServer(ctx context.Context, r *mux.Router, version string, config * logger.Printf("No token id configured, remote streams will be disabled") } - maxIncoming, maxOutgoing := getTargetBandwidths(logger, config) - mcuTimeoutSeconds, _ := config.GetInt("mcu", "timeout") if mcuTimeoutSeconds <= 0 { mcuTimeoutSeconds = defaultMcuTimeoutSeconds @@ -397,10 +369,10 @@ func NewProxyServer(ctx context.Context, r *mux.Router, version string, config * remotePublishers: make(map[string]map[*proxyRemotePublisher]bool), } - result.maxIncoming.Store(maxIncoming) - result.maxOutgoing.Store(maxOutgoing) - result.statsAllowedIps.Store(statsAllowedIps) - result.trustedProxies.Store(trustedProxiesIps) + if err := result.loadConfig(config, false); err != nil { + return nil, err + } + result.upgrader.CheckOrigin = result.checkOrigin statsLoadCurrent.Set(0) @@ -629,7 +601,7 @@ func (s *ProxyServer) ScheduleShutdown() { } } -func (s *ProxyServer) Reload(config *goconf.ConfigFile) { +func (s *ProxyServer) loadConfig(config *goconf.ConfigFile, fromReload bool) error { statsAllowed, _ := config.GetString("stats", "allowed_ips") if statsAllowedIps, err := container.ParseIPList(statsAllowed); err == nil { if !statsAllowedIps.Empty() { @@ -639,8 +611,10 @@ func (s *ProxyServer) Reload(config *goconf.ConfigFile) { s.logger.Printf("No IPs configured for the stats endpoint, only allowing access from %s", statsAllowedIps) } s.statsAllowedIps.Store(statsAllowedIps) - } else { + } else if fromReload { s.logger.Printf("Error parsing allowed stats ips from \"%s\": %s", statsAllowedIps, err) + } else { + return fmt.Errorf("error parsing allowed stats ips from \"%s\": %w", statsAllowedIps, err) } trustedProxies, _ := config.GetString("app", "trustedproxies") @@ -652,18 +626,28 @@ func (s *ProxyServer) Reload(config *goconf.ConfigFile) { s.logger.Printf("No trusted proxies configured, only allowing for %s", trustedProxiesIps) } s.trustedProxies.Store(trustedProxiesIps) - } else { + } else if fromReload { s.logger.Printf("Error parsing trusted proxies from \"%s\": %s", trustedProxies, err) + } else { + return fmt.Errorf("error parsing trusted proxies ips from \"%s\": %w", trustedProxies, err) } maxIncoming, maxOutgoing := getTargetBandwidths(s.logger, config) oldIncoming := s.maxIncoming.Swap(maxIncoming) oldOutgoing := s.maxOutgoing.Swap(maxOutgoing) - if oldIncoming != maxIncoming || oldOutgoing != maxOutgoing { + if fromReload && (oldIncoming != maxIncoming || oldOutgoing != maxOutgoing) { // Notify sessions about updated load / bandwidth usage. go s.sendLoadToAll(s.load.Load(), s.currentIncoming.Load(), s.currentOutgoing.Load()) } + return nil +} + +func (s *ProxyServer) Reload(config *goconf.ConfigFile) { + if err := s.loadConfig(config, true); err != nil { + s.logger.Printf("Error reloading configuration: %s", err) + } + s.tokens.Reload(config) s.mcu.Reload(config) } diff --git a/cmd/proxy/proxy_server_test.go b/cmd/proxy/proxy_server_test.go index f04e55c..4aa212d 100644 --- a/cmd/proxy/proxy_server_test.go +++ b/cmd/proxy/proxy_server_test.go @@ -28,6 +28,7 @@ import ( "crypto/x509" "encoding/pem" "errors" + "fmt" "net" "net/http/httptest" "os" @@ -48,9 +49,11 @@ import ( "github.com/strukturag/nextcloud-spreed-signaling/internal" "github.com/strukturag/nextcloud-spreed-signaling/log" logtest "github.com/strukturag/nextcloud-spreed-signaling/log/test" + "github.com/strukturag/nextcloud-spreed-signaling/mock" "github.com/strukturag/nextcloud-spreed-signaling/proxy" "github.com/strukturag/nextcloud-spreed-signaling/sfu" "github.com/strukturag/nextcloud-spreed-signaling/talk" + "github.com/strukturag/nextcloud-spreed-signaling/test" ) const ( @@ -461,6 +464,7 @@ type PublisherTestMCU struct { type TestPublisherWithBandwidth struct { TestMCUPublisher + t *testing.T bandwidth *sfu.ClientBandwidthInfo } @@ -468,6 +472,25 @@ func (p *TestPublisherWithBandwidth) Bandwidth() *sfu.ClientBandwidthInfo { return p.bandwidth } +func (p *TestPublisherWithBandwidth) SendMessage(ctx context.Context, message *api.MessageClientMessage, data *api.MessageClientMessageData, callback func(error, api.StringMap)) { + switch data.Type { + case "offer": + assert.Equal(p.t, mock.MockSdpOfferAudioAndVideo, data.Payload["sdp"]) + assert.NotNil(p.t, data.OfferSdp) + callback(nil, api.StringMap{ + "type": "answer", + "sdp": mock.MockSdpAnswerAudioAndVideo, + }) + case "requestoffer": + callback(nil, api.StringMap{ + "type": "offer", + "sdp": mock.MockSdpOfferAudioOnly, + }) + default: + callback(fmt.Errorf("type %s not implemented", data.Type), nil) + } +} + func (m *PublisherTestMCU) NewPublisher(ctx context.Context, listener sfu.Listener, id api.PublicSessionId, sid string, streamType sfu.StreamType, settings sfu.NewPublisherSettings, initiator sfu.Initiator) (sfu.Publisher, error) { publisher := &TestPublisherWithBandwidth{ TestMCUPublisher: TestMCUPublisher{ @@ -476,6 +499,7 @@ func (m *PublisherTestMCU) NewPublisher(ctx context.Context, listener sfu.Listen streamType: streamType, }, + t: m.t, bandwidth: &sfu.ClientBandwidthInfo{ Sent: api.BandwidthFromBytes(1000), Received: api.BandwidthFromBytes(2000), @@ -497,6 +521,9 @@ func TestProxyPublisherBandwidth(t *testing.T) { assert := assert.New(t) require := require.New(t) proxyServer, key, server := newProxyServerForTest(t) + t.Cleanup(func() { + assert.EqualValues(0, proxyServer.GetSessionsCount()) + }) proxyServer.maxIncoming.Store(api.BandwidthFromBytes(10000)) proxyServer.maxOutgoing.Store(api.BandwidthFromBytes(10000)) @@ -507,6 +534,8 @@ func TestProxyPublisherBandwidth(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() + assert.EqualValues(0, proxyServer.GetSessionsCount()) + client := NewProxyTestClient(ctx, t, server.URL) defer client.CloseWithBye() @@ -528,12 +557,19 @@ func TestProxyPublisherBandwidth(t *testing.T) { }, })) + var clientId string if message, err := client.RunUntilMessage(ctx); assert.NoError(err) { assert.Equal("2345", message.Id) if err := checkMessageType(message, "command"); assert.NoError(err) { assert.NotEmpty(message.Command.Id) + clientId = message.Command.Id } } + require.NotEmpty(clientId) + + if publisher := proxyServer.GetPublisher(clientId); assert.NotNil(publisher) { + assert.Equal(clientId, proxyServer.GetClientId(publisher)) + } proxyServer.updateLoad() @@ -550,6 +586,66 @@ func TestProxyPublisherBandwidth(t *testing.T) { } } } + + require.NoError(client.WriteJSON(&proxy.ClientMessage{ + Id: "3456", + Type: "payload", + Payload: &proxy.PayloadClientMessage{ + Type: "offer", + ClientId: clientId, + Payload: api.StringMap{ + "type": "offer", + "sdp": mock.MockSdpOfferAudioAndVideo, + }, + }, + })) + if message, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("3456", message.Id) + assert.Equal("payload", message.Type) + if payload := message.Payload; assert.NotNil(payload) { + assert.Equal(clientId, payload.ClientId) + assert.Equal("offer", payload.Type) + assert.Equal("answer", payload.Payload["type"]) + assert.Equal(mock.MockSdpAnswerAudioAndVideo, payload.Payload["sdp"]) + } + } + + require.NoError(client.WriteJSON(&proxy.ClientMessage{ + Id: "4567", + Type: "payload", + Payload: &proxy.PayloadClientMessage{ + Type: "endOfCandidates", + ClientId: clientId, + }, + })) + if message, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("4567", message.Id) + assert.Equal("payload", message.Type) + if payload := message.Payload; assert.NotNil(payload) { + assert.Equal(clientId, payload.ClientId) + assert.Equal("endOfCandidates", payload.Type) + assert.Empty(payload.Payload) + } + } + + require.NoError(client.WriteJSON(&proxy.ClientMessage{ + Id: "5678", + Type: "payload", + Payload: &proxy.PayloadClientMessage{ + Type: "requestoffer", + ClientId: clientId, + }, + })) + if message, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("5678", message.Id) + assert.Equal("payload", message.Type) + if payload := message.Payload; assert.NotNil(payload) { + assert.Equal(clientId, payload.ClientId) + assert.Equal("requestoffer", payload.Type) + assert.Equal("offer", payload.Payload["type"]) + assert.Equal(mock.MockSdpOfferAudioOnly, payload.Payload["sdp"]) + } + } } type HangingTestMCU struct { @@ -1605,3 +1701,332 @@ func TestProxyUnpublishRemoteOnSessionClose(t *testing.T) { assert.Nil(publisher.getRemoteData()) } } + +func TestExpireSessions(t *testing.T) { + t.Parallel() + + test.SynctestTest(t, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + proxyServer, key, server := newProxyServerForTest(t) + server.Close() + + // No-op + proxyServer.expireSessions() + + claims := &proxy.TokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(time.Now().Add(-maxTokenAge / 2)), + Issuer: TokenIdForTest, + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tokenString, err := token.SignedString(key) + require.NoError(err) + + hello := &proxy.HelloClientMessage{ + Version: "1.0", + Token: tokenString, + } + session, err := proxyServer.NewSession(hello) + require.NoError(err) + t.Cleanup(func() { + proxyServer.DeleteSession(session.Sid()) + }) + assert.Same(session, proxyServer.GetSession(session.Sid())) + + proxyServer.expireSessions() + assert.Same(session, proxyServer.GetSession(session.Sid())) + + time.Sleep(sessionExpirationTime) + proxyServer.expireSessions() + + assert.Nil(proxyServer.GetSession(session.Sid())) + }) +} + +func TestScheduleShutdownEmpty(t *testing.T) { + t.Parallel() + + proxyServer, _, _ := newProxyServerForTest(t) + + proxyServer.ScheduleShutdown() + <-proxyServer.ShutdownChannel() +} + +func TestScheduleShutdownNoClients(t *testing.T) { + t.Parallel() + + require := require.New(t) + assert := assert.New(t) + proxyServer, key, server := newProxyServerForTest(t) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client := NewProxyTestClient(ctx, t, server.URL) + defer client.CloseWithBye() + + require.NoError(client.SendHello(key)) + + if hello, err := client.RunUntilHello(ctx); assert.NoError(err) { + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello) + } + + _, err := client.RunUntilLoad(ctx, 0) + assert.NoError(err) + + proxyServer.ScheduleShutdown() + + if msg, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("event", msg.Type) + if event := msg.Event; assert.NotNil(event) { + assert.Equal("shutdown-scheduled", event.Type) + } + } + + <-proxyServer.ShutdownChannel() +} + +func TestScheduleShutdown(t *testing.T) { + t.Parallel() + + require := require.New(t) + assert := assert.New(t) + proxyServer, key, server := newProxyServerForTest(t) + + mcu := NewPublisherTestMCU(t) + proxyServer.mcu = mcu + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client := NewProxyTestClient(ctx, t, server.URL) + defer client.CloseWithBye() + + require.NoError(client.SendHello(key)) + + if hello, err := client.RunUntilHello(ctx); assert.NoError(err) { + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello) + } + + _, err := client.RunUntilLoad(ctx, 0) + assert.NoError(err) + + publisherId := api.PublicSessionId("the-publisher-id") + require.NoError(client.WriteJSON(&proxy.ClientMessage{ + Id: "2345", + Type: "command", + Command: &proxy.CommandClientMessage{ + Type: "create-publisher", + PublisherId: publisherId, + Sid: "1234-abcd", + StreamType: sfu.StreamTypeVideo, + PublisherSettings: &sfu.NewPublisherSettings{ + Bitrate: 1234567, + MediaTypes: sfu.MediaTypeAudio | sfu.MediaTypeVideo, + }, + }, + })) + + var clientId string + if message, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("2345", message.Id) + if err := checkMessageType(message, "command"); assert.NoError(err) { + require.NotEmpty(message.Command.Id) + clientId = message.Command.Id + } + } + + readyChan := make(chan struct{}) + var readyReceived atomic.Bool + go func() { + for { + select { + case <-proxyServer.ShutdownChannel(): + return + case <-readyChan: + readyReceived.Store(true) + case <-ctx.Done(): + assert.NoError(ctx.Err()) + return + } + } + }() + + proxyServer.ScheduleShutdown() + + if msg, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("event", msg.Type) + if event := msg.Event; assert.NotNil(event) { + assert.Equal("shutdown-scheduled", event.Type) + } + } + close(readyChan) + + proxyServer.ScheduleShutdown() + + select { + case <-proxyServer.ShutdownChannel(): + assert.Fail("should only shutdown after all clients closed") + default: + } + require.NoError(client.WriteJSON(&proxy.ClientMessage{ + Id: "4567", + Type: "command", + Command: &proxy.CommandClientMessage{ + Type: "delete-publisher", + ClientId: clientId, + }, + })) + + if message, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("4567", message.Id) + if err := checkMessageType(message, "command"); assert.NoError(err) { + require.NotEmpty(message.Command.Id) + } + } + + <-proxyServer.ShutdownChannel() + assert.True(readyReceived.Load()) +} + +func TestScheduleShutdownOnResume(t *testing.T) { + t.Parallel() + + require := require.New(t) + assert := assert.New(t) + proxyServer, key, server := newProxyServerForTest(t) + + mcu := NewPublisherTestMCU(t) + proxyServer.mcu = mcu + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client := NewProxyTestClient(ctx, t, server.URL) + defer client.CloseWithBye() + + require.NoError(client.SendHello(key)) + + hello, err := client.RunUntilHello(ctx) + if assert.NoError(err) { + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello) + } + + _, err = client.RunUntilLoad(ctx, 0) + assert.NoError(err) + + publisherId := api.PublicSessionId("the-publisher-id") + require.NoError(client.WriteJSON(&proxy.ClientMessage{ + Id: "2345", + Type: "command", + Command: &proxy.CommandClientMessage{ + Type: "create-publisher", + PublisherId: publisherId, + Sid: "1234-abcd", + StreamType: sfu.StreamTypeVideo, + PublisherSettings: &sfu.NewPublisherSettings{ + Bitrate: 1234567, + MediaTypes: sfu.MediaTypeAudio | sfu.MediaTypeVideo, + }, + }, + })) + + var clientId string + if message, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("2345", message.Id) + if err := checkMessageType(message, "command"); assert.NoError(err) { + require.NotEmpty(message.Command.Id) + clientId = message.Command.Id + } + } + + readyChan := make(chan struct{}) + var readyReceived atomic.Bool + go func() { + for { + select { + case <-proxyServer.ShutdownChannel(): + return + case <-readyChan: + readyReceived.Store(true) + case <-ctx.Done(): + assert.NoError(ctx.Err()) + return + } + } + }() + + client.Close() + + proxyServer.ScheduleShutdown() + + client = NewProxyTestClient(ctx, t, server.URL) + defer client.CloseWithBye() + + hello2 := &proxy.ClientMessage{ + Id: "1234", + Type: "hello", + Hello: &proxy.HelloClientMessage{ + Version: "1.0", + Features: []string{}, + ResumeId: hello.Hello.SessionId, + }, + } + require.NoError(client.WriteJSON(hello2)) + + if hello3, err := client.RunUntilHello(ctx); assert.NoError(err) { + assert.Equal(hello.Hello.SessionId, hello3.Hello.SessionId) + } + + if msg, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("event", msg.Type) + if event := msg.Event; assert.NotNil(event) { + assert.Equal("shutdown-scheduled", event.Type) + } + } + + client2 := NewProxyTestClient(ctx, t, server.URL) + defer client2.CloseWithBye() + + require.NoError(client2.SendHello(key)) + + if hello, err := client2.RunUntilHello(ctx); assert.NoError(err) { + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello) + } + + if msg, err := client2.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("event", msg.Type) + if event := msg.Event; assert.NotNil(event) { + assert.Equal("shutdown-scheduled", event.Type) + } + } + close(readyChan) + + proxyServer.ScheduleShutdown() + + select { + case <-proxyServer.ShutdownChannel(): + assert.Fail("should only shutdown after all clients closed") + default: + } + require.NoError(client.WriteJSON(&proxy.ClientMessage{ + Id: "4567", + Type: "command", + Command: &proxy.CommandClientMessage{ + Type: "delete-publisher", + ClientId: clientId, + }, + })) + + if message, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("4567", message.Id) + if err := checkMessageType(message, "command"); assert.NoError(err) { + require.NotEmpty(message.Command.Id) + } + } + + <-proxyServer.ShutdownChannel() + assert.True(readyReceived.Load()) +} diff --git a/cmd/proxy/proxy_session.go b/cmd/proxy/proxy_session.go index 0c12f23..7bc7133 100644 --- a/cmd/proxy/proxy_session.go +++ b/cmd/proxy/proxy_session.go @@ -117,7 +117,7 @@ func (s *ProxySession) LastUsed() time.Time { func (s *ProxySession) IsExpired() bool { expiresAt := s.LastUsed().Add(sessionExpirationTime) - return expiresAt.Before(time.Now()) + return !expiresAt.After(time.Now()) } func (s *ProxySession) MarkUsed() { diff --git a/cmd/proxy/proxy_tokens_etcd_test.go b/cmd/proxy/proxy_tokens_etcd_test.go index 32a44e8..aba5a5b 100644 --- a/cmd/proxy/proxy_tokens_etcd_test.go +++ b/cmd/proxy/proxy_tokens_etcd_test.go @@ -153,3 +153,34 @@ func TestProxyTokensEtcd(t *testing.T) { assert.True(key2.PublicKey.Equal(token.key)) } } + +func TestProxyTokensEtcdReload(t *testing.T) { + t.Parallel() + assert := assert.New(t) + tokens, etcd := newTokensEtcdForTesting(t) + + key1 := generateAndSaveKey(t, etcd, "/foo") + + if token, err := tokens.Get("foo"); assert.NoError(err) && assert.NotNil(token) { + assert.True(key1.PublicKey.Equal(token.key)) + } + + if token, err := tokens.Get("bar"); assert.NoError(err) { + assert.Nil(token) + } + + cfg := goconf.NewConfigFile() + cfg.AddOption("etcd", "endpoints", etcd.Config().ListenClientUrls[0].String()) + cfg.AddOption("tokens", "keyformat", "/reload/%s/key") + + tokens.Reload(cfg) + key2 := generateAndSaveKey(t, etcd, "/reload/bar/key") + + if token, err := tokens.Get("foo"); assert.NoError(err) { + assert.Nil(token) + } + + if token, err := tokens.Get("bar"); assert.NoError(err) && assert.NotNil(token) { + assert.True(key2.PublicKey.Equal(token.key)) + } +} diff --git a/cmd/proxy/proxy_tokens_static.go b/cmd/proxy/proxy_tokens_static.go index 2abb43c..baa9148 100644 --- a/cmd/proxy/proxy_tokens_static.go +++ b/cmd/proxy/proxy_tokens_static.go @@ -84,7 +84,7 @@ func (t *tokensStatic) load(cfg *goconf.ConfigFile, ignoreErrors bool) error { keyData, err := os.ReadFile(filename) if err != nil { if !ignoreErrors { - return fmt.Errorf("could not read public key from %s: %s", filename, err) + return fmt.Errorf("could not read public key from %s: %w", filename, err) } t.logger.Printf("Could not read public key from %s, ignoring: %s", filename, err) @@ -93,7 +93,7 @@ func (t *tokensStatic) load(cfg *goconf.ConfigFile, ignoreErrors bool) error { key, err := jwt.ParseRSAPublicKeyFromPEM(keyData) if err != nil { if !ignoreErrors { - return fmt.Errorf("could not parse public key from %s: %s", filename, err) + return fmt.Errorf("could not parse public key from %s: %w", filename, err) } t.logger.Printf("Could not parse public key from %s, ignoring: %s", filename, err) diff --git a/cmd/proxy/proxy_tokens_static_test.go b/cmd/proxy/proxy_tokens_static_test.go new file mode 100644 index 0000000..662174e --- /dev/null +++ b/cmd/proxy/proxy_tokens_static_test.go @@ -0,0 +1,185 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2026 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package main + +import ( + "crypto/rand" + "crypto/rsa" + "os" + "path" + "testing" + + "github.com/dlintw/goconf" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/strukturag/nextcloud-spreed-signaling/internal" + logtest "github.com/strukturag/nextcloud-spreed-signaling/log/test" +) + +func TestStaticTokens(t *testing.T) { + t.Parallel() + + require := require.New(t) + assert := assert.New(t) + + filename := path.Join(t.TempDir(), "token.pub") + + key1, err := rsa.GenerateKey(rand.Reader, 1024) + require.NoError(err) + require.NoError(internal.WritePublicKey(&key1.PublicKey, filename)) + + logger := logtest.NewLoggerForTest(t) + config := goconf.NewConfigFile() + config.AddOption("tokens", "foo", filename) + + tokens, err := NewProxyTokensStatic(logger, config) + require.NoError(err) + + defer tokens.Close() + + if token, err := tokens.Get("foo"); assert.NoError(err) { + assert.Equal("foo", token.id) + assert.True(key1.PublicKey.Equal(token.key)) + } + + key2, err := rsa.GenerateKey(rand.Reader, 1024) + require.NoError(err) + require.NoError(internal.WritePublicKey(&key2.PublicKey, filename)) + + tokens.Reload(config) + + if token, err := tokens.Get("foo"); assert.NoError(err) { + assert.Equal("foo", token.id) + assert.True(key2.PublicKey.Equal(token.key)) + } +} + +func testStaticTokensMissing(t *testing.T, reload bool) { + require := require.New(t) + assert := assert.New(t) + + filename := path.Join(t.TempDir(), "token.pub") + + logger := logtest.NewLoggerForTest(t) + config := goconf.NewConfigFile() + if !reload { + config.AddOption("tokens", "foo", filename) + } + + tokens, err := NewProxyTokensStatic(logger, config) + if !reload { + assert.ErrorIs(err, os.ErrNotExist) + return + } + + require.NoError(err) + defer tokens.Close() + + config.AddOption("tokens", "foo", filename) + tokens.Reload(config) +} + +func TestStaticTokensMissing(t *testing.T) { + t.Parallel() + + testStaticTokensMissing(t, false) +} + +func TestStaticTokensMissingReload(t *testing.T) { + t.Parallel() + + testStaticTokensMissing(t, true) +} + +func testStaticTokensEmpty(t *testing.T, reload bool) { + require := require.New(t) + assert := assert.New(t) + + logger := logtest.NewLoggerForTest(t) + config := goconf.NewConfigFile() + if !reload { + config.AddOption("tokens", "foo", "") + } + + tokens, err := NewProxyTokensStatic(logger, config) + if !reload { + assert.ErrorContains(err, "no filename given") + return + } + + require.NoError(err) + defer tokens.Close() + + config.AddOption("tokens", "foo", "") + tokens.Reload(config) +} + +func TestStaticTokensEmpty(t *testing.T) { + t.Parallel() + + testStaticTokensEmpty(t, false) +} + +func TestStaticTokensEmptyReload(t *testing.T) { + t.Parallel() + + testStaticTokensEmpty(t, true) +} + +func testStaticTokensInvalidData(t *testing.T, reload bool) { + require := require.New(t) + assert := assert.New(t) + + filename := path.Join(t.TempDir(), "token.pub") + require.NoError(os.WriteFile(filename, []byte("invalid-key-data"), 0600)) + + logger := logtest.NewLoggerForTest(t) + config := goconf.NewConfigFile() + if !reload { + config.AddOption("tokens", "foo", filename) + } + + tokens, err := NewProxyTokensStatic(logger, config) + if !reload { + assert.ErrorContains(err, "could not parse public key") + return + } + + require.NoError(err) + defer tokens.Close() + + config.AddOption("tokens", "foo", filename) + tokens.Reload(config) +} + +func TestStaticTokensInvalidData(t *testing.T) { + t.Parallel() + + testStaticTokensInvalidData(t, false) +} + +func TestStaticTokensInvalidDataReload(t *testing.T) { + t.Parallel() + + testStaticTokensInvalidData(t, true) +} diff --git a/grpc/test/server.go b/grpc/test/server.go index fd1ba21..ef67170 100644 --- a/grpc/test/server.go +++ b/grpc/test/server.go @@ -22,7 +22,10 @@ package test import ( + "context" + "errors" "net" + "net/url" "strconv" "testing" @@ -30,9 +33,12 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/strukturag/nextcloud-spreed-signaling/api" "github.com/strukturag/nextcloud-spreed-signaling/grpc" "github.com/strukturag/nextcloud-spreed-signaling/log" logtest "github.com/strukturag/nextcloud-spreed-signaling/log/test" + "github.com/strukturag/nextcloud-spreed-signaling/sfu" + "github.com/strukturag/nextcloud-spreed-signaling/talk" "github.com/strukturag/nextcloud-spreed-signaling/test" ) @@ -71,3 +77,46 @@ func NewServerForTest(t *testing.T) (server *grpc.Server, addr string) { config := goconf.NewConfigFile() return NewServerForTestWithConfig(t, config) } + +type MockHub struct { +} + +func (h *MockHub) GetSessionIdByResumeId(resumeId api.PrivateSessionId) api.PublicSessionId { + return "" +} + +func (h *MockHub) GetSessionIdByRoomSessionId(roomSessionId api.RoomSessionId) (api.PublicSessionId, error) { + return "", errors.New("not implemented") +} + +func (h *MockHub) IsSessionIdInCall(sessionId api.PublicSessionId, roomId string, backendUrl string) (bool, bool) { + return false, false +} + +func (h *MockHub) DisconnectSessionByRoomSessionId(sessionId api.PublicSessionId, roomSessionId api.RoomSessionId, reason string) { +} + +func (h *MockHub) GetBackend(u *url.URL) *talk.Backend { + return nil +} + +func (h *MockHub) GetInternalSessions(roomId string, backend *talk.Backend) ([]*grpc.InternalSessionData, []*grpc.VirtualSessionData, bool) { + return nil, nil, false +} + +func (h *MockHub) GetTransientEntries(roomId string, backend *talk.Backend) (api.TransientDataEntries, bool) { + return nil, false +} + +func (h *MockHub) GetPublisherIdForSessionId(ctx context.Context, sessionId api.PublicSessionId, streamType sfu.StreamType) (*grpc.GetPublisherIdReply, error) { + return nil, errors.New("not implemented") +} + +func (h *MockHub) ProxySession(request grpc.RpcSessions_ProxySessionServer) error { + return errors.New("not implemented") +} + +var ( + // Compile-time check that MockHub implements the interface. + _ grpc.ServerHub = &MockHub{} +) diff --git a/grpc/test/server_test.go b/grpc/test/server_test.go index bd5e871..323334e 100644 --- a/grpc/test/server_test.go +++ b/grpc/test/server_test.go @@ -22,12 +22,43 @@ package test import ( + "errors" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/strukturag/nextcloud-spreed-signaling/geoip" + "github.com/strukturag/nextcloud-spreed-signaling/grpc" + "github.com/strukturag/nextcloud-spreed-signaling/sfu" + "github.com/strukturag/nextcloud-spreed-signaling/talk" ) +type emptyReceiver struct { +} + +func (r *emptyReceiver) RemoteAddr() string { + return "127.0.0.1" +} + +func (r *emptyReceiver) Country() geoip.Country { + return "DE" +} + +func (r *emptyReceiver) UserAgent() string { + return "testing" +} + +func (r *emptyReceiver) OnProxyMessage(message *grpc.ServerSessionMessage) error { + return errors.New("not implemented") +} + +func (r *emptyReceiver) OnProxyClose(err error) { + // Ignore +} + func TestServer(t *testing.T) { t.Parallel() @@ -37,15 +68,63 @@ func TestServer(t *testing.T) { serverId := "the-test-server-id" server, addr := NewServerForTest(t) server.SetServerId(serverId) + hub := &MockHub{} + server.SetHub(hub) clients, _ := NewClientsForTest(t, addr, nil) require.NoError(clients.WaitForInitialized(t.Context())) + backend := talk.NewCompatBackend(nil) + for _, client := range clients.GetClients() { if id, version, err := client.GetServerId(t.Context()); assert.NoError(err) { assert.Equal(serverId, id) assert.NotEmpty(version) } + + reply, err := client.LookupResumeId(t.Context(), "resume-id") + assert.ErrorIs(err, grpc.ErrNoSuchResumeId) + assert.Nil(reply) + + id, err := client.LookupSessionId(t.Context(), "session-id", "") + if s, ok := status.FromError(err); assert.True(ok) { + assert.Equal(codes.Unknown, s.Code()) + assert.Equal("not implemented", s.Message()) + } + assert.Empty(id) + + if incall, err := client.IsSessionInCall(t.Context(), "session-id", "room-id", ""); assert.NoError(err) { + assert.False(incall) + } + + if internal, virtual, err := client.GetInternalSessions(t.Context(), "room-id", nil); assert.NoError(err) { + assert.Empty(internal) + assert.Empty(virtual) + } + + publisherId, proxyUrl, ip, connToken, publisherToken, err := client.GetPublisherId(t.Context(), "session-id", sfu.StreamTypeVideo) + if s, ok := status.FromError(err); assert.True(ok) { + assert.Equal(codes.Unknown, s.Code()) + assert.Equal("not implemented", s.Message()) + } + assert.Empty(publisherId) + assert.Empty(proxyUrl) + assert.Empty(ip) + assert.Empty(connToken) + assert.Empty(publisherToken) + + if count, err := client.GetSessionCount(t.Context(), ""); assert.NoError(err) { + assert.EqualValues(0, count) + } + + if data, err := client.GetTransientData(t.Context(), "room-id", backend); assert.NoError(err) { + assert.Empty(data) + } + + receiver := &emptyReceiver{} + proxy, err := client.ProxySession(t.Context(), "session-id", receiver) + assert.NoError(err) + assert.NotNil(proxy) } } diff --git a/metrics/prometheus_test.go b/metrics/prometheus_test.go new file mode 100644 index 0000000..801ae3f --- /dev/null +++ b/metrics/prometheus_test.go @@ -0,0 +1,67 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2026 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package metrics + +import ( + "testing" + + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" +) + +func TestRegistration(t *testing.T) { + t.Parallel() + + collectors := []prometheus.Collector{ + prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "signaling", + Subsystem: "test", + Name: "value_total", + Help: "Total value.", + }), + prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: "signaling", + Subsystem: "test", + Name: "value", + Help: "Current value.", + }, []string{"foo", "bar"}), + } + // Can unregister without previous registration + UnregisterAll(collectors...) + RegisterAll(collectors...) + // Can register multiple times + RegisterAll(collectors...) + UnregisterAll(collectors...) +} + +func TestRegistrationError(t *testing.T) { + t.Parallel() + + defer func() { + value := recover() + if err, ok := value.(error); assert.True(t, ok) { + assert.ErrorContains(t, err, "is not a valid metric name") + } + }() + + RegisterAll(prometheus.NewCounter(prometheus.CounterOpts{})) +} diff --git a/pool/buffer_pool_test.go b/pool/buffer_pool_test.go new file mode 100644 index 0000000..f06b969 --- /dev/null +++ b/pool/buffer_pool_test.go @@ -0,0 +1,141 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2026 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package pool + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBufferPool(t *testing.T) { + t.Parallel() + + assert := assert.New(t) + + var pool BufferPool + + buf1 := pool.Get() + assert.NotNil(buf1) + buf2 := pool.Get() + assert.NotSame(buf1, buf2) + + buf1.WriteString("empty string") + pool.Put(buf1) + + buf3 := pool.Get() + assert.Equal(0, buf3.Len()) + + pool.Put(nil) +} + +func TestBufferPoolReadAll(t *testing.T) { + t.Parallel() + + require := require.New(t) + assert := assert.New(t) + + s := "Hello world!" + data := bytes.NewBufferString(s) + + var pool BufferPool + + buf1 := pool.Get() + assert.NotNil(buf1) + pool.Put(buf1) + + buf2, err := pool.ReadAll(data) + require.NoError(err) + assert.Equal(s, buf2.String()) +} + +var errTest = errors.New("test error") + +type errorReader struct{} + +func (e errorReader) Read(b []byte) (int, error) { + return 0, errTest +} + +func TestBufferPoolReadAllError(t *testing.T) { + t.Parallel() + + assert := assert.New(t) + + var pool BufferPool + + buf1 := pool.Get() + assert.NotNil(buf1) + pool.Put(buf1) + + r := &errorReader{} + buf2, err := pool.ReadAll(r) + assert.ErrorIs(err, errTest) + assert.Nil(buf2) +} + +func TestBufferPoolMarshalAsJSON(t *testing.T) { + t.Parallel() + + require := require.New(t) + assert := assert.New(t) + + var pool BufferPool + buf1 := pool.Get() + assert.NotNil(buf1) + pool.Put(buf1) + + s := "Hello world!" + buf2, err := pool.MarshalAsJSON(s) + require.NoError(err) + + assert.Equal(fmt.Sprintf("\"%s\"\n", s), buf2.String()) +} + +type errorMarshaler struct { + json.Marshaler +} + +func (e errorMarshaler) MarshalJSON() ([]byte, error) { + return nil, errTest +} + +func TestBufferPoolMarshalAsJSONError(t *testing.T) { + t.Parallel() + + assert := assert.New(t) + + var pool BufferPool + buf1 := pool.Get() + assert.NotNil(buf1) + pool.Put(buf1) + + var ob errorMarshaler + buf2, err := pool.MarshalAsJSON(ob) + assert.ErrorIs(err, errTest) + assert.Nil(buf2) +} diff --git a/proxy/api_test.go b/proxy/api_test.go new file mode 100644 index 0000000..7d5b1ca --- /dev/null +++ b/proxy/api_test.go @@ -0,0 +1,591 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2026 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package proxy + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/strukturag/nextcloud-spreed-signaling/api" + "github.com/strukturag/nextcloud-spreed-signaling/internal" + "github.com/strukturag/nextcloud-spreed-signaling/mock" + "github.com/strukturag/nextcloud-spreed-signaling/sfu" +) + +func TestValidate(t *testing.T) { + t.Parallel() + + assert := assert.New(t) + testcases := []struct { + message *ClientMessage + reason string + }{ + { + &ClientMessage{}, + "type missing", + }, + { + // Unknown types are ignored. + &ClientMessage{ + Type: "invalid", + }, + "", + }, + { + &ClientMessage{ + Type: "hello", + }, + "hello missing", + }, + { + &ClientMessage{ + Type: "hello", + Hello: &HelloClientMessage{}, + }, + "unsupported hello version", + }, + { + &ClientMessage{ + Type: "hello", + Hello: &HelloClientMessage{ + Version: "abc", + }, + }, + "unsupported hello version", + }, + { + &ClientMessage{ + Type: "hello", + Hello: &HelloClientMessage{ + Version: "1.0", + }, + }, + "token missing", + }, + { + &ClientMessage{ + Type: "hello", + Hello: &HelloClientMessage{ + Version: "1.0", + Token: "token", + }, + }, + "", + }, + { + &ClientMessage{ + Type: "hello", + Hello: &HelloClientMessage{ + Version: "1.0", + ResumeId: "resume-id", + }, + }, + "", + }, + { + &ClientMessage{ + Type: "bye", + }, + "", + }, + { + &ClientMessage{ + Type: "bye", + Bye: &ByeClientMessage{}, + }, + "", + }, + { + &ClientMessage{ + Type: "command", + }, + "command missing", + }, + { + &ClientMessage{ + Type: "command", + Command: &CommandClientMessage{}, + }, + "type missing", + }, + { + // Unknown types are ignored. + &ClientMessage{ + Type: "command", + Command: &CommandClientMessage{ + Type: "invalid", + }, + }, + "", + }, + { + &ClientMessage{ + Type: "command", + Command: &CommandClientMessage{ + Type: "create-publisher", + }, + }, + "stream type missing", + }, + { + &ClientMessage{ + Type: "command", + Command: &CommandClientMessage{ + Type: "create-publisher", + StreamType: sfu.StreamTypeVideo, + }, + }, + "", + }, + { + &ClientMessage{ + Type: "command", + Command: &CommandClientMessage{ + Type: "create-subscriber", + }, + }, + "publisher id missing", + }, + { + &ClientMessage{ + Type: "command", + Command: &CommandClientMessage{ + Type: "create-subscriber", + PublisherId: "foo", + }, + }, + "stream type missing", + }, + { + &ClientMessage{ + Type: "command", + Command: &CommandClientMessage{ + Type: "create-subscriber", + PublisherId: "foo", + StreamType: sfu.StreamTypeVideo, + }, + }, + "", + }, + { + &ClientMessage{ + Type: "command", + Command: &CommandClientMessage{ + Type: "create-subscriber", + PublisherId: "foo", + StreamType: sfu.StreamTypeVideo, + RemoteUrl: "http://domain.invalid", + }, + }, + "remote token missing", + }, + { + &ClientMessage{ + Type: "command", + Command: &CommandClientMessage{ + Type: "create-subscriber", + PublisherId: "foo", + StreamType: sfu.StreamTypeVideo, + RemoteUrl: ":", + RemoteToken: "remote-token", + }, + }, + "invalid remote url", + }, + { + &ClientMessage{ + Type: "command", + Command: &CommandClientMessage{ + Type: "create-subscriber", + PublisherId: "foo", + StreamType: sfu.StreamTypeVideo, + RemoteUrl: "http://domain.invalid", + RemoteToken: "remote-token", + }, + }, + "", + }, + { + &ClientMessage{ + Type: "command", + Command: &CommandClientMessage{ + Type: "delete-publisher", + }, + }, + "client id missing", + }, + { + &ClientMessage{ + Type: "command", + Command: &CommandClientMessage{ + Type: "delete-publisher", + ClientId: "foo", + }, + }, + "", + }, + { + &ClientMessage{ + Type: "command", + Command: &CommandClientMessage{ + Type: "delete-subscriber", + }, + }, + "client id missing", + }, + { + &ClientMessage{ + Type: "command", + Command: &CommandClientMessage{ + Type: "delete-subscriber", + ClientId: "foo", + }, + }, + "", + }, + { + &ClientMessage{ + Type: "payload", + }, + "payload missing", + }, + { + &ClientMessage{ + Type: "payload", + Payload: &PayloadClientMessage{}, + }, + "type missing", + }, + { + // Unknown types are ignored. + &ClientMessage{ + Type: "payload", + Payload: &PayloadClientMessage{ + Type: "invalid", + }, + }, + "client id missing", + }, + { + &ClientMessage{ + Type: "payload", + Payload: &PayloadClientMessage{ + Type: "offer", + }, + }, + "payload missing", + }, + { + &ClientMessage{ + Type: "payload", + Payload: &PayloadClientMessage{ + Type: "offer", + Payload: api.StringMap{ + "sdp": mock.MockSdpOfferAudioAndVideo, + }, + }, + }, + "client id missing", + }, + { + &ClientMessage{ + Type: "payload", + Payload: &PayloadClientMessage{ + Type: "offer", + Payload: api.StringMap{ + "sdp": mock.MockSdpOfferAudioAndVideo, + }, + ClientId: "foo", + }, + }, + "", + }, + { + &ClientMessage{ + Type: "payload", + Payload: &PayloadClientMessage{ + Type: "answer", + }, + }, + "payload missing", + }, + { + &ClientMessage{ + Type: "payload", + Payload: &PayloadClientMessage{ + Type: "answer", + Payload: api.StringMap{ + "sdp": mock.MockSdpAnswerAudioAndVideo, + }, + }, + }, + "client id missing", + }, + { + &ClientMessage{ + Type: "payload", + Payload: &PayloadClientMessage{ + Type: "answer", + Payload: api.StringMap{ + "sdp": mock.MockSdpAnswerAudioAndVideo, + }, + ClientId: "foo", + }, + }, + "", + }, + { + &ClientMessage{ + Type: "payload", + Payload: &PayloadClientMessage{ + Type: "candidate", + }, + }, + "payload missing", + }, + { + &ClientMessage{ + Type: "payload", + Payload: &PayloadClientMessage{ + Type: "candidate", + Payload: api.StringMap{ + "candidate": "invalid-candidate", + }, + }, + }, + "client id missing", + }, + { + &ClientMessage{ + Type: "payload", + Payload: &PayloadClientMessage{ + Type: "candidate", + Payload: api.StringMap{ + "candidate": "invalid-candidate", + }, + ClientId: "foo", + }, + }, + "", + }, + { + &ClientMessage{ + Type: "payload", + Payload: &PayloadClientMessage{ + Type: "endOfCandidates", + }, + }, + "client id missing", + }, + { + &ClientMessage{ + Type: "payload", + Payload: &PayloadClientMessage{ + Type: "endOfCandidates", + ClientId: "foo", + }, + }, + "", + }, + { + &ClientMessage{ + Type: "payload", + Payload: &PayloadClientMessage{ + Type: "requestoffer", + }, + }, + "client id missing", + }, + { + &ClientMessage{ + Type: "payload", + Payload: &PayloadClientMessage{ + Type: "requestoffer", + ClientId: "foo", + }, + }, + "", + }, + } + + for idx, tc := range testcases { + err := tc.message.CheckValid() + if tc.reason == "" { + assert.NoError(err, "failed for testcase %d: %+v", idx, tc.message) + } else { + assert.ErrorContains(err, tc.reason, "failed for testcase %d: %+v", idx, tc.message) + } + } +} + +func TestServerErrorMessage(t *testing.T) { + t.Parallel() + + assert := assert.New(t) + message := &ClientMessage{ + Id: "12346", + } + err := message.NewErrorServerMessage(api.NewError("error_code", "Test error")) + assert.Equal(message.Id, err.Id) + if e := err.Error; assert.NotNil(e) { + assert.Equal("error_code", e.Code) + assert.Equal("Test error", e.Message) + assert.Empty(e.Details) + } +} + +func TestWrapperServerErrorMessage(t *testing.T) { + t.Parallel() + + assert := assert.New(t) + message := &ClientMessage{ + Id: "12346", + } + err := message.NewWrappedErrorServerMessage(errors.New("an internal server error")) + assert.Equal(message.Id, err.Id) + if e := err.Error; assert.NotNil(e) { + assert.Equal("internal_error", e.Code) + assert.Equal("an internal server error", e.Message) + assert.Empty(e.Details) + } +} + +func TestCloseAfterSend(t *testing.T) { + t.Parallel() + + assert := assert.New(t) + message := &ServerMessage{ + Type: "bye", + } + assert.True(message.CloseAfterSend(nil)) + + for _, msgType := range []string{ + "error", + "hello", + "command", + "payload", + "event", + } { + message = &ServerMessage{ + Type: msgType, + } + assert.False(message.CloseAfterSend(nil), "failed for %s", msgType) + } +} + +func TestAllowIncoming(t *testing.T) { + t.Parallel() + + assert := assert.New(t) + testcases := []struct { + bw float64 + allow bool + }{ + { + 0, true, + }, + { + 99, true, + }, + { + 99.9, true, + }, + { + 100, false, + }, + { + 200, false, + }, + } + + bw := EventServerBandwidth{ + Incoming: nil, + } + assert.True(bw.AllowIncoming()) + for idx, tc := range testcases { + bw := EventServerBandwidth{ + Incoming: internal.MakePtr(tc.bw), + } + assert.Equal(tc.allow, bw.AllowIncoming(), "failed for testcase %d: %+v", idx, tc) + } +} + +func TestAllowOutgoing(t *testing.T) { + t.Parallel() + + assert := assert.New(t) + testcases := []struct { + bw float64 + allow bool + }{ + { + 0, true, + }, + { + 99, true, + }, + { + 99.9, true, + }, + { + 100, false, + }, + { + 200, false, + }, + } + + bw := EventServerBandwidth{ + Outgoing: nil, + } + assert.True(bw.AllowOutgoing()) + for idx, tc := range testcases { + bw := EventServerBandwidth{ + Outgoing: internal.MakePtr(tc.bw), + } + assert.Equal(tc.allow, bw.AllowOutgoing(), "failed for testcase %d: %+v", idx, tc) + } +} + +func TestInformationEtcd(t *testing.T) { + t.Parallel() + + assert := assert.New(t) + + info1 := &InformationEtcd{} + assert.ErrorContains(info1.CheckValid(), "address missing") + + info2 := &InformationEtcd{ + Address: "http://domain.invalid", + } + if assert.NoError(info2.CheckValid()) { + assert.Equal("http://domain.invalid/", info2.Address) + } + + info3 := &InformationEtcd{ + Address: "http://domain.invalid/", + } + if assert.NoError(info3.CheckValid()) { + assert.Equal("http://domain.invalid/", info3.Address) + } +} diff --git a/security/certificate_reloader_test.go b/security/certificate_reloader_test.go new file mode 100644 index 0000000..f8c61b0 --- /dev/null +++ b/security/certificate_reloader_test.go @@ -0,0 +1,154 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2026 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package security + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "path" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/strukturag/nextcloud-spreed-signaling/internal" + logtest "github.com/strukturag/nextcloud-spreed-signaling/log/test" +) + +type withReloadCounter interface { + GetReloadCounter() uint64 +} + +func waitForReload(ctx context.Context, t *testing.T, r withReloadCounter, expected uint64) bool { + t.Helper() + + for r.GetReloadCounter() < expected { + if !assert.NoError(t, ctx.Err()) { + return false + } + + time.Sleep(time.Millisecond) + } + return true +} + +func TestCertificateReloader(t *testing.T) { + t.Parallel() + + require := require.New(t) + assert := assert.New(t) + + key, err := rsa.GenerateKey(rand.Reader, 1024) + require.NoError(err) + + org1 := "Testing certificate" + cert1 := internal.GenerateSelfSignedCertificateForTesting(t, org1, key) + + tmpdir := t.TempDir() + certFile := path.Join(tmpdir, "cert.pem") + privkeyFile := path.Join(tmpdir, "privkey.pem") + pubkeyFile := path.Join(tmpdir, "pubkey.pem") + + require.NoError(internal.WritePrivateKey(key, privkeyFile)) + require.NoError(internal.WritePublicKey(&key.PublicKey, pubkeyFile)) + require.NoError(internal.WriteCertificate(cert1, certFile)) + + logger := logtest.NewLoggerForTest(t) + reloader, err := NewCertificateReloader(logger, certFile, privkeyFile) + require.NoError(err) + + defer reloader.Close() + + if cert, err := reloader.GetCertificate(nil); assert.NoError(err) { + assert.True(cert1.Equal(cert.Leaf)) + assert.True(key.Equal(cert.PrivateKey)) + } + if cert, err := reloader.GetClientCertificate(nil); assert.NoError(err) { + assert.True(cert1.Equal(cert.Leaf)) + assert.True(key.Equal(cert.PrivateKey)) + } + + ctx, cancel := context.WithTimeout(t.Context(), time.Second) + defer cancel() + + org2 := "Updated certificate" + cert2 := internal.GenerateSelfSignedCertificateForTesting(t, org2, key) + internal.ReplaceCertificate(t, certFile, cert2) + + waitForReload(ctx, t, reloader, 1) + + if cert, err := reloader.GetCertificate(nil); assert.NoError(err) { + assert.True(cert2.Equal(cert.Leaf)) + assert.True(key.Equal(cert.PrivateKey)) + } + if cert, err := reloader.GetClientCertificate(nil); assert.NoError(err) { + assert.True(cert2.Equal(cert.Leaf)) + assert.True(key.Equal(cert.PrivateKey)) + } +} + +func TestCertPoolReloader(t *testing.T) { + t.Parallel() + + require := require.New(t) + assert := assert.New(t) + + key, err := rsa.GenerateKey(rand.Reader, 1024) + require.NoError(err) + + org1 := "Testing certificate" + cert1 := internal.GenerateSelfSignedCertificateForTesting(t, org1, key) + + tmpdir := t.TempDir() + certFile := path.Join(tmpdir, "cert.pem") + privkeyFile := path.Join(tmpdir, "privkey.pem") + pubkeyFile := path.Join(tmpdir, "pubkey.pem") + + require.NoError(internal.WritePrivateKey(key, privkeyFile)) + require.NoError(internal.WritePublicKey(&key.PublicKey, pubkeyFile)) + require.NoError(internal.WriteCertificate(cert1, certFile)) + + logger := logtest.NewLoggerForTest(t) + reloader, err := NewCertPoolReloader(logger, certFile) + require.NoError(err) + + defer reloader.Close() + + pool1 := reloader.GetCertPool() + assert.NotNil(pool1) + + ctx, cancel := context.WithTimeout(t.Context(), time.Second) + defer cancel() + + org2 := "Updated certificate" + cert2 := internal.GenerateSelfSignedCertificateForTesting(t, org2, key) + internal.ReplaceCertificate(t, certFile, cert2) + + waitForReload(ctx, t, reloader, 1) + + pool2 := reloader.GetCertPool() + assert.NotNil(pool2) + + assert.False(pool1.Equal(pool2)) +} diff --git a/server/clientsession_test.go b/server/clientsession_test.go index 5466618..5c2def2 100644 --- a/server/clientsession_test.go +++ b/server/clientsession_test.go @@ -742,3 +742,178 @@ func TestPermissionHideDisplayNames(t *testing.T) { t.Run("without-hide-displaynames", testFunc(false)) // nolint:paralleltest t.Run("with-hide-displaynames", testFunc(true)) // nolint:paralleltest } + +func Test_ClientSessionPublisherEvents(t *testing.T) { + t.Parallel() + require := require.New(t) + assert := assert.New(t) + hub, _, _, server := CreateHubForTest(t) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + mcu := test.NewSFU(t) + require.NoError(mcu.Start(ctx)) + defer mcu.Stop() + + hub.SetMcu(mcu) + + client, hello := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId) + defer client.CloseWithBye() + + roomId := "test-room" + roomMsg := MustSucceed2(t, client.JoinRoom, ctx, roomId) + require.Equal(roomId, roomMsg.Room.RoomId) + + client.RunUntilJoined(ctx, hello.Hello) + + room := hub.getRoom(roomId) + require.NotNil(room) + + session := hub.GetSessionByPublicId(hello.Hello.SessionId).(*ClientSession) + require.NotNil(session, "Session %s does not exist", hello.Hello.SessionId) + + require.NoError(client.SendMessage(api.MessageClientMessageRecipient{ + Type: "session", + SessionId: hello.Hello.SessionId, + }, api.MessageClientMessageData{ + Type: "offer", + Sid: "54321", + RoomType: string(sfu.StreamTypeVideo), + Payload: api.StringMap{ + "sdp": mock.MockSdpOfferAudioAndVideo, + }, + })) + + require.True(client.RunUntilAnswer(ctx, mock.MockSdpAnswerAudioAndVideo)) + + pub := mcu.GetPublisher(hello.Hello.SessionId) + require.NotNil(pub) + + assert.Equal(pub, session.GetPublisher(sfu.StreamTypeVideo)) + session.OnIceCandidate(pub, "test-candidate") + + if message, ok := client.RunUntilMessage(ctx); ok { + assert.Equal("message", message.Type) + if msg := message.Message; assert.NotNil(msg) { + if sender := msg.Sender; assert.NotNil(sender) { + assert.Equal("session", sender.Type) + assert.Equal(hello.Hello.SessionId, sender.SessionId) + } + var ao api.AnswerOfferMessage + if assert.NoError(json.Unmarshal(msg.Data, &ao)) { + assert.Equal(hello.Hello.SessionId, ao.From) + assert.Equal(hello.Hello.SessionId, ao.To) + assert.Equal("candidate", ao.Type) + assert.EqualValues(sfu.StreamTypeVideo, ao.RoomType) + assert.Equal("test-candidate", ao.Payload["candidate"]) + } + } + } + + // No-op + session.OnIceCompleted(pub) + + session.PublisherClosed(pub) + assert.Nil(session.GetPublisher(sfu.StreamTypeVideo)) +} + +func Test_ClientSessionSubscriberEvents(t *testing.T) { + t.Parallel() + require := require.New(t) + assert := assert.New(t) + hub, _, _, server := CreateHubForTest(t) + hub.allowSubscribeAnyStream = true + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + mcu := test.NewSFU(t) + require.NoError(mcu.Start(ctx)) + defer mcu.Stop() + + hub.SetMcu(mcu) + + client1, hello1 := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId+"1") + defer client1.CloseWithBye() + client2, hello2 := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId+"2") + defer client2.CloseWithBye() + + roomId := "test-room" + roomMsg := MustSucceed2(t, client1.JoinRoom, ctx, roomId) + require.Equal(roomId, roomMsg.Room.RoomId) + roomMsg = MustSucceed2(t, client2.JoinRoom, ctx, roomId) + require.Equal(roomId, roomMsg.Room.RoomId) + + WaitForUsersJoined(ctx, t, client1, hello1, client2, hello2) + + room := hub.getRoom(roomId) + require.NotNil(room) + + session1 := hub.GetSessionByPublicId(hello1.Hello.SessionId).(*ClientSession) + require.NotNil(session1, "Session %s does not exist", hello1.Hello.SessionId) + session2 := hub.GetSessionByPublicId(hello2.Hello.SessionId).(*ClientSession) + require.NotNil(session2, "Session %s does not exist", hello2.Hello.SessionId) + + require.NoError(client1.SendMessage(api.MessageClientMessageRecipient{ + Type: "session", + SessionId: hello1.Hello.SessionId, + }, api.MessageClientMessageData{ + Type: "offer", + Sid: "54321", + RoomType: string(sfu.StreamTypeVideo), + Payload: api.StringMap{ + "sdp": mock.MockSdpOfferAudioAndVideo, + }, + })) + + require.True(client1.RunUntilAnswer(ctx, mock.MockSdpAnswerAudioAndVideo)) + + require.NoError(client2.SendMessage(api.MessageClientMessageRecipient{ + Type: "session", + SessionId: hello1.Hello.SessionId, + }, api.MessageClientMessageData{ + Type: "requestoffer", + Sid: "54321", + RoomType: string(sfu.StreamTypeVideo), + })) + + require.True(client2.RunUntilOffer(ctx, mock.MockSdpOfferAudioAndVideo)) + + sub := mcu.GetSubscriber(hello1.Hello.SessionId, sfu.StreamTypeVideo) + require.NotNil(sub) + + assert.Equal(sub, session2.GetSubscriber(hello1.Hello.SessionId, sfu.StreamTypeVideo)) + session2.OnIceCandidate(sub, "test-candidate") + + if message, ok := client2.RunUntilMessage(ctx); ok { + assert.Equal("message", message.Type) + if msg := message.Message; assert.NotNil(msg) { + if sender := msg.Sender; assert.NotNil(sender) { + assert.Equal("session", sender.Type) + assert.Equal(hello1.Hello.SessionId, sender.SessionId) + } + var ao api.AnswerOfferMessage + if assert.NoError(json.Unmarshal(msg.Data, &ao)) { + assert.Equal(hello1.Hello.SessionId, ao.From) + assert.Equal(hello2.Hello.SessionId, ao.To) + assert.Equal("candidate", ao.Type) + assert.EqualValues(sfu.StreamTypeVideo, ao.RoomType) + assert.Equal("test-candidate", ao.Payload["candidate"]) + } + } + } + + // No-op + session2.OnIceCompleted(sub) + + session2.OnUpdateOffer(sub, api.StringMap{ + "type": "offer", + "sdp": mock.MockSdpOfferAudioOnly, + }) + + require.True(client2.RunUntilOffer(ctx, mock.MockSdpOfferAudioOnly)) + + session2.SubscriberClosed(sub) + assert.Nil(session2.GetSubscriber(hello1.Hello.SessionId, sfu.StreamTypeVideo)) +} diff --git a/server/hub_sfu_janus_test.go b/server/hub_sfu_janus_test.go deleted file mode 100644 index 5d30758..0000000 --- a/server/hub_sfu_janus_test.go +++ /dev/null @@ -1,758 +0,0 @@ -/** - * Standalone signaling server for the Nextcloud Spreed app. - * Copyright (C) 2026 struktur AG - * - * @author Joachim Bauch - * - * @license GNU AGPL version 3 or any later version - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ -package server - -import ( - "context" - "strings" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/dlintw/goconf" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/strukturag/nextcloud-spreed-signaling/api" - "github.com/strukturag/nextcloud-spreed-signaling/log" - logtest "github.com/strukturag/nextcloud-spreed-signaling/log/test" - "github.com/strukturag/nextcloud-spreed-signaling/mock" - "github.com/strukturag/nextcloud-spreed-signaling/sfu" - sfujanus "github.com/strukturag/nextcloud-spreed-signaling/sfu/janus" - "github.com/strukturag/nextcloud-spreed-signaling/sfu/janus/janus" - janustest "github.com/strukturag/nextcloud-spreed-signaling/sfu/janus/test" -) - -type JanusSFU interface { - sfu.SFU - - SetStats(stats sfujanus.Stats) - Settings() *sfujanus.Settings -} - -func newMcuJanusForTesting(t *testing.T) (JanusSFU, *janustest.JanusGateway) { - gateway := janustest.NewJanusGateway(t) - - config := goconf.NewConfigFile() - if strings.Contains(t.Name(), "Filter") { - config.AddOption("mcu", "blockedcandidates", "192.0.0.0/24, 192.168.0.0/16") - } - logger := logtest.NewLoggerForTest(t) - ctx := log.NewLoggerContext(t.Context(), logger) - mcu, err := sfujanus.NewJanusSFUWithGateway(ctx, gateway, config) - require.NoError(t, err) - t.Cleanup(func() { - mcu.Stop() - }) - - require.NoError(t, mcu.Start(ctx)) - return mcu.(JanusSFU), gateway -} - -type mockJanusStats struct { - called atomic.Bool - - mu sync.Mutex - // +checklocks:mu - value map[sfu.StreamType]int -} - -func (s *mockJanusStats) Value(streamType sfu.StreamType) int { - s.mu.Lock() - defer s.mu.Unlock() - - return s.value[streamType] -} - -func (s *mockJanusStats) IncSubscriber(streamType sfu.StreamType) { - s.called.Store(true) - - s.mu.Lock() - defer s.mu.Unlock() - - if s.value == nil { - s.value = make(map[sfu.StreamType]int) - } - s.value[streamType]++ -} - -func (s *mockJanusStats) DecSubscriber(streamType sfu.StreamType) { - s.called.Store(true) - - s.mu.Lock() - defer s.mu.Unlock() - - if s.value == nil { - s.value = make(map[sfu.StreamType]int) - } - s.value[streamType]-- -} - -func Test_JanusSubscriberNoSuchRoom(t *testing.T) { - t.Parallel() - require := require.New(t) - assert := assert.New(t) - - stats := &mockJanusStats{} - - t.Cleanup(func() { - if !t.Failed() { - assert.True(stats.called.Load(), "stats were not called") - assert.Equal(0, stats.Value("video")) - } - }) - - mcu, gateway := newMcuJanusForTesting(t) - mcu.SetStats(stats) - gateway.RegisterHandlers(map[string]janustest.JanusHandler{ - "configure": func(room *janustest.JanusRoom, body, jsep api.StringMap) (any, *janus.ErrorMsg) { - assert.EqualValues(1, room.Id()) - return &janus.EventMsg{ - Jsep: api.StringMap{ - "type": "answer", - "sdp": mock.MockSdpAnswerAudioAndVideo, - }, - }, nil - }, - }) - - hub, _, _, server := CreateHubForTest(t) - hub.SetMcu(mcu) - - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() - - client1, hello1 := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId+"1") - client2, hello2 := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId+"2") - require.NotEqual(hello1.Hello.SessionId, hello2.Hello.SessionId) - require.NotEqual(hello1.Hello.UserId, hello2.Hello.UserId) - - // Join room by id. - roomId := "test-room" - roomMsg := MustSucceed2(t, client1.JoinRoom, ctx, roomId) - require.Equal(roomId, roomMsg.Room.RoomId) - - // Give message processing some time. - time.Sleep(10 * time.Millisecond) - - roomMsg = MustSucceed2(t, client2.JoinRoom, ctx, roomId) - require.Equal(roomId, roomMsg.Room.RoomId) - - WaitForUsersJoined(ctx, t, client1, hello1, client2, hello2) - - // Simulate request from the backend that sessions joined the call. - users1 := []api.StringMap{ - { - "sessionId": hello1.Hello.SessionId, - "inCall": 1, - }, - { - "sessionId": hello2.Hello.SessionId, - "inCall": 1, - }, - } - room := hub.getRoom(roomId) - require.NotNil(room, "Could not find room %s", roomId) - room.PublishUsersInCallChanged(users1, users1) - checkReceiveClientEvent(ctx, t, client1, "update", nil) - checkReceiveClientEvent(ctx, t, client2, "update", nil) - - require.NoError(client1.SendMessage(api.MessageClientMessageRecipient{ - Type: "session", - SessionId: hello1.Hello.SessionId, - }, api.MessageClientMessageData{ - Type: "offer", - RoomType: "video", - Payload: api.StringMap{ - "sdp": mock.MockSdpOfferAudioAndVideo, - }, - })) - - client1.RunUntilAnswer(ctx, mock.MockSdpAnswerAudioAndVideo) - - require.NoError(client2.SendMessage(api.MessageClientMessageRecipient{ - Type: "session", - SessionId: hello1.Hello.SessionId, - }, api.MessageClientMessageData{ - Type: "requestoffer", - RoomType: "video", - })) - - MustSucceed2(t, client2.RunUntilError, ctx, "processing_failed") // nolint - - require.NoError(client2.SendMessage(api.MessageClientMessageRecipient{ - Type: "session", - SessionId: hello1.Hello.SessionId, - }, api.MessageClientMessageData{ - Type: "requestoffer", - RoomType: "video", - })) - - client2.RunUntilOffer(ctx, mock.MockSdpOfferAudioAndVideo) -} - -func test_JanusSubscriberAlreadyJoined(t *testing.T) { - require := require.New(t) - assert := assert.New(t) - - stats := &mockJanusStats{} - - t.Cleanup(func() { - if !t.Failed() { - assert.True(stats.called.Load(), "stats were not called") - assert.Equal(0, stats.Value("video")) - } - }) - - mcu, gateway := newMcuJanusForTesting(t) - mcu.SetStats(stats) - gateway.RegisterHandlers(map[string]janustest.JanusHandler{ - "configure": func(room *janustest.JanusRoom, body, jsep api.StringMap) (any, *janus.ErrorMsg) { - assert.EqualValues(1, room.Id()) - return &janus.EventMsg{ - Jsep: api.StringMap{ - "type": "answer", - "sdp": mock.MockSdpAnswerAudioAndVideo, - }, - }, nil - }, - }) - - hub, _, _, server := CreateHubForTest(t) - hub.SetMcu(mcu) - - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() - - client1, hello1 := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId+"1") - client2, hello2 := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId+"2") - require.NotEqual(hello1.Hello.SessionId, hello2.Hello.SessionId) - require.NotEqual(hello1.Hello.UserId, hello2.Hello.UserId) - - // Join room by id. - roomId := "test-room" - roomMsg := MustSucceed2(t, client1.JoinRoom, ctx, roomId) - require.Equal(roomId, roomMsg.Room.RoomId) - - // Give message processing some time. - time.Sleep(10 * time.Millisecond) - - roomMsg = MustSucceed2(t, client2.JoinRoom, ctx, roomId) - require.Equal(roomId, roomMsg.Room.RoomId) - - WaitForUsersJoined(ctx, t, client1, hello1, client2, hello2) - - // Simulate request from the backend that sessions joined the call. - users1 := []api.StringMap{ - { - "sessionId": hello1.Hello.SessionId, - "inCall": 1, - }, - { - "sessionId": hello2.Hello.SessionId, - "inCall": 1, - }, - } - room := hub.getRoom(roomId) - require.NotNil(room, "Could not find room %s", roomId) - room.PublishUsersInCallChanged(users1, users1) - checkReceiveClientEvent(ctx, t, client1, "update", nil) - checkReceiveClientEvent(ctx, t, client2, "update", nil) - - require.NoError(client1.SendMessage(api.MessageClientMessageRecipient{ - Type: "session", - SessionId: hello1.Hello.SessionId, - }, api.MessageClientMessageData{ - Type: "offer", - RoomType: "video", - Payload: api.StringMap{ - "sdp": mock.MockSdpOfferAudioAndVideo, - }, - })) - - client1.RunUntilAnswer(ctx, mock.MockSdpAnswerAudioAndVideo) - - require.NoError(client2.SendMessage(api.MessageClientMessageRecipient{ - Type: "session", - SessionId: hello1.Hello.SessionId, - }, api.MessageClientMessageData{ - Type: "requestoffer", - RoomType: "video", - })) - - if strings.Contains(t.Name(), "AttachError") { - MustSucceed2(t, client2.RunUntilError, ctx, "processing_failed") // nolint - - require.NoError(client2.SendMessage(api.MessageClientMessageRecipient{ - Type: "session", - SessionId: hello1.Hello.SessionId, - }, api.MessageClientMessageData{ - Type: "requestoffer", - RoomType: "video", - })) - } - - client2.RunUntilOffer(ctx, mock.MockSdpOfferAudioAndVideo) -} - -func Test_JanusSubscriberAlreadyJoined(t *testing.T) { - t.Parallel() - test_JanusSubscriberAlreadyJoined(t) -} - -func Test_JanusSubscriberAlreadyJoinedAttachError(t *testing.T) { - t.Parallel() - test_JanusSubscriberAlreadyJoined(t) -} - -func Test_JanusSubscriberTimeout(t *testing.T) { - t.Parallel() - require := require.New(t) - assert := assert.New(t) - - stats := &mockJanusStats{} - - t.Cleanup(func() { - if !t.Failed() { - assert.True(stats.called.Load(), "stats were not called") - assert.Equal(0, stats.Value("video")) - } - }) - - mcu, gateway := newMcuJanusForTesting(t) - mcu.SetStats(stats) - gateway.RegisterHandlers(map[string]janustest.JanusHandler{ - "configure": func(room *janustest.JanusRoom, body, jsep api.StringMap) (any, *janus.ErrorMsg) { - assert.EqualValues(1, room.Id()) - return &janus.EventMsg{ - Jsep: api.StringMap{ - "type": "answer", - "sdp": mock.MockSdpAnswerAudioAndVideo, - }, - }, nil - }, - }) - - hub, _, _, server := CreateHubForTest(t) - hub.SetMcu(mcu) - - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() - - client1, hello1 := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId+"1") - client2, hello2 := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId+"2") - require.NotEqual(hello1.Hello.SessionId, hello2.Hello.SessionId) - require.NotEqual(hello1.Hello.UserId, hello2.Hello.UserId) - - // Join room by id. - roomId := "test-room" - roomMsg := MustSucceed2(t, client1.JoinRoom, ctx, roomId) - require.Equal(roomId, roomMsg.Room.RoomId) - - // Give message processing some time. - time.Sleep(10 * time.Millisecond) - - roomMsg = MustSucceed2(t, client2.JoinRoom, ctx, roomId) - require.Equal(roomId, roomMsg.Room.RoomId) - - WaitForUsersJoined(ctx, t, client1, hello1, client2, hello2) - - // Simulate request from the backend that sessions joined the call. - users1 := []api.StringMap{ - { - "sessionId": hello1.Hello.SessionId, - "inCall": 1, - }, - { - "sessionId": hello2.Hello.SessionId, - "inCall": 1, - }, - } - room := hub.getRoom(roomId) - require.NotNil(room, "Could not find room %s", roomId) - room.PublishUsersInCallChanged(users1, users1) - checkReceiveClientEvent(ctx, t, client1, "update", nil) - checkReceiveClientEvent(ctx, t, client2, "update", nil) - - require.NoError(client1.SendMessage(api.MessageClientMessageRecipient{ - Type: "session", - SessionId: hello1.Hello.SessionId, - }, api.MessageClientMessageData{ - Type: "offer", - RoomType: "video", - Payload: api.StringMap{ - "sdp": mock.MockSdpOfferAudioAndVideo, - }, - })) - - client1.RunUntilAnswer(ctx, mock.MockSdpAnswerAudioAndVideo) - - oldTimeout := mcu.Settings().Timeout() - mcu.Settings().SetTimeout(100 * time.Millisecond) - - require.NoError(client2.SendMessage(api.MessageClientMessageRecipient{ - Type: "session", - SessionId: hello1.Hello.SessionId, - }, api.MessageClientMessageData{ - Type: "requestoffer", - RoomType: "video", - })) - - MustSucceed2(t, client2.RunUntilError, ctx, "processing_failed") // nolint - - mcu.Settings().SetTimeout(oldTimeout) - - require.NoError(client2.SendMessage(api.MessageClientMessageRecipient{ - Type: "session", - SessionId: hello1.Hello.SessionId, - }, api.MessageClientMessageData{ - Type: "requestoffer", - RoomType: "video", - })) - - client2.RunUntilOffer(ctx, mock.MockSdpOfferAudioAndVideo) -} - -func Test_JanusSubscriberCloseEmptyStreams(t *testing.T) { - t.Parallel() - require := require.New(t) - assert := assert.New(t) - - stats := &mockJanusStats{} - - t.Cleanup(func() { - if !t.Failed() { - assert.True(stats.called.Load(), "stats were not called") - assert.Equal(0, stats.Value("video")) - } - }) - - mcu, gateway := newMcuJanusForTesting(t) - mcu.SetStats(stats) - gateway.RegisterHandlers(map[string]janustest.JanusHandler{ - "configure": func(room *janustest.JanusRoom, body, jsep api.StringMap) (any, *janus.ErrorMsg) { - assert.EqualValues(1, room.Id()) - return &janus.EventMsg{ - Jsep: api.StringMap{ - "type": "answer", - "sdp": mock.MockSdpAnswerAudioAndVideo, - }, - }, nil - }, - }) - - hub, _, _, server := CreateHubForTest(t) - hub.SetMcu(mcu) - - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() - - client1, hello1 := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId+"1") - client2, hello2 := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId+"2") - require.NotEqual(hello1.Hello.SessionId, hello2.Hello.SessionId) - require.NotEqual(hello1.Hello.UserId, hello2.Hello.UserId) - - // Join room by id. - roomId := "test-room" - roomMsg := MustSucceed2(t, client1.JoinRoom, ctx, roomId) - require.Equal(roomId, roomMsg.Room.RoomId) - - // Give message processing some time. - time.Sleep(10 * time.Millisecond) - - roomMsg = MustSucceed2(t, client2.JoinRoom, ctx, roomId) - require.Equal(roomId, roomMsg.Room.RoomId) - - WaitForUsersJoined(ctx, t, client1, hello1, client2, hello2) - - // Simulate request from the backend that sessions joined the call. - users1 := []api.StringMap{ - { - "sessionId": hello1.Hello.SessionId, - "inCall": 1, - }, - { - "sessionId": hello2.Hello.SessionId, - "inCall": 1, - }, - } - room := hub.getRoom(roomId) - require.NotNil(room, "Could not find room %s", roomId) - room.PublishUsersInCallChanged(users1, users1) - checkReceiveClientEvent(ctx, t, client1, "update", nil) - checkReceiveClientEvent(ctx, t, client2, "update", nil) - - require.NoError(client1.SendMessage(api.MessageClientMessageRecipient{ - Type: "session", - SessionId: hello1.Hello.SessionId, - }, api.MessageClientMessageData{ - Type: "offer", - RoomType: "video", - Payload: api.StringMap{ - "sdp": mock.MockSdpOfferAudioAndVideo, - }, - })) - - client1.RunUntilAnswer(ctx, mock.MockSdpAnswerAudioAndVideo) - - require.NoError(client2.SendMessage(api.MessageClientMessageRecipient{ - Type: "session", - SessionId: hello1.Hello.SessionId, - }, api.MessageClientMessageData{ - Type: "requestoffer", - RoomType: "video", - })) - - client2.RunUntilOffer(ctx, mock.MockSdpOfferAudioAndVideo) - - sess2 := hub.GetSessionByPublicId(hello2.Hello.SessionId) - require.NotNil(sess2) - session2 := sess2.(*ClientSession) - - sub := session2.GetSubscriber(hello1.Hello.SessionId, sfu.StreamTypeVideo) - require.NotNil(sub) - - subscriber := sub.(sfujanus.Subscriber) - handle := subscriber.JanusHandle() - require.NotNil(handle) - - for ctx.Err() == nil { - if handle = subscriber.JanusHandle(); handle == nil { - break - } - - time.Sleep(time.Millisecond) - } - - assert.Nil(handle, "subscriber should have been closed") -} - -func Test_JanusSubscriberRoomDestroyed(t *testing.T) { - t.Parallel() - require := require.New(t) - assert := assert.New(t) - - stats := &mockJanusStats{} - - t.Cleanup(func() { - if !t.Failed() { - assert.True(stats.called.Load(), "stats were not called") - assert.Equal(0, stats.Value("video")) - } - }) - - mcu, gateway := newMcuJanusForTesting(t) - mcu.SetStats(stats) - gateway.RegisterHandlers(map[string]janustest.JanusHandler{ - "configure": func(room *janustest.JanusRoom, body, jsep api.StringMap) (any, *janus.ErrorMsg) { - assert.EqualValues(1, room.Id()) - return &janus.EventMsg{ - Jsep: api.StringMap{ - "type": "answer", - "sdp": mock.MockSdpAnswerAudioAndVideo, - }, - }, nil - }, - }) - - hub, _, _, server := CreateHubForTest(t) - hub.SetMcu(mcu) - - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() - - client1, hello1 := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId+"1") - client2, hello2 := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId+"2") - require.NotEqual(hello1.Hello.SessionId, hello2.Hello.SessionId) - require.NotEqual(hello1.Hello.UserId, hello2.Hello.UserId) - - // Join room by id. - roomId := "test-room" - roomMsg := MustSucceed2(t, client1.JoinRoom, ctx, roomId) - require.Equal(roomId, roomMsg.Room.RoomId) - - // Give message processing some time. - time.Sleep(10 * time.Millisecond) - - roomMsg = MustSucceed2(t, client2.JoinRoom, ctx, roomId) - require.Equal(roomId, roomMsg.Room.RoomId) - - WaitForUsersJoined(ctx, t, client1, hello1, client2, hello2) - - // Simulate request from the backend that sessions joined the call. - users1 := []api.StringMap{ - { - "sessionId": hello1.Hello.SessionId, - "inCall": 1, - }, - { - "sessionId": hello2.Hello.SessionId, - "inCall": 1, - }, - } - room := hub.getRoom(roomId) - require.NotNil(room, "Could not find room %s", roomId) - room.PublishUsersInCallChanged(users1, users1) - checkReceiveClientEvent(ctx, t, client1, "update", nil) - checkReceiveClientEvent(ctx, t, client2, "update", nil) - - require.NoError(client1.SendMessage(api.MessageClientMessageRecipient{ - Type: "session", - SessionId: hello1.Hello.SessionId, - }, api.MessageClientMessageData{ - Type: "offer", - RoomType: "video", - Payload: api.StringMap{ - "sdp": mock.MockSdpOfferAudioAndVideo, - }, - })) - - client1.RunUntilAnswer(ctx, mock.MockSdpAnswerAudioAndVideo) - - require.NoError(client2.SendMessage(api.MessageClientMessageRecipient{ - Type: "session", - SessionId: hello1.Hello.SessionId, - }, api.MessageClientMessageData{ - Type: "requestoffer", - RoomType: "video", - })) - - client2.RunUntilOffer(ctx, mock.MockSdpOfferAudioAndVideo) - - sess2 := hub.GetSessionByPublicId(hello2.Hello.SessionId) - require.NotNil(sess2) - session2 := sess2.(*ClientSession) - - sub := session2.GetSubscriber(hello1.Hello.SessionId, sfu.StreamTypeVideo) - require.NotNil(sub) - - subscriber := sub.(sfujanus.Subscriber) - handle := subscriber.JanusHandle() - require.NotNil(handle) - - for ctx.Err() == nil { - if handle = subscriber.JanusHandle(); handle == nil { - break - } - - time.Sleep(time.Millisecond) - } - - assert.Nil(handle, "subscriber should have been closed") -} - -func Test_JanusSubscriberUpdateOffer(t *testing.T) { - t.Parallel() - require := require.New(t) - assert := assert.New(t) - - stats := &mockJanusStats{} - - t.Cleanup(func() { - if !t.Failed() { - assert.True(stats.called.Load(), "stats were not called") - assert.Equal(0, stats.Value("video")) - } - }) - - mcu, gateway := newMcuJanusForTesting(t) - mcu.SetStats(stats) - gateway.RegisterHandlers(map[string]janustest.JanusHandler{ - "configure": func(room *janustest.JanusRoom, body, jsep api.StringMap) (any, *janus.ErrorMsg) { - assert.EqualValues(1, room.Id()) - return &janus.EventMsg{ - Jsep: api.StringMap{ - "type": "answer", - "sdp": mock.MockSdpAnswerAudioAndVideo, - }, - }, nil - }, - }) - - hub, _, _, server := CreateHubForTest(t) - hub.SetMcu(mcu) - - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() - - client1, hello1 := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId+"1") - client2, hello2 := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId+"2") - require.NotEqual(hello1.Hello.SessionId, hello2.Hello.SessionId) - require.NotEqual(hello1.Hello.UserId, hello2.Hello.UserId) - - // Join room by id. - roomId := "test-room" - roomMsg := MustSucceed2(t, client1.JoinRoom, ctx, roomId) - require.Equal(roomId, roomMsg.Room.RoomId) - - // Give message processing some time. - time.Sleep(10 * time.Millisecond) - - roomMsg = MustSucceed2(t, client2.JoinRoom, ctx, roomId) - require.Equal(roomId, roomMsg.Room.RoomId) - - WaitForUsersJoined(ctx, t, client1, hello1, client2, hello2) - - // Simulate request from the backend that sessions joined the call. - users1 := []api.StringMap{ - { - "sessionId": hello1.Hello.SessionId, - "inCall": 1, - }, - { - "sessionId": hello2.Hello.SessionId, - "inCall": 1, - }, - } - room := hub.getRoom(roomId) - require.NotNil(room, "Could not find room %s", roomId) - room.PublishUsersInCallChanged(users1, users1) - checkReceiveClientEvent(ctx, t, client1, "update", nil) - checkReceiveClientEvent(ctx, t, client2, "update", nil) - - require.NoError(client1.SendMessage(api.MessageClientMessageRecipient{ - Type: "session", - SessionId: hello1.Hello.SessionId, - }, api.MessageClientMessageData{ - Type: "offer", - RoomType: "video", - Payload: api.StringMap{ - "sdp": mock.MockSdpOfferAudioAndVideo, - }, - })) - - client1.RunUntilAnswer(ctx, mock.MockSdpAnswerAudioAndVideo) - - require.NoError(client2.SendMessage(api.MessageClientMessageRecipient{ - Type: "session", - SessionId: hello1.Hello.SessionId, - }, api.MessageClientMessageData{ - Type: "requestoffer", - RoomType: "video", - })) - - client2.RunUntilOffer(ctx, mock.MockSdpOfferAudioAndVideo) - - // Test MCU will trigger an updated offer. - client2.RunUntilOffer(ctx, mock.MockSdpOfferAudioOnly) -} diff --git a/server/hub_sfu_proxy_test.go b/server/hub_sfu_proxy_test.go deleted file mode 100644 index 278f3be..0000000 --- a/server/hub_sfu_proxy_test.go +++ /dev/null @@ -1,653 +0,0 @@ -/** - * Standalone signaling server for the Nextcloud Spreed app. - * Copyright (C) 2026 struktur AG - * - * @author Joachim Bauch - * - * @license GNU AGPL version 3 or any later version - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ -package server - -import ( - "context" - "errors" - "net/url" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/strukturag/nextcloud-spreed-signaling/api" - etcdtest "github.com/strukturag/nextcloud-spreed-signaling/etcd/test" - "github.com/strukturag/nextcloud-spreed-signaling/grpc" - grpctest "github.com/strukturag/nextcloud-spreed-signaling/grpc/test" - "github.com/strukturag/nextcloud-spreed-signaling/sfu" - "github.com/strukturag/nextcloud-spreed-signaling/sfu/mock" - proxytest "github.com/strukturag/nextcloud-spreed-signaling/sfu/proxy/test" - "github.com/strukturag/nextcloud-spreed-signaling/sfu/proxy/testserver" - "github.com/strukturag/nextcloud-spreed-signaling/talk" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - -type mockGrpcServerHub struct { - proxy atomic.Pointer[sfu.WithToken] - sessionsLock sync.Mutex - // +checklocks:sessionsLock - sessionByPublicId map[api.PublicSessionId]Session -} - -func (h *mockGrpcServerHub) setProxy(t *testing.T, proxy sfu.SFU) { - t.Helper() - - wt, ok := proxy.(sfu.WithToken) - require.True(t, ok, "need a sfu with token support") - h.proxy.Store(&wt) -} - -func (h *mockGrpcServerHub) getSession(sessionId api.PublicSessionId) Session { - h.sessionsLock.Lock() - defer h.sessionsLock.Unlock() - - return h.sessionByPublicId[sessionId] -} - -func (h *mockGrpcServerHub) addSession(session *ClientSession) { - h.sessionsLock.Lock() - defer h.sessionsLock.Unlock() - if h.sessionByPublicId == nil { - h.sessionByPublicId = make(map[api.PublicSessionId]Session) - } - h.sessionByPublicId[session.PublicId()] = session -} - -func (h *mockGrpcServerHub) removeSession(session *ClientSession) { - h.sessionsLock.Lock() - defer h.sessionsLock.Unlock() - delete(h.sessionByPublicId, session.PublicId()) -} - -func (h *mockGrpcServerHub) GetSessionIdByResumeId(resumeId api.PrivateSessionId) api.PublicSessionId { - return "" -} - -func (h *mockGrpcServerHub) GetSessionIdByRoomSessionId(roomSessionId api.RoomSessionId) (api.PublicSessionId, error) { - return "", nil -} - -func (h *mockGrpcServerHub) IsSessionIdInCall(sessionId api.PublicSessionId, roomId string, backendUrl string) (bool, bool) { - return false, false -} - -func (h *mockGrpcServerHub) DisconnectSessionByRoomSessionId(sessionId api.PublicSessionId, roomSessionId api.RoomSessionId, reason string) { -} - -func (h *mockGrpcServerHub) GetBackend(u *url.URL) *talk.Backend { - return nil -} - -func (h *mockGrpcServerHub) GetInternalSessions(roomId string, backend *talk.Backend) ([]*grpc.InternalSessionData, []*grpc.VirtualSessionData, bool) { - return nil, nil, false -} - -func (h *mockGrpcServerHub) GetTransientEntries(roomId string, backend *talk.Backend) (api.TransientDataEntries, bool) { - return nil, false -} - -func (h *mockGrpcServerHub) GetPublisherIdForSessionId(ctx context.Context, sessionId api.PublicSessionId, streamType sfu.StreamType) (*grpc.GetPublisherIdReply, error) { - session := h.getSession(sessionId) - if session == nil { - return nil, status.Error(codes.NotFound, "no such session") - } - - clientSession, ok := session.(*ClientSession) - if !ok { - return nil, status.Error(codes.NotFound, "no such session") - } - - publisher := clientSession.GetOrWaitForPublisher(ctx, streamType) - if publisher, ok := publisher.(sfu.PublisherWithConnectionUrlAndIP); ok { - connUrl, ip := publisher.GetConnectionURL() - reply := &grpc.GetPublisherIdReply{ - PublisherId: publisher.Id(), - ProxyUrl: connUrl, - } - if len(ip) > 0 { - reply.Ip = ip.String() - } - - if proxy := h.proxy.Load(); proxy != nil { - reply.ConnectToken, _ = (*proxy).CreateToken("") - reply.PublisherToken, _ = (*proxy).CreateToken(publisher.Id()) - } - return reply, nil - } - - return nil, status.Error(codes.NotFound, "no such publisher") -} - -func (h *mockGrpcServerHub) ProxySession(request grpc.RpcSessions_ProxySessionServer) error { - return errors.New("not implemented") -} - -func Test_ProxyRemotePublisher(t *testing.T) { - t.Parallel() - - embedEtcd := etcdtest.NewServerForTest(t) - - grpcServer1, addr1 := grpctest.NewServerForTest(t) - grpcServer2, addr2 := grpctest.NewServerForTest(t) - - hub1 := &mockGrpcServerHub{} - hub2 := &mockGrpcServerHub{} - grpcServer1.SetHub(hub1) - grpcServer2.SetHub(hub2) - - embedEtcd.SetValue("/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}")) - embedEtcd.SetValue("/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}")) - - server1 := testserver.NewProxyServerForTest(t, "DE") - server2 := testserver.NewProxyServerForTest(t, "DE") - - mcu1, _ := proxytest.NewMcuProxyForTestWithOptions(t, testserver.ProxyTestOptions{ - Etcd: embedEtcd, - Servers: []testserver.ProxyTestServer{ - server1, - server2, - }, - }, 1, nil) - hub1.setProxy(t, mcu1) - mcu2, _ := proxytest.NewMcuProxyForTestWithOptions(t, testserver.ProxyTestOptions{ - Etcd: embedEtcd, - Servers: []testserver.ProxyTestServer{ - server1, - server2, - }, - }, 2, nil) - hub2.setProxy(t, mcu2) - - ctx, cancel := context.WithTimeout(t.Context(), testTimeout) - defer cancel() - - pubId := api.PublicSessionId("the-publisher") - pubSid := "1234567890" - pubListener := mock.NewListener(pubId + "-public") - pubInitiator := mock.NewInitiator("DE") - - session1 := &ClientSession{ - publicId: pubId, - publishers: make(map[sfu.StreamType]sfu.Publisher), - } - hub1.addSession(session1) - defer hub1.removeSession(session1) - - pub, err := mcu1.NewPublisher(ctx, pubListener, pubId, pubSid, sfu.StreamTypeVideo, sfu.NewPublisherSettings{ - MediaTypes: sfu.MediaTypeVideo | sfu.MediaTypeAudio, - }, pubInitiator) - require.NoError(t, err) - - defer pub.Close(context.Background()) - - session1.mu.Lock() - session1.publishers[sfu.StreamTypeVideo] = pub - session1.publisherWaiters.Wakeup() - session1.mu.Unlock() - - subListener := mock.NewListener("subscriber-public") - subInitiator := mock.NewInitiator("DE") - sub, err := mcu2.NewSubscriber(ctx, subListener, pubId, sfu.StreamTypeVideo, subInitiator) - require.NoError(t, err) - - defer sub.Close(context.Background()) -} - -func Test_ProxyMultipleRemotePublisher(t *testing.T) { - t.Parallel() - - embedEtcd := etcdtest.NewServerForTest(t) - - grpcServer1, addr1 := grpctest.NewServerForTest(t) - grpcServer2, addr2 := grpctest.NewServerForTest(t) - grpcServer3, addr3 := grpctest.NewServerForTest(t) - - hub1 := &mockGrpcServerHub{} - hub2 := &mockGrpcServerHub{} - hub3 := &mockGrpcServerHub{} - grpcServer1.SetHub(hub1) - grpcServer2.SetHub(hub2) - grpcServer3.SetHub(hub3) - - embedEtcd.SetValue("/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}")) - embedEtcd.SetValue("/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}")) - embedEtcd.SetValue("/grpctargets/three", []byte("{\"address\":\""+addr3+"\"}")) - - server1 := testserver.NewProxyServerForTest(t, "DE") - server2 := testserver.NewProxyServerForTest(t, "US") - server3 := testserver.NewProxyServerForTest(t, "US") - - mcu1, _ := proxytest.NewMcuProxyForTestWithOptions(t, testserver.ProxyTestOptions{ - Etcd: embedEtcd, - Servers: []testserver.ProxyTestServer{ - server1, - server2, - server3, - }, - }, 1, nil) - hub1.setProxy(t, mcu1) - mcu2, _ := proxytest.NewMcuProxyForTestWithOptions(t, testserver.ProxyTestOptions{ - Etcd: embedEtcd, - Servers: []testserver.ProxyTestServer{ - server1, - server2, - server3, - }, - }, 2, nil) - hub2.setProxy(t, mcu2) - mcu3, _ := proxytest.NewMcuProxyForTestWithOptions(t, testserver.ProxyTestOptions{ - Etcd: embedEtcd, - Servers: []testserver.ProxyTestServer{ - server1, - server2, - server3, - }, - }, 3, nil) - hub3.setProxy(t, mcu3) - - ctx, cancel := context.WithTimeout(t.Context(), testTimeout) - defer cancel() - - pubId := api.PublicSessionId("the-publisher") - pubSid := "1234567890" - pubListener := mock.NewListener(pubId + "-public") - pubInitiator := mock.NewInitiator("DE") - - session1 := &ClientSession{ - publicId: pubId, - publishers: make(map[sfu.StreamType]sfu.Publisher), - } - hub1.addSession(session1) - defer hub1.removeSession(session1) - - pub, err := mcu1.NewPublisher(ctx, pubListener, pubId, pubSid, sfu.StreamTypeVideo, sfu.NewPublisherSettings{ - MediaTypes: sfu.MediaTypeVideo | sfu.MediaTypeAudio, - }, pubInitiator) - require.NoError(t, err) - - defer pub.Close(context.Background()) - - session1.mu.Lock() - session1.publishers[sfu.StreamTypeVideo] = pub - session1.publisherWaiters.Wakeup() - session1.mu.Unlock() - - sub1Listener := mock.NewListener("subscriber-public-1") - sub1Initiator := mock.NewInitiator("US") - sub1, err := mcu2.NewSubscriber(ctx, sub1Listener, pubId, sfu.StreamTypeVideo, sub1Initiator) - require.NoError(t, err) - - defer sub1.Close(context.Background()) - - sub2Listener := mock.NewListener("subscriber-public-2") - sub2Initiator := mock.NewInitiator("US") - sub2, err := mcu3.NewSubscriber(ctx, sub2Listener, pubId, sfu.StreamTypeVideo, sub2Initiator) - require.NoError(t, err) - - defer sub2.Close(context.Background()) -} - -func Test_ProxyRemotePublisherWait(t *testing.T) { - t.Parallel() - - embedEtcd := etcdtest.NewServerForTest(t) - - grpcServer1, addr1 := grpctest.NewServerForTest(t) - grpcServer2, addr2 := grpctest.NewServerForTest(t) - - hub1 := &mockGrpcServerHub{} - hub2 := &mockGrpcServerHub{} - grpcServer1.SetHub(hub1) - grpcServer2.SetHub(hub2) - - embedEtcd.SetValue("/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}")) - embedEtcd.SetValue("/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}")) - - server1 := testserver.NewProxyServerForTest(t, "DE") - server2 := testserver.NewProxyServerForTest(t, "DE") - - mcu1, _ := proxytest.NewMcuProxyForTestWithOptions(t, testserver.ProxyTestOptions{ - Etcd: embedEtcd, - Servers: []testserver.ProxyTestServer{ - server1, - server2, - }, - }, 1, nil) - hub1.setProxy(t, mcu1) - mcu2, _ := proxytest.NewMcuProxyForTestWithOptions(t, testserver.ProxyTestOptions{ - Etcd: embedEtcd, - Servers: []testserver.ProxyTestServer{ - server1, - server2, - }, - }, 2, nil) - hub2.setProxy(t, mcu2) - - ctx, cancel := context.WithTimeout(t.Context(), testTimeout) - defer cancel() - - pubId := api.PublicSessionId("the-publisher") - pubSid := "1234567890" - pubListener := mock.NewListener(pubId + "-public") - pubInitiator := mock.NewInitiator("DE") - - session1 := &ClientSession{ - publicId: pubId, - publishers: make(map[sfu.StreamType]sfu.Publisher), - } - hub1.addSession(session1) - defer hub1.removeSession(session1) - - subListener := mock.NewListener("subscriber-public") - subInitiator := mock.NewInitiator("DE") - - done := make(chan struct{}) - go func() { - defer close(done) - sub, err := mcu2.NewSubscriber(ctx, subListener, pubId, sfu.StreamTypeVideo, subInitiator) - if !assert.NoError(t, err) { - return - } - - defer sub.Close(context.Background()) - }() - - // Give subscriber goroutine some time to start - time.Sleep(100 * time.Millisecond) - - pub, err := mcu1.NewPublisher(ctx, pubListener, pubId, pubSid, sfu.StreamTypeVideo, sfu.NewPublisherSettings{ - MediaTypes: sfu.MediaTypeVideo | sfu.MediaTypeAudio, - }, pubInitiator) - require.NoError(t, err) - - defer pub.Close(context.Background()) - - session1.mu.Lock() - session1.publishers[sfu.StreamTypeVideo] = pub - session1.publisherWaiters.Wakeup() - session1.mu.Unlock() - - select { - case <-done: - case <-ctx.Done(): - assert.NoError(t, ctx.Err()) - } -} - -func Test_ProxyRemotePublisherTemporary(t *testing.T) { - t.Parallel() - - assert := assert.New(t) - embedEtcd := etcdtest.NewServerForTest(t) - - grpcServer1, addr1 := grpctest.NewServerForTest(t) - grpcServer2, addr2 := grpctest.NewServerForTest(t) - - hub1 := &mockGrpcServerHub{} - hub2 := &mockGrpcServerHub{} - grpcServer1.SetHub(hub1) - grpcServer2.SetHub(hub2) - - embedEtcd.SetValue("/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}")) - embedEtcd.SetValue("/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}")) - - server1 := testserver.NewProxyServerForTest(t, "DE") - server2 := testserver.NewProxyServerForTest(t, "DE") - - mcu1, _ := proxytest.NewMcuProxyForTestWithOptions(t, testserver.ProxyTestOptions{ - Etcd: embedEtcd, - Servers: []testserver.ProxyTestServer{ - server1, - }, - }, 1, nil) - hub1.setProxy(t, mcu1) - mcu2, _ := proxytest.NewMcuProxyForTestWithOptions(t, testserver.ProxyTestOptions{ - Etcd: embedEtcd, - Servers: []testserver.ProxyTestServer{ - server2, - }, - }, 2, nil) - hub2.setProxy(t, mcu2) - - ctx, cancel := context.WithTimeout(t.Context(), testTimeout) - defer cancel() - - pubId := api.PublicSessionId("the-publisher") - pubSid := "1234567890" - pubListener := mock.NewListener(pubId + "-public") - pubInitiator := mock.NewInitiator("DE") - - session1 := &ClientSession{ - publicId: pubId, - publishers: make(map[sfu.StreamType]sfu.Publisher), - } - hub1.addSession(session1) - defer hub1.removeSession(session1) - - pub, err := mcu1.NewPublisher(ctx, pubListener, pubId, pubSid, sfu.StreamTypeVideo, sfu.NewPublisherSettings{ - MediaTypes: sfu.MediaTypeVideo | sfu.MediaTypeAudio, - }, pubInitiator) - require.NoError(t, err) - - defer pub.Close(context.Background()) - - session1.mu.Lock() - session1.publishers[sfu.StreamTypeVideo] = pub - session1.publisherWaiters.Wakeup() - session1.mu.Unlock() - - type connectionCounter interface { - ConnectionsCount() int - } - - if counter2, ok := mcu2.(connectionCounter); assert.True(ok) { - assert.Equal(1, counter2.ConnectionsCount()) - } - - subListener := mock.NewListener("subscriber-public") - subInitiator := mock.NewInitiator("DE") - sub, err := mcu2.NewSubscriber(ctx, subListener, pubId, sfu.StreamTypeVideo, subInitiator) - require.NoError(t, err) - - defer sub.Close(context.Background()) - - if connSub, ok := sub.(sfu.SubscriberWithConnectionUrlAndIP); assert.True(ok) { - url, ip := connSub.GetConnectionURL() - assert.Equal(server1.URL(), url) - assert.Empty(ip) - } - - // The temporary connection has been added - if counter2, ok := mcu2.(connectionCounter); assert.True(ok) { - assert.Equal(2, counter2.ConnectionsCount()) - } - - sub.Close(context.Background()) - - // Wait for temporary connection to be removed. -loop: - for { - select { - case <-ctx.Done(): - assert.NoError(ctx.Err()) - default: - if counter2, ok := mcu2.(connectionCounter); assert.True(ok) { - if counter2.ConnectionsCount() == 1 { - break loop - } - } - } - } -} - -func Test_ProxyConnectToken(t *testing.T) { - t.Parallel() - - embedEtcd := etcdtest.NewServerForTest(t) - - grpcServer1, addr1 := grpctest.NewServerForTest(t) - grpcServer2, addr2 := grpctest.NewServerForTest(t) - - hub1 := &mockGrpcServerHub{} - hub2 := &mockGrpcServerHub{} - grpcServer1.SetHub(hub1) - grpcServer2.SetHub(hub2) - - embedEtcd.SetValue("/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}")) - embedEtcd.SetValue("/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}")) - - server1 := testserver.NewProxyServerForTest(t, "DE") - server2 := testserver.NewProxyServerForTest(t, "DE") - - // Signaling server instances are in a cluster but don't share their proxies, - // i.e. they are only known to their local proxy, not the one of the other - // signaling server - so the connection token must be passed between them. - mcu1, _ := proxytest.NewMcuProxyForTestWithOptions(t, testserver.ProxyTestOptions{ - Etcd: embedEtcd, - Servers: []testserver.ProxyTestServer{ - server1, - }, - }, 1, nil) - hub1.setProxy(t, mcu1) - mcu2, _ := proxytest.NewMcuProxyForTestWithOptions(t, testserver.ProxyTestOptions{ - Etcd: embedEtcd, - Servers: []testserver.ProxyTestServer{ - server2, - }, - }, 2, nil) - hub2.setProxy(t, mcu2) - - ctx, cancel := context.WithTimeout(t.Context(), testTimeout) - defer cancel() - - pubId := api.PublicSessionId("the-publisher") - pubSid := "1234567890" - pubListener := mock.NewListener(pubId + "-public") - pubInitiator := mock.NewInitiator("DE") - - session1 := &ClientSession{ - publicId: pubId, - publishers: make(map[sfu.StreamType]sfu.Publisher), - } - hub1.addSession(session1) - defer hub1.removeSession(session1) - - pub, err := mcu1.NewPublisher(ctx, pubListener, pubId, pubSid, sfu.StreamTypeVideo, sfu.NewPublisherSettings{ - MediaTypes: sfu.MediaTypeVideo | sfu.MediaTypeAudio, - }, pubInitiator) - require.NoError(t, err) - - defer pub.Close(context.Background()) - - session1.mu.Lock() - session1.publishers[sfu.StreamTypeVideo] = pub - session1.publisherWaiters.Wakeup() - session1.mu.Unlock() - - subListener := mock.NewListener("subscriber-public") - subInitiator := mock.NewInitiator("DE") - sub, err := mcu2.NewSubscriber(ctx, subListener, pubId, sfu.StreamTypeVideo, subInitiator) - require.NoError(t, err) - - defer sub.Close(context.Background()) -} - -func Test_ProxyPublisherToken(t *testing.T) { - t.Parallel() - - embedEtcd := etcdtest.NewServerForTest(t) - - grpcServer1, addr1 := grpctest.NewServerForTest(t) - grpcServer2, addr2 := grpctest.NewServerForTest(t) - - hub1 := &mockGrpcServerHub{} - hub2 := &mockGrpcServerHub{} - grpcServer1.SetHub(hub1) - grpcServer2.SetHub(hub2) - - embedEtcd.SetValue("/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}")) - embedEtcd.SetValue("/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}")) - - server1 := testserver.NewProxyServerForTest(t, "DE") - server2 := testserver.NewProxyServerForTest(t, "US") - - // Signaling server instances are in a cluster but don't share their proxies, - // i.e. they are only known to their local proxy, not the one of the other - // signaling server - so the connection token must be passed between them. - // Also the subscriber is connecting from a different country, so a remote - // stream will be created that needs a valid token from the remote proxy. - mcu1, _ := proxytest.NewMcuProxyForTestWithOptions(t, testserver.ProxyTestOptions{ - Etcd: embedEtcd, - Servers: []testserver.ProxyTestServer{ - server1, - }, - }, 1, nil) - hub1.setProxy(t, mcu1) - mcu2, _ := proxytest.NewMcuProxyForTestWithOptions(t, testserver.ProxyTestOptions{ - Etcd: embedEtcd, - Servers: []testserver.ProxyTestServer{ - server2, - }, - }, 2, nil) - hub2.setProxy(t, mcu2) - // Support remote subscribers for the tests. - server1.Servers = append(server1.Servers, server2) - server2.Servers = append(server2.Servers, server1) - - ctx, cancel := context.WithTimeout(t.Context(), testTimeout) - defer cancel() - - pubId := api.PublicSessionId("the-publisher") - pubSid := "1234567890" - pubListener := mock.NewListener(pubId + "-public") - pubInitiator := mock.NewInitiator("DE") - - session1 := &ClientSession{ - publicId: pubId, - publishers: make(map[sfu.StreamType]sfu.Publisher), - } - hub1.addSession(session1) - defer hub1.removeSession(session1) - - pub, err := mcu1.NewPublisher(ctx, pubListener, pubId, pubSid, sfu.StreamTypeVideo, sfu.NewPublisherSettings{ - MediaTypes: sfu.MediaTypeVideo | sfu.MediaTypeAudio, - }, pubInitiator) - require.NoError(t, err) - - defer pub.Close(context.Background()) - - session1.mu.Lock() - session1.publishers[sfu.StreamTypeVideo] = pub - session1.publisherWaiters.Wakeup() - session1.mu.Unlock() - - subListener := mock.NewListener("subscriber-public") - subInitiator := mock.NewInitiator("US") - sub, err := mcu2.NewSubscriber(ctx, subListener, pubId, sfu.StreamTypeVideo, subInitiator) - require.NoError(t, err) - - defer sub.Close(context.Background()) -} diff --git a/server/hub_test.go b/server/hub_test.go index 2bd62a1..13f0af0 100644 --- a/server/hub_test.go +++ b/server/hub_test.go @@ -65,6 +65,7 @@ import ( "github.com/strukturag/nextcloud-spreed-signaling/mock" natstest "github.com/strukturag/nextcloud-spreed-signaling/nats/test" "github.com/strukturag/nextcloud-spreed-signaling/session" + "github.com/strukturag/nextcloud-spreed-signaling/sfu" sfutest "github.com/strukturag/nextcloud-spreed-signaling/sfu/test" "github.com/strukturag/nextcloud-spreed-signaling/talk" "github.com/strukturag/nextcloud-spreed-signaling/test" @@ -5259,3 +5260,53 @@ func TestGracefulShutdownOnExpiration(t *testing.T) { assert.Fail("should have shutdown") } } + +func TestHubGetPublisherIdForSessionId(t *testing.T) { + t.Parallel() + require := require.New(t) + assert := assert.New(t) + hub, _, _, server := CreateHubForTest(t) + + mcu := sfutest.NewSFU(t) + hub.SetMcu(mcu) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client, hello := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId) + + done := make(chan struct{}) + go func() { + defer close(done) + + if reply, err := hub.GetPublisherIdForSessionId(ctx, hello.Hello.SessionId, sfu.StreamTypeVideo); assert.NoError(err) && assert.NotNil(reply) { + assert.Equal("https://proxy.domain.invalid", reply.ProxyUrl) + assert.Equal("10.20.30.40", reply.Ip) + // The test-SFU doesn't support token creation. + assert.Empty(reply.ConnectToken) + assert.Empty(reply.PublisherToken) + + if session := hub.GetSessionByPublicId(hello.Hello.SessionId); assert.NotNil(session) { + if cs, ok := session.(*ClientSession); assert.True(ok) { + if pub := cs.GetPublisher(sfu.StreamTypeVideo); assert.NotNil(pub) { + assert.EqualValues(pub.PublisherId(), reply.PublisherId) + } + } + } + } + }() + + require.NoError(client.SendMessage(api.MessageClientMessageRecipient{ + Type: "session", + SessionId: hello.Hello.SessionId, + }, api.MessageClientMessageData{ + Type: "offer", + Sid: "54321", + RoomType: "video", + Payload: api.StringMap{ + "sdp": mock.MockSdpOfferAudioAndVideo, + }, + })) + + <-done +} diff --git a/sfu/janus/janus_test.go b/sfu/janus/janus_test.go index 5906b50..54c7c9c 100644 --- a/sfu/janus/janus_test.go +++ b/sfu/janus/janus_test.go @@ -23,6 +23,7 @@ package janus import ( "context" + "encoding/json" "strings" "sync" "sync/atomic" @@ -77,7 +78,16 @@ func newMcuJanusForTesting(t *testing.T) (*janusSFU, *janustest.JanusGateway) { } type TestMcuListener struct { - id api.PublicSessionId + id api.PublicSessionId + closed atomic.Bool + updatedOffer chan api.StringMap +} + +func NewTestMcuListener(id api.PublicSessionId) *TestMcuListener { + return &TestMcuListener{ + id: id, + updatedOffer: make(chan api.StringMap), + } } func (t *TestMcuListener) PublicId() api.PublicSessionId { @@ -85,7 +95,7 @@ func (t *TestMcuListener) PublicId() api.PublicSessionId { } func (t *TestMcuListener) OnUpdateOffer(client sfu.Client, offer api.StringMap) { - + t.updatedOffer <- offer } func (t *TestMcuListener) OnIceCandidate(client sfu.Client, candidate any) { @@ -105,7 +115,7 @@ func (t *TestMcuListener) PublisherClosed(publisher sfu.Publisher) { } func (t *TestMcuListener) SubscriberClosed(subscriber sfu.Subscriber) { - + t.closed.Store(true) } type TestMcuController struct { @@ -880,3 +890,844 @@ func Test_JanusRemotePublisher(t *testing.T) { assert.EqualValues(1, added.Load()) assert.EqualValues(1, removed.Load()) } + +type mockJanusStats struct { + called atomic.Bool + + mu sync.Mutex + // +checklocks:mu + value map[sfu.StreamType]int +} + +func (s *mockJanusStats) Value(streamType sfu.StreamType) int { + s.mu.Lock() + defer s.mu.Unlock() + + return s.value[streamType] +} + +func (s *mockJanusStats) IncSubscriber(streamType sfu.StreamType) { + s.called.Store(true) + + s.mu.Lock() + defer s.mu.Unlock() + + if s.value == nil { + s.value = make(map[sfu.StreamType]int) + } + s.value[streamType]++ +} + +func (s *mockJanusStats) DecSubscriber(streamType sfu.StreamType) { + s.called.Store(true) + + s.mu.Lock() + defer s.mu.Unlock() + + if s.value == nil { + s.value = make(map[sfu.StreamType]int) + } + s.value[streamType]-- +} + +func Test_SubscriberNoSuchRoom(t *testing.T) { + t.Parallel() + require := require.New(t) + assert := assert.New(t) + + stats := &mockJanusStats{} + + t.Cleanup(func() { + if !t.Failed() { + assert.True(stats.called.Load(), "stats were not called") + assert.Equal(0, stats.Value("video")) + } + }) + + mcu, gateway := newMcuJanusForTesting(t) + mcu.SetStats(stats) + gateway.RegisterHandlers(map[string]janustest.JanusHandler{ + "configure": func(room *janustest.JanusRoom, body, jsep api.StringMap) (any, *janus.ErrorMsg) { + assert.EqualValues(1, room.Id()) + return &janus.EventMsg{ + Jsep: api.StringMap{ + "type": "answer", + "sdp": mock.MockSdpAnswerAudioAndVideo, + }, + }, nil + }, + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + pubId := api.PublicSessionId("publisher-id") + listener1 := &TestMcuListener{ + id: pubId, + } + + settings1 := sfu.NewPublisherSettings{} + initiator1 := &TestMcuInitiator{ + country: "DE", + } + + pub, err := mcu.NewPublisher(ctx, listener1, pubId, "sid", sfu.StreamTypeVideo, settings1, initiator1) + require.NoError(err) + + defer pub.Close(context.Background()) + + msgData := &api.MessageClientMessageData{ + Type: "offer", + RoomType: "video", + Payload: api.StringMap{ + "sdp": mock.MockSdpOfferAudioAndVideo, + }, + } + require.NoError(msgData.CheckValid()) + payload, err := json.Marshal(msgData) + require.NoError(err) + msg := &api.MessageClientMessage{ + Recipient: api.MessageClientMessageRecipient{ + Type: "session", + SessionId: pubId, + }, + Data: payload, + } + + ch := make(chan struct{}) + pub.SendMessage(ctx, msg, msgData, func(err error, sm api.StringMap) { + defer func() { + ch <- struct{}{} + }() + + assert.NoError(err) + assert.Equal(mock.MockSdpAnswerAudioAndVideo, sm["sdp"]) + }) + <-ch + + listener2 := &TestMcuListener{ + id: pubId, + } + initiator2 := &TestMcuInitiator{ + country: "DE", + } + + sub, err := mcu.NewSubscriber(ctx, listener2, pubId, sfu.StreamTypeVideo, initiator2) + require.NoError(err) + + defer sub.Close(context.Background()) + + msgData = &api.MessageClientMessageData{ + Type: "requestoffer", + RoomType: "video", + } + require.NoError(msgData.CheckValid()) + payload, err = json.Marshal(msgData) + require.NoError(err) + msg = &api.MessageClientMessage{ + Recipient: api.MessageClientMessageRecipient{ + Type: "session", + SessionId: pubId, + }, + Data: payload, + } + + sub.SendMessage(ctx, msg, msgData, func(err error, sm api.StringMap) { + defer func() { + ch <- struct{}{} + }() + + assert.ErrorContains(err, "not created yet for") + assert.Empty(sm) + }) + <-ch + + assert.True(listener2.closed.Load()) + + listener3 := &TestMcuListener{ + id: pubId, + } + + sub, err = mcu.NewSubscriber(ctx, listener3, pubId, sfu.StreamTypeVideo, initiator2) + require.NoError(err) + + defer sub.Close(context.Background()) + + sub.SendMessage(ctx, msg, msgData, func(err error, sm api.StringMap) { + defer func() { + ch <- struct{}{} + }() + + assert.NoError(err) + assert.Equal(mock.MockSdpOfferAudioAndVideo, sm["sdp"]) + }) + <-ch + + assert.False(listener3.closed.Load()) +} + +func test_JanusSubscriberAlreadyJoined(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + + stats := &mockJanusStats{} + + t.Cleanup(func() { + if !t.Failed() { + assert.True(stats.called.Load(), "stats were not called") + assert.Equal(0, stats.Value("video")) + } + }) + + mcu, gateway := newMcuJanusForTesting(t) + mcu.SetStats(stats) + gateway.RegisterHandlers(map[string]janustest.JanusHandler{ + "configure": func(room *janustest.JanusRoom, body, jsep api.StringMap) (any, *janus.ErrorMsg) { + assert.EqualValues(1, room.Id()) + return &janus.EventMsg{ + Jsep: api.StringMap{ + "type": "answer", + "sdp": mock.MockSdpAnswerAudioAndVideo, + }, + }, nil + }, + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + pubId := api.PublicSessionId("publisher-id") + listener1 := &TestMcuListener{ + id: pubId, + } + + settings1 := sfu.NewPublisherSettings{} + initiator1 := &TestMcuInitiator{ + country: "DE", + } + + pub, err := mcu.NewPublisher(ctx, listener1, pubId, "sid", sfu.StreamTypeVideo, settings1, initiator1) + require.NoError(err) + + defer pub.Close(context.Background()) + + msgData := &api.MessageClientMessageData{ + Type: "offer", + RoomType: "video", + Payload: api.StringMap{ + "sdp": mock.MockSdpOfferAudioAndVideo, + }, + } + require.NoError(msgData.CheckValid()) + payload, err := json.Marshal(msgData) + require.NoError(err) + msg := &api.MessageClientMessage{ + Recipient: api.MessageClientMessageRecipient{ + Type: "session", + SessionId: pubId, + }, + Data: payload, + } + + ch := make(chan struct{}) + pub.SendMessage(ctx, msg, msgData, func(err error, sm api.StringMap) { + defer func() { + ch <- struct{}{} + }() + + assert.NoError(err) + assert.Equal(mock.MockSdpAnswerAudioAndVideo, sm["sdp"]) + }) + <-ch + + listener2 := &TestMcuListener{ + id: pubId, + } + initiator2 := &TestMcuInitiator{ + country: "DE", + } + + sub, err := mcu.NewSubscriber(ctx, listener2, pubId, sfu.StreamTypeVideo, initiator2) + require.NoError(err) + + defer sub.Close(context.Background()) + + msgData = &api.MessageClientMessageData{ + Type: "requestoffer", + RoomType: "video", + } + require.NoError(msgData.CheckValid()) + payload, err = json.Marshal(msgData) + require.NoError(err) + msg = &api.MessageClientMessage{ + Recipient: api.MessageClientMessageRecipient{ + Type: "session", + SessionId: pubId, + }, + Data: payload, + } + + sub.SendMessage(ctx, msg, msgData, func(err error, sm api.StringMap) { + defer func() { + ch <- struct{}{} + }() + + if strings.Contains(t.Name(), "AttachError") { + assert.ErrorContains(err, "already connected as subscriber for") + assert.Empty(sm) + } else { + assert.NoError(err) + assert.Equal(mock.MockSdpOfferAudioAndVideo, sm["sdp"]) + } + }) + <-ch + + if strings.Contains(t.Name(), "AttachError") { + assert.True(listener2.closed.Load()) + + listener3 := &TestMcuListener{ + id: pubId, + } + + sub, err := mcu.NewSubscriber(ctx, listener3, pubId, sfu.StreamTypeVideo, initiator2) + require.NoError(err) + + defer sub.Close(context.Background()) + + sub.SendMessage(ctx, msg, msgData, func(err error, sm api.StringMap) { + defer func() { + ch <- struct{}{} + }() + + assert.NoError(err) + assert.Equal(mock.MockSdpOfferAudioAndVideo, sm["sdp"]) + }) + <-ch + } +} + +func Test_SubscriberAlreadyJoined(t *testing.T) { + t.Parallel() + test_JanusSubscriberAlreadyJoined(t) +} + +func Test_SubscriberAlreadyJoinedAttachError(t *testing.T) { + t.Parallel() + test_JanusSubscriberAlreadyJoined(t) +} + +func Test_SubscriberTimeout(t *testing.T) { + t.Parallel() + require := require.New(t) + assert := assert.New(t) + + stats := &mockJanusStats{} + + t.Cleanup(func() { + if !t.Failed() { + assert.True(stats.called.Load(), "stats were not called") + assert.Equal(0, stats.Value("video")) + } + }) + + mcu, gateway := newMcuJanusForTesting(t) + mcu.SetStats(stats) + gateway.RegisterHandlers(map[string]janustest.JanusHandler{ + "configure": func(room *janustest.JanusRoom, body, jsep api.StringMap) (any, *janus.ErrorMsg) { + assert.EqualValues(1, room.Id()) + return &janus.EventMsg{ + Jsep: api.StringMap{ + "type": "answer", + "sdp": mock.MockSdpAnswerAudioAndVideo, + }, + }, nil + }, + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + pubId := api.PublicSessionId("publisher-id") + listener1 := &TestMcuListener{ + id: pubId, + } + + settings1 := sfu.NewPublisherSettings{} + initiator1 := &TestMcuInitiator{ + country: "DE", + } + + pub, err := mcu.NewPublisher(ctx, listener1, pubId, "sid", sfu.StreamTypeVideo, settings1, initiator1) + require.NoError(err) + + defer pub.Close(context.Background()) + + msgData := &api.MessageClientMessageData{ + Type: "offer", + RoomType: "video", + Payload: api.StringMap{ + "sdp": mock.MockSdpOfferAudioAndVideo, + }, + } + require.NoError(msgData.CheckValid()) + payload, err := json.Marshal(msgData) + require.NoError(err) + msg := &api.MessageClientMessage{ + Recipient: api.MessageClientMessageRecipient{ + Type: "session", + SessionId: pubId, + }, + Data: payload, + } + + ch := make(chan struct{}) + pub.SendMessage(ctx, msg, msgData, func(err error, sm api.StringMap) { + defer func() { + ch <- struct{}{} + }() + + assert.NoError(err) + assert.Equal(mock.MockSdpAnswerAudioAndVideo, sm["sdp"]) + }) + <-ch + + oldTimeout := mcu.Settings().Timeout() + mcu.Settings().SetTimeout(100 * time.Millisecond) + + listener2 := &TestMcuListener{ + id: pubId, + } + initiator2 := &TestMcuInitiator{ + country: "DE", + } + + sub, err := mcu.NewSubscriber(ctx, listener2, pubId, sfu.StreamTypeVideo, initiator2) + require.NoError(err) + + defer sub.Close(context.Background()) + + msgData = &api.MessageClientMessageData{ + Type: "requestoffer", + RoomType: "video", + } + require.NoError(msgData.CheckValid()) + payload, err = json.Marshal(msgData) + require.NoError(err) + msg = &api.MessageClientMessage{ + Recipient: api.MessageClientMessageRecipient{ + Type: "session", + SessionId: pubId, + }, + Data: payload, + } + + sub.SendMessage(ctx, msg, msgData, func(err error, sm api.StringMap) { + defer func() { + ch <- struct{}{} + }() + + assert.ErrorIs(err, context.DeadlineExceeded) + assert.Empty(sm) + }) + <-ch + + assert.True(listener2.closed.Load()) + + mcu.Settings().SetTimeout(oldTimeout) + + listener3 := &TestMcuListener{ + id: pubId, + } + + sub, err = mcu.NewSubscriber(ctx, listener3, pubId, sfu.StreamTypeVideo, initiator2) + require.NoError(err) + + defer sub.Close(context.Background()) + + sub.SendMessage(ctx, msg, msgData, func(err error, sm api.StringMap) { + defer func() { + ch <- struct{}{} + }() + + assert.NoError(err) + assert.Equal(mock.MockSdpOfferAudioAndVideo, sm["sdp"]) + }) + <-ch +} + +func Test_SubscriberCloseEmptyStreams(t *testing.T) { + t.Parallel() + require := require.New(t) + assert := assert.New(t) + + stats := &mockJanusStats{} + + t.Cleanup(func() { + if !t.Failed() { + assert.True(stats.called.Load(), "stats were not called") + assert.Equal(0, stats.Value("video")) + } + }) + + mcu, gateway := newMcuJanusForTesting(t) + mcu.SetStats(stats) + gateway.RegisterHandlers(map[string]janustest.JanusHandler{ + "configure": func(room *janustest.JanusRoom, body, jsep api.StringMap) (any, *janus.ErrorMsg) { + assert.EqualValues(1, room.Id()) + return &janus.EventMsg{ + Jsep: api.StringMap{ + "type": "answer", + "sdp": mock.MockSdpAnswerAudioAndVideo, + }, + }, nil + }, + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + pubId := api.PublicSessionId("publisher-id") + listener1 := &TestMcuListener{ + id: pubId, + } + + settings1 := sfu.NewPublisherSettings{} + initiator1 := &TestMcuInitiator{ + country: "DE", + } + + pub, err := mcu.NewPublisher(ctx, listener1, pubId, "sid", sfu.StreamTypeVideo, settings1, initiator1) + require.NoError(err) + + defer pub.Close(context.Background()) + + msgData := &api.MessageClientMessageData{ + Type: "offer", + RoomType: "video", + Payload: api.StringMap{ + "sdp": mock.MockSdpOfferAudioAndVideo, + }, + } + require.NoError(msgData.CheckValid()) + payload, err := json.Marshal(msgData) + require.NoError(err) + msg := &api.MessageClientMessage{ + Recipient: api.MessageClientMessageRecipient{ + Type: "session", + SessionId: pubId, + }, + Data: payload, + } + + ch := make(chan struct{}) + pub.SendMessage(ctx, msg, msgData, func(err error, sm api.StringMap) { + defer func() { + ch <- struct{}{} + }() + + assert.NoError(err) + assert.Equal(mock.MockSdpAnswerAudioAndVideo, sm["sdp"]) + }) + <-ch + + listener2 := &TestMcuListener{ + id: pubId, + } + initiator2 := &TestMcuInitiator{ + country: "DE", + } + + sub, err := mcu.NewSubscriber(ctx, listener2, pubId, sfu.StreamTypeVideo, initiator2) + require.NoError(err) + + defer sub.Close(context.Background()) + + msgData = &api.MessageClientMessageData{ + Type: "requestoffer", + RoomType: "video", + } + require.NoError(msgData.CheckValid()) + payload, err = json.Marshal(msgData) + require.NoError(err) + msg = &api.MessageClientMessage{ + Recipient: api.MessageClientMessageRecipient{ + Type: "session", + SessionId: pubId, + }, + Data: payload, + } + + sub.SendMessage(ctx, msg, msgData, func(err error, sm api.StringMap) { + defer func() { + ch <- struct{}{} + }() + + assert.NoError(err) + assert.Equal(mock.MockSdpOfferAudioAndVideo, sm["sdp"]) + }) + <-ch + + subscriber, ok := sub.(*janusSubscriber) + require.True(ok) + handle := subscriber.JanusHandle() + require.NotNil(handle) + + for ctx.Err() == nil { + if handle = subscriber.JanusHandle(); handle == nil && listener2.closed.Load() { + break + } + + time.Sleep(time.Millisecond) + } + + assert.Nil(handle, "subscriber should have been closed") + assert.True(listener2.closed.Load()) +} + +func Test_SubscriberRoomDestroyed(t *testing.T) { + t.Parallel() + require := require.New(t) + assert := assert.New(t) + + stats := &mockJanusStats{} + + t.Cleanup(func() { + if !t.Failed() { + assert.True(stats.called.Load(), "stats were not called") + assert.Equal(0, stats.Value("video")) + } + }) + + mcu, gateway := newMcuJanusForTesting(t) + mcu.SetStats(stats) + gateway.RegisterHandlers(map[string]janustest.JanusHandler{ + "configure": func(room *janustest.JanusRoom, body, jsep api.StringMap) (any, *janus.ErrorMsg) { + assert.EqualValues(1, room.Id()) + return &janus.EventMsg{ + Jsep: api.StringMap{ + "type": "answer", + "sdp": mock.MockSdpAnswerAudioAndVideo, + }, + }, nil + }, + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + pubId := api.PublicSessionId("publisher-id") + listener1 := &TestMcuListener{ + id: pubId, + } + + settings1 := sfu.NewPublisherSettings{} + initiator1 := &TestMcuInitiator{ + country: "DE", + } + + pub, err := mcu.NewPublisher(ctx, listener1, pubId, "sid", sfu.StreamTypeVideo, settings1, initiator1) + require.NoError(err) + + defer pub.Close(context.Background()) + + msgData := &api.MessageClientMessageData{ + Type: "offer", + RoomType: "video", + Payload: api.StringMap{ + "sdp": mock.MockSdpOfferAudioAndVideo, + }, + } + require.NoError(msgData.CheckValid()) + payload, err := json.Marshal(msgData) + require.NoError(err) + msg := &api.MessageClientMessage{ + Recipient: api.MessageClientMessageRecipient{ + Type: "session", + SessionId: pubId, + }, + Data: payload, + } + + ch := make(chan struct{}) + pub.SendMessage(ctx, msg, msgData, func(err error, sm api.StringMap) { + defer func() { + ch <- struct{}{} + }() + + assert.NoError(err) + assert.Equal(mock.MockSdpAnswerAudioAndVideo, sm["sdp"]) + }) + <-ch + + listener2 := &TestMcuListener{ + id: pubId, + } + initiator2 := &TestMcuInitiator{ + country: "DE", + } + + sub, err := mcu.NewSubscriber(ctx, listener2, pubId, sfu.StreamTypeVideo, initiator2) + require.NoError(err) + + defer sub.Close(context.Background()) + + msgData = &api.MessageClientMessageData{ + Type: "requestoffer", + RoomType: "video", + } + require.NoError(msgData.CheckValid()) + payload, err = json.Marshal(msgData) + require.NoError(err) + msg = &api.MessageClientMessage{ + Recipient: api.MessageClientMessageRecipient{ + Type: "session", + SessionId: pubId, + }, + Data: payload, + } + + sub.SendMessage(ctx, msg, msgData, func(err error, sm api.StringMap) { + defer func() { + ch <- struct{}{} + }() + + assert.NoError(err) + assert.Equal(mock.MockSdpOfferAudioAndVideo, sm["sdp"]) + }) + <-ch + + subscriber, ok := sub.(*janusSubscriber) + require.True(ok) + handle := subscriber.JanusHandle() + require.NotNil(handle) + + for ctx.Err() == nil { + if handle = subscriber.JanusHandle(); handle == nil && listener2.closed.Load() { + break + } + + time.Sleep(time.Millisecond) + } + + assert.Nil(handle, "subscriber should have been closed") + assert.True(listener2.closed.Load()) +} + +func Test_SubscriberUpdateOffer(t *testing.T) { + t.Parallel() + require := require.New(t) + assert := assert.New(t) + + stats := &mockJanusStats{} + + t.Cleanup(func() { + if !t.Failed() { + assert.True(stats.called.Load(), "stats were not called") + assert.Equal(0, stats.Value("video")) + } + }) + + mcu, gateway := newMcuJanusForTesting(t) + mcu.SetStats(stats) + gateway.RegisterHandlers(map[string]janustest.JanusHandler{ + "configure": func(room *janustest.JanusRoom, body, jsep api.StringMap) (any, *janus.ErrorMsg) { + assert.EqualValues(1, room.Id()) + return &janus.EventMsg{ + Jsep: api.StringMap{ + "type": "answer", + "sdp": mock.MockSdpAnswerAudioAndVideo, + }, + }, nil + }, + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + pubId := api.PublicSessionId("publisher-id") + listener1 := &TestMcuListener{ + id: pubId, + } + + settings1 := sfu.NewPublisherSettings{} + initiator1 := &TestMcuInitiator{ + country: "DE", + } + + pub, err := mcu.NewPublisher(ctx, listener1, pubId, "sid", sfu.StreamTypeVideo, settings1, initiator1) + require.NoError(err) + + defer pub.Close(context.Background()) + + msgData := &api.MessageClientMessageData{ + Type: "offer", + RoomType: "video", + Payload: api.StringMap{ + "sdp": mock.MockSdpOfferAudioAndVideo, + }, + } + require.NoError(msgData.CheckValid()) + payload, err := json.Marshal(msgData) + require.NoError(err) + msg := &api.MessageClientMessage{ + Recipient: api.MessageClientMessageRecipient{ + Type: "session", + SessionId: pubId, + }, + Data: payload, + } + + ch := make(chan struct{}) + pub.SendMessage(ctx, msg, msgData, func(err error, sm api.StringMap) { + defer func() { + ch <- struct{}{} + }() + + assert.NoError(err) + assert.Equal(mock.MockSdpAnswerAudioAndVideo, sm["sdp"]) + }) + <-ch + + listener2 := NewTestMcuListener(pubId) + initiator2 := &TestMcuInitiator{ + country: "DE", + } + + sub, err := mcu.NewSubscriber(ctx, listener2, pubId, sfu.StreamTypeVideo, initiator2) + require.NoError(err) + + defer sub.Close(context.Background()) + + msgData = &api.MessageClientMessageData{ + Type: "requestoffer", + RoomType: "video", + } + require.NoError(msgData.CheckValid()) + payload, err = json.Marshal(msgData) + require.NoError(err) + msg = &api.MessageClientMessage{ + Recipient: api.MessageClientMessageRecipient{ + Type: "session", + SessionId: pubId, + }, + Data: payload, + } + + sub.SendMessage(ctx, msg, msgData, func(err error, sm api.StringMap) { + defer func() { + ch <- struct{}{} + }() + + assert.NoError(err) + assert.Equal(mock.MockSdpOfferAudioAndVideo, sm["sdp"]) + }) + <-ch + + // Test MCU will trigger an updated offer. + select { + case offer := <-listener2.updatedOffer: + assert.Equal(mock.MockSdpOfferAudioOnly, offer["sdp"]) + case <-ctx.Done(): + assert.NoError(ctx.Err()) + } +} diff --git a/sfu/janus/publisher_stats_counter_test.go b/sfu/janus/publisher_stats_counter_test.go index 221b9f2..123d710 100644 --- a/sfu/janus/publisher_stats_counter_test.go +++ b/sfu/janus/publisher_stats_counter_test.go @@ -88,6 +88,7 @@ func TestPublisherStatsPrometheus(t *testing.T) { t.Parallel() RegisterStats() + UnregisterStats() } func TestPublisherStatsCounter(t *testing.T) { diff --git a/sfu/janus/stream_selection.go b/sfu/janus/stream_selection.go index 046b579..aa7ad10 100644 --- a/sfu/janus/stream_selection.go +++ b/sfu/janus/stream_selection.go @@ -22,35 +22,35 @@ package janus import ( - "database/sql" "fmt" "github.com/strukturag/nextcloud-spreed-signaling/api" + "github.com/strukturag/nextcloud-spreed-signaling/internal" ) type streamSelection struct { - substream sql.NullInt16 - temporal sql.NullInt16 - audio sql.NullBool - video sql.NullBool + substream *int + temporal *int + audio *bool + video *bool } func (s *streamSelection) HasValues() bool { - return s.substream.Valid || s.temporal.Valid || s.audio.Valid || s.video.Valid + return s.substream != nil || s.temporal != nil || s.audio != nil || s.video != nil } func (s *streamSelection) AddToMessage(message api.StringMap) { - if s.substream.Valid { - message["substream"] = s.substream.Int16 + if s.substream != nil { + message["substream"] = *s.substream } - if s.temporal.Valid { - message["temporal"] = s.temporal.Int16 + if s.temporal != nil { + message["temporal"] = *s.temporal } - if s.audio.Valid { - message["audio"] = s.audio.Bool + if s.audio != nil { + message["audio"] = *s.audio } - if s.video.Valid { - message["video"] = s.video.Bool + if s.video != nil { + message["video"] = *s.video } } @@ -59,14 +59,11 @@ func parseStreamSelection(payload api.StringMap) (*streamSelection, error) { if value, found := payload["substream"]; found { switch value := value.(type) { case int: - stream.substream.Valid = true - stream.substream.Int16 = int16(value) + stream.substream = &value case float32: - stream.substream.Valid = true - stream.substream.Int16 = int16(value) + stream.substream = internal.MakePtr(int(value)) case float64: - stream.substream.Valid = true - stream.substream.Int16 = int16(value) + stream.substream = internal.MakePtr(int(value)) default: return nil, fmt.Errorf("unsupported substream value: %v", value) } @@ -75,14 +72,11 @@ func parseStreamSelection(payload api.StringMap) (*streamSelection, error) { if value, found := payload["temporal"]; found { switch value := value.(type) { case int: - stream.temporal.Valid = true - stream.temporal.Int16 = int16(value) + stream.temporal = &value case float32: - stream.temporal.Valid = true - stream.temporal.Int16 = int16(value) + stream.temporal = internal.MakePtr(int(value)) case float64: - stream.temporal.Valid = true - stream.temporal.Int16 = int16(value) + stream.temporal = internal.MakePtr(int(value)) default: return nil, fmt.Errorf("unsupported temporal value: %v", value) } @@ -91,8 +85,7 @@ func parseStreamSelection(payload api.StringMap) (*streamSelection, error) { if value, found := payload["audio"]; found { switch value := value.(type) { case bool: - stream.audio.Valid = true - stream.audio.Bool = value + stream.audio = &value default: return nil, fmt.Errorf("unsupported audio value: %v", value) } @@ -101,8 +94,7 @@ func parseStreamSelection(payload api.StringMap) (*streamSelection, error) { if value, found := payload["video"]; found { switch value := value.(type) { case bool: - stream.video.Valid = true - stream.video.Bool = value + stream.video = &value default: return nil, fmt.Errorf("unsupported video value: %v", value) } diff --git a/sfu/janus/stream_selection_test.go b/sfu/janus/stream_selection_test.go new file mode 100644 index 0000000..8d5463d --- /dev/null +++ b/sfu/janus/stream_selection_test.go @@ -0,0 +1,92 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2026 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package janus + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/strukturag/nextcloud-spreed-signaling/api" +) + +func TestStreamSelection(t *testing.T) { + t.Parallel() + + assert := assert.New(t) + testcases := []api.StringMap{ + {}, + { + "substream": 1.0, + }, + { + "temporal": 1.0, + }, + { + "substream": float32(1.0), + }, + { + "temporal": float32(1.0), + }, + { + "substream": 1, + "temporal": 3, + }, + { + "substream": 1, + "audio": true, + "video": false, + }, + } + + for idx, tc := range testcases { + parsed, err := parseStreamSelection(tc) + if assert.NoError(err, "failed for testcase %d: %+v", idx, tc) { + assert.Equal(len(tc) > 0, parsed.HasValues(), "failed for testcase %d: %+v", idx, tc) + m := make(api.StringMap) + parsed.AddToMessage(m) + for k, v := range tc { + assert.EqualValues(v, m[k], "failed for key %s in testcase %d", k, idx) + } + } + } + + _, err := parseStreamSelection(api.StringMap{ + "substream": "foo", + }) + assert.ErrorContains(err, "unsupported substream value") + + _, err = parseStreamSelection(api.StringMap{ + "temporal": "foo", + }) + assert.ErrorContains(err, "unsupported temporal value") + + _, err = parseStreamSelection(api.StringMap{ + "audio": 1, + }) + assert.ErrorContains(err, "unsupported audio value") + + _, err = parseStreamSelection(api.StringMap{ + "video": "true", + }) + assert.ErrorContains(err, "unsupported video value") +} diff --git a/sfu/janus/subscriber.go b/sfu/janus/subscriber.go index a1f99c6..d581d40 100644 --- a/sfu/janus/subscriber.go +++ b/sfu/janus/subscriber.go @@ -32,10 +32,6 @@ import ( "github.com/strukturag/nextcloud-spreed-signaling/sfu/janus/janus" ) -type Subscriber interface { - JanusHandle() *janus.Handle -} - type janusSubscriber struct { janusClient diff --git a/sfu/proxy/proxy_test.go b/sfu/proxy/proxy_test.go index 41049fd..47c225f 100644 --- a/sfu/proxy/proxy_test.go +++ b/sfu/proxy/proxy_test.go @@ -25,6 +25,7 @@ import ( "context" "crypto/rand" "crypto/rsa" + "encoding/json" "fmt" "net" "net/url" @@ -32,6 +33,7 @@ import ( "slices" "strconv" "strings" + "sync" "testing" "time" @@ -40,10 +42,12 @@ import ( "github.com/stretchr/testify/require" "github.com/strukturag/nextcloud-spreed-signaling/api" + "github.com/strukturag/nextcloud-spreed-signaling/async" dnstest "github.com/strukturag/nextcloud-spreed-signaling/dns/test" "github.com/strukturag/nextcloud-spreed-signaling/etcd" etcdtest "github.com/strukturag/nextcloud-spreed-signaling/etcd/test" "github.com/strukturag/nextcloud-spreed-signaling/geoip" + "github.com/strukturag/nextcloud-spreed-signaling/grpc" grpctest "github.com/strukturag/nextcloud-spreed-signaling/grpc/test" "github.com/strukturag/nextcloud-spreed-signaling/internal" "github.com/strukturag/nextcloud-spreed-signaling/log" @@ -1238,3 +1242,509 @@ func Test_ProxyResumeFail(t *testing.T) { assert.NotEqual(sessionId, connections[0].SessionId()) } } + +type publisherHub struct { + grpctest.MockHub + + mu sync.Mutex + // +checklocks:mu + publishers map[api.PublicSessionId]*proxyPublisher + waiter async.ChannelWaiters // +checklocksignore: Has its own locking. +} + +func newPublisherHub() *publisherHub { + return &publisherHub{ + publishers: make(map[api.PublicSessionId]*proxyPublisher), + } +} + +func (h *publisherHub) addPublisher(publisher *proxyPublisher) { + h.mu.Lock() + defer h.mu.Unlock() + + h.publishers[publisher.PublisherId()] = publisher + h.waiter.Wakeup() +} + +func (h *publisherHub) GetPublisherIdForSessionId(ctx context.Context, sessionId api.PublicSessionId, streamType sfu.StreamType) (*grpc.GetPublisherIdReply, error) { + h.mu.Lock() + defer h.mu.Unlock() + + pub, found := h.publishers[sessionId] + if !found { + ch := make(chan struct{}, 1) + id := h.waiter.Add(ch) + defer h.waiter.Remove(id) + + for !found { + h.mu.Unlock() + select { + case <-ch: + h.mu.Lock() + pub, found = h.publishers[sessionId] + case <-ctx.Done(): + h.mu.Lock() + return nil, ctx.Err() + } + } + } + + connToken, err := pub.conn.proxy.CreateToken("") + if err != nil { + return nil, err + } + pubToken, err := pub.conn.proxy.CreateToken(string(pub.Id())) + if err != nil { + return nil, err + } + + reply := &grpc.GetPublisherIdReply{ + PublisherId: pub.Id(), + ProxyUrl: pub.conn.rawUrl, + ConnectToken: connToken, + PublisherToken: pubToken, + } + if ip := pub.conn.ip; len(ip) > 0 { + reply.Ip = ip.String() + } + return reply, nil +} + +func Test_ProxyRemotePublisher(t *testing.T) { + t.Parallel() + + assert := assert.New(t) + embedEtcd := etcdtest.NewServerForTest(t) + + grpcServer1, addr1 := grpctest.NewServerForTest(t) + grpcServer2, addr2 := grpctest.NewServerForTest(t) + + hub1 := newPublisherHub() + hub2 := newPublisherHub() + grpcServer1.SetHub(hub1) + grpcServer2.SetHub(hub2) + + embedEtcd.SetValue("/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}")) + embedEtcd.SetValue("/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}")) + + server1 := testserver.NewProxyServerForTest(t, "DE") + server2 := testserver.NewProxyServerForTest(t, "DE") + + mcu1, _ := newMcuProxyForTestWithOptions(t, testserver.ProxyTestOptions{ + Etcd: embedEtcd, + Servers: []testserver.ProxyTestServer{ + server1, + server2, + }, + }, 1, nil) + mcu2, _ := newMcuProxyForTestWithOptions(t, testserver.ProxyTestOptions{ + Etcd: embedEtcd, + Servers: []testserver.ProxyTestServer{ + server1, + server2, + }, + }, 2, nil) + + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) + defer cancel() + + pubId := api.PublicSessionId("the-publisher") + pubSid := "1234567890" + pubListener := mock.NewListener(pubId + "-public") + pubInitiator := mock.NewInitiator("DE") + + pub, err := mcu1.NewPublisher(ctx, pubListener, pubId, pubSid, sfu.StreamTypeVideo, sfu.NewPublisherSettings{ + MediaTypes: sfu.MediaTypeVideo | sfu.MediaTypeAudio, + }, pubInitiator) + require.NoError(t, err) + + defer pub.Close(context.Background()) + + if proxyPub, ok := pub.(*proxyPublisher); assert.True(ok) { + hub1.addPublisher(proxyPub) + } + + subListener := mock.NewListener("subscriber-public") + subInitiator := mock.NewInitiator("DE") + sub, err := mcu2.NewSubscriber(ctx, subListener, pubId, sfu.StreamTypeVideo, subInitiator) + require.NoError(t, err) + + defer sub.Close(context.Background()) +} + +func Test_ProxyMultipleRemotePublisher(t *testing.T) { + t.Parallel() + + assert := assert.New(t) + embedEtcd := etcdtest.NewServerForTest(t) + + grpcServer1, addr1 := grpctest.NewServerForTest(t) + grpcServer2, addr2 := grpctest.NewServerForTest(t) + grpcServer3, addr3 := grpctest.NewServerForTest(t) + + hub1 := newPublisherHub() + hub2 := newPublisherHub() + hub3 := newPublisherHub() + grpcServer1.SetHub(hub1) + grpcServer2.SetHub(hub2) + grpcServer3.SetHub(hub3) + + embedEtcd.SetValue("/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}")) + embedEtcd.SetValue("/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}")) + embedEtcd.SetValue("/grpctargets/three", []byte("{\"address\":\""+addr3+"\"}")) + + server1 := testserver.NewProxyServerForTest(t, "DE") + server2 := testserver.NewProxyServerForTest(t, "US") + server3 := testserver.NewProxyServerForTest(t, "US") + + mcu1, _ := newMcuProxyForTestWithOptions(t, testserver.ProxyTestOptions{ + Etcd: embedEtcd, + Servers: []testserver.ProxyTestServer{ + server1, + server2, + server3, + }, + }, 1, nil) + mcu2, _ := newMcuProxyForTestWithOptions(t, testserver.ProxyTestOptions{ + Etcd: embedEtcd, + Servers: []testserver.ProxyTestServer{ + server1, + server2, + server3, + }, + }, 2, nil) + mcu3, _ := newMcuProxyForTestWithOptions(t, testserver.ProxyTestOptions{ + Etcd: embedEtcd, + Servers: []testserver.ProxyTestServer{ + server1, + server2, + server3, + }, + }, 3, nil) + + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) + defer cancel() + + pubId := api.PublicSessionId("the-publisher") + pubSid := "1234567890" + pubListener := mock.NewListener(pubId + "-public") + pubInitiator := mock.NewInitiator("DE") + + pub, err := mcu1.NewPublisher(ctx, pubListener, pubId, pubSid, sfu.StreamTypeVideo, sfu.NewPublisherSettings{ + MediaTypes: sfu.MediaTypeVideo | sfu.MediaTypeAudio, + }, pubInitiator) + require.NoError(t, err) + + defer pub.Close(context.Background()) + + if proxyPub, ok := pub.(*proxyPublisher); assert.True(ok) { + hub1.addPublisher(proxyPub) + } + + sub1Listener := mock.NewListener("subscriber-public-1") + sub1Initiator := mock.NewInitiator("US") + sub1, err := mcu2.NewSubscriber(ctx, sub1Listener, pubId, sfu.StreamTypeVideo, sub1Initiator) + require.NoError(t, err) + + defer sub1.Close(context.Background()) + + sub2Listener := mock.NewListener("subscriber-public-2") + sub2Initiator := mock.NewInitiator("US") + sub2, err := mcu3.NewSubscriber(ctx, sub2Listener, pubId, sfu.StreamTypeVideo, sub2Initiator) + require.NoError(t, err) + + defer sub2.Close(context.Background()) +} + +func Test_ProxyRemotePublisherWait(t *testing.T) { + t.Parallel() + + assert := assert.New(t) + embedEtcd := etcdtest.NewServerForTest(t) + + grpcServer1, addr1 := grpctest.NewServerForTest(t) + grpcServer2, addr2 := grpctest.NewServerForTest(t) + + hub1 := newPublisherHub() + hub2 := newPublisherHub() + grpcServer1.SetHub(hub1) + grpcServer2.SetHub(hub2) + + embedEtcd.SetValue("/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}")) + embedEtcd.SetValue("/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}")) + + server1 := testserver.NewProxyServerForTest(t, "DE") + server2 := testserver.NewProxyServerForTest(t, "DE") + + mcu1, _ := newMcuProxyForTestWithOptions(t, testserver.ProxyTestOptions{ + Etcd: embedEtcd, + Servers: []testserver.ProxyTestServer{ + server1, + server2, + }, + }, 1, nil) + mcu2, _ := newMcuProxyForTestWithOptions(t, testserver.ProxyTestOptions{ + Etcd: embedEtcd, + Servers: []testserver.ProxyTestServer{ + server1, + server2, + }, + }, 2, nil) + + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) + defer cancel() + + pubId := api.PublicSessionId("the-publisher") + pubSid := "1234567890" + pubListener := mock.NewListener(pubId + "-public") + pubInitiator := mock.NewInitiator("DE") + + subListener := mock.NewListener("subscriber-public") + subInitiator := mock.NewInitiator("DE") + + done := make(chan struct{}) + go func() { + defer close(done) + sub, err := mcu2.NewSubscriber(ctx, subListener, pubId, sfu.StreamTypeVideo, subInitiator) + if !assert.NoError(err) { + return + } + + defer sub.Close(context.Background()) + }() + + // Give subscriber goroutine some time to start + time.Sleep(100 * time.Millisecond) + + pub, err := mcu1.NewPublisher(ctx, pubListener, pubId, pubSid, sfu.StreamTypeVideo, sfu.NewPublisherSettings{ + MediaTypes: sfu.MediaTypeVideo | sfu.MediaTypeAudio, + }, pubInitiator) + require.NoError(t, err) + + defer pub.Close(context.Background()) + + if proxyPub, ok := pub.(*proxyPublisher); assert.True(ok) { + hub1.addPublisher(proxyPub) + } + + select { + case <-done: + case <-ctx.Done(): + assert.NoError(ctx.Err()) + } +} + +func Test_ProxyRemotePublisherTemporary(t *testing.T) { + t.Parallel() + + require := require.New(t) + assert := assert.New(t) + embedEtcd := etcdtest.NewServerForTest(t) + + server, addr1 := grpctest.NewServerForTest(t) + hub := newPublisherHub() + server.SetHub(hub) + + target := grpc.TargetInformationEtcd{ + Address: addr1, + } + encoded, err := json.Marshal(target) + require.NoError(err) + embedEtcd.SetValue("/grpctargets/server", encoded) + + server1 := testserver.NewProxyServerForTest(t, "DE") + server2 := testserver.NewProxyServerForTest(t, "DE") + + mcu1, _ := newMcuProxyForTestWithOptions(t, testserver.ProxyTestOptions{ + Etcd: embedEtcd, + Servers: []testserver.ProxyTestServer{ + server1, + }, + }, 1, nil) + mcu2, _ := newMcuProxyForTestWithOptions(t, testserver.ProxyTestOptions{ + Etcd: embedEtcd, + Servers: []testserver.ProxyTestServer{ + server2, + }, + }, 2, nil) + + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) + defer cancel() + + pubId := api.PublicSessionId("the-publisher") + pubSid := "1234567890" + pubListener := mock.NewListener(pubId + "-public") + pubInitiator := mock.NewInitiator("DE") + + pub, err := mcu1.NewPublisher(ctx, pubListener, pubId, pubSid, sfu.StreamTypeVideo, sfu.NewPublisherSettings{ + MediaTypes: sfu.MediaTypeVideo | sfu.MediaTypeAudio, + }, pubInitiator) + require.NoError(err) + + defer pub.Close(context.Background()) + + if proxyPub, ok := pub.(*proxyPublisher); assert.True(ok) { + hub.addPublisher(proxyPub) + } + + subListener := mock.NewListener("subscriber-public") + subInitiator := mock.NewInitiator("DE") + sub, err := mcu2.NewSubscriber(ctx, subListener, pubId, sfu.StreamTypeVideo, subInitiator) + require.NoError(err) + + defer sub.Close(context.Background()) + + if connSub, ok := sub.(*proxySubscriber); assert.True(ok) { + assert.Equal(server1.URL(), connSub.conn.rawUrl) + assert.Empty(connSub.conn.ip) + } + + // The temporary connection has been added + assert.Equal(2, mcu2.ConnectionsCount()) + + sub.Close(context.Background()) + + // Wait for temporary connection to be removed. +loop: + for { + select { + case <-ctx.Done(): + assert.NoError(ctx.Err()) + default: + if mcu2.ConnectionsCount() == 1 { + break loop + } + } + } +} + +func Test_ProxyConnectToken(t *testing.T) { + t.Parallel() + + assert := assert.New(t) + embedEtcd := etcdtest.NewServerForTest(t) + + grpcServer1, addr1 := grpctest.NewServerForTest(t) + grpcServer2, addr2 := grpctest.NewServerForTest(t) + + hub1 := newPublisherHub() + hub2 := newPublisherHub() + grpcServer1.SetHub(hub1) + grpcServer2.SetHub(hub2) + + embedEtcd.SetValue("/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}")) + embedEtcd.SetValue("/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}")) + + server1 := testserver.NewProxyServerForTest(t, "DE") + server2 := testserver.NewProxyServerForTest(t, "DE") + + // Signaling server instances are in a cluster but don't share their proxies, + // i.e. they are only known to their local proxy, not the one of the other + // signaling server - so the connection token must be passed between them. + mcu1, _ := newMcuProxyForTestWithOptions(t, testserver.ProxyTestOptions{ + Etcd: embedEtcd, + Servers: []testserver.ProxyTestServer{ + server1, + }, + }, 1, nil) + mcu2, _ := newMcuProxyForTestWithOptions(t, testserver.ProxyTestOptions{ + Etcd: embedEtcd, + Servers: []testserver.ProxyTestServer{ + server2, + }, + }, 2, nil) + + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) + defer cancel() + + pubId := api.PublicSessionId("the-publisher") + pubSid := "1234567890" + pubListener := mock.NewListener(pubId + "-public") + pubInitiator := mock.NewInitiator("DE") + + pub, err := mcu1.NewPublisher(ctx, pubListener, pubId, pubSid, sfu.StreamTypeVideo, sfu.NewPublisherSettings{ + MediaTypes: sfu.MediaTypeVideo | sfu.MediaTypeAudio, + }, pubInitiator) + require.NoError(t, err) + + defer pub.Close(context.Background()) + + if proxyPub, ok := pub.(*proxyPublisher); assert.True(ok) { + hub1.addPublisher(proxyPub) + } + + subListener := mock.NewListener("subscriber-public") + subInitiator := mock.NewInitiator("DE") + sub, err := mcu2.NewSubscriber(ctx, subListener, pubId, sfu.StreamTypeVideo, subInitiator) + require.NoError(t, err) + + defer sub.Close(context.Background()) +} + +func Test_ProxyPublisherToken(t *testing.T) { + t.Parallel() + + assert := assert.New(t) + embedEtcd := etcdtest.NewServerForTest(t) + + grpcServer1, addr1 := grpctest.NewServerForTest(t) + grpcServer2, addr2 := grpctest.NewServerForTest(t) + + hub1 := newPublisherHub() + hub2 := newPublisherHub() + grpcServer1.SetHub(hub1) + grpcServer2.SetHub(hub2) + + embedEtcd.SetValue("/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}")) + embedEtcd.SetValue("/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}")) + + server1 := testserver.NewProxyServerForTest(t, "DE") + server2 := testserver.NewProxyServerForTest(t, "US") + + // Signaling server instances are in a cluster but don't share their proxies, + // i.e. they are only known to their local proxy, not the one of the other + // signaling server - so the connection token must be passed between them. + // Also the subscriber is connecting from a different country, so a remote + // stream will be created that needs a valid token from the remote proxy. + mcu1, _ := newMcuProxyForTestWithOptions(t, testserver.ProxyTestOptions{ + Etcd: embedEtcd, + Servers: []testserver.ProxyTestServer{ + server1, + }, + }, 1, nil) + mcu2, _ := newMcuProxyForTestWithOptions(t, testserver.ProxyTestOptions{ + Etcd: embedEtcd, + Servers: []testserver.ProxyTestServer{ + server2, + }, + }, 2, nil) + // Support remote subscribers for the tests. + server1.Servers = append(server1.Servers, server2) + server2.Servers = append(server2.Servers, server1) + + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) + defer cancel() + + pubId := api.PublicSessionId("the-publisher") + pubSid := "1234567890" + pubListener := mock.NewListener(pubId + "-public") + pubInitiator := mock.NewInitiator("DE") + + pub, err := mcu1.NewPublisher(ctx, pubListener, pubId, pubSid, sfu.StreamTypeVideo, sfu.NewPublisherSettings{ + MediaTypes: sfu.MediaTypeVideo | sfu.MediaTypeAudio, + }, pubInitiator) + require.NoError(t, err) + + defer pub.Close(context.Background()) + + if proxyPub, ok := pub.(*proxyPublisher); assert.True(ok) { + hub1.addPublisher(proxyPub) + } + + subListener := mock.NewListener("subscriber-public") + subInitiator := mock.NewInitiator("US") + sub, err := mcu2.NewSubscriber(ctx, subListener, pubId, sfu.StreamTypeVideo, subInitiator) + require.NoError(t, err) + + defer sub.Close(context.Background()) +} diff --git a/sfu/test/sfu.go b/sfu/test/sfu.go index bd3db0a..dabd659 100644 --- a/sfu/test/sfu.go +++ b/sfu/test/sfu.go @@ -26,11 +26,13 @@ import ( "errors" "fmt" "maps" + "net" "sync" "sync/atomic" "testing" "github.com/dlintw/goconf" + "github.com/stretchr/testify/assert" "github.com/strukturag/nextcloud-spreed-signaling/api" "github.com/strukturag/nextcloud-spreed-signaling/internal" @@ -144,6 +146,14 @@ func (m *SFU) GetPublisher(id api.PublicSessionId) *SFUPublisher { return m.publishers[id] } +func (m *SFU) GetSubscriber(id api.PublicSessionId, streamType sfu.StreamType) *SFUSubscriber { + m.mu.Lock() + defer m.mu.Unlock() + + key := fmt.Sprintf("%s|%s", id, streamType) + return m.subscribers[key] +} + func (m *SFU) NewSubscriber(ctx context.Context, listener sfu.Listener, publisher api.PublicSessionId, streamType sfu.StreamType, initiator sfu.Initiator) (sfu.Subscriber, error) { m.mu.Lock() defer m.mu.Unlock() @@ -163,6 +173,9 @@ func (m *SFU) NewSubscriber(ctx context.Context, listener sfu.Listener, publishe publisher: pub, } + key := fmt.Sprintf("%s|%s", publisher, streamType) + assert.Empty(m.t, m.subscribers[key], "duplicate subscriber") + m.subscribers[key] = sub return sub, nil } @@ -272,6 +285,10 @@ func (p *SFUPublisher) UnpublishRemote(ctx context.Context, remoteId api.PublicS return errors.New("remote publishing not supported") } +func (p *SFUPublisher) GetConnectionURL() (string, net.IP) { + return "https://proxy.domain.invalid", net.ParseIP("10.20.30.40") +} + type SFUSubscriber struct { SFUClient