mirror of
https://github.com/strukturag/nextcloud-spreed-signaling
synced 2024-05-02 05:52:44 +02:00
Merge pull request #628 from strukturag/dnsmonitor
Refactor DNS monitoring
This commit is contained in:
commit
32ccc2e50e
2
Makefile
2
Makefile
|
@ -37,7 +37,7 @@ TIMEOUT := 60s
|
|||
endif
|
||||
|
||||
ifneq ($(TEST),)
|
||||
TESTARGS := $(TESTARGS) -run $(TEST)
|
||||
TESTARGS := $(TESTARGS) -run "$(TEST)"
|
||||
endif
|
||||
|
||||
ifneq ($(COUNT),)
|
||||
|
|
|
@ -169,7 +169,7 @@ func CreateBackendServerWithClusteringForTestFromConfig(t *testing.T, config1 *g
|
|||
t.Cleanup(func() {
|
||||
events1.Close()
|
||||
})
|
||||
client1 := NewGrpcClientsForTest(t, addr2)
|
||||
client1, _ := NewGrpcClientsForTest(t, addr2)
|
||||
hub1, err := NewHub(config1, events1, grpcServer1, client1, nil, r1, "no-version")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -198,7 +198,7 @@ func CreateBackendServerWithClusteringForTestFromConfig(t *testing.T, config1 *g
|
|||
t.Cleanup(func() {
|
||||
events2.Close()
|
||||
})
|
||||
client2 := NewGrpcClientsForTest(t, addr1)
|
||||
client2, _ := NewGrpcClientsForTest(t, addr1)
|
||||
hub2, err := NewHub(config2, events2, grpcServer2, client2, nil, r2, "no-version")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
300
dns_monitor.go
Normal file
300
dns_monitor.go
Normal file
|
@ -0,0 +1,300 @@
|
|||
/**
|
||||
* Standalone signaling server for the Nextcloud Spreed app.
|
||||
* Copyright (C) 2023 struktur AG
|
||||
*
|
||||
* @author Joachim Bauch <bauch@struktur.de>
|
||||
*
|
||||
* @license GNU AGPL version 3 or any later version
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package signaling
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
lookupDnsMonitorIP = net.LookupIP
|
||||
)
|
||||
|
||||
const (
|
||||
defaultDnsMonitorInterval = time.Second
|
||||
)
|
||||
|
||||
type DnsMonitorCallback = func(entry *DnsMonitorEntry, all []net.IP, add []net.IP, keep []net.IP, remove []net.IP)
|
||||
|
||||
type DnsMonitorEntry struct {
|
||||
entry *dnsMonitorEntry
|
||||
url string
|
||||
callback DnsMonitorCallback
|
||||
}
|
||||
|
||||
func (e *DnsMonitorEntry) URL() string {
|
||||
return e.url
|
||||
}
|
||||
|
||||
type dnsMonitorEntry struct {
|
||||
hostname string
|
||||
hostIP net.IP
|
||||
|
||||
mu sync.Mutex
|
||||
ips []net.IP
|
||||
entries map[*DnsMonitorEntry]bool
|
||||
}
|
||||
|
||||
func (e *dnsMonitorEntry) setIPs(ips []net.IP, fromIP bool) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
empty := len(e.ips) == 0
|
||||
if empty {
|
||||
// Simple case: initial lookup.
|
||||
if len(ips) > 0 {
|
||||
e.ips = ips
|
||||
e.runCallbacks(ips, ips, nil, nil)
|
||||
}
|
||||
return
|
||||
} else if fromIP {
|
||||
// No more updates possible for IP addresses.
|
||||
return
|
||||
} else if len(ips) == 0 {
|
||||
// Simple case: no records received from lookup.
|
||||
if !empty {
|
||||
removed := e.ips
|
||||
e.ips = nil
|
||||
e.runCallbacks(nil, nil, nil, removed)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var newIPs []net.IP
|
||||
var addedIPs []net.IP
|
||||
var removedIPs []net.IP
|
||||
var keepIPs []net.IP
|
||||
for _, oldIP := range e.ips {
|
||||
found := false
|
||||
for idx, newIP := range ips {
|
||||
if oldIP.Equal(newIP) {
|
||||
ips = append(ips[:idx], ips[idx+1:]...)
|
||||
found = true
|
||||
keepIPs = append(keepIPs, oldIP)
|
||||
newIPs = append(newIPs, oldIP)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
removedIPs = append(removedIPs, oldIP)
|
||||
}
|
||||
}
|
||||
|
||||
if len(ips) > 0 {
|
||||
addedIPs = append(addedIPs, ips...)
|
||||
newIPs = append(newIPs, ips...)
|
||||
}
|
||||
e.ips = newIPs
|
||||
|
||||
if len(addedIPs) > 0 || len(removedIPs) > 0 {
|
||||
e.runCallbacks(newIPs, addedIPs, keepIPs, removedIPs)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *dnsMonitorEntry) addEntry(entry *DnsMonitorEntry) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
e.entries[entry] = true
|
||||
}
|
||||
|
||||
func (e *dnsMonitorEntry) removeEntry(entry *DnsMonitorEntry) bool {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
delete(e.entries, entry)
|
||||
return len(e.entries) == 0
|
||||
}
|
||||
|
||||
func (e *dnsMonitorEntry) runCallbacks(all []net.IP, add []net.IP, keep []net.IP, remove []net.IP) {
|
||||
for entry := range e.entries {
|
||||
entry.callback(entry, all, add, keep, remove)
|
||||
}
|
||||
}
|
||||
|
||||
type DnsMonitor struct {
|
||||
interval time.Duration
|
||||
|
||||
stopCtx context.Context
|
||||
stopFunc func()
|
||||
|
||||
mu sync.RWMutex
|
||||
cond *sync.Cond
|
||||
hostnames map[string]*dnsMonitorEntry
|
||||
|
||||
// Can be overwritten from tests.
|
||||
checkHostnames func()
|
||||
}
|
||||
|
||||
func NewDnsMonitor(interval time.Duration) (*DnsMonitor, error) {
|
||||
if interval < 0 {
|
||||
interval = defaultDnsMonitorInterval
|
||||
}
|
||||
|
||||
stopCtx, stopFunc := context.WithCancel(context.Background())
|
||||
monitor := &DnsMonitor{
|
||||
interval: interval,
|
||||
|
||||
stopCtx: stopCtx,
|
||||
stopFunc: stopFunc,
|
||||
|
||||
hostnames: make(map[string]*dnsMonitorEntry),
|
||||
}
|
||||
monitor.cond = sync.NewCond(&monitor.mu)
|
||||
monitor.checkHostnames = monitor.doCheckHostnames
|
||||
return monitor, nil
|
||||
}
|
||||
|
||||
func (m *DnsMonitor) Start() error {
|
||||
go m.run()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *DnsMonitor) Stop() {
|
||||
m.stopFunc()
|
||||
m.cond.Signal()
|
||||
}
|
||||
|
||||
func (m *DnsMonitor) Add(target string, callback DnsMonitorCallback) (*DnsMonitorEntry, error) {
|
||||
var hostname string
|
||||
if strings.Contains(target, "://") {
|
||||
// Full URL passed.
|
||||
parsed, err := url.Parse(target)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hostname = parsed.Host
|
||||
} else {
|
||||
// Hostname with optional port passed.
|
||||
hostname = target
|
||||
if h, _, err := net.SplitHostPort(target); err == nil {
|
||||
hostname = h
|
||||
}
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
e := &DnsMonitorEntry{
|
||||
url: target,
|
||||
callback: callback,
|
||||
}
|
||||
|
||||
entry, found := m.hostnames[hostname]
|
||||
if !found {
|
||||
entry = &dnsMonitorEntry{
|
||||
hostname: hostname,
|
||||
hostIP: net.ParseIP(hostname),
|
||||
entries: make(map[*DnsMonitorEntry]bool),
|
||||
}
|
||||
m.hostnames[hostname] = entry
|
||||
}
|
||||
e.entry = entry
|
||||
entry.addEntry(e)
|
||||
m.cond.Signal()
|
||||
return e, nil
|
||||
}
|
||||
|
||||
func (m *DnsMonitor) Remove(entry *DnsMonitorEntry) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if entry.entry == nil {
|
||||
return
|
||||
}
|
||||
|
||||
e, found := m.hostnames[entry.entry.hostname]
|
||||
if !found {
|
||||
return
|
||||
}
|
||||
|
||||
entry.entry = nil
|
||||
if e.removeEntry(entry) {
|
||||
delete(m.hostnames, e.hostname)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DnsMonitor) waitForEntries() (waited bool) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for len(m.hostnames) == 0 && m.stopCtx.Err() == nil {
|
||||
m.cond.Wait()
|
||||
waited = true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m *DnsMonitor) run() {
|
||||
ticker := time.NewTicker(m.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
if m.waitForEntries() {
|
||||
ticker.Reset(m.interval)
|
||||
if m.stopCtx.Err() == nil {
|
||||
// Initial check when a new entry was added. More checks will be
|
||||
// triggered by the Ticker.
|
||||
m.checkHostnames()
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-m.stopCtx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.checkHostnames()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DnsMonitor) doCheckHostnames() {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
for _, entry := range m.hostnames {
|
||||
m.checkHostname(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DnsMonitor) checkHostname(entry *dnsMonitorEntry) {
|
||||
if len(entry.hostIP) > 0 {
|
||||
entry.setIPs([]net.IP{entry.hostIP}, true)
|
||||
return
|
||||
}
|
||||
|
||||
ips, err := lookupDnsMonitorIP(entry.hostname)
|
||||
if err != nil {
|
||||
log.Printf("Could not lookup %s: %s", entry.hostname, err)
|
||||
return
|
||||
}
|
||||
|
||||
entry.setIPs(ips, false)
|
||||
}
|
334
dns_monitor_test.go
Normal file
334
dns_monitor_test.go
Normal file
|
@ -0,0 +1,334 @@
|
|||
/**
|
||||
* Standalone signaling server for the Nextcloud Spreed app.
|
||||
* Copyright (C) 2023 struktur AG
|
||||
*
|
||||
* @author Joachim Bauch <bauch@struktur.de>
|
||||
*
|
||||
* @license GNU AGPL version 3 or any later version
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package signaling
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type mockDnsLookup struct {
|
||||
sync.RWMutex
|
||||
|
||||
ips map[string][]net.IP
|
||||
}
|
||||
|
||||
func newMockDnsLookupForTest(t *testing.T) *mockDnsLookup {
|
||||
mock := &mockDnsLookup{
|
||||
ips: make(map[string][]net.IP),
|
||||
}
|
||||
prev := lookupDnsMonitorIP
|
||||
t.Cleanup(func() {
|
||||
lookupDnsMonitorIP = prev
|
||||
})
|
||||
lookupDnsMonitorIP = mock.lookup
|
||||
return mock
|
||||
}
|
||||
|
||||
func (m *mockDnsLookup) Set(host string, ips []net.IP) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
m.ips[host] = ips
|
||||
}
|
||||
|
||||
func (m *mockDnsLookup) Get(host string) []net.IP {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
return m.ips[host]
|
||||
}
|
||||
|
||||
func (m *mockDnsLookup) lookup(host string) ([]net.IP, error) {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
|
||||
ips, found := m.ips[host]
|
||||
if !found {
|
||||
return nil, &net.DNSError{
|
||||
Err: fmt.Sprintf("could not resolve %s", host),
|
||||
Name: host,
|
||||
IsNotFound: true,
|
||||
}
|
||||
}
|
||||
|
||||
return append([]net.IP{}, ips...), nil
|
||||
}
|
||||
|
||||
func newDnsMonitorForTest(t *testing.T, interval time.Duration) *DnsMonitor {
|
||||
t.Helper()
|
||||
|
||||
monitor, err := NewDnsMonitor(interval)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
monitor.Stop()
|
||||
})
|
||||
|
||||
if err := monitor.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return monitor
|
||||
}
|
||||
|
||||
type dnsMonitorReceiverRecord struct {
|
||||
all []net.IP
|
||||
add []net.IP
|
||||
keep []net.IP
|
||||
remove []net.IP
|
||||
}
|
||||
|
||||
func (r *dnsMonitorReceiverRecord) Equal(other *dnsMonitorReceiverRecord) bool {
|
||||
return r == other || (reflect.DeepEqual(r.add, other.add) &&
|
||||
reflect.DeepEqual(r.keep, other.keep) &&
|
||||
reflect.DeepEqual(r.remove, other.remove))
|
||||
}
|
||||
|
||||
func (r *dnsMonitorReceiverRecord) String() string {
|
||||
return fmt.Sprintf("all=%v, add=%v, keep=%v, remove=%v", r.all, r.add, r.keep, r.remove)
|
||||
}
|
||||
|
||||
var (
|
||||
expectNone = &dnsMonitorReceiverRecord{}
|
||||
)
|
||||
|
||||
type dnsMonitorReceiver struct {
|
||||
sync.Mutex
|
||||
|
||||
t *testing.T
|
||||
expected *dnsMonitorReceiverRecord
|
||||
received *dnsMonitorReceiverRecord
|
||||
}
|
||||
|
||||
func newDnsMonitorReceiverForTest(t *testing.T) *dnsMonitorReceiver {
|
||||
return &dnsMonitorReceiver{
|
||||
t: t,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *dnsMonitorReceiver) OnLookup(entry *DnsMonitorEntry, all, add, keep, remove []net.IP) {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
received := &dnsMonitorReceiverRecord{
|
||||
all: all,
|
||||
add: add,
|
||||
keep: keep,
|
||||
remove: remove,
|
||||
}
|
||||
|
||||
expected := r.expected
|
||||
r.expected = nil
|
||||
if expected == expectNone {
|
||||
r.t.Errorf("expected no event, got %v", received)
|
||||
return
|
||||
}
|
||||
|
||||
if expected == nil {
|
||||
if r.received != nil && !r.received.Equal(received) {
|
||||
r.t.Errorf("already received %v, got %v", r.received, received)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if !expected.Equal(received) {
|
||||
r.t.Errorf("expected %v, got %v", expected, received)
|
||||
}
|
||||
r.received = nil
|
||||
r.expected = nil
|
||||
}
|
||||
|
||||
func (r *dnsMonitorReceiver) WaitForExpected(ctx context.Context) {
|
||||
r.t.Helper()
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
ticker := time.NewTicker(time.Microsecond)
|
||||
abort := false
|
||||
for r.expected != nil && !abort {
|
||||
r.Unlock()
|
||||
select {
|
||||
case <-ticker.C:
|
||||
case <-ctx.Done():
|
||||
r.t.Error(ctx.Err())
|
||||
abort = true
|
||||
}
|
||||
r.Lock()
|
||||
}
|
||||
}
|
||||
|
||||
func (r *dnsMonitorReceiver) Expect(all, add, keep, remove []net.IP) {
|
||||
r.t.Helper()
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
if r.expected != nil && r.expected != expectNone {
|
||||
r.t.Errorf("didn't get previously expected %v", r.expected)
|
||||
}
|
||||
|
||||
expected := &dnsMonitorReceiverRecord{
|
||||
all: all,
|
||||
add: add,
|
||||
keep: keep,
|
||||
remove: remove,
|
||||
}
|
||||
if r.received != nil && r.received.Equal(expected) {
|
||||
r.received = nil
|
||||
return
|
||||
}
|
||||
|
||||
r.expected = expected
|
||||
}
|
||||
|
||||
func (r *dnsMonitorReceiver) ExpectNone() {
|
||||
r.t.Helper()
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
if r.expected != nil && r.expected != expectNone {
|
||||
r.t.Errorf("didn't get previously expected %v", r.expected)
|
||||
}
|
||||
|
||||
r.expected = expectNone
|
||||
}
|
||||
|
||||
func TestDnsMonitor(t *testing.T) {
|
||||
lookup := newMockDnsLookupForTest(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
interval := time.Millisecond
|
||||
monitor := newDnsMonitorForTest(t, interval)
|
||||
|
||||
ip1 := net.ParseIP("192.168.0.1")
|
||||
ip2 := net.ParseIP("192.168.1.1")
|
||||
ip3 := net.ParseIP("10.1.2.3")
|
||||
ips1 := []net.IP{
|
||||
ip1,
|
||||
ip2,
|
||||
}
|
||||
lookup.Set("foo", ips1)
|
||||
|
||||
rec1 := newDnsMonitorReceiverForTest(t)
|
||||
rec1.Expect(ips1, ips1, nil, nil)
|
||||
|
||||
entry1, err := monitor.Add("https://foo", rec1.OnLookup)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer monitor.Remove(entry1)
|
||||
|
||||
rec1.WaitForExpected(ctx)
|
||||
|
||||
ips2 := []net.IP{
|
||||
ip1,
|
||||
ip2,
|
||||
ip3,
|
||||
}
|
||||
add2 := []net.IP{ip3}
|
||||
keep2 := []net.IP{ip1, ip2}
|
||||
lookup.Set("foo", ips2)
|
||||
rec1.Expect(ips2, add2, keep2, nil)
|
||||
rec1.WaitForExpected(ctx)
|
||||
|
||||
ips3 := []net.IP{
|
||||
ip2,
|
||||
ip3,
|
||||
}
|
||||
lookup.Set("foo", ips3)
|
||||
keep3 := []net.IP{ip2, ip3}
|
||||
remove3 := []net.IP{ip1}
|
||||
rec1.Expect(ips3, nil, keep3, remove3)
|
||||
rec1.WaitForExpected(ctx)
|
||||
|
||||
rec1.ExpectNone()
|
||||
time.Sleep(5 * interval)
|
||||
|
||||
lookup.Set("foo", nil)
|
||||
remove4 := []net.IP{ip2, ip3}
|
||||
rec1.Expect(nil, nil, nil, remove4)
|
||||
rec1.WaitForExpected(ctx)
|
||||
|
||||
rec1.ExpectNone()
|
||||
time.Sleep(5 * interval)
|
||||
|
||||
// Removing multiple times is supported.
|
||||
monitor.Remove(entry1)
|
||||
monitor.Remove(entry1)
|
||||
|
||||
// No more events after removing.
|
||||
lookup.Set("foo", ips1)
|
||||
rec1.ExpectNone()
|
||||
time.Sleep(5 * interval)
|
||||
}
|
||||
|
||||
func TestDnsMonitorIP(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
interval := time.Millisecond
|
||||
monitor := newDnsMonitorForTest(t, interval)
|
||||
|
||||
ip := "192.168.0.1"
|
||||
ips := []net.IP{
|
||||
net.ParseIP(ip),
|
||||
}
|
||||
|
||||
rec1 := newDnsMonitorReceiverForTest(t)
|
||||
rec1.Expect(ips, ips, nil, nil)
|
||||
|
||||
entry, err := monitor.Add(ip+":12345", rec1.OnLookup)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer monitor.Remove(entry)
|
||||
|
||||
rec1.WaitForExpected(ctx)
|
||||
|
||||
rec1.ExpectNone()
|
||||
time.Sleep(5 * interval)
|
||||
}
|
||||
|
||||
func TestDnsMonitorNoLookupIfEmpty(t *testing.T) {
|
||||
interval := time.Millisecond
|
||||
monitor := newDnsMonitorForTest(t, interval)
|
||||
|
||||
var checked atomic.Bool
|
||||
monitor.checkHostnames = func() {
|
||||
checked.Store(true)
|
||||
monitor.doCheckHostnames()
|
||||
}
|
||||
|
||||
time.Sleep(10 * interval)
|
||||
if checked.Load() {
|
||||
t.Error("should not have checked hostnames")
|
||||
}
|
||||
}
|
262
grpc_client.go
262
grpc_client.go
|
@ -49,8 +49,6 @@ const (
|
|||
)
|
||||
|
||||
var (
|
||||
lookupGrpcIp = net.LookupIP // can be overwritten from tests
|
||||
|
||||
customResolverPrefix atomic.Uint64
|
||||
)
|
||||
|
||||
|
@ -258,15 +256,19 @@ func (c *GrpcClient) GetSessionCount(ctx context.Context, u *url.URL) (uint32, e
|
|||
return response.GetCount(), nil
|
||||
}
|
||||
|
||||
type grpcClientsList struct {
|
||||
clients []*GrpcClient
|
||||
entry *DnsMonitorEntry
|
||||
}
|
||||
|
||||
type GrpcClients struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
clientsMap map[string][]*GrpcClient
|
||||
clientsMap map[string]*grpcClientsList
|
||||
clients []*GrpcClient
|
||||
|
||||
dnsMonitor *DnsMonitor
|
||||
dnsDiscovery bool
|
||||
stopping chan struct{}
|
||||
stopped chan struct{}
|
||||
|
||||
etcdClient *EtcdClient
|
||||
targetPrefix string
|
||||
|
@ -280,15 +282,13 @@ type GrpcClients struct {
|
|||
selfCheckWaitGroup sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewGrpcClients(config *goconf.ConfigFile, etcdClient *EtcdClient) (*GrpcClients, error) {
|
||||
func NewGrpcClients(config *goconf.ConfigFile, etcdClient *EtcdClient, dnsMonitor *DnsMonitor) (*GrpcClients, error) {
|
||||
initializedCtx, initializedFunc := context.WithCancel(context.Background())
|
||||
result := &GrpcClients{
|
||||
dnsMonitor: dnsMonitor,
|
||||
etcdClient: etcdClient,
|
||||
initializedCtx: initializedCtx,
|
||||
initializedFunc: initializedFunc,
|
||||
|
||||
stopping: make(chan struct{}, 1),
|
||||
stopped: make(chan struct{}, 1),
|
||||
}
|
||||
if err := result.load(config, false); err != nil {
|
||||
return nil, err
|
||||
|
@ -313,9 +313,6 @@ func (c *GrpcClients) load(config *goconf.ConfigFile, fromReload bool) error {
|
|||
switch targetType {
|
||||
case GrpcTargetTypeStatic:
|
||||
err = c.loadTargetsStatic(config, fromReload, opts...)
|
||||
if err == nil && c.dnsDiscovery {
|
||||
go c.monitorGrpcIPs()
|
||||
}
|
||||
case GrpcTargetTypeEtcd:
|
||||
err = c.loadTargetsEtcd(config, fromReload, opts...)
|
||||
default:
|
||||
|
@ -344,7 +341,7 @@ func (c *GrpcClients) isClientAvailable(target string, client *GrpcClient) bool
|
|||
return false
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
for _, entry := range entries.clients {
|
||||
if entry == client {
|
||||
return true
|
||||
}
|
||||
|
@ -401,7 +398,20 @@ func (c *GrpcClients) loadTargetsStatic(config *goconf.ConfigFile, fromReload bo
|
|||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
clientsMap := make(map[string][]*GrpcClient)
|
||||
dnsDiscovery, _ := config.GetBool("grpc", "dnsdiscovery")
|
||||
if dnsDiscovery != c.dnsDiscovery {
|
||||
if !dnsDiscovery {
|
||||
for _, entry := range c.clientsMap {
|
||||
if entry.entry != nil {
|
||||
c.dnsMonitor.Remove(entry.entry)
|
||||
entry.entry = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
c.dnsDiscovery = dnsDiscovery
|
||||
}
|
||||
|
||||
clientsMap := make(map[string]*grpcClientsList)
|
||||
var clients []*GrpcClient
|
||||
removeTargets := make(map[string]bool, len(c.clientsMap))
|
||||
for target, entries := range c.clientsMap {
|
||||
|
@ -417,7 +427,15 @@ func (c *GrpcClients) loadTargetsStatic(config *goconf.ConfigFile, fromReload bo
|
|||
}
|
||||
|
||||
if entries, found := clientsMap[target]; found {
|
||||
clients = append(clients, entries...)
|
||||
clients = append(clients, entries.clients...)
|
||||
if dnsDiscovery && entries.entry == nil {
|
||||
entry, err := c.dnsMonitor.Add(target, c.onLookup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
entries.entry = entry
|
||||
}
|
||||
delete(removeTargets, target)
|
||||
continue
|
||||
}
|
||||
|
@ -427,61 +445,58 @@ func (c *GrpcClients) loadTargetsStatic(config *goconf.ConfigFile, fromReload bo
|
|||
host = h
|
||||
}
|
||||
|
||||
var ips []net.IP
|
||||
if net.ParseIP(host) == nil {
|
||||
if dnsDiscovery && net.ParseIP(host) == nil {
|
||||
// Use dedicated client for each IP address.
|
||||
var err error
|
||||
ips, err = lookupGrpcIp(host)
|
||||
entry, err := c.dnsMonitor.Add(target, c.onLookup)
|
||||
if err != nil {
|
||||
log.Printf("Could not lookup %s: %s", host, err)
|
||||
// Make sure updating continues even if initial lookup failed.
|
||||
clientsMap[target] = nil
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
// Connect directly to IP address.
|
||||
ips = []net.IP{nil}
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
client, err := NewGrpcClient(target, ip, opts...)
|
||||
if err != nil {
|
||||
for _, clients := range clientsMap {
|
||||
for _, client := range clients {
|
||||
c.closeClient(client)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
c.selfCheckWaitGroup.Add(1)
|
||||
go c.checkIsSelf(context.Background(), target, client)
|
||||
|
||||
log.Printf("Adding %s as GRPC target", client.Target())
|
||||
clientsMap[target] = append(clientsMap[target], client)
|
||||
clients = append(clients, client)
|
||||
clientsMap[target] = &grpcClientsList{
|
||||
entry: entry,
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
client, err := NewGrpcClient(target, nil, opts...)
|
||||
if err != nil {
|
||||
for _, entry := range clientsMap {
|
||||
for _, client := range entry.clients {
|
||||
c.closeClient(client)
|
||||
}
|
||||
|
||||
if entry.entry != nil {
|
||||
c.dnsMonitor.Remove(entry.entry)
|
||||
entry.entry = nil
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
c.selfCheckWaitGroup.Add(1)
|
||||
go c.checkIsSelf(context.Background(), target, client)
|
||||
|
||||
log.Printf("Adding %s as GRPC target", client.Target())
|
||||
entry, found := clientsMap[target]
|
||||
if !found {
|
||||
entry = &grpcClientsList{}
|
||||
}
|
||||
entry.clients = append(entry.clients, client)
|
||||
clients = append(clients, client)
|
||||
}
|
||||
|
||||
for target := range removeTargets {
|
||||
if clients, found := clientsMap[target]; found {
|
||||
for _, client := range clients {
|
||||
if entry, found := clientsMap[target]; found {
|
||||
for _, client := range entry.clients {
|
||||
log.Printf("Deleting GRPC target %s", client.Target())
|
||||
c.closeClient(client)
|
||||
}
|
||||
delete(clientsMap, target)
|
||||
}
|
||||
}
|
||||
|
||||
dnsDiscovery, _ := config.GetBool("grpc", "dnsdiscovery")
|
||||
if dnsDiscovery != c.dnsDiscovery {
|
||||
if !dnsDiscovery && fromReload {
|
||||
c.stopping <- struct{}{}
|
||||
<-c.stopped
|
||||
}
|
||||
c.dnsDiscovery = dnsDiscovery
|
||||
if dnsDiscovery && fromReload {
|
||||
go c.monitorGrpcIPs()
|
||||
if entry.entry != nil {
|
||||
c.dnsMonitor.Remove(entry.entry)
|
||||
entry.entry = nil
|
||||
}
|
||||
delete(clientsMap, target)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -492,91 +507,61 @@ func (c *GrpcClients) loadTargetsStatic(config *goconf.ConfigFile, fromReload bo
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *GrpcClients) monitorGrpcIPs() {
|
||||
log.Printf("Start monitoring GRPC client IPs")
|
||||
ticker := time.NewTicker(updateDnsInterval)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
c.updateGrpcIPs()
|
||||
case <-c.stopping:
|
||||
c.stopped <- struct{}{}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *GrpcClients) updateGrpcIPs() {
|
||||
func (c *GrpcClients) onLookup(entry *DnsMonitorEntry, all []net.IP, added []net.IP, keep []net.IP, removed []net.IP) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
target := entry.URL()
|
||||
e, found := c.clientsMap[target]
|
||||
if !found {
|
||||
return
|
||||
}
|
||||
|
||||
opts := c.dialOptions.Load().([]grpc.DialOption)
|
||||
|
||||
mapModified := false
|
||||
for target, clients := range c.clientsMap {
|
||||
host := target
|
||||
if h, _, err := net.SplitHostPort(target); err == nil {
|
||||
host = h
|
||||
}
|
||||
|
||||
if net.ParseIP(host) != nil {
|
||||
// No need to lookup endpoints that connect to IP addresses.
|
||||
continue
|
||||
}
|
||||
|
||||
ips, err := lookupGrpcIp(host)
|
||||
if err != nil {
|
||||
log.Printf("Could not lookup %s: %s", host, err)
|
||||
continue
|
||||
}
|
||||
|
||||
var newClients []*GrpcClient
|
||||
changed := false
|
||||
for _, client := range clients {
|
||||
found := false
|
||||
for idx, ip := range ips {
|
||||
if ip.Equal(client.ip) {
|
||||
ips = append(ips[:idx], ips[idx+1:]...)
|
||||
found = true
|
||||
newClients = append(newClients, client)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
changed = true
|
||||
var newClients []*GrpcClient
|
||||
for _, ip := range removed {
|
||||
for _, client := range e.clients {
|
||||
if ip.Equal(client.ip) {
|
||||
mapModified = true
|
||||
log.Printf("Removing connection to %s", client.Target())
|
||||
c.closeClient(client)
|
||||
c.wakeupForTesting()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
client, err := NewGrpcClient(target, ip, opts...)
|
||||
if err != nil {
|
||||
log.Printf("Error creating client to %s with IP %s: %s", target, ip.String(), err)
|
||||
continue
|
||||
for _, ip := range keep {
|
||||
for _, client := range e.clients {
|
||||
if ip.Equal(client.ip) {
|
||||
newClients = append(newClients, client)
|
||||
}
|
||||
|
||||
c.selfCheckWaitGroup.Add(1)
|
||||
go c.checkIsSelf(context.Background(), target, client)
|
||||
|
||||
log.Printf("Adding %s as GRPC target", client.Target())
|
||||
newClients = append(newClients, client)
|
||||
changed = true
|
||||
c.wakeupForTesting()
|
||||
}
|
||||
|
||||
if changed {
|
||||
c.clientsMap[target] = newClients
|
||||
mapModified = true
|
||||
}
|
||||
}
|
||||
|
||||
for _, ip := range added {
|
||||
client, err := NewGrpcClient(target, ip, opts...)
|
||||
if err != nil {
|
||||
log.Printf("Error creating client to %s with IP %s: %s", target, ip.String(), err)
|
||||
continue
|
||||
}
|
||||
|
||||
c.selfCheckWaitGroup.Add(1)
|
||||
go c.checkIsSelf(context.Background(), target, client)
|
||||
|
||||
log.Printf("Adding %s as GRPC target", client.Target())
|
||||
newClients = append(newClients, client)
|
||||
mapModified = true
|
||||
c.wakeupForTesting()
|
||||
}
|
||||
|
||||
if mapModified {
|
||||
c.clientsMap[target].clients = newClients
|
||||
|
||||
c.clients = make([]*GrpcClient, 0, len(c.clientsMap))
|
||||
for _, clients := range c.clientsMap {
|
||||
c.clients = append(c.clients, clients...)
|
||||
for _, entry := range c.clientsMap {
|
||||
c.clients = append(c.clients, entry.clients...)
|
||||
}
|
||||
statsGrpcClients.Set(float64(len(c.clients)))
|
||||
}
|
||||
|
@ -684,9 +669,11 @@ func (c *GrpcClients) EtcdKeyUpdated(client *EtcdClient, key string, data []byte
|
|||
log.Printf("Adding %s as GRPC target", cl.Target())
|
||||
|
||||
if c.clientsMap == nil {
|
||||
c.clientsMap = make(map[string][]*GrpcClient)
|
||||
c.clientsMap = make(map[string]*grpcClientsList)
|
||||
}
|
||||
c.clientsMap[info.Address] = &grpcClientsList{
|
||||
clients: []*GrpcClient{cl},
|
||||
}
|
||||
c.clientsMap[info.Address] = []*GrpcClient{cl}
|
||||
c.clients = append(c.clients, cl)
|
||||
c.targetInformation[key] = &info
|
||||
statsGrpcClients.Inc()
|
||||
|
@ -709,19 +696,19 @@ func (c *GrpcClients) removeEtcdClientLocked(key string) {
|
|||
}
|
||||
|
||||
delete(c.targetInformation, key)
|
||||
clients, found := c.clientsMap[info.Address]
|
||||
entry, found := c.clientsMap[info.Address]
|
||||
if !found {
|
||||
return
|
||||
}
|
||||
|
||||
for _, client := range clients {
|
||||
for _, client := range entry.clients {
|
||||
log.Printf("Removing connection to %s (from %s)", client.Target(), key)
|
||||
c.closeClient(client)
|
||||
}
|
||||
delete(c.clientsMap, info.Address)
|
||||
c.clients = make([]*GrpcClient, 0, len(c.clientsMap))
|
||||
for _, clients := range c.clientsMap {
|
||||
c.clients = append(c.clients, clients...)
|
||||
for _, entry := range c.clientsMap {
|
||||
c.clients = append(c.clients, entry.clients...)
|
||||
}
|
||||
statsGrpcClients.Dec()
|
||||
c.wakeupForTesting()
|
||||
|
@ -757,21 +744,22 @@ func (c *GrpcClients) Close() {
|
|||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
for _, clients := range c.clientsMap {
|
||||
for _, client := range clients {
|
||||
for _, entry := range c.clientsMap {
|
||||
for _, client := range entry.clients {
|
||||
if err := client.Close(); err != nil {
|
||||
log.Printf("Error closing client to %s: %s", client.Target(), err)
|
||||
}
|
||||
}
|
||||
|
||||
if entry.entry != nil {
|
||||
c.dnsMonitor.Remove(entry.entry)
|
||||
entry.entry = nil
|
||||
}
|
||||
}
|
||||
|
||||
c.clients = nil
|
||||
c.clientsMap = nil
|
||||
if c.dnsDiscovery {
|
||||
c.stopping <- struct{}{}
|
||||
<-c.stopped
|
||||
c.dnsDiscovery = false
|
||||
}
|
||||
c.dnsDiscovery = false
|
||||
|
||||
if c.etcdClient != nil {
|
||||
c.etcdClient.RemoveListener(c)
|
||||
|
|
|
@ -37,6 +37,9 @@ import (
|
|||
)
|
||||
|
||||
func (c *GrpcClients) getWakeupChannelForTesting() <-chan struct{} {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.wakeupChanForTesting != nil {
|
||||
return c.wakeupChanForTesting
|
||||
}
|
||||
|
@ -46,8 +49,9 @@ func (c *GrpcClients) getWakeupChannelForTesting() <-chan struct{} {
|
|||
return ch
|
||||
}
|
||||
|
||||
func NewGrpcClientsForTestWithConfig(t *testing.T, config *goconf.ConfigFile, etcdClient *EtcdClient) *GrpcClients {
|
||||
client, err := NewGrpcClients(config, etcdClient)
|
||||
func NewGrpcClientsForTestWithConfig(t *testing.T, config *goconf.ConfigFile, etcdClient *EtcdClient) (*GrpcClients, *DnsMonitor) {
|
||||
dnsMonitor := newDnsMonitorForTest(t, time.Hour) // will be updated manually
|
||||
client, err := NewGrpcClients(config, etcdClient, dnsMonitor)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -55,10 +59,10 @@ func NewGrpcClientsForTestWithConfig(t *testing.T, config *goconf.ConfigFile, et
|
|||
client.Close()
|
||||
})
|
||||
|
||||
return client
|
||||
return client, dnsMonitor
|
||||
}
|
||||
|
||||
func NewGrpcClientsForTest(t *testing.T, addr string) *GrpcClients {
|
||||
func NewGrpcClientsForTest(t *testing.T, addr string) (*GrpcClients, *DnsMonitor) {
|
||||
config := goconf.NewConfigFile()
|
||||
config.AddOption("grpc", "targets", addr)
|
||||
config.AddOption("grpc", "dnsdiscovery", "true")
|
||||
|
@ -66,7 +70,7 @@ func NewGrpcClientsForTest(t *testing.T, addr string) *GrpcClients {
|
|||
return NewGrpcClientsForTestWithConfig(t, config, nil)
|
||||
}
|
||||
|
||||
func NewGrpcClientsWithEtcdForTest(t *testing.T, etcd *embed.Etcd) *GrpcClients {
|
||||
func NewGrpcClientsWithEtcdForTest(t *testing.T, etcd *embed.Etcd) (*GrpcClients, *DnsMonitor) {
|
||||
config := goconf.NewConfigFile()
|
||||
config.AddOption("etcd", "endpoints", etcd.Config().ListenClientUrls[0].String())
|
||||
|
||||
|
@ -116,7 +120,7 @@ func Test_GrpcClients_EtcdInitial(t *testing.T) {
|
|||
SetEtcdValue(etcd, "/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}"))
|
||||
SetEtcdValue(etcd, "/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}"))
|
||||
|
||||
client := NewGrpcClientsWithEtcdForTest(t, etcd)
|
||||
client, _ := NewGrpcClientsWithEtcdForTest(t, etcd)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
if err := client.WaitForInitialized(ctx); err != nil {
|
||||
|
@ -130,7 +134,7 @@ func Test_GrpcClients_EtcdInitial(t *testing.T) {
|
|||
|
||||
func Test_GrpcClients_EtcdUpdate(t *testing.T) {
|
||||
etcd := NewEtcdForTest(t)
|
||||
client := NewGrpcClientsWithEtcdForTest(t, etcd)
|
||||
client, _ := NewGrpcClientsWithEtcdForTest(t, etcd)
|
||||
ch := client.getWakeupChannelForTesting()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
|
@ -184,7 +188,7 @@ func Test_GrpcClients_EtcdUpdate(t *testing.T) {
|
|||
|
||||
func Test_GrpcClients_EtcdIgnoreSelf(t *testing.T) {
|
||||
etcd := NewEtcdForTest(t)
|
||||
client := NewGrpcClientsWithEtcdForTest(t, etcd)
|
||||
client, _ := NewGrpcClientsWithEtcdForTest(t, etcd)
|
||||
ch := client.getWakeupChannelForTesting()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
|
@ -227,26 +231,20 @@ func Test_GrpcClients_EtcdIgnoreSelf(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_GrpcClients_DnsDiscovery(t *testing.T) {
|
||||
var ipsResult []net.IP
|
||||
lookupGrpcIp = func(host string) ([]net.IP, error) {
|
||||
if host == "testgrpc" {
|
||||
return ipsResult, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unknown host")
|
||||
}
|
||||
lookup := newMockDnsLookupForTest(t)
|
||||
target := "testgrpc:12345"
|
||||
ip1 := net.ParseIP("192.168.0.1")
|
||||
ip2 := net.ParseIP("192.168.0.2")
|
||||
targetWithIp1 := fmt.Sprintf("%s (%s)", target, ip1)
|
||||
targetWithIp2 := fmt.Sprintf("%s (%s)", target, ip2)
|
||||
ipsResult = []net.IP{ip1}
|
||||
client := NewGrpcClientsForTest(t, target)
|
||||
lookup.Set("testgrpc", []net.IP{ip1})
|
||||
client, dnsMonitor := NewGrpcClientsForTest(t, target)
|
||||
ch := client.getWakeupChannelForTesting()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
|
||||
dnsMonitor.checkHostnames()
|
||||
if clients := client.GetClients(); len(clients) != 1 {
|
||||
t.Errorf("Expected one client, got %+v", clients)
|
||||
} else if clients[0].Target() != targetWithIp1 {
|
||||
|
@ -255,9 +253,9 @@ func Test_GrpcClients_DnsDiscovery(t *testing.T) {
|
|||
t.Errorf("Expected IP %s, got %s", ip1, clients[0].ip)
|
||||
}
|
||||
|
||||
ipsResult = []net.IP{ip1, ip2}
|
||||
lookup.Set("testgrpc", []net.IP{ip1, ip2})
|
||||
drainWakeupChannel(ch)
|
||||
client.updateGrpcIPs()
|
||||
dnsMonitor.checkHostnames()
|
||||
waitForEvent(ctx, t, ch)
|
||||
|
||||
if clients := client.GetClients(); len(clients) != 2 {
|
||||
|
@ -272,9 +270,9 @@ func Test_GrpcClients_DnsDiscovery(t *testing.T) {
|
|||
t.Errorf("Expected IP %s, got %s", ip2, clients[1].ip)
|
||||
}
|
||||
|
||||
ipsResult = []net.IP{ip2}
|
||||
lookup.Set("testgrpc", []net.IP{ip2})
|
||||
drainWakeupChannel(ch)
|
||||
client.updateGrpcIPs()
|
||||
dnsMonitor.checkHostnames()
|
||||
waitForEvent(ctx, t, ch)
|
||||
|
||||
if clients := client.GetClients(); len(clients) != 1 {
|
||||
|
@ -287,22 +285,11 @@ func Test_GrpcClients_DnsDiscovery(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_GrpcClients_DnsDiscoveryInitialFailed(t *testing.T) {
|
||||
var ipsResult []net.IP
|
||||
lookupGrpcIp = func(host string) ([]net.IP, error) {
|
||||
if host == "testgrpc" && len(ipsResult) > 0 {
|
||||
return ipsResult, nil
|
||||
}
|
||||
|
||||
return nil, &net.DNSError{
|
||||
Err: "no such host",
|
||||
Name: host,
|
||||
IsNotFound: true,
|
||||
}
|
||||
}
|
||||
lookup := newMockDnsLookupForTest(t)
|
||||
target := "testgrpc:12345"
|
||||
ip1 := net.ParseIP("192.168.0.1")
|
||||
targetWithIp1 := fmt.Sprintf("%s (%s)", target, ip1)
|
||||
client := NewGrpcClientsForTest(t, target)
|
||||
client, dnsMonitor := NewGrpcClientsForTest(t, target)
|
||||
ch := client.getWakeupChannelForTesting()
|
||||
|
||||
testCtx, testCtxCancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
|
@ -318,9 +305,9 @@ func Test_GrpcClients_DnsDiscoveryInitialFailed(t *testing.T) {
|
|||
t.Errorf("Expected no client, got %+v", clients)
|
||||
}
|
||||
|
||||
ipsResult = []net.IP{ip1}
|
||||
lookup.Set("testgrpc", []net.IP{ip1})
|
||||
drainWakeupChannel(ch)
|
||||
client.updateGrpcIPs()
|
||||
dnsMonitor.checkHostnames()
|
||||
waitForEvent(testCtx, t, ch)
|
||||
|
||||
if clients := client.GetClients(); len(clients) != 1 {
|
||||
|
@ -370,7 +357,7 @@ func Test_GrpcClients_Encryption(t *testing.T) {
|
|||
clientConfig.AddOption("grpc", "clientcertificate", clientCertFile)
|
||||
clientConfig.AddOption("grpc", "clientkey", clientPrivkeyFile)
|
||||
clientConfig.AddOption("grpc", "serverca", serverCertFile)
|
||||
clients := NewGrpcClientsForTestWithConfig(t, clientConfig, nil)
|
||||
clients, _ := NewGrpcClientsForTestWithConfig(t, clientConfig, nil)
|
||||
|
||||
ctx, cancel1 := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel1()
|
||||
|
|
|
@ -211,7 +211,7 @@ func CreateClusteredHubsForTestWithConfig(t *testing.T, getConfigFunc func(*http
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
client1 := NewGrpcClientsForTest(t, addr2)
|
||||
client1, _ := NewGrpcClientsForTest(t, addr2)
|
||||
h1, err := NewHub(config1, events1, grpcServer1, client1, nil, r1, "no-version")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -231,7 +231,7 @@ func CreateClusteredHubsForTestWithConfig(t *testing.T, getConfigFunc func(*http
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
client2 := NewGrpcClientsForTest(t, addr1)
|
||||
client2, _ := NewGrpcClientsForTest(t, addr1)
|
||||
h2, err := NewHub(config2, events2, grpcServer2, client2, nil, r2, "no-version")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
|
@ -66,9 +66,6 @@ const (
|
|||
defaultProxyTimeoutSeconds = 2
|
||||
|
||||
rttLogDuration = 500 * time.Millisecond
|
||||
|
||||
// Update service IP addresses every 10 seconds.
|
||||
updateDnsInterval = 10 * time.Second
|
||||
)
|
||||
|
||||
type McuProxy interface {
|
||||
|
@ -1123,7 +1120,7 @@ type mcuProxy struct {
|
|||
rpcClients *GrpcClients
|
||||
}
|
||||
|
||||
func NewMcuProxy(config *goconf.ConfigFile, etcdClient *EtcdClient, rpcClients *GrpcClients) (Mcu, error) {
|
||||
func NewMcuProxy(config *goconf.ConfigFile, etcdClient *EtcdClient, rpcClients *GrpcClients, dnsMonitor *DnsMonitor) (Mcu, error) {
|
||||
urlType, _ := config.GetString("mcu", "urltype")
|
||||
if urlType == "" {
|
||||
urlType = proxyUrlTypeStatic
|
||||
|
@ -1196,7 +1193,7 @@ func NewMcuProxy(config *goconf.ConfigFile, etcdClient *EtcdClient, rpcClients *
|
|||
|
||||
switch urlType {
|
||||
case proxyUrlTypeStatic:
|
||||
mcu.config, err = NewProxyConfigStatic(config, mcu)
|
||||
mcu.config, err = NewProxyConfigStatic(config, mcu, dnsMonitor)
|
||||
case proxyUrlTypeEtcd:
|
||||
mcu.config, err = NewProxyConfigEtcd(config, etcdClient, mcu)
|
||||
default:
|
||||
|
|
|
@ -22,15 +22,9 @@
|
|||
package signaling
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/dlintw/goconf"
|
||||
)
|
||||
|
||||
var (
|
||||
lookupProxyIP = net.LookupIP
|
||||
)
|
||||
|
||||
type ProxyConfig interface {
|
||||
Start() error
|
||||
Stop()
|
||||
|
|
|
@ -28,8 +28,6 @@ import (
|
|||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/dlintw/goconf"
|
||||
)
|
||||
|
@ -37,27 +35,24 @@ import (
|
|||
type ipList struct {
|
||||
hostname string
|
||||
|
||||
ips []net.IP
|
||||
entry *DnsMonitorEntry
|
||||
ips []net.IP
|
||||
}
|
||||
|
||||
type proxyConfigStatic struct {
|
||||
mu sync.Mutex
|
||||
proxy McuProxy
|
||||
|
||||
dnsDiscovery atomic.Bool
|
||||
stopping chan struct{}
|
||||
stopped chan struct{}
|
||||
dnsMonitor *DnsMonitor
|
||||
dnsDiscovery bool
|
||||
|
||||
connectionsMap map[string]*ipList
|
||||
}
|
||||
|
||||
func NewProxyConfigStatic(config *goconf.ConfigFile, proxy McuProxy) (ProxyConfig, error) {
|
||||
func NewProxyConfigStatic(config *goconf.ConfigFile, proxy McuProxy, dnsMonitor *DnsMonitor) (ProxyConfig, error) {
|
||||
result := &proxyConfigStatic{
|
||||
proxy: proxy,
|
||||
|
||||
stopping: make(chan struct{}, 1),
|
||||
stopped: make(chan struct{}, 1),
|
||||
|
||||
proxy: proxy,
|
||||
dnsMonitor: dnsMonitor,
|
||||
connectionsMap: make(map[string]*ipList),
|
||||
}
|
||||
if err := result.configure(config, false); err != nil {
|
||||
|
@ -70,19 +65,22 @@ func NewProxyConfigStatic(config *goconf.ConfigFile, proxy McuProxy) (ProxyConfi
|
|||
}
|
||||
|
||||
func (p *proxyConfigStatic) configure(config *goconf.ConfigFile, fromReload bool) error {
|
||||
dnsDiscovery, _ := config.GetBool("mcu", "dnsdiscovery")
|
||||
if p.dnsDiscovery.CompareAndSwap(!dnsDiscovery, dnsDiscovery) && fromReload {
|
||||
if !dnsDiscovery {
|
||||
p.stopping <- struct{}{}
|
||||
<-p.stopped
|
||||
} else {
|
||||
go p.monitorProxyIPs()
|
||||
}
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
dnsDiscovery, _ := config.GetBool("mcu", "dnsdiscovery")
|
||||
if dnsDiscovery != p.dnsDiscovery {
|
||||
if !dnsDiscovery {
|
||||
for _, ips := range p.connectionsMap {
|
||||
if ips.entry != nil {
|
||||
p.dnsMonitor.Remove(ips.entry)
|
||||
ips.entry = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
p.dnsDiscovery = dnsDiscovery
|
||||
}
|
||||
|
||||
remove := make(map[string]*ipList)
|
||||
for u, ips := range p.connectionsMap {
|
||||
remove[u] = ips
|
||||
|
@ -116,18 +114,15 @@ func (p *proxyConfigStatic) configure(config *goconf.ConfigFile, fromReload bool
|
|||
parsed.Host = host
|
||||
}
|
||||
|
||||
var ips []net.IP
|
||||
if dnsDiscovery {
|
||||
ips, err = lookupProxyIP(parsed.Host)
|
||||
if err != nil {
|
||||
// Will be retried later.
|
||||
log.Printf("Could not lookup %s: %s\n", parsed.Host, err)
|
||||
continue
|
||||
p.connectionsMap[u] = &ipList{
|
||||
hostname: parsed.Host,
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if fromReload {
|
||||
if err := p.proxy.AddConnection(fromReload, u, ips...); err != nil {
|
||||
if err := p.proxy.AddConnection(fromReload, u); err != nil {
|
||||
if !fromReload {
|
||||
return err
|
||||
}
|
||||
|
@ -139,7 +134,6 @@ func (p *proxyConfigStatic) configure(config *goconf.ConfigFile, fromReload bool
|
|||
|
||||
p.connectionsMap[u] = &ipList{
|
||||
hostname: parsed.Host,
|
||||
ips: ips,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -155,92 +149,53 @@ func (p *proxyConfigStatic) Start() error {
|
|||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
for u, ipList := range p.connectionsMap {
|
||||
if err := p.proxy.AddConnection(false, u, ipList.ips...); err != nil {
|
||||
return err
|
||||
if p.dnsDiscovery {
|
||||
for u, ips := range p.connectionsMap {
|
||||
entry, err := p.dnsMonitor.Add(u, p.onLookup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ips.entry = entry
|
||||
}
|
||||
} else {
|
||||
for u, ipList := range p.connectionsMap {
|
||||
if err := p.proxy.AddConnection(false, u, ipList.ips...); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if p.dnsDiscovery.Load() {
|
||||
go p.monitorProxyIPs()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *proxyConfigStatic) Stop() {
|
||||
if p.dnsDiscovery.CompareAndSwap(true, false) {
|
||||
p.stopping <- struct{}{}
|
||||
<-p.stopped
|
||||
}
|
||||
}
|
||||
|
||||
func (p *proxyConfigStatic) Reload(config *goconf.ConfigFile) error {
|
||||
return p.configure(config, true)
|
||||
}
|
||||
|
||||
func (p *proxyConfigStatic) monitorProxyIPs() {
|
||||
log.Printf("Start monitoring proxy IPs")
|
||||
ticker := time.NewTicker(updateDnsInterval)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
p.updateProxyIPs()
|
||||
case <-p.stopping:
|
||||
p.stopped <- struct{}{}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *proxyConfigStatic) updateProxyIPs() {
|
||||
func (p *proxyConfigStatic) onLookup(entry *DnsMonitorEntry, all []net.IP, added []net.IP, keep []net.IP, removed []net.IP) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
for u, iplist := range p.connectionsMap {
|
||||
if len(iplist.ips) == 0 {
|
||||
continue
|
||||
}
|
||||
u := entry.URL()
|
||||
for _, ip := range keep {
|
||||
p.proxy.KeepConnection(u, ip)
|
||||
}
|
||||
|
||||
if net.ParseIP(iplist.hostname) != nil {
|
||||
// No need to lookup endpoints that connect to IP addresses.
|
||||
continue
|
||||
}
|
||||
|
||||
ips, err := lookupProxyIP(iplist.hostname)
|
||||
if err != nil {
|
||||
log.Printf("Could not lookup %s: %s", iplist.hostname, err)
|
||||
continue
|
||||
}
|
||||
|
||||
var newIPs []net.IP
|
||||
var removedIPs []net.IP
|
||||
for _, oldIP := range iplist.ips {
|
||||
found := false
|
||||
for idx, newIP := range ips {
|
||||
if oldIP.Equal(newIP) {
|
||||
ips = append(ips[:idx], ips[idx+1:]...)
|
||||
found = true
|
||||
p.proxy.KeepConnection(u, oldIP)
|
||||
newIPs = append(newIPs, oldIP)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
removedIPs = append(removedIPs, oldIP)
|
||||
}
|
||||
}
|
||||
|
||||
if len(ips) > 0 {
|
||||
newIPs = append(newIPs, ips...)
|
||||
if err := p.proxy.AddConnection(true, u, ips...); err != nil {
|
||||
log.Printf("Could not add proxy connection to %s with %+v: %s", u, ips, err)
|
||||
}
|
||||
}
|
||||
iplist.ips = newIPs
|
||||
|
||||
if len(removedIPs) > 0 {
|
||||
p.proxy.RemoveConnection(u, removedIPs...)
|
||||
if len(added) > 0 {
|
||||
if err := p.proxy.AddConnection(true, u, added...); err != nil {
|
||||
log.Printf("Could not add proxy connection to %s with %+v: %s", u, added, err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(removed) > 0 {
|
||||
p.proxy.RemoveConnection(u, removed...)
|
||||
}
|
||||
|
||||
if ipList, found := p.connectionsMap[u]; found {
|
||||
ipList.ips = all
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,24 +25,26 @@ import (
|
|||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/dlintw/goconf"
|
||||
)
|
||||
|
||||
func newProxyConfigStatic(t *testing.T, proxy McuProxy, dns bool, urls ...string) ProxyConfig {
|
||||
func newProxyConfigStatic(t *testing.T, proxy McuProxy, dns bool, urls ...string) (ProxyConfig, *DnsMonitor) {
|
||||
cfg := goconf.NewConfigFile()
|
||||
cfg.AddOption("mcu", "url", strings.Join(urls, " "))
|
||||
if dns {
|
||||
cfg.AddOption("mcu", "dnsdiscovery", "true")
|
||||
}
|
||||
p, err := NewProxyConfigStatic(cfg, proxy)
|
||||
dnsMonitor := newDnsMonitorForTest(t, time.Hour) // will be updated manually
|
||||
p, err := NewProxyConfigStatic(cfg, proxy, dnsMonitor)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
p.Stop()
|
||||
})
|
||||
return p
|
||||
return p, dnsMonitor
|
||||
}
|
||||
|
||||
func updateProxyConfigStatic(t *testing.T, config ProxyConfig, dns bool, urls ...string) {
|
||||
|
@ -58,7 +60,7 @@ func updateProxyConfigStatic(t *testing.T, config ProxyConfig, dns bool, urls ..
|
|||
|
||||
func TestProxyConfigStaticSimple(t *testing.T) {
|
||||
proxy := newMcuProxyForConfig(t)
|
||||
config := newProxyConfigStatic(t, proxy, false, "https://foo/")
|
||||
config, _ := newProxyConfigStatic(t, proxy, false, "https://foo/")
|
||||
proxy.Expect("add", "https://foo/")
|
||||
if err := config.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -75,38 +77,31 @@ func TestProxyConfigStaticSimple(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestProxyConfigStaticDNS(t *testing.T) {
|
||||
old := lookupProxyIP
|
||||
t.Cleanup(func() {
|
||||
lookupProxyIP = old
|
||||
})
|
||||
proxyIPs := make(map[string][]net.IP)
|
||||
lookupProxyIP = func(hostname string) ([]net.IP, error) {
|
||||
ips := append([]net.IP{}, proxyIPs[hostname]...)
|
||||
return ips, nil
|
||||
}
|
||||
proxyIPs["foo"] = []net.IP{
|
||||
lookup := newMockDnsLookupForTest(t)
|
||||
lookup.Set("foo", []net.IP{
|
||||
net.ParseIP("192.168.0.1"),
|
||||
net.ParseIP("10.1.2.3"),
|
||||
}
|
||||
})
|
||||
|
||||
proxy := newMcuProxyForConfig(t)
|
||||
config := newProxyConfigStatic(t, proxy, true, "https://foo/").(*proxyConfigStatic)
|
||||
proxy.Expect("add", "https://foo/", proxyIPs["foo"]...)
|
||||
config, dnsMonitor := newProxyConfigStatic(t, proxy, true, "https://foo/")
|
||||
proxy.Expect("add", "https://foo/", lookup.Get("foo")...)
|
||||
if err := config.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
proxyIPs["foo"] = []net.IP{
|
||||
dnsMonitor.checkHostnames()
|
||||
lookup.Set("foo", []net.IP{
|
||||
net.ParseIP("192.168.0.1"),
|
||||
net.ParseIP("192.168.1.1"),
|
||||
net.ParseIP("192.168.1.2"),
|
||||
}
|
||||
})
|
||||
proxy.Expect("keep", "https://foo/", net.ParseIP("192.168.0.1"))
|
||||
proxy.Expect("add", "https://foo/", net.ParseIP("192.168.1.1"), net.ParseIP("192.168.1.2"))
|
||||
proxy.Expect("remove", "https://foo/", net.ParseIP("10.1.2.3"))
|
||||
config.updateProxyIPs()
|
||||
dnsMonitor.checkHostnames()
|
||||
|
||||
proxy.Expect("add", "https://bar/")
|
||||
proxy.Expect("remove", "https://foo/", proxyIPs["foo"]...)
|
||||
proxy.Expect("remove", "https://foo/", lookup.Get("foo")...)
|
||||
updateProxyConfigStatic(t, config, false, "https://bar/")
|
||||
}
|
||||
|
|
|
@ -62,6 +62,8 @@ const (
|
|||
|
||||
initialMcuRetry = time.Second
|
||||
maxMcuRetry = time.Second * 16
|
||||
|
||||
dnsMonitorInterval = time.Second
|
||||
)
|
||||
|
||||
func createListener(addr string) (net.Listener, error) {
|
||||
|
@ -154,6 +156,12 @@ func main() {
|
|||
}
|
||||
defer events.Close()
|
||||
|
||||
dnsMonitor, err := signaling.NewDnsMonitor(dnsMonitorInterval)
|
||||
if err != nil {
|
||||
log.Fatal("Could not create DNS monitor: ", err)
|
||||
}
|
||||
defer dnsMonitor.Stop()
|
||||
|
||||
etcdClient, err := signaling.NewEtcdClient(config, "mcu")
|
||||
if err != nil {
|
||||
log.Fatalf("Could not create etcd client: %s", err)
|
||||
|
@ -175,7 +183,7 @@ func main() {
|
|||
}()
|
||||
defer rpcServer.Close()
|
||||
|
||||
rpcClients, err := signaling.NewGrpcClients(config, etcdClient)
|
||||
rpcClients, err := signaling.NewGrpcClients(config, etcdClient, dnsMonitor)
|
||||
if err != nil {
|
||||
log.Fatalf("Could not create RPC clients: %s", err)
|
||||
}
|
||||
|
@ -209,7 +217,7 @@ func main() {
|
|||
signaling.UnregisterProxyMcuStats()
|
||||
signaling.RegisterJanusMcuStats()
|
||||
case signaling.McuTypeProxy:
|
||||
mcu, err = signaling.NewMcuProxy(config, etcdClient, rpcClients)
|
||||
mcu, err = signaling.NewMcuProxy(config, etcdClient, rpcClients, dnsMonitor)
|
||||
signaling.UnregisterJanusMcuStats()
|
||||
signaling.RegisterProxyMcuStats()
|
||||
default:
|
||||
|
|
Loading…
Reference in a new issue