diff --git a/backend_storage_etcd.go b/backend_storage_etcd.go index 4f95043..e71cda9 100644 --- a/backend_storage_etcd.go +++ b/backend_storage_etcd.go @@ -27,6 +27,7 @@ import ( "fmt" "log" "net/url" + "sync" "time" "github.com/dlintw/goconf" @@ -42,6 +43,7 @@ type backendStorageEtcd struct { initializedCtx context.Context initializedFunc context.CancelFunc + initializedWg sync.WaitGroup wakeupChanForTesting chan bool } @@ -100,6 +102,7 @@ func (s *backendStorageEtcd) wakeupForTesting() { } func (s *backendStorageEtcd) EtcdClientCreated(client *EtcdClient) { + s.initializedWg.Add(1) go func() { if err := client.Watch(context.Background(), s.keyPrefix, s, clientv3.WithPrefix()); err != nil { log.Printf("Error processing watch for %s: %s", s.keyPrefix, err) @@ -130,12 +133,17 @@ func (s *backendStorageEtcd) EtcdClientCreated(client *EtcdClient) { for _, ev := range response.Kvs { s.EtcdKeyUpdated(client, string(ev.Key), ev.Value) } + s.initializedWg.Wait() s.initializedFunc() return } }() } +func (s *backendStorageEtcd) EtcdWatchCreated(client *EtcdClient, key string) { + s.initializedWg.Done() +} + func (s *backendStorageEtcd) getBackends(client *EtcdClient, keyPrefix string) (*clientv3.GetResponse, error) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() diff --git a/etcd_client.go b/etcd_client.go index 7d8f1bd..8da453d 100644 --- a/etcd_client.go +++ b/etcd_client.go @@ -41,6 +41,7 @@ type EtcdClientListener interface { } type EtcdClientWatcher interface { + EtcdWatchCreated(client *EtcdClient, key string) EtcdKeyUpdated(client *EtcdClient, key string, value []byte) EtcdKeyDeleted(client *EtcdClient, key string) } @@ -242,6 +243,7 @@ func (c *EtcdClient) Watch(ctx context.Context, key string, watcher EtcdClientWa log.Printf("Wait for leader and start watching on %s", key) ch := c.getEtcdClient().Watch(clientv3.WithRequireLeader(ctx), key, opts...) log.Printf("Watch created for %s", key) + watcher.EtcdWatchCreated(c, key) for response := range ch { if err := response.Err(); err != nil { return err diff --git a/etcd_client_test.go b/etcd_client_test.go index e8bc046..2426f37 100644 --- a/etcd_client_test.go +++ b/etcd_client_test.go @@ -29,6 +29,7 @@ import ( "os" "runtime" "strconv" + "sync" "syscall" "testing" "time" @@ -200,8 +201,9 @@ type EtcdClientTestListener struct { ctx context.Context cancel context.CancelFunc - initial chan bool - events chan etcdEvent + initial chan bool + initialWg sync.WaitGroup + events chan etcdEvent } func NewEtcdClientTestListener(ctx context.Context, t *testing.T) *EtcdClientTestListener { @@ -222,6 +224,7 @@ func (l *EtcdClientTestListener) Close() { } func (l *EtcdClientTestListener) EtcdClientCreated(client *EtcdClient) { + l.initialWg.Add(1) go func() { if err := client.Watch(clientv3.WithRequireLeader(l.ctx), "foo", l, clientv3.WithPrefix()); err != nil { l.t.Error(err) @@ -243,10 +246,15 @@ func (l *EtcdClientTestListener) EtcdClientCreated(client *EtcdClient) { } else if string(response.Kvs[0].Value) != "1" { l.t.Errorf("expected value \"1\", got \"%s\"", string(response.Kvs[0].Value)) } + l.initialWg.Wait() l.initial <- true }() } +func (l *EtcdClientTestListener) EtcdWatchCreated(client *EtcdClient, key string) { + l.initialWg.Done() +} + func (l *EtcdClientTestListener) EtcdKeyUpdated(client *EtcdClient, key string, value []byte) { l.events <- etcdEvent{ t: clientv3.EventTypePut, diff --git a/grpc_client.go b/grpc_client.go index 48d162a..f68222e 100644 --- a/grpc_client.go +++ b/grpc_client.go @@ -259,6 +259,7 @@ type GrpcClients struct { initializedCtx context.Context initializedFunc context.CancelFunc + initializedWg sync.WaitGroup wakeupChanForTesting chan bool selfCheckWaitGroup sync.WaitGroup } @@ -584,6 +585,7 @@ func (c *GrpcClients) loadTargetsEtcd(config *goconf.ConfigFile, fromReload bool } func (c *GrpcClients) EtcdClientCreated(client *EtcdClient) { + c.initializedWg.Add(1) go func() { if err := client.Watch(context.Background(), c.targetPrefix, c, clientv3.WithPrefix()); err != nil { log.Printf("Error processing watch for %s: %s", c.targetPrefix, err) @@ -610,12 +612,17 @@ func (c *GrpcClients) EtcdClientCreated(client *EtcdClient) { for _, ev := range response.Kvs { c.EtcdKeyUpdated(client, string(ev.Key), ev.Value) } + c.initializedWg.Wait() c.initializedFunc() return } }() } +func (c *GrpcClients) EtcdWatchCreated(client *EtcdClient, key string) { + c.initializedWg.Done() +} + func (c *GrpcClients) getGrpcTargets(client *EtcdClient, targetPrefix string) (*clientv3.GetResponse, error) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() diff --git a/mcu_proxy.go b/mcu_proxy.go index befac70..99f0010 100644 --- a/mcu_proxy.go +++ b/mcu_proxy.go @@ -1561,6 +1561,9 @@ func (m *mcuProxy) EtcdClientCreated(client *EtcdClient) { }() } +func (m *mcuProxy) EtcdWatchCreated(client *EtcdClient, key string) { +} + func (m *mcuProxy) getProxyUrls(client *EtcdClient, keyPrefix string) (*clientv3.GetResponse, error) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel()