diff --git a/api_proxy.go b/api_proxy.go index 62a0cfd..fc5f34b 100644 --- a/api_proxy.go +++ b/api_proxy.go @@ -24,6 +24,7 @@ package signaling import ( "encoding/json" "fmt" + "net/url" "github.com/golang-jwt/jwt/v4" ) @@ -201,6 +202,14 @@ type CommandProxyClientMessage struct { ClientId string `json:"clientId,omitempty"` Bitrate int `json:"bitrate,omitempty"` MediaTypes MediaType `json:"mediatypes,omitempty"` + + RemoteUrl string `json:"remoteUrl,omitempty"` + remoteUrl *url.URL + RemoteToken string `json:"remoteToken,omitempty"` + + Hostname string `json:"hostname,omitempty"` + Port int `json:"port,omitempty"` + RtcpPort int `json:"rtcpPort,omitempty"` } func (m *CommandProxyClientMessage) CheckValid() error { @@ -218,6 +227,17 @@ func (m *CommandProxyClientMessage) CheckValid() error { if m.StreamType == "" { return fmt.Errorf("stream type missing") } + if m.RemoteUrl != "" { + if m.RemoteToken == "" { + return fmt.Errorf("remote token missing") + } + + remoteUrl, err := url.Parse(m.RemoteUrl) + if err != nil { + return fmt.Errorf("invalid remote url: %w", err) + } + m.remoteUrl = remoteUrl + } case "delete-publisher": fallthrough case "delete-subscriber": diff --git a/clientsession.go b/clientsession.go index 72c22d1..d4e8c40 100644 --- a/clientsession.go +++ b/clientsession.go @@ -934,9 +934,10 @@ func (s *ClientSession) GetOrCreateSubscriber(ctx context.Context, mcu Mcu, id s subscriber, found := s.subscribers[getStreamId(id, streamType)] if !found { + client := s.getClientUnlocked() s.mu.Unlock() var err error - subscriber, err = mcu.NewSubscriber(ctx, s, id, streamType) + subscriber, err = mcu.NewSubscriber(ctx, s, id, streamType, client) s.mu.Lock() if err != nil { return nil, err diff --git a/docker/README.md b/docker/README.md index c19e078..ffa13bb 100644 --- a/docker/README.md +++ b/docker/README.md @@ -100,6 +100,9 @@ The running container can be configured through different environment variables: - `CONFIG`: Optional name of configuration file to use. - `HTTP_LISTEN`: Address of HTTP listener. - `COUNTRY`: Optional ISO 3166 country this proxy is located at. +- `EXTERNAL_HOSTNAME`: The external hostname for remote streams. Will try to autodetect if omitted. +- `TOKEN_ID`: Id of the token to use when connecting remote streams. +- `TOKEN_KEY`: Private key for the configured token id. - `JANUS_URL`: Url to Janus server. - `MAX_STREAM_BITRATE`: Optional maximum bitrate for audio/video streams. - `MAX_SCREEN_BITRATE`: Optional maximum bitrate for screensharing streams. diff --git a/docker/proxy/entrypoint.sh b/docker/proxy/entrypoint.sh index 31e37f0..483541d 100755 --- a/docker/proxy/entrypoint.sh +++ b/docker/proxy/entrypoint.sh @@ -44,6 +44,16 @@ if [ ! -f "$CONFIG" ]; then sed -i "s|#country =.*|country = $COUNTRY|" "$CONFIG" fi + if [ -n "$EXTERNAL_HOSTNAME" ]; then + sed -i "s|#hostname =.*|hostname = $EXTERNAL_HOSTNAME|" "$CONFIG" + fi + if [ -n "$TOKEN_ID" ]; then + sed -i "s|#token_id =.*|token_id = $TOKEN_ID|" "$CONFIG" + fi + if [ -n "$TOKEN_KEY" ]; then + sed -i "s|#token_key =.*|token_key = $TOKEN_KEY|" "$CONFIG" + fi + HAS_ETCD= if [ -n "$ETCD_ENDPOINTS" ]; then sed -i "s|#endpoints =.*|endpoints = $ETCD_ENDPOINTS|" "$CONFIG" diff --git a/mcu_common.go b/mcu_common.go index 3bea933..0df9cb0 100644 --- a/mcu_common.go +++ b/mcu_common.go @@ -76,7 +76,18 @@ type Mcu interface { GetStats() interface{} NewPublisher(ctx context.Context, listener McuListener, id string, sid string, streamType StreamType, bitrate int, mediaTypes MediaType, initiator McuInitiator) (McuPublisher, error) - NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType StreamType) (McuSubscriber, error) + NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType StreamType, initiator McuInitiator) (McuSubscriber, error) +} + +type RemotePublisherController interface { + PublisherId() string + + StartPublishing(ctx context.Context, publisher McuRemotePublisherProperties) error +} + +type RemoteMcu interface { + NewRemotePublisher(ctx context.Context, listener McuListener, controller RemotePublisherController, streamType StreamType) (McuRemotePublisher, error) + NewRemoteSubscriber(ctx context.Context, listener McuListener, publisher McuRemotePublisher) (McuRemoteSubscriber, error) } type StreamType string @@ -116,6 +127,8 @@ type McuPublisher interface { HasMedia(MediaType) bool SetMedia(MediaType) + + PublishRemote(ctx context.Context, hostname string, port int, rtcpPort int) error } type McuSubscriber interface { @@ -123,3 +136,18 @@ type McuSubscriber interface { Publisher() string } + +type McuRemotePublisherProperties interface { + Port() int + RtcpPort() int +} + +type McuRemotePublisher interface { + McuClient + + McuRemotePublisherProperties +} + +type McuRemoteSubscriber interface { + McuSubscriber +} diff --git a/mcu_janus.go b/mcu_janus.go index 6048a7c..059e06a 100644 --- a/mcu_janus.go +++ b/mcu_janus.go @@ -25,6 +25,7 @@ import ( "context" "database/sql" "encoding/json" + "errors" "fmt" "log" "reflect" @@ -53,6 +54,8 @@ const ( ) var ( + ErrRemoteStreamsNotSupported = errors.New("Need Janus 1.1.0 for remote streams") + streamTypeUserIds = map[StreamType]uint64{ StreamTypeVideo: videoPublisherUserId, StreamTypeScreen: screenPublisherUserId, @@ -143,6 +146,7 @@ type mcuJanus struct { gw *JanusGateway session *JanusSession handle *JanusHandle + version int closeChan chan struct{} @@ -154,6 +158,7 @@ type mcuJanus struct { publishers map[string]*mcuJanusPublisher publisherCreated Notifier publisherConnected Notifier + remotePublishers map[string]*mcuJanusRemotePublisher reconnectTimer *time.Timer reconnectInterval time.Duration @@ -189,7 +194,8 @@ func NewMcuJanus(url string, config *goconf.ConfigFile) (Mcu, error) { closeChan: make(chan struct{}, 1), clients: make(map[clientInterface]bool), - publishers: make(map[string]*mcuJanusPublisher), + publishers: make(map[string]*mcuJanusPublisher), + remotePublishers: make(map[string]*mcuJanusRemotePublisher), reconnectInterval: initialReconnectInterval, } @@ -288,6 +294,10 @@ func (m *mcuJanus) isMultistream() bool { return m.version >= 1000 } +func (m *mcuJanus) hasRemotePublisher() bool { + return m.version >= 1100 +} + func (m *mcuJanus) Start() error { ctx := context.TODO() info, err := m.gw.Info(ctx) @@ -719,17 +729,7 @@ func (m *mcuJanus) SubscriberDisconnected(id string, publisher string, streamTyp } } -func (m *mcuJanus) getOrCreatePublisherHandle(ctx context.Context, id string, streamType StreamType, bitrate int) (*JanusHandle, uint64, uint64, int, error) { - session := m.session - if session == nil { - return nil, 0, 0, 0, ErrNotConnected - } - handle, err := session.Attach(ctx, pluginVideoRoom) - if err != nil { - return nil, 0, 0, 0, err - } - - log.Printf("Attached %s as publisher %d to plugin %s in session %d", streamType, handle.Id, pluginVideoRoom, session.Id) +func (m *mcuJanus) createPublisherRoom(ctx context.Context, handle *JanusHandle, id string, streamType StreamType, bitrate int) (uint64, int, error) { create_msg := map[string]interface{}{ "request": "create", "description": getStreamId(id, streamType), @@ -756,7 +756,7 @@ func (m *mcuJanus) getOrCreatePublisherHandle(ctx context.Context, id string, st if _, err2 := handle.Detach(ctx); err2 != nil { log.Printf("Error detaching handle %d: %s", handle.Id, err2) } - return nil, 0, 0, 0, err + return 0, 0, err } roomId := getPluginIntValue(create_response.PluginData, pluginVideoRoom, "room") @@ -764,10 +764,32 @@ func (m *mcuJanus) getOrCreatePublisherHandle(ctx context.Context, id string, st if _, err := handle.Detach(ctx); err != nil { log.Printf("Error detaching handle %d: %s", handle.Id, err) } - return nil, 0, 0, 0, fmt.Errorf("No room id received: %+v", create_response) + return 0, 0, fmt.Errorf("No room id received: %+v", create_response) } log.Println("Created room", roomId, create_response.PluginData) + return roomId, bitrate, nil +} + +func (m *mcuJanus) getOrCreatePublisherHandle(ctx context.Context, id string, streamType StreamType, bitrate int) (*JanusHandle, uint64, uint64, int, error) { + session := m.session + if session == nil { + return nil, 0, 0, 0, ErrNotConnected + } + handle, err := session.Attach(ctx, pluginVideoRoom) + if err != nil { + return nil, 0, 0, 0, err + } + + log.Printf("Attached %s as publisher %d to plugin %s in session %d", streamType, handle.Id, pluginVideoRoom, session.Id) + + roomId, bitrate, err := m.createPublisherRoom(ctx, handle, id, streamType, bitrate) + if err != nil { + if _, err2 := handle.Detach(ctx); err2 != nil { + log.Printf("Error detaching handle %d: %s", handle.Id, err2) + } + return nil, 0, 0, 0, err + } msg := map[string]interface{}{ "request": "join", @@ -975,6 +997,97 @@ func (p *mcuJanusPublisher) SendMessage(ctx context.Context, message *MessageCli } } +func (p *mcuJanusPublisher) PublishRemote(ctx context.Context, hostname string, port int, rtcpPort int) error { + msg := map[string]interface{}{ + "request": "publish_remotely", + "room": p.roomId, + "publisher_id": streamTypeUserIds[p.streamType], + "remote_id": p.id, + "host": hostname, + "port": port, + "rtcp_port": rtcpPort, + } + response, err := p.handle.Request(ctx, msg) + if err != nil { + return err + } + + errorMessage := getPluginStringValue(response.PluginData, pluginVideoRoom, "error") + errorCode := getPluginIntValue(response.PluginData, pluginVideoRoom, "error_code") + if errorMessage != "" || errorCode != 0 { + if errorMessage == "" { + errorMessage = "unknown error" + } + return fmt.Errorf("%s (%d)", errorMessage, errorCode) + } + + log.Printf("Publishing %s to %s (port=%d, rtcpPort=%d)", p.id, hostname, port, rtcpPort) + return nil +} + +type mcuJanusRemotePublisher struct { + mcuJanusClient + + ref atomic.Int64 + publisher string + port int + rtcpPort int +} + +func (p *mcuJanusRemotePublisher) addRef() int64 { + return p.ref.Add(1) +} + +func (p *mcuJanusRemotePublisher) release() bool { + return p.ref.Add(-1) == 0 +} + +func (p *mcuJanusRemotePublisher) Port() int { + return p.port +} + +func (p *mcuJanusRemotePublisher) RtcpPort() int { + return p.rtcpPort +} + +func (p *mcuJanusRemotePublisher) Close(ctx context.Context) { + if !p.release() { + return + } + + p.mu.Lock() + if handle := p.handle; handle != nil { + response, err := p.handle.Request(ctx, map[string]interface{}{ + "request": "remove_remote_publisher", + "room": p.roomId, + "id": streamTypeUserIds[p.streamType], + }) + if err != nil { + log.Printf("Error removing remote publisher %d in room %d: %s", p.id, p.roomId, err) + } else { + log.Printf("Removed remote publisher: %+v", response) + } + if p.roomId != 0 { + destroy_msg := map[string]interface{}{ + "request": "destroy", + "room": p.roomId, + } + if _, err := handle.Request(ctx, destroy_msg); err != nil { + log.Printf("Error destroying room %d: %s", p.roomId, err) + } else { + log.Printf("Room %d destroyed", p.roomId) + } + p.mcu.mu.Lock() + delete(p.mcu.remotePublishers, getStreamId(p.publisher, p.streamType)) + p.mcu.mu.Unlock() + p.roomId = 0 + } + } + + p.closeClient(ctx) + p.mu.Unlock() +} + type mcuJanusSubscriber struct { mcuJanusClient @@ -1029,7 +1142,7 @@ func (m *mcuJanus) getOrCreateSubscriberHandle(ctx context.Context, publisher st return handle, pub, nil } -func (m *mcuJanus) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType StreamType) (McuSubscriber, error) { +func (m *mcuJanus) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType StreamType, initiator McuInitiator) (McuSubscriber, error) { if _, found := streamTypeUserIds[streamType]; !found { return nil, fmt.Errorf("Unsupported stream type %s", streamType) } @@ -1070,6 +1183,186 @@ func (m *mcuJanus) NewSubscriber(ctx context.Context, listener McuListener, publ return client, nil } +type mcuJanusRemoteSubscriber struct { + mcuJanusSubscriber + + remote atomic.Pointer[mcuJanusRemotePublisher] +} + +func (s *mcuJanusRemoteSubscriber) Close(ctx context.Context) { + s.mcuJanusSubscriber.Close(ctx) + + if remote := s.remote.Swap(nil); remote != nil { + remote.Close(context.Background()) + } +} + +func (m *mcuJanus) getOrCreateRemotePublisher(ctx context.Context, controller RemotePublisherController, streamType StreamType, bitrate int) (*mcuJanusRemotePublisher, error) { + m.mu.Lock() + defer m.mu.Unlock() + pub, found := m.remotePublishers[getStreamId(controller.PublisherId(), streamType)] + if found { + return pub, nil + } + + session := m.session + if session == nil { + return nil, ErrNotConnected + } + + handle, err := session.Attach(ctx, pluginVideoRoom) + if err != nil { + return nil, err + } + + roomId, bitrate, err := m.createPublisherRoom(ctx, handle, controller.PublisherId(), streamType, bitrate) + if err != nil { + if _, err2 := handle.Detach(ctx); err2 != nil { + log.Printf("Error detaching handle %d: %s", handle.Id, err2) + } + return nil, err + } + + response, err := handle.Request(ctx, map[string]interface{}{ + "request": "add_remote_publisher", + "room": roomId, + "id": streamTypeUserIds[streamType], + "streams": []map[string]interface{}{ + { + "mid": "0", + "mindex": 0, + "type": "audio", + "codec": "opus", + "fec": true, + }, + { + "mid": "1", + "mindex": 1, + "type": "video", + "codec": "vp8", + "simulcast": true, + }, + { + "mid": "2", + "mindex": 2, + "type": "data", + }, + }, + }) + if err != nil { + if _, err2 := handle.Detach(ctx); err2 != nil { + log.Printf("Error detaching handle %d: %s", handle.Id, err2) + } + return nil, err + } + + id := getPluginIntValue(response.PluginData, pluginVideoRoom, "id") + port := getPluginIntValue(response.PluginData, pluginVideoRoom, "port") + rtcp_port := getPluginIntValue(response.PluginData, pluginVideoRoom, "rtcp_port") + + pub = &mcuJanusRemotePublisher{ + mcuJanusClient: mcuJanusClient{ + mcu: m, + + id: id, + session: response.Session, + roomId: roomId, + sid: strconv.FormatUint(handle.Id, 10), + streamType: streamType, + maxBitrate: bitrate, + + handle: handle, + handleId: handle.Id, + closeChan: make(chan struct{}, 1), + deferred: make(chan func(), 64), + }, + + publisher: controller.PublisherId(), + port: int(port), + rtcpPort: int(rtcp_port), + } + + if err := controller.StartPublishing(ctx, pub); err != nil { + go pub.Close(context.Background()) + return nil, err + } + + m.remotePublishers[getStreamId(controller.PublisherId(), streamType)] = pub + + return pub, nil +} + +func (m *mcuJanus) NewRemotePublisher(ctx context.Context, listener McuListener, controller RemotePublisherController, streamType StreamType) (McuRemotePublisher, error) { + if _, found := streamTypeUserIds[streamType]; !found { + return nil, fmt.Errorf("Unsupported stream type %s", streamType) + } + + if !m.hasRemotePublisher() { + return nil, ErrRemoteStreamsNotSupported + } + + pub, err := m.getOrCreateRemotePublisher(ctx, controller, streamType, 0) + if err != nil { + return nil, err + } + + pub.addRef() + return pub, nil +} + +func (m *mcuJanus) NewRemoteSubscriber(ctx context.Context, listener McuListener, publisher McuRemotePublisher) (McuRemoteSubscriber, error) { + pub, ok := publisher.(*mcuJanusRemotePublisher) + if !ok { + return nil, errors.New("unsupported remote publisher") + } + + session := m.session + if session == nil { + return nil, ErrNotConnected + } + + handle, err := session.Attach(ctx, pluginVideoRoom) + if err != nil { + return nil, err + } + + log.Printf("Attached subscriber to room %d of publisher %s in plugin %s in session %d as %d", pub.roomId, pub.publisher, pluginVideoRoom, session.Id, handle.Id) + + client := &mcuJanusRemoteSubscriber{ + mcuJanusSubscriber: mcuJanusSubscriber{ + mcuJanusClient: mcuJanusClient{ + mcu: m, + listener: listener, + + id: m.clientId.Add(1), + roomId: pub.roomId, + sid: strconv.FormatUint(handle.Id, 10), + streamType: publisher.StreamType(), + maxBitrate: pub.MaxBitrate(), + + handle: handle, + handleId: handle.Id, + closeChan: make(chan struct{}, 1), + deferred: make(chan func(), 64), + }, + publisher: pub.publisher, + }, + } + client.remote.Store(pub) + pub.addRef() + client.mcuJanusClient.handleEvent = client.handleEvent + client.mcuJanusClient.handleHangup = client.handleHangup + client.mcuJanusClient.handleDetached = client.handleDetached + client.mcuJanusClient.handleConnected = client.handleConnected + client.mcuJanusClient.handleSlowLink = client.handleSlowLink + client.mcuJanusClient.handleMedia = client.handleMedia + m.registerClient(client) + go client.run(handle, client.closeChan) + statsSubscribersCurrent.WithLabelValues(string(publisher.StreamType())).Inc() + statsSubscribersTotal.WithLabelValues(string(publisher.StreamType())).Inc() + return client, nil +} + func (p *mcuJanusSubscriber) Publisher() string { return p.publisher } diff --git a/mcu_proxy.go b/mcu_proxy.go index 643dd47..3e5c243 100644 --- a/mcu_proxy.go +++ b/mcu_proxy.go @@ -218,13 +218,18 @@ func (p *mcuProxyPublisher) ProcessEvent(msg *EventProxyServerMessage) { } } +func (p *mcuProxyPublisher) PublishRemote(ctx context.Context, hostname string, port int, rtcpPort int) error { + return errors.New("remote publishing not supported for proxy publishers") +} + type mcuProxySubscriber struct { mcuProxyPubSubCommon - publisherId string + publisherId string + publisherConn *mcuProxyConnection } -func newMcuProxySubscriber(publisherId string, sid string, streamType StreamType, maxBitrate int, proxyId string, conn *mcuProxyConnection, listener McuListener) *mcuProxySubscriber { +func newMcuProxySubscriber(publisherId string, sid string, streamType StreamType, maxBitrate int, proxyId string, conn *mcuProxyConnection, listener McuListener, publisherConn *mcuProxyConnection) *mcuProxySubscriber { return &mcuProxySubscriber{ mcuProxyPubSubCommon: mcuProxyPubSubCommon{ sid: sid, @@ -235,7 +240,8 @@ func newMcuProxySubscriber(publisherId string, sid string, streamType StreamType listener: listener, }, - publisherId: publisherId, + publisherId: publisherId, + publisherConn: publisherConn, } } @@ -244,7 +250,11 @@ func (s *mcuProxySubscriber) Publisher() string { } func (s *mcuProxySubscriber) NotifyClosed() { - log.Printf("Subscriber %s at %s was closed", s.proxyId, s.conn) + if s.publisherConn != nil { + log.Printf("Remote subscriber %s at %s (forwarded to %s) was closed", s.proxyId, s.conn, s.publisherConn) + } else { + log.Printf("Subscriber %s at %s was closed", s.proxyId, s.conn) + } s.listener.SubscriberClosed(s) s.conn.removeSubscriber(s) } @@ -261,14 +271,26 @@ func (s *mcuProxySubscriber) Close(ctx context.Context) { } if response, err := s.conn.performSyncRequest(ctx, msg); err != nil { - log.Printf("Could not delete subscriber %s at %s: %s", s.proxyId, s.conn, err) + if s.publisherConn != nil { + log.Printf("Could not delete remote subscriber %s at %s (forwarded to %s): %s", s.proxyId, s.conn, s.publisherConn, err) + } else { + log.Printf("Could not delete subscriber %s at %s: %s", s.proxyId, s.conn, err) + } return } else if response.Type == "error" { - log.Printf("Could not delete subscriber %s at %s: %s", s.proxyId, s.conn, response.Error) + if s.publisherConn != nil { + log.Printf("Could not delete remote subscriber %s at %s (forwarded to %s): %s", s.proxyId, s.conn, s.publisherConn, response.Error) + } else { + log.Printf("Could not delete subscriber %s at %s: %s", s.proxyId, s.conn, response.Error) + } return } - log.Printf("Deleted subscriber %s at %s", s.proxyId, s.conn) + if s.publisherConn != nil { + log.Printf("Deleted remote subscriber %s at %s (forwarded to %s)", s.proxyId, s.conn, s.publisherConn) + } else { + log.Printf("Deleted subscriber %s at %s", s.proxyId, s.conn) + } } func (s *mcuProxySubscriber) SendMessage(ctx context.Context, message *MessageClientMessage, data *MessageClientMessageData, callback func(error, map[string]interface{})) { @@ -373,6 +395,54 @@ func (c *mcuProxyConnection) String() string { return c.rawUrl } +func (c *mcuProxyConnection) IsSameCountry(initiator McuInitiator) bool { + if initiator == nil { + return true + } + + initiatorCountry := initiator.Country() + if initiatorCountry == "" { + return true + } + + connCountry := c.Country() + if connCountry == "" { + return true + } + + return initiatorCountry == connCountry +} + +func (c *mcuProxyConnection) IsSameContinent(initiator McuInitiator) bool { + if initiator == nil { + return true + } + + initiatorCountry := initiator.Country() + if initiatorCountry == "" { + return true + } + + connCountry := c.Country() + if connCountry == "" { + return true + } + + initiatorContinents, found := ContinentMap[initiatorCountry] + if found { + m := c.proxy.getContinentsMap() + // Map continents to other continents (e.g. use Europe for Africa). + for _, continent := range initiatorContinents { + if toAdd, found := m[continent]; found { + initiatorContinents = append(initiatorContinents, toAdd...) + } + } + + } + connContinents := ContinentMap[connCountry] + return ContinentsOverlap(initiatorContinents, connContinents) +} + type mcuProxyConnectionStats struct { Url string `json:"url"` IP net.IP `json:"ip,omitempty"` @@ -982,14 +1052,7 @@ func (c *mcuProxyConnection) sendHello() error { if sessionId := c.SessionId(); sessionId != "" { msg.Hello.ResumeId = sessionId } else { - claims := &TokenClaims{ - jwt.RegisteredClaims{ - IssuedAt: jwt.NewNumericDate(time.Now()), - Issuer: c.proxy.tokenId, - }, - } - token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) - tokenString, err := token.SignedString(c.proxy.tokenKey) + tokenString, err := c.proxy.createToken("") if err != nil { return err } @@ -1110,7 +1173,48 @@ func (c *mcuProxyConnection) newSubscriber(ctx context.Context, listener McuList proxyId := response.Command.Id log.Printf("Created %s subscriber %s on %s for %s", streamType, proxyId, c, publisherSessionId) - subscriber := newMcuProxySubscriber(publisherSessionId, response.Command.Sid, streamType, response.Command.Bitrate, proxyId, c, listener) + subscriber := newMcuProxySubscriber(publisherSessionId, response.Command.Sid, streamType, response.Command.Bitrate, proxyId, c, listener, nil) + c.subscribersLock.Lock() + c.subscribers[proxyId] = subscriber + c.subscribersLock.Unlock() + statsSubscribersCurrent.WithLabelValues(string(streamType)).Inc() + statsSubscribersTotal.WithLabelValues(string(streamType)).Inc() + return subscriber, nil +} + +func (c *mcuProxyConnection) newRemoteSubscriber(ctx context.Context, listener McuListener, publisherId string, publisherSessionId string, streamType StreamType, publisherConn *mcuProxyConnection) (McuSubscriber, error) { + if c == publisherConn { + return c.newSubscriber(ctx, listener, publisherId, publisherSessionId, streamType) + } + + remoteToken, err := c.proxy.createToken(publisherId) + if err != nil { + return nil, err + } + + msg := &ProxyClientMessage{ + Type: "command", + Command: &CommandProxyClientMessage{ + Type: "create-subscriber", + StreamType: streamType, + PublisherId: publisherId, + + RemoteUrl: publisherConn.rawUrl, + RemoteToken: remoteToken, + }, + } + + response, err := c.performSyncRequest(ctx, msg) + if err != nil { + // TODO: Cancel request + return nil, err + } else if response.Type == "error" { + return nil, fmt.Errorf("Error creating remote %s subscriber for %s on %s (forwarded to %s): %+v", streamType, publisherSessionId, c, publisherConn, response.Error) + } + + proxyId := response.Command.Id + log.Printf("Created remote %s subscriber %s on %s for %s (forwarded to %s)", streamType, proxyId, c, publisherSessionId, publisherConn) + subscriber := newMcuProxySubscriber(publisherSessionId, response.Command.Sid, streamType, response.Command.Bitrate, proxyId, c, listener, publisherConn) c.subscribersLock.Lock() c.subscribers[proxyId] = subscriber c.subscribersLock.Unlock() @@ -1293,6 +1397,23 @@ func (m *mcuProxy) Stop() { m.config.Stop() } +func (m *mcuProxy) createToken(subject string) (string, error) { + claims := &TokenClaims{ + jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(time.Now()), + Issuer: m.tokenId, + Subject: subject, + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tokenString, err := token.SignedString(m.tokenKey) + if err != nil { + return "", err + } + + return tokenString, nil +} + func (m *mcuProxy) hasConnections() bool { m.connectionsMu.RLock() defer m.connectionsMu.RUnlock() @@ -1685,7 +1806,14 @@ func (m *mcuProxy) waitForPublisherConnection(ctx context.Context, publisher str } } -func (m *mcuProxy) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType StreamType) (McuSubscriber, error) { +type proxyPublisherInfo struct { + id string + conn *mcuProxyConnection + err error +} + +func (m *mcuProxy) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType StreamType, initiator McuInitiator) (McuSubscriber, error) { + var publisherInfo *proxyPublisherInfo if conn := m.getPublisherConnection(publisher, streamType); conn != nil { // Fast common path: publisher is available locally. conn.publishersLock.Lock() @@ -1695,113 +1823,159 @@ func (m *mcuProxy) NewSubscriber(ctx context.Context, listener McuListener, publ return nil, fmt.Errorf("Unknown publisher %s", publisher) } - return conn.newSubscriber(ctx, listener, id, publisher, streamType) - } - - log.Printf("No %s publisher %s found yet, deferring", streamType, publisher) - ch := make(chan McuSubscriber) - getctx, cancel := context.WithCancel(ctx) - defer cancel() - - // Wait for publisher to be created locally. - go func() { - if conn := m.waitForPublisherConnection(getctx, publisher, streamType); conn != nil { - cancel() // Cancel pending RPC calls. - - conn.publishersLock.Lock() - id, found := conn.publisherIds[getStreamId(publisher, streamType)] - conn.publishersLock.Unlock() - if !found { - log.Printf("Unknown id for local %s publisher %s", streamType, publisher) - return - } - - subscriber, err := conn.newSubscriber(ctx, listener, id, publisher, streamType) - if subscriber != nil { - ch <- subscriber - } else if err != nil { - log.Printf("Error creating local subscriber for %s publisher %s: %s", streamType, publisher, err) - } + publisherInfo = &proxyPublisherInfo{ + id: id, + conn: conn, } - }() + } else { + log.Printf("No %s publisher %s found yet, deferring", streamType, publisher) + ch := make(chan *proxyPublisherInfo, 1) + getctx, cancel := context.WithCancel(ctx) + defer cancel() - // Wait for publisher to be created on one of the other servers in the cluster. - if clients := m.rpcClients.GetClients(); len(clients) > 0 { - for _, client := range clients { - go func(client *GrpcClient) { - id, url, ip, err := client.GetPublisherId(getctx, publisher, streamType) - if errors.Is(err, context.Canceled) { - return - } else if err != nil { - log.Printf("Error getting %s publisher id %s from %s: %s", streamType, publisher, client.Target(), err) - return - } else if id == "" { - // Publisher not found on other server - return - } + var wg sync.WaitGroup + // Wait for publisher to be created locally. + wg.Add(1) + go func() { + defer wg.Done() + if conn := m.waitForPublisherConnection(getctx, publisher, streamType); conn != nil { cancel() // Cancel pending RPC calls. - log.Printf("Found publisher id %s through %s on proxy %s", id, client.Target(), url) - m.connectionsMu.RLock() - connections := m.connections - m.connectionsMu.RUnlock() - var publisherConn *mcuProxyConnection - for _, conn := range connections { - if conn.rawUrl != url || !ip.Equal(conn.ip) { - continue + conn.publishersLock.Lock() + id, found := conn.publisherIds[getStreamId(publisher, streamType)] + conn.publishersLock.Unlock() + if !found { + ch <- &proxyPublisherInfo{ + err: fmt.Errorf("Unknown id for local %s publisher %s", streamType, publisher), } - - // Simple case, signaling server has a connection to the same endpoint - publisherConn = conn - break - } - - if publisherConn == nil { - publisherConn, err = newMcuProxyConnection(m, url, ip) - if err != nil { - log.Printf("Could not create temporary connection to %s for %s publisher %s: %s", url, streamType, publisher, err) - return - } - - publisherConn.setTemporary() - publisherConn.start() - if err := publisherConn.waitUntilConnected(ctx); err != nil { - log.Printf("Could not establish new connection to %s: %s", publisherConn, err) - publisherConn.closeIfEmpty() - return - } - - m.connectionsMu.Lock() - m.connections = append(m.connections, publisherConn) - conns, found := m.connectionsMap[url] - if found { - conns = append(conns, publisherConn) - } else { - conns = []*mcuProxyConnection{publisherConn} - } - m.connectionsMap[url] = conns - m.connectionsMu.Unlock() - } - - subscriber, err := publisherConn.newSubscriber(ctx, listener, id, publisher, streamType) - if err != nil { - if publisherConn.IsTemporary() { - publisherConn.closeIfEmpty() - } - log.Printf("Could not create subscriber for %s publisher %s: %s", streamType, publisher, err) return } - ch <- subscriber - }(client) + ch <- &proxyPublisherInfo{ + id: id, + conn: conn, + } + } + }() + + // Wait for publisher to be created on one of the other servers in the cluster. + if clients := m.rpcClients.GetClients(); len(clients) > 0 { + for _, client := range clients { + wg.Add(1) + go func(client *GrpcClient) { + defer wg.Done() + id, url, ip, err := client.GetPublisherId(getctx, publisher, streamType) + if errors.Is(err, context.Canceled) { + return + } else if err != nil { + log.Printf("Error getting %s publisher id %s from %s: %s", streamType, publisher, client.Target(), err) + return + } else if id == "" { + // Publisher not found on other server + return + } + + cancel() // Cancel pending RPC calls. + log.Printf("Found publisher id %s through %s on proxy %s", id, client.Target(), url) + + m.connectionsMu.RLock() + connections := m.connections + m.connectionsMu.RUnlock() + var publisherConn *mcuProxyConnection + for _, conn := range connections { + if conn.rawUrl != url || !ip.Equal(conn.ip) { + continue + } + + // Simple case, signaling server has a connection to the same endpoint + publisherConn = conn + break + } + + if publisherConn == nil { + publisherConn, err = newMcuProxyConnection(m, url, ip) + if err != nil { + log.Printf("Could not create temporary connection to %s for %s publisher %s: %s", url, streamType, publisher, err) + return + } + + publisherConn.setTemporary() + publisherConn.start() + if err := publisherConn.waitUntilConnected(ctx); err != nil { + log.Printf("Could not establish new connection to %s: %s", publisherConn, err) + publisherConn.closeIfEmpty() + return + } + + m.connectionsMu.Lock() + m.connections = append(m.connections, publisherConn) + conns, found := m.connectionsMap[url] + if found { + conns = append(conns, publisherConn) + } else { + conns = []*mcuProxyConnection{publisherConn} + } + m.connectionsMap[url] = conns + m.connectionsMu.Unlock() + } + + ch <- &proxyPublisherInfo{ + id: id, + conn: publisherConn, + } + }(client) + } + } + + wg.Wait() + select { + case ch <- &proxyPublisherInfo{ + err: fmt.Errorf("No %s publisher %s found", streamType, publisher), + }: + default: + } + + select { + case info := <-ch: + publisherInfo = info + case <-ctx.Done(): + return nil, fmt.Errorf("No %s publisher %s found", streamType, publisher) } } - select { - case subscriber := <-ch: - return subscriber, nil - case <-ctx.Done(): - return nil, fmt.Errorf("No %s publisher %s found", streamType, publisher) + if publisherInfo.err != nil { + return nil, publisherInfo.err } + + if !publisherInfo.conn.IsSameCountry(initiator) { + connections := m.getSortedConnections(initiator) + if len(connections) > 0 && !connections[0].IsSameCountry(publisherInfo.conn) { + // Connect to remote publisher through "closer" gateway. + for _, conn := range connections { + if conn.IsShutdownScheduled() || conn.IsTemporary() || conn == publisherInfo.conn { + continue + } + + subscriber, err := conn.newRemoteSubscriber(ctx, listener, publisherInfo.id, publisher, streamType, publisherInfo.conn) + if err != nil { + log.Printf("Could not create subscriber for %s publisher %s on %s: %s", streamType, publisher, conn, err) + continue + } + + return subscriber, nil + } + } + } + + subscriber, err := publisherInfo.conn.newSubscriber(ctx, listener, publisherInfo.id, publisher, streamType) + if err != nil { + if publisherInfo.conn.IsTemporary() { + publisherInfo.conn.closeIfEmpty() + } + log.Printf("Could not create subscriber for %s publisher %s on %s: %s", streamType, publisher, publisherInfo.conn, err) + return nil, err + } + + return subscriber, nil } diff --git a/mcu_proxy_test.go b/mcu_proxy_test.go index 7d39e4e..428c8e0 100644 --- a/mcu_proxy_test.go +++ b/mcu_proxy_test.go @@ -195,12 +195,14 @@ type testProxyServerSubscriber struct { id string sid string pub *testProxyServerPublisher + + remoteUrl string } type testProxyServerClient struct { t *testing.T - server *testProxyServerHandler + server *TestProxyServerHandler ws *websocket.Conn processMessage proxyServerClientHandler @@ -273,7 +275,20 @@ func (c *testProxyServerClient) processCommandMessage(msg *ProxyClientMessage) ( c.server.updateLoad(-1) } case "create-subscriber": - pub := c.server.getPublisher(msg.Command.PublisherId) + var pub *testProxyServerPublisher + if msg.Command.RemoteUrl != "" { + for _, server := range c.server.servers { + if server.URL != msg.Command.RemoteUrl { + continue + } + + pub = server.getPublisher(msg.Command.PublisherId) + break + } + } else { + pub = c.server.getPublisher(msg.Command.PublisherId) + } + if pub == nil { response = msg.NewWrappedErrorServerMessage(fmt.Errorf("publisher %s not found", msg.Command.PublisherId)) } else { @@ -292,6 +307,11 @@ func (c *testProxyServerClient) processCommandMessage(msg *ProxyClientMessage) ( if sub, found := c.server.deleteSubscriber(msg.Command.ClientId); !found { response = msg.NewWrappedErrorServerMessage(fmt.Errorf("subscriber %s not found", msg.Command.ClientId)) } else { + if msg.Command.RemoteUrl != sub.remoteUrl { + response = msg.NewWrappedErrorServerMessage(fmt.Errorf("remote subscriber %s not found", msg.Command.ClientId)) + return response, nil + } + response = &ProxyServerMessage{ Id: msg.Id, Type: "command", @@ -415,9 +435,12 @@ func (c *testProxyServerClient) run() { } } -type testProxyServerHandler struct { +type TestProxyServerHandler struct { t *testing.T + URL string + server *httptest.Server + servers []*TestProxyServerHandler upgrader *websocket.Upgrader country string @@ -428,7 +451,7 @@ type testProxyServerHandler struct { subscribers map[string]*testProxyServerSubscriber } -func (h *testProxyServerHandler) createPublisher() *testProxyServerPublisher { +func (h *TestProxyServerHandler) createPublisher() *testProxyServerPublisher { h.mu.Lock() defer h.mu.Unlock() pub := &testProxyServerPublisher{ @@ -446,14 +469,14 @@ func (h *testProxyServerHandler) createPublisher() *testProxyServerPublisher { return pub } -func (h *testProxyServerHandler) getPublisher(id string) *testProxyServerPublisher { +func (h *TestProxyServerHandler) getPublisher(id string) *testProxyServerPublisher { h.mu.Lock() defer h.mu.Unlock() return h.publishers[id] } -func (h *testProxyServerHandler) deletePublisher(id string) (*testProxyServerPublisher, bool) { +func (h *TestProxyServerHandler) deletePublisher(id string) (*testProxyServerPublisher, bool) { h.mu.Lock() defer h.mu.Unlock() @@ -466,7 +489,7 @@ func (h *testProxyServerHandler) deletePublisher(id string) (*testProxyServerPub return pub, true } -func (h *testProxyServerHandler) createSubscriber(pub *testProxyServerPublisher) *testProxyServerSubscriber { +func (h *TestProxyServerHandler) createSubscriber(pub *testProxyServerPublisher) *testProxyServerSubscriber { h.mu.Lock() defer h.mu.Unlock() @@ -487,7 +510,7 @@ func (h *testProxyServerHandler) createSubscriber(pub *testProxyServerPublisher) return sub } -func (h *testProxyServerHandler) deleteSubscriber(id string) (*testProxyServerSubscriber, bool) { +func (h *TestProxyServerHandler) deleteSubscriber(id string) (*testProxyServerSubscriber, bool) { h.mu.Lock() defer h.mu.Unlock() @@ -500,7 +523,7 @@ func (h *testProxyServerHandler) deleteSubscriber(id string) (*testProxyServerSu return sub, true } -func (h *testProxyServerHandler) updateLoad(delta int64) { +func (h *TestProxyServerHandler) updateLoad(delta int64) { if delta == 0 { return } @@ -522,7 +545,7 @@ func (h *testProxyServerHandler) updateLoad(delta int64) { } } -func (h *testProxyServerHandler) sendLoad(c *testProxyServerClient) { +func (h *TestProxyServerHandler) sendLoad(c *testProxyServerClient) { c.sendMessage(&ProxyServerMessage{ Type: "event", Event: &EventProxyServerMessage{ @@ -532,13 +555,13 @@ func (h *testProxyServerHandler) sendLoad(c *testProxyServerClient) { }) } -func (h *testProxyServerHandler) removeClient(client *testProxyServerClient) { +func (h *TestProxyServerHandler) removeClient(client *testProxyServerClient) { h.mu.Lock() defer h.mu.Unlock() delete(h.clients, client.sessionId) } -func (h *testProxyServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (h *TestProxyServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ws, err := h.upgrader.Upgrade(w, r, nil) if err != nil { h.t.Error(err) @@ -559,11 +582,11 @@ func (h *testProxyServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Reques go client.run() } -func NewProxyServerForTest(t *testing.T, country string) *httptest.Server { +func NewProxyServerForTest(t *testing.T, country string) *TestProxyServerHandler { t.Helper() upgrader := websocket.Upgrader{} - proxyHandler := &testProxyServerHandler{ + proxyHandler := &TestProxyServerHandler{ t: t, upgrader: &upgrader, country: country, @@ -572,6 +595,8 @@ func NewProxyServerForTest(t *testing.T, country string) *httptest.Server { subscribers: make(map[string]*testProxyServerSubscriber), } server := httptest.NewServer(proxyHandler) + proxyHandler.server = server + proxyHandler.URL = server.URL t.Cleanup(func() { server.Close() proxyHandler.mu.Lock() @@ -581,12 +606,12 @@ func NewProxyServerForTest(t *testing.T, country string) *httptest.Server { } }) - return server + return proxyHandler } type proxyTestOptions struct { etcd *embed.Etcd - servers []*httptest.Server + servers []*TestProxyServerHandler } func newMcuProxyForTestWithOptions(t *testing.T, options proxyTestOptions) *mcuProxy { @@ -611,11 +636,12 @@ func newMcuProxyForTestWithOptions(t *testing.T, options proxyTestOptions) *mcuP var urls []string waitingMap := make(map[string]bool) if len(options.servers) == 0 { - options.servers = []*httptest.Server{ + options.servers = []*TestProxyServerHandler{ NewProxyServerForTest(t, "DE"), } } for _, s := range options.servers { + s.servers = options.servers urls = append(urls, s.URL) waitingMap[s.URL] = true } @@ -680,7 +706,7 @@ func newMcuProxyForTestWithOptions(t *testing.T, options proxyTestOptions) *mcuP return proxy } -func newMcuProxyForTestWithServers(t *testing.T, servers []*httptest.Server) *mcuProxy { +func newMcuProxyForTestWithServers(t *testing.T, servers []*TestProxyServerHandler) *mcuProxy { t.Helper() return newMcuProxyForTestWithOptions(t, proxyTestOptions{ @@ -692,7 +718,7 @@ func newMcuProxyForTest(t *testing.T) *mcuProxy { t.Helper() server := NewProxyServerForTest(t, "DE") - return newMcuProxyForTestWithServers(t, []*httptest.Server{server}) + return newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{server}) } func Test_ProxyPublisherSubscriber(t *testing.T) { @@ -722,7 +748,10 @@ func Test_ProxyPublisherSubscriber(t *testing.T) { subListener := &MockMcuListener{ publicId: "subscriber-public", } - sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo) + subInitiator := &MockMcuInitiator{ + country: "DE", + } + sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator) if err != nil { t.Fatal(err) } @@ -750,10 +779,13 @@ func Test_ProxyWaitForPublisher(t *testing.T) { subListener := &MockMcuListener{ publicId: "subscriber-public", } + subInitiator := &MockMcuInitiator{ + country: "DE", + } done := make(chan struct{}) go func() { defer close(done) - sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo) + sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator) if err != nil { t.Error(err) return @@ -783,7 +815,7 @@ func Test_ProxyPublisherLoad(t *testing.T) { t.Parallel() server1 := NewProxyServerForTest(t, "DE") server2 := NewProxyServerForTest(t, "DE") - mcu := newMcuProxyForTestWithServers(t, []*httptest.Server{ + mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{ server1, server2, }) @@ -835,7 +867,7 @@ func Test_ProxyPublisherCountry(t *testing.T) { t.Parallel() serverDE := NewProxyServerForTest(t, "DE") serverUS := NewProxyServerForTest(t, "US") - mcu := newMcuProxyForTestWithServers(t, []*httptest.Server{ + mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{ serverDE, serverUS, }) @@ -887,7 +919,7 @@ func Test_ProxyPublisherContinent(t *testing.T) { t.Parallel() serverDE := NewProxyServerForTest(t, "DE") serverUS := NewProxyServerForTest(t, "US") - mcu := newMcuProxyForTestWithServers(t, []*httptest.Server{ + mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{ serverDE, serverUS, }) @@ -933,3 +965,53 @@ func Test_ProxyPublisherContinent(t *testing.T) { t.Errorf("expected server %s, go %s", serverDE.URL, pubFR.(*mcuProxyPublisher).conn.rawUrl) } } + +func Test_ProxySubscriberCountry(t *testing.T) { + CatchLogForTest(t) + t.Parallel() + serverDE := NewProxyServerForTest(t, "DE") + serverUS := NewProxyServerForTest(t, "US") + mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{ + serverDE, + serverUS, + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + pubId := "the-publisher" + pubSid := "1234567890" + pubListener := &MockMcuListener{ + publicId: pubId + "-public", + } + pubInitiator := &MockMcuInitiator{ + country: "DE", + } + pub, err := mcu.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator) + if err != nil { + t.Fatal(err) + } + + defer pub.Close(context.Background()) + + if pub.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL { + t.Errorf("expected server %s, go %s", serverDE.URL, pub.(*mcuProxyPublisher).conn.rawUrl) + } + + subListener := &MockMcuListener{ + publicId: "subscriber-public", + } + subInitiator := &MockMcuInitiator{ + country: "US", + } + sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator) + if err != nil { + t.Fatal(err) + } + + defer sub.Close(context.Background()) + + if sub.(*mcuProxySubscriber).conn.rawUrl != serverUS.URL { + t.Errorf("expected server %s, go %s", serverUS.URL, sub.(*mcuProxySubscriber).conn.rawUrl) + } +} diff --git a/mcu_test.go b/mcu_test.go index 903a2bc..eac0012 100644 --- a/mcu_test.go +++ b/mcu_test.go @@ -23,6 +23,7 @@ package signaling import ( "context" + "errors" "fmt" "log" "sync" @@ -117,7 +118,7 @@ func (m *TestMCU) GetPublisher(id string) *TestMCUPublisher { return m.publishers[id] } -func (m *TestMCU) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType StreamType) (McuSubscriber, error) { +func (m *TestMCU) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType StreamType, initiator McuInitiator) (McuSubscriber, error) { m.mu.Lock() defer m.mu.Unlock() @@ -222,6 +223,10 @@ func (p *TestMCUPublisher) SendMessage(ctx context.Context, message *MessageClie }() } +func (p *TestMCUPublisher) PublishRemote(ctx context.Context, hostname string, port int, rtcpPort int) error { + return errors.New("remote publishing not supported") +} + type TestMCUSubscriber struct { TestMCUClient diff --git a/proxy.conf.in b/proxy.conf.in index 0fcb350..272ccdc 100644 --- a/proxy.conf.in +++ b/proxy.conf.in @@ -25,6 +25,17 @@ # - etcd: Token information are retrieved from an etcd cluster (see below). tokentype = static +# The external hostname for remote streams. Leaving this empty will autodetect +# and use the first public IP found on the available network interfaces. +#hostname = + +# The token id to use when connecting remote stream. +#token_id = server1 + +# The private key for the configured token id to use when connecting remote +# streams. +#token_key = privkey.pem + [tokens] # For token type "static": Mapping of = of signaling # servers allowed to connect. diff --git a/proxy/proxy_remote.go b/proxy/proxy_remote.go new file mode 100644 index 0000000..70ea21f --- /dev/null +++ b/proxy/proxy_remote.go @@ -0,0 +1,340 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2024 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package main + +import ( + "context" + "crypto/rsa" + "crypto/tls" + "encoding/json" + "errors" + "io" + "log" + "net/http" + "net/url" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/golang-jwt/jwt/v4" + "github.com/gorilla/websocket" + + signaling "github.com/strukturag/nextcloud-spreed-signaling" +) + +var ( + ErrNotConnected = errors.New("not connected") +) + +type RemoteConnection struct { + mu sync.Mutex + url *url.URL + conn *websocket.Conn + + tokenId string + tokenKey *rsa.PrivateKey + + msgId atomic.Int64 + helloMsgId string + sessionId string + + messageCallbacks map[string]chan *signaling.ProxyServerMessage +} + +func NewRemoteConnection(proxyUrl string, tokenId string, tokenKey *rsa.PrivateKey) (*RemoteConnection, error) { + u, err := url.Parse(proxyUrl) + if err != nil { + return nil, err + } + + result := &RemoteConnection{ + url: u, + + tokenId: tokenId, + tokenKey: tokenKey, + + messageCallbacks: make(map[string]chan *signaling.ProxyServerMessage), + } + return result, nil +} + +func (c *RemoteConnection) String() string { + return c.url.String() +} + +func (c *RemoteConnection) Connect(ctx context.Context) error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.conn != nil { + return nil + } + + u, err := c.url.Parse("proxy") + if err != nil { + return err + } + if u.Scheme == "http" { + u.Scheme = "ws" + } else if u.Scheme == "https" { + u.Scheme = "wss" + } + + dialer := websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + TLSClientConfig: &tls.Config{ + // TODO: Make this configurable. + InsecureSkipVerify: true, + }, + } + + conn, _, err := dialer.DialContext(ctx, u.String(), nil) + if err != nil { + return err + } + + c.conn = conn + go c.readPump() + + return c.sendHello() +} + +func (c *RemoteConnection) sendHello() error { + c.helloMsgId = strconv.FormatInt(c.msgId.Add(1), 10) + msg := &signaling.ProxyClientMessage{ + Id: c.helloMsgId, + Type: "hello", + Hello: &signaling.HelloProxyClientMessage{ + Version: "1.0", + }, + } + if sessionId := c.sessionId; sessionId != "" { + msg.Hello.ResumeId = sessionId + } else { + tokenString, err := c.createToken("") + if err != nil { + return err + } + + msg.Hello.Token = tokenString + } + + return c.sendMessageLocked(msg) +} + +func (c *RemoteConnection) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + if c.conn == nil { + return nil + } + + err1 := c.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Time{}) + err2 := c.conn.Close() + c.conn = nil + if err1 != nil { + return err1 + } + return err2 +} + +func (c *RemoteConnection) createToken(subject string) (string, error) { + claims := &signaling.TokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(time.Now()), + Issuer: c.tokenId, + Subject: subject, + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tokenString, err := token.SignedString(c.tokenKey) + if err != nil { + return "", err + } + + return tokenString, nil +} + +func (c *RemoteConnection) SendMessage(msg *signaling.ProxyClientMessage) error { + c.mu.Lock() + defer c.mu.Unlock() + + return c.sendMessageLocked(msg) +} + +func (c *RemoteConnection) sendMessageLocked(msg *signaling.ProxyClientMessage) error { + if c.conn == nil { + return ErrNotConnected + } + + return c.conn.WriteJSON(msg) +} + +func (c *RemoteConnection) readPump() { + for { + c.mu.Lock() + conn := c.conn + c.mu.Unlock() + if conn == nil { + return + } + + msgType, reader, err := conn.NextReader() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) { + log.Printf("error reading: %s", err) + } + c.mu.Lock() + c.conn = nil + c.mu.Unlock() + return + } + + body, err := io.ReadAll(reader) + if err != nil { + log.Printf("error reading message: %s", err) + continue + } + + if msgType != websocket.TextMessage { + log.Printf("unexpected message type %q (%s)", msgType, string(body)) + continue + } + + var msg signaling.ProxyServerMessage + if err := json.Unmarshal(body, &msg); err != nil { + log.Printf("could not decode message %s: %s", string(body), err) + continue + } + + c.mu.Lock() + helloMsgId := c.helloMsgId + c.mu.Unlock() + + if helloMsgId != "" && msg.Id == helloMsgId { + c.processHello(&msg) + } else { + c.processMessage(&msg) + } + } +} + +func (c *RemoteConnection) processHello(msg *signaling.ProxyServerMessage) { + c.mu.Lock() + defer c.mu.Unlock() + + c.helloMsgId = "" + switch msg.Type { + case "error": + if msg.Error.Code == "no_such_session" { + log.Printf("Session %s could not be resumed on %s, registering new", c.sessionId, c) + c.sessionId = "" + if err := c.sendHello(); err != nil { + log.Printf("Could not send hello request to %s: %s", c, err) + // TODO: c.scheduleReconnect() + } + return + } + + log.Printf("Hello connection to %s failed with %+v, reconnecting", c, msg.Error) + // TODO: c.scheduleReconnect() + case "hello": + resumed := c.sessionId == msg.Hello.SessionId + c.sessionId = msg.Hello.SessionId + country := "" + if msg.Hello.Server != nil { + if country = msg.Hello.Server.Country; country != "" && !signaling.IsValidCountry(country) { + log.Printf("Proxy %s sent invalid country %s in hello response", c, country) + country = "" + } + } + if resumed { + log.Printf("Resumed session %s on %s", c.sessionId, c) + } else if country != "" { + log.Printf("Received session %s from %s (in %s)", c.sessionId, c, country) + } else { + log.Printf("Received session %s from %s", c.sessionId, c) + } + default: + log.Printf("Received unsupported hello response %+v from %s, reconnecting", msg, c) + // TODO: c.scheduleReconnect() + } +} + +func (c *RemoteConnection) processMessage(msg *signaling.ProxyServerMessage) { + if msg.Id != "" { + c.mu.Lock() + ch, found := c.messageCallbacks[msg.Id] + if found { + delete(c.messageCallbacks, msg.Id) + c.mu.Unlock() + ch <- msg + return + } + c.mu.Unlock() + } + + switch msg.Type { + case "event": + c.processEvent(msg) + default: + log.Printf("Received unsupported message %+v from %s", msg, c) + } +} + +func (c *RemoteConnection) processEvent(msg *signaling.ProxyServerMessage) { + switch msg.Event.Type { + case "update-load": + default: + log.Printf("Received unsupported event %+v from %s", msg, c) + } +} + +func (c *RemoteConnection) RequestMessage(ctx context.Context, msg *signaling.ProxyClientMessage) (*signaling.ProxyServerMessage, error) { + msg.Id = strconv.FormatInt(c.msgId.Add(1), 10) + + c.mu.Lock() + defer c.mu.Unlock() + + if err := c.sendMessageLocked(msg); err != nil { + return nil, err + } + ch := make(chan *signaling.ProxyServerMessage, 1) + c.messageCallbacks[msg.Id] = ch + c.mu.Unlock() + defer func() { + c.mu.Lock() + delete(c.messageCallbacks, msg.Id) + }() + + select { + case <-ctx.Done(): + // TODO: Cancel request. + return nil, ctx.Err() + case response := <-ch: + if response.Type == "error" { + return nil, response.Error + } + return response, nil + } +} diff --git a/proxy/proxy_server.go b/proxy/proxy_server.go index feaf306..bf348fa 100644 --- a/proxy/proxy_server.go +++ b/proxy/proxy_server.go @@ -24,7 +24,9 @@ package main import ( "context" "crypto/rand" + "crypto/rsa" "encoding/json" + "errors" "fmt" "io" "log" @@ -63,6 +65,8 @@ const ( // Maximum age a token may have to prevent reuse of old tokens. maxTokenAge = 5 * time.Minute + + remotePublisherTimeout = 5 * time.Second ) type ContextKey string @@ -70,22 +74,24 @@ type ContextKey string var ( ContextKeySession = ContextKey("session") - TimeoutCreatingPublisher = signaling.NewError("timeout", "Timeout creating publisher.") - TimeoutCreatingSubscriber = signaling.NewError("timeout", "Timeout creating subscriber.") - TokenAuthFailed = signaling.NewError("auth_failed", "The token could not be authenticated.") - TokenExpired = signaling.NewError("token_expired", "The token is expired.") - TokenNotValidYet = signaling.NewError("token_not_valid_yet", "The token is not valid yet.") - UnknownClient = signaling.NewError("unknown_client", "Unknown client id given.") - UnsupportedCommand = signaling.NewError("bad_request", "Unsupported command received.") - UnsupportedMessage = signaling.NewError("bad_request", "Unsupported message received.") - UnsupportedPayload = signaling.NewError("unsupported_payload", "Unsupported payload type.") - ShutdownScheduled = signaling.NewError("shutdown_scheduled", "The server is scheduled to shutdown.") + TimeoutCreatingPublisher = signaling.NewError("timeout", "Timeout creating publisher.") + TimeoutCreatingSubscriber = signaling.NewError("timeout", "Timeout creating subscriber.") + TokenAuthFailed = signaling.NewError("auth_failed", "The token could not be authenticated.") + TokenExpired = signaling.NewError("token_expired", "The token is expired.") + TokenNotValidYet = signaling.NewError("token_not_valid_yet", "The token is not valid yet.") + UnknownClient = signaling.NewError("unknown_client", "Unknown client id given.") + UnsupportedCommand = signaling.NewError("bad_request", "Unsupported command received.") + UnsupportedMessage = signaling.NewError("bad_request", "Unsupported message received.") + UnsupportedPayload = signaling.NewError("unsupported_payload", "Unsupported payload type.") + ShutdownScheduled = signaling.NewError("shutdown_scheduled", "The server is scheduled to shutdown.") + RemoteSubscribersNotSupported = signaling.NewError("unsupported_subscriber", "Remote subscribers are not supported.") ) type ProxyServer struct { version string country string welcomeMessage string + config *goconf.ConfigFile url string mcu signaling.Mcu @@ -109,6 +115,47 @@ type ProxyServer struct { clients map[string]signaling.McuClient clientIds map[string]string clientsLock sync.RWMutex + + tokenId string + tokenKey *rsa.PrivateKey + remoteHostname string + remoteConnections map[string]*RemoteConnection + remoteConnectionsLock sync.Mutex +} + +func IsPublicIP(IP net.IP) bool { + if IP.IsLoopback() || IP.IsLinkLocalMulticast() || IP.IsLinkLocalUnicast() { + return false + } + if ip4 := IP.To4(); ip4 != nil { + switch { + case ip4[0] == 10: + return false + case ip4[0] == 172 && ip4[1] >= 16 && ip4[1] <= 31: + return false + case ip4[0] == 192 && ip4[1] == 168: + return false + default: + return true + } + } + return false +} + +func GetLocalIP() (string, error) { + addrs, err := net.InterfaceAddrs() + if err != nil { + return "", err + } + + for _, address := range addrs { + if ipnet, ok := address.(*net.IPNet); ok && IsPublicIP(ipnet.IP) { + if ipnet.IP.To4() != nil { + return ipnet.IP.String(), nil + } + } + } + return "", nil } func NewProxyServer(r *mux.Router, version string, config *goconf.ConfigFile) (*ProxyServer, error) { @@ -187,10 +234,45 @@ func NewProxyServer(r *mux.Router, version string, config *goconf.ConfigFile) (* return nil, err } + tokenId, _ := config.GetString("app", "token_id") + var tokenKey *rsa.PrivateKey + var remoteHostname string + if tokenId != "" { + tokenKeyFilename, _ := config.GetString("app", "token_key") + if tokenKeyFilename == "" { + return nil, fmt.Errorf("No token key configured") + } + tokenKeyData, err := os.ReadFile(tokenKeyFilename) + if err != nil { + return nil, fmt.Errorf("Could not read private key from %s: %s", tokenKeyFilename, err) + } + tokenKey, err = jwt.ParseRSAPrivateKeyFromPEM(tokenKeyData) + if err != nil { + return nil, fmt.Errorf("Could not parse private key from %s: %s", tokenKeyFilename, err) + } + log.Printf("Using \"%s\" as token id for remote streams", tokenId) + + remoteHostname, _ = config.GetString("app", "hostname") + if remoteHostname == "" { + remoteHostname, err = GetLocalIP() + if err != nil { + return nil, fmt.Errorf("could not get local ip: %w", err) + } + } + if remoteHostname == "" { + log.Printf("WARNING: Could not determine hostname for remote streams, will be disabled. Please configure manually.") + } else { + log.Printf("Using \"%s\" as hostname for remote streams", remoteHostname) + } + } else { + log.Printf("No token id configured, remote streams will be disabled") + } + result := &ProxyServer{ version: version, country: country, welcomeMessage: string(welcomeMessage) + "\n", + config: config, shutdownChannel: make(chan struct{}), @@ -208,6 +290,11 @@ func NewProxyServer(r *mux.Router, version string, config *goconf.ConfigFile) (* clients: make(map[string]signaling.McuClient), clientIds: make(map[string]string), + + tokenId: tokenId, + tokenKey: tokenKey, + remoteHostname: remoteHostname, + remoteConnections: make(map[string]*RemoteConnection), } result.upgrader.CheckOrigin = result.checkOrigin @@ -610,6 +697,40 @@ func (i *emptyInitiator) Country() string { return "" } +type proxyRemotePublisher struct { + proxy *ProxyServer + remoteUrl string + + publisherId string +} + +func (p *proxyRemotePublisher) PublisherId() string { + return p.publisherId +} + +func (p *proxyRemotePublisher) StartPublishing(ctx context.Context, publisher signaling.McuRemotePublisherProperties) error { + var conn *RemoteConnection + conn, err := p.proxy.getRemoteConnection(ctx, p.remoteUrl) + if err != nil { + return err + } + + if _, err := conn.RequestMessage(ctx, &signaling.ProxyClientMessage{ + Type: "command", + Command: &signaling.CommandProxyClientMessage{ + Type: "publish-remote", + ClientId: p.publisherId, + Hostname: p.proxy.remoteHostname, + Port: publisher.Port(), + RtcpPort: publisher.RtcpPort(), + }, + }); err != nil { + return err + } + + return nil +} + func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, session *ProxySession, message *signaling.ProxyClientMessage) { cmd := message.Command @@ -652,18 +773,89 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s case "create-subscriber": id := uuid.New().String() publisherId := cmd.PublisherId - subscriber, err := s.mcu.NewSubscriber(ctx, session, publisherId, cmd.StreamType) - if err == context.DeadlineExceeded { - log.Printf("Timeout while creating %s subscriber on %s for %s", cmd.StreamType, publisherId, session.PublicId()) - session.sendMessage(message.NewErrorServerMessage(TimeoutCreatingSubscriber)) - return - } else if err != nil { + var subscriber signaling.McuSubscriber + var err error + + handleCreateError := func(err error) { + if err == context.DeadlineExceeded { + log.Printf("Timeout while creating %s subscriber on %s for %s", cmd.StreamType, publisherId, session.PublicId()) + session.sendMessage(message.NewErrorServerMessage(TimeoutCreatingSubscriber)) + return + } else if errors.Is(err, signaling.ErrRemoteStreamsNotSupported) { + session.sendMessage(message.NewErrorServerMessage(RemoteSubscribersNotSupported)) + return + } + log.Printf("Error while creating %s subscriber on %s for %s: %s", cmd.StreamType, publisherId, session.PublicId(), err) session.sendMessage(message.NewWrappedErrorServerMessage(err)) - return } - log.Printf("Created %s subscriber %s as %s for %s", cmd.StreamType, subscriber.Id(), id, session.PublicId()) + if cmd.RemoteUrl != "" { + if s.tokenId == "" || s.tokenKey == nil || s.remoteHostname == "" { + session.sendMessage(message.NewErrorServerMessage(RemoteSubscribersNotSupported)) + return + } + + remoteMcu, ok := s.mcu.(signaling.RemoteMcu) + if !ok { + session.sendMessage(message.NewErrorServerMessage(RemoteSubscribersNotSupported)) + return + } + + claims, _, err := s.parseToken(cmd.RemoteToken) + if err != nil { + if e, ok := err.(*signaling.Error); ok { + client.SendMessage(message.NewErrorServerMessage(e)) + } else { + client.SendMessage(message.NewWrappedErrorServerMessage(err)) + } + return + } + + if claims.Subject != publisherId { + session.sendMessage(message.NewErrorServerMessage(TokenAuthFailed)) + return + } + + subCtx, cancel := context.WithTimeout(ctx, remotePublisherTimeout) + defer cancel() + + log.Printf("Creating remote subscriber for %s on %s", publisherId, cmd.RemoteUrl) + + controller := &proxyRemotePublisher{ + proxy: s, + remoteUrl: cmd.RemoteUrl, + publisherId: publisherId, + } + + var publisher signaling.McuRemotePublisher + publisher, err = remoteMcu.NewRemotePublisher(subCtx, session, controller, cmd.StreamType) + if err != nil { + handleCreateError(err) + return + } + + defer func() { + go publisher.Close(context.Background()) + }() + + subscriber, err = remoteMcu.NewRemoteSubscriber(subCtx, session, publisher) + if err != nil { + handleCreateError(err) + return + } + + log.Printf("Created remote %s subscriber %s as %s for %s on %s", cmd.StreamType, subscriber.Id(), id, session.PublicId(), cmd.RemoteUrl) + } else { + subscriber, err = s.mcu.NewSubscriber(ctx, session, publisherId, cmd.StreamType, &emptyInitiator{}) + if err != nil { + handleCreateError(err) + return + } + + log.Printf("Created %s subscriber %s as %s for %s", cmd.StreamType, subscriber.Id(), id, session.PublicId()) + } + session.StoreSubscriber(ctx, id, subscriber) s.StoreClient(id, subscriber) @@ -740,6 +932,33 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s client.Close(context.Background()) }() + response := &signaling.ProxyServerMessage{ + Id: message.Id, + Type: "command", + Command: &signaling.CommandProxyServerMessage{ + Id: cmd.ClientId, + }, + } + session.sendMessage(response) + case "publish-remote": + client := s.GetClient(cmd.ClientId) + if client == nil { + session.sendMessage(message.NewErrorServerMessage(UnknownClient)) + return + } + + publisher, ok := client.(signaling.McuPublisher) + if !ok { + session.sendMessage(message.NewErrorServerMessage(UnknownClient)) + return + } + + if err := publisher.PublishRemote(ctx, cmd.Hostname, cmd.Port, cmd.RtcpPort); err != nil { + log.Printf("Error publishing %s %s to remote %s (port=%d, rtcpPort=%d): %s", publisher.StreamType(), cmd.ClientId, cmd.Hostname, cmd.Port, cmd.RtcpPort, err) + session.sendMessage(message.NewWrappedErrorServerMessage(err)) + return + } + response := &signaling.ProxyServerMessage{ Id: message.Id, Type: "command", @@ -830,13 +1049,9 @@ func (s *ProxyServer) processPayload(ctx context.Context, client *ProxyClient, s }) } -func (s *ProxyServer) NewSession(hello *signaling.HelloProxyClientMessage) (*ProxySession, error) { - if proxyDebugMessages { - log.Printf("Hello: %+v", hello) - } - +func (s *ProxyServer) parseToken(tokenValue string) (*signaling.TokenClaims, string, error) { reason := "auth-failed" - token, err := jwt.ParseWithClaims(hello.Token, &signaling.TokenClaims{}, func(token *jwt.Token) (interface{}, error) { + token, err := jwt.ParseWithClaims(tokenValue, &signaling.TokenClaims{}, func(token *jwt.Token) (interface{}, error) { // Don't forget to validate the alg is what you expect: if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { log.Printf("Unexpected signing method: %v", token.Header["alg"]) @@ -868,25 +1083,35 @@ func (s *ProxyServer) NewSession(hello *signaling.HelloProxyClientMessage) (*Pro }) if err, ok := err.(*jwt.ValidationError); ok { if err.Errors&jwt.ValidationErrorIssuedAt == jwt.ValidationErrorIssuedAt { - statsTokenErrorsTotal.WithLabelValues("not-valid-yet").Inc() - return nil, TokenNotValidYet + return nil, "not-valid-yet", TokenNotValidYet } } if err != nil { - statsTokenErrorsTotal.WithLabelValues(reason).Inc() - return nil, TokenAuthFailed + return nil, reason, TokenAuthFailed } claims, ok := token.Claims.(*signaling.TokenClaims) if !ok || !token.Valid { - statsTokenErrorsTotal.WithLabelValues("auth-failed").Inc() - return nil, TokenAuthFailed + return nil, "auth-failed", TokenAuthFailed } minIssuedAt := time.Now().Add(-maxTokenAge) if issuedAt := claims.IssuedAt; issuedAt != nil && issuedAt.Before(minIssuedAt) { - statsTokenErrorsTotal.WithLabelValues("expired").Inc() - return nil, TokenExpired + return nil, "expired", TokenExpired + } + + return claims, "", nil +} + +func (s *ProxyServer) NewSession(hello *signaling.HelloProxyClientMessage) (*ProxySession, error) { + if proxyDebugMessages { + log.Printf("Hello: %+v", hello) + } + + claims, reason, err := s.parseToken(hello.Token) + if err != nil { + statsTokenErrorsTotal.WithLabelValues(reason).Inc() + return nil, err } sid := s.sid.Add(1) @@ -999,6 +1224,22 @@ func (s *ProxyServer) GetClient(id string) signaling.McuClient { return s.clients[id] } +func (s *ProxyServer) GetPublisher(publisherId string) signaling.McuPublisher { + s.clientsLock.RLock() + defer s.clientsLock.RUnlock() + for _, c := range s.clients { + pub, ok := c.(signaling.McuPublisher) + if !ok { + continue + } + + if pub.Id() == publisherId { + return pub + } + } + return nil +} + func (s *ProxyServer) GetClientId(client signaling.McuClient) string { s.clientsLock.RLock() defer s.clientsLock.RUnlock() @@ -1054,3 +1295,25 @@ func (s *ProxyServer) metricsHandler(w http.ResponseWriter, r *http.Request) { // Expose prometheus metrics at "/metrics". promhttp.Handler().ServeHTTP(w, r) } + +func (s *ProxyServer) getRemoteConnection(ctx context.Context, url string) (*RemoteConnection, error) { + s.remoteConnectionsLock.Lock() + defer s.remoteConnectionsLock.Unlock() + + conn, found := s.remoteConnections[url] + if found { + return conn, nil + } + + conn, err := NewRemoteConnection(url, s.tokenId, s.tokenKey) + if err != nil { + return nil, err + } + + if err := conn.Connect(ctx); err != nil { + return nil, err + } + + s.remoteConnections[url] = conn + return conn, nil +} diff --git a/proxy/proxy_server_test.go b/proxy/proxy_server_test.go index 7ed87df..5000dfc 100644 --- a/proxy/proxy_server_test.go +++ b/proxy/proxy_server_test.go @@ -26,6 +26,7 @@ import ( "crypto/rsa" "crypto/x509" "encoding/pem" + "net" "os" "testing" "time" @@ -120,3 +121,38 @@ func TestTokenInFuture(t *testing.T) { t.Errorf("could have failed with TokenNotValidYet, got %s", err) } } + +func TestPublicIPs(t *testing.T) { + public := []string{ + "8.8.8.8", + "172.15.1.2", + "172.32.1.2", + "192.167.0.1", + "192.169.0.1", + } + private := []string{ + "127.0.0.1", + "10.1.2.3", + "172.16.1.2", + "172.31.1.2", + "192.168.0.1", + "192.168.254.254", + } + for _, s := range public { + ip := net.ParseIP(s) + if len(ip) == 0 { + t.Errorf("invalid IP: %s", s) + } else if !IsPublicIP(ip) { + t.Errorf("should be public IP: %s", s) + } + } + + for _, s := range private { + ip := net.ParseIP(s) + if len(ip) == 0 { + t.Errorf("invalid IP: %s", s) + } else if IsPublicIP(ip) { + t.Errorf("should be private IP: %s", s) + } + } +}