diff --git a/dns_monitor.go b/dns_monitor.go index c1cbdcd..59f7f2e 100644 --- a/dns_monitor.go +++ b/dns_monitor.go @@ -43,8 +43,7 @@ const ( type DnsMonitorCallback = func(entry *DnsMonitorEntry, all []net.IP, add []net.IP, keep []net.IP, remove []net.IP) type DnsMonitorEntry struct { - removing atomic.Bool - entry *dnsMonitorEntry + entry atomic.Pointer[dnsMonitorEntry] url string callback DnsMonitorCallback } @@ -145,6 +144,7 @@ type DnsMonitor struct { stopCtx context.Context stopFunc func() + stopped chan struct{} mu sync.RWMutex cond *sync.Cond @@ -167,6 +167,7 @@ func NewDnsMonitor(interval time.Duration) (*DnsMonitor, error) { stopCtx: stopCtx, stopFunc: stopFunc, + stopped: make(chan struct{}), hostnames: make(map[string]*dnsMonitorEntry), } @@ -183,6 +184,7 @@ func (m *DnsMonitor) Start() error { func (m *DnsMonitor) Stop() { m.stopFunc() m.cond.Signal() + <-m.stopped } func (m *DnsMonitor) Add(target string, callback DnsMonitorCallback) (*DnsMonitorEntry, error) { @@ -219,14 +221,15 @@ func (m *DnsMonitor) Add(target string, callback DnsMonitorCallback) (*DnsMonito } m.hostnames[hostname] = entry } - e.entry = entry + e.entry.Store(entry) entry.addEntry(e) m.cond.Signal() return e, nil } func (m *DnsMonitor) Remove(entry *DnsMonitorEntry) { - if !entry.removing.CompareAndSwap(false, true) { + oldEntry := entry.entry.Swap(nil) + if oldEntry == nil { // Already removed. return } @@ -244,16 +247,11 @@ func (m *DnsMonitor) Remove(entry *DnsMonitorEntry) { } defer m.mu.Unlock() - if entry.entry == nil { - return - } - - e, found := m.hostnames[entry.entry.hostname] + e, found := m.hostnames[oldEntry.hostname] if !found { return } - entry.entry = nil if e.removeEntry(entry) { delete(m.hostnames, e.hostname) } @@ -270,7 +268,7 @@ func (m *DnsMonitor) clearRemoved() { for hostname, entry := range m.hostnames { deleted := false for e := range entry.entries { - if e.removing.Load() { + if e.entry.Load() == nil { delete(entry.entries, e) deleted = true } @@ -296,6 +294,7 @@ func (m *DnsMonitor) waitForEntries() (waited bool) { func (m *DnsMonitor) run() { ticker := time.NewTicker(m.interval) defer ticker.Stop() + defer close(m.stopped) for { if m.waitForEntries() { diff --git a/proxy_config_static.go b/proxy_config_static.go index 9e1a887..84f7548 100644 --- a/proxy_config_static.go +++ b/proxy_config_static.go @@ -151,6 +151,10 @@ func (p *proxyConfigStatic) Start() error { if p.dnsDiscovery { for u, ips := range p.connectionsMap { + if ips.entry != nil { + continue + } + entry, err := p.dnsMonitor.Add(u, p.onLookup) if err != nil { return err @@ -170,6 +174,19 @@ func (p *proxyConfigStatic) Start() error { } func (p *proxyConfigStatic) Stop() { + p.mu.Lock() + defer p.mu.Unlock() + + if p.dnsDiscovery { + for _, ips := range p.connectionsMap { + if ips.entry == nil { + continue + } + + p.dnsMonitor.Remove(ips.entry) + ips.entry = nil + } + } } func (p *proxyConfigStatic) Reload(config *goconf.ConfigFile) error {