diff --git a/src/proxy/proxy_server.go b/src/proxy/proxy_server.go index 802c916..05c4fa0 100644 --- a/src/proxy/proxy_server.go +++ b/src/proxy/proxy_server.go @@ -178,8 +178,11 @@ func NewProxyServer(r *mux.Router, version string, config *goconf.ConfigFile, na } country, _ := config.GetString("app", "country") - if country != "" { + country = strings.ToUpper(country) + if signaling.IsValidCountry(country) { log.Printf("Sending %s as country information", country) + } else if country != "" { + return nil, fmt.Errorf("Invalid country: %s", country) } else { log.Printf("Not sending country information") } @@ -664,6 +667,12 @@ func (s *ProxyServer) processMessage(client *ProxyClient, data []byte) { } } +type emptyInitiator struct{} + +func (i *emptyInitiator) Country() string { + return "" +} + func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, session *ProxySession, message *signaling.ProxyClientMessage) { cmd := message.Command switch cmd.Type { @@ -674,7 +683,7 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s } id := uuid.New().String() - publisher, err := s.mcu.NewPublisher(ctx, session, id, cmd.StreamType) + publisher, err := s.mcu.NewPublisher(ctx, session, id, cmd.StreamType, &emptyInitiator{}) if err == context.DeadlineExceeded { log.Printf("Timeout while creating %s publisher %s for %s", cmd.StreamType, id, session.PublicId()) session.sendMessage(message.NewErrorServerMessage(TimeoutCreatingPublisher)) diff --git a/src/signaling/client.go b/src/signaling/client.go index 187ca2d..ebfef43 100644 --- a/src/signaling/client.go +++ b/src/signaling/client.go @@ -58,6 +58,21 @@ var ( unknownCountry string = "unknown-country" ) +func IsValidCountry(country string) bool { + switch country { + case "": + fallthrough + case noCountry: + fallthrough + case loopback: + fallthrough + case unknownCountry: + return false + default: + return true + } +} + var ( InvalidFormat = NewError("invalid_format", "Invalid data format.") diff --git a/src/signaling/clientsession.go b/src/signaling/clientsession.go index 7a9d2fc..8a0b496 100644 --- a/src/signaling/clientsession.go +++ b/src/signaling/clientsession.go @@ -436,6 +436,10 @@ func (s *ClientSession) GetClient() *Client { s.mu.Lock() defer s.mu.Unlock() + return s.getClientUnlocked() +} + +func (s *ClientSession) getClientUnlocked() *Client { return s.client } @@ -554,9 +558,10 @@ func (s *ClientSession) GetOrCreatePublisher(ctx context.Context, mcu Mcu, strea publisher, found := s.publishers[streamType] if !found { + client := s.getClientUnlocked() s.mu.Unlock() var err error - publisher, err = mcu.NewPublisher(ctx, s, s.PublicId(), streamType) + publisher, err = mcu.NewPublisher(ctx, s, s.PublicId(), streamType, client) s.mu.Lock() if err != nil { return nil, err diff --git a/src/signaling/hub.go b/src/signaling/hub.go index 0bbd173..9e352f9 100644 --- a/src/signaling/hub.go +++ b/src/signaling/hub.go @@ -1544,10 +1544,13 @@ func (h *Hub) lookupClientCountry(client *Client) string { country, err := h.geoip.LookupCountry(ip) if err != nil { - log.Printf("Could not lookup country for %s", ip) + log.Printf("Could not lookup country for %s: %s", ip, err) return unknownCountry } + if country == "" { + return unknownCountry + } return country } diff --git a/src/signaling/mcu_common.go b/src/signaling/mcu_common.go index 20721af..c821ff0 100644 --- a/src/signaling/mcu_common.go +++ b/src/signaling/mcu_common.go @@ -48,6 +48,10 @@ type McuListener interface { SubscriberClosed(subscriber McuSubscriber) } +type McuInitiator interface { + Country() string +} + type Mcu interface { Start() error Stop() @@ -57,7 +61,7 @@ type Mcu interface { GetStats() interface{} - NewPublisher(ctx context.Context, listener McuListener, id string, streamType string) (McuPublisher, error) + NewPublisher(ctx context.Context, listener McuListener, id string, streamType string, initiator McuInitiator) (McuPublisher, error) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType string) (McuSubscriber, error) } diff --git a/src/signaling/mcu_janus.go b/src/signaling/mcu_janus.go index 84bd18d..6f59524 100644 --- a/src/signaling/mcu_janus.go +++ b/src/signaling/mcu_janus.go @@ -635,7 +635,7 @@ func (m *mcuJanus) getOrCreatePublisherHandle(ctx context.Context, id string, st return handle, response.Session, roomId, nil } -func (m *mcuJanus) NewPublisher(ctx context.Context, listener McuListener, id string, streamType string) (McuPublisher, error) { +func (m *mcuJanus) NewPublisher(ctx context.Context, listener McuListener, id string, streamType string, initiator McuInitiator) (McuPublisher, error) { if _, found := streamTypeUserIds[streamType]; !found { return nil, fmt.Errorf("Unsupported stream type %s", streamType) } diff --git a/src/signaling/mcu_proxy.go b/src/signaling/mcu_proxy.go index 032d1f1..145c4f9 100644 --- a/src/signaling/mcu_proxy.go +++ b/src/signaling/mcu_proxy.go @@ -260,6 +260,7 @@ type mcuProxyConnection struct { helloMsgId string sessionId string load int64 + country atomic.Value callbacks map[string]func(*ProxyServerMessage) @@ -289,6 +290,7 @@ func newMcuProxyConnection(proxy *mcuProxy, baseUrl string) (*mcuProxyConnection publisherIds: make(map[string]string), subscribers: make(map[string]*mcuProxySubscriber), } + conn.country.Store("") return conn, nil } @@ -324,6 +326,10 @@ func (c *mcuProxyConnection) Load() int64 { return atomic.LoadInt64(&c.load) } +func (c *mcuProxyConnection) Country() string { + return c.country.Load().(string) +} + func (c *mcuProxyConnection) IsShutdownScheduled() bool { return atomic.LoadUint32(&c.shutdownScheduled) != 0 } @@ -564,7 +570,19 @@ func (c *mcuProxyConnection) processMessage(msg *ProxyServerMessage) { c.scheduleReconnect() case "hello": c.sessionId = msg.Hello.SessionId - log.Printf("Received session %s from %s", c.sessionId, c.url) + country := "" + if msg.Hello.Server != nil { + if country = msg.Hello.Server.Country; country != "" && !IsValidCountry(country) { + log.Printf("Proxy %s sent invalid country %s in hello response", c.url, country) + country = "" + } + } + c.country.Store(country) + if country != "" { + log.Printf("Received session %s from %s (in %s)", c.sessionId, c.url, country) + } else { + log.Printf("Received session %s from %s", c.sessionId, c.url) + } default: log.Printf("Received unsupported hello response %+v from %s, reconnecting", msg, c.url) c.scheduleReconnect() @@ -930,7 +948,56 @@ func (l mcuProxyConnectionsList) Sort() { sort.Sort(l) } -func (m *mcuProxy) getSortedConnections() []*mcuProxyConnection { +func ContinentsOverlap(a, b []string) bool { + if len(a) == 0 || len(b) == 0 { + return false + } + + for _, checkA := range a { + for _, checkB := range b { + if checkA == checkB { + return true + } + } + } + return false +} + +func sortConnectionsForCountry(connections []*mcuProxyConnection, country string) []*mcuProxyConnection { + // Move connections in the same country to the start of the list. + sorted := make(mcuProxyConnectionsList, 0, len(connections)) + unprocessed := make(mcuProxyConnectionsList, 0, len(connections)) + for _, conn := range connections { + if country == conn.Country() { + sorted = append(sorted, conn) + } else { + unprocessed = append(unprocessed, conn) + } + } + if continents, found := ContinentMap[country]; found && len(unprocessed) > 1 { + remaining := make(mcuProxyConnectionsList, 0, len(unprocessed)) + // Next up are connections on the same continent. + for _, conn := range unprocessed { + connCountry := conn.Country() + if IsValidCountry(connCountry) { + connContinents := ContinentMap[connCountry] + if ContinentsOverlap(continents, connContinents) { + sorted = append(sorted, conn) + } else { + remaining = append(remaining, conn) + } + } else { + remaining = append(remaining, conn) + } + } + unprocessed = remaining + } + // Add all other connections by load. + sorted = append(sorted, unprocessed...) + return sorted +} + +func (m *mcuProxy) getSortedConnections(initiator McuInitiator) []*mcuProxyConnection { connections := m.getConnections() if len(connections) < 2 { return connections @@ -951,6 +1018,11 @@ func (m *mcuProxy) getSortedConnections() []*mcuProxyConnection { connections = sorted } + if initiator != nil { + if country := initiator.Country(); IsValidCountry(country) { + connections = sortConnectionsForCountry(connections, country) + } + } return connections } @@ -980,8 +1052,8 @@ func (m *mcuProxy) removeWaiter(id uint64) { delete(m.publisherWaiters, id) } -func (m *mcuProxy) NewPublisher(ctx context.Context, listener McuListener, id string, streamType string) (McuPublisher, error) { - connections := m.getSortedConnections() +func (m *mcuProxy) NewPublisher(ctx context.Context, listener McuListener, id string, streamType string, initiator McuInitiator) (McuPublisher, error) { + connections := m.getSortedConnections(initiator) for _, conn := range connections { if conn.IsShutdownScheduled() { continue diff --git a/src/signaling/mcu_proxy_test.go b/src/signaling/mcu_proxy_test.go new file mode 100644 index 0000000..bf1cbca --- /dev/null +++ b/src/signaling/mcu_proxy_test.go @@ -0,0 +1,86 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2020 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "testing" +) + +func newProxyConnectionWithCountry(country string) *mcuProxyConnection { + conn := &mcuProxyConnection{} + conn.country.Store(country) + return conn +} + +func Test_sortConnectionsForCountry(t *testing.T) { + conn_de := newProxyConnectionWithCountry("DE") + conn_at := newProxyConnectionWithCountry("AT") + conn_jp := newProxyConnectionWithCountry("JP") + conn_us := newProxyConnectionWithCountry("US") + + testcases := map[string][][]*mcuProxyConnection{ + // Direct country match + "DE": [][]*mcuProxyConnection{ + []*mcuProxyConnection{conn_at, conn_jp, conn_de}, + []*mcuProxyConnection{conn_de, conn_at, conn_jp}, + }, + // Direct country match + "AT": [][]*mcuProxyConnection{ + []*mcuProxyConnection{conn_at, conn_jp, conn_de}, + []*mcuProxyConnection{conn_at, conn_de, conn_jp}, + }, + // Continent match + "CH": [][]*mcuProxyConnection{ + []*mcuProxyConnection{conn_de, conn_jp, conn_at}, + []*mcuProxyConnection{conn_de, conn_at, conn_jp}, + }, + // Direct country match + "JP": [][]*mcuProxyConnection{ + []*mcuProxyConnection{conn_de, conn_jp, conn_at}, + []*mcuProxyConnection{conn_jp, conn_de, conn_at}, + }, + // Continent match + "CN": [][]*mcuProxyConnection{ + []*mcuProxyConnection{conn_de, conn_jp, conn_at}, + []*mcuProxyConnection{conn_jp, conn_de, conn_at}, + }, + // Partial continent match + "RU": [][]*mcuProxyConnection{ + []*mcuProxyConnection{conn_us, conn_de, conn_jp, conn_at}, + []*mcuProxyConnection{conn_de, conn_jp, conn_at, conn_us}, + }, + // No match + "AU": [][]*mcuProxyConnection{ + []*mcuProxyConnection{conn_us, conn_de, conn_jp, conn_at}, + []*mcuProxyConnection{conn_us, conn_de, conn_jp, conn_at}, + }, + } + + for country, test := range testcases { + sorted := sortConnectionsForCountry(test[0], country) + for idx, conn := range sorted { + if test[1][idx] != conn { + t.Errorf("Index %d for %s: expected %s, got %s", idx, country, test[1][idx].Country(), conn.Country()) + } + } + } +} diff --git a/src/signaling/mcu_test.go b/src/signaling/mcu_test.go index 7062a49..dbfe485 100644 --- a/src/signaling/mcu_test.go +++ b/src/signaling/mcu_test.go @@ -51,7 +51,7 @@ func (m *TestMCU) GetStats() interface{} { return nil } -func (m *TestMCU) NewPublisher(ctx context.Context, listener McuListener, id string, streamType string) (McuPublisher, error) { +func (m *TestMCU) NewPublisher(ctx context.Context, listener McuListener, id string, streamType string, initiator McuInitiator) (McuPublisher, error) { return nil, fmt.Errorf("Not implemented") }