diff --git a/.gitignore b/.gitignore index ae1d4af..7296465 100644 --- a/.gitignore +++ b/.gitignore @@ -2,9 +2,11 @@ bin/ vendor/ *_easyjson.go +*.pem *.prof *.socket *.tar.gz cover.out +proxy.conf server.conf diff --git a/Makefile b/Makefile index ee0c9c1..ea49861 100644 --- a/Makefile +++ b/Makefile @@ -97,6 +97,7 @@ coverhtml: dependencies vet common common: easyjson \ src/signaling/api_signaling_easyjson.go \ src/signaling/api_backend_easyjson.go \ + src/signaling/api_proxy_easyjson.go \ src/signaling/natsclient_easyjson.go \ src/signaling/room_easyjson.go @@ -108,10 +109,14 @@ server: dependencies common mkdir -p $(BINDIR) GOPATH=$(GOPATH) $(GO) build $(BUILDARGS) -ldflags '$(INTERNALLDFLAGS)' -o $(BINDIR)/signaling ./src/server/... +proxy: dependencies common + mkdir -p $(BINDIR) + GOPATH=$(GOPATH) $(GO) build $(BUILDARGS) -ldflags '$(INTERNALLDFLAGS)' -o $(BINDIR)/proxy ./src/proxy/... + clean: rm -f src/signaling/*_easyjson.go -build: server +build: server proxy tarball: git archive \ diff --git a/README.md b/README.md index 358ab23..785c34a 100644 --- a/README.md +++ b/README.md @@ -127,6 +127,33 @@ The maximum bandwidth per publishing stream can also be configured in the section `[mcu]`, see properties `maxstreambitrate` and `maxscreenbitrate`. +### Use multiple Janus servers + +To scale the setup and add high availability, a signaling server can connect to +one or multiple proxy servers that each provide access to a single Janus server. + +For that, set the `type` key in section `[mcu]` to `proxy` and set `url` to a +space-separated list of URLs where a proxy server is running. + +Each signaling server that connects to a proxy needs a unique token id and a +public / private RSA keypair. The token id must be configured as `token_id` in +section `[mcu]`, the path to the private key file as `token_key`. + + +### Setup of proxy server + +The proxy server is built with the standard make command `make build` as +`bin/proxy` binary. Copy the `proxy.conf.in` as `proxy.conf` and edit section +`[tokens]` to the list of allowed token ids and filenames of the public keys +for each token id. See the comments in `proxy.conf.in` for other configuration +options. + +When the proxy process receives a `SIGHUP` signal, the list of allowed token +ids / public keys is reloaded. A `SIGUSR1` signal can be used to shutdown a +proxy process gracefully after all clients have been disconnected. No new +publishers will be accepted in this case. + + ## Setup of frontend webserver Usually the standalone signaling server is running behind a webserver that does diff --git a/dependencies.tsv b/dependencies.tsv index 9855e5f..020bea7 100644 --- a/dependencies.tsv +++ b/dependencies.tsv @@ -1,4 +1,5 @@ github.com/dlintw/goconf git dcc070983490608a14480e3bf943bad464785df5 2012-02-28T08:26:10Z +github.com/google/uuid git 0e4e31197428a347842d152773b4cace4645ca25 2020-07-02T18:56:42Z github.com/gorilla/context git 08b5f424b9271eedf6f9f0ce86cb9396ed337a42 2016-08-17T18:46:32Z github.com/gorilla/mux git ac112f7d75a0714af1bd86ab17749b31f7809640 2017-07-04T07:43:45Z github.com/gorilla/securecookie git e59506cc896acb7f7bf732d4fdf5e25f7ccd8983 2017-02-24T19:38:04Z @@ -10,3 +11,4 @@ github.com/notedit/janus-go git 10eb8b95d1a0469ac8921c5ce5fb55b4c0d3ad7d 2020-05 github.com/oschwald/maxminddb-golang git 1960b16a5147df3a4c61ac83b2f31cd8f811d609 2019-05-23T23:57:38Z golang.org/x/net git f01ecb60fe3835d80d9a0b7b2bf24b228c89260e 2017-07-11T18:12:19Z golang.org/x/sys git ac767d655b305d4e9612f5f6e33120b9176c4ad4 2018-07-15T08:55:29Z +gopkg.in/dgrijalva/jwt-go.v3 git 06ea1031745cb8b3dab3f6a236daf2b0aa468b7e 2018-03-08T23:13:08Z diff --git a/proxy.conf.in b/proxy.conf.in new file mode 100644 index 0000000..4770eda --- /dev/null +++ b/proxy.conf.in @@ -0,0 +1,55 @@ +[http] +# IP and port to listen on for HTTP requests. +# Comment line to disable the listener. +#listen = 127.0.0.1:9090 + +[app] +# Set to "true" to install pprof debug handlers. +# See "https://golang.org/pkg/net/http/pprof/" for further information. +#debug = false + +# ISO 3166 country this proxy is located at. This will be used by the signaling +# servers to determine the closest proxy for publishers. +#country = DE + +[sessions] +# Secret value used to generate checksums of sessions. This should be a random +# string of 32 or 64 bytes. +hashkey = secret-for-session-checksums + +# Optional key for encrypting data in the sessions. Must be either 16, 24 or +# 32 bytes. +# If no key is specified, data will not be encrypted (not recommended). +blockkey = -encryption-key- + +[nats] +# Url of NATS backend to use. This can also be a list of URLs to connect to +# multiple backends. For local development, this can be set to ":loopback:" +# to process NATS messages internally instead of sending them through an +# external NATS backend. +#url = nats://localhost:4222 + +[tokens] +# Mapping of = of signaling servers allowed to connect. +#server1 = pubkey1.pem +#server2 = pubkey2.pem + +[mcu] +# The type of the MCU to use. Currently only "janus" is supported. +type = janus + +# The URL to the websocket endpoint of the MCU server. +url = ws://localhost:8188/ + +# The maximum bitrate per publishing stream (in bits per second). +# Defaults to 1 mbit/sec. +#maxstreambitrate = 1048576 + +# The maximum bitrate per screensharing stream (in bits per second). +# Default is 2 mbit/sec. +#maxscreenbitrate = 2097152 + +[stats] +# Comma-separated list of IP addresses that are allowed to access the stats +# endpoint. Leave empty (or commented) to only allow access from "127.0.0.1". +#allowed_ips = diff --git a/server.conf.in b/server.conf.in index 6488e4e..aaf10c2 100644 --- a/server.conf.in +++ b/server.conf.in @@ -98,21 +98,31 @@ connectionsperhost = 8 #url = nats://localhost:4222 [mcu] -# The type of the MCU to use. Currently only "janus" is supported. +# The type of the MCU to use. Currently only "janus" and "proxy" are supported. type = janus -# The URL to the websocket endpoint of the MCU server. Leave empty to disable -# MCU functionality. +# For type "janus": the URL to the websocket endpoint of the MCU server. +# For type "proxy": a space-separated list of proxy URLs to connect to. +# Leave empty to disable MCU functionality. url = -# The maximum bitrate per publishing stream (in bits per second). +# For type "janus": the maximum bitrate per publishing stream (in bits per +# second). # Defaults to 1 mbit/sec. #maxstreambitrate = 1048576 -# The maximum bitrate per screensharing stream (in bits per second). +# For type "janus": the maximum bitrate per screensharing stream (in bits per +# second). # Default is 2 mbit/sec. #maxscreenbitrate = 2097152 +# For type "proxy": the id of the token to use when connecting to proxy servers. +#token_id = server1 + +# For type "proxy": the private key for the configured token id to use when +# connecting to proxy servers. +#token_key = privkey.pem + [turn] # API key that the MCU will need to send when requesting TURN credentials. #apikey = the-api-key-for-the-rest-service diff --git a/src/proxy/main.go b/src/proxy/main.go new file mode 100644 index 0000000..e3505e8 --- /dev/null +++ b/src/proxy/main.go @@ -0,0 +1,161 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2020 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package main + +import ( + "flag" + "fmt" + "log" + "net" + "net/http" + "os" + "os/signal" + "runtime" + "strings" + "syscall" + "time" + + "github.com/dlintw/goconf" + "github.com/gorilla/mux" + "github.com/nats-io/go-nats" + + "signaling" +) + +var ( + version = "unreleased" + + configFlag = flag.String("config", "proxy.conf", "config file to use") + + showVersion = flag.Bool("version", false, "show version and quit") +) + +const ( + defaultReadTimeout = 15 + defaultWriteTimeout = 15 + + proxyDebugMessages = false +) + +func main() { + log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds | log.Lshortfile) + flag.Parse() + + if *showVersion { + fmt.Printf("nextcloud-spreed-signaling-proxy version %s/%s\n", version, runtime.Version()) + os.Exit(0) + } + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt) + signal.Notify(sigChan, syscall.SIGHUP) + signal.Notify(sigChan, syscall.SIGUSR1) + + log.Printf("Starting up version %s/%s as pid %d", version, runtime.Version(), os.Getpid()) + + config, err := goconf.ReadConfigFile(*configFlag) + if err != nil { + log.Fatal("Could not read configuration: ", err) + } + + cpus := runtime.NumCPU() + runtime.GOMAXPROCS(cpus) + log.Printf("Using a maximum of %d CPUs\n", cpus) + + natsUrl, _ := config.GetString("nats", "url") + if natsUrl == "" { + natsUrl = nats.DefaultURL + } + + nats, err := signaling.NewNatsClient(natsUrl) + if err != nil { + log.Fatal("Could not create NATS client: ", err) + } + + r := mux.NewRouter() + + proxy, err := NewProxyServer(r, version, config, nats) + if err != nil { + log.Fatal(err) + } + + if err := proxy.Start(config); err != nil { + log.Fatal(err) + } + defer proxy.Stop() + + if addr, _ := config.GetString("http", "listen"); addr != "" { + readTimeout, _ := config.GetInt("http", "readtimeout") + if readTimeout <= 0 { + readTimeout = defaultReadTimeout + } + writeTimeout, _ := config.GetInt("http", "writetimeout") + if writeTimeout <= 0 { + writeTimeout = defaultWriteTimeout + } + + for _, address := range strings.Split(addr, " ") { + go func(address string) { + log.Println("Listening on", address) + listener, err := net.Listen("tcp", address) + if err != nil { + log.Fatal("Could not start listening: ", err) + } + srv := &http.Server{ + Handler: r, + Addr: addr, + + ReadTimeout: time.Duration(readTimeout) * time.Second, + WriteTimeout: time.Duration(writeTimeout) * time.Second, + } + if err := srv.Serve(listener); err != nil { + log.Fatal("Could not start server: ", err) + } + }(address) + } + } + +loop: + for { + select { + case sig := <-sigChan: + switch sig { + case os.Interrupt: + log.Println("Interrupted") + break loop + case syscall.SIGHUP: + log.Printf("Received SIGHUP, reloading %s", *configFlag) + if config, err := goconf.ReadConfigFile(*configFlag); err != nil { + log.Printf("Could not read configuration from %s: %s", *configFlag, err) + } else { + proxy.Reload(config) + } + case syscall.SIGUSR1: + log.Printf("Received SIGUSR1, scheduling server to shutdown") + proxy.ScheduleShutdown() + } + case <-proxy.ShutdownChannel(): + log.Printf("All clients disconnected, shutting down") + break loop + } + } +} diff --git a/src/proxy/proxy_client.go b/src/proxy/proxy_client.go new file mode 100644 index 0000000..5cb8d6a --- /dev/null +++ b/src/proxy/proxy_client.go @@ -0,0 +1,55 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2020 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package main + +import ( + "sync/atomic" + "unsafe" + + "github.com/gorilla/websocket" + + "signaling" +) + +type ProxyClient struct { + signaling.Client + + proxy *ProxyServer + + session unsafe.Pointer +} + +func NewProxyClient(proxy *ProxyServer, conn *websocket.Conn, addr string) (*ProxyClient, error) { + client := &ProxyClient{ + proxy: proxy, + } + client.SetConn(conn, addr) + return client, nil +} + +func (c *ProxyClient) GetSession() *ProxySession { + return (*ProxySession)(atomic.LoadPointer(&c.session)) +} + +func (c *ProxyClient) SetSession(session *ProxySession) { + atomic.StorePointer(&c.session, unsafe.Pointer(session)) +} diff --git a/src/proxy/proxy_server.go b/src/proxy/proxy_server.go new file mode 100644 index 0000000..05c4fa0 --- /dev/null +++ b/src/proxy/proxy_server.go @@ -0,0 +1,1037 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2020 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package main + +import ( + "crypto/rsa" + "encoding/json" + "fmt" + "io/ioutil" + "log" + "net" + "net/http" + "net/http/pprof" + "os" + "os/signal" + runtimepprof "runtime/pprof" + "sort" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/dlintw/goconf" + "github.com/google/uuid" + "github.com/gorilla/mux" + "github.com/gorilla/securecookie" + "github.com/gorilla/websocket" + + "golang.org/x/net/context" + + "gopkg.in/dgrijalva/jwt-go.v3" + + "signaling" +) + +const ( + // Buffer sizes when reading/writing websocket connections. + websocketReadBufferSize = 4096 + websocketWriteBufferSize = 4096 + + initialMcuRetry = time.Second + maxMcuRetry = time.Second * 16 + + updateLoadInterval = time.Second + expireSessionsInterval = 10 * time.Second + + // Maximum age a token may have to prevent reuse of old tokens. + maxTokenAge = 5 * time.Minute +) + +type ContextKey string + +var ( + ContextKeySession = ContextKey("session") + + TimeoutCreatingPublisher = signaling.NewError("timeout", "Timeout creating publisher.") + TimeoutCreatingSubscriber = signaling.NewError("timeout", "Timeout creating subscriber.") + TokenAuthFailed = signaling.NewError("auth_failed", "The token could not be authenticated.") + TokenExpired = signaling.NewError("token_expired", "The token is expired.") + UnknownClient = signaling.NewError("unknown_client", "Unknown client id given.") + UnsupportedCommand = signaling.NewError("bad_request", "Unsupported command received.") + UnsupportedMessage = signaling.NewError("bad_request", "Unsupported message received.") + UnsupportedPayload = signaling.NewError("unsupported_payload", "Unsupported payload type.") + ShutdownScheduled = signaling.NewError("shutdown_scheduled", "The server is scheduled to shutdown.") +) + +type ProxyServer struct { + version string + country string + + url string + nats signaling.NatsClient + mcu signaling.Mcu + stopped uint32 + load int64 + + shutdownChannel chan bool + shutdownScheduled uint32 + + upgrader websocket.Upgrader + + tokenKeys atomic.Value + statsAllowedIps map[string]bool + + sid uint64 + cookie *securecookie.SecureCookie + sessions map[uint64]*ProxySession + sessionsLock sync.RWMutex + + clients map[string]signaling.McuClient + clientIds map[string]string + clientsLock sync.RWMutex +} + +func NewProxyServer(r *mux.Router, version string, config *goconf.ConfigFile, nats signaling.NatsClient) (*ProxyServer, error) { + hashKey, _ := config.GetString("sessions", "hashkey") + switch len(hashKey) { + case 32: + case 64: + default: + log.Printf("WARNING: The sessions hash key should be 32 or 64 bytes but is %d bytes", len(hashKey)) + } + + blockKey, _ := config.GetString("sessions", "blockkey") + blockBytes := []byte(blockKey) + switch len(blockKey) { + case 0: + blockBytes = nil + case 16: + case 24: + case 32: + default: + return nil, fmt.Errorf("The sessions block key must be 16, 24 or 32 bytes but is %d bytes", len(blockKey)) + } + + tokenKeys := make(map[string]*rsa.PublicKey) + options, _ := config.GetOptions("tokens") + for _, id := range options { + filename, _ := config.GetString("tokens", id) + if filename == "" { + return nil, fmt.Errorf("No filename given for token %s", id) + } + + keyData, err := ioutil.ReadFile(filename) + if err != nil { + return nil, fmt.Errorf("Could not read public key from %s: %s", filename, err) + } + key, err := jwt.ParseRSAPublicKeyFromPEM(keyData) + if err != nil { + return nil, fmt.Errorf("Could not parse public key from %s: %s", filename, err) + } + + tokenKeys[id] = key + } + + var keyIds []string + for k, _ := range tokenKeys { + keyIds = append(keyIds, k) + } + sort.Strings(keyIds) + log.Printf("Enabled token keys: %v", keyIds) + + statsAllowed, _ := config.GetString("stats", "allowed_ips") + var statsAllowedIps map[string]bool + if statsAllowed == "" { + log.Printf("No IPs configured for the stats endpoint, only allowing access from 127.0.0.1") + statsAllowedIps = map[string]bool{ + "127.0.0.1": true, + } + } else { + log.Printf("Only allowing access to the stats endpoing from %s", statsAllowed) + statsAllowedIps = make(map[string]bool) + for _, ip := range strings.Split(statsAllowed, ",") { + ip = strings.TrimSpace(ip) + if ip != "" { + statsAllowedIps[ip] = true + } + } + } + + country, _ := config.GetString("app", "country") + country = strings.ToUpper(country) + if signaling.IsValidCountry(country) { + log.Printf("Sending %s as country information", country) + } else if country != "" { + return nil, fmt.Errorf("Invalid country: %s", country) + } else { + log.Printf("Not sending country information") + } + + result := &ProxyServer{ + version: version, + country: country, + + nats: nats, + + shutdownChannel: make(chan bool, 1), + + upgrader: websocket.Upgrader{ + ReadBufferSize: websocketReadBufferSize, + WriteBufferSize: websocketWriteBufferSize, + }, + + statsAllowedIps: statsAllowedIps, + + cookie: securecookie.New([]byte(hashKey), blockBytes).MaxAge(0), + sessions: make(map[uint64]*ProxySession), + + clients: make(map[string]signaling.McuClient), + clientIds: make(map[string]string), + } + + result.setTokenKeys(tokenKeys) + result.upgrader.CheckOrigin = result.checkOrigin + + if debug, _ := config.GetBool("app", "debug"); debug { + log.Println("Installing debug handlers in \"/debug/pprof\"") + r.Handle("/debug/pprof/", http.HandlerFunc(pprof.Index)) + r.Handle("/debug/pprof/cmdline", http.HandlerFunc(pprof.Cmdline)) + r.Handle("/debug/pprof/profile", http.HandlerFunc(pprof.Profile)) + r.Handle("/debug/pprof/symbol", http.HandlerFunc(pprof.Symbol)) + r.Handle("/debug/pprof/trace", http.HandlerFunc(pprof.Trace)) + for _, profile := range runtimepprof.Profiles() { + name := profile.Name() + r.Handle("/debug/pprof/"+name, pprof.Handler(name)) + } + } + + r.HandleFunc("/proxy", result.setCommonHeaders(result.proxyHandler)).Methods("GET") + r.HandleFunc("/stats", result.setCommonHeaders(result.validateStatsRequest(result.statsHandler))).Methods("GET") + return result, nil +} + +func (s *ProxyServer) checkOrigin(r *http.Request) bool { + // We allow any Origin to connect to the service. + return true +} + +func (s *ProxyServer) setTokenKeys(keys map[string]*rsa.PublicKey) { + s.tokenKeys.Store(keys) +} + +func (s *ProxyServer) getTokenKeys() map[string]*rsa.PublicKey { + return s.tokenKeys.Load().(map[string]*rsa.PublicKey) +} + +func (s *ProxyServer) Start(config *goconf.ConfigFile) error { + s.url, _ = config.GetString("mcu", "url") + if s.url == "" { + return fmt.Errorf("No MCU server url configured") + } + + mcuType, _ := config.GetString("mcu", "type") + if mcuType == "" { + mcuType = signaling.McuTypeDefault + } + + interrupt := make(chan os.Signal, 1) + signal.Notify(interrupt, os.Interrupt) + defer signal.Stop(interrupt) + + var err error + var mcu signaling.Mcu + mcuRetry := initialMcuRetry + mcuRetryTimer := time.NewTimer(mcuRetry) + for { + switch mcuType { + case signaling.McuTypeJanus: + mcu, err = signaling.NewMcuJanus(s.url, config, s.nats) + default: + return fmt.Errorf("Unsupported MCU type: %s", mcuType) + } + if err == nil { + mcu.SetOnConnected(s.onMcuConnected) + mcu.SetOnDisconnected(s.onMcuDisconnected) + err = mcu.Start() + if err != nil { + log.Printf("Could not create %s MCU at %s: %s", mcuType, s.url, err) + } + } + if err == nil { + break + } + + log.Printf("Could not initialize %s MCU at %s (%s) will retry in %s", mcuType, s.url, err, mcuRetry) + mcuRetryTimer.Reset(mcuRetry) + select { + case <-interrupt: + return fmt.Errorf("Cancelled") + case <-mcuRetryTimer.C: + // Retry connection + mcuRetry = mcuRetry * 2 + if mcuRetry > maxMcuRetry { + mcuRetry = maxMcuRetry + } + } + } + + s.mcu = mcu + + go s.run() + + return nil +} + +func (s *ProxyServer) run() { + updateLoadTicker := time.NewTicker(updateLoadInterval) + expireSessionsTicker := time.NewTicker(expireSessionsInterval) +loop: + for { + select { + case <-updateLoadTicker.C: + if atomic.LoadUint32(&s.stopped) != 0 { + break loop + } + s.updateLoad() + case <-expireSessionsTicker.C: + if atomic.LoadUint32(&s.stopped) != 0 { + break loop + } + s.expireSessions() + } + } +} + +func (s *ProxyServer) updateLoad() { + // TODO: Take maximum bandwidth of clients into account when calculating + // load (screensharing requires more than regular audio/video). + load := s.GetClientCount() + if load == atomic.LoadInt64(&s.load) { + return + } + + atomic.StoreInt64(&s.load, load) + if atomic.LoadUint32(&s.shutdownScheduled) != 0 { + // Server is scheduled to shutdown, no need to update clients with current load. + return + } + + msg := &signaling.ProxyServerMessage{ + Type: "event", + Event: &signaling.EventProxyServerMessage{ + Type: "update-load", + Load: load, + }, + } + + s.IterateSessions(func(session *ProxySession) { + session.sendMessage(msg) + }) +} + +func (s *ProxyServer) getExpiredSessions() []*ProxySession { + var expired []*ProxySession + s.IterateSessions(func(session *ProxySession) { + if session.IsExpired() { + expired = append(expired, session) + } + }) + return expired +} + +func (s *ProxyServer) expireSessions() { + expired := s.getExpiredSessions() + if len(expired) == 0 { + return + } + + s.sessionsLock.Lock() + defer s.sessionsLock.Unlock() + for _, session := range expired { + if !session.IsExpired() { + // Session was used while waiting for the lock. + continue + } + + log.Printf("Delete expired session %s", session.PublicId()) + s.deleteSessionLocked(session.Sid()) + } +} + +func (s *ProxyServer) Stop() { + if !atomic.CompareAndSwapUint32(&s.stopped, 0, 1) { + return + } + + s.mcu.Stop() +} + +func (s *ProxyServer) ShutdownChannel() chan bool { + return s.shutdownChannel +} + +func (s *ProxyServer) ScheduleShutdown() { + if !atomic.CompareAndSwapUint32(&s.shutdownScheduled, 0, 1) { + return + } + + msg := &signaling.ProxyServerMessage{ + Type: "event", + Event: &signaling.EventProxyServerMessage{ + Type: "shutdown-scheduled", + }, + } + s.IterateSessions(func(session *ProxySession) { + session.sendMessage(msg) + }) + + if s.GetClientCount() == 0 { + go func() { + s.shutdownChannel <- true + }() + } +} + +func (s *ProxyServer) Reload(config *goconf.ConfigFile) { + tokenKeys := make(map[string]*rsa.PublicKey) + options, _ := config.GetOptions("tokens") + for _, id := range options { + filename, _ := config.GetString("tokens", id) + if filename == "" { + log.Printf("No filename given for token %s, ignoring", id) + continue + } + + keyData, err := ioutil.ReadFile(filename) + if err != nil { + log.Printf("Could not read public key from %s, ignoring: %s", filename, err) + continue + } + key, err := jwt.ParseRSAPublicKeyFromPEM(keyData) + if err != nil { + log.Printf("Could not parse public key from %s, ignoring: %s", filename, err) + continue + } + + tokenKeys[id] = key + } + + if len(tokenKeys) == 0 { + log.Printf("No token keys loaded") + } else { + var keyIds []string + for k, _ := range tokenKeys { + keyIds = append(keyIds, k) + } + sort.Strings(keyIds) + log.Printf("Enabled token keys: %v", keyIds) + } + s.setTokenKeys(tokenKeys) +} + +func (s *ProxyServer) setCommonHeaders(f func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Server", "nextcloud-spreed-signaling-proxy/"+s.version) + f(w, r) + } +} + +func getRealUserIP(r *http.Request) string { + // Note this function assumes it is running behind a trusted proxy, so + // the headers can be trusted. + if ip := r.Header.Get("X-Real-IP"); ip != "" { + return ip + } + + if ip := r.Header.Get("X-Forwarded-For"); ip != "" { + // Result could be a list "clientip, proxy1, proxy2", so only use first element. + if pos := strings.Index(ip, ","); pos >= 0 { + ip = strings.TrimSpace(ip[:pos]) + } + return ip + } + + return r.RemoteAddr +} + +func (s *ProxyServer) proxyHandler(w http.ResponseWriter, r *http.Request) { + addr := getRealUserIP(r) + conn, err := s.upgrader.Upgrade(w, r, nil) + if err != nil { + log.Printf("Could not upgrade request from %s: %s", addr, err) + return + } + + client, err := NewProxyClient(s, conn, addr) + if err != nil { + log.Printf("Could not create client for %s: %s", addr, err) + return + } + + client.OnClosed = s.clientClosed + client.OnMessageReceived = func(c *signaling.Client, data []byte) { + s.processMessage(client, data) + } + client.OnRTTReceived = func(c *signaling.Client, rtt time.Duration) { + if session := client.GetSession(); session != nil { + session.MarkUsed() + } + } + + go client.WritePump() + go client.ReadPump() +} + +func (s *ProxyServer) clientClosed(client *signaling.Client) { + log.Printf("Connection from %s closed", client.RemoteAddr()) +} + +func (s *ProxyServer) onMcuConnected() { + log.Printf("Connection to %s established", s.url) + msg := &signaling.ProxyServerMessage{ + Type: "event", + Event: &signaling.EventProxyServerMessage{ + Type: "backend-connected", + }, + } + + s.IterateSessions(func(session *ProxySession) { + session.sendMessage(msg) + }) +} + +func (s *ProxyServer) onMcuDisconnected() { + if atomic.LoadUint32(&s.stopped) != 0 { + // Shutting down, no need to notify. + return + } + + log.Printf("Connection to %s lost", s.url) + msg := &signaling.ProxyServerMessage{ + Type: "event", + Event: &signaling.EventProxyServerMessage{ + Type: "backend-disconnected", + }, + } + + s.IterateSessions(func(session *ProxySession) { + session.sendMessage(msg) + session.NotifyDisconnected() + }) +} + +func (s *ProxyServer) sendCurrentLoad(session *ProxySession) { + msg := &signaling.ProxyServerMessage{ + Type: "event", + Event: &signaling.EventProxyServerMessage{ + Type: "update-load", + Load: atomic.LoadInt64(&s.load), + }, + } + session.sendMessage(msg) +} + +func (s *ProxyServer) sendShutdownScheduled(session *ProxySession) { + msg := &signaling.ProxyServerMessage{ + Type: "event", + Event: &signaling.EventProxyServerMessage{ + Type: "shutdown-scheduled", + }, + } + session.sendMessage(msg) +} + +func (s *ProxyServer) processMessage(client *ProxyClient, data []byte) { + if proxyDebugMessages { + log.Printf("Message: %s", string(data)) + } + var message signaling.ProxyClientMessage + if err := message.UnmarshalJSON(data); err != nil { + if session := client.GetSession(); session != nil { + log.Printf("Error decoding message from client %s: %v", session.PublicId(), err) + } else { + log.Printf("Error decoding message from %s: %v", client.RemoteAddr(), err) + } + client.SendError(signaling.InvalidFormat) + return + } + + if err := message.CheckValid(); err != nil { + if session := client.GetSession(); session != nil { + log.Printf("Invalid message %+v from client %s: %v", message, session.PublicId(), err) + } else { + log.Printf("Invalid message %+v from %s: %v", message, client.RemoteAddr(), err) + } + client.SendMessage(message.NewErrorServerMessage(signaling.InvalidFormat)) + return + } + + session := client.GetSession() + if session == nil { + if message.Type != "hello" { + client.SendMessage(message.NewErrorServerMessage(signaling.HelloExpected)) + return + } + + var session *ProxySession + if resumeId := message.Hello.ResumeId; resumeId != "" { + var data signaling.SessionIdData + if s.cookie.Decode("session", resumeId, &data) == nil { + session = s.GetSession(data.Sid) + } + + if session == nil { + client.SendMessage(message.NewErrorServerMessage(signaling.NoSuchSession)) + return + } + + log.Printf("Resumed session %s", session.PublicId()) + if atomic.LoadUint32(&s.shutdownScheduled) != 0 { + s.sendShutdownScheduled(session) + } else { + s.sendCurrentLoad(session) + } + } else { + var err error + if session, err = s.NewSession(message.Hello); err != nil { + if e, ok := err.(*signaling.Error); ok { + client.SendMessage(message.NewErrorServerMessage(e)) + } else { + client.SendMessage(message.NewWrappedErrorServerMessage(err)) + } + return + } + } + + prev := session.SetClient(client) + if prev != nil { + msg := &signaling.ProxyServerMessage{ + Type: "bye", + Bye: &signaling.ByeProxyServerMessage{ + Reason: "session_resumed", + }, + } + prev.SendMessage(msg) + } + response := &signaling.ProxyServerMessage{ + Id: message.Id, + Type: "hello", + Hello: &signaling.HelloProxyServerMessage{ + Version: signaling.HelloVersion, + SessionId: session.PublicId(), + Server: &signaling.HelloServerMessageServer{ + Version: s.version, + Country: s.country, + }, + }, + } + client.SendMessage(response) + if atomic.LoadUint32(&s.shutdownScheduled) != 0 { + s.sendShutdownScheduled(session) + } else { + s.sendCurrentLoad(session) + } + return + } + + ctx := context.WithValue(context.Background(), ContextKeySession, session) + session.MarkUsed() + + switch message.Type { + case "command": + s.processCommand(ctx, client, session, &message) + case "payload": + s.processPayload(ctx, client, session, &message) + default: + session.sendMessage(message.NewErrorServerMessage(UnsupportedMessage)) + } +} + +type emptyInitiator struct{} + +func (i *emptyInitiator) Country() string { + return "" +} + +func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, session *ProxySession, message *signaling.ProxyClientMessage) { + cmd := message.Command + switch cmd.Type { + case "create-publisher": + if atomic.LoadUint32(&s.shutdownScheduled) != 0 { + session.sendMessage(message.NewErrorServerMessage(ShutdownScheduled)) + return + } + + id := uuid.New().String() + publisher, err := s.mcu.NewPublisher(ctx, session, id, cmd.StreamType, &emptyInitiator{}) + if err == context.DeadlineExceeded { + log.Printf("Timeout while creating %s publisher %s for %s", cmd.StreamType, id, session.PublicId()) + session.sendMessage(message.NewErrorServerMessage(TimeoutCreatingPublisher)) + return + } else if err != nil { + log.Printf("Error while creating %s publisher %s for %s: %s", cmd.StreamType, id, session.PublicId(), err) + session.sendMessage(message.NewWrappedErrorServerMessage(err)) + return + } + + log.Printf("Created %s publisher %s as %s", cmd.StreamType, publisher.Id(), id) + session.StorePublisher(ctx, id, publisher) + s.StoreClient(id, publisher) + + response := &signaling.ProxyServerMessage{ + Id: message.Id, + Type: "command", + Command: &signaling.CommandProxyServerMessage{ + Id: id, + }, + } + session.sendMessage(response) + case "create-subscriber": + id := uuid.New().String() + publisherId := cmd.PublisherId + subscriber, err := s.mcu.NewSubscriber(ctx, session, publisherId, cmd.StreamType) + if err == context.DeadlineExceeded { + log.Printf("Timeout while creating %s subscriber on %s for %s", cmd.StreamType, publisherId, session.PublicId()) + session.sendMessage(message.NewErrorServerMessage(TimeoutCreatingSubscriber)) + return + } else if err != nil { + log.Printf("Error while creating %s subscriber on %s for %s: %s", cmd.StreamType, publisherId, session.PublicId(), err) + session.sendMessage(message.NewWrappedErrorServerMessage(err)) + return + } + + log.Printf("Created %s subscriber %s as %s", cmd.StreamType, subscriber.Id(), id) + session.StoreSubscriber(ctx, id, subscriber) + s.StoreClient(id, subscriber) + + response := &signaling.ProxyServerMessage{ + Id: message.Id, + Type: "command", + Command: &signaling.CommandProxyServerMessage{ + Id: id, + }, + } + session.sendMessage(response) + case "delete-publisher": + client := s.GetClient(cmd.ClientId) + if client == nil { + session.sendMessage(message.NewErrorServerMessage(UnknownClient)) + return + } + + if session.DeletePublisher(client) == "" { + session.sendMessage(message.NewErrorServerMessage(UnknownClient)) + return + } + + s.DeleteClient(cmd.ClientId, client) + + go func() { + log.Printf("Closing %s publisher %s as %s", client.StreamType(), client.Id(), cmd.ClientId) + client.Close(context.Background()) + }() + + response := &signaling.ProxyServerMessage{ + Id: message.Id, + Type: "command", + Command: &signaling.CommandProxyServerMessage{ + Id: cmd.ClientId, + }, + } + session.sendMessage(response) + case "delete-subscriber": + client := s.GetClient(cmd.ClientId) + if client == nil { + session.sendMessage(message.NewErrorServerMessage(UnknownClient)) + return + } + + subscriber, ok := client.(signaling.McuSubscriber) + if !ok { + session.sendMessage(message.NewErrorServerMessage(UnknownClient)) + return + } + + if session.DeleteSubscriber(subscriber) == "" { + session.sendMessage(message.NewErrorServerMessage(UnknownClient)) + return + } + + s.DeleteClient(cmd.ClientId, client) + + go func() { + log.Printf("Closing %s subscriber %s as %s", client.StreamType(), client.Id(), cmd.ClientId) + client.Close(context.Background()) + }() + + response := &signaling.ProxyServerMessage{ + Id: message.Id, + Type: "command", + Command: &signaling.CommandProxyServerMessage{ + Id: cmd.ClientId, + }, + } + session.sendMessage(response) + default: + log.Printf("Unsupported command %+v", message.Command) + session.sendMessage(message.NewErrorServerMessage(UnsupportedCommand)) + } +} + +func (s *ProxyServer) processPayload(ctx context.Context, client *ProxyClient, session *ProxySession, message *signaling.ProxyClientMessage) { + payload := message.Payload + mcuClient := s.GetClient(payload.ClientId) + if mcuClient == nil { + session.sendMessage(message.NewErrorServerMessage(UnknownClient)) + return + } + + var mcuData *signaling.MessageClientMessageData + switch payload.Type { + case "offer": + fallthrough + case "answer": + fallthrough + case "candidate": + mcuData = &signaling.MessageClientMessageData{ + Type: payload.Type, + Payload: payload.Payload, + } + case "endOfCandidates": + // Ignore but confirm, not passed along to Janus anyway. + session.sendMessage(&signaling.ProxyServerMessage{ + Id: message.Id, + Type: "payload", + Payload: &signaling.PayloadProxyServerMessage{ + Type: payload.Type, + ClientId: payload.ClientId, + }, + }) + return + case "requestoffer": + fallthrough + case "sendoffer": + mcuData = &signaling.MessageClientMessageData{ + Type: payload.Type, + } + default: + session.sendMessage(message.NewErrorServerMessage(UnsupportedPayload)) + return + } + + mcuClient.SendMessage(ctx, nil, mcuData, func(err error, response map[string]interface{}) { + var responseMsg *signaling.ProxyServerMessage + if err != nil { + log.Printf("Error sending %s to %s client %s: %s", mcuData, mcuClient.StreamType(), payload.ClientId, err) + responseMsg = message.NewWrappedErrorServerMessage(err) + } else { + responseMsg = &signaling.ProxyServerMessage{ + Id: message.Id, + Type: "payload", + Payload: &signaling.PayloadProxyServerMessage{ + Type: payload.Type, + ClientId: payload.ClientId, + Payload: response, + }, + } + } + + session.sendMessage(responseMsg) + }) +} + +func (s *ProxyServer) NewSession(hello *signaling.HelloProxyClientMessage) (*ProxySession, error) { + if proxyDebugMessages { + log.Printf("Hello: %+v", hello) + } + + token, err := jwt.ParseWithClaims(hello.Token, &signaling.TokenClaims{}, func(token *jwt.Token) (interface{}, error) { + // Don't forget to validate the alg is what you expect: + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { + log.Printf("Unexpected signing method: %v", token.Header["alg"]) + return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) + } + + claims, ok := token.Claims.(*signaling.TokenClaims) + if !ok { + log.Printf("Unsupported claims type: %+v", token.Claims) + return nil, fmt.Errorf("Unsupported claims type") + } + + tokenKeys := s.getTokenKeys() + publicKey := tokenKeys[claims.Issuer] + if publicKey == nil { + log.Printf("Issuer %s is not supported", claims.Issuer) + return nil, fmt.Errorf("No key found for issuer") + } + return publicKey, nil + }) + if err != nil { + return nil, TokenAuthFailed + } + + claims, ok := token.Claims.(*signaling.TokenClaims) + if !ok || !token.Valid { + return nil, TokenAuthFailed + } + + minIssuedAt := time.Now().Add(-maxTokenAge) + if issuedAt := time.Unix(claims.IssuedAt, 0); issuedAt.Before(minIssuedAt) { + return nil, TokenExpired + } + + sid := atomic.AddUint64(&s.sid, 1) + for sid == 0 { + sid = atomic.AddUint64(&s.sid, 1) + } + + sessionIdData := &signaling.SessionIdData{ + Sid: sid, + Created: time.Now(), + } + + encoded, err := s.cookie.Encode("session", sessionIdData) + if err != nil { + return nil, err + } + + log.Printf("Created session %s for %+v", encoded, claims) + session := NewProxySession(s, sid, encoded) + s.StoreSession(sid, session) + return session, nil +} + +func (s *ProxyServer) StoreSession(id uint64, session *ProxySession) { + s.sessionsLock.Lock() + defer s.sessionsLock.Unlock() + s.sessions[id] = session +} + +func (s *ProxyServer) GetSession(id uint64) *ProxySession { + s.sessionsLock.RLock() + defer s.sessionsLock.RUnlock() + return s.sessions[id] +} + +func (s *ProxyServer) GetSessionsCount() int64 { + s.sessionsLock.RLock() + defer s.sessionsLock.RUnlock() + return int64(len(s.sessions)) +} + +func (s *ProxyServer) IterateSessions(f func(*ProxySession)) { + s.sessionsLock.RLock() + defer s.sessionsLock.RUnlock() + + for _, session := range s.sessions { + f(session) + } +} + +func (s *ProxyServer) DeleteSession(id uint64) { + s.sessionsLock.Lock() + defer s.sessionsLock.Unlock() + s.deleteSessionLocked(id) +} + +func (s *ProxyServer) deleteSessionLocked(id uint64) { + delete(s.sessions, id) +} + +func (s *ProxyServer) StoreClient(id string, client signaling.McuClient) { + s.clientsLock.Lock() + defer s.clientsLock.Unlock() + s.clients[id] = client + s.clientIds[client.Id()] = id +} + +func (s *ProxyServer) DeleteClient(id string, client signaling.McuClient) { + s.clientsLock.Lock() + defer s.clientsLock.Unlock() + delete(s.clients, id) + delete(s.clientIds, client.Id()) + + if len(s.clients) == 0 && atomic.LoadUint32(&s.shutdownScheduled) != 0 { + go func() { + s.shutdownChannel <- true + }() + } +} + +func (s *ProxyServer) GetClientCount() int64 { + s.clientsLock.RLock() + defer s.clientsLock.RUnlock() + return int64(len(s.clients)) +} + +func (s *ProxyServer) GetClient(id string) signaling.McuClient { + s.clientsLock.RLock() + defer s.clientsLock.RUnlock() + return s.clients[id] +} + +func (s *ProxyServer) GetClientId(client signaling.McuClient) string { + s.clientsLock.RLock() + defer s.clientsLock.RUnlock() + return s.clientIds[client.Id()] +} + +func (s *ProxyServer) getStats() map[string]interface{} { + result := map[string]interface{}{ + "sessions": s.GetSessionsCount(), + "mcu": s.mcu.GetStats(), + } + return result +} + +func (s *ProxyServer) validateStatsRequest(f func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + addr := getRealUserIP(r) + if strings.Contains(addr, ":") { + if host, _, err := net.SplitHostPort(addr); err == nil { + addr = host + } + } + if !s.statsAllowedIps[addr] { + http.Error(w, "Authentication check failed", http.StatusForbidden) + return + } + + f(w, r) + } +} + +func (s *ProxyServer) statsHandler(w http.ResponseWriter, r *http.Request) { + stats := s.getStats() + statsData, err := json.MarshalIndent(stats, "", " ") + if err != nil { + log.Printf("Could not serialize stats %+v: %s", stats, err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.WriteHeader(http.StatusOK) + w.Write(statsData) +} diff --git a/src/proxy/proxy_session.go b/src/proxy/proxy_session.go new file mode 100644 index 0000000..ebbb85d --- /dev/null +++ b/src/proxy/proxy_session.go @@ -0,0 +1,272 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2020 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package main + +import ( + "log" + "sync" + "sync/atomic" + "time" + + "golang.org/x/net/context" + + "signaling" +) + +const ( + // Sessions expire if they have not been used for one minute. + sessionExpirationTime = time.Minute +) + +type ProxySession struct { + proxy *ProxyServer + id string + sid uint64 + lastUsed int64 + + clientLock sync.Mutex + client *ProxyClient + pendingMessages []*signaling.ProxyServerMessage + + publishersLock sync.Mutex + publishers map[string]signaling.McuPublisher + publisherIds map[string]string + + subscribersLock sync.Mutex + subscribers map[string]signaling.McuSubscriber + subscriberIds map[string]string +} + +func NewProxySession(proxy *ProxyServer, sid uint64, id string) *ProxySession { + return &ProxySession{ + proxy: proxy, + id: id, + sid: sid, + lastUsed: time.Now().UnixNano(), + + publishers: make(map[string]signaling.McuPublisher), + publisherIds: make(map[string]string), + + subscribers: make(map[string]signaling.McuSubscriber), + subscriberIds: make(map[string]string), + } +} + +func (s *ProxySession) PublicId() string { + return s.id +} + +func (s *ProxySession) Sid() uint64 { + return s.sid +} + +func (s *ProxySession) LastUsed() time.Time { + lastUsed := atomic.LoadInt64(&s.lastUsed) + return time.Unix(0, lastUsed) +} + +func (s *ProxySession) IsExpired() bool { + expiresAt := s.LastUsed().Add(sessionExpirationTime) + return expiresAt.Before(time.Now()) +} + +func (s *ProxySession) MarkUsed() { + now := time.Now() + atomic.StoreInt64(&s.lastUsed, now.UnixNano()) +} + +func (s *ProxySession) SetClient(client *ProxyClient) *ProxyClient { + s.clientLock.Lock() + prev := s.client + s.client = client + var messages []*signaling.ProxyServerMessage + if client != nil { + messages, s.pendingMessages = s.pendingMessages, nil + } + s.clientLock.Unlock() + if prev != nil { + prev.SetSession(nil) + } + if client != nil { + s.MarkUsed() + client.SetSession(s) + for _, msg := range messages { + client.SendMessage(msg) + } + } + return prev +} + +func (s *ProxySession) OnIceCandidate(client signaling.McuClient, candidate interface{}) { + id := s.proxy.GetClientId(client) + if id == "" { + log.Printf("Received candidate %+v from unknown %s client %s (%+v)", candidate, client.StreamType(), client.Id(), client) + return + } + + msg := &signaling.ProxyServerMessage{ + Type: "payload", + Payload: &signaling.PayloadProxyServerMessage{ + Type: "candidate", + ClientId: id, + Payload: map[string]interface{}{ + "candidate": candidate, + }, + }, + } + s.sendMessage(msg) +} + +func (s *ProxySession) sendMessage(message *signaling.ProxyServerMessage) { + var client *ProxyClient + s.clientLock.Lock() + client = s.client + if client == nil { + s.pendingMessages = append(s.pendingMessages, message) + } + s.clientLock.Unlock() + if client != nil { + client.SendMessage(message) + } +} + +func (s *ProxySession) OnIceCompleted(client signaling.McuClient) { + id := s.proxy.GetClientId(client) + if id == "" { + log.Printf("Received ice completed event from unknown %s client %s (%+v)", client.StreamType(), client.Id(), client) + return + } + + msg := &signaling.ProxyServerMessage{ + Type: "event", + Event: &signaling.EventProxyServerMessage{ + Type: "ice-completed", + ClientId: id, + }, + } + s.sendMessage(msg) +} + +func (s *ProxySession) PublisherClosed(publisher signaling.McuPublisher) { + if id := s.DeletePublisher(publisher); id != "" { + s.proxy.DeleteClient(id, publisher) + + msg := &signaling.ProxyServerMessage{ + Type: "event", + Event: &signaling.EventProxyServerMessage{ + Type: "publisher-closed", + ClientId: id, + }, + } + s.sendMessage(msg) + } +} + +func (s *ProxySession) SubscriberClosed(subscriber signaling.McuSubscriber) { + if id := s.DeleteSubscriber(subscriber); id != "" { + s.proxy.DeleteClient(id, subscriber) + + msg := &signaling.ProxyServerMessage{ + Type: "event", + Event: &signaling.EventProxyServerMessage{ + Type: "subscriber-closed", + ClientId: id, + }, + } + s.sendMessage(msg) + } +} + +func (s *ProxySession) StorePublisher(ctx context.Context, id string, publisher signaling.McuPublisher) { + s.publishersLock.Lock() + defer s.publishersLock.Unlock() + + s.publishers[id] = publisher + s.publisherIds[publisher.Id()] = id +} + +func (s *ProxySession) DeletePublisher(publisher signaling.McuPublisher) string { + s.publishersLock.Lock() + defer s.publishersLock.Unlock() + + id, found := s.publisherIds[publisher.Id()] + if !found { + return "" + } + + delete(s.publishers, id) + delete(s.publisherIds, publisher.Id()) + return id +} + +func (s *ProxySession) StoreSubscriber(ctx context.Context, id string, subscriber signaling.McuSubscriber) { + s.subscribersLock.Lock() + defer s.subscribersLock.Unlock() + + s.subscribers[id] = subscriber + s.subscriberIds[subscriber.Id()] = id +} + +func (s *ProxySession) DeleteSubscriber(subscriber signaling.McuSubscriber) string { + s.subscribersLock.Lock() + defer s.subscribersLock.Unlock() + + id, found := s.subscriberIds[subscriber.Id()] + if !found { + return "" + } + + delete(s.subscribers, id) + delete(s.subscriberIds, subscriber.Id()) + return id +} + +func (s *ProxySession) clearPublishers() { + s.publishersLock.Lock() + defer s.publishersLock.Unlock() + + go func(publishers map[string]signaling.McuPublisher) { + for _, publisher := range publishers { + publisher.Close(context.Background()) + } + }(s.publishers) + s.publishers = make(map[string]signaling.McuPublisher) + s.publisherIds = make(map[string]string) +} + +func (s *ProxySession) clearSubscribers() { + s.publishersLock.Lock() + defer s.publishersLock.Unlock() + + go func(subscribers map[string]signaling.McuSubscriber) { + for _, subscriber := range subscribers { + subscriber.Close(context.Background()) + } + }(s.subscribers) + s.subscribers = make(map[string]signaling.McuSubscriber) + s.subscriberIds = make(map[string]string) +} + +func (s *ProxySession) NotifyDisconnected() { + s.clearPublishers() + s.clearSubscribers() +} diff --git a/src/server/main.go b/src/server/main.go index 468b899..f052aa7 100644 --- a/src/server/main.go +++ b/src/server/main.go @@ -166,6 +166,8 @@ func main() { switch mcuType { case signaling.McuTypeJanus: mcu, err = signaling.NewMcuJanus(mcuUrl, config, nats) + case signaling.McuTypeProxy: + mcu, err = signaling.NewMcuProxy(mcuUrl, config) default: log.Fatal("Unsupported MCU type: ", mcuType) } diff --git a/src/signaling/api_proxy.go b/src/signaling/api_proxy.go new file mode 100644 index 0000000..ad78a4c --- /dev/null +++ b/src/signaling/api_proxy.go @@ -0,0 +1,254 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2020 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "fmt" + + "gopkg.in/dgrijalva/jwt-go.v3" +) + +type ProxyClientMessage struct { + // The unique request id (optional). + Id string `json:"id,omitempty"` + + // The type of the request. + Type string `json:"type"` + + // Filled for type "hello" + Hello *HelloProxyClientMessage `json:"hello,omitempty"` + + Bye *ByeProxyClientMessage `json:"bye,omitempty"` + + Command *CommandProxyClientMessage `json:"command,omitempty"` + + Payload *PayloadProxyClientMessage `json:"payload,omitempty"` +} + +func (m *ProxyClientMessage) CheckValid() error { + switch m.Type { + case "": + return fmt.Errorf("type missing") + case "hello": + if m.Hello == nil { + return fmt.Errorf("hello missing") + } else if err := m.Hello.CheckValid(); err != nil { + return err + } + case "bye": + if m.Bye != nil { + // Bye contents are optional + if err := m.Bye.CheckValid(); err != nil { + return err + } + } + case "command": + if m.Command == nil { + return fmt.Errorf("command missing") + } else if err := m.Command.CheckValid(); err != nil { + return err + } + case "payload": + if m.Payload == nil { + return fmt.Errorf("payload missing") + } else if err := m.Payload.CheckValid(); err != nil { + return err + } + } + return nil +} + +func (m *ProxyClientMessage) NewErrorServerMessage(e *Error) *ProxyServerMessage { + return &ProxyServerMessage{ + Id: m.Id, + Type: "error", + Error: e, + } +} + +func (m *ProxyClientMessage) NewWrappedErrorServerMessage(e error) *ProxyServerMessage { + return m.NewErrorServerMessage(NewError("internal_error", e.Error())) +} + +// ProxyServerMessage is a message that is sent from the server to a client. +type ProxyServerMessage struct { + Id string `json:"id,omitempty"` + + Type string `json:"type"` + + Error *Error `json:"error,omitempty"` + + Hello *HelloProxyServerMessage `json:"hello,omitempty"` + + Bye *ByeProxyServerMessage `json:"bye,omitempty"` + + Command *CommandProxyServerMessage `json:"command,omitempty"` + + Payload *PayloadProxyServerMessage `json:"payload,omitempty"` + + Event *EventProxyServerMessage `json:"event,omitempty"` +} + +func (r *ProxyServerMessage) CloseAfterSend(session Session) bool { + if r.Type == "bye" { + return true + } + + return false +} + +// Type "hello" + +type TokenClaims struct { + jwt.StandardClaims +} + +type HelloProxyClientMessage struct { + Version string `json:"version"` + + ResumeId string `json:"resumeid"` + + Features []string `json:"features,omitempty"` + + // The authentication credentials. + Token string `json:"token"` +} + +func (m *HelloProxyClientMessage) CheckValid() error { + if m.Version != HelloVersion { + return fmt.Errorf("unsupported hello version: %s", m.Version) + } + if m.ResumeId == "" { + if m.Token == "" { + return fmt.Errorf("token missing") + } + } + return nil +} + +type HelloProxyServerMessage struct { + Version string `json:"version"` + + SessionId string `json:"sessionid"` + Server *HelloServerMessageServer `json:"server,omitempty"` +} + +// Type "bye" + +type ByeProxyClientMessage struct { +} + +func (m *ByeProxyClientMessage) CheckValid() error { + // No additional validation required. + return nil +} + +type ByeProxyServerMessage struct { + Reason string `json:"reason"` +} + +// Type "command" + +type CommandProxyClientMessage struct { + Type string `json:"type"` + + StreamType string `json:"streamType,omitempty"` + PublisherId string `json:"publisherId,omitempty"` + ClientId string `json:"clientId,omitempty"` +} + +func (m *CommandProxyClientMessage) CheckValid() error { + switch m.Type { + case "": + return fmt.Errorf("type missing") + case "create-publisher": + if m.StreamType == "" { + return fmt.Errorf("stream type missing") + } + case "create-subscriber": + if m.PublisherId == "" { + return fmt.Errorf("publisher id missing") + } + if m.StreamType == "" { + return fmt.Errorf("stream type missing") + } + case "delete-publisher": + fallthrough + case "delete-subscriber": + if m.ClientId == "" { + return fmt.Errorf("client id missing") + } + } + return nil +} + +type CommandProxyServerMessage struct { + Id string `json:"id,omitempty"` +} + +// Type "payload" + +type PayloadProxyClientMessage struct { + Type string `json:"type"` + + ClientId string `json:"clientId"` + Payload map[string]interface{} `json:"payload,omitempty"` +} + +func (m *PayloadProxyClientMessage) CheckValid() error { + switch m.Type { + case "": + return fmt.Errorf("type missing") + case "offer": + fallthrough + case "answer": + fallthrough + case "candidate": + if len(m.Payload) == 0 { + return fmt.Errorf("payload missing") + } + case "endOfCandidates": + fallthrough + case "requestoffer": + // No payload required. + } + if m.ClientId == "" { + return fmt.Errorf("client id missing") + } + return nil +} + +type PayloadProxyServerMessage struct { + Type string `json:"type"` + + ClientId string `json:"clientId"` + Payload map[string]interface{} `json:"payload"` +} + +// Type "event" + +type EventProxyServerMessage struct { + Type string `json:"type"` + + ClientId string `json:"clientId,omitempty"` + Load int64 `json:"load,omitempty"` +} diff --git a/src/signaling/api_signaling.go b/src/signaling/api_signaling.go index 5f43969..5762459 100644 --- a/src/signaling/api_signaling.go +++ b/src/signaling/api_signaling.go @@ -278,6 +278,7 @@ const ( type HelloServerMessageServer struct { Version string `json:"version"` Features []string `json:"features,omitempty"` + Country string `json:"country,omitempty"` } type HelloServerMessage struct { diff --git a/src/signaling/client.go b/src/signaling/client.go index e1535c4..ebfef43 100644 --- a/src/signaling/client.go +++ b/src/signaling/client.go @@ -25,7 +25,6 @@ import ( "bytes" "encoding/json" "log" - "net" "strconv" "strings" "sync" @@ -52,16 +51,28 @@ const ( ) var ( - _noCountry string = "no-country" - noCountry *string = &_noCountry + noCountry string = "no-country" - _loopback string = "loopback" - loopback *string = &_loopback + loopback string = "loopback" - _unknownCountry string = "unknown-country" - unknownCountry *string = &_unknownCountry + unknownCountry string = "unknown-country" ) +func IsValidCountry(country string) bool { + switch country { + case "": + fallthrough + case noCountry: + fallthrough + case loopback: + fallthrough + case unknownCountry: + return false + default: + return true + } +} + var ( InvalidFormat = NewError("invalid_format", "Invalid data format.") @@ -72,8 +83,13 @@ var ( } ) +type WritableClientMessage interface { + json.Marshaler + + CloseAfterSend(session Session) bool +} + type Client struct { - hub *Hub conn *websocket.Conn addr string agent string @@ -85,9 +101,14 @@ type Client struct { mu sync.Mutex closeChan chan bool + + OnLookupCountry func(*Client) string + OnClosed func(*Client) + OnMessageReceived func(*Client, []byte) + OnRTTReceived func(*Client, time.Duration) } -func NewClient(hub *Hub, conn *websocket.Conn, remoteAddress string, agent string) (*Client, error) { +func NewClient(conn *websocket.Conn, remoteAddress string, agent string) (*Client, error) { remoteAddress = strings.TrimSpace(remoteAddress) if remoteAddress == "" { remoteAddress = "unknown remote address" @@ -97,15 +118,28 @@ func NewClient(hub *Hub, conn *websocket.Conn, remoteAddress string, agent strin agent = "unknown user agent" } client := &Client{ - hub: hub, conn: conn, addr: remoteAddress, agent: agent, closeChan: make(chan bool, 1), + + OnLookupCountry: func(client *Client) string { return unknownCountry }, + OnClosed: func(client *Client) {}, + OnMessageReceived: func(client *Client, data []byte) {}, + OnRTTReceived: func(client *Client, rtt time.Duration) {}, } return client, nil } +func (c *Client) SetConn(conn *websocket.Conn, remoteAddress string) { + c.conn = conn + c.addr = remoteAddress + c.closeChan = make(chan bool, 1) + c.OnLookupCountry = func(client *Client) string { return unknownCountry } + c.OnClosed = func(client *Client) {} + c.OnMessageReceived = func(client *Client, data []byte) {} +} + func (c *Client) IsConnected() bool { return atomic.LoadUint32(&c.closed) == 0 } @@ -132,25 +166,7 @@ func (c *Client) UserAgent() string { func (c *Client) Country() string { if c.country == nil { - if c.hub.geoip == nil { - c.country = unknownCountry - return *c.country - } - ip := net.ParseIP(c.RemoteAddr()) - if ip == nil { - c.country = noCountry - return *c.country - } else if ip.IsLoopback() { - c.country = loopback - return *c.country - } - - country, err := c.hub.geoip.LookupCountry(ip) - if err != nil { - log.Printf("Could not lookup country for %s", ip) - c.country = unknownCountry - return *c.country - } + country := c.OnLookupCountry(c) c.country = &country } @@ -164,7 +180,7 @@ func (c *Client) Close() { c.closeChan <- true - c.hub.processUnregister(c) + c.OnClosed(c) c.SetSession(nil) c.mu.Lock() @@ -183,41 +199,6 @@ func (c *Client) SendError(e *Error) bool { return c.SendMessage(message) } -func (c *Client) SendRoom(message *ClientMessage, room *Room) bool { - response := &ServerMessage{ - Type: "room", - } - if message != nil { - response.Id = message.Id - } - if room == nil { - response.Room = &RoomServerMessage{ - RoomId: "", - } - } else { - response.Room = &RoomServerMessage{ - RoomId: room.id, - Properties: room.properties, - } - } - return c.SendMessage(response) -} - -func (c *Client) SendHelloResponse(message *ClientMessage, session *ClientSession) bool { - response := &ServerMessage{ - Id: message.Id, - Type: "hello", - Hello: &HelloServerMessage{ - Version: HelloVersion, - SessionId: session.PublicId(), - ResumeId: session.PrivateId(), - UserId: session.UserId(), - Server: c.hub.GetServerInfo(), - }, - } - return c.SendMessage(response) -} - func (c *Client) SendByeResponse(message *ClientMessage) bool { return c.SendByeResponseWithReason(message, "") } @@ -236,11 +217,11 @@ func (c *Client) SendByeResponseWithReason(message *ClientMessage, reason string return c.SendMessage(response) } -func (c *Client) SendMessage(message *ServerMessage) bool { +func (c *Client) SendMessage(message WritableClientMessage) bool { return c.writeMessage(message) } -func (c *Client) readPump() { +func (c *Client) ReadPump() { defer func() { c.Close() }() @@ -270,6 +251,7 @@ func (c *Client) readPump() { } else { log.Printf("Client from %s has RTT of %d ms (%s)", addr, rtt_ms, rtt) } + c.OnRTTReceived(c, rtt) } return nil }) @@ -312,28 +294,7 @@ func (c *Client) readPump() { break } - var message ClientMessage - if err := message.UnmarshalJSON(decodeBuffer.Bytes()); err != nil { - if session := c.GetSession(); session != nil { - log.Printf("Error decoding message from client %s: %v", session.PublicId(), err) - } else { - log.Printf("Error decoding message from %s: %v", addr, err) - } - c.SendError(InvalidFormat) - continue - } - - if err := message.CheckValid(); err != nil { - if session := c.GetSession(); session != nil { - log.Printf("Invalid message %+v from client %s: %v", message, session.PublicId(), err) - } else { - log.Printf("Invalid message %+v from %s: %v", message, addr, err) - } - c.SendMessage(message.NewErrorServerMessage(InvalidFormat)) - continue - } - - c.hub.processMessage(c, &message) + c.OnMessageReceived(c, decodeBuffer.Bytes()) } } @@ -407,7 +368,7 @@ func (c *Client) writeError(e error) bool { return false } -func (c *Client) writeMessage(message *ServerMessage) bool { +func (c *Client) writeMessage(message WritableClientMessage) bool { c.mu.Lock() defer c.mu.Unlock() if c.conn == nil { @@ -417,7 +378,7 @@ func (c *Client) writeMessage(message *ServerMessage) bool { return c.writeMessageLocked(message) } -func (c *Client) writeMessageLocked(message *ServerMessage) bool { +func (c *Client) writeMessageLocked(message WritableClientMessage) bool { if !c.writeInternal(message) { return false } @@ -458,7 +419,7 @@ func (c *Client) sendPing() bool { return true } -func (c *Client) writePump() { +func (c *Client) WritePump() { ticker := time.NewTicker(pingPeriod) defer func() { ticker.Stop() diff --git a/src/signaling/clientsession.go b/src/signaling/clientsession.go index 7a9d2fc..8a0b496 100644 --- a/src/signaling/clientsession.go +++ b/src/signaling/clientsession.go @@ -436,6 +436,10 @@ func (s *ClientSession) GetClient() *Client { s.mu.Lock() defer s.mu.Unlock() + return s.getClientUnlocked() +} + +func (s *ClientSession) getClientUnlocked() *Client { return s.client } @@ -554,9 +558,10 @@ func (s *ClientSession) GetOrCreatePublisher(ctx context.Context, mcu Mcu, strea publisher, found := s.publishers[streamType] if !found { + client := s.getClientUnlocked() s.mu.Unlock() var err error - publisher, err = mcu.NewPublisher(ctx, s, s.PublicId(), streamType) + publisher, err = mcu.NewPublisher(ctx, s, s.PublicId(), streamType, client) s.mu.Lock() if err != nil { return nil, err diff --git a/src/signaling/hub.go b/src/signaling/hub.go index 1747081..9e352f9 100644 --- a/src/signaling/hub.go +++ b/src/signaling/hub.go @@ -30,6 +30,7 @@ import ( "fmt" "hash/fnv" "log" + "net" "net/http" "strings" "sync" @@ -633,7 +634,7 @@ func (h *Hub) processRegister(client *Client, message *ClientMessage, backend *B h.setDecodedSessionId(privateSessionId, privateSessionName, sessionIdData) h.setDecodedSessionId(publicSessionId, publicSessionName, sessionIdData) - client.SendHelloResponse(message, session) + h.sendHelloResponse(client, message, session) } func (h *Hub) processUnregister(client *Client) *ClientSession { @@ -656,7 +657,28 @@ func (h *Hub) processUnregister(client *Client) *ClientSession { return session } -func (h *Hub) processMessage(client *Client, message *ClientMessage) { +func (h *Hub) processMessage(client *Client, data []byte) { + var message ClientMessage + if err := message.UnmarshalJSON(data); err != nil { + if session := client.GetSession(); session != nil { + log.Printf("Error decoding message from client %s: %v", session.PublicId(), err) + } else { + log.Printf("Error decoding message from %s: %v", client.RemoteAddr(), err) + } + client.SendError(InvalidFormat) + return + } + + if err := message.CheckValid(); err != nil { + if session := client.GetSession(); session != nil { + log.Printf("Invalid message %+v from client %s: %v", message, session.PublicId(), err) + } else { + log.Printf("Invalid message %+v from %s: %v", message, client.RemoteAddr(), err) + } + client.SendMessage(message.NewErrorServerMessage(InvalidFormat)) + return + } + session := client.GetSession() if session == nil { if message.Type != "hello" { @@ -664,19 +686,19 @@ func (h *Hub) processMessage(client *Client, message *ClientMessage) { return } - h.processHello(client, message) + h.processHello(client, &message) return } switch message.Type { case "room": - h.processRoom(client, message) + h.processRoom(client, &message) case "message": - h.processMessageMsg(client, message) + h.processMessageMsg(client, &message) case "control": - h.processControlMsg(client, message) + h.processControlMsg(client, &message) case "bye": - h.processByeMsg(client, message) + h.processByeMsg(client, &message) case "hello": log.Printf("Ignore hello %+v for already authenticated connection %s", message.Hello, session.PublicId()) default: @@ -684,6 +706,21 @@ func (h *Hub) processMessage(client *Client, message *ClientMessage) { } } +func (h *Hub) sendHelloResponse(client *Client, message *ClientMessage, session *ClientSession) bool { + response := &ServerMessage{ + Id: message.Id, + Type: "hello", + Hello: &HelloServerMessage{ + Version: HelloVersion, + SessionId: session.PublicId(), + ResumeId: session.PrivateId(), + UserId: session.UserId(), + Server: h.GetServerInfo(), + }, + } + return client.SendMessage(response) +} + func (h *Hub) processHello(client *Client, message *ClientMessage) { resumeId := message.Hello.ResumeId if resumeId != "" { @@ -728,7 +765,7 @@ func (h *Hub) processHello(client *Client, message *ClientMessage) { log.Printf("Resume session from %s in %s (%s) %s (private=%s)", client.RemoteAddr(), client.Country(), client.UserAgent(), session.PublicId(), session.PrivateId()) - client.SendHelloResponse(message, clientSession) + h.sendHelloResponse(client, message, clientSession) clientSession.NotifySessionResumed(client) return } @@ -839,6 +876,26 @@ func (h *Hub) disconnectByRoomSessionId(roomSessionId string) { session.Close() } +func (h *Hub) sendRoom(client *Client, message *ClientMessage, room *Room) bool { + response := &ServerMessage{ + Type: "room", + } + if message != nil { + response.Id = message.Id + } + if room == nil { + response.Room = &RoomServerMessage{ + RoomId: "", + } + } else { + response.Room = &RoomServerMessage{ + RoomId: room.id, + Properties: room.properties, + } + } + return client.SendMessage(response) +} + func (h *Hub) processRoom(client *Client, message *ClientMessage) { session := client.GetSession() roomId := message.Room.RoomId @@ -850,7 +907,7 @@ func (h *Hub) processRoom(client *Client, message *ClientMessage) { // We can handle leaving a room directly. if session.LeaveRoom(true) != nil { // User was in a room before, so need to notify about leaving it. - client.SendRoom(message, nil) + h.sendRoom(client, message, nil) } if session.UserId() == "" && session.ClientType() != HelloClientTypeInternal { h.startWaitAnonymousClientRoom(client) @@ -965,7 +1022,7 @@ func (h *Hub) processJoinRoom(client *Client, message *ClientMessage, room *Back if err := session.SubscribeRoomNats(h.nats, roomId, message.Room.SessionId); err != nil { client.SendMessage(message.NewWrappedErrorServerMessage(err)) // The client (implicitly) left the room due to an error. - client.SendRoom(nil, nil) + h.sendRoom(client, nil, nil) return } @@ -978,7 +1035,7 @@ func (h *Hub) processJoinRoom(client *Client, message *ClientMessage, room *Back client.SendMessage(message.NewWrappedErrorServerMessage(err)) // The client (implicitly) left the room due to an error. session.UnsubscribeRoomNats() - client.SendRoom(nil, nil) + h.sendRoom(client, nil, nil) return } } @@ -992,7 +1049,7 @@ func (h *Hub) processJoinRoom(client *Client, message *ClientMessage, room *Back if room.Room.Permissions != nil { session.SetPermissions(*room.Room.Permissions) } - client.SendRoom(message, r) + h.sendRoom(client, message, r) h.notifyUserJoinedRoom(r, client, session, room.Room.Session) } @@ -1427,7 +1484,7 @@ func (h *Hub) processRoomDeleted(message *BackendServerRoomRequest) { switch sess := session.(type) { case *ClientSession: if client := sess.GetClient(); client != nil { - client.SendRoom(nil, nil) + h.sendRoom(client, nil, nil) } } } @@ -1477,6 +1534,26 @@ func getRealUserIP(r *http.Request) string { return r.RemoteAddr } +func (h *Hub) lookupClientCountry(client *Client) string { + ip := net.ParseIP(client.RemoteAddr()) + if ip == nil { + return noCountry + } else if ip.IsLoopback() { + return loopback + } + + country, err := h.geoip.LookupCountry(ip) + if err != nil { + log.Printf("Could not lookup country for %s: %s", ip, err) + return unknownCountry + } + + if country == "" { + return unknownCountry + } + return country +} + func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) { addr := getRealUserIP(r) agent := r.Header.Get("User-Agent") @@ -1487,13 +1564,21 @@ func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) { return } - client, err := NewClient(h, conn, addr, agent) + client, err := NewClient(conn, addr, agent) if err != nil { log.Printf("Could not create client for %s: %s", addr, err) return } + if h.geoip != nil { + client.OnLookupCountry = h.lookupClientCountry + } + client.OnMessageReceived = h.processMessage + client.OnClosed = func(client *Client) { + h.processUnregister(client) + } + h.processNewClient(client) - go client.writePump() - go client.readPump() + go client.WritePump() + go client.ReadPump() } diff --git a/src/signaling/mcu_common.go b/src/signaling/mcu_common.go index 7bfd250..c821ff0 100644 --- a/src/signaling/mcu_common.go +++ b/src/signaling/mcu_common.go @@ -22,17 +22,24 @@ package signaling import ( + "fmt" + "golang.org/x/net/context" ) const ( McuTypeJanus = "janus" + McuTypeProxy = "proxy" McuTypeDefault = McuTypeJanus ) +var ( + ErrNotConnected = fmt.Errorf("Not connected") +) + type McuListener interface { - Session + PublicId() string OnIceCandidate(client McuClient, candidate interface{}) OnIceCompleted(client McuClient) @@ -41,13 +48,20 @@ type McuListener interface { SubscriberClosed(subscriber McuSubscriber) } +type McuInitiator interface { + Country() string +} + type Mcu interface { Start() error Stop() + SetOnConnected(func()) + SetOnDisconnected(func()) + GetStats() interface{} - NewPublisher(ctx context.Context, listener McuListener, id string, streamType string) (McuPublisher, error) + NewPublisher(ctx context.Context, listener McuListener, id string, streamType string, initiator McuInitiator) (McuPublisher, error) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType string) (McuSubscriber, error) } diff --git a/src/signaling/mcu_janus.go b/src/signaling/mcu_janus.go index 1aa7b9e..4e4f38a 100644 --- a/src/signaling/mcu_janus.go +++ b/src/signaling/mcu_janus.go @@ -28,6 +28,7 @@ import ( "reflect" "strconv" "sync" + "sync/atomic" "time" "github.com/dlintw/goconf" @@ -64,8 +65,6 @@ var ( videoPublisherUserId: streamTypeVideo, screenPublisherUserId: streamTypeScreen, } - - ErrNotConnected = fmt.Errorf("Not connected") ) func getPluginValue(data janus.PluginData, pluginName string, key string) interface{} { @@ -161,8 +160,13 @@ type mcuJanus struct { reconnectInterval time.Duration connectedSince time.Time + onConnected atomic.Value + onDisconnected atomic.Value } +func emptyOnConnected() {} +func emptyOnDisconnected() {} + func NewMcuJanus(url string, config *goconf.ConfigFile, nats NatsClient) (Mcu, error) { maxStreamBitrate, _ := config.GetInt("mcu", "maxstreambitrate") if maxStreamBitrate <= 0 { @@ -190,6 +194,9 @@ func NewMcuJanus(url string, config *goconf.ConfigFile, nats NatsClient) (Mcu, e reconnectInterval: initialReconnectInterval, } + mcu.onConnected.Store(emptyOnConnected) + mcu.onDisconnected.Store(emptyOnDisconnected) + mcu.reconnectTimer = time.AfterFunc(mcu.reconnectInterval, mcu.doReconnect) mcu.reconnectTimer.Stop() if err := mcu.reconnect(); err != nil { @@ -269,6 +276,7 @@ func (m *mcuJanus) scheduleReconnect(err error) { func (m *mcuJanus) ConnectionInterrupted() { m.scheduleReconnect(nil) + m.notifyOnDisconnected() } func (m *mcuJanus) Start() error { @@ -314,6 +322,8 @@ func (m *mcuJanus) Start() error { log.Println("Created Janus handle", m.handle.Id) go m.run() + + m.notifyOnConnected() return nil } @@ -349,6 +359,32 @@ func (m *mcuJanus) Stop() { m.reconnectTimer.Stop() } +func (m *mcuJanus) SetOnConnected(f func()) { + if f == nil { + f = emptyOnConnected + } + + m.onConnected.Store(f) +} + +func (m *mcuJanus) notifyOnConnected() { + f := m.onConnected.Load().(func()) + f() +} + +func (m *mcuJanus) SetOnDisconnected(f func()) { + if f == nil { + f = emptyOnDisconnected + } + + m.onDisconnected.Store(f) +} + +func (m *mcuJanus) notifyOnDisconnected() { + f := m.onDisconnected.Load().(func()) + f() +} + type mcuJanusConnectionStats struct { Url string `json:"url"` Connected bool `json:"connected"` @@ -599,7 +635,7 @@ func (m *mcuJanus) getOrCreatePublisherHandle(ctx context.Context, id string, st return handle, response.Session, roomId, nil } -func (m *mcuJanus) NewPublisher(ctx context.Context, listener McuListener, id string, streamType string) (McuPublisher, error) { +func (m *mcuJanus) NewPublisher(ctx context.Context, listener McuListener, id string, streamType string, initiator McuInitiator) (McuPublisher, error) { if _, found := streamTypeUserIds[streamType]; !found { return nil, fmt.Errorf("Unsupported stream type %s", streamType) } diff --git a/src/signaling/mcu_proxy.go b/src/signaling/mcu_proxy.go new file mode 100644 index 0000000..145c4f9 --- /dev/null +++ b/src/signaling/mcu_proxy.go @@ -0,0 +1,1122 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2020 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "crypto/rsa" + "encoding/json" + "fmt" + "io/ioutil" + "log" + "net/url" + "sort" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/dlintw/goconf" + "github.com/gorilla/websocket" + + "golang.org/x/net/context" + + "gopkg.in/dgrijalva/jwt-go.v3" +) + +const ( + closeTimeout = time.Second + + proxyDebugMessages = false + + // Very high value so the connections get sorted at the end. + loadNotConnected = 1000000 + + // Sort connections by load every 10 publishing requests or once per second. + connectionSortRequests = 10 + connectionSortInterval = time.Second +) + +type mcuProxyPubSubCommon struct { + streamType string + proxyId string + conn *mcuProxyConnection + listener McuListener +} + +func (c *mcuProxyPubSubCommon) Id() string { + return c.proxyId +} + +func (c *mcuProxyPubSubCommon) StreamType() string { + return c.streamType +} + +func (c *mcuProxyPubSubCommon) doSendMessage(ctx context.Context, msg *ProxyClientMessage, callback func(error, map[string]interface{})) { + c.conn.performAsyncRequest(ctx, msg, func(err error, response *ProxyServerMessage) { + if err != nil { + callback(err, nil) + return + } + + if proxyDebugMessages { + log.Printf("Response from %s: %+v", c.conn.url, response) + } + if response.Type == "error" { + callback(response.Error, nil) + } else if response.Payload != nil { + callback(nil, response.Payload.Payload) + } else { + callback(nil, nil) + } + }) +} + +func (c *mcuProxyPubSubCommon) doProcessPayload(client McuClient, msg *PayloadProxyServerMessage) { + switch msg.Type { + case "candidate": + c.listener.OnIceCandidate(client, msg.Payload["candidate"]) + default: + log.Printf("Unsupported payload from %s: %+v", c.conn.url, msg) + } +} + +type mcuProxyPublisher struct { + mcuProxyPubSubCommon + + id string +} + +func newMcuProxyPublisher(id string, streamType string, proxyId string, conn *mcuProxyConnection, listener McuListener) *mcuProxyPublisher { + return &mcuProxyPublisher{ + mcuProxyPubSubCommon: mcuProxyPubSubCommon{ + streamType: streamType, + proxyId: proxyId, + conn: conn, + listener: listener, + }, + id: id, + } +} + +func (p *mcuProxyPublisher) NotifyClosed() { + p.listener.PublisherClosed(p) + p.conn.removePublisher(p) +} + +func (p *mcuProxyPublisher) Close(ctx context.Context) { + p.NotifyClosed() + + msg := &ProxyClientMessage{ + Type: "command", + Command: &CommandProxyClientMessage{ + Type: "delete-publisher", + ClientId: p.proxyId, + }, + } + + if _, err := p.conn.performSyncRequest(ctx, msg); err != nil { + log.Printf("Could not delete publisher %s at %s: %s", p.proxyId, p.conn.url, err) + return + } +} + +func (p *mcuProxyPublisher) SendMessage(ctx context.Context, message *MessageClientMessage, data *MessageClientMessageData, callback func(error, map[string]interface{})) { + msg := &ProxyClientMessage{ + Type: "payload", + Payload: &PayloadProxyClientMessage{ + Type: data.Type, + ClientId: p.proxyId, + Payload: data.Payload, + }, + } + + p.doSendMessage(ctx, msg, callback) +} + +func (p *mcuProxyPublisher) ProcessPayload(msg *PayloadProxyServerMessage) { + p.doProcessPayload(p, msg) +} + +func (p *mcuProxyPublisher) ProcessEvent(msg *EventProxyServerMessage) { + switch msg.Type { + case "ice-completed": + p.listener.OnIceCompleted(p) + case "publisher-closed": + p.NotifyClosed() + default: + log.Printf("Unsupported event from %s: %+v", p.conn.url, msg) + } +} + +type mcuProxySubscriber struct { + mcuProxyPubSubCommon + + publisherId string +} + +func newMcuProxySubscriber(publisherId string, streamType string, proxyId string, conn *mcuProxyConnection, listener McuListener) *mcuProxySubscriber { + return &mcuProxySubscriber{ + mcuProxyPubSubCommon: mcuProxyPubSubCommon{ + streamType: streamType, + proxyId: proxyId, + conn: conn, + listener: listener, + }, + + publisherId: publisherId, + } +} + +func (s *mcuProxySubscriber) Publisher() string { + return s.publisherId +} + +func (s *mcuProxySubscriber) NotifyClosed() { + s.listener.SubscriberClosed(s) + s.conn.removeSubscriber(s) +} + +func (s *mcuProxySubscriber) Close(ctx context.Context) { + s.NotifyClosed() + + msg := &ProxyClientMessage{ + Type: "command", + Command: &CommandProxyClientMessage{ + Type: "delete-subscriber", + ClientId: s.proxyId, + }, + } + + if _, err := s.conn.performSyncRequest(ctx, msg); err != nil { + log.Printf("Could not delete subscriber %s at %s: %s", s.proxyId, s.conn.url, err) + return + } +} + +func (s *mcuProxySubscriber) SendMessage(ctx context.Context, message *MessageClientMessage, data *MessageClientMessageData, callback func(error, map[string]interface{})) { + msg := &ProxyClientMessage{ + Type: "payload", + Payload: &PayloadProxyClientMessage{ + Type: data.Type, + ClientId: s.proxyId, + Payload: data.Payload, + }, + } + + s.doSendMessage(ctx, msg, callback) +} + +func (s *mcuProxySubscriber) ProcessPayload(msg *PayloadProxyServerMessage) { + s.doProcessPayload(s, msg) +} + +func (s *mcuProxySubscriber) ProcessEvent(msg *EventProxyServerMessage) { + switch msg.Type { + case "ice-completed": + s.listener.OnIceCompleted(s) + case "subscriber-closed": + s.NotifyClosed() + default: + log.Printf("Unsupported event from %s: %+v", s.conn.url, msg) + } +} + +type mcuProxyConnection struct { + proxy *mcuProxy + url *url.URL + + mu sync.Mutex + closeChan chan bool + closedChan chan bool + closed uint32 + conn *websocket.Conn + + connectedSince time.Time + reconnectInterval int64 + reconnectTimer *time.Timer + shutdownScheduled uint32 + + msgId int64 + helloMsgId string + sessionId string + load int64 + country atomic.Value + + callbacks map[string]func(*ProxyServerMessage) + + publishersLock sync.RWMutex + publishers map[string]*mcuProxyPublisher + publisherIds map[string]string + + subscribersLock sync.RWMutex + subscribers map[string]*mcuProxySubscriber +} + +func newMcuProxyConnection(proxy *mcuProxy, baseUrl string) (*mcuProxyConnection, error) { + parsed, err := url.Parse(baseUrl) + if err != nil { + return nil, err + } + + conn := &mcuProxyConnection{ + proxy: proxy, + url: parsed, + closeChan: make(chan bool, 1), + closedChan: make(chan bool, 1), + reconnectInterval: int64(initialReconnectInterval), + load: loadNotConnected, + callbacks: make(map[string]func(*ProxyServerMessage)), + publishers: make(map[string]*mcuProxyPublisher), + publisherIds: make(map[string]string), + subscribers: make(map[string]*mcuProxySubscriber), + } + conn.country.Store("") + return conn, nil +} + +type mcuProxyConnectionStats struct { + Url string `json:"url"` + Connected bool `json:"connected"` + Publishers int64 `json:"publishers"` + Clients int64 `json:"clients"` + Uptime *time.Time `json:"uptime,omitempty"` +} + +func (c *mcuProxyConnection) GetStats() *mcuProxyConnectionStats { + result := &mcuProxyConnectionStats{ + Url: c.url.String(), + } + c.mu.Lock() + if c.conn != nil { + result.Connected = true + result.Uptime = &c.connectedSince + } + c.mu.Unlock() + c.publishersLock.RLock() + result.Publishers = int64(len(c.publishers)) + c.publishersLock.RUnlock() + c.subscribersLock.RLock() + result.Clients = int64(len(c.subscribers)) + c.subscribersLock.RUnlock() + result.Clients += result.Publishers + return result +} + +func (c *mcuProxyConnection) Load() int64 { + return atomic.LoadInt64(&c.load) +} + +func (c *mcuProxyConnection) Country() string { + return c.country.Load().(string) +} + +func (c *mcuProxyConnection) IsShutdownScheduled() bool { + return atomic.LoadUint32(&c.shutdownScheduled) != 0 +} + +func (c *mcuProxyConnection) readPump() { + defer func() { + if atomic.LoadUint32(&c.closed) == 0 { + c.scheduleReconnect() + } else { + c.closedChan <- true + } + }() + defer c.close() + defer atomic.StoreInt64(&c.load, loadNotConnected) + + c.mu.Lock() + conn := c.conn + c.mu.Unlock() + + for { + _, message, err := conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, + websocket.CloseNormalClosure, + websocket.CloseGoingAway, + websocket.CloseNoStatusReceived) { + log.Printf("Error reading from %s: %v", c.url, err) + } + break + } + + var msg ProxyServerMessage + if err := json.Unmarshal(message, &msg); err != nil { + log.Printf("Error unmarshaling %s from %s: %s", string(message), c.url, err) + continue + } + + c.processMessage(&msg) + } +} + +func (c *mcuProxyConnection) writePump() { + c.reconnectTimer = time.NewTimer(0) + for { + select { + case <-c.reconnectTimer.C: + c.reconnect() + case <-c.closeChan: + return + } + } +} + +func (c *mcuProxyConnection) start() error { + go c.writePump() + return nil +} + +func (c *mcuProxyConnection) sendClose() error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.conn == nil { + return ErrNotConnected + } + + return c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) +} + +func (c *mcuProxyConnection) stop(ctx context.Context) { + if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) { + return + } + + c.closeChan <- true + if err := c.sendClose(); err != nil { + if err != ErrNotConnected { + log.Printf("Could not send close message to %s: %s", c.url, err) + } + c.close() + return + } + + select { + case <-c.closedChan: + case <-ctx.Done(): + if err := ctx.Err(); err != nil { + log.Printf("Error waiting for connection to %s get closed: %s", c.url, err) + c.close() + } + } +} + +func (c *mcuProxyConnection) close() { + c.mu.Lock() + defer c.mu.Unlock() + + if c.conn != nil { + c.conn.Close() + c.conn = nil + } +} + +func (c *mcuProxyConnection) scheduleReconnect() { + if err := c.sendClose(); err != nil && err != ErrNotConnected { + log.Printf("Could not send close message to %s: %s", c.url, err) + c.close() + } + + interval := atomic.LoadInt64(&c.reconnectInterval) + c.reconnectTimer.Reset(time.Duration(interval)) + + interval = interval * 2 + if interval > int64(maxReconnectInterval) { + interval = int64(maxReconnectInterval) + } + atomic.StoreInt64(&c.reconnectInterval, interval) +} + +func (c *mcuProxyConnection) reconnect() { + u, err := c.url.Parse("proxy") + if err != nil { + log.Printf("Could not resolve url to proxy at %s: %s", c.url, err) + c.scheduleReconnect() + return + } + if u.Scheme == "http" { + u.Scheme = "ws" + } else if u.Scheme == "https" { + u.Scheme = "wss" + } + + conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + if err != nil { + log.Printf("Could not connect to %s: %s", u, err) + c.scheduleReconnect() + return + } + + log.Printf("Connected to %s", u) + atomic.StoreUint32(&c.closed, 0) + + c.mu.Lock() + c.connectedSince = time.Now() + c.conn = conn + c.mu.Unlock() + + atomic.StoreInt64(&c.reconnectInterval, int64(initialReconnectInterval)) + atomic.StoreUint32(&c.shutdownScheduled, 0) + if err := c.sendHello(); err != nil { + log.Printf("Could not send hello request to %s: %s", c.url, err) + c.scheduleReconnect() + return + } + + go c.readPump() +} + +func (c *mcuProxyConnection) removePublisher(publisher *mcuProxyPublisher) { + c.proxy.removePublisher(publisher) + + c.publishersLock.Lock() + defer c.publishersLock.Unlock() + + delete(c.publishers, publisher.proxyId) + delete(c.publisherIds, publisher.id+"|"+publisher.StreamType()) +} + +func (c *mcuProxyConnection) clearPublishers() { + c.publishersLock.Lock() + defer c.publishersLock.Unlock() + + go func(publishers map[string]*mcuProxyPublisher) { + for _, publisher := range publishers { + publisher.NotifyClosed() + } + }(c.publishers) + c.publishers = make(map[string]*mcuProxyPublisher) + c.publisherIds = make(map[string]string) +} + +func (c *mcuProxyConnection) removeSubscriber(subscriber *mcuProxySubscriber) { + c.subscribersLock.Lock() + defer c.subscribersLock.Unlock() + + delete(c.subscribers, subscriber.proxyId) +} + +func (c *mcuProxyConnection) clearSubscribers() { + c.subscribersLock.Lock() + defer c.subscribersLock.Unlock() + + go func(subscribers map[string]*mcuProxySubscriber) { + for _, subscriber := range subscribers { + subscriber.NotifyClosed() + } + }(c.subscribers) + c.subscribers = make(map[string]*mcuProxySubscriber) +} + +func (c *mcuProxyConnection) clearCallbacks() { + c.mu.Lock() + defer c.mu.Unlock() + + c.callbacks = make(map[string]func(*ProxyServerMessage)) +} + +func (c *mcuProxyConnection) getCallback(id string) func(*ProxyServerMessage) { + c.mu.Lock() + defer c.mu.Unlock() + + callback, found := c.callbacks[id] + if found { + delete(c.callbacks, id) + } + return callback +} + +func (c *mcuProxyConnection) processMessage(msg *ProxyServerMessage) { + if c.helloMsgId != "" && msg.Id == c.helloMsgId { + c.helloMsgId = "" + switch msg.Type { + case "error": + if msg.Error.Code == "no_such_session" { + log.Printf("Session %s could not be resumed on %s, registering new", c.sessionId, c.url) + c.clearPublishers() + c.clearSubscribers() + c.clearCallbacks() + c.sessionId = "" + if err := c.sendHello(); err != nil { + log.Printf("Could not send hello request to %s: %s", c.url, err) + c.scheduleReconnect() + } + return + } + + log.Printf("Hello connection to %s failed with %+v, reconnecting", c.url, msg.Error) + c.scheduleReconnect() + case "hello": + c.sessionId = msg.Hello.SessionId + country := "" + if msg.Hello.Server != nil { + if country = msg.Hello.Server.Country; country != "" && !IsValidCountry(country) { + log.Printf("Proxy %s sent invalid country %s in hello response", c.url, country) + country = "" + } + } + c.country.Store(country) + if country != "" { + log.Printf("Received session %s from %s (in %s)", c.sessionId, c.url, country) + } else { + log.Printf("Received session %s from %s", c.sessionId, c.url) + } + default: + log.Printf("Received unsupported hello response %+v from %s, reconnecting", msg, c.url) + c.scheduleReconnect() + } + return + } + + if proxyDebugMessages { + log.Printf("Received from %s: %+v", c.url, msg) + } + callback := c.getCallback(msg.Id) + if callback != nil { + callback(msg) + return + } + + switch msg.Type { + case "payload": + c.processPayload(msg) + case "event": + c.processEvent(msg) + default: + log.Printf("Unsupported message received from %s: %+v", c.url, msg) + } +} + +func (c *mcuProxyConnection) processPayload(msg *ProxyServerMessage) { + payload := msg.Payload + c.publishersLock.RLock() + publisher, found := c.publishers[payload.ClientId] + c.publishersLock.RUnlock() + if found { + publisher.ProcessPayload(payload) + return + } + + c.subscribersLock.RLock() + subscriber, found := c.subscribers[payload.ClientId] + c.subscribersLock.RUnlock() + if found { + subscriber.ProcessPayload(payload) + return + } + + log.Printf("Received payload for unknown client %+v from %s", payload, c.url) +} + +func (c *mcuProxyConnection) processEvent(msg *ProxyServerMessage) { + event := msg.Event + switch event.Type { + case "backend-disconnected": + log.Printf("Upstream backend at %s got disconnected, reset MCU objects", c.url) + c.clearPublishers() + c.clearSubscribers() + c.clearCallbacks() + // TODO: Should we also reconnect? + return + case "backend-connected": + log.Printf("Upstream backend at %s is connected", c.url) + return + case "update-load": + if proxyDebugMessages { + log.Printf("Load of %s now at %d", c.url, event.Load) + } + atomic.StoreInt64(&c.load, event.Load) + return + case "shutdown-scheduled": + log.Printf("Proxy %s is scheduled to shutdown", c.url) + atomic.StoreUint32(&c.shutdownScheduled, 1) + return + } + + if proxyDebugMessages { + log.Printf("Process event from %s: %+v", c.url, event) + } + c.publishersLock.RLock() + publisher, found := c.publishers[event.ClientId] + c.publishersLock.RUnlock() + if found { + publisher.ProcessEvent(event) + return + } + + c.subscribersLock.RLock() + subscriber, found := c.subscribers[event.ClientId] + c.subscribersLock.RUnlock() + if found { + subscriber.ProcessEvent(event) + return + } + + log.Printf("Received event for unknown client %+v from %s", event, c.url) +} + +func (c *mcuProxyConnection) sendHello() error { + c.helloMsgId = strconv.FormatInt(atomic.AddInt64(&c.msgId, 1), 10) + msg := &ProxyClientMessage{ + Id: c.helloMsgId, + Type: "hello", + Hello: &HelloProxyClientMessage{ + Version: "1.0", + }, + } + if c.sessionId != "" { + msg.Hello.ResumeId = c.sessionId + } else { + claims := &TokenClaims{ + jwt.StandardClaims{ + IssuedAt: time.Now().Unix(), + Issuer: c.proxy.tokenId, + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tokenString, err := token.SignedString(c.proxy.tokenKey) + if err != nil { + return err + } + + msg.Hello.Token = tokenString + } + return c.sendMessage(msg) +} + +func (c *mcuProxyConnection) sendMessage(msg *ProxyClientMessage) error { + c.mu.Lock() + defer c.mu.Unlock() + + return c.sendMessageLocked(msg) +} + +func (c *mcuProxyConnection) sendMessageLocked(msg *ProxyClientMessage) error { + if proxyDebugMessages { + log.Printf("Send message to %s: %+v", c.url, msg) + } + if c.conn == nil { + return ErrNotConnected + } + return c.conn.WriteJSON(msg) +} + +func (c *mcuProxyConnection) performAsyncRequest(ctx context.Context, msg *ProxyClientMessage, callback func(err error, response *ProxyServerMessage)) { + msgId := strconv.FormatInt(atomic.AddInt64(&c.msgId, 1), 10) + msg.Id = msgId + + c.mu.Lock() + defer c.mu.Unlock() + c.callbacks[msgId] = func(msg *ProxyServerMessage) { + callback(nil, msg) + } + if err := c.sendMessageLocked(msg); err != nil { + delete(c.callbacks, msgId) + go callback(err, nil) + return + } +} + +func (c *mcuProxyConnection) performSyncRequest(ctx context.Context, msg *ProxyClientMessage) (*ProxyServerMessage, error) { + errChan := make(chan error, 1) + responseChan := make(chan *ProxyServerMessage, 1) + c.performAsyncRequest(ctx, msg, func(err error, response *ProxyServerMessage) { + if err != nil { + errChan <- err + } else { + responseChan <- response + } + }) + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case err := <-errChan: + return nil, err + case response := <-responseChan: + return response, nil + } +} + +func (c *mcuProxyConnection) newPublisher(ctx context.Context, listener McuListener, id string, streamType string) (McuPublisher, error) { + msg := &ProxyClientMessage{ + Type: "command", + Command: &CommandProxyClientMessage{ + Type: "create-publisher", + StreamType: streamType, + }, + } + + response, err := c.performSyncRequest(ctx, msg) + if err != nil { + // TODO: Cancel request + return nil, err + } + + proxyId := response.Command.Id + log.Printf("Created %s publisher %s on %s for %s", streamType, proxyId, c.url, id) + publisher := newMcuProxyPublisher(id, streamType, proxyId, c, listener) + c.publishersLock.Lock() + c.publishers[proxyId] = publisher + c.publisherIds[id+"|"+streamType] = proxyId + c.publishersLock.Unlock() + return publisher, nil +} + +func (c *mcuProxyConnection) newSubscriber(ctx context.Context, listener McuListener, publisher string, streamType string) (McuSubscriber, error) { + c.publishersLock.Lock() + id, found := c.publisherIds[publisher+"|"+streamType] + c.publishersLock.Unlock() + if !found { + return nil, fmt.Errorf("Unknown publisher %s", publisher) + } + + msg := &ProxyClientMessage{ + Type: "command", + Command: &CommandProxyClientMessage{ + Type: "create-subscriber", + StreamType: streamType, + PublisherId: id, + }, + } + + response, err := c.performSyncRequest(ctx, msg) + if err != nil { + // TODO: Cancel request + return nil, err + } + + proxyId := response.Command.Id + log.Printf("Created %s subscriber %s on %s for %s", streamType, proxyId, c.url, publisher) + subscriber := newMcuProxySubscriber(publisher, streamType, proxyId, c, listener) + c.subscribersLock.Lock() + c.subscribers[proxyId] = subscriber + c.subscribersLock.Unlock() + return subscriber, nil +} + +type mcuProxy struct { + tokenId string + tokenKey *rsa.PrivateKey + + connections atomic.Value + connRequests int64 + nextSort int64 + + mu sync.RWMutex + publishers map[string]*mcuProxyConnection + + publisherWaitersId uint64 + publisherWaiters map[uint64]chan bool +} + +func NewMcuProxy(baseUrl string, config *goconf.ConfigFile) (Mcu, error) { + var connections []*mcuProxyConnection + + tokenId, _ := config.GetString("mcu", "token_id") + if tokenId == "" { + return nil, fmt.Errorf("No token id configured") + } + tokenKeyFilename, _ := config.GetString("mcu", "token_key") + if tokenKeyFilename == "" { + return nil, fmt.Errorf("No token key configured") + } + tokenKeyData, err := ioutil.ReadFile(tokenKeyFilename) + if err != nil { + return nil, fmt.Errorf("Could not read private key from %s: %s", tokenKeyFilename, err) + } + tokenKey, err := jwt.ParseRSAPrivateKeyFromPEM(tokenKeyData) + if err != nil { + return nil, fmt.Errorf("Could not parse private key from %s: %s", tokenKeyFilename, err) + } + + mcu := &mcuProxy{ + tokenId: tokenId, + tokenKey: tokenKey, + + publishers: make(map[string]*mcuProxyConnection), + + publisherWaiters: make(map[uint64]chan bool), + } + + for _, u := range strings.Split(baseUrl, " ") { + conn, err := newMcuProxyConnection(mcu, u) + if err != nil { + return nil, err + } + + connections = append(connections, conn) + } + if len(connections) == 0 { + return nil, fmt.Errorf("No MCU proxy connections configured") + } + + mcu.setConnections(connections) + return mcu, nil +} + +func (m *mcuProxy) setConnections(connections []*mcuProxyConnection) { + m.connections.Store(connections) +} + +func (m *mcuProxy) getConnections() []*mcuProxyConnection { + return m.connections.Load().([]*mcuProxyConnection) +} + +func (m *mcuProxy) Start() error { + for _, c := range m.getConnections() { + if err := c.start(); err != nil { + return err + } + } + return nil +} + +func (m *mcuProxy) Stop() { + for _, c := range m.getConnections() { + ctx, cancel := context.WithTimeout(context.Background(), closeTimeout) + defer cancel() + c.stop(ctx) + } +} + +func (m *mcuProxy) SetOnConnected(f func()) { + // Not supported. +} + +func (m *mcuProxy) SetOnDisconnected(f func()) { + // Not supported. +} + +type mcuProxyStats struct { + Publishers int64 `json:"publishers"` + Clients int64 `json:"clients"` + Details map[string]*mcuProxyConnectionStats `json:"details"` +} + +func (m *mcuProxy) GetStats() interface{} { + details := make(map[string]*mcuProxyConnectionStats) + result := &mcuProxyStats{ + Details: details, + } + for _, conn := range m.getConnections() { + stats := conn.GetStats() + result.Publishers += stats.Publishers + result.Clients += stats.Clients + details[stats.Url] = stats + } + return result +} + +type mcuProxyConnectionsList []*mcuProxyConnection + +func (l mcuProxyConnectionsList) Len() int { + return len(l) +} + +func (l mcuProxyConnectionsList) Less(i, j int) bool { + return l[i].Load() < l[j].Load() +} + +func (l mcuProxyConnectionsList) Swap(i, j int) { + l[i], l[j] = l[j], l[i] +} + +func (l mcuProxyConnectionsList) Sort() { + sort.Sort(l) +} + +func ContinentsOverlap(a, b []string) bool { + if len(a) == 0 || len(b) == 0 { + return false + } + + for _, checkA := range a { + for _, checkB := range b { + if checkA == checkB { + return true + } + } + } + return false +} + +func sortConnectionsForCountry(connections []*mcuProxyConnection, country string) []*mcuProxyConnection { + // Move connections in the same country to the start of the list. + sorted := make(mcuProxyConnectionsList, 0, len(connections)) + unprocessed := make(mcuProxyConnectionsList, 0, len(connections)) + for _, conn := range connections { + if country == conn.Country() { + sorted = append(sorted, conn) + } else { + unprocessed = append(unprocessed, conn) + } + } + if continents, found := ContinentMap[country]; found && len(unprocessed) > 1 { + remaining := make(mcuProxyConnectionsList, 0, len(unprocessed)) + // Next up are connections on the same continent. + for _, conn := range unprocessed { + connCountry := conn.Country() + if IsValidCountry(connCountry) { + connContinents := ContinentMap[connCountry] + if ContinentsOverlap(continents, connContinents) { + sorted = append(sorted, conn) + } else { + remaining = append(remaining, conn) + } + } else { + remaining = append(remaining, conn) + } + } + unprocessed = remaining + } + // Add all other connections by load. + sorted = append(sorted, unprocessed...) + return sorted +} + +func (m *mcuProxy) getSortedConnections(initiator McuInitiator) []*mcuProxyConnection { + connections := m.getConnections() + if len(connections) < 2 { + return connections + } + + // Connections are re-sorted every requests or + // every . + now := time.Now().UnixNano() + if atomic.AddInt64(&m.connRequests, 1)%connectionSortRequests == 0 || atomic.LoadInt64(&m.nextSort) <= now { + atomic.StoreInt64(&m.nextSort, now+int64(connectionSortInterval)) + + sorted := make(mcuProxyConnectionsList, len(connections)) + copy(sorted, connections) + + sorted.Sort() + + m.setConnections(sorted) + connections = sorted + } + + if initiator != nil { + if country := initiator.Country(); IsValidCountry(country) { + connections = sortConnectionsForCountry(connections, country) + } + } + return connections +} + +func (m *mcuProxy) removePublisher(publisher *mcuProxyPublisher) { + m.mu.Lock() + defer m.mu.Unlock() + + delete(m.publishers, publisher.id+"|"+publisher.StreamType()) +} + +func (m *mcuProxy) wakeupWaiters() { + m.mu.RLock() + defer m.mu.RUnlock() + for _, ch := range m.publisherWaiters { + ch <- true + } +} + +func (m *mcuProxy) addWaiter(ch chan bool) uint64 { + id := m.publisherWaitersId + 1 + m.publisherWaitersId = id + m.publisherWaiters[id] = ch + return id +} + +func (m *mcuProxy) removeWaiter(id uint64) { + delete(m.publisherWaiters, id) +} + +func (m *mcuProxy) NewPublisher(ctx context.Context, listener McuListener, id string, streamType string, initiator McuInitiator) (McuPublisher, error) { + connections := m.getSortedConnections(initiator) + for _, conn := range connections { + if conn.IsShutdownScheduled() { + continue + } + + publisher, err := conn.newPublisher(ctx, listener, id, streamType) + if err != nil { + log.Printf("Could not create %s publisher for %s on %s: %s", streamType, id, conn.url, err) + continue + } + + m.mu.Lock() + m.publishers[id+"|"+streamType] = conn + m.mu.Unlock() + m.wakeupWaiters() + return publisher, nil + } + + return nil, fmt.Errorf("No MCU connection available") +} + +func (m *mcuProxy) getPublisherConnection(ctx context.Context, publisher string, streamType string) *mcuProxyConnection { + m.mu.RLock() + conn := m.publishers[publisher+"|"+streamType] + m.mu.RUnlock() + if conn != nil { + return conn + } + + log.Printf("No %s publisher %s found yet, deferring", streamType, publisher) + m.mu.Lock() + defer m.mu.Unlock() + + conn = m.publishers[publisher+"|"+streamType] + if conn != nil { + return conn + } + + ch := make(chan bool, 1) + id := m.addWaiter(ch) + defer m.removeWaiter(id) + + for { + m.mu.Unlock() + select { + case <-ch: + m.mu.Lock() + conn = m.publishers[publisher+"|"+streamType] + if conn != nil { + return conn + } + case <-ctx.Done(): + m.mu.Lock() + return nil + } + } +} + +func (m *mcuProxy) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType string) (McuSubscriber, error) { + conn := m.getPublisherConnection(ctx, publisher, streamType) + if conn == nil { + return nil, fmt.Errorf("No %s publisher %s found", streamType, publisher) + } + + return conn.newSubscriber(ctx, listener, publisher, streamType) +} diff --git a/src/signaling/mcu_proxy_test.go b/src/signaling/mcu_proxy_test.go new file mode 100644 index 0000000..bf1cbca --- /dev/null +++ b/src/signaling/mcu_proxy_test.go @@ -0,0 +1,86 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2020 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "testing" +) + +func newProxyConnectionWithCountry(country string) *mcuProxyConnection { + conn := &mcuProxyConnection{} + conn.country.Store(country) + return conn +} + +func Test_sortConnectionsForCountry(t *testing.T) { + conn_de := newProxyConnectionWithCountry("DE") + conn_at := newProxyConnectionWithCountry("AT") + conn_jp := newProxyConnectionWithCountry("JP") + conn_us := newProxyConnectionWithCountry("US") + + testcases := map[string][][]*mcuProxyConnection{ + // Direct country match + "DE": [][]*mcuProxyConnection{ + []*mcuProxyConnection{conn_at, conn_jp, conn_de}, + []*mcuProxyConnection{conn_de, conn_at, conn_jp}, + }, + // Direct country match + "AT": [][]*mcuProxyConnection{ + []*mcuProxyConnection{conn_at, conn_jp, conn_de}, + []*mcuProxyConnection{conn_at, conn_de, conn_jp}, + }, + // Continent match + "CH": [][]*mcuProxyConnection{ + []*mcuProxyConnection{conn_de, conn_jp, conn_at}, + []*mcuProxyConnection{conn_de, conn_at, conn_jp}, + }, + // Direct country match + "JP": [][]*mcuProxyConnection{ + []*mcuProxyConnection{conn_de, conn_jp, conn_at}, + []*mcuProxyConnection{conn_jp, conn_de, conn_at}, + }, + // Continent match + "CN": [][]*mcuProxyConnection{ + []*mcuProxyConnection{conn_de, conn_jp, conn_at}, + []*mcuProxyConnection{conn_jp, conn_de, conn_at}, + }, + // Partial continent match + "RU": [][]*mcuProxyConnection{ + []*mcuProxyConnection{conn_us, conn_de, conn_jp, conn_at}, + []*mcuProxyConnection{conn_de, conn_jp, conn_at, conn_us}, + }, + // No match + "AU": [][]*mcuProxyConnection{ + []*mcuProxyConnection{conn_us, conn_de, conn_jp, conn_at}, + []*mcuProxyConnection{conn_us, conn_de, conn_jp, conn_at}, + }, + } + + for country, test := range testcases { + sorted := sortConnectionsForCountry(test[0], country) + for idx, conn := range sorted { + if test[1][idx] != conn { + t.Errorf("Index %d for %s: expected %s, got %s", idx, country, test[1][idx].Country(), conn.Country()) + } + } + } +} diff --git a/src/signaling/mcu_test.go b/src/signaling/mcu_test.go index 8bdaad3..dbfe485 100644 --- a/src/signaling/mcu_test.go +++ b/src/signaling/mcu_test.go @@ -41,11 +41,17 @@ func (m *TestMCU) Start() error { func (m *TestMCU) Stop() { } +func (m *TestMCU) SetOnConnected(f func()) { +} + +func (m *TestMCU) SetOnDisconnected(f func()) { +} + func (m *TestMCU) GetStats() interface{} { return nil } -func (m *TestMCU) NewPublisher(ctx context.Context, listener McuListener, id string, streamType string) (McuPublisher, error) { +func (m *TestMCU) NewPublisher(ctx context.Context, listener McuListener, id string, streamType string, initiator McuInitiator) (McuPublisher, error) { return nil, fmt.Errorf("Not implemented") }