diff --git a/backend_storage_etcd.go b/backend_storage_etcd.go index de3e66d..a88f71a 100644 --- a/backend_storage_etcd.go +++ b/backend_storage_etcd.go @@ -24,10 +24,10 @@ package signaling import ( "context" "encoding/json" + "errors" "fmt" "log" "net/url" - "sync" "time" "github.com/dlintw/goconf" @@ -43,8 +43,10 @@ type backendStorageEtcd struct { initializedCtx context.Context initializedFunc context.CancelFunc - initializedWg sync.WaitGroup wakeupChanForTesting chan struct{} + + closeCtx context.Context + closeFunc context.CancelFunc } func NewBackendStorageEtcd(config *goconf.ConfigFile, etcdClient *EtcdClient) (BackendStorage, error) { @@ -58,6 +60,7 @@ func NewBackendStorageEtcd(config *goconf.ConfigFile, etcdClient *EtcdClient) (B } initializedCtx, initializedFunc := context.WithCancel(context.Background()) + closeCtx, closeFunc := context.WithCancel(context.Background()) result := &backendStorageEtcd{ backendStorageCommon: backendStorageCommon{ backends: make(map[string][]*Backend), @@ -68,6 +71,8 @@ func NewBackendStorageEtcd(config *goconf.ConfigFile, etcdClient *EtcdClient) (B initializedCtx: initializedCtx, initializedFunc: initializedFunc, + closeCtx: closeCtx, + closeFunc: closeFunc, } etcdClient.AddListener(result) @@ -95,15 +100,12 @@ func (s *backendStorageEtcd) wakeupForTesting() { } func (s *backendStorageEtcd) EtcdClientCreated(client *EtcdClient) { - s.initializedWg.Add(1) go func() { - if err := client.Watch(context.Background(), s.keyPrefix, s, clientv3.WithPrefix()); err != nil { - log.Printf("Error processing watch for %s: %s", s.keyPrefix, err) - } - }() + if err := client.WaitForConnection(s.closeCtx); err != nil { + if errors.Is(err, context.Canceled) { + return + } - go func() { - if err := client.WaitForConnection(context.Background()); err != nil { panic(err) } @@ -111,35 +113,43 @@ func (s *backendStorageEtcd) EtcdClientCreated(client *EtcdClient) { if err != nil { panic(err) } - for { - response, err := s.getBackends(client, s.keyPrefix) + for s.closeCtx.Err() == nil { + response, err := s.getBackends(s.closeCtx, client, s.keyPrefix) if err != nil { - if err == context.DeadlineExceeded { + if errors.Is(err, context.Canceled) { + return + } else if errors.Is(err, context.DeadlineExceeded) { log.Printf("Timeout getting initial list of backends, retry in %s", backoff.NextWait()) } else { log.Printf("Could not get initial list of backends, retry in %s: %s", backoff.NextWait(), err) } - backoff.Wait(context.Background()) + backoff.Wait(s.closeCtx) continue } for _, ev := range response.Kvs { s.EtcdKeyUpdated(client, string(ev.Key), ev.Value) } - s.initializedWg.Wait() s.initializedFunc() + + nextRevision := response.Header.Revision + 1 + for s.closeCtx.Err() == nil { + var err error + if nextRevision, err = client.Watch(s.closeCtx, s.keyPrefix, nextRevision, s, clientv3.WithPrefix()); err != nil { + log.Printf("Error processing watch for %s: %s", s.keyPrefix, err) + } + } return } }() } func (s *backendStorageEtcd) EtcdWatchCreated(client *EtcdClient, key string) { - s.initializedWg.Done() } -func (s *backendStorageEtcd) getBackends(client *EtcdClient, keyPrefix string) (*clientv3.GetResponse, error) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) +func (s *backendStorageEtcd) getBackends(ctx context.Context, client *EtcdClient, keyPrefix string) (*clientv3.GetResponse, error) { + ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() return client.Get(ctx, keyPrefix, clientv3.WithPrefix()) @@ -241,6 +251,7 @@ func (s *backendStorageEtcd) EtcdKeyDeleted(client *EtcdClient, key string) { func (s *backendStorageEtcd) Close() { s.etcdClient.RemoveListener(s) + s.closeFunc() } func (s *backendStorageEtcd) Reload(config *goconf.ConfigFile) { diff --git a/backend_storage_etcd_test.go b/backend_storage_etcd_test.go index bc1f83d..5ab6549 100644 --- a/backend_storage_etcd_test.go +++ b/backend_storage_etcd_test.go @@ -21,6 +21,13 @@ */ package signaling +import ( + "testing" + + "github.com/dlintw/goconf" + "go.etcd.io/etcd/server/v3/embed" +) + func (s *backendStorageEtcd) getWakeupChannelForTesting() <-chan struct{} { s.mu.Lock() defer s.mu.Unlock() @@ -33,3 +40,37 @@ func (s *backendStorageEtcd) getWakeupChannelForTesting() <-chan struct{} { s.wakeupChanForTesting = ch return ch } + +type testListener struct { + etcd *embed.Etcd + closed chan struct{} +} + +func (tl *testListener) EtcdClientCreated(client *EtcdClient) { + tl.etcd.Server.Stop() + close(tl.closed) +} + +func Test_BackendStorageEtcdNoLeak(t *testing.T) { + ensureNoGoroutinesLeak(t, func(t *testing.T) { + etcd, client := NewEtcdClientForTest(t) + tl := &testListener{ + etcd: etcd, + closed: make(chan struct{}), + } + client.AddListener(tl) + defer client.RemoveListener(tl) + + config := goconf.NewConfigFile() + config.AddOption("backend", "backendtype", "etcd") + config.AddOption("backend", "backendprefix", "/backends") + + cfg, err := NewBackendConfiguration(config, client) + if err != nil { + t.Fatal(err) + } + + <-tl.closed + cfg.Close() + }) +} diff --git a/etcd_client.go b/etcd_client.go index 815a20f..6443701 100644 --- a/etcd_client.go +++ b/etcd_client.go @@ -23,6 +23,7 @@ package signaling import ( "context" + "errors" "fmt" "log" "strings" @@ -34,6 +35,8 @@ import ( "go.etcd.io/etcd/client/pkg/v3/srv" "go.etcd.io/etcd/client/pkg/v3/transport" clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" ) type EtcdClientListener interface { @@ -112,6 +115,17 @@ func (c *EtcdClient) load(config *goconf.ConfigFile, ignoreErrors bool) error { DialTimeout: time.Second, } + if logLevel, _ := config.GetString("etcd", "loglevel"); logLevel != "" { + var l zapcore.Level + if err := l.Set(logLevel); err != nil { + return fmt.Errorf("Unsupported etcd log level %s: %w", logLevel, err) + } + + logConfig := zap.NewProductionConfig() + logConfig.Level = zap.NewAtomicLevelAt(l) + cfg.LogConfig = &logConfig + } + clientKey := c.getConfigStringWithFallback(config, "clientkey") clientCert := c.getConfigStringWithFallback(config, "clientcert") caCert := c.getConfigStringWithFallback(config, "cacert") @@ -176,8 +190,8 @@ func (c *EtcdClient) getEtcdClient() *clientv3.Client { return client.(*clientv3.Client) } -func (c *EtcdClient) syncClient() error { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) +func (c *EtcdClient) syncClient(ctx context.Context) error { + ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() return c.getEtcdClient().Sync(ctx) @@ -223,8 +237,10 @@ func (c *EtcdClient) WaitForConnection(ctx context.Context) error { return err } - if err := c.syncClient(); err != nil { - if err == context.DeadlineExceeded { + if err := c.syncClient(ctx); err != nil { + if errors.Is(err, context.Canceled) { + return err + } else if errors.Is(err, context.DeadlineExceeded) { log.Printf("Timeout waiting for etcd client to connect to the cluster, retry in %s", backoff.NextWait()) } else { log.Printf("Could not sync etcd client with the cluster, retry in %s: %s", backoff.NextWait(), err) @@ -243,16 +259,18 @@ func (c *EtcdClient) Get(ctx context.Context, key string, opts ...clientv3.OpOpt return c.getEtcdClient().Get(ctx, key, opts...) } -func (c *EtcdClient) Watch(ctx context.Context, key string, watcher EtcdClientWatcher, opts ...clientv3.OpOption) error { - log.Printf("Wait for leader and start watching on %s", key) +func (c *EtcdClient) Watch(ctx context.Context, key string, nextRevision int64, watcher EtcdClientWatcher, opts ...clientv3.OpOption) (int64, error) { + log.Printf("Wait for leader and start watching on %s (rev=%d)", key, nextRevision) + opts = append(opts, clientv3.WithRev(nextRevision)) ch := c.getEtcdClient().Watch(clientv3.WithRequireLeader(ctx), key, opts...) log.Printf("Watch created for %s", key) watcher.EtcdWatchCreated(c, key) for response := range ch { if err := response.Err(); err != nil { - return err + return nextRevision, err } + nextRevision = response.Header.Revision + 1 for _, ev := range response.Events { switch ev.Type { case clientv3.EventTypePut: @@ -265,5 +283,5 @@ func (c *EtcdClient) Watch(ctx context.Context, key string, watcher EtcdClientWa } } - return nil + return nextRevision, nil } diff --git a/etcd_client_test.go b/etcd_client_test.go index d40bf28..14f718b 100644 --- a/etcd_client_test.go +++ b/etcd_client_test.go @@ -29,7 +29,6 @@ import ( "os" "runtime" "strconv" - "sync" "syscall" "testing" "time" @@ -115,6 +114,7 @@ func NewEtcdClientForTest(t *testing.T) (*embed.Etcd, *EtcdClient) { config := goconf.NewConfigFile() config.AddOption("etcd", "endpoints", etcd.Config().ListenClientUrls[0].String()) + config.AddOption("etcd", "loglevel", "error") client, err := NewEtcdClient(config, "") if err != nil { @@ -204,9 +204,8 @@ type EtcdClientTestListener struct { ctx context.Context cancel context.CancelFunc - initial chan struct{} - initialWg sync.WaitGroup - events chan etcdEvent + initial chan struct{} + events chan etcdEvent } func NewEtcdClientTestListener(ctx context.Context, t *testing.T) *EtcdClientTestListener { @@ -227,15 +226,7 @@ func (l *EtcdClientTestListener) Close() { } func (l *EtcdClientTestListener) EtcdClientCreated(client *EtcdClient) { - l.initialWg.Add(1) go func() { - if err := client.Watch(clientv3.WithRequireLeader(l.ctx), "foo", l, clientv3.WithPrefix()); err != nil { - l.t.Error(err) - } - }() - - go func() { - defer close(l.initial) if err := client.WaitForConnection(l.ctx); err != nil { l.t.Errorf("error waiting for connection: %s", err) return @@ -244,7 +235,8 @@ func (l *EtcdClientTestListener) EtcdClientCreated(client *EtcdClient) { ctx, cancel := context.WithTimeout(l.ctx, time.Second) defer cancel() - if response, err := client.Get(ctx, "foo", clientv3.WithPrefix()); err != nil { + response, err := client.Get(ctx, "foo", clientv3.WithPrefix()) + if err != nil { l.t.Error(err) } else if response.Count != 1 { l.t.Errorf("expected 1 responses, got %d", response.Count) @@ -253,12 +245,19 @@ func (l *EtcdClientTestListener) EtcdClientCreated(client *EtcdClient) { } else if string(response.Kvs[0].Value) != "1" { l.t.Errorf("expected value \"1\", got \"%s\"", string(response.Kvs[0].Value)) } - l.initialWg.Wait() + + close(l.initial) + nextRevision := response.Header.Revision + 1 + for l.ctx.Err() == nil { + var err error + if nextRevision, err = client.Watch(clientv3.WithRequireLeader(l.ctx), "foo", nextRevision, l, clientv3.WithPrefix()); err != nil { + l.t.Error(err) + } + } }() } func (l *EtcdClientTestListener) EtcdWatchCreated(client *EtcdClient, key string) { - l.initialWg.Done() } func (l *EtcdClientTestListener) EtcdKeyUpdated(client *EtcdClient, key string, value []byte) { diff --git a/go.mod b/go.mod index 73d6db2..79c79e7 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( go.etcd.io/etcd/client/pkg/v3 v3.5.12 go.etcd.io/etcd/client/v3 v3.5.12 go.etcd.io/etcd/server/v3 v3.5.12 + go.uber.org/zap v1.17.0 google.golang.org/grpc v1.63.2 google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.3.0 google.golang.org/protobuf v1.33.0 @@ -76,7 +77,6 @@ require ( go.opentelemetry.io/proto/otlp v1.0.0 // indirect go.uber.org/atomic v1.7.0 // indirect go.uber.org/multierr v1.6.0 // indirect - go.uber.org/zap v1.17.0 // indirect golang.org/x/crypto v0.22.0 // indirect golang.org/x/net v0.21.0 // indirect golang.org/x/sys v0.19.0 // indirect diff --git a/grpc_client.go b/grpc_client.go index f9be030..1f62ad5 100644 --- a/grpc_client.go +++ b/grpc_client.go @@ -24,6 +24,7 @@ package signaling import ( "context" "encoding/json" + "errors" "fmt" "log" "net" @@ -277,18 +278,23 @@ type GrpcClients struct { initializedCtx context.Context initializedFunc context.CancelFunc - initializedWg sync.WaitGroup wakeupChanForTesting chan struct{} selfCheckWaitGroup sync.WaitGroup + + closeCtx context.Context + closeFunc context.CancelFunc } func NewGrpcClients(config *goconf.ConfigFile, etcdClient *EtcdClient, dnsMonitor *DnsMonitor) (*GrpcClients, error) { initializedCtx, initializedFunc := context.WithCancel(context.Background()) + closeCtx, closeFunc := context.WithCancel(context.Background()) result := &GrpcClients{ dnsMonitor: dnsMonitor, etcdClient: etcdClient, initializedCtx: initializedCtx, initializedFunc: initializedFunc, + closeCtx: closeCtx, + closeFunc: closeFunc, } if err := result.load(config, false); err != nil { return nil, err @@ -586,48 +592,54 @@ func (c *GrpcClients) loadTargetsEtcd(config *goconf.ConfigFile, fromReload bool } func (c *GrpcClients) EtcdClientCreated(client *EtcdClient) { - c.initializedWg.Add(1) go func() { - if err := client.Watch(context.Background(), c.targetPrefix, c, clientv3.WithPrefix()); err != nil { - log.Printf("Error processing watch for %s: %s", c.targetPrefix, err) - } - }() + if err := client.WaitForConnection(c.closeCtx); err != nil { + if errors.Is(err, context.Canceled) { + return + } - go func() { - if err := client.WaitForConnection(context.Background()); err != nil { panic(err) } backoff, _ := NewExponentialBackoff(initialWaitDelay, maxWaitDelay) - for { - response, err := c.getGrpcTargets(client, c.targetPrefix) + var nextRevision int64 + for c.closeCtx.Err() == nil { + response, err := c.getGrpcTargets(c.closeCtx, client, c.targetPrefix) if err != nil { - if err == context.DeadlineExceeded { + if errors.Is(err, context.Canceled) { + return + } else if errors.Is(err, context.DeadlineExceeded) { log.Printf("Timeout getting initial list of GRPC targets, retry in %s", backoff.NextWait()) } else { log.Printf("Could not get initial list of GRPC targets, retry in %s: %s", backoff.NextWait(), err) } - backoff.Wait(context.Background()) + backoff.Wait(c.closeCtx) continue } for _, ev := range response.Kvs { c.EtcdKeyUpdated(client, string(ev.Key), ev.Value) } - c.initializedWg.Wait() c.initializedFunc() - return + nextRevision = response.Header.Revision + 1 + break + } + + for c.closeCtx.Err() == nil { + var err error + if nextRevision, err = client.Watch(c.closeCtx, c.targetPrefix, nextRevision, c, clientv3.WithPrefix()); err != nil { + log.Printf("Error processing watch for %s: %s", c.targetPrefix, err) + } } }() } func (c *GrpcClients) EtcdWatchCreated(client *EtcdClient, key string) { - c.initializedWg.Done() } -func (c *GrpcClients) getGrpcTargets(client *EtcdClient, targetPrefix string) (*clientv3.GetResponse, error) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) +func (c *GrpcClients) getGrpcTargets(ctx context.Context, client *EtcdClient, targetPrefix string) (*clientv3.GetResponse, error) { + ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() return client.Get(ctx, targetPrefix, clientv3.WithPrefix()) @@ -766,6 +778,7 @@ func (c *GrpcClients) Close() { if c.etcdClient != nil { c.etcdClient.RemoveListener(c) } + c.closeFunc() } func (c *GrpcClients) GetClients() []*GrpcClient { diff --git a/proxy_config_etcd.go b/proxy_config_etcd.go index b03ee0b..3d40ef1 100644 --- a/proxy_config_etcd.go +++ b/proxy_config_etcd.go @@ -41,6 +41,9 @@ type proxyConfigEtcd struct { keyPrefix string keyInfos map[string]*ProxyInformationEtcd urlToKey map[string]string + + closeCtx context.Context + closeFunc context.CancelFunc } func NewProxyConfigEtcd(config *goconf.ConfigFile, etcdClient *EtcdClient, proxy McuProxy) (ProxyConfig, error) { @@ -48,12 +51,17 @@ func NewProxyConfigEtcd(config *goconf.ConfigFile, etcdClient *EtcdClient, proxy return nil, errors.New("No etcd endpoints configured") } + closeCtx, closeFunc := context.WithCancel(context.Background()) + result := &proxyConfigEtcd{ proxy: proxy, client: etcdClient, keyInfos: make(map[string]*ProxyInformationEtcd), urlToKey: make(map[string]string), + + closeCtx: closeCtx, + closeFunc: closeFunc, } if err := result.configure(config, false); err != nil { return nil, err @@ -83,17 +91,16 @@ func (p *proxyConfigEtcd) Reload(config *goconf.ConfigFile) error { func (p *proxyConfigEtcd) Stop() { p.client.RemoveListener(p) + p.closeFunc() } func (p *proxyConfigEtcd) EtcdClientCreated(client *EtcdClient) { go func() { - if err := client.Watch(context.Background(), p.keyPrefix, p, clientv3.WithPrefix()); err != nil { - log.Printf("Error processing watch for %s: %s", p.keyPrefix, err) - } - }() + if err := client.WaitForConnection(p.closeCtx); err != nil { + if errors.Is(err, context.Canceled) { + return + } - go func() { - if err := client.WaitForConnection(context.Background()); err != nil { panic(err) } @@ -101,23 +108,35 @@ func (p *proxyConfigEtcd) EtcdClientCreated(client *EtcdClient) { if err != nil { panic(err) } - for { - response, err := p.getProxyUrls(client, p.keyPrefix) + + var nextRevision int64 + for p.closeCtx.Err() == nil { + response, err := p.getProxyUrls(p.closeCtx, client, p.keyPrefix) if err != nil { - if err == context.DeadlineExceeded { + if errors.Is(err, context.Canceled) { + return + } else if errors.Is(err, context.DeadlineExceeded) { log.Printf("Timeout getting initial list of proxy URLs, retry in %s", backoff.NextWait()) } else { log.Printf("Could not get initial list of proxy URLs, retry in %s: %s", backoff.NextWait(), err) } - backoff.Wait(context.Background()) + backoff.Wait(p.closeCtx) continue } for _, ev := range response.Kvs { p.EtcdKeyUpdated(client, string(ev.Key), ev.Value) } - return + nextRevision = response.Header.Revision + 1 + break + } + + for p.closeCtx.Err() == nil { + var err error + if nextRevision, err = client.Watch(p.closeCtx, p.keyPrefix, nextRevision, p, clientv3.WithPrefix()); err != nil { + log.Printf("Error processing watch for %s: %s", p.keyPrefix, err) + } } }() } @@ -125,8 +144,8 @@ func (p *proxyConfigEtcd) EtcdClientCreated(client *EtcdClient) { func (p *proxyConfigEtcd) EtcdWatchCreated(client *EtcdClient, key string) { } -func (p *proxyConfigEtcd) getProxyUrls(client *EtcdClient, keyPrefix string) (*clientv3.GetResponse, error) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) +func (p *proxyConfigEtcd) getProxyUrls(ctx context.Context, client *EtcdClient, keyPrefix string) (*clientv3.GetResponse, error) { + ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() return client.Get(ctx, keyPrefix, clientv3.WithPrefix())