diff --git a/mcu_proxy.go b/mcu_proxy.go index 522d065..67ea4fb 100644 --- a/mcu_proxy.go +++ b/mcu_proxy.go @@ -1123,7 +1123,7 @@ type mcuProxy struct { rpcClients *GrpcClients } -func NewMcuProxy(config *goconf.ConfigFile, etcdClient *EtcdClient, rpcClients *GrpcClients) (Mcu, error) { +func NewMcuProxy(config *goconf.ConfigFile, etcdClient *EtcdClient, rpcClients *GrpcClients, dnsMonitor *DnsMonitor) (Mcu, error) { urlType, _ := config.GetString("mcu", "urltype") if urlType == "" { urlType = proxyUrlTypeStatic @@ -1196,7 +1196,7 @@ func NewMcuProxy(config *goconf.ConfigFile, etcdClient *EtcdClient, rpcClients * switch urlType { case proxyUrlTypeStatic: - mcu.config, err = NewProxyConfigStatic(config, mcu) + mcu.config, err = NewProxyConfigStatic(config, mcu, dnsMonitor) case proxyUrlTypeEtcd: mcu.config, err = NewProxyConfigEtcd(config, etcdClient, mcu) default: diff --git a/proxy_config.go b/proxy_config.go index 7d964ee..2a4102c 100644 --- a/proxy_config.go +++ b/proxy_config.go @@ -22,15 +22,9 @@ package signaling import ( - "net" - "github.com/dlintw/goconf" ) -var ( - lookupProxyIP = net.LookupIP -) - type ProxyConfig interface { Start() error Stop() diff --git a/proxy_config_static.go b/proxy_config_static.go index 21bfb6d..9e1a887 100644 --- a/proxy_config_static.go +++ b/proxy_config_static.go @@ -28,8 +28,6 @@ import ( "net/url" "strings" "sync" - "sync/atomic" - "time" "github.com/dlintw/goconf" ) @@ -37,27 +35,24 @@ import ( type ipList struct { hostname string - ips []net.IP + entry *DnsMonitorEntry + ips []net.IP } type proxyConfigStatic struct { mu sync.Mutex proxy McuProxy - dnsDiscovery atomic.Bool - stopping chan struct{} - stopped chan struct{} + dnsMonitor *DnsMonitor + dnsDiscovery bool connectionsMap map[string]*ipList } -func NewProxyConfigStatic(config *goconf.ConfigFile, proxy McuProxy) (ProxyConfig, error) { +func NewProxyConfigStatic(config *goconf.ConfigFile, proxy McuProxy, dnsMonitor *DnsMonitor) (ProxyConfig, error) { result := &proxyConfigStatic{ - proxy: proxy, - - stopping: make(chan struct{}, 1), - stopped: make(chan struct{}, 1), - + proxy: proxy, + dnsMonitor: dnsMonitor, connectionsMap: make(map[string]*ipList), } if err := result.configure(config, false); err != nil { @@ -70,19 +65,22 @@ func NewProxyConfigStatic(config *goconf.ConfigFile, proxy McuProxy) (ProxyConfi } func (p *proxyConfigStatic) configure(config *goconf.ConfigFile, fromReload bool) error { - dnsDiscovery, _ := config.GetBool("mcu", "dnsdiscovery") - if p.dnsDiscovery.CompareAndSwap(!dnsDiscovery, dnsDiscovery) && fromReload { - if !dnsDiscovery { - p.stopping <- struct{}{} - <-p.stopped - } else { - go p.monitorProxyIPs() - } - } - p.mu.Lock() defer p.mu.Unlock() + dnsDiscovery, _ := config.GetBool("mcu", "dnsdiscovery") + if dnsDiscovery != p.dnsDiscovery { + if !dnsDiscovery { + for _, ips := range p.connectionsMap { + if ips.entry != nil { + p.dnsMonitor.Remove(ips.entry) + ips.entry = nil + } + } + } + p.dnsDiscovery = dnsDiscovery + } + remove := make(map[string]*ipList) for u, ips := range p.connectionsMap { remove[u] = ips @@ -116,18 +114,15 @@ func (p *proxyConfigStatic) configure(config *goconf.ConfigFile, fromReload bool parsed.Host = host } - var ips []net.IP if dnsDiscovery { - ips, err = lookupProxyIP(parsed.Host) - if err != nil { - // Will be retried later. - log.Printf("Could not lookup %s: %s\n", parsed.Host, err) - continue + p.connectionsMap[u] = &ipList{ + hostname: parsed.Host, } + continue } if fromReload { - if err := p.proxy.AddConnection(fromReload, u, ips...); err != nil { + if err := p.proxy.AddConnection(fromReload, u); err != nil { if !fromReload { return err } @@ -139,7 +134,6 @@ func (p *proxyConfigStatic) configure(config *goconf.ConfigFile, fromReload bool p.connectionsMap[u] = &ipList{ hostname: parsed.Host, - ips: ips, } } @@ -155,92 +149,53 @@ func (p *proxyConfigStatic) Start() error { p.mu.Lock() defer p.mu.Unlock() - for u, ipList := range p.connectionsMap { - if err := p.proxy.AddConnection(false, u, ipList.ips...); err != nil { - return err + if p.dnsDiscovery { + for u, ips := range p.connectionsMap { + entry, err := p.dnsMonitor.Add(u, p.onLookup) + if err != nil { + return err + } + + ips.entry = entry + } + } else { + for u, ipList := range p.connectionsMap { + if err := p.proxy.AddConnection(false, u, ipList.ips...); err != nil { + return err + } } } - if p.dnsDiscovery.Load() { - go p.monitorProxyIPs() - } return nil } func (p *proxyConfigStatic) Stop() { - if p.dnsDiscovery.CompareAndSwap(true, false) { - p.stopping <- struct{}{} - <-p.stopped - } } func (p *proxyConfigStatic) Reload(config *goconf.ConfigFile) error { return p.configure(config, true) } -func (p *proxyConfigStatic) monitorProxyIPs() { - log.Printf("Start monitoring proxy IPs") - ticker := time.NewTicker(updateDnsInterval) - for { - select { - case <-ticker.C: - p.updateProxyIPs() - case <-p.stopping: - p.stopped <- struct{}{} - return - } - } -} - -func (p *proxyConfigStatic) updateProxyIPs() { +func (p *proxyConfigStatic) onLookup(entry *DnsMonitorEntry, all []net.IP, added []net.IP, keep []net.IP, removed []net.IP) { p.mu.Lock() defer p.mu.Unlock() - for u, iplist := range p.connectionsMap { - if len(iplist.ips) == 0 { - continue - } + u := entry.URL() + for _, ip := range keep { + p.proxy.KeepConnection(u, ip) + } - if net.ParseIP(iplist.hostname) != nil { - // No need to lookup endpoints that connect to IP addresses. - continue - } - - ips, err := lookupProxyIP(iplist.hostname) - if err != nil { - log.Printf("Could not lookup %s: %s", iplist.hostname, err) - continue - } - - var newIPs []net.IP - var removedIPs []net.IP - for _, oldIP := range iplist.ips { - found := false - for idx, newIP := range ips { - if oldIP.Equal(newIP) { - ips = append(ips[:idx], ips[idx+1:]...) - found = true - p.proxy.KeepConnection(u, oldIP) - newIPs = append(newIPs, oldIP) - break - } - } - - if !found { - removedIPs = append(removedIPs, oldIP) - } - } - - if len(ips) > 0 { - newIPs = append(newIPs, ips...) - if err := p.proxy.AddConnection(true, u, ips...); err != nil { - log.Printf("Could not add proxy connection to %s with %+v: %s", u, ips, err) - } - } - iplist.ips = newIPs - - if len(removedIPs) > 0 { - p.proxy.RemoveConnection(u, removedIPs...) + if len(added) > 0 { + if err := p.proxy.AddConnection(true, u, added...); err != nil { + log.Printf("Could not add proxy connection to %s with %+v: %s", u, added, err) } } + + if len(removed) > 0 { + p.proxy.RemoveConnection(u, removed...) + } + + if ipList, found := p.connectionsMap[u]; found { + ipList.ips = all + } } diff --git a/proxy_config_static_test.go b/proxy_config_static_test.go index 89896fa..de331b3 100644 --- a/proxy_config_static_test.go +++ b/proxy_config_static_test.go @@ -25,24 +25,26 @@ import ( "net" "strings" "testing" + "time" "github.com/dlintw/goconf" ) -func newProxyConfigStatic(t *testing.T, proxy McuProxy, dns bool, urls ...string) ProxyConfig { +func newProxyConfigStatic(t *testing.T, proxy McuProxy, dns bool, urls ...string) (ProxyConfig, *DnsMonitor) { cfg := goconf.NewConfigFile() cfg.AddOption("mcu", "url", strings.Join(urls, " ")) if dns { cfg.AddOption("mcu", "dnsdiscovery", "true") } - p, err := NewProxyConfigStatic(cfg, proxy) + dnsMonitor := newDnsMonitorForTest(t, time.Hour) // will be updated manually + p, err := NewProxyConfigStatic(cfg, proxy, dnsMonitor) if err != nil { t.Fatal(err) } t.Cleanup(func() { p.Stop() }) - return p + return p, dnsMonitor } func updateProxyConfigStatic(t *testing.T, config ProxyConfig, dns bool, urls ...string) { @@ -58,7 +60,7 @@ func updateProxyConfigStatic(t *testing.T, config ProxyConfig, dns bool, urls .. func TestProxyConfigStaticSimple(t *testing.T) { proxy := newMcuProxyForConfig(t) - config := newProxyConfigStatic(t, proxy, false, "https://foo/") + config, _ := newProxyConfigStatic(t, proxy, false, "https://foo/") proxy.Expect("add", "https://foo/") if err := config.Start(); err != nil { t.Fatal(err) @@ -75,38 +77,31 @@ func TestProxyConfigStaticSimple(t *testing.T) { } func TestProxyConfigStaticDNS(t *testing.T) { - old := lookupProxyIP - t.Cleanup(func() { - lookupProxyIP = old - }) - proxyIPs := make(map[string][]net.IP) - lookupProxyIP = func(hostname string) ([]net.IP, error) { - ips := append([]net.IP{}, proxyIPs[hostname]...) - return ips, nil - } - proxyIPs["foo"] = []net.IP{ + lookup := newMockDnsLookupForTest(t) + lookup.Set("foo", []net.IP{ net.ParseIP("192.168.0.1"), net.ParseIP("10.1.2.3"), - } + }) proxy := newMcuProxyForConfig(t) - config := newProxyConfigStatic(t, proxy, true, "https://foo/").(*proxyConfigStatic) - proxy.Expect("add", "https://foo/", proxyIPs["foo"]...) + config, dnsMonitor := newProxyConfigStatic(t, proxy, true, "https://foo/") + proxy.Expect("add", "https://foo/", lookup.Get("foo")...) if err := config.Start(); err != nil { t.Fatal(err) } - proxyIPs["foo"] = []net.IP{ + dnsMonitor.checkHostnames() + lookup.Set("foo", []net.IP{ net.ParseIP("192.168.0.1"), net.ParseIP("192.168.1.1"), net.ParseIP("192.168.1.2"), - } + }) proxy.Expect("keep", "https://foo/", net.ParseIP("192.168.0.1")) proxy.Expect("add", "https://foo/", net.ParseIP("192.168.1.1"), net.ParseIP("192.168.1.2")) proxy.Expect("remove", "https://foo/", net.ParseIP("10.1.2.3")) - config.updateProxyIPs() + dnsMonitor.checkHostnames() proxy.Expect("add", "https://bar/") - proxy.Expect("remove", "https://foo/", proxyIPs["foo"]...) + proxy.Expect("remove", "https://foo/", lookup.Get("foo")...) updateProxyConfigStatic(t, config, false, "https://bar/") } diff --git a/server/main.go b/server/main.go index 11147a9..8e7d944 100644 --- a/server/main.go +++ b/server/main.go @@ -62,6 +62,8 @@ const ( initialMcuRetry = time.Second maxMcuRetry = time.Second * 16 + + dnsMonitorInterval = time.Second ) func createListener(addr string) (net.Listener, error) { @@ -154,6 +156,12 @@ func main() { } defer events.Close() + dnsMonitor, err := signaling.NewDnsMonitor(dnsMonitorInterval) + if err != nil { + log.Fatal("Could not create DNS monitor: ", err) + } + defer dnsMonitor.Stop() + etcdClient, err := signaling.NewEtcdClient(config, "mcu") if err != nil { log.Fatalf("Could not create etcd client: %s", err) @@ -209,7 +217,7 @@ func main() { signaling.UnregisterProxyMcuStats() signaling.RegisterJanusMcuStats() case signaling.McuTypeProxy: - mcu, err = signaling.NewMcuProxy(config, etcdClient, rpcClients) + mcu, err = signaling.NewMcuProxy(config, etcdClient, rpcClients, dnsMonitor) signaling.UnregisterJanusMcuStats() signaling.RegisterProxyMcuStats() default: