/** * Standalone signaling server for the Nextcloud Spreed app. * Copyright (C) 2023 struktur AG * * @author Joachim Bauch * * @license GNU AGPL version 3 or any later version * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see . */ package signaling import ( "context" "log" "net" "net/url" "strings" "sync" "sync/atomic" "time" ) var ( lookupDnsMonitorIP = net.LookupIP ) const ( defaultDnsMonitorInterval = time.Second ) type DnsMonitorCallback = func(entry *DnsMonitorEntry, all []net.IP, add []net.IP, keep []net.IP, remove []net.IP) type DnsMonitorEntry struct { entry atomic.Pointer[dnsMonitorEntry] url string callback DnsMonitorCallback } func (e *DnsMonitorEntry) URL() string { return e.url } type dnsMonitorEntry struct { hostname string hostIP net.IP mu sync.Mutex ips []net.IP entries map[*DnsMonitorEntry]bool } func (e *dnsMonitorEntry) setIPs(ips []net.IP, fromIP bool) { e.mu.Lock() defer e.mu.Unlock() empty := len(e.ips) == 0 if empty { // Simple case: initial lookup. if len(ips) > 0 { e.ips = ips e.runCallbacks(ips, ips, nil, nil) } return } else if fromIP { // No more updates possible for IP addresses. return } else if len(ips) == 0 { // Simple case: no records received from lookup. if !empty { removed := e.ips e.ips = nil e.runCallbacks(nil, nil, nil, removed) } return } var newIPs []net.IP var addedIPs []net.IP var removedIPs []net.IP var keepIPs []net.IP for _, oldIP := range e.ips { found := false for idx, newIP := range ips { if oldIP.Equal(newIP) { ips = append(ips[:idx], ips[idx+1:]...) found = true keepIPs = append(keepIPs, oldIP) newIPs = append(newIPs, oldIP) break } } if !found { removedIPs = append(removedIPs, oldIP) } } if len(ips) > 0 { addedIPs = append(addedIPs, ips...) newIPs = append(newIPs, ips...) } e.ips = newIPs if len(addedIPs) > 0 || len(removedIPs) > 0 { e.runCallbacks(newIPs, addedIPs, keepIPs, removedIPs) } } func (e *dnsMonitorEntry) addEntry(entry *DnsMonitorEntry) { e.mu.Lock() defer e.mu.Unlock() e.entries[entry] = true } func (e *dnsMonitorEntry) removeEntry(entry *DnsMonitorEntry) bool { e.mu.Lock() defer e.mu.Unlock() delete(e.entries, entry) return len(e.entries) == 0 } func (e *dnsMonitorEntry) runCallbacks(all []net.IP, add []net.IP, keep []net.IP, remove []net.IP) { for entry := range e.entries { entry.callback(entry, all, add, keep, remove) } } type DnsMonitor struct { interval time.Duration stopCtx context.Context stopFunc func() stopped chan struct{} mu sync.RWMutex cond *sync.Cond hostnames map[string]*dnsMonitorEntry hasRemoved atomic.Bool // Can be overwritten from tests. checkHostnames func() } func NewDnsMonitor(interval time.Duration) (*DnsMonitor, error) { if interval < 0 { interval = defaultDnsMonitorInterval } stopCtx, stopFunc := context.WithCancel(context.Background()) monitor := &DnsMonitor{ interval: interval, stopCtx: stopCtx, stopFunc: stopFunc, stopped: make(chan struct{}), hostnames: make(map[string]*dnsMonitorEntry), } monitor.cond = sync.NewCond(&monitor.mu) monitor.checkHostnames = monitor.doCheckHostnames return monitor, nil } func (m *DnsMonitor) Start() error { go m.run() return nil } func (m *DnsMonitor) Stop() { m.stopFunc() m.cond.Signal() <-m.stopped } func (m *DnsMonitor) Add(target string, callback DnsMonitorCallback) (*DnsMonitorEntry, error) { var hostname string if strings.Contains(target, "://") { // Full URL passed. parsed, err := url.Parse(target) if err != nil { return nil, err } hostname = parsed.Host } else { // Hostname only passed. hostname = target } if h, _, err := net.SplitHostPort(hostname); err == nil { hostname = h } m.mu.Lock() defer m.mu.Unlock() e := &DnsMonitorEntry{ url: target, callback: callback, } entry, found := m.hostnames[hostname] if !found { entry = &dnsMonitorEntry{ hostname: hostname, hostIP: net.ParseIP(hostname), entries: make(map[*DnsMonitorEntry]bool), } m.hostnames[hostname] = entry } e.entry.Store(entry) entry.addEntry(e) m.cond.Signal() return e, nil } func (m *DnsMonitor) Remove(entry *DnsMonitorEntry) { oldEntry := entry.entry.Swap(nil) if oldEntry == nil { // Already removed. return } locked := m.mu.TryLock() // Spin-lock for simple cases that resolve immediately to avoid deferred removal. for i := 0; !locked && i < 1000; i++ { time.Sleep(time.Nanosecond) locked = m.mu.TryLock() } if !locked { // Currently processing callbacks for this entry, need to defer removal. m.hasRemoved.Store(true) return } defer m.mu.Unlock() e, found := m.hostnames[oldEntry.hostname] if !found { return } if e.removeEntry(entry) { delete(m.hostnames, e.hostname) } } func (m *DnsMonitor) clearRemoved() { if !m.hasRemoved.CompareAndSwap(true, false) { return } m.mu.Lock() defer m.mu.Unlock() for hostname, entry := range m.hostnames { deleted := false for e := range entry.entries { if e.entry.Load() == nil { delete(entry.entries, e) deleted = true } } if deleted && len(entry.entries) == 0 { delete(m.hostnames, hostname) } } } func (m *DnsMonitor) waitForEntries() (waited bool) { m.mu.Lock() defer m.mu.Unlock() for len(m.hostnames) == 0 && m.stopCtx.Err() == nil { m.cond.Wait() waited = true } return } func (m *DnsMonitor) run() { ticker := time.NewTicker(m.interval) defer ticker.Stop() defer close(m.stopped) for { if m.waitForEntries() { ticker.Reset(m.interval) if m.stopCtx.Err() == nil { // Initial check when a new entry was added. More checks will be // triggered by the Ticker. m.checkHostnames() continue } } select { case <-m.stopCtx.Done(): return case <-ticker.C: m.checkHostnames() } } } func (m *DnsMonitor) doCheckHostnames() { m.clearRemoved() m.mu.RLock() defer m.mu.RUnlock() for _, entry := range m.hostnames { m.checkHostname(entry) } } func (m *DnsMonitor) checkHostname(entry *dnsMonitorEntry) { if len(entry.hostIP) > 0 { entry.setIPs([]net.IP{entry.hostIP}, true) return } ips, err := lookupDnsMonitorIP(entry.hostname) if err != nil { log.Printf("Could not lookup %s: %s", entry.hostname, err) return } entry.setIPs(ips, false) }