From 8f2933071e4c9948196dc04e1eac5460173de549 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Fri, 1 Dec 2023 23:42:59 +0100 Subject: [PATCH] Move proxy configuration code to different files. --- mcu_proxy.go | 477 +++++++----------------------------- proxy_config.go | 18 ++ proxy_config_etcd.go | 172 +++++++++++++ proxy_config_etcd_test.go | 88 +++++++ proxy_config_static.go | 225 +++++++++++++++++ proxy_config_static_test.go | 91 +++++++ proxy_config_test.go | 165 +++++++++++++ 7 files changed, 846 insertions(+), 390 deletions(-) create mode 100644 proxy_config.go create mode 100644 proxy_config_etcd.go create mode 100644 proxy_config_etcd_test.go create mode 100644 proxy_config_static.go create mode 100644 proxy_config_static_test.go create mode 100644 proxy_config_test.go diff --git a/mcu_proxy.go b/mcu_proxy.go index ff0adfd..0eb732c 100644 --- a/mcu_proxy.go +++ b/mcu_proxy.go @@ -43,8 +43,6 @@ import ( "github.com/dlintw/goconf" "github.com/golang-jwt/jwt/v4" "github.com/gorilla/websocket" - - clientv3 "go.etcd.io/etcd/client/v3" ) const ( @@ -73,6 +71,12 @@ const ( updateDnsInterval = 10 * time.Second ) +type McuProxy interface { + AddConnection(ignoreErrors bool, url string, ips ...net.IP) error + KeepConnection(url string, ips ...net.IP) + RemoveConnection(url string, ips ...net.IP) +} + type mcuProxyPubSubCommon struct { sid string streamType string @@ -519,9 +523,8 @@ func (c *mcuProxyConnection) writePump() { } } -func (c *mcuProxyConnection) start() error { +func (c *mcuProxyConnection) start() { go c.writePump() - return nil } func (c *mcuProxyConnection) sendClose() error { @@ -1095,12 +1098,7 @@ type mcuProxy struct { urlType string tokenId string tokenKey *rsa.PrivateKey - - etcdMu sync.Mutex - etcdClient *EtcdClient - keyPrefix string - keyInfos map[string]*ProxyInformationEtcd - urlToKey map[string]string + config ProxyConfig dialer *websocket.Dialer connections []*mcuProxyConnection @@ -1110,10 +1108,6 @@ type mcuProxy struct { connRequests atomic.Int64 nextSort atomic.Int64 - dnsDiscovery bool - stopping chan struct{} - stopped chan struct{} - maxStreamBitrate int maxScreenBitrate int @@ -1171,8 +1165,6 @@ func NewMcuProxy(config *goconf.ConfigFile, etcdClient *EtcdClient, rpcClients * tokenId: tokenId, tokenKey: tokenKey, - etcdClient: etcdClient, - dialer: &websocket.Dialer{ Proxy: http.ProxyFromEnvironment, HandshakeTimeout: proxyTimeout, @@ -1180,9 +1172,6 @@ func NewMcuProxy(config *goconf.ConfigFile, etcdClient *EtcdClient, rpcClients * connectionsMap: make(map[string][]*mcuProxyConnection), proxyTimeout: proxyTimeout, - stopping: make(chan struct{}, 1), - stopped: make(chan struct{}, 1), - maxStreamBitrate: maxStreamBitrate, maxScreenBitrate: maxScreenBitrate, @@ -1205,25 +1194,14 @@ func NewMcuProxy(config *goconf.ConfigFile, etcdClient *EtcdClient, rpcClients * switch urlType { case proxyUrlTypeStatic: - if err := mcu.configureStatic(config, false); err != nil { - return nil, err - } - if len(mcu.connections) == 0 { - return nil, fmt.Errorf("No MCU proxy connections configured") - } + mcu.config, err = NewProxyConfigStatic(config, mcu) case proxyUrlTypeEtcd: - if !etcdClient.IsConfigured() { - return nil, fmt.Errorf("No etcd endpoints configured") - } - - mcu.keyInfos = make(map[string]*ProxyInformationEtcd) - mcu.urlToKey = make(map[string]string) - if err := mcu.configureEtcd(config, false); err != nil { - return nil, err - } - mcu.etcdClient.AddListener(mcu) + mcu.config, err = NewProxyConfigEtcd(config, etcdClient, mcu) default: - return nil, fmt.Errorf("Unsupported proxy URL type %s", urlType) + err = fmt.Errorf("Unsupported proxy URL type %s", urlType) + } + if err != nil { + return nil, err } return mcu, nil @@ -1271,310 +1249,121 @@ func (m *mcuProxy) loadContinentsMap(config *goconf.ConfigFile) error { } func (m *mcuProxy) Start() error { - m.connectionsMu.RLock() - defer m.connectionsMu.RUnlock() - log.Printf("Maximum bandwidth %d bits/sec per publishing stream", m.maxStreamBitrate) log.Printf("Maximum bandwidth %d bits/sec per screensharing stream", m.maxScreenBitrate) - for _, c := range m.connections { - if err := c.start(); err != nil { - return err - } - } - - if m.urlType == proxyUrlTypeStatic && m.dnsDiscovery { - go m.monitorProxyIPs() - } - - return nil + return m.config.Start() } func (m *mcuProxy) Stop() { - m.etcdClient.RemoveListener(m) m.connectionsMu.RLock() defer m.connectionsMu.RUnlock() + ctx, cancel := context.WithTimeout(context.Background(), closeTimeout) + defer cancel() for _, c := range m.connections { - ctx, cancel := context.WithTimeout(context.Background(), closeTimeout) - defer cancel() c.stop(ctx) } - if m.urlType == proxyUrlTypeStatic && m.dnsDiscovery { - m.stopping <- struct{}{} - <-m.stopped - } + m.config.Stop() } -func (m *mcuProxy) monitorProxyIPs() { - log.Printf("Start monitoring proxy IPs") - ticker := time.NewTicker(updateDnsInterval) - for { - select { - case <-ticker.C: - m.updateProxyIPs() - case <-m.stopping: - m.stopped <- struct{}{} - return - } - } -} - -func (m *mcuProxy) updateProxyIPs() { +func (m *mcuProxy) AddConnection(ignoreErrors bool, url string, ips ...net.IP) error { m.connectionsMu.Lock() defer m.connectionsMu.Unlock() - for u, conns := range m.connectionsMap { - if len(conns) == 0 { - continue - } - - host := conns[0].url.Host - if h, _, err := net.SplitHostPort(host); err == nil { - host = h - } - - if net.ParseIP(host) != nil { - // No need to lookup endpoints that connect to IP addresses. - continue - } - - ips, err := net.LookupIP(host) + var conns []*mcuProxyConnection + if len(ips) == 0 { + conn, err := newMcuProxyConnection(m, url, nil) if err != nil { - log.Printf("Could not lookup %s: %s", host, err) - continue - } - - var newConns []*mcuProxyConnection - changed := false - for _, conn := range conns { - found := false - for idx, ip := range ips { - if ip.Equal(conn.ip) { - ips = append(ips[:idx], ips[idx+1:]...) - found = true - conn.stopCloseIfEmpty() - conn.clearTemporary() - newConns = append(newConns, conn) - break - } + if ignoreErrors { + log.Printf("Could not create proxy connection to %s: %s", url, err) + return nil } - if !found { - changed = true - log.Printf("Removing connection to %s", conn) - conn.closeIfEmpty() - } + return err } + conns = append(conns, conn) + } else { for _, ip := range ips { - conn, err := newMcuProxyConnection(m, u, ip) + conn, err := newMcuProxyConnection(m, url, ip) if err != nil { - log.Printf("Could not create proxy connection to %s (%s): %s", u, ip, err) - continue - } - - if err := conn.start(); err != nil { - log.Printf("Could not start new connection to %s: %s", conn, err) - continue - } - - log.Printf("Adding new connection to %s", conn) - m.connections = append(m.connections, conn) - newConns = append(newConns, conn) - changed = true - } - - if changed { - m.connectionsMap[u] = newConns - } - } -} - -func (m *mcuProxy) configureStatic(config *goconf.ConfigFile, fromReload bool) error { - m.connectionsMu.Lock() - defer m.connectionsMu.Unlock() - - remove := make(map[string][]*mcuProxyConnection) - for u, conns := range m.connectionsMap { - remove[u] = conns - } - created := make(map[string][]*mcuProxyConnection) - changed := false - - mcuUrl, _ := config.GetString("mcu", "url") - dnsDiscovery, _ := config.GetBool("mcu", "dnsdiscovery") - if dnsDiscovery != m.dnsDiscovery { - if !dnsDiscovery && fromReload { - m.stopping <- struct{}{} - <-m.stopped - } - m.dnsDiscovery = dnsDiscovery - if dnsDiscovery && fromReload { - go m.monitorProxyIPs() - } - } - - for _, u := range strings.Split(mcuUrl, " ") { - if existing, found := remove[u]; found { - // Proxy connection still exists in new configuration - delete(remove, u) - for _, conn := range existing { - conn.stopCloseIfEmpty() - conn.clearTemporary() - } - continue - } - - var ips []net.IP - if dnsDiscovery { - parsed, err := url.Parse(u) - if err != nil { - if !fromReload { - return err + if ignoreErrors { + log.Printf("Could not create proxy connection to %s (%s): %s", url, ip, err) + continue } - log.Printf("Could not parse URL %s: %s", u, err) - continue - } - - if host, _, err := net.SplitHostPort(parsed.Host); err == nil { - parsed.Host = host - } - - ips, err = net.LookupIP(parsed.Host) - if err != nil { - // Will be retried later. - log.Printf("Could not lookup %s: %s\n", parsed.Host, err) - continue - } - } - - var conns []*mcuProxyConnection - if ips == nil { - conn, err := newMcuProxyConnection(m, u, nil) - if err != nil { - if !fromReload { - return err - } - - log.Printf("Could not create proxy connection to %s: %s", u, err) - continue + return err } conns = append(conns, conn) + } + } + + for _, conn := range conns { + log.Printf("Adding new connection to %s", conn) + conn.start() + + m.connections = append(m.connections, conn) + if existing, found := m.connectionsMap[url]; found { + m.connectionsMap[url] = append(existing, conn) } else { - for _, ip := range ips { - conn, err := newMcuProxyConnection(m, u, ip) - if err != nil { - if !fromReload { - return err - } - - log.Printf("Could not create proxy connection to %s (%s): %s", u, ip, err) - continue - } - - conns = append(conns, conn) - } - } - created[u] = conns - } - - for _, conns := range remove { - for _, conn := range conns { - go conn.closeIfEmpty() + m.connectionsMap[url] = []*mcuProxyConnection{conn} } } - if fromReload { - for u, conns := range created { - var started []*mcuProxyConnection - for _, conn := range conns { - if err := conn.start(); err != nil { - log.Printf("Could not start new connection to %s: %s", conn, err) - continue - } + m.nextSort.Store(0) + return nil +} - log.Printf("Adding new connection to %s", conn) - started = append(started, conn) - m.connections = append(m.connections, conn) - } - - if len(started) > 0 { - m.connectionsMap[u] = started - changed = true - } +func containsIP(ips []net.IP, ip net.IP) bool { + for _, i := range ips { + if i.Equal(ip) { + return true } + } - if changed { - m.nextSort.Store(0) - } + return false +} + +func (m *mcuProxy) iterateConnections(url string, ips []net.IP, f func(conn *mcuProxyConnection)) { + m.connectionsMu.Lock() + defer m.connectionsMu.Unlock() + + conns, found := m.connectionsMap[url] + if !found { + return + } + + var toRemove []*mcuProxyConnection + if len(ips) == 0 { + toRemove = conns } else { - for u, conns := range created { - m.connections = append(m.connections, conns...) - m.connectionsMap[u] = conns + for _, conn := range conns { + if containsIP(ips, conn.ip) { + toRemove = append(toRemove, conn) + } } } - return nil -} - -func (m *mcuProxy) configureEtcd(config *goconf.ConfigFile, ignoreErrors bool) error { - keyPrefix, _ := config.GetString("mcu", "keyprefix") - if keyPrefix == "" { - keyPrefix = "/%s" + for _, conn := range toRemove { + f(conn) } - - m.keyPrefix = keyPrefix - return nil } -func (m *mcuProxy) EtcdClientCreated(client *EtcdClient) { - go func() { - if err := client.Watch(context.Background(), m.keyPrefix, m, clientv3.WithPrefix()); err != nil { - log.Printf("Error processing watch for %s: %s", m.keyPrefix, err) - } - }() - - go func() { - client.WaitForConnection() - - waitDelay := initialWaitDelay - for { - response, err := m.getProxyUrls(client, m.keyPrefix) - if err != nil { - if err == context.DeadlineExceeded { - log.Printf("Timeout getting initial list of proxy URLs, retry in %s", waitDelay) - } else { - log.Printf("Could not get initial list of proxy URLs, retry in %s: %s", waitDelay, err) - } - - time.Sleep(waitDelay) - waitDelay = waitDelay * 2 - if waitDelay > maxWaitDelay { - waitDelay = maxWaitDelay - } - continue - } - - for _, ev := range response.Kvs { - m.EtcdKeyUpdated(client, string(ev.Key), ev.Value) - } - return - } - }() +func (m *mcuProxy) RemoveConnection(url string, ips ...net.IP) { + m.iterateConnections(url, ips, func(conn *mcuProxyConnection) { + log.Printf("Removing connection to %s", conn) + conn.closeIfEmpty() + }) } -func (m *mcuProxy) EtcdWatchCreated(client *EtcdClient, key string) { -} - -func (m *mcuProxy) getProxyUrls(client *EtcdClient, keyPrefix string) (*clientv3.GetResponse, error) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - - return client.Get(ctx, keyPrefix, clientv3.WithPrefix()) +func (m *mcuProxy) KeepConnection(url string, ips ...net.IP) { + m.iterateConnections(url, ips, func(conn *mcuProxyConnection) { + conn.stopCloseIfEmpty() + conn.clearTemporary() + }) } func (m *mcuProxy) Reload(config *goconf.ConfigFile) { @@ -1582,95 +1371,8 @@ func (m *mcuProxy) Reload(config *goconf.ConfigFile) { log.Printf("Error loading continents map: %s", err) } - switch m.urlType { - case proxyUrlTypeStatic: - if err := m.configureStatic(config, true); err != nil { - log.Printf("Could not configure static proxy urls: %s", err) - } - default: - // Reloading not supported yet. - } -} - -func (m *mcuProxy) EtcdKeyUpdated(client *EtcdClient, key string, data []byte) { - var info ProxyInformationEtcd - if err := json.Unmarshal(data, &info); err != nil { - log.Printf("Could not decode proxy information %s: %s", string(data), err) - return - } - if err := info.CheckValid(); err != nil { - log.Printf("Received invalid proxy information %s: %s", string(data), err) - return - } - - m.etcdMu.Lock() - defer m.etcdMu.Unlock() - - prev, found := m.keyInfos[key] - if found && info.Address != prev.Address { - // Address of a proxy has changed. - m.removeEtcdProxyLocked(key) - } - - if otherKey, found := m.urlToKey[info.Address]; found && otherKey != key { - log.Printf("Address %s is already registered for key %s, ignoring %s", info.Address, otherKey, key) - return - } - - m.connectionsMu.Lock() - defer m.connectionsMu.Unlock() - if conns, found := m.connectionsMap[info.Address]; found { - m.keyInfos[key] = &info - m.urlToKey[info.Address] = key - for _, conn := range conns { - conn.stopCloseIfEmpty() - conn.clearTemporary() - } - } else { - conn, err := newMcuProxyConnection(m, info.Address, nil) - if err != nil { - log.Printf("Could not create proxy connection to %s: %s", info.Address, err) - return - } - - if err := conn.start(); err != nil { - log.Printf("Could not start new connection to %s: %s", info.Address, err) - return - } - - log.Printf("Adding new connection to %s (from %s)", info.Address, key) - m.keyInfos[key] = &info - m.urlToKey[info.Address] = key - m.connections = append(m.connections, conn) - m.connectionsMap[info.Address] = []*mcuProxyConnection{conn} - m.nextSort.Store(0) - } -} - -func (m *mcuProxy) EtcdKeyDeleted(client *EtcdClient, key string) { - m.etcdMu.Lock() - defer m.etcdMu.Unlock() - - m.removeEtcdProxyLocked(key) -} - -func (m *mcuProxy) removeEtcdProxyLocked(key string) { - info, found := m.keyInfos[key] - if !found { - return - } - - delete(m.keyInfos, key) - delete(m.urlToKey, info.Address) - - log.Printf("Removing connection to %s (from %s)", info.Address, key) - - m.connectionsMu.RLock() - defer m.connectionsMu.RUnlock() - if conns, found := m.connectionsMap[info.Address]; found { - for _, conn := range conns { - go conn.closeIfEmpty() - } + if err := m.config.Reload(config); err != nil { + log.Printf("could not reload proxy configuration: %s", err) } } @@ -2011,14 +1713,9 @@ func (m *mcuProxy) NewSubscriber(ctx context.Context, listener McuListener, publ log.Printf("Could not create temporary connection to %s for %s publisher %s: %s", url, streamType, publisher, err) return } + publisherConn.setTemporary() - - if err := publisherConn.start(); err != nil { - log.Printf("Could not start new connection to %s: %s", publisherConn, err) - publisherConn.closeIfEmpty() - return - } - + publisherConn.start() if err := publisherConn.waitUntilConnected(ctx); err != nil { log.Printf("Could not establish new connection to %s: %s", publisherConn, err) publisherConn.closeIfEmpty() diff --git a/proxy_config.go b/proxy_config.go new file mode 100644 index 0000000..be250e4 --- /dev/null +++ b/proxy_config.go @@ -0,0 +1,18 @@ +package signaling + +import ( + "net" + + "github.com/dlintw/goconf" +) + +var ( + lookupProxyIP = net.LookupIP +) + +type ProxyConfig interface { + Start() error + Stop() + + Reload(config *goconf.ConfigFile) error +} diff --git a/proxy_config_etcd.go b/proxy_config_etcd.go new file mode 100644 index 0000000..f79919b --- /dev/null +++ b/proxy_config_etcd.go @@ -0,0 +1,172 @@ +package signaling + +import ( + "context" + "encoding/json" + "errors" + "log" + "sync" + "time" + + "github.com/dlintw/goconf" + clientv3 "go.etcd.io/etcd/client/v3" +) + +type proxyConfigEtcd struct { + mu sync.Mutex + proxy McuProxy + + client *EtcdClient + keyPrefix string + keyInfos map[string]*ProxyInformationEtcd + urlToKey map[string]string +} + +func NewProxyConfigEtcd(config *goconf.ConfigFile, etcdClient *EtcdClient, proxy McuProxy) (ProxyConfig, error) { + if !etcdClient.IsConfigured() { + return nil, errors.New("No etcd endpoints configured") + } + + result := &proxyConfigEtcd{ + proxy: proxy, + + client: etcdClient, + keyInfos: make(map[string]*ProxyInformationEtcd), + urlToKey: make(map[string]string), + } + if err := result.configure(config, false); err != nil { + return nil, err + } + return result, nil +} + +func (p *proxyConfigEtcd) configure(config *goconf.ConfigFile, fromReload bool) error { + keyPrefix, _ := config.GetString("mcu", "keyprefix") + if keyPrefix == "" { + keyPrefix = "/%s" + } + + p.keyPrefix = keyPrefix + return nil +} + +func (p *proxyConfigEtcd) Start() error { + p.client.AddListener(p) + return nil +} + +func (p *proxyConfigEtcd) Reload(config *goconf.ConfigFile) error { + // not implemented + return nil +} + +func (p *proxyConfigEtcd) Stop() { + p.client.RemoveListener(p) +} + +func (p *proxyConfigEtcd) EtcdClientCreated(client *EtcdClient) { + go func() { + if err := client.Watch(context.Background(), p.keyPrefix, p, clientv3.WithPrefix()); err != nil { + log.Printf("Error processing watch for %s: %s", p.keyPrefix, err) + } + }() + + go func() { + client.WaitForConnection() + + waitDelay := initialWaitDelay + for { + response, err := p.getProxyUrls(client, p.keyPrefix) + if err != nil { + if err == context.DeadlineExceeded { + log.Printf("Timeout getting initial list of proxy URLs, retry in %s", waitDelay) + } else { + log.Printf("Could not get initial list of proxy URLs, retry in %s: %s", waitDelay, err) + } + + time.Sleep(waitDelay) + waitDelay = waitDelay * 2 + if waitDelay > maxWaitDelay { + waitDelay = maxWaitDelay + } + continue + } + + for _, ev := range response.Kvs { + p.EtcdKeyUpdated(client, string(ev.Key), ev.Value) + } + return + } + }() +} + +func (p *proxyConfigEtcd) EtcdWatchCreated(client *EtcdClient, key string) { +} + +func (p *proxyConfigEtcd) getProxyUrls(client *EtcdClient, keyPrefix string) (*clientv3.GetResponse, error) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + return client.Get(ctx, keyPrefix, clientv3.WithPrefix()) +} + +func (p *proxyConfigEtcd) EtcdKeyUpdated(client *EtcdClient, key string, data []byte) { + var info ProxyInformationEtcd + if err := json.Unmarshal(data, &info); err != nil { + log.Printf("Could not decode proxy information %s: %s", string(data), err) + return + } + if err := info.CheckValid(); err != nil { + log.Printf("Received invalid proxy information %s: %s", string(data), err) + return + } + + p.mu.Lock() + defer p.mu.Unlock() + + prev, found := p.keyInfos[key] + if found && info.Address != prev.Address { + // Address of a proxy has changed. + p.removeEtcdProxyLocked(key) + found = false + } + + if otherKey, otherFound := p.urlToKey[info.Address]; otherFound && otherKey != key { + log.Printf("Address %s is already registered for key %s, ignoring %s", info.Address, otherKey, key) + return + } + + if found { + p.keyInfos[key] = &info + p.proxy.KeepConnection(info.Address) + } else { + if err := p.proxy.AddConnection(false, info.Address); err != nil { + log.Printf("Could not create proxy connection to %s: %s", info.Address, err) + return + } + + log.Printf("Added new connection to %s (from %s)", info.Address, key) + p.keyInfos[key] = &info + p.urlToKey[info.Address] = key + } +} + +func (p *proxyConfigEtcd) EtcdKeyDeleted(client *EtcdClient, key string) { + p.mu.Lock() + defer p.mu.Unlock() + + p.removeEtcdProxyLocked(key) +} + +func (p *proxyConfigEtcd) removeEtcdProxyLocked(key string) { + info, found := p.keyInfos[key] + if !found { + return + } + + delete(p.keyInfos, key) + delete(p.urlToKey, info.Address) + + log.Printf("Removing connection to %s (from %s)", info.Address, key) + p.proxy.RemoveConnection(info.Address) +} diff --git a/proxy_config_etcd_test.go b/proxy_config_etcd_test.go new file mode 100644 index 0000000..456fae4 --- /dev/null +++ b/proxy_config_etcd_test.go @@ -0,0 +1,88 @@ +package signaling + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/dlintw/goconf" + "go.etcd.io/etcd/server/v3/embed" +) + +type TestProxyInformationEtcd struct { + Address string `json:"address"` + + OtherData string `json:"otherdata,omitempty"` +} + +func newProxyConfigEtcd(t *testing.T, proxy McuProxy) (*embed.Etcd, ProxyConfig) { + t.Helper() + etcd, client := NewEtcdClientForTest(t) + cfg := goconf.NewConfigFile() + cfg.AddOption("mcu", "keyprefix", "proxies/") + p, err := NewProxyConfigEtcd(cfg, client, proxy) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + p.Stop() + }) + return etcd, p +} + +func SetEtcdProxy(t *testing.T, etcd *embed.Etcd, path string, proxy *TestProxyInformationEtcd) { + t.Helper() + data, err := json.Marshal(proxy) + if err != nil { + t.Fatal(err) + } + SetEtcdValue(etcd, path, data) +} + +func TestProxyConfigEtcd(t *testing.T) { + proxy := newMcuProxyForConfig(t) + etcd, config := newProxyConfigEtcd(t, proxy) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + SetEtcdProxy(t, etcd, "proxies/a", &TestProxyInformationEtcd{ + Address: "https://foo/", + }) + proxy.Expect("add", "https://foo/") + if err := config.Start(); err != nil { + t.Fatal(err) + } + proxy.WaitForEvents(ctx) + + proxy.Expect("add", "https://bar/") + SetEtcdProxy(t, etcd, "proxies/b", &TestProxyInformationEtcd{ + Address: "https://bar/", + }) + proxy.WaitForEvents(ctx) + + proxy.Expect("keep", "https://bar/") + SetEtcdProxy(t, etcd, "proxies/b", &TestProxyInformationEtcd{ + Address: "https://bar/", + OtherData: "ignore-me", + }) + proxy.WaitForEvents(ctx) + + proxy.Expect("remove", "https://foo/") + DeleteEtcdValue(etcd, "proxies/a") + proxy.WaitForEvents(ctx) + + proxy.Expect("remove", "https://bar/") + proxy.Expect("add", "https://baz/") + SetEtcdProxy(t, etcd, "proxies/b", &TestProxyInformationEtcd{ + Address: "https://baz/", + }) + proxy.WaitForEvents(ctx) + + // Adding the same hostname multiple times should not trigger an event. + SetEtcdProxy(t, etcd, "proxies/c", &TestProxyInformationEtcd{ + Address: "https://baz/", + }) + time.Sleep(100 * time.Millisecond) +} diff --git a/proxy_config_static.go b/proxy_config_static.go new file mode 100644 index 0000000..45f467a --- /dev/null +++ b/proxy_config_static.go @@ -0,0 +1,225 @@ +package signaling + +import ( + "errors" + "log" + "net" + "net/url" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/dlintw/goconf" +) + +type ipList struct { + hostname string + + ips []net.IP +} + +type proxyConfigStatic struct { + mu sync.Mutex + proxy McuProxy + + dnsDiscovery atomic.Bool + stopping chan struct{} + stopped chan struct{} + + connectionsMap map[string]*ipList +} + +func NewProxyConfigStatic(config *goconf.ConfigFile, proxy McuProxy) (ProxyConfig, error) { + result := &proxyConfigStatic{ + proxy: proxy, + + stopping: make(chan struct{}, 1), + stopped: make(chan struct{}, 1), + + connectionsMap: make(map[string]*ipList), + } + if err := result.configure(config, false); err != nil { + return nil, err + } + if len(result.connectionsMap) == 0 { + return nil, errors.New("No MCU proxy connections configured") + } + return result, nil +} + +func (p *proxyConfigStatic) configure(config *goconf.ConfigFile, fromReload bool) error { + dnsDiscovery, _ := config.GetBool("mcu", "dnsdiscovery") + if p.dnsDiscovery.CompareAndSwap(!dnsDiscovery, dnsDiscovery) && fromReload { + if !dnsDiscovery { + p.stopping <- struct{}{} + <-p.stopped + } else { + go p.monitorProxyIPs() + } + } + + p.mu.Lock() + defer p.mu.Unlock() + + remove := make(map[string]*ipList) + for u, ips := range p.connectionsMap { + remove[u] = ips + } + + mcuUrl, _ := config.GetString("mcu", "url") + for _, u := range strings.Split(mcuUrl, " ") { + u = strings.TrimSpace(u) + if u == "" { + continue + } + + if existing, found := remove[u]; found { + // Proxy connection still exists in new configuration + delete(remove, u) + p.proxy.KeepConnection(u, existing.ips...) + continue + } + + parsed, err := url.Parse(u) + if err != nil { + if !fromReload { + return err + } + + log.Printf("Could not parse URL %s: %s", u, err) + continue + } + + if host, _, err := net.SplitHostPort(parsed.Host); err == nil { + parsed.Host = host + } + + var ips []net.IP + if dnsDiscovery { + ips, err = lookupProxyIP(parsed.Host) + if err != nil { + // Will be retried later. + log.Printf("Could not lookup %s: %s\n", parsed.Host, err) + continue + } + } + + if fromReload { + if err := p.proxy.AddConnection(fromReload, u, ips...); err != nil { + if !fromReload { + return err + } + + log.Printf("Could not create proxy connection to %s: %s", u, err) + continue + } + } + + p.connectionsMap[u] = &ipList{ + hostname: parsed.Host, + ips: ips, + } + } + + for u, entry := range remove { + p.proxy.RemoveConnection(u, entry.ips...) + delete(p.connectionsMap, u) + } + + return nil +} + +func (p *proxyConfigStatic) Start() error { + p.mu.Lock() + defer p.mu.Unlock() + + for u, ipList := range p.connectionsMap { + if err := p.proxy.AddConnection(false, u, ipList.ips...); err != nil { + return err + } + } + + if p.dnsDiscovery.Load() { + go p.monitorProxyIPs() + } + return nil +} + +func (p *proxyConfigStatic) Stop() { + if p.dnsDiscovery.CompareAndSwap(true, false) { + p.stopping <- struct{}{} + <-p.stopped + } +} + +func (p *proxyConfigStatic) Reload(config *goconf.ConfigFile) error { + return p.configure(config, true) +} + +func (p *proxyConfigStatic) monitorProxyIPs() { + log.Printf("Start monitoring proxy IPs") + ticker := time.NewTicker(updateDnsInterval) + for { + select { + case <-ticker.C: + p.updateProxyIPs() + case <-p.stopping: + p.stopped <- struct{}{} + return + } + } +} + +func (p *proxyConfigStatic) updateProxyIPs() { + p.mu.Lock() + defer p.mu.Unlock() + + for u, iplist := range p.connectionsMap { + if len(iplist.ips) == 0 { + continue + } + + if net.ParseIP(iplist.hostname) != nil { + // No need to lookup endpoints that connect to IP addresses. + continue + } + + ips, err := lookupProxyIP(iplist.hostname) + if err != nil { + log.Printf("Could not lookup %s: %s", iplist.hostname, err) + continue + } + + var newIPs []net.IP + var removedIPs []net.IP + for _, oldIP := range iplist.ips { + found := false + for idx, newIP := range ips { + if oldIP.Equal(newIP) { + ips = append(ips[:idx], ips[idx+1:]...) + found = true + p.proxy.KeepConnection(u, oldIP) + newIPs = append(newIPs, oldIP) + break + } + } + + if !found { + removedIPs = append(removedIPs, oldIP) + } + } + + if len(ips) > 0 { + newIPs = append(newIPs, ips...) + if err := p.proxy.AddConnection(true, u, ips...); err != nil { + log.Printf("Could not add proxy connection to %s with %+v: %s", u, ips, err) + } + } + iplist.ips = newIPs + + if len(removedIPs) > 0 { + p.proxy.RemoveConnection(u, removedIPs...) + } + } +} diff --git a/proxy_config_static_test.go b/proxy_config_static_test.go new file mode 100644 index 0000000..e9962bf --- /dev/null +++ b/proxy_config_static_test.go @@ -0,0 +1,91 @@ +package signaling + +import ( + "net" + "strings" + "testing" + + "github.com/dlintw/goconf" +) + +func newProxyConfigStatic(t *testing.T, proxy McuProxy, dns bool, urls ...string) ProxyConfig { + cfg := goconf.NewConfigFile() + cfg.AddOption("mcu", "url", strings.Join(urls, " ")) + if dns { + cfg.AddOption("mcu", "dnsdiscovery", "true") + } + p, err := NewProxyConfigStatic(cfg, proxy) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + p.Stop() + }) + return p +} + +func updateProxyConfigStatic(t *testing.T, config ProxyConfig, dns bool, urls ...string) { + cfg := goconf.NewConfigFile() + cfg.AddOption("mcu", "url", strings.Join(urls, " ")) + if dns { + cfg.AddOption("mcu", "dnsdiscovery", "true") + } + if err := config.Reload(cfg); err != nil { + t.Fatal(err) + } +} + +func TestProxyConfigStaticSimple(t *testing.T) { + proxy := newMcuProxyForConfig(t) + config := newProxyConfigStatic(t, proxy, false, "https://foo/") + proxy.Expect("add", "https://foo/") + if err := config.Start(); err != nil { + t.Fatal(err) + } + + proxy.Expect("keep", "https://foo/") + proxy.Expect("add", "https://bar/") + updateProxyConfigStatic(t, config, false, "https://foo/", "https://bar/") + + proxy.Expect("keep", "https://bar/") + proxy.Expect("add", "https://baz/") + proxy.Expect("remove", "https://foo/") + updateProxyConfigStatic(t, config, false, "https://bar/", "https://baz/") +} + +func TestProxyConfigStaticDNS(t *testing.T) { + old := lookupProxyIP + t.Cleanup(func() { + lookupProxyIP = old + }) + proxyIPs := make(map[string][]net.IP) + lookupProxyIP = func(hostname string) ([]net.IP, error) { + ips := append([]net.IP{}, proxyIPs[hostname]...) + return ips, nil + } + proxyIPs["foo"] = []net.IP{ + net.ParseIP("192.168.0.1"), + net.ParseIP("10.1.2.3"), + } + + proxy := newMcuProxyForConfig(t) + config := newProxyConfigStatic(t, proxy, true, "https://foo/").(*proxyConfigStatic) + proxy.Expect("add", "https://foo/", proxyIPs["foo"]...) + if err := config.Start(); err != nil { + t.Fatal(err) + } + + proxyIPs["foo"] = []net.IP{ + net.ParseIP("192.168.0.1"), + net.ParseIP("192.168.1.1"), + net.ParseIP("192.168.1.2"), + } + proxy.Expect("keep", "https://foo/", net.ParseIP("192.168.0.1")) + proxy.Expect("add", "https://foo/", net.ParseIP("192.168.1.1"), net.ParseIP("192.168.1.2")) + proxy.Expect("remove", "https://foo/", net.ParseIP("10.1.2.3")) + config.updateProxyIPs() + + proxy.Expect("add", "https://bar/") + proxy.Expect("remove", "https://foo/", proxyIPs["foo"]...) + updateProxyConfigStatic(t, config, false, "https://bar/") +} diff --git a/proxy_config_test.go b/proxy_config_test.go new file mode 100644 index 0000000..1004106 --- /dev/null +++ b/proxy_config_test.go @@ -0,0 +1,165 @@ +package signaling + +import ( + "context" + "net" + "reflect" + "runtime" + "strings" + "sync" + "testing" +) + +var ( + thisFilename string +) + +func init() { + pc := make([]uintptr, 1) + count := runtime.Callers(1, pc) + frames := runtime.CallersFrames(pc[:count]) + frame, _ := frames.Next() + thisFilename = frame.File +} + +type proxyConfigEvent struct { + action string + url string + ips []net.IP +} + +type mcuProxyForConfig struct { + t *testing.T + expected []proxyConfigEvent + mu sync.Mutex + waiters []chan struct{} +} + +func newMcuProxyForConfig(t *testing.T) *mcuProxyForConfig { + proxy := &mcuProxyForConfig{ + t: t, + } + t.Cleanup(func() { + if len(proxy.expected) > 0 { + t.Errorf("expected events %+v were not triggered", proxy.expected) + } + }) + return proxy +} + +func (p *mcuProxyForConfig) Expect(action string, url string, ips ...net.IP) { + if len(ips) == 0 { + ips = nil + } + + p.mu.Lock() + defer p.mu.Unlock() + + p.expected = append(p.expected, proxyConfigEvent{ + action: action, + url: url, + ips: ips, + }) +} + +func (p *mcuProxyForConfig) WaitForEvents(ctx context.Context) { + p.t.Helper() + + p.mu.Lock() + defer p.mu.Unlock() + + if len(p.expected) == 0 { + return + } + + waiter := make(chan struct{}) + p.waiters = append(p.waiters, waiter) + p.mu.Unlock() + defer p.mu.Lock() + select { + case <-ctx.Done(): + p.t.Error(ctx.Err()) + case <-waiter: + } +} + +func (p *mcuProxyForConfig) checkEvent(event *proxyConfigEvent) { + p.t.Helper() + pc := make([]uintptr, 32) + count := runtime.Callers(2, pc) + frames := runtime.CallersFrames(pc[:count]) + var caller runtime.Frame + for { + frame, more := frames.Next() + if frame.File != thisFilename && strings.HasSuffix(frame.File, "_test.go") { + caller = frame + break + } + if !more { + break + } + } + + if len(p.expected) == 0 { + p.t.Errorf("no event expected, got %+v from %s:%d", event, caller.File, caller.Line) + return + } + + defer func() { + if len(p.expected) == 0 { + p.mu.Lock() + waiters := p.waiters + p.waiters = nil + p.mu.Unlock() + + for _, ch := range waiters { + ch <- struct{}{} + } + } + }() + + p.mu.Lock() + defer p.mu.Unlock() + expected := p.expected[0] + p.expected = p.expected[1:] + if !reflect.DeepEqual(expected, *event) { + p.t.Errorf("expected %+v, got %+v from %s:%d", expected, event, caller.File, caller.Line) + } +} + +func (p *mcuProxyForConfig) AddConnection(ignoreErrors bool, url string, ips ...net.IP) error { + p.t.Helper() + if len(ips) == 0 { + ips = nil + } + p.checkEvent(&proxyConfigEvent{ + action: "add", + url: url, + ips: ips, + }) + return nil +} + +func (p *mcuProxyForConfig) KeepConnection(url string, ips ...net.IP) { + p.t.Helper() + if len(ips) == 0 { + ips = nil + } + p.checkEvent(&proxyConfigEvent{ + action: "keep", + url: url, + ips: ips, + }) +} + +func (p *mcuProxyForConfig) RemoveConnection(url string, ips ...net.IP) { + p.t.Helper() + if len(ips) == 0 { + ips = nil + } + p.checkEvent(&proxyConfigEvent{ + action: "remove", + url: url, + ips: ips, + }) +}