From 8417f37cbaf8ee471fa62d0c7c22c71f70b77b76 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Thu, 21 Dec 2023 16:55:37 +0100 Subject: [PATCH] Move common DNS monitor code to own class. --- dns_monitor.go | 248 ++++++++++++++++++++++++++++++++++ dns_monitor_test.go | 317 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 565 insertions(+) create mode 100644 dns_monitor.go create mode 100644 dns_monitor_test.go diff --git a/dns_monitor.go b/dns_monitor.go new file mode 100644 index 0000000..be19d63 --- /dev/null +++ b/dns_monitor.go @@ -0,0 +1,248 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2023 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "context" + "log" + "net" + "net/url" + "strings" + "sync" + "time" +) + +var ( + lookupDnsMonitorIP = net.LookupIP +) + +const ( + defaultDnsMonitorInterval = time.Second +) + +type DnsMonitorCallback = func(entry *DnsMonitorEntry, all []net.IP, add []net.IP, keep []net.IP, remove []net.IP) + +type DnsMonitorEntry struct { + entry *dnsMonitorEntry + url string + callback DnsMonitorCallback +} + +func (e *DnsMonitorEntry) URL() string { + return e.url +} + +type dnsMonitorEntry struct { + hostname string + hostIP net.IP + + ips []net.IP + entries map[*DnsMonitorEntry]bool +} + +func (e *dnsMonitorEntry) setIPs(ips []net.IP, fromIP bool) { + empty := len(e.ips) == 0 + if empty { + // Simple case: initial lookup. + if len(ips) > 0 { + e.ips = ips + e.runCallbacks(ips, ips, nil, nil) + } + return + } else if fromIP { + // No more updates possible for IP addresses. + return + } else if len(ips) == 0 { + // Simple case: no records received from lookup. + if !empty { + removed := e.ips + e.ips = nil + e.runCallbacks(nil, nil, nil, removed) + } + return + } + + var newIPs []net.IP + var addedIPs []net.IP + var removedIPs []net.IP + var keepIPs []net.IP + for _, oldIP := range e.ips { + found := false + for idx, newIP := range ips { + if oldIP.Equal(newIP) { + ips = append(ips[:idx], ips[idx+1:]...) + found = true + keepIPs = append(keepIPs, oldIP) + newIPs = append(newIPs, oldIP) + break + } + } + + if !found { + removedIPs = append(removedIPs, oldIP) + } + } + + if len(ips) > 0 { + addedIPs = append(addedIPs, ips...) + newIPs = append(newIPs, ips...) + } + e.ips = newIPs + + if len(addedIPs) > 0 || len(removedIPs) > 0 { + e.runCallbacks(newIPs, addedIPs, keepIPs, removedIPs) + } +} + +func (e *dnsMonitorEntry) runCallbacks(all []net.IP, add []net.IP, keep []net.IP, remove []net.IP) { + for entry := range e.entries { + entry.callback(entry, all, add, keep, remove) + } +} + +type DnsMonitor struct { + interval time.Duration + + stopCtx context.Context + stopFunc func() + + mu sync.RWMutex + hostnames map[string]*dnsMonitorEntry +} + +func NewDnsMonitor(interval time.Duration) (*DnsMonitor, error) { + if interval < 0 { + interval = defaultDnsMonitorInterval + } + + stopCtx, stopFunc := context.WithCancel(context.Background()) + monitor := &DnsMonitor{ + interval: interval, + + stopCtx: stopCtx, + stopFunc: stopFunc, + + hostnames: make(map[string]*dnsMonitorEntry), + } + return monitor, nil +} + +func (m *DnsMonitor) Start() error { + go m.run() + return nil +} + +func (m *DnsMonitor) Stop() { + m.stopFunc() +} + +func (m *DnsMonitor) Add(target string, callback DnsMonitorCallback) (*DnsMonitorEntry, error) { + var hostname string + if strings.Contains(target, "://") { + // Full URL passed. + parsed, err := url.Parse(target) + if err != nil { + return nil, err + } + hostname = parsed.Host + } else { + // Hostname with optional port passed. + hostname = target + if h, _, err := net.SplitHostPort(target); err == nil { + hostname = h + } + } + + m.mu.Lock() + defer m.mu.Unlock() + + e := &DnsMonitorEntry{ + url: target, + callback: callback, + } + + entry, found := m.hostnames[hostname] + if !found { + entry = &dnsMonitorEntry{ + hostname: hostname, + hostIP: net.ParseIP(hostname), + entries: make(map[*DnsMonitorEntry]bool), + } + m.hostnames[hostname] = entry + } + e.entry = entry + entry.entries[e] = true + return e, nil +} + +func (m *DnsMonitor) Remove(entry *DnsMonitorEntry) { + m.mu.Lock() + defer m.mu.Unlock() + + if entry.entry == nil { + return + } + + e, found := m.hostnames[entry.entry.hostname] + if !found { + return + } + + entry.entry = nil + delete(e.entries, entry) +} + +func (m *DnsMonitor) run() { + ticker := time.NewTicker(m.interval) + for { + select { + case <-m.stopCtx.Done(): + return + case <-ticker.C: + m.checkHostnames() + } + } +} + +func (m *DnsMonitor) checkHostnames() { + m.mu.RLock() + defer m.mu.RUnlock() + + for _, entry := range m.hostnames { + m.checkHostname(entry) + } +} + +func (m *DnsMonitor) checkHostname(entry *dnsMonitorEntry) { + if len(entry.hostIP) > 0 { + entry.setIPs([]net.IP{entry.hostIP}, true) + return + } + + ips, err := lookupDnsMonitorIP(entry.hostname) + if err != nil { + log.Printf("Could not lookup %s: %s", entry.hostname, err) + return + } + + entry.setIPs(ips, false) +} diff --git a/dns_monitor_test.go b/dns_monitor_test.go new file mode 100644 index 0000000..4fa7b49 --- /dev/null +++ b/dns_monitor_test.go @@ -0,0 +1,317 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2023 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "context" + "fmt" + "net" + "reflect" + "sync" + "testing" + "time" +) + +type mockDnsLookup struct { + sync.RWMutex + + ips map[string][]net.IP +} + +func newMockDnsLookupForTest(t *testing.T) *mockDnsLookup { + mock := &mockDnsLookup{ + ips: make(map[string][]net.IP), + } + prev := lookupDnsMonitorIP + t.Cleanup(func() { + lookupDnsMonitorIP = prev + }) + lookupDnsMonitorIP = mock.lookup + return mock +} + +func (m *mockDnsLookup) Set(host string, ips []net.IP) { + m.Lock() + defer m.Unlock() + + m.ips[host] = ips +} + +func (m *mockDnsLookup) Get(host string) []net.IP { + m.Lock() + defer m.Unlock() + + return m.ips[host] +} + +func (m *mockDnsLookup) lookup(host string) ([]net.IP, error) { + m.RLock() + defer m.RUnlock() + + ips, found := m.ips[host] + if !found { + return nil, &net.DNSError{ + Err: fmt.Sprintf("could not resolve %s", host), + Name: host, + IsNotFound: true, + } + } + + return append([]net.IP{}, ips...), nil +} + +func newDnsMonitorForTest(t *testing.T, interval time.Duration) *DnsMonitor { + t.Helper() + + monitor, err := NewDnsMonitor(interval) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + monitor.Stop() + }) + + if err := monitor.Start(); err != nil { + t.Fatal(err) + } + + return monitor +} + +type dnsMonitorReceiverRecord struct { + all []net.IP + add []net.IP + keep []net.IP + remove []net.IP +} + +func (r *dnsMonitorReceiverRecord) Equal(other *dnsMonitorReceiverRecord) bool { + return r == other || (reflect.DeepEqual(r.add, other.add) && + reflect.DeepEqual(r.keep, other.keep) && + reflect.DeepEqual(r.remove, other.remove)) +} + +func (r *dnsMonitorReceiverRecord) String() string { + return fmt.Sprintf("all=%v, add=%v, keep=%v, remove=%v", r.all, r.add, r.keep, r.remove) +} + +var ( + expectNone = &dnsMonitorReceiverRecord{} +) + +type dnsMonitorReceiver struct { + sync.Mutex + + t *testing.T + expected *dnsMonitorReceiverRecord + received *dnsMonitorReceiverRecord +} + +func newDnsMonitorReceiverForTest(t *testing.T) *dnsMonitorReceiver { + return &dnsMonitorReceiver{ + t: t, + } +} + +func (r *dnsMonitorReceiver) OnLookup(entry *DnsMonitorEntry, all, add, keep, remove []net.IP) { + r.Lock() + defer r.Unlock() + + received := &dnsMonitorReceiverRecord{ + all: all, + add: add, + keep: keep, + remove: remove, + } + + expected := r.expected + r.expected = nil + if expected == expectNone { + r.t.Errorf("expected no event, got %v", received) + return + } + + if expected == nil { + if r.received != nil && !r.received.Equal(received) { + r.t.Errorf("already received %v, got %v", r.received, received) + } + return + } + + if !expected.Equal(received) { + r.t.Errorf("expected %v, got %v", expected, received) + } + r.received = nil + r.expected = nil +} + +func (r *dnsMonitorReceiver) WaitForExpected(ctx context.Context) { + r.t.Helper() + r.Lock() + defer r.Unlock() + + ticker := time.NewTicker(time.Microsecond) + abort := false + for r.expected != nil && !abort { + r.Unlock() + select { + case <-ticker.C: + case <-ctx.Done(): + r.t.Error(ctx.Err()) + abort = true + } + r.Lock() + } +} + +func (r *dnsMonitorReceiver) Expect(all, add, keep, remove []net.IP) { + r.t.Helper() + r.Lock() + defer r.Unlock() + + if r.expected != nil && r.expected != expectNone { + r.t.Errorf("didn't get previously expected %v", r.expected) + } + + expected := &dnsMonitorReceiverRecord{ + all: all, + add: add, + keep: keep, + remove: remove, + } + if r.received != nil && r.received.Equal(expected) { + r.received = nil + return + } + + r.expected = expected +} + +func (r *dnsMonitorReceiver) ExpectNone() { + r.t.Helper() + r.Lock() + defer r.Unlock() + + if r.expected != nil && r.expected != expectNone { + r.t.Errorf("didn't get previously expected %v", r.expected) + } + + r.expected = expectNone +} + +func TestDnsMonitor(t *testing.T) { + lookup := newMockDnsLookupForTest(t) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + interval := time.Millisecond + monitor := newDnsMonitorForTest(t, interval) + + ip1 := net.ParseIP("192.168.0.1") + ip2 := net.ParseIP("192.168.1.1") + ip3 := net.ParseIP("10.1.2.3") + ips1 := []net.IP{ + ip1, + ip2, + } + lookup.Set("foo", ips1) + + rec1 := newDnsMonitorReceiverForTest(t) + rec1.Expect(ips1, ips1, nil, nil) + + entry1, err := monitor.Add("https://foo", rec1.OnLookup) + if err != nil { + t.Fatal(err) + } + defer monitor.Remove(entry1) + + rec1.WaitForExpected(ctx) + + ips2 := []net.IP{ + ip1, + ip2, + ip3, + } + add2 := []net.IP{ip3} + keep2 := []net.IP{ip1, ip2} + lookup.Set("foo", ips2) + rec1.Expect(ips2, add2, keep2, nil) + rec1.WaitForExpected(ctx) + + ips3 := []net.IP{ + ip2, + ip3, + } + lookup.Set("foo", ips3) + keep3 := []net.IP{ip2, ip3} + remove3 := []net.IP{ip1} + rec1.Expect(ips3, nil, keep3, remove3) + rec1.WaitForExpected(ctx) + + rec1.ExpectNone() + time.Sleep(5 * interval) + + lookup.Set("foo", nil) + remove4 := []net.IP{ip2, ip3} + rec1.Expect(nil, nil, nil, remove4) + rec1.WaitForExpected(ctx) + + rec1.ExpectNone() + time.Sleep(5 * interval) + + // Removing multiple times is supported. + monitor.Remove(entry1) + monitor.Remove(entry1) + + // No more events after removing. + lookup.Set("foo", ips1) + rec1.ExpectNone() + time.Sleep(5 * interval) +} + +func TestDnsMonitorIP(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + interval := time.Millisecond + monitor := newDnsMonitorForTest(t, interval) + + ip := "192.168.0.1" + ips := []net.IP{ + net.ParseIP(ip), + } + + rec1 := newDnsMonitorReceiverForTest(t) + rec1.Expect(ips, ips, nil, nil) + + entry, err := monitor.Add("https://"+ip, rec1.OnLookup) + if err != nil { + t.Fatal(err) + } + defer monitor.Remove(entry) + + rec1.WaitForExpected(ctx) + + rec1.ExpectNone() + time.Sleep(5 * interval) +}