From 2d81afe1909e13083e67236e86af80df45e2662d Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Wed, 17 Apr 2024 10:32:16 +0200 Subject: [PATCH] Add basic tests for mcu proxy client. --- mcu_common_test.go | 40 +++ mcu_proxy.go | 54 +++- mcu_proxy_test.go | 619 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 706 insertions(+), 7 deletions(-) diff --git a/mcu_common_test.go b/mcu_common_test.go index 0609ef3..6304638 100644 --- a/mcu_common_test.go +++ b/mcu_common_test.go @@ -28,3 +28,43 @@ import ( func TestCommonMcuStats(t *testing.T) { collectAndLint(t, commonMcuStats...) } + +type MockMcuListener struct { + publicId string +} + +func (m *MockMcuListener) PublicId() string { + return m.publicId +} + +func (m *MockMcuListener) OnUpdateOffer(client McuClient, offer map[string]interface{}) { + +} + +func (m *MockMcuListener) OnIceCandidate(client McuClient, candidate interface{}) { + +} + +func (m *MockMcuListener) OnIceCompleted(client McuClient) { + +} + +func (m *MockMcuListener) SubscriberSidUpdated(subscriber McuSubscriber) { + +} + +func (m *MockMcuListener) PublisherClosed(publisher McuPublisher) { + +} + +func (m *MockMcuListener) SubscriberClosed(subscriber McuSubscriber) { + +} + +type MockMcuInitiator struct { + country string +} + +func (m *MockMcuInitiator) Country() string { + return m.country +} diff --git a/mcu_proxy.go b/mcu_proxy.go index 131fa98..9e6cce8 100644 --- a/mcu_proxy.go +++ b/mcu_proxy.go @@ -326,7 +326,7 @@ type mcuProxyConnection struct { msgId atomic.Int64 helloMsgId string - sessionId string + sessionId atomic.Value country atomic.Value callbacks map[string]func(*ProxyServerMessage) @@ -418,6 +418,21 @@ func (c *mcuProxyConnection) Country() string { return c.country.Load().(string) } +func (c *mcuProxyConnection) SessionId() string { + sid := c.sessionId.Load() + if sid == nil { + return "" + } + + return sid.(string) +} + +func (c *mcuProxyConnection) IsConnected() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.conn != nil && c.SessionId() != "" +} + func (c *mcuProxyConnection) IsTemporary() bool { return c.temporary.Load() } @@ -810,7 +825,7 @@ func (c *mcuProxyConnection) processMessage(msg *ProxyServerMessage) { c.clearPublishers() c.clearSubscribers() c.clearCallbacks() - c.sessionId = "" + c.sessionId.Store("") if err := c.sendHello(); err != nil { log.Printf("Could not send hello request to %s: %s", c, err) c.scheduleReconnect() @@ -821,8 +836,8 @@ func (c *mcuProxyConnection) processMessage(msg *ProxyServerMessage) { log.Printf("Hello connection to %s failed with %+v, reconnecting", c, msg.Error) c.scheduleReconnect() case "hello": - resumed := c.sessionId == msg.Hello.SessionId - c.sessionId = msg.Hello.SessionId + resumed := c.SessionId() == msg.Hello.SessionId + c.sessionId.Store(msg.Hello.SessionId) country := "" if msg.Hello.Server != nil { if country = msg.Hello.Server.Country; country != "" && !IsValidCountry(country) { @@ -945,7 +960,7 @@ func (c *mcuProxyConnection) processBye(msg *ProxyServerMessage) { switch bye.Reason { case "session_resumed": log.Printf("Session %s on %s was resumed by other client, resetting", c.sessionId, c) - c.sessionId = "" + c.sessionId.Store("") default: log.Printf("Received bye with unsupported reason from %s %+v", c, bye) } @@ -960,8 +975,8 @@ func (c *mcuProxyConnection) sendHello() error { Version: "1.0", }, } - if c.sessionId != "" { - msg.Hello.ResumeId = c.sessionId + if sessionId := c.SessionId(); sessionId != "" { + msg.Hello.ResumeId = sessionId } else { claims := &TokenClaims{ jwt.RegisteredClaims{ @@ -1274,6 +1289,31 @@ func (m *mcuProxy) Stop() { m.config.Stop() } +func (m *mcuProxy) hasConnections() bool { + m.connectionsMu.RLock() + defer m.connectionsMu.RUnlock() + for _, conn := range m.connections { + if conn.IsConnected() { + return true + } + } + return false +} + +func (m *mcuProxy) WaitForConnections(ctx context.Context) error { + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + + for !m.hasConnections() { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + } + } + return nil +} + func (m *mcuProxy) AddConnection(ignoreErrors bool, url string, ips ...net.IP) error { m.connectionsMu.Lock() defer m.connectionsMu.Unlock() diff --git a/mcu_proxy_test.go b/mcu_proxy_test.go index e518e6d..1aa9ffc 100644 --- a/mcu_proxy_test.go +++ b/mcu_proxy_test.go @@ -22,7 +22,23 @@ package signaling import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "path" + "strings" + "sync" + "sync/atomic" "testing" + "time" + + "github.com/dlintw/goconf" + "github.com/gorilla/websocket" ) func TestMcuProxyStats(t *testing.T) { @@ -166,3 +182,606 @@ func Test_sortConnectionsForCountryWithOverride(t *testing.T) { }) } } + +type proxyServerClientHandler func(msg *ProxyClientMessage) (*ProxyServerMessage, error) + +type testProxyServerPublisher struct { + id string +} + +type testProxyServerSubscriber struct { + id string + sid string + pub *testProxyServerPublisher +} + +type testProxyServerClient struct { + t *testing.T + + server *testProxyServerHandler + ws *websocket.Conn + processMessage proxyServerClientHandler + + mu sync.Mutex + sessionId string + publishers map[string]*testProxyServerPublisher + subscribers map[string]*testProxyServerSubscriber +} + +func (c *testProxyServerClient) processHello(msg *ProxyClientMessage) (*ProxyServerMessage, error) { + if msg.Type != "hello" { + return nil, fmt.Errorf("expected hello, got %+v", msg) + } + + response := &ProxyServerMessage{ + Id: msg.Id, + Type: "hello", + Hello: &HelloProxyServerMessage{ + Version: "1.0", + SessionId: c.sessionId, + Server: &WelcomeServerMessage{ + Version: "1.0", + Country: c.server.country, + }, + }, + } + c.processMessage = c.processRegularMessage + return response, nil +} + +func (c *testProxyServerClient) processRegularMessage(msg *ProxyClientMessage) (*ProxyServerMessage, error) { + var handler proxyServerClientHandler + switch msg.Type { + case "command": + handler = c.processCommandMessage + } + + if handler == nil { + response := msg.NewWrappedErrorServerMessage(fmt.Errorf("type \"%s\" is not implemented", msg.Type)) + return response, nil + } + + return handler(msg) +} + +func (c *testProxyServerClient) processCommandMessage(msg *ProxyClientMessage) (*ProxyServerMessage, error) { + var response *ProxyServerMessage + switch msg.Command.Type { + case "create-publisher": + pub := &testProxyServerPublisher{ + id: newRandomString(32), + } + + response = &ProxyServerMessage{ + Id: msg.Id, + Type: "command", + Command: &CommandProxyServerMessage{ + Id: pub.id, + Bitrate: msg.Command.Bitrate, + }, + } + c.mu.Lock() + defer c.mu.Unlock() + c.publishers[pub.id] = pub + c.server.updateLoad(1) + case "delete-publisher": + c.mu.Lock() + defer c.mu.Unlock() + pub, found := c.publishers[msg.Command.ClientId] + if !found { + response = msg.NewWrappedErrorServerMessage(fmt.Errorf("publisher %s not found", msg.Command.ClientId)) + } else { + delete(c.publishers, pub.id) + response = &ProxyServerMessage{ + Id: msg.Id, + Type: "command", + Command: &CommandProxyServerMessage{ + Id: pub.id, + }, + } + c.server.updateLoad(-1) + } + case "create-subscriber": + c.mu.Lock() + defer c.mu.Unlock() + pub, found := c.publishers[msg.Command.PublisherId] + if !found { + response = msg.NewWrappedErrorServerMessage(fmt.Errorf("publisher %s not found", msg.Command.PublisherId)) + } else { + sub := &testProxyServerSubscriber{ + id: newRandomString(32), + sid: newRandomString(8), + pub: pub, + } + response = &ProxyServerMessage{ + Id: msg.Id, + Type: "command", + Command: &CommandProxyServerMessage{ + Id: sub.id, + Sid: sub.sid, + }, + } + c.subscribers[sub.id] = sub + c.server.updateLoad(1) + } + case "delete-subscriber": + c.mu.Lock() + defer c.mu.Unlock() + sub, found := c.subscribers[msg.Command.ClientId] + if !found { + response = msg.NewWrappedErrorServerMessage(fmt.Errorf("subscriber %s not found", msg.Command.ClientId)) + } else { + delete(c.subscribers, sub.id) + response = &ProxyServerMessage{ + Id: msg.Id, + Type: "command", + Command: &CommandProxyServerMessage{ + Id: sub.id, + }, + } + c.server.updateLoad(-1) + } + } + if response == nil { + response = msg.NewWrappedErrorServerMessage(fmt.Errorf("command \"%s\" is not implemented", msg.Command.Type)) + } + + return response, nil +} + +func (c *testProxyServerClient) close() { + c.ws.Close() +} + +func (c *testProxyServerClient) sendMessage(msg *ProxyServerMessage) { + c.mu.Lock() + defer c.mu.Unlock() + + data, err := json.Marshal(msg) + if err != nil { + c.t.Error(err) + return + } + + w, err := c.ws.NextWriter(websocket.TextMessage) + if err != nil { + c.t.Error(err) + return + } + + if _, err := w.Write(data); err != nil { + c.t.Error(err) + return + } + + if err := w.Close(); err != nil { + c.t.Error(err) + } +} + +func (c *testProxyServerClient) run() { + c.processMessage = c.processHello + for { + msgType, reader, err := c.ws.NextReader() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) { + c.t.Error(err) + } + return + } + + body, err := io.ReadAll(reader) + if err != nil { + c.t.Error(err) + continue + } + + if msgType != websocket.TextMessage { + c.t.Errorf("unexpected message type %q (%s)", msgType, string(body)) + continue + } + + var msg ProxyClientMessage + if err := json.Unmarshal(body, &msg); err != nil { + c.t.Errorf("could not decode message %s: %s", string(body), err) + continue + } + + if err := msg.CheckValid(); err != nil { + c.t.Errorf("invalid message %s: %s", string(body), err) + continue + } + + response, err := c.processMessage(&msg) + if err != nil { + c.t.Error(err) + continue + } + + c.sendMessage(response) + if response.Type == "hello" { + c.server.sendLoad(c) + } + } +} + +type testProxyServerHandler struct { + t *testing.T + + upgrader *websocket.Upgrader + country string + + mu sync.Mutex + load atomic.Int64 + clients map[string]*testProxyServerClient +} + +func (h *testProxyServerHandler) updateLoad(delta int64) { + if delta == 0 { + return + } + + h.mu.Lock() + defer h.mu.Unlock() + + load := h.load.Add(delta) + for _, c := range h.clients { + go func(c *testProxyServerClient, load int64) { + c.sendMessage(&ProxyServerMessage{ + Type: "event", + Event: &EventProxyServerMessage{ + Type: "update-load", + Load: load, + }, + }) + }(c, load) + } +} + +func (h *testProxyServerHandler) sendLoad(c *testProxyServerClient) { + c.sendMessage(&ProxyServerMessage{ + Type: "event", + Event: &EventProxyServerMessage{ + Type: "update-load", + Load: h.load.Load(), + }, + }) +} + +func (h *testProxyServerHandler) removeClient(client *testProxyServerClient) { + h.mu.Lock() + defer h.mu.Unlock() + delete(h.clients, client.sessionId) +} + +func (h *testProxyServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ws, err := h.upgrader.Upgrade(w, r, nil) + if err != nil { + h.t.Error(err) + return + } + + client := &testProxyServerClient{ + t: h.t, + server: h, + ws: ws, + sessionId: newRandomString(32), + publishers: make(map[string]*testProxyServerPublisher), + subscribers: make(map[string]*testProxyServerSubscriber), + } + + h.mu.Lock() + h.clients[client.sessionId] = client + h.mu.Unlock() + + go func(client *testProxyServerClient) { + defer h.removeClient(client) + client.run() + }(client) +} + +func NewProxyServerForTest(t *testing.T, country string) *httptest.Server { + t.Helper() + + upgrader := websocket.Upgrader{} + proxyHandler := &testProxyServerHandler{ + t: t, + upgrader: &upgrader, + country: country, + clients: make(map[string]*testProxyServerClient), + } + server := httptest.NewServer(proxyHandler) + t.Cleanup(func() { + server.Close() + proxyHandler.mu.Lock() + defer proxyHandler.mu.Unlock() + for _, c := range proxyHandler.clients { + c.close() + } + }) + + return server +} + +func newMcuProxyForTestWithServers(t *testing.T, servers []*httptest.Server) *mcuProxy { + etcd, etcdClient := NewEtcdClientForTest(t) + grpcClients, dnsMonitor := NewGrpcClientsWithEtcdForTest(t, etcd) + + tokenKey, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + t.Fatal(err) + } + dir := t.TempDir() + privkeyFile := path.Join(dir, "privkey.pem") + pubkeyFile := path.Join(dir, "pubkey.pem") + WritePrivateKey(tokenKey, privkeyFile) // nolint + WritePublicKey(&tokenKey.PublicKey, pubkeyFile) // nolint + + cfg := goconf.NewConfigFile() + cfg.AddOption("mcu", "urltype", "static") + var urls []string + for _, s := range servers { + urls = append(urls, s.URL) + } + cfg.AddOption("mcu", "url", strings.Join(urls, " ")) + cfg.AddOption("mcu", "token_id", "test-token") + cfg.AddOption("mcu", "token_key", privkeyFile) + + mcu, err := NewMcuProxy(cfg, etcdClient, grpcClients, dnsMonitor) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + mcu.Stop() + }) + + if err := mcu.Start(); err != nil { + t.Fatal(err) + } + + proxy := mcu.(*mcuProxy) + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + if err := proxy.WaitForConnections(ctx); err != nil { + t.Fatal(err) + } + + return proxy +} + +func newMcuProxyForTest(t *testing.T) *mcuProxy { + t.Helper() + server := NewProxyServerForTest(t, "DE") + + return newMcuProxyForTestWithServers(t, []*httptest.Server{server}) +} + +func Test_ProxyPublisherSubscriber(t *testing.T) { + mcu := newMcuProxyForTest(t) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + pubId := "the-publisher" + pubSid := "1234567890" + pubListener := &MockMcuListener{ + publicId: pubId + "-public", + } + pubInitiator := &MockMcuInitiator{ + country: "DE", + } + + pub, err := mcu.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator) + if err != nil { + t.Fatal(err) + } + + defer pub.Close(context.Background()) + + subListener := &MockMcuListener{ + publicId: "subscriber-public", + } + sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo) + if err != nil { + t.Fatal(err) + } + + defer sub.Close(context.Background()) +} + +func Test_ProxyWaitForPublisher(t *testing.T) { + mcu := newMcuProxyForTest(t) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + pubId := "the-publisher" + pubSid := "1234567890" + pubListener := &MockMcuListener{ + publicId: pubId + "-public", + } + pubInitiator := &MockMcuInitiator{ + country: "DE", + } + + subListener := &MockMcuListener{ + publicId: "subscriber-public", + } + done := make(chan struct{}) + go func() { + defer close(done) + sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo) + if err != nil { + t.Error(err) + return + } + + defer sub.Close(context.Background()) + }() + + // Give subscriber goroutine some time to start + time.Sleep(100 * time.Millisecond) + + pub, err := mcu.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator) + if err != nil { + t.Fatal(err) + } + + select { + case <-done: + case <-ctx.Done(): + t.Error(ctx.Err()) + } + defer pub.Close(context.Background()) +} + +func Test_ProxyPublisherLoad(t *testing.T) { + server1 := NewProxyServerForTest(t, "DE") + server2 := NewProxyServerForTest(t, "DE") + mcu := newMcuProxyForTestWithServers(t, []*httptest.Server{ + server1, + server2, + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + pub1Id := "the-publisher-1" + pub1Sid := "1234567890" + pub1Listener := &MockMcuListener{ + publicId: pub1Id + "-public", + } + pub1Initiator := &MockMcuInitiator{ + country: "DE", + } + pub1, err := mcu.NewPublisher(ctx, pub1Listener, pub1Id, pub1Sid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pub1Initiator) + if err != nil { + t.Fatal(err) + } + + defer pub1.Close(context.Background()) + + // Make sure connections are re-sorted. + mcu.nextSort.Store(0) + time.Sleep(100 * time.Millisecond) + + pub2Id := "the-publisher-2" + pub2id := "1234567890" + pub2Listener := &MockMcuListener{ + publicId: pub2Id + "-public", + } + pub2Initiator := &MockMcuInitiator{ + country: "DE", + } + pub2, err := mcu.NewPublisher(ctx, pub2Listener, pub2Id, pub2id, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pub2Initiator) + if err != nil { + t.Fatal(err) + } + + defer pub2.Close(context.Background()) + + if pub1.(*mcuProxyPublisher).conn.rawUrl == pub2.(*mcuProxyPublisher).conn.rawUrl { + t.Errorf("servers should be different, got %s", pub1.(*mcuProxyPublisher).conn.rawUrl) + } +} + +func Test_ProxyPublisherCountry(t *testing.T) { + serverDE := NewProxyServerForTest(t, "DE") + serverUS := NewProxyServerForTest(t, "US") + mcu := newMcuProxyForTestWithServers(t, []*httptest.Server{ + serverDE, + serverUS, + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + pubDEId := "the-publisher-de" + pubDESid := "1234567890" + pubDEListener := &MockMcuListener{ + publicId: pubDEId + "-public", + } + pubDEInitiator := &MockMcuInitiator{ + country: "DE", + } + pubDE, err := mcu.NewPublisher(ctx, pubDEListener, pubDEId, pubDESid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubDEInitiator) + if err != nil { + t.Fatal(err) + } + + defer pubDE.Close(context.Background()) + + if pubDE.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL { + t.Errorf("expected server %s, go %s", serverDE.URL, pubDE.(*mcuProxyPublisher).conn.rawUrl) + } + + pubUSId := "the-publisher-us" + pubUSSid := "1234567890" + pubUSListener := &MockMcuListener{ + publicId: pubUSId + "-public", + } + pubUSInitiator := &MockMcuInitiator{ + country: "US", + } + pubUS, err := mcu.NewPublisher(ctx, pubUSListener, pubUSId, pubUSSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubUSInitiator) + if err != nil { + t.Fatal(err) + } + + defer pubUS.Close(context.Background()) + + if pubUS.(*mcuProxyPublisher).conn.rawUrl != serverUS.URL { + t.Errorf("expected server %s, go %s", serverUS.URL, pubUS.(*mcuProxyPublisher).conn.rawUrl) + } +} + +func Test_ProxyPublisherContinent(t *testing.T) { + serverDE := NewProxyServerForTest(t, "DE") + serverUS := NewProxyServerForTest(t, "US") + mcu := newMcuProxyForTestWithServers(t, []*httptest.Server{ + serverDE, + serverUS, + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + pubDEId := "the-publisher-de" + pubDESid := "1234567890" + pubDEListener := &MockMcuListener{ + publicId: pubDEId + "-public", + } + pubDEInitiator := &MockMcuInitiator{ + country: "DE", + } + pubDE, err := mcu.NewPublisher(ctx, pubDEListener, pubDEId, pubDESid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubDEInitiator) + if err != nil { + t.Fatal(err) + } + + defer pubDE.Close(context.Background()) + + if pubDE.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL { + t.Errorf("expected server %s, go %s", serverDE.URL, pubDE.(*mcuProxyPublisher).conn.rawUrl) + } + + pubFRId := "the-publisher-fr" + pubFRSid := "1234567890" + pubFRListener := &MockMcuListener{ + publicId: pubFRId + "-public", + } + pubFRInitiator := &MockMcuInitiator{ + country: "FR", + } + pubFR, err := mcu.NewPublisher(ctx, pubFRListener, pubFRId, pubFRSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubFRInitiator) + if err != nil { + t.Fatal(err) + } + + defer pubFR.Close(context.Background()) + + if pubFR.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL { + t.Errorf("expected server %s, go %s", serverDE.URL, pubFR.(*mcuProxyPublisher).conn.rawUrl) + } +}