diff --git a/grpc_client.go b/grpc_client.go index ecb4f04..9146857 100644 --- a/grpc_client.go +++ b/grpc_client.go @@ -74,6 +74,8 @@ func newGrpcClientImpl(conn grpc.ClientConnInterface) *grpcClientImpl { } type GrpcClient struct { + isSelf uint32 + ip net.IP target string conn *grpc.ClientConn @@ -164,6 +166,18 @@ func (c *GrpcClient) Close() error { return c.conn.Close() } +func (c *GrpcClient) IsSelf() bool { + return atomic.LoadUint32(&c.isSelf) != 0 +} + +func (c *GrpcClient) SetSelf(self bool) { + if self { + atomic.StoreUint32(&c.isSelf, 1) + } else { + atomic.StoreUint32(&c.isSelf, 0) + } +} + func (c *GrpcClient) GetServerId(ctx context.Context) (string, error) { statsGrpcClientCalls.WithLabelValues("GetServerId").Inc() response, err := c.impl.GetServerId(ctx, &GetServerIdRequest{}, grpc.WaitForReady(true)) @@ -248,6 +262,7 @@ type GrpcClients struct { initializedCtx context.Context initializedFunc context.CancelFunc wakeupChanForTesting chan bool + selfCheckWaitGroup sync.WaitGroup } func NewGrpcClients(config *goconf.ConfigFile, etcdClient *EtcdClient) (*GrpcClients, error) { @@ -306,6 +321,79 @@ func (c *GrpcClients) load(config *goconf.ConfigFile, fromReload bool) error { return err } +func (c *GrpcClients) closeClient(client *GrpcClient) { + if client.IsSelf() { + // Already closed. + return + } + + if err := client.Close(); err != nil { + log.Printf("Error closing client to %s: %s", client.Target(), err) + } +} + +func (c *GrpcClients) isClientAvailable(target string, client *GrpcClient) bool { + c.mu.RLock() + defer c.mu.RUnlock() + + entries, found := c.clientsMap[target] + if !found { + return false + } + + for _, entry := range entries { + if entry == client { + return true + } + } + + return false +} + +func (c *GrpcClients) getServerIdWithTimeout(ctx context.Context, client *GrpcClient) (string, error) { + ctx2, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + + id, err := client.GetServerId(ctx2) + return id, err +} + +func (c *GrpcClients) checkIsSelf(ctx context.Context, target string, client *GrpcClient) { + backoff, _ := NewExponentialBackoff(initialWaitDelay, maxWaitDelay) + defer c.selfCheckWaitGroup.Done() + +loop: + for { + select { + case <-ctx.Done(): + // Cancelled + return + default: + if !c.isClientAvailable(target, client) { + return + } + + id, err := c.getServerIdWithTimeout(ctx, client) + if err != nil { + if status.Code(err) != codes.Canceled { + log.Printf("Error checking GRPC server id of %s, retrying in %s: %s", client.Target(), backoff.NextWait(), err) + } + backoff.Wait(ctx) + continue + } + + if id == GrpcServerId { + log.Printf("GRPC target %s is this server, removing", client.Target()) + c.closeClient(client) + client.SetSelf(true) + } else { + log.Printf("Checked GRPC server id of %s", client.Target()) + } + break loop + } + } +} + func (c *GrpcClients) loadTargetsStatic(config *goconf.ConfigFile, fromReload bool, opts ...grpc.DialOption) error { c.mu.Lock() defer c.mu.Unlock() @@ -355,26 +443,14 @@ func (c *GrpcClients) loadTargetsStatic(config *goconf.ConfigFile, fromReload bo if err != nil { for _, clients := range clientsMap { for _, client := range clients { - if closeerr := client.Close(); closeerr != nil { - log.Printf("Error closing client to %s: %s", client.Target(), closeerr) - } + c.closeClient(client) } } return err } - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - - if id, err := client.GetServerId(ctx); err != nil { - log.Printf("Error checking server id of %s: %s", client.Target(), err) - } else if id == GrpcServerId { - log.Printf("GRPC target %s is this server, ignoring", client.Target()) - if err := client.Close(); err != nil { - log.Printf("Error closing client to %s: %s", client.Target(), err) - } - continue - } + c.selfCheckWaitGroup.Add(1) + go c.checkIsSelf(context.Background(), target, client) log.Printf("Adding %s as GRPC target", client.Target()) clientsMap[target] = append(clientsMap[target], client) @@ -386,9 +462,7 @@ func (c *GrpcClients) loadTargetsStatic(config *goconf.ConfigFile, fromReload bo if clients, found := clientsMap[target]; found { for _, client := range clients { log.Printf("Deleting GRPC target %s", client.Target()) - if err := client.Close(); err != nil { - log.Printf("Error closing client to %s: %s", client.Target(), err) - } + c.closeClient(client) } delete(clientsMap, target) } @@ -467,9 +541,7 @@ func (c *GrpcClients) updateGrpcIPs() { if !found { changed = true log.Printf("Removing connection to %s", client.Target()) - if err := client.Close(); err != nil { - log.Printf("Error closing client to %s: %s", client.Target(), err) - } + c.closeClient(client) c.wakeupForTesting() } } @@ -481,18 +553,8 @@ func (c *GrpcClients) updateGrpcIPs() { continue } - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - - if id, err := client.GetServerId(ctx); err != nil { - log.Printf("Error checking server id of %s: %s", client.Target(), err) - } else if id == GrpcServerId { - //log.Printf("GRPC target %s is this server, ignoring", client.Target()) - if err := client.Close(); err != nil { - log.Printf("Error closing client to %s: %s", client.Target(), err) - } - continue - } + c.selfCheckWaitGroup.Add(1) + go c.checkIsSelf(context.Background(), target, client) log.Printf("Adding %s as GRPC target", client.Target()) newClients = append(newClients, client) @@ -543,21 +605,17 @@ func (c *GrpcClients) EtcdClientCreated(client *EtcdClient) { go func() { client.WaitForConnection() - waitDelay := initialWaitDelay + backoff, _ := NewExponentialBackoff(initialWaitDelay, maxWaitDelay) for { response, err := c.getGrpcTargets(client, c.targetPrefix) if err != nil { if err == context.DeadlineExceeded { - log.Printf("Timeout getting initial list of GRPC targets, retry in %s", waitDelay) + log.Printf("Timeout getting initial list of GRPC targets, retry in %s", backoff.NextWait()) } else { - log.Printf("Could not get initial list of GRPC targets, retry in %s: %s", waitDelay, err) + log.Printf("Could not get initial list of GRPC targets, retry in %s: %s", backoff.NextWait(), err) } - time.Sleep(waitDelay) - waitDelay = waitDelay * 2 - if waitDelay > maxWaitDelay { - waitDelay = maxWaitDelay - } + backoff.Wait(context.Background()) continue } @@ -609,19 +667,8 @@ func (c *GrpcClients) EtcdKeyUpdated(client *EtcdClient, key string, data []byte return } - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - - if id, err := cl.GetServerId(ctx); err != nil { - log.Printf("Error checking server id of %s: %s", cl.Target(), err) - } else if id == GrpcServerId { - log.Printf("GRPC target %s is this server, ignoring %s", cl.Target(), key) - if err := cl.Close(); err != nil { - log.Printf("Error closing client to %s: %s", cl.Target(), err) - } - c.wakeupForTesting() - return - } + c.selfCheckWaitGroup.Add(1) + go c.checkIsSelf(context.Background(), info.Address, cl) log.Printf("Adding %s as GRPC target", cl.Target()) @@ -658,9 +705,7 @@ func (c *GrpcClients) removeEtcdClientLocked(key string) { for _, client := range clients { log.Printf("Removing connection to %s (from %s)", client.Target(), key) - if err := client.Close(); err != nil { - log.Printf("Error closing client to %s: %s", client.Target(), err) - } + c.closeClient(client) } delete(c.clientsMap, info.Address) c.clients = make([]*GrpcClient, 0, len(c.clientsMap)) @@ -726,5 +771,17 @@ func (c *GrpcClients) GetClients() []*GrpcClient { c.mu.RLock() defer c.mu.RUnlock() - return c.clients + if len(c.clients) == 0 { + return c.clients + } + + result := make([]*GrpcClient, 0, len(c.clients)-1) + for _, client := range c.clients { + if client.IsSelf() { + continue + } + + result = append(result, client) + } + return result } diff --git a/grpc_client_test.go b/grpc_client_test.go index 8ad4f92..b773db5 100644 --- a/grpc_client_test.go +++ b/grpc_client_test.go @@ -184,6 +184,7 @@ func Test_GrpcClients_EtcdIgnoreSelf(t *testing.T) { server2.serverId = GrpcServerId SetEtcdValue(etcd, "/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}")) <-ch + client.selfCheckWaitGroup.Wait() if clients := client.GetClients(); len(clients) != 1 { t.Errorf("Expected one client, got %+v", clients) } else if clients[0].Target() != addr1 {