From 9235b80125b36eb08902c8173b4d67958eb72e9c Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Wed, 23 Jul 2025 13:57:27 +0200 Subject: [PATCH] Return connection / publisher tokens for remote publishers. This supports connecting to and subscribing streams from proxies that don't know the signaling server sending the request. --- grpc_client.go | 8 +- grpc_mcu.proto | 2 + grpc_server.go | 12 ++ hub.go | 9 ++ mcu_proxy.go | 56 +++++---- mcu_proxy_test.go | 284 +++++++++++++++++++++++++++++++++++++++++----- 6 files changed, 316 insertions(+), 55 deletions(-) diff --git a/grpc_client.go b/grpc_client.go index 9cf0dec..fac9d0c 100644 --- a/grpc_client.go +++ b/grpc_client.go @@ -292,7 +292,7 @@ func (c *GrpcClient) GetInternalSessions(ctx context.Context, roomId string, bac return } -func (c *GrpcClient) GetPublisherId(ctx context.Context, sessionId string, streamType StreamType) (string, string, net.IP, error) { +func (c *GrpcClient) GetPublisherId(ctx context.Context, sessionId string, streamType StreamType) (string, string, net.IP, string, string, error) { statsGrpcClientCalls.WithLabelValues("GetPublisherId").Inc() // TODO: Remove debug logging log.Printf("Get %s publisher id %s on %s", streamType, sessionId, c.Target()) @@ -301,12 +301,12 @@ func (c *GrpcClient) GetPublisherId(ctx context.Context, sessionId string, strea StreamType: string(streamType), }, grpc.WaitForReady(true)) if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { - return "", "", nil, nil + return "", "", nil, "", "", nil } else if err != nil { - return "", "", nil, err + return "", "", nil, "", "", err } - return response.GetPublisherId(), response.GetProxyUrl(), net.ParseIP(response.GetIp()), nil + return response.GetPublisherId(), response.GetProxyUrl(), net.ParseIP(response.GetIp()), response.GetConnectToken(), response.GetPublisherToken(), nil } func (c *GrpcClient) GetSessionCount(ctx context.Context, url string) (uint32, error) { diff --git a/grpc_mcu.proto b/grpc_mcu.proto index b2313d2..c766ddc 100644 --- a/grpc_mcu.proto +++ b/grpc_mcu.proto @@ -38,4 +38,6 @@ message GetPublisherIdReply { string publisherId = 1; string proxyUrl = 2; string ip = 3; + string connectToken = 4; + string publisherToken = 5; } diff --git a/grpc_server.go b/grpc_server.go index 17327fd..cee0ab7 100644 --- a/grpc_server.go +++ b/grpc_server.go @@ -41,6 +41,8 @@ import ( var ( GrpcServerId string + + ErrNoProxyMcu = errors.New("no proxy mcu") ) func init() { @@ -62,6 +64,7 @@ type GrpcServerHub interface { GetRoomForBackend(roomId string, backend *Backend) *Room GetBackend(u *url.URL) *Backend + CreateProxyToken(publisherId string) (string, error) } type GrpcServer struct { @@ -276,6 +279,15 @@ func (s *GrpcServer) GetPublisherId(ctx context.Context, request *GetPublisherId if ip := publisher.conn.ip; ip != nil { reply.Ip = ip.String() } + var err error + if reply.ConnectToken, err = s.hub.CreateProxyToken(""); err != nil && !errors.Is(err, ErrNoProxyMcu) { + log.Printf("Error creating proxy token for connection: %s", err) + return nil, status.Error(codes.Internal, "error creating proxy connect token") + } + if reply.PublisherToken, err = s.hub.CreateProxyToken(publisher.Id()); err != nil && !errors.Is(err, ErrNoProxyMcu) { + log.Printf("Error creating proxy token for publisher %s: %s", publisher.Id(), err) + return nil, status.Error(codes.Internal, "error creating proxy publisher token") + } return reply, nil } diff --git a/hub.go b/hub.go index 5e3b4da..f8800d9 100644 --- a/hub.go +++ b/hub.go @@ -722,6 +722,15 @@ func (h *Hub) GetBackend(u *url.URL) *Backend { return h.backend.GetBackend(u) } +func (h *Hub) CreateProxyToken(publisherId string) (string, error) { + proxy, ok := h.mcu.(*mcuProxy) + if !ok { + return "", ErrNoProxyMcu + } + + return proxy.createToken(publisherId) +} + func (h *Hub) checkExpiredSessions(now time.Time) { for session, expires := range h.expiredSessions { if now.After(expires) { diff --git a/mcu_proxy.go b/mcu_proxy.go index 715535d..1fad338 100644 --- a/mcu_proxy.go +++ b/mcu_proxy.go @@ -340,10 +340,11 @@ func (s *mcuProxySubscriber) ProcessEvent(msg *EventProxyServerMessage) { } type mcuProxyConnection struct { - proxy *mcuProxy - rawUrl string - url *url.URL - ip net.IP + proxy *mcuProxy + rawUrl string + url *url.URL + ip net.IP + connectToken string load atomic.Int64 bandwidth atomic.Pointer[EventProxyServerBandwidth] @@ -380,7 +381,7 @@ type mcuProxyConnection struct { subscribers map[string]*mcuProxySubscriber } -func newMcuProxyConnection(proxy *mcuProxy, baseUrl string, ip net.IP) (*mcuProxyConnection, error) { +func newMcuProxyConnection(proxy *mcuProxy, baseUrl string, ip net.IP, token string) (*mcuProxyConnection, error) { parsed, err := url.Parse(baseUrl) if err != nil { return nil, err @@ -391,6 +392,7 @@ func newMcuProxyConnection(proxy *mcuProxy, baseUrl string, ip net.IP) (*mcuProx rawUrl: baseUrl, url: parsed, ip: ip, + connectToken: token, closer: NewCloser(), closedDone: NewCloser(), callbacks: make(map[string]func(*ProxyServerMessage)), @@ -1105,6 +1107,8 @@ func (c *mcuProxyConnection) sendHello() error { } if sessionId := c.SessionId(); sessionId != "" { msg.Hello.ResumeId = sessionId + } else if c.connectToken != "" { + msg.Hello.Token = c.connectToken } else { tokenString, err := c.proxy.createToken("") if err != nil { @@ -1238,14 +1242,16 @@ func (c *mcuProxyConnection) newSubscriber(ctx context.Context, listener McuList return subscriber, nil } -func (c *mcuProxyConnection) newRemoteSubscriber(ctx context.Context, listener McuListener, publisherId string, publisherSessionId string, streamType StreamType, publisherConn *mcuProxyConnection) (McuSubscriber, error) { +func (c *mcuProxyConnection) newRemoteSubscriber(ctx context.Context, listener McuListener, publisherId string, publisherSessionId string, streamType StreamType, publisherConn *mcuProxyConnection, remoteToken string) (McuSubscriber, error) { if c == publisherConn { return c.newSubscriber(ctx, listener, publisherId, publisherSessionId, streamType) } - remoteToken, err := c.proxy.createToken(publisherId) - if err != nil { - return nil, err + if remoteToken == "" { + var err error + if remoteToken, err = c.proxy.createToken(publisherId); err != nil { + return nil, err + } } msg := &ProxyClientMessage{ @@ -1516,7 +1522,7 @@ func (m *mcuProxy) AddConnection(ignoreErrors bool, url string, ips ...net.IP) e var conns []*mcuProxyConnection if len(ips) == 0 { - conn, err := newMcuProxyConnection(m, url, nil) + conn, err := newMcuProxyConnection(m, url, nil, "") if err != nil { if ignoreErrors { log.Printf("Could not create proxy connection to %s: %s", url, err) @@ -1529,7 +1535,7 @@ func (m *mcuProxy) AddConnection(ignoreErrors bool, url string, ips ...net.IP) e conns = append(conns, conn) } else { for _, ip := range ips { - conn, err := newMcuProxyConnection(m, url, ip) + conn, err := newMcuProxyConnection(m, url, ip, "") if err != nil { if ignoreErrors { log.Printf("Could not create proxy connection to %s (%s): %s", url, ip, err) @@ -1974,12 +1980,13 @@ func (m *mcuProxy) waitForPublisherConnection(ctx context.Context, publisher str } type proxyPublisherInfo struct { - id string - conn *mcuProxyConnection - err error + id string + conn *mcuProxyConnection + token string + err error } -func (m *mcuProxy) createSubscriber(ctx context.Context, listener McuListener, id string, publisher string, streamType StreamType, publisherConn *mcuProxyConnection, connections []*mcuProxyConnection, isAllowed func(c *mcuProxyConnection) bool) McuSubscriber { +func (m *mcuProxy) createSubscriber(ctx context.Context, listener McuListener, info *proxyPublisherInfo, publisher string, streamType StreamType, connections []*mcuProxyConnection, isAllowed func(c *mcuProxyConnection) bool) McuSubscriber { for _, conn := range connections { if !isAllowed(conn) || conn.IsShutdownScheduled() || conn.IsTemporary() { continue @@ -1987,10 +1994,10 @@ func (m *mcuProxy) createSubscriber(ctx context.Context, listener McuListener, i var subscriber McuSubscriber var err error - if conn == publisherConn { - subscriber, err = conn.newSubscriber(ctx, listener, id, publisher, streamType) + if conn == info.conn { + subscriber, err = conn.newSubscriber(ctx, listener, info.id, publisher, streamType) } else { - subscriber, err = conn.newRemoteSubscriber(ctx, listener, id, publisher, streamType, publisherConn) + subscriber, err = conn.newRemoteSubscriber(ctx, listener, info.id, publisher, streamType, info.conn, info.token) } if err != nil { log.Printf("Could not create subscriber for %s publisher %s on %s: %s", streamType, publisher, conn, err) @@ -2056,7 +2063,7 @@ func (m *mcuProxy) NewSubscriber(ctx context.Context, listener McuListener, publ wg.Add(1) go func(client *GrpcClient) { defer wg.Done() - id, url, ip, err := client.GetPublisherId(getctx, publisher, streamType) + id, url, ip, connectToken, publisherToken, err := client.GetPublisherId(getctx, publisher, streamType) if errors.Is(err, context.Canceled) { return } else if err != nil { @@ -2085,7 +2092,7 @@ func (m *mcuProxy) NewSubscriber(ctx context.Context, listener McuListener, publ } if publisherConn == nil { - publisherConn, err = newMcuProxyConnection(m, url, ip) + publisherConn, err = newMcuProxyConnection(m, url, ip, connectToken) if err != nil { log.Printf("Could not create temporary connection to %s for %s publisher %s: %s", url, streamType, publisher, err) return @@ -2112,8 +2119,9 @@ func (m *mcuProxy) NewSubscriber(ctx context.Context, listener McuListener, publ } ch <- &proxyPublisherInfo{ - id: id, - conn: publisherConn, + id: id, + conn: publisherConn, + token: publisherToken, } }(client) } @@ -2145,7 +2153,7 @@ func (m *mcuProxy) NewSubscriber(ctx context.Context, listener McuListener, publ connections := m.getSortedConnections(initiator) if !allowOutgoing || len(connections) > 0 && !connections[0].IsSameCountry(publisherInfo.conn) { // Connect to remote publisher through "closer" gateway. - subscriber := m.createSubscriber(ctx, listener, publisherInfo.id, publisher, streamType, publisherInfo.conn, connections, func(c *mcuProxyConnection) bool { + subscriber := m.createSubscriber(ctx, listener, publisherInfo, publisher, streamType, connections, func(c *mcuProxyConnection) bool { bw := c.Bandwidth() return bw == nil || bw.AllowOutgoing() }) @@ -2180,7 +2188,7 @@ func (m *mcuProxy) NewSubscriber(ctx context.Context, listener McuListener, publ } return 0 }) - subscriber = m.createSubscriber(ctx, listener, publisherInfo.id, publisher, streamType, publisherInfo.conn, connections2, func(c *mcuProxyConnection) bool { + subscriber = m.createSubscriber(ctx, listener, publisherInfo, publisher, streamType, connections2, func(c *mcuProxyConnection) bool { return true }) } diff --git a/mcu_proxy_test.go b/mcu_proxy_test.go index d68e217..96bab02 100644 --- a/mcu_proxy_test.go +++ b/mcu_proxy_test.go @@ -40,6 +40,7 @@ import ( "time" "github.com/dlintw/goconf" + "github.com/golang-jwt/jwt/v5" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -214,6 +215,26 @@ func (c *testProxyServerClient) processHello(msg *ProxyClientMessage) (*ProxySer return nil, fmt.Errorf("expected hello, got %+v", msg) } + token, err := jwt.ParseWithClaims(msg.Hello.Token, &TokenClaims{}, func(token *jwt.Token) (interface{}, error) { + claims, ok := token.Claims.(*TokenClaims) + if !assert.True(c.t, ok, "unsupported claims type: %+v", token.Claims) { + return nil, errors.New("unsupported claims type") + } + + key, found := c.server.tokens[claims.Issuer] + if !assert.True(c.t, found) { + return nil, fmt.Errorf("no key found for issuer") + } + + return key, nil + }) + if assert.NoError(c.t, err) { + if assert.True(c.t, token.Valid) { + _, ok := token.Claims.(*TokenClaims) + assert.True(c.t, ok) + } + } + response := &ProxyServerMessage{ Id: msg.Id, Type: "hello", @@ -295,6 +316,25 @@ func (c *testProxyServerClient) processCommandMessage(msg *ProxyClientMessage) ( continue } + token, err := jwt.ParseWithClaims(msg.Command.RemoteToken, &TokenClaims{}, func(token *jwt.Token) (interface{}, error) { + claims, ok := token.Claims.(*TokenClaims) + if !assert.True(c.t, ok, "unsupported claims type: %+v", token.Claims) { + return nil, errors.New("unsupported claims type") + } + + key, found := server.tokens[claims.Issuer] + if !assert.True(c.t, found) { + return nil, fmt.Errorf("no key found for issuer") + } + + return key, nil + }) + if assert.NoError(c.t, err) { + if claims, ok := token.Claims.(*TokenClaims); assert.True(c.t, token.Valid) && assert.True(c.t, ok) { + assert.Equal(c.t, msg.Command.PublisherId, claims.Subject) + } + } + pub = server.getPublisher(msg.Command.PublisherId) break } @@ -450,6 +490,7 @@ type TestProxyServerHandler struct { URL string server *httptest.Server servers []*TestProxyServerHandler + tokens map[string]*rsa.PublicKey upgrader *websocket.Upgrader country string @@ -637,6 +678,7 @@ func NewProxyServerForTest(t *testing.T, country string) *TestProxyServerHandler upgrader := websocket.Upgrader{} proxyHandler := &TestProxyServerHandler{ t: t, + tokens: make(map[string]*rsa.PublicKey), upgrader: &upgrader, country: country, clients: make(map[string]*testProxyServerClient), @@ -663,7 +705,7 @@ type proxyTestOptions struct { servers []*TestProxyServerHandler } -func newMcuProxyForTestWithOptions(t *testing.T, options proxyTestOptions) *mcuProxy { +func newMcuProxyForTestWithOptions(t *testing.T, options proxyTestOptions, idx int) *mcuProxy { t.Helper() require := require.New(t) if options.etcd == nil { @@ -689,13 +731,15 @@ func newMcuProxyForTestWithOptions(t *testing.T, options proxyTestOptions) *mcuP NewProxyServerForTest(t, "DE"), } } + tokenId := fmt.Sprintf("test-token-%d", idx) for _, s := range options.servers { s.servers = options.servers + s.tokens[tokenId] = &tokenKey.PublicKey urls = append(urls, s.URL) waitingMap[s.URL] = true } cfg.AddOption("mcu", "url", strings.Join(urls, " ")) - cfg.AddOption("mcu", "token_id", "test-token") + cfg.AddOption("mcu", "token_id", tokenId) cfg.AddOption("mcu", "token_key", privkeyFile) etcdConfig := goconf.NewConfigFile() @@ -744,25 +788,25 @@ func newMcuProxyForTestWithOptions(t *testing.T, options proxyTestOptions) *mcuP return proxy } -func newMcuProxyForTestWithServers(t *testing.T, servers []*TestProxyServerHandler) *mcuProxy { +func newMcuProxyForTestWithServers(t *testing.T, servers []*TestProxyServerHandler, idx int) *mcuProxy { t.Helper() return newMcuProxyForTestWithOptions(t, proxyTestOptions{ servers: servers, - }) + }, idx) } -func newMcuProxyForTest(t *testing.T) *mcuProxy { +func newMcuProxyForTest(t *testing.T, idx int) *mcuProxy { t.Helper() server := NewProxyServerForTest(t, "DE") - return newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{server}) + return newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{server}, idx) } func Test_ProxyPublisherSubscriber(t *testing.T) { CatchLogForTest(t) t.Parallel() - mcu := newMcuProxyForTest(t) + mcu := newMcuProxyForTest(t, 0) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() @@ -798,7 +842,7 @@ func Test_ProxyPublisherSubscriber(t *testing.T) { func Test_ProxyPublisherCodecs(t *testing.T) { CatchLogForTest(t) t.Parallel() - mcu := newMcuProxyForTest(t) + mcu := newMcuProxyForTest(t, 0) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() @@ -825,7 +869,7 @@ func Test_ProxyPublisherCodecs(t *testing.T) { func Test_ProxyWaitForPublisher(t *testing.T) { CatchLogForTest(t) t.Parallel() - mcu := newMcuProxyForTest(t) + mcu := newMcuProxyForTest(t, 0) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() @@ -880,7 +924,7 @@ func Test_ProxyPublisherBandwidth(t *testing.T) { mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{ server1, server2, - }) + }, 0) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() @@ -950,7 +994,7 @@ func Test_ProxyPublisherBandwidthOverload(t *testing.T) { mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{ server1, server2, - }) + }, 0) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() @@ -1023,7 +1067,7 @@ func Test_ProxyPublisherLoad(t *testing.T) { mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{ server1, server2, - }) + }, 0) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() @@ -1073,7 +1117,7 @@ func Test_ProxyPublisherCountry(t *testing.T) { mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{ serverDE, serverUS, - }) + }, 0) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() @@ -1121,7 +1165,7 @@ func Test_ProxyPublisherContinent(t *testing.T) { mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{ serverDE, serverUS, - }) + }, 0) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() @@ -1169,7 +1213,7 @@ func Test_ProxySubscriberCountry(t *testing.T) { mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{ serverDE, serverUS, - }) + }, 0) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() @@ -1213,7 +1257,7 @@ func Test_ProxySubscriberContinent(t *testing.T) { mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{ serverDE, serverUS, - }) + }, 0) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() @@ -1257,7 +1301,7 @@ func Test_ProxySubscriberBandwidth(t *testing.T) { mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{ serverDE, serverUS, - }) + }, 0) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() @@ -1321,7 +1365,7 @@ func Test_ProxySubscriberBandwidthOverload(t *testing.T) { mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{ serverDE, serverUS, - }) + }, 0) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() @@ -1379,6 +1423,7 @@ func Test_ProxySubscriberBandwidthOverload(t *testing.T) { } type mockGrpcServerHub struct { + proxy atomic.Pointer[mcuProxy] sessionsLock sync.Mutex sessionByPublicId map[string]Session } @@ -1420,6 +1465,15 @@ func (h *mockGrpcServerHub) GetRoomForBackend(roomId string, backend *Backend) * return nil } +func (h *mockGrpcServerHub) CreateProxyToken(publisherId string) (string, error) { + proxy := h.proxy.Load() + if proxy == nil { + return "", errors.New("not a proxy mcu") + } + + return proxy.createToken(publisherId) +} + func Test_ProxyRemotePublisher(t *testing.T) { CatchLogForTest(t) t.Parallel() @@ -1446,14 +1500,16 @@ func Test_ProxyRemotePublisher(t *testing.T) { server1, server2, }, - }) + }, 1) + hub1.proxy.Store(mcu1) mcu2 := newMcuProxyForTestWithOptions(t, proxyTestOptions{ etcd: etcd, servers: []*TestProxyServerHandler{ server1, server2, }, - }) + }, 2) + hub2.proxy.Store(mcu2) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() @@ -1530,7 +1586,8 @@ func Test_ProxyMultipleRemotePublisher(t *testing.T) { server2, server3, }, - }) + }, 1) + hub1.proxy.Store(mcu1) mcu2 := newMcuProxyForTestWithOptions(t, proxyTestOptions{ etcd: etcd, servers: []*TestProxyServerHandler{ @@ -1538,7 +1595,8 @@ func Test_ProxyMultipleRemotePublisher(t *testing.T) { server2, server3, }, - }) + }, 2) + hub2.proxy.Store(mcu2) mcu3 := newMcuProxyForTestWithOptions(t, proxyTestOptions{ etcd: etcd, servers: []*TestProxyServerHandler{ @@ -1546,7 +1604,8 @@ func Test_ProxyMultipleRemotePublisher(t *testing.T) { server2, server3, }, - }) + }, 3) + hub3.proxy.Store(mcu3) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() @@ -1628,14 +1687,16 @@ func Test_ProxyRemotePublisherWait(t *testing.T) { server1, server2, }, - }) + }, 1) + hub1.proxy.Store(mcu1) mcu2 := newMcuProxyForTestWithOptions(t, proxyTestOptions{ etcd: etcd, servers: []*TestProxyServerHandler{ server1, server2, }, - }) + }, 2) + hub2.proxy.Store(mcu2) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() @@ -1721,13 +1782,15 @@ func Test_ProxyRemotePublisherTemporary(t *testing.T) { servers: []*TestProxyServerHandler{ server1, }, - }) + }, 1) + hub1.proxy.Store(mcu1) mcu2 := newMcuProxyForTestWithOptions(t, proxyTestOptions{ etcd: etcd, servers: []*TestProxyServerHandler{ server2, }, - }) + }, 2) + hub2.proxy.Store(mcu2) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() @@ -1802,3 +1865,170 @@ loop: } } } + +func Test_ProxyConnectToken(t *testing.T) { + CatchLogForTest(t) + t.Parallel() + + etcd := NewEtcdForTest(t) + + grpcServer1, addr1 := NewGrpcServerForTest(t) + grpcServer2, addr2 := NewGrpcServerForTest(t) + + hub1 := &mockGrpcServerHub{} + hub2 := &mockGrpcServerHub{} + grpcServer1.hub = hub1 + grpcServer2.hub = hub2 + + SetEtcdValue(etcd, "/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}")) + SetEtcdValue(etcd, "/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}")) + + server1 := NewProxyServerForTest(t, "DE") + server2 := NewProxyServerForTest(t, "DE") + + // Signaling server instances are in a cluster but don't share their proxies, + // i.e. they are only known to their local proxy, not the one of the other + // signaling server - so the connection token must be passed between them. + mcu1 := newMcuProxyForTestWithOptions(t, proxyTestOptions{ + etcd: etcd, + servers: []*TestProxyServerHandler{ + server1, + }, + }, 1) + hub1.proxy.Store(mcu1) + mcu2 := newMcuProxyForTestWithOptions(t, proxyTestOptions{ + etcd: etcd, + servers: []*TestProxyServerHandler{ + server2, + }, + }, 2) + hub2.proxy.Store(mcu2) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + pubId := "the-publisher" + pubSid := "1234567890" + pubListener := &MockMcuListener{ + publicId: pubId + "-public", + } + pubInitiator := &MockMcuInitiator{ + country: "DE", + } + + session1 := &ClientSession{ + publicId: pubId, + publishers: make(map[StreamType]McuPublisher), + } + hub1.addSession(session1) + defer hub1.removeSession(session1) + + pub, err := mcu1.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, NewPublisherSettings{ + MediaTypes: MediaTypeVideo | MediaTypeAudio, + }, pubInitiator) + require.NoError(t, err) + + defer pub.Close(context.Background()) + + session1.mu.Lock() + session1.publishers[StreamTypeVideo] = pub + session1.publisherWaiters.Wakeup() + session1.mu.Unlock() + + subListener := &MockMcuListener{ + publicId: "subscriber-public", + } + subInitiator := &MockMcuInitiator{ + country: "DE", + } + sub, err := mcu2.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator) + require.NoError(t, err) + + defer sub.Close(context.Background()) +} + +func Test_ProxyPublisherToken(t *testing.T) { + CatchLogForTest(t) + t.Parallel() + + etcd := NewEtcdForTest(t) + + grpcServer1, addr1 := NewGrpcServerForTest(t) + grpcServer2, addr2 := NewGrpcServerForTest(t) + + hub1 := &mockGrpcServerHub{} + hub2 := &mockGrpcServerHub{} + grpcServer1.hub = hub1 + grpcServer2.hub = hub2 + + SetEtcdValue(etcd, "/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}")) + SetEtcdValue(etcd, "/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}")) + + server1 := NewProxyServerForTest(t, "DE") + server2 := NewProxyServerForTest(t, "US") + + // Signaling server instances are in a cluster but don't share their proxies, + // i.e. they are only known to their local proxy, not the one of the other + // signaling server - so the connection token must be passed between them. + // Also the subscriber is connecting from a different country, so a remote + // stream will be created that needs a valid token from the remote proxy. + mcu1 := newMcuProxyForTestWithOptions(t, proxyTestOptions{ + etcd: etcd, + servers: []*TestProxyServerHandler{ + server1, + }, + }, 1) + hub1.proxy.Store(mcu1) + mcu2 := newMcuProxyForTestWithOptions(t, proxyTestOptions{ + etcd: etcd, + servers: []*TestProxyServerHandler{ + server2, + }, + }, 2) + hub2.proxy.Store(mcu2) + // Support remote subscribers for the tests. + server1.servers = append(server1.servers, server2) + server2.servers = append(server2.servers, server1) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + pubId := "the-publisher" + pubSid := "1234567890" + pubListener := &MockMcuListener{ + publicId: pubId + "-public", + } + pubInitiator := &MockMcuInitiator{ + country: "DE", + } + + session1 := &ClientSession{ + publicId: pubId, + publishers: make(map[StreamType]McuPublisher), + } + hub1.addSession(session1) + defer hub1.removeSession(session1) + + pub, err := mcu1.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, NewPublisherSettings{ + MediaTypes: MediaTypeVideo | MediaTypeAudio, + }, pubInitiator) + require.NoError(t, err) + + defer pub.Close(context.Background()) + + session1.mu.Lock() + session1.publishers[StreamTypeVideo] = pub + session1.publisherWaiters.Wakeup() + session1.mu.Unlock() + + subListener := &MockMcuListener{ + publicId: "subscriber-public", + } + subInitiator := &MockMcuInitiator{ + country: "US", + } + sub, err := mcu2.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator) + require.NoError(t, err) + + defer sub.Close(context.Background()) +}