From 99762a3ca98e7aef8519733b8ed23368518cf697 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Tue, 3 Feb 2026 09:37:07 +0100 Subject: [PATCH] Improxy proxy server test coverage. --- cmd/proxy/proxy_server_test.go | 425 +++++++++++++++++++++++++++++++++ cmd/proxy/proxy_session.go | 2 +- 2 files changed, 426 insertions(+), 1 deletion(-) diff --git a/cmd/proxy/proxy_server_test.go b/cmd/proxy/proxy_server_test.go index f04e55c..4aa212d 100644 --- a/cmd/proxy/proxy_server_test.go +++ b/cmd/proxy/proxy_server_test.go @@ -28,6 +28,7 @@ import ( "crypto/x509" "encoding/pem" "errors" + "fmt" "net" "net/http/httptest" "os" @@ -48,9 +49,11 @@ import ( "github.com/strukturag/nextcloud-spreed-signaling/internal" "github.com/strukturag/nextcloud-spreed-signaling/log" logtest "github.com/strukturag/nextcloud-spreed-signaling/log/test" + "github.com/strukturag/nextcloud-spreed-signaling/mock" "github.com/strukturag/nextcloud-spreed-signaling/proxy" "github.com/strukturag/nextcloud-spreed-signaling/sfu" "github.com/strukturag/nextcloud-spreed-signaling/talk" + "github.com/strukturag/nextcloud-spreed-signaling/test" ) const ( @@ -461,6 +464,7 @@ type PublisherTestMCU struct { type TestPublisherWithBandwidth struct { TestMCUPublisher + t *testing.T bandwidth *sfu.ClientBandwidthInfo } @@ -468,6 +472,25 @@ func (p *TestPublisherWithBandwidth) Bandwidth() *sfu.ClientBandwidthInfo { return p.bandwidth } +func (p *TestPublisherWithBandwidth) SendMessage(ctx context.Context, message *api.MessageClientMessage, data *api.MessageClientMessageData, callback func(error, api.StringMap)) { + switch data.Type { + case "offer": + assert.Equal(p.t, mock.MockSdpOfferAudioAndVideo, data.Payload["sdp"]) + assert.NotNil(p.t, data.OfferSdp) + callback(nil, api.StringMap{ + "type": "answer", + "sdp": mock.MockSdpAnswerAudioAndVideo, + }) + case "requestoffer": + callback(nil, api.StringMap{ + "type": "offer", + "sdp": mock.MockSdpOfferAudioOnly, + }) + default: + callback(fmt.Errorf("type %s not implemented", data.Type), nil) + } +} + func (m *PublisherTestMCU) NewPublisher(ctx context.Context, listener sfu.Listener, id api.PublicSessionId, sid string, streamType sfu.StreamType, settings sfu.NewPublisherSettings, initiator sfu.Initiator) (sfu.Publisher, error) { publisher := &TestPublisherWithBandwidth{ TestMCUPublisher: TestMCUPublisher{ @@ -476,6 +499,7 @@ func (m *PublisherTestMCU) NewPublisher(ctx context.Context, listener sfu.Listen streamType: streamType, }, + t: m.t, bandwidth: &sfu.ClientBandwidthInfo{ Sent: api.BandwidthFromBytes(1000), Received: api.BandwidthFromBytes(2000), @@ -497,6 +521,9 @@ func TestProxyPublisherBandwidth(t *testing.T) { assert := assert.New(t) require := require.New(t) proxyServer, key, server := newProxyServerForTest(t) + t.Cleanup(func() { + assert.EqualValues(0, proxyServer.GetSessionsCount()) + }) proxyServer.maxIncoming.Store(api.BandwidthFromBytes(10000)) proxyServer.maxOutgoing.Store(api.BandwidthFromBytes(10000)) @@ -507,6 +534,8 @@ func TestProxyPublisherBandwidth(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() + assert.EqualValues(0, proxyServer.GetSessionsCount()) + client := NewProxyTestClient(ctx, t, server.URL) defer client.CloseWithBye() @@ -528,12 +557,19 @@ func TestProxyPublisherBandwidth(t *testing.T) { }, })) + var clientId string if message, err := client.RunUntilMessage(ctx); assert.NoError(err) { assert.Equal("2345", message.Id) if err := checkMessageType(message, "command"); assert.NoError(err) { assert.NotEmpty(message.Command.Id) + clientId = message.Command.Id } } + require.NotEmpty(clientId) + + if publisher := proxyServer.GetPublisher(clientId); assert.NotNil(publisher) { + assert.Equal(clientId, proxyServer.GetClientId(publisher)) + } proxyServer.updateLoad() @@ -550,6 +586,66 @@ func TestProxyPublisherBandwidth(t *testing.T) { } } } + + require.NoError(client.WriteJSON(&proxy.ClientMessage{ + Id: "3456", + Type: "payload", + Payload: &proxy.PayloadClientMessage{ + Type: "offer", + ClientId: clientId, + Payload: api.StringMap{ + "type": "offer", + "sdp": mock.MockSdpOfferAudioAndVideo, + }, + }, + })) + if message, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("3456", message.Id) + assert.Equal("payload", message.Type) + if payload := message.Payload; assert.NotNil(payload) { + assert.Equal(clientId, payload.ClientId) + assert.Equal("offer", payload.Type) + assert.Equal("answer", payload.Payload["type"]) + assert.Equal(mock.MockSdpAnswerAudioAndVideo, payload.Payload["sdp"]) + } + } + + require.NoError(client.WriteJSON(&proxy.ClientMessage{ + Id: "4567", + Type: "payload", + Payload: &proxy.PayloadClientMessage{ + Type: "endOfCandidates", + ClientId: clientId, + }, + })) + if message, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("4567", message.Id) + assert.Equal("payload", message.Type) + if payload := message.Payload; assert.NotNil(payload) { + assert.Equal(clientId, payload.ClientId) + assert.Equal("endOfCandidates", payload.Type) + assert.Empty(payload.Payload) + } + } + + require.NoError(client.WriteJSON(&proxy.ClientMessage{ + Id: "5678", + Type: "payload", + Payload: &proxy.PayloadClientMessage{ + Type: "requestoffer", + ClientId: clientId, + }, + })) + if message, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("5678", message.Id) + assert.Equal("payload", message.Type) + if payload := message.Payload; assert.NotNil(payload) { + assert.Equal(clientId, payload.ClientId) + assert.Equal("requestoffer", payload.Type) + assert.Equal("offer", payload.Payload["type"]) + assert.Equal(mock.MockSdpOfferAudioOnly, payload.Payload["sdp"]) + } + } } type HangingTestMCU struct { @@ -1605,3 +1701,332 @@ func TestProxyUnpublishRemoteOnSessionClose(t *testing.T) { assert.Nil(publisher.getRemoteData()) } } + +func TestExpireSessions(t *testing.T) { + t.Parallel() + + test.SynctestTest(t, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + proxyServer, key, server := newProxyServerForTest(t) + server.Close() + + // No-op + proxyServer.expireSessions() + + claims := &proxy.TokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(time.Now().Add(-maxTokenAge / 2)), + Issuer: TokenIdForTest, + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tokenString, err := token.SignedString(key) + require.NoError(err) + + hello := &proxy.HelloClientMessage{ + Version: "1.0", + Token: tokenString, + } + session, err := proxyServer.NewSession(hello) + require.NoError(err) + t.Cleanup(func() { + proxyServer.DeleteSession(session.Sid()) + }) + assert.Same(session, proxyServer.GetSession(session.Sid())) + + proxyServer.expireSessions() + assert.Same(session, proxyServer.GetSession(session.Sid())) + + time.Sleep(sessionExpirationTime) + proxyServer.expireSessions() + + assert.Nil(proxyServer.GetSession(session.Sid())) + }) +} + +func TestScheduleShutdownEmpty(t *testing.T) { + t.Parallel() + + proxyServer, _, _ := newProxyServerForTest(t) + + proxyServer.ScheduleShutdown() + <-proxyServer.ShutdownChannel() +} + +func TestScheduleShutdownNoClients(t *testing.T) { + t.Parallel() + + require := require.New(t) + assert := assert.New(t) + proxyServer, key, server := newProxyServerForTest(t) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client := NewProxyTestClient(ctx, t, server.URL) + defer client.CloseWithBye() + + require.NoError(client.SendHello(key)) + + if hello, err := client.RunUntilHello(ctx); assert.NoError(err) { + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello) + } + + _, err := client.RunUntilLoad(ctx, 0) + assert.NoError(err) + + proxyServer.ScheduleShutdown() + + if msg, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("event", msg.Type) + if event := msg.Event; assert.NotNil(event) { + assert.Equal("shutdown-scheduled", event.Type) + } + } + + <-proxyServer.ShutdownChannel() +} + +func TestScheduleShutdown(t *testing.T) { + t.Parallel() + + require := require.New(t) + assert := assert.New(t) + proxyServer, key, server := newProxyServerForTest(t) + + mcu := NewPublisherTestMCU(t) + proxyServer.mcu = mcu + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client := NewProxyTestClient(ctx, t, server.URL) + defer client.CloseWithBye() + + require.NoError(client.SendHello(key)) + + if hello, err := client.RunUntilHello(ctx); assert.NoError(err) { + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello) + } + + _, err := client.RunUntilLoad(ctx, 0) + assert.NoError(err) + + publisherId := api.PublicSessionId("the-publisher-id") + require.NoError(client.WriteJSON(&proxy.ClientMessage{ + Id: "2345", + Type: "command", + Command: &proxy.CommandClientMessage{ + Type: "create-publisher", + PublisherId: publisherId, + Sid: "1234-abcd", + StreamType: sfu.StreamTypeVideo, + PublisherSettings: &sfu.NewPublisherSettings{ + Bitrate: 1234567, + MediaTypes: sfu.MediaTypeAudio | sfu.MediaTypeVideo, + }, + }, + })) + + var clientId string + if message, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("2345", message.Id) + if err := checkMessageType(message, "command"); assert.NoError(err) { + require.NotEmpty(message.Command.Id) + clientId = message.Command.Id + } + } + + readyChan := make(chan struct{}) + var readyReceived atomic.Bool + go func() { + for { + select { + case <-proxyServer.ShutdownChannel(): + return + case <-readyChan: + readyReceived.Store(true) + case <-ctx.Done(): + assert.NoError(ctx.Err()) + return + } + } + }() + + proxyServer.ScheduleShutdown() + + if msg, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("event", msg.Type) + if event := msg.Event; assert.NotNil(event) { + assert.Equal("shutdown-scheduled", event.Type) + } + } + close(readyChan) + + proxyServer.ScheduleShutdown() + + select { + case <-proxyServer.ShutdownChannel(): + assert.Fail("should only shutdown after all clients closed") + default: + } + require.NoError(client.WriteJSON(&proxy.ClientMessage{ + Id: "4567", + Type: "command", + Command: &proxy.CommandClientMessage{ + Type: "delete-publisher", + ClientId: clientId, + }, + })) + + if message, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("4567", message.Id) + if err := checkMessageType(message, "command"); assert.NoError(err) { + require.NotEmpty(message.Command.Id) + } + } + + <-proxyServer.ShutdownChannel() + assert.True(readyReceived.Load()) +} + +func TestScheduleShutdownOnResume(t *testing.T) { + t.Parallel() + + require := require.New(t) + assert := assert.New(t) + proxyServer, key, server := newProxyServerForTest(t) + + mcu := NewPublisherTestMCU(t) + proxyServer.mcu = mcu + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client := NewProxyTestClient(ctx, t, server.URL) + defer client.CloseWithBye() + + require.NoError(client.SendHello(key)) + + hello, err := client.RunUntilHello(ctx) + if assert.NoError(err) { + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello) + } + + _, err = client.RunUntilLoad(ctx, 0) + assert.NoError(err) + + publisherId := api.PublicSessionId("the-publisher-id") + require.NoError(client.WriteJSON(&proxy.ClientMessage{ + Id: "2345", + Type: "command", + Command: &proxy.CommandClientMessage{ + Type: "create-publisher", + PublisherId: publisherId, + Sid: "1234-abcd", + StreamType: sfu.StreamTypeVideo, + PublisherSettings: &sfu.NewPublisherSettings{ + Bitrate: 1234567, + MediaTypes: sfu.MediaTypeAudio | sfu.MediaTypeVideo, + }, + }, + })) + + var clientId string + if message, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("2345", message.Id) + if err := checkMessageType(message, "command"); assert.NoError(err) { + require.NotEmpty(message.Command.Id) + clientId = message.Command.Id + } + } + + readyChan := make(chan struct{}) + var readyReceived atomic.Bool + go func() { + for { + select { + case <-proxyServer.ShutdownChannel(): + return + case <-readyChan: + readyReceived.Store(true) + case <-ctx.Done(): + assert.NoError(ctx.Err()) + return + } + } + }() + + client.Close() + + proxyServer.ScheduleShutdown() + + client = NewProxyTestClient(ctx, t, server.URL) + defer client.CloseWithBye() + + hello2 := &proxy.ClientMessage{ + Id: "1234", + Type: "hello", + Hello: &proxy.HelloClientMessage{ + Version: "1.0", + Features: []string{}, + ResumeId: hello.Hello.SessionId, + }, + } + require.NoError(client.WriteJSON(hello2)) + + if hello3, err := client.RunUntilHello(ctx); assert.NoError(err) { + assert.Equal(hello.Hello.SessionId, hello3.Hello.SessionId) + } + + if msg, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("event", msg.Type) + if event := msg.Event; assert.NotNil(event) { + assert.Equal("shutdown-scheduled", event.Type) + } + } + + client2 := NewProxyTestClient(ctx, t, server.URL) + defer client2.CloseWithBye() + + require.NoError(client2.SendHello(key)) + + if hello, err := client2.RunUntilHello(ctx); assert.NoError(err) { + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello) + } + + if msg, err := client2.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("event", msg.Type) + if event := msg.Event; assert.NotNil(event) { + assert.Equal("shutdown-scheduled", event.Type) + } + } + close(readyChan) + + proxyServer.ScheduleShutdown() + + select { + case <-proxyServer.ShutdownChannel(): + assert.Fail("should only shutdown after all clients closed") + default: + } + require.NoError(client.WriteJSON(&proxy.ClientMessage{ + Id: "4567", + Type: "command", + Command: &proxy.CommandClientMessage{ + Type: "delete-publisher", + ClientId: clientId, + }, + })) + + if message, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("4567", message.Id) + if err := checkMessageType(message, "command"); assert.NoError(err) { + require.NotEmpty(message.Command.Id) + } + } + + <-proxyServer.ShutdownChannel() + assert.True(readyReceived.Load()) +} diff --git a/cmd/proxy/proxy_session.go b/cmd/proxy/proxy_session.go index 0c12f23..7bc7133 100644 --- a/cmd/proxy/proxy_session.go +++ b/cmd/proxy/proxy_session.go @@ -117,7 +117,7 @@ func (s *ProxySession) LastUsed() time.Time { func (s *ProxySession) IsExpired() bool { expiresAt := s.LastUsed().Add(sessionExpirationTime) - return expiresAt.Before(time.Now()) + return !expiresAt.After(time.Now()) } func (s *ProxySession) MarkUsed() {