diff --git a/etcd_client.go b/etcd_client.go new file mode 100644 index 0000000..7d8f1bd --- /dev/null +++ b/etcd_client.go @@ -0,0 +1,263 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2022 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "context" + "fmt" + "log" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/dlintw/goconf" + "go.etcd.io/etcd/client/pkg/v3/srv" + "go.etcd.io/etcd/client/pkg/v3/transport" + clientv3 "go.etcd.io/etcd/client/v3" +) + +type EtcdClientListener interface { + EtcdClientCreated(client *EtcdClient) +} + +type EtcdClientWatcher interface { + EtcdKeyUpdated(client *EtcdClient, key string, value []byte) + EtcdKeyDeleted(client *EtcdClient, key string) +} + +type EtcdClient struct { + compatSection string + + mu sync.Mutex + client atomic.Value + listeners map[EtcdClientListener]bool +} + +func NewEtcdClient(config *goconf.ConfigFile, compatSection string) (*EtcdClient, error) { + result := &EtcdClient{ + compatSection: compatSection, + } + if err := result.load(config, false); err != nil { + return nil, err + } + + return result, nil +} + +func (c *EtcdClient) getConfigStringWithFallback(config *goconf.ConfigFile, option string) string { + value, _ := config.GetString("etcd", option) + if value == "" && c.compatSection != "" { + value, _ = config.GetString(c.compatSection, option) + if value != "" { + log.Printf("WARNING: Configuring etcd option \"%s\" in section \"%s\" is deprecated, use section \"etcd\" instead", option, c.compatSection) + } + } + + return value +} + +func (c *EtcdClient) load(config *goconf.ConfigFile, ignoreErrors bool) error { + var endpoints []string + if endpointsString := c.getConfigStringWithFallback(config, "endpoints"); endpointsString != "" { + for _, ep := range strings.Split(endpointsString, ",") { + ep := strings.TrimSpace(ep) + if ep != "" { + endpoints = append(endpoints, ep) + } + } + } else if discoverySrv := c.getConfigStringWithFallback(config, "discoverysrv"); discoverySrv != "" { + discoveryService := c.getConfigStringWithFallback(config, "discoveryservice") + clients, err := srv.GetClient("etcd-client", discoverySrv, discoveryService) + if err != nil { + if !ignoreErrors { + return fmt.Errorf("Could not discover etcd endpoints for %s: %w", discoverySrv, err) + } + } else { + endpoints = clients.Endpoints + } + } + + if len(endpoints) == 0 { + if !ignoreErrors { + return nil + } + + log.Printf("No etcd endpoints configured, not changing client") + } else { + cfg := clientv3.Config{ + Endpoints: endpoints, + + // set timeout per request to fail fast when the target endpoint is unavailable + DialTimeout: time.Second, + } + + clientKey := c.getConfigStringWithFallback(config, "clientkey") + clientCert := c.getConfigStringWithFallback(config, "clientcert") + caCert := c.getConfigStringWithFallback(config, "cacert") + if clientKey != "" && clientCert != "" && caCert != "" { + tlsInfo := transport.TLSInfo{ + CertFile: clientCert, + KeyFile: clientKey, + TrustedCAFile: caCert, + } + tlsConfig, err := tlsInfo.ClientConfig() + if err != nil { + if !ignoreErrors { + return fmt.Errorf("Could not setup etcd TLS configuration: %w", err) + } + + log.Printf("Could not setup TLS configuration, will be disabled (%s)", err) + } else { + cfg.TLS = tlsConfig + } + } + + client, err := clientv3.New(cfg) + if err != nil { + if !ignoreErrors { + return err + } + + log.Printf("Could not create new client from etd endpoints %+v: %s", endpoints, err) + } else { + prev := c.getEtcdClient() + if prev != nil { + prev.Close() + } + c.client.Store(client) + log.Printf("Using etcd endpoints %+v", endpoints) + c.notifyListeners() + } + } + + return nil +} + +func (c *EtcdClient) Close() error { + client := c.getEtcdClient() + if client != nil { + return client.Close() + } + + return nil +} + +func (c *EtcdClient) IsConfigured() bool { + return c.getEtcdClient() != nil +} + +func (c *EtcdClient) getEtcdClient() *clientv3.Client { + client := c.client.Load() + if client == nil { + return nil + } + + return client.(*clientv3.Client) +} + +func (c *EtcdClient) syncClient() error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + return c.getEtcdClient().Sync(ctx) +} + +func (c *EtcdClient) notifyListeners() { + c.mu.Lock() + defer c.mu.Unlock() + + for listener := range c.listeners { + listener.EtcdClientCreated(c) + } +} + +func (c *EtcdClient) AddListener(listener EtcdClientListener) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.listeners == nil { + c.listeners = make(map[EtcdClientListener]bool) + } + c.listeners[listener] = true + if client := c.getEtcdClient(); client != nil { + go listener.EtcdClientCreated(c) + } +} + +func (c *EtcdClient) RemoveListener(listener EtcdClientListener) { + c.mu.Lock() + defer c.mu.Unlock() + + delete(c.listeners, listener) +} + +func (c *EtcdClient) WaitForConnection() { + waitDelay := initialWaitDelay + for { + if err := c.syncClient(); err != nil { + if err == context.DeadlineExceeded { + log.Printf("Timeout waiting for etcd client to connect to the cluster, retry in %s", waitDelay) + } else { + log.Printf("Could not sync etcd client with the cluster, retry in %s: %s", waitDelay, err) + } + + time.Sleep(waitDelay) + waitDelay = waitDelay * 2 + if waitDelay > maxWaitDelay { + waitDelay = maxWaitDelay + } + continue + } + + log.Printf("Client synced, using endpoints %+v", c.getEtcdClient().Endpoints()) + return + } +} + +func (c *EtcdClient) Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + return c.getEtcdClient().Get(ctx, key, opts...) +} + +func (c *EtcdClient) Watch(ctx context.Context, key string, watcher EtcdClientWatcher, opts ...clientv3.OpOption) error { + log.Printf("Wait for leader and start watching on %s", key) + ch := c.getEtcdClient().Watch(clientv3.WithRequireLeader(ctx), key, opts...) + log.Printf("Watch created for %s", key) + for response := range ch { + if err := response.Err(); err != nil { + return err + } + + for _, ev := range response.Events { + switch ev.Type { + case clientv3.EventTypePut: + watcher.EtcdKeyUpdated(c, string(ev.Kv.Key), ev.Kv.Value) + case clientv3.EventTypeDelete: + watcher.EtcdKeyDeleted(c, string(ev.Kv.Key)) + default: + log.Printf("Unsupported watch event %s %q -> %q", ev.Type, ev.Kv.Key, ev.Kv.Value) + } + } + } + + return nil +} diff --git a/etcd_client_test.go b/etcd_client_test.go new file mode 100644 index 0000000..695c4a0 --- /dev/null +++ b/etcd_client_test.go @@ -0,0 +1,290 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2022 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "context" + "errors" + "net" + "net/url" + "os" + "runtime" + "strconv" + "syscall" + "testing" + "time" + + "github.com/dlintw/goconf" + "go.etcd.io/etcd/api/v3/mvccpb" + clientv3 "go.etcd.io/etcd/client/v3" + "go.etcd.io/etcd/server/v3/embed" + "go.etcd.io/etcd/server/v3/lease" +) + +var ( + etcdListenUrl = "http://localhost:8080" +) + +func isErrorAddressAlreadyInUse(err error) bool { + var eOsSyscall *os.SyscallError + if !errors.As(err, &eOsSyscall) { + return false + } + var errErrno syscall.Errno // doesn't need a "*" (ptr) because it's already a ptr (uintptr) + if !errors.As(eOsSyscall, &errErrno) { + return false + } + if errErrno == syscall.EADDRINUSE { + return true + } + const WSAEADDRINUSE = 10048 + if runtime.GOOS == "windows" && errErrno == WSAEADDRINUSE { + return true + } + return false +} + +func NewEtcdForTest(t *testing.T) *embed.Etcd { + cfg := embed.NewConfig() + cfg.Dir = t.TempDir() + os.Chmod(cfg.Dir, 0700) // nolint + cfg.LogLevel = "warn" + + u, err := url.Parse(etcdListenUrl) + if err != nil { + t.Fatal(err) + } + + // Find a free port to bind the server to. + var etcd *embed.Etcd + for port := 50000; port < 50100; port++ { + u.Host = net.JoinHostPort("localhost", strconv.Itoa(port)) + cfg.LCUrls = []url.URL{*u} + cfg.ACUrls = []url.URL{*u} + etcd, err = embed.StartEtcd(cfg) + if isErrorAddressAlreadyInUse(err) { + continue + } else if err != nil { + t.Fatal(err) + } + break + } + if etcd == nil { + t.Fatal("could not find free port") + } + + t.Cleanup(func() { + etcd.Close() + }) + // Wait for server to be ready. + <-etcd.Server.ReadyNotify() + + return etcd +} + +func NewEtcdClientForTest(t *testing.T) (*embed.Etcd, *EtcdClient) { + etcd := NewEtcdForTest(t) + + config := goconf.NewConfigFile() + config.AddOption("etcd", "endpoints", etcd.Config().LCUrls[0].String()) + + client, err := NewEtcdClient(config, "") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + if err := client.Close(); err != nil { + t.Error(err) + } + }) + return etcd, client +} + +func SetEtcdValue(etcd *embed.Etcd, key string, value []byte) { + if kv := etcd.Server.KV(); kv != nil { + kv.Put([]byte(key), value, lease.NoLease) + kv.Commit() + } +} + +func DeleteEtcdValue(etcd *embed.Etcd, key string) { + if kv := etcd.Server.KV(); kv != nil { + kv.DeleteRange([]byte(key), nil) + kv.Commit() + } +} + +func Test_EtcdClient_Get(t *testing.T) { + etcd, client := NewEtcdClientForTest(t) + + if response, err := client.Get(context.Background(), "foo"); err != nil { + t.Error(err) + } else if response.Count != 0 { + t.Errorf("expected 0 response, got %d", response.Count) + } + + SetEtcdValue(etcd, "foo", []byte("bar")) + + if response, err := client.Get(context.Background(), "foo"); err != nil { + t.Error(err) + } else if response.Count != 1 { + t.Errorf("expected 1 responses, got %d", response.Count) + } else if string(response.Kvs[0].Key) != "foo" { + t.Errorf("expected key \"foo\", got \"%s\"", string(response.Kvs[0].Key)) + } else if string(response.Kvs[0].Value) != "bar" { + t.Errorf("expected value \"bar\", got \"%s\"", string(response.Kvs[0].Value)) + } +} + +func Test_EtcdClient_GetPrefix(t *testing.T) { + etcd, client := NewEtcdClientForTest(t) + + if response, err := client.Get(context.Background(), "foo"); err != nil { + t.Error(err) + } else if response.Count != 0 { + t.Errorf("expected 0 response, got %d", response.Count) + } + + SetEtcdValue(etcd, "foo", []byte("1")) + SetEtcdValue(etcd, "foo/lala", []byte("2")) + SetEtcdValue(etcd, "lala/foo", []byte("3")) + + if response, err := client.Get(context.Background(), "foo", clientv3.WithPrefix()); err != nil { + t.Error(err) + } else if response.Count != 2 { + t.Errorf("expected 2 responses, got %d", response.Count) + } else if string(response.Kvs[0].Key) != "foo" { + t.Errorf("expected key \"foo\", got \"%s\"", string(response.Kvs[0].Key)) + } else if string(response.Kvs[0].Value) != "1" { + t.Errorf("expected value \"1\", got \"%s\"", string(response.Kvs[0].Value)) + } else if string(response.Kvs[1].Key) != "foo/lala" { + t.Errorf("expected key \"foo/lala\", got \"%s\"", string(response.Kvs[1].Key)) + } else if string(response.Kvs[1].Value) != "2" { + t.Errorf("expected value \"2\", got \"%s\"", string(response.Kvs[1].Value)) + } +} + +type etcdEvent struct { + t mvccpb.Event_EventType + key string + value string +} + +type EtcdClientTestListener struct { + t *testing.T + + ctx context.Context + cancel context.CancelFunc + + initial chan bool + events chan etcdEvent +} + +func NewEtcdClientTestListener(ctx context.Context, t *testing.T) *EtcdClientTestListener { + ctx, cancel := context.WithCancel(ctx) + return &EtcdClientTestListener{ + t: t, + + ctx: ctx, + cancel: cancel, + + initial: make(chan bool), + events: make(chan etcdEvent), + } +} + +func (l *EtcdClientTestListener) Close() { + l.cancel() +} + +func (l *EtcdClientTestListener) EtcdClientCreated(client *EtcdClient) { + go func() { + if err := client.Watch(clientv3.WithRequireLeader(l.ctx), "foo", l, clientv3.WithPrefix()); err != nil { + l.t.Error(err) + } + }() + + go func() { + client.WaitForConnection() + + ctx, cancel := context.WithTimeout(l.ctx, time.Second) + defer cancel() + + if response, err := client.Get(ctx, "foo", clientv3.WithPrefix()); err != nil { + l.t.Error(err) + } else if response.Count != 1 { + l.t.Errorf("expected 1 responses, got %d", response.Count) + } else if string(response.Kvs[0].Key) != "foo/a" { + l.t.Errorf("expected key \"foo/a\", got \"%s\"", string(response.Kvs[0].Key)) + } else if string(response.Kvs[0].Value) != "1" { + l.t.Errorf("expected value \"1\", got \"%s\"", string(response.Kvs[0].Value)) + } + l.initial <- true + }() +} + +func (l *EtcdClientTestListener) EtcdKeyUpdated(client *EtcdClient, key string, value []byte) { + l.events <- etcdEvent{ + t: clientv3.EventTypePut, + key: string(key), + value: string(value), + } +} + +func (l *EtcdClientTestListener) EtcdKeyDeleted(client *EtcdClient, key string) { + l.events <- etcdEvent{ + t: clientv3.EventTypeDelete, + key: string(key), + } +} + +func Test_EtcdClient_Watch(t *testing.T) { + etcd, client := NewEtcdClientForTest(t) + + SetEtcdValue(etcd, "foo/a", []byte("1")) + + listener := NewEtcdClientTestListener(context.Background(), t) + defer listener.Close() + + client.AddListener(listener) + defer client.RemoveListener(listener) + + <-listener.initial + + SetEtcdValue(etcd, "foo/b", []byte("2")) + event := <-listener.events + if event.t != clientv3.EventTypePut { + t.Errorf("expected type %d, got %d", clientv3.EventTypePut, event.t) + } else if event.key != "foo/b" { + t.Errorf("expected key %s, got %s", "foo/b", event.key) + } else if event.value != "2" { + t.Errorf("expected value %s, got %s", "2", event.value) + } + + DeleteEtcdValue(etcd, "foo/a") + event = <-listener.events + if event.t != clientv3.EventTypeDelete { + t.Errorf("expected type %d, got %d", clientv3.EventTypeDelete, event.t) + } else if event.key != "foo/a" { + t.Errorf("expected key %s, got %s", "foo/a", event.key) + } +} diff --git a/mcu_proxy.go b/mcu_proxy.go index 3c9baad..d395c82 100644 --- a/mcu_proxy.go +++ b/mcu_proxy.go @@ -43,8 +43,6 @@ import ( "github.com/golang-jwt/jwt" "github.com/gorilla/websocket" - "go.etcd.io/etcd/client/pkg/v3/srv" - "go.etcd.io/etcd/client/pkg/v3/transport" clientv3 "go.etcd.io/etcd/client/v3" ) @@ -1051,10 +1049,11 @@ type mcuProxy struct { tokenId string tokenKey *rsa.PrivateKey - etcdMu sync.Mutex - client atomic.Value - keyInfos map[string]*ProxyInformationEtcd - urlToKey map[string]string + etcdMu sync.Mutex + etcdClient *EtcdClient + keyPrefix string + keyInfos map[string]*ProxyInformationEtcd + urlToKey map[string]string dialer *websocket.Dialer connections []*mcuProxyConnection @@ -1078,7 +1077,7 @@ type mcuProxy struct { continentsMap atomic.Value } -func NewMcuProxy(config *goconf.ConfigFile) (Mcu, error) { +func NewMcuProxy(config *goconf.ConfigFile, etcdClient *EtcdClient) (Mcu, error) { urlType, _ := config.GetString("mcu", "urltype") if urlType == "" { urlType = proxyUrlTypeStatic @@ -1122,6 +1121,8 @@ func NewMcuProxy(config *goconf.ConfigFile) (Mcu, error) { tokenId: tokenId, tokenKey: tokenKey, + etcdClient: etcdClient, + dialer: &websocket.Dialer{ Proxy: http.ProxyFromEnvironment, HandshakeTimeout: proxyTimeout, @@ -1161,11 +1162,16 @@ func NewMcuProxy(config *goconf.ConfigFile) (Mcu, error) { return nil, fmt.Errorf("No MCU proxy connections configured") } case proxyUrlTypeEtcd: + if !etcdClient.IsConfigured() { + return nil, fmt.Errorf("No etcd endpoints configured") + } + mcu.keyInfos = make(map[string]*ProxyInformationEtcd) mcu.urlToKey = make(map[string]string) if err := mcu.configureEtcd(config, false); err != nil { return nil, err } + mcu.etcdClient.AddListener(mcu) default: return nil, fmt.Errorf("Unsupported proxy URL type %s", urlType) } @@ -1211,15 +1217,6 @@ func (m *mcuProxy) loadContinentsMap(config *goconf.ConfigFile) error { return nil } -func (m *mcuProxy) getEtcdClient() *clientv3.Client { - c := m.client.Load() - if c == nil { - return nil - } - - return c.(*clientv3.Client) -} - func (m *mcuProxy) Start() error { m.connectionsMu.RLock() defer m.connectionsMu.RUnlock() @@ -1241,6 +1238,7 @@ func (m *mcuProxy) Start() error { } func (m *mcuProxy) Stop() { + m.etcdClient.RemoveListener(m) m.connectionsMu.RLock() defer m.connectionsMu.RUnlock() @@ -1469,151 +1467,51 @@ func (m *mcuProxy) configureEtcd(config *goconf.ConfigFile, ignoreErrors bool) e keyPrefix = "/%s" } - var endpoints []string - if endpointsString, _ := config.GetString("mcu", "endpoints"); endpointsString != "" { - for _, ep := range strings.Split(endpointsString, ",") { - ep := strings.TrimSpace(ep) - if ep != "" { - endpoints = append(endpoints, ep) - } - } - } else if discoverySrv, _ := config.GetString("mcu", "discoverysrv"); discoverySrv != "" { - discoveryService, _ := config.GetString("mcu", "discoveryservice") - clients, err := srv.GetClient("etcd-client", discoverySrv, discoveryService) - if err != nil { - if !ignoreErrors { - return fmt.Errorf("Could not discover endpoints for %s: %s", discoverySrv, err) - } - } else { - endpoints = clients.Endpoints - } - } - - if len(endpoints) == 0 { - if !ignoreErrors { - return fmt.Errorf("No proxy URL endpoints configured") - } - - log.Printf("No proxy URL endpoints configured, not changing client") - } else { - cfg := clientv3.Config{ - Endpoints: endpoints, - - // set timeout per request to fail fast when the target endpoint is unavailable - DialTimeout: time.Second, - } - - clientKey, _ := config.GetString("mcu", "clientkey") - clientCert, _ := config.GetString("mcu", "clientcert") - caCert, _ := config.GetString("mcu", "cacert") - if clientKey != "" && clientCert != "" && caCert != "" { - tlsInfo := transport.TLSInfo{ - CertFile: clientCert, - KeyFile: clientKey, - TrustedCAFile: caCert, - } - tlsConfig, err := tlsInfo.ClientConfig() - if err != nil { - if !ignoreErrors { - return fmt.Errorf("Could not setup TLS configuration: %s", err) - } - - log.Printf("Could not setup TLS configuration, will be disabled (%s)", err) - } else { - cfg.TLS = tlsConfig - } - } - - c, err := clientv3.New(cfg) - if err != nil { - if !ignoreErrors { - return err - } - - log.Printf("Could not create new client from proxy URL endpoints %+v: %s", endpoints, err) - } else { - prev := m.getEtcdClient() - if prev != nil { - prev.Close() - } - m.client.Store(c) - log.Printf("Using proxy URL endpoints %+v", endpoints) - - go func(client *clientv3.Client) { - log.Printf("Wait for leader and start watching on %s", keyPrefix) - ch := client.Watch(clientv3.WithRequireLeader(context.Background()), keyPrefix, clientv3.WithPrefix()) - log.Printf("Watch created for %s", keyPrefix) - m.processWatches(ch) - }(c) - - go func() { - m.waitForConnection() - - waitDelay := initialWaitDelay - for { - response, err := m.getProxyUrls(keyPrefix) - if err != nil { - if err == context.DeadlineExceeded { - log.Printf("Timeout getting initial list of proxy URLs, retry in %s", waitDelay) - } else { - log.Printf("Could not get initial list of proxy URLs, retry in %s: %s", waitDelay, err) - } - - time.Sleep(waitDelay) - waitDelay = waitDelay * 2 - if waitDelay > maxWaitDelay { - waitDelay = maxWaitDelay - } - continue - } - - for _, ev := range response.Kvs { - m.addEtcdProxy(string(ev.Key), ev.Value) - } - return - } - }() - } - } - + m.keyPrefix = keyPrefix return nil } -func (m *mcuProxy) getProxyUrls(keyPrefix string) (*clientv3.GetResponse, error) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - - return m.getEtcdClient().Get(ctx, keyPrefix, clientv3.WithPrefix()) -} - -func (m *mcuProxy) waitForConnection() { - waitDelay := initialWaitDelay - for { - if err := m.syncClient(); err != nil { - if err == context.DeadlineExceeded { - log.Printf("Timeout waiting for etcd client to connect to the cluster, retry in %s", waitDelay) - } else { - log.Printf("Could not sync etcd client with the cluster, retry in %s: %s", waitDelay, err) - } - - time.Sleep(waitDelay) - waitDelay = waitDelay * 2 - if waitDelay > maxWaitDelay { - waitDelay = maxWaitDelay - } - continue +func (m *mcuProxy) EtcdClientCreated(client *EtcdClient) { + go func() { + if err := client.Watch(context.Background(), m.keyPrefix, m, clientv3.WithPrefix()); err != nil { + log.Printf("Error processing watch for %s: %s", m.keyPrefix, err) } + }() - log.Printf("Client using endpoints %+v", m.getEtcdClient().Endpoints()) - return - } + go func() { + client.WaitForConnection() + + waitDelay := initialWaitDelay + for { + response, err := m.getProxyUrls(client, m.keyPrefix) + if err != nil { + if err == context.DeadlineExceeded { + log.Printf("Timeout getting initial list of proxy URLs, retry in %s", waitDelay) + } else { + log.Printf("Could not get initial list of proxy URLs, retry in %s: %s", waitDelay, err) + } + + time.Sleep(waitDelay) + waitDelay = waitDelay * 2 + if waitDelay > maxWaitDelay { + waitDelay = maxWaitDelay + } + continue + } + + for _, ev := range response.Kvs { + m.EtcdKeyUpdated(client, string(ev.Key), ev.Value) + } + return + } + }() } -func (m *mcuProxy) syncClient() error { +func (m *mcuProxy) getProxyUrls(client *EtcdClient, keyPrefix string) (*clientv3.GetResponse, error) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - return m.getEtcdClient().Sync(ctx) + return client.Get(ctx, keyPrefix, clientv3.WithPrefix()) } func (m *mcuProxy) Reload(config *goconf.ConfigFile) { @@ -1631,22 +1529,7 @@ func (m *mcuProxy) Reload(config *goconf.ConfigFile) { } } -func (m *mcuProxy) processWatches(ch clientv3.WatchChan) { - for response := range ch { - for _, ev := range response.Events { - switch ev.Type { - case clientv3.EventTypePut: - m.addEtcdProxy(string(ev.Kv.Key), ev.Kv.Value) - case clientv3.EventTypeDelete: - m.removeEtcdProxy(string(ev.Kv.Key)) - default: - log.Printf("Unsupported event %s %q -> %q", ev.Type, ev.Kv.Key, ev.Kv.Value) - } - } - } -} - -func (m *mcuProxy) addEtcdProxy(key string, data []byte) { +func (m *mcuProxy) EtcdKeyUpdated(client *EtcdClient, key string, data []byte) { var info ProxyInformationEtcd if err := json.Unmarshal(data, &info); err != nil { log.Printf("Could not decode proxy information %s: %s", string(data), err) @@ -1700,7 +1583,7 @@ func (m *mcuProxy) addEtcdProxy(key string, data []byte) { } } -func (m *mcuProxy) removeEtcdProxy(key string) { +func (m *mcuProxy) EtcdKeyDeleted(client *EtcdClient, key string) { m.etcdMu.Lock() defer m.etcdMu.Unlock() diff --git a/proxy.conf.in b/proxy.conf.in index c978939..ac17412 100644 --- a/proxy.conf.in +++ b/proxy.conf.in @@ -26,21 +26,6 @@ tokentype = static #server1 = pubkey1.pem #server2 = pubkey2.pem -# For token type "etcd": Comma-separated list of static etcd endpoints to -# connect to. -#endpoints = 127.0.0.1:2379,127.0.0.1:22379,127.0.0.1:32379 - -# For token type "etcd": Options to perform endpoint discovery through DNS SRV. -# Only used if no endpoints are configured manually. -#discoverysrv = example.com -#discoveryservice = foo - -# For token type "etcd": Path to private key, client certificate and CA -# certificate if TLS authentication should be used. -#clientkey = /path/to/etcd-client.key -#clientcert = /path/to/etcd-client.crt -#cacert = /path/to/etcd-ca.crt - # For token type "etcd": Format of key name to retrieve the public key from, # "%s" will be replaced with the token id. Multiple possible formats can be # comma-separated. @@ -65,3 +50,18 @@ url = ws://localhost:8188/ # Comma-separated list of IP addresses that are allowed to access the stats # endpoint. Leave empty (or commented) to only allow access from "127.0.0.1". #allowed_ips = + +[etcd] +# Comma-separated list of static etcd endpoints to connect to. +#endpoints = 127.0.0.1:2379,127.0.0.1:22379,127.0.0.1:32379 + +# Options to perform endpoint discovery through DNS SRV. +# Only used if no endpoints are configured manually. +#discoverysrv = example.com +#discoveryservice = foo + +# Path to private key, client certificate and CA certificate if TLS +# authentication should be used. +#clientkey = /path/to/etcd-client.key +#clientcert = /path/to/etcd-client.crt +#cacert = /path/to/etcd-ca.crt diff --git a/proxy/proxy_tokens_etcd.go b/proxy/proxy_tokens_etcd.go index 74879aa..2779bf1 100644 --- a/proxy/proxy_tokens_etcd.go +++ b/proxy/proxy_tokens_etcd.go @@ -33,10 +33,6 @@ import ( "github.com/dlintw/goconf" "github.com/golang-jwt/jwt" - "go.etcd.io/etcd/client/pkg/v3/srv" - "go.etcd.io/etcd/client/pkg/v3/transport" - clientv3 "go.etcd.io/etcd/client/v3" - signaling "github.com/strukturag/nextcloud-spreed-signaling" ) @@ -50,14 +46,24 @@ type tokenCacheEntry struct { } type tokensEtcd struct { - client atomic.Value + client *signaling.EtcdClient tokenFormats atomic.Value tokenCache *signaling.LruCache } func NewProxyTokensEtcd(config *goconf.ConfigFile) (ProxyTokens, error) { + client, err := signaling.NewEtcdClient(config, "tokens") + if err != nil { + return nil, err + } + + if !client.IsConfigured() { + return nil, fmt.Errorf("No etcd endpoints configured") + } + result := &tokensEtcd{ + client: client, tokenCache: signaling.NewLruCache(tokenCacheSize), } if err := result.load(config, false); err != nil { @@ -67,15 +73,6 @@ func NewProxyTokensEtcd(config *goconf.ConfigFile) (ProxyTokens, error) { return result, nil } -func (t *tokensEtcd) getClient() *clientv3.Client { - c := t.client.Load() - if c == nil { - return nil - } - - return c.(*clientv3.Client) -} - func (t *tokensEtcd) getKeys(id string) []string { format := t.tokenFormats.Load().([]string) var result []string @@ -89,7 +86,7 @@ func (t *tokensEtcd) getByKey(id string, key string) (*ProxyToken, error) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - resp, err := t.getClient().Get(ctx, key) + resp, err := t.client.Get(ctx, key) if err != nil { return nil, err } @@ -139,82 +136,7 @@ func (t *tokensEtcd) Get(id string) (*ProxyToken, error) { } func (t *tokensEtcd) load(config *goconf.ConfigFile, ignoreErrors bool) error { - var endpoints []string - if endpointsString, _ := config.GetString("tokens", "endpoints"); endpointsString != "" { - for _, ep := range strings.Split(endpointsString, ",") { - ep := strings.TrimSpace(ep) - if ep != "" { - endpoints = append(endpoints, ep) - } - } - } else if discoverySrv, _ := config.GetString("tokens", "discoverysrv"); discoverySrv != "" { - discoveryService, _ := config.GetString("tokens", "discoveryservice") - clients, err := srv.GetClient("etcd-client", discoverySrv, discoveryService) - if err != nil { - if !ignoreErrors { - return fmt.Errorf("Could not discover endpoints for %s: %s", discoverySrv, err) - } - } else { - endpoints = clients.Endpoints - } - } - - if len(endpoints) == 0 { - if !ignoreErrors { - return fmt.Errorf("No token endpoints configured") - } - - log.Printf("No token endpoints configured, not changing client") - } else { - cfg := clientv3.Config{ - Endpoints: endpoints, - - // set timeout per request to fail fast when the target endpoint is unavailable - DialTimeout: time.Second, - } - - clientKey, _ := config.GetString("tokens", "clientkey") - clientCert, _ := config.GetString("tokens", "clientcert") - caCert, _ := config.GetString("tokens", "cacert") - if clientKey != "" && clientCert != "" && caCert != "" { - tlsInfo := transport.TLSInfo{ - CertFile: clientCert, - KeyFile: clientKey, - TrustedCAFile: caCert, - } - tlsConfig, err := tlsInfo.ClientConfig() - if err != nil { - if !ignoreErrors { - return fmt.Errorf("Could not setup TLS configuration: %s", err) - } - - log.Printf("Could not setup TLS configuration, will be disabled (%s)", err) - } else { - cfg.TLS = tlsConfig - } - } - - c, err := clientv3.New(cfg) - if err != nil { - if !ignoreErrors { - return err - } - - log.Printf("Could not create new client from token endpoints %+v: %s", endpoints, err) - } else { - prev := t.getClient() - if prev != nil { - prev.Close() - } - t.client.Store(c) - log.Printf("Using token endpoints %+v", endpoints) - } - } - tokenFormat, _ := config.GetString("tokens", "keyformat") - if tokenFormat == "" { - tokenFormat = "/%s" - } formats := strings.Split(tokenFormat, ",") var tokenFormats []string @@ -224,6 +146,9 @@ func (t *tokensEtcd) load(config *goconf.ConfigFile, ignoreErrors bool) error { tokenFormats = append(tokenFormats, f) } } + if len(tokenFormats) == 0 { + tokenFormats = []string{"/%s"} + } t.tokenFormats.Store(tokenFormats) log.Printf("Using %v as token formats", tokenFormats) @@ -237,7 +162,7 @@ func (t *tokensEtcd) Reload(config *goconf.ConfigFile) { } func (t *tokensEtcd) Close() { - if client := t.getClient(); client != nil { - client.Close() + if err := t.client.Close(); err != nil { + log.Printf("Error while closing etcd client: %s", err) } } diff --git a/proxy/proxy_tokens_etcd_test.go b/proxy/proxy_tokens_etcd_test.go index 4fbe8a8..c7a78af 100644 --- a/proxy/proxy_tokens_etcd_test.go +++ b/proxy/proxy_tokens_etcd_test.go @@ -105,7 +105,7 @@ func newTokensEtcdForTesting(t *testing.T) (*tokensEtcd, *embed.Etcd) { etcd := newEtcdForTesting(t) cfg := goconf.NewConfigFile() - cfg.AddOption("tokens", "endpoints", etcd.Config().LCUrls[0].String()) + cfg.AddOption("etcd", "endpoints", etcd.Config().LCUrls[0].String()) cfg.AddOption("tokens", "keyformat", "/%s, /testing/%s/key") tokens, err := NewProxyTokensEtcd(cfg) diff --git a/server.conf.in b/server.conf.in index 9125b6c..f06b475 100644 --- a/server.conf.in +++ b/server.conf.in @@ -165,21 +165,6 @@ connectionsperhost = 8 # or deleted as necessary. #dnsdiscovery = true -# For url type "etcd": Comma-separated list of static etcd endpoints to -# connect to. -#endpoints = 127.0.0.1:2379,127.0.0.1:22379,127.0.0.1:32379 - -# For url type "etcd": Options to perform endpoint discovery through DNS SRV. -# Only used if no endpoints are configured manually. -#discoverysrv = example.com -#discoveryservice = foo - -# For url type "etcd": Path to private key, client certificate and CA -# certificate if TLS authentication should be used. -#clientkey = /path/to/etcd-client.key -#clientcert = /path/to/etcd-client.crt -#cacert = /path/to/etcd-ca.crt - # For url type "etcd": Key prefix of MCU proxy entries. All keys below will be # watched and assumed to contain a JSON document. The entry "address" from this # document will be used as proxy URL, other contents in the document will be @@ -234,3 +219,18 @@ connectionsperhost = 8 # Comma-separated list of IP addresses that are allowed to access the stats # endpoint. Leave empty (or commented) to only allow access from "127.0.0.1". #allowed_ips = + +[etcd] +# Comma-separated list of static etcd endpoints to connect to. +#endpoints = 127.0.0.1:2379,127.0.0.1:22379,127.0.0.1:32379 + +# Options to perform endpoint discovery through DNS SRV. +# Only used if no endpoints are configured manually. +#discoverysrv = example.com +#discoveryservice = foo + +# Path to private key, client certificate and CA certificate if TLS +# authentication should be used. +#clientkey = /path/to/etcd-client.key +#clientcert = /path/to/etcd-client.crt +#cacert = /path/to/etcd-ca.crt diff --git a/server/main.go b/server/main.go index ea14bd6..9db6dcb 100644 --- a/server/main.go +++ b/server/main.go @@ -153,6 +153,16 @@ func main() { log.Fatal("Could not create NATS client: ", err) } + etcdClient, err := signaling.NewEtcdClient(config, "mcu") + if err != nil { + log.Fatalf("Could not create etcd client: %s", err) + } + defer func() { + if err := etcdClient.Close(); err != nil { + log.Printf("Error while closing etcd client: %s", err) + } + }() + r := mux.NewRouter() hub, err := signaling.NewHub(config, nats, r, version) if err != nil { @@ -181,7 +191,7 @@ func main() { signaling.UnregisterProxyMcuStats() signaling.RegisterJanusMcuStats() case signaling.McuTypeProxy: - mcu, err = signaling.NewMcuProxy(config) + mcu, err = signaling.NewMcuProxy(config, etcdClient) signaling.UnregisterJanusMcuStats() signaling.RegisterProxyMcuStats() default: