From cf5ee8e4a1bd0dfadfa29fcf01ffa9ef0c965ef4 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Wed, 7 Feb 2024 12:52:04 +0100 Subject: [PATCH] Fix deadlock when entry is removed while receiver holds lock in lookup. --- dns_monitor.go | 46 +++++++++++++++++++++- dns_monitor_test.go | 94 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 139 insertions(+), 1 deletion(-) diff --git a/dns_monitor.go b/dns_monitor.go index 3240682..c1cbdcd 100644 --- a/dns_monitor.go +++ b/dns_monitor.go @@ -28,6 +28,7 @@ import ( "net/url" "strings" "sync" + "sync/atomic" "time" ) @@ -42,6 +43,7 @@ const ( type DnsMonitorCallback = func(entry *DnsMonitorEntry, all []net.IP, add []net.IP, keep []net.IP, remove []net.IP) type DnsMonitorEntry struct { + removing atomic.Bool entry *dnsMonitorEntry url string callback DnsMonitorCallback @@ -148,6 +150,8 @@ type DnsMonitor struct { cond *sync.Cond hostnames map[string]*dnsMonitorEntry + hasRemoved atomic.Bool + // Can be overwritten from tests. checkHostnames func() } @@ -222,7 +226,22 @@ func (m *DnsMonitor) Add(target string, callback DnsMonitorCallback) (*DnsMonito } func (m *DnsMonitor) Remove(entry *DnsMonitorEntry) { - m.mu.Lock() + if !entry.removing.CompareAndSwap(false, true) { + // Already removed. + return + } + + locked := m.mu.TryLock() + // Spin-lock for simple cases that resolve immediately to avoid deferred removal. + for i := 0; !locked && i < 1000; i++ { + time.Sleep(time.Nanosecond) + locked = m.mu.TryLock() + } + if !locked { + // Currently processing callbacks for this entry, need to defer removal. + m.hasRemoved.Store(true) + return + } defer m.mu.Unlock() if entry.entry == nil { @@ -240,6 +259,29 @@ func (m *DnsMonitor) Remove(entry *DnsMonitorEntry) { } } +func (m *DnsMonitor) clearRemoved() { + if !m.hasRemoved.CompareAndSwap(true, false) { + return + } + + m.mu.Lock() + defer m.mu.Unlock() + + for hostname, entry := range m.hostnames { + deleted := false + for e := range entry.entries { + if e.removing.Load() { + delete(entry.entries, e) + deleted = true + } + } + + if deleted && len(entry.entries) == 0 { + delete(m.hostnames, hostname) + } + } +} + func (m *DnsMonitor) waitForEntries() (waited bool) { m.mu.Lock() defer m.mu.Unlock() @@ -276,6 +318,8 @@ func (m *DnsMonitor) run() { } func (m *DnsMonitor) doCheckHostnames() { + m.clearRemoved() + m.mu.RLock() defer m.mu.RUnlock() diff --git a/dns_monitor_test.go b/dns_monitor_test.go index 0b1466d..eedae1c 100644 --- a/dns_monitor_test.go +++ b/dns_monitor_test.go @@ -332,3 +332,97 @@ func TestDnsMonitorNoLookupIfEmpty(t *testing.T) { t.Error("should not have checked hostnames") } } + +type deadlockMonitorReceiver struct { + t *testing.T + monitor *DnsMonitor + + mu sync.RWMutex + wg sync.WaitGroup + + entry *DnsMonitorEntry + started chan struct{} + triggered bool + closed atomic.Bool +} + +func newDeadlockMonitorReceiver(t *testing.T, monitor *DnsMonitor) *deadlockMonitorReceiver { + return &deadlockMonitorReceiver{ + t: t, + monitor: monitor, + started: make(chan struct{}), + } +} + +func (r *deadlockMonitorReceiver) OnLookup(entry *DnsMonitorEntry, all []net.IP, add []net.IP, keep []net.IP, remove []net.IP) { + if r.closed.Load() { + r.t.Error("received lookup after closed") + return + } + + r.mu.Lock() + defer r.mu.Unlock() + + if r.triggered { + return + } + + r.triggered = true + r.wg.Add(1) + go func() { + defer r.wg.Done() + + r.mu.RLock() + defer r.mu.RUnlock() + + close(r.started) + time.Sleep(50 * time.Millisecond) + }() +} + +func (r *deadlockMonitorReceiver) Start() { + r.mu.Lock() + defer r.mu.Unlock() + + entry, err := r.monitor.Add("foo", r.OnLookup) + if err != nil { + r.t.Errorf("error adding listener: %s", err) + return + } + + r.entry = entry +} + +func (r *deadlockMonitorReceiver) Close() { + r.mu.Lock() + defer r.mu.Unlock() + + if r.entry != nil { + r.monitor.Remove(r.entry) + r.closed.Store(true) + } + r.wg.Wait() +} + +func TestDnsMonitorDeadlock(t *testing.T) { + lookup := newMockDnsLookupForTest(t) + ip1 := net.ParseIP("192.168.0.1") + ip2 := net.ParseIP("192.168.0.2") + lookup.Set("foo", []net.IP{ip1}) + + interval := time.Millisecond + monitor := newDnsMonitorForTest(t, interval) + + r := newDeadlockMonitorReceiver(t, monitor) + r.Start() + <-r.started + lookup.Set("foo", []net.IP{ip2}) + r.Close() + lookup.Set("foo", []net.IP{ip1}) + time.Sleep(10 * interval) + monitor.mu.Lock() + defer monitor.mu.Unlock() + if len(monitor.hostnames) > 0 { + t.Errorf("should have cleared hostnames, got %+v", monitor.hostnames) + } +}