diff --git a/proxy/proxy_server_test.go b/proxy/proxy_server_test.go index e12b66b..cb6f42f 100644 --- a/proxy/proxy_server_test.go +++ b/proxy/proxy_server_test.go @@ -618,3 +618,220 @@ func TestProxyCodecs(t *testing.T) { } } } + +type RemoteSubscriberTestMCU struct { + TestMCU + + publisher *TestRemotePublisher + subscriber *TestRemoteSubscriber +} + +func NewRemoteSubscriberTestMCU(t *testing.T) *RemoteSubscriberTestMCU { + return &RemoteSubscriberTestMCU{ + TestMCU: TestMCU{ + t: t, + }, + } +} + +type TestRemotePublisher struct { + t *testing.T + + streamType signaling.StreamType + refcnt atomic.Int32 + closed context.Context + closeFunc context.CancelFunc +} + +func (p *TestRemotePublisher) Id() string { + return "id" +} + +func (p *TestRemotePublisher) Sid() string { + return "sid" +} + +func (p *TestRemotePublisher) StreamType() signaling.StreamType { + return p.streamType +} + +func (p *TestRemotePublisher) MaxBitrate() int { + return 0 +} + +func (p *TestRemotePublisher) Close(ctx context.Context) { + if count := p.refcnt.Add(-1); assert.True(p.t, count >= 0) && count == 0 { + p.closeFunc() + } +} + +func (p *TestRemotePublisher) SendMessage(ctx context.Context, message *signaling.MessageClientMessage, data *signaling.MessageClientMessageData, callback func(error, map[string]interface{})) { + callback(errors.New("not implemented"), nil) +} + +func (p *TestRemotePublisher) Port() int { + return 1 +} + +func (p *TestRemotePublisher) RtcpPort() int { + return 2 +} + +func (m *RemoteSubscriberTestMCU) NewRemotePublisher(ctx context.Context, listener signaling.McuListener, controller signaling.RemotePublisherController, streamType signaling.StreamType) (signaling.McuRemotePublisher, error) { + require.Nil(m.t, m.publisher) + assert.EqualValues(m.t, "video", streamType) + closeCtx, closeFunc := context.WithCancel(context.Background()) + m.publisher = &TestRemotePublisher{ + t: m.t, + + streamType: streamType, + closed: closeCtx, + closeFunc: closeFunc, + } + m.publisher.refcnt.Add(1) + return m.publisher, nil +} + +type TestRemoteSubscriber struct { + t *testing.T + + publisher *TestRemotePublisher + closed context.Context + closeFunc context.CancelFunc +} + +func (s *TestRemoteSubscriber) Id() string { + return "id" +} + +func (s *TestRemoteSubscriber) Sid() string { + return "sid" +} + +func (s *TestRemoteSubscriber) StreamType() signaling.StreamType { + return s.publisher.StreamType() +} + +func (s *TestRemoteSubscriber) MaxBitrate() int { + return 0 +} + +func (s *TestRemoteSubscriber) Close(ctx context.Context) { + s.publisher.Close(ctx) + s.closeFunc() +} + +func (s *TestRemoteSubscriber) SendMessage(ctx context.Context, message *signaling.MessageClientMessage, data *signaling.MessageClientMessageData, callback func(error, map[string]interface{})) { + callback(errors.New("not implemented"), nil) +} + +func (s *TestRemoteSubscriber) Publisher() string { + return s.publisher.Id() +} + +func (m *RemoteSubscriberTestMCU) NewRemoteSubscriber(ctx context.Context, listener signaling.McuListener, publisher signaling.McuRemotePublisher) (signaling.McuRemoteSubscriber, error) { + require.Nil(m.t, m.subscriber) + pub, ok := publisher.(*TestRemotePublisher) + require.True(m.t, ok) + closeCtx, closeFunc := context.WithCancel(context.Background()) + m.subscriber = &TestRemoteSubscriber{ + t: m.t, + + publisher: pub, + closed: closeCtx, + closeFunc: closeFunc, + } + pub.refcnt.Add(1) + return m.subscriber, nil +} + +func TestProxyRemoteSubscriber(t *testing.T) { + signaling.CatchLogForTest(t) + assert := assert.New(t) + require := require.New(t) + proxy, key, server := newProxyServerForTest(t) + + mcu := NewRemoteSubscriberTestMCU(t) + proxy.mcu = mcu + // Unused but must be set so remote subscribing works + proxy.tokenId = "token" + proxy.tokenKey = key + proxy.remoteHostname = "test-hostname" + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client := NewProxyTestClient(ctx, t, server.URL) + defer client.CloseWithBye() + + require.NoError(client.SendHello(key)) + + if hello, err := client.RunUntilHello(ctx); assert.NoError(err) { + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello) + } + + _, err := client.RunUntilLoad(ctx, 0) + assert.NoError(err) + + publisherId := "the-publisher-id" + claims := &signaling.TokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(time.Now().Add(-maxTokenAge / 2)), + Issuer: TokenIdForTest, + Subject: publisherId, + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tokenString, err := token.SignedString(key) + require.NoError(err) + + require.NoError(client.WriteJSON(&signaling.ProxyClientMessage{ + Id: "2345", + Type: "command", + Command: &signaling.CommandProxyClientMessage{ + Type: "create-subscriber", + StreamType: signaling.StreamTypeVideo, + PublisherId: publisherId, + RemoteUrl: "https://remote-hostname", + RemoteToken: tokenString, + }, + })) + + var clientId string + if message, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("2345", message.Id) + if err := checkMessageType(message, "command"); assert.NoError(err) { + require.NotEmpty(message.Command.Id) + clientId = message.Command.Id + } + } + + require.NoError(client.WriteJSON(&signaling.ProxyClientMessage{ + Id: "3456", + Type: "command", + Command: &signaling.CommandProxyClientMessage{ + Type: "delete-subscriber", + ClientId: clientId, + }, + })) + + if message, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("3456", message.Id) + if err := checkMessageType(message, "command"); assert.NoError(err) { + assert.Equal(clientId, message.Command.Id) + } + } + + if assert.NotNil(mcu.publisher) && assert.NotNil(mcu.subscriber) { + select { + case <-mcu.subscriber.closed.Done(): + case <-ctx.Done(): + assert.Fail("subscriber was not closed") + } + select { + case <-mcu.publisher.closed.Done(): + case <-ctx.Done(): + assert.Fail("publisher was not closed") + } + } +}