From bf68a15943f57e9824062ec894344d96ade04420 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Tue, 23 Apr 2024 10:37:15 +0200 Subject: [PATCH 1/3] Make sure "clientsMap" is updated so all clients are closed on shutdown. --- grpc_client.go | 1 + 1 file changed, 1 insertion(+) diff --git a/grpc_client.go b/grpc_client.go index fccdd97..33d70ae 100644 --- a/grpc_client.go +++ b/grpc_client.go @@ -490,6 +490,7 @@ func (c *GrpcClients) loadTargetsStatic(config *goconf.ConfigFile, fromReload bo entry, found := clientsMap[target] if !found { entry = &grpcClientsList{} + clientsMap[target] = entry } entry.clients = append(entry.clients, client) clients = append(clients, client) From 9adb762ccf8f52c3925445688b2a95b9bbc26eed Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Tue, 23 Apr 2024 10:53:28 +0200 Subject: [PATCH 2/3] Close file watcher on shutdown to prevent goroutine leaks. --- certificate_reloader.go | 9 +++++++++ file_watcher.go | 13 ++++++++++++- grpc_client.go | 14 ++++++++++++++ grpc_common.go | 9 +++++++++ grpc_server.go | 3 +++ 5 files changed, 47 insertions(+), 1 deletion(-) diff --git a/certificate_reloader.go b/certificate_reloader.go index e9ac82d..3e23c96 100644 --- a/certificate_reloader.go +++ b/certificate_reloader.go @@ -66,6 +66,11 @@ func NewCertificateReloader(certFile string, keyFile string) (*CertificateReload return reloader, nil } +func (r *CertificateReloader) Close() { + r.keyWatcher.Close() + r.certWatcher.Close() +} + func (r *CertificateReloader) reload(filename string) { log.Printf("reloading certificate from %s with %s", r.certFile, r.keyFile) pair, err := tls.LoadX509KeyPair(r.certFile, r.keyFile) @@ -135,6 +140,10 @@ func NewCertPoolReloader(certFile string) (*CertPoolReloader, error) { return reloader, nil } +func (r *CertPoolReloader) Close() { + r.certWatcher.Close() +} + func (r *CertPoolReloader) reload(filename string) { log.Printf("reloading certificate pool from %s", r.certFile) pool, err := loadCertPool(r.certFile) diff --git a/file_watcher.go b/file_watcher.go index be4d375..6d3a923 100644 --- a/file_watcher.go +++ b/file_watcher.go @@ -22,6 +22,7 @@ package signaling import ( + "context" "errors" "log" "os" @@ -54,7 +55,9 @@ type FileWatcher struct { target string callback FileWatcherCallback - watcher *fsnotify.Watcher + watcher *fsnotify.Watcher + closeCtx context.Context + closeFunc context.CancelFunc } func NewFileWatcher(filename string, callback FileWatcherCallback) (*FileWatcher, error) { @@ -78,17 +81,23 @@ func NewFileWatcher(filename string, callback FileWatcherCallback) (*FileWatcher return nil, err } + closeCtx, closeFunc := context.WithCancel(context.Background()) + w := &FileWatcher{ filename: filename, target: realFilename, callback: callback, watcher: watcher, + + closeCtx: closeCtx, + closeFunc: closeFunc, } go w.run() return w, nil } func (f *FileWatcher) Close() error { + f.closeFunc() return f.watcher.Close() } @@ -152,6 +161,8 @@ func (f *FileWatcher) run() { } log.Printf("Error watching %s: %s", f.filename, err) + case <-f.closeCtx.Done(): + return } } } diff --git a/grpc_client.go b/grpc_client.go index 33d70ae..f2efa8d 100644 --- a/grpc_client.go +++ b/grpc_client.go @@ -38,6 +38,7 @@ import ( clientv3 "go.etcd.io/etcd/client/v3" "google.golang.org/grpc" codes "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/resolver" status "google.golang.org/grpc/status" ) @@ -275,6 +276,7 @@ type GrpcClients struct { targetPrefix string targetInformation map[string]*GrpcTargetInformationEtcd dialOptions atomic.Value // []grpc.DialOption + creds credentials.TransportCredentials initializedCtx context.Context initializedFunc context.CancelFunc @@ -308,6 +310,13 @@ func (c *GrpcClients) load(config *goconf.ConfigFile, fromReload bool) error { return err } + if c.creds != nil { + if cr, ok := c.creds.(*reloadableCredentials); ok { + cr.Close() + } + } + c.creds = creds + opts := []grpc.DialOption{grpc.WithTransportCredentials(creds)} c.dialOptions.Store(opts) @@ -795,6 +804,11 @@ func (c *GrpcClients) Close() { if c.etcdClient != nil { c.etcdClient.RemoveListener(c) } + if c.creds != nil { + if cr, ok := c.creds.(*reloadableCredentials); ok { + cr.Close() + } + } c.closeFunc() } diff --git a/grpc_common.go b/grpc_common.go index 4846179..b7df93e 100644 --- a/grpc_common.go +++ b/grpc_common.go @@ -125,6 +125,15 @@ func (c *reloadableCredentials) OverrideServerName(serverName string) error { return nil } +func (c *reloadableCredentials) Close() { + if c.loader != nil { + c.loader.Close() + } + if c.pool != nil { + c.pool.Close() + } +} + func NewReloadableCredentials(config *goconf.ConfigFile, server bool) (credentials.TransportCredentials, error) { var prefix string var caPrefix string diff --git a/grpc_server.go b/grpc_server.go index 6fd1069..6dd01e9 100644 --- a/grpc_server.go +++ b/grpc_server.go @@ -108,6 +108,9 @@ func (s *GrpcServer) Run() error { func (s *GrpcServer) Close() { s.conn.GracefulStop() + if cr, ok := s.creds.(*reloadableCredentials); ok { + cr.Close() + } } func (s *GrpcServer) LookupSessionId(ctx context.Context, request *LookupSessionIdRequest) (*LookupSessionIdReply, error) { From b77525603cd88955f979fb82a4c14151cd9d4aeb Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Tue, 23 Apr 2024 10:53:55 +0200 Subject: [PATCH 3/3] Enable goroutine leak checks for more tests. --- file_watcher_test.go | 82 ++++++++-------- grpc_client_test.go | 222 ++++++++++++++++++++++--------------------- 2 files changed, 156 insertions(+), 148 deletions(-) diff --git a/file_watcher_test.go b/file_watcher_test.go index 268a68e..175844f 100644 --- a/file_watcher_test.go +++ b/file_watcher_test.go @@ -47,48 +47,50 @@ func TestFileWatcher_NotExist(t *testing.T) { } func TestFileWatcher_File(t *testing.T) { - tmpdir := t.TempDir() - filename := path.Join(tmpdir, "test.txt") - if err := os.WriteFile(filename, []byte("Hello world!"), 0644); err != nil { - t.Fatal(err) - } + ensureNoGoroutinesLeak(t, func(t *testing.T) { + tmpdir := t.TempDir() + filename := path.Join(tmpdir, "test.txt") + if err := os.WriteFile(filename, []byte("Hello world!"), 0644); err != nil { + t.Fatal(err) + } - modified := make(chan struct{}) - w, err := NewFileWatcher(filename, func(filename string) { - modified <- struct{}{} + modified := make(chan struct{}) + w, err := NewFileWatcher(filename, func(filename string) { + modified <- struct{}{} + }) + if err != nil { + t.Fatal(err) + } + defer w.Close() + + if err := os.WriteFile(filename, []byte("Updated"), 0644); err != nil { + t.Fatal(err) + } + <-modified + + ctxTimeout, cancel := context.WithTimeout(context.Background(), testWatcherNoEventTimeout) + defer cancel() + + select { + case <-modified: + t.Error("should not have received another event") + case <-ctxTimeout.Done(): + } + + if err := os.WriteFile(filename, []byte("Updated"), 0644); err != nil { + t.Fatal(err) + } + <-modified + + ctxTimeout, cancel = context.WithTimeout(context.Background(), testWatcherNoEventTimeout) + defer cancel() + + select { + case <-modified: + t.Error("should not have received another event") + case <-ctxTimeout.Done(): + } }) - if err != nil { - t.Fatal(err) - } - defer w.Close() - - if err := os.WriteFile(filename, []byte("Updated"), 0644); err != nil { - t.Fatal(err) - } - <-modified - - ctxTimeout, cancel := context.WithTimeout(context.Background(), testWatcherNoEventTimeout) - defer cancel() - - select { - case <-modified: - t.Error("should not have received another event") - case <-ctxTimeout.Done(): - } - - if err := os.WriteFile(filename, []byte("Updated"), 0644); err != nil { - t.Fatal(err) - } - <-modified - - ctxTimeout, cancel = context.WithTimeout(context.Background(), testWatcherNoEventTimeout) - defer cancel() - - select { - case <-modified: - t.Error("should not have received another event") - case <-ctxTimeout.Done(): - } } func TestFileWatcher_Rename(t *testing.T) { diff --git a/grpc_client_test.go b/grpc_client_test.go index 30e719e..faf5480 100644 --- a/grpc_client_test.go +++ b/grpc_client_test.go @@ -112,24 +112,26 @@ func waitForEvent(ctx context.Context, t *testing.T, ch <-chan struct{}) { } func Test_GrpcClients_EtcdInitial(t *testing.T) { - _, addr1 := NewGrpcServerForTest(t) - _, addr2 := NewGrpcServerForTest(t) + ensureNoGoroutinesLeak(t, func(t *testing.T) { + _, addr1 := NewGrpcServerForTest(t) + _, addr2 := NewGrpcServerForTest(t) - etcd := NewEtcdForTest(t) + etcd := NewEtcdForTest(t) - SetEtcdValue(etcd, "/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}")) - SetEtcdValue(etcd, "/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}")) + SetEtcdValue(etcd, "/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}")) + SetEtcdValue(etcd, "/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}")) - client, _ := NewGrpcClientsWithEtcdForTest(t, etcd) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - if err := client.WaitForInitialized(ctx); err != nil { - t.Fatal(err) - } + client, _ := NewGrpcClientsWithEtcdForTest(t, etcd) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if err := client.WaitForInitialized(ctx); err != nil { + t.Fatal(err) + } - if clients := client.GetClients(); len(clients) != 2 { - t.Errorf("Expected two clients, got %+v", clients) - } + if clients := client.GetClients(); len(clients) != 2 { + t.Errorf("Expected two clients, got %+v", clients) + } + }) } func Test_GrpcClients_EtcdUpdate(t *testing.T) { @@ -231,57 +233,59 @@ func Test_GrpcClients_EtcdIgnoreSelf(t *testing.T) { } func Test_GrpcClients_DnsDiscovery(t *testing.T) { - lookup := newMockDnsLookupForTest(t) - target := "testgrpc:12345" - ip1 := net.ParseIP("192.168.0.1") - ip2 := net.ParseIP("192.168.0.2") - targetWithIp1 := fmt.Sprintf("%s (%s)", target, ip1) - targetWithIp2 := fmt.Sprintf("%s (%s)", target, ip2) - lookup.Set("testgrpc", []net.IP{ip1}) - client, dnsMonitor := NewGrpcClientsForTest(t, target) - ch := client.getWakeupChannelForTesting() + ensureNoGoroutinesLeak(t, func(t *testing.T) { + lookup := newMockDnsLookupForTest(t) + target := "testgrpc:12345" + ip1 := net.ParseIP("192.168.0.1") + ip2 := net.ParseIP("192.168.0.2") + targetWithIp1 := fmt.Sprintf("%s (%s)", target, ip1) + targetWithIp2 := fmt.Sprintf("%s (%s)", target, ip2) + lookup.Set("testgrpc", []net.IP{ip1}) + client, dnsMonitor := NewGrpcClientsForTest(t, target) + ch := client.getWakeupChannelForTesting() - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() - dnsMonitor.checkHostnames() - if clients := client.GetClients(); len(clients) != 1 { - t.Errorf("Expected one client, got %+v", clients) - } else if clients[0].Target() != targetWithIp1 { - t.Errorf("Expected target %s, got %s", targetWithIp1, clients[0].Target()) - } else if !clients[0].ip.Equal(ip1) { - t.Errorf("Expected IP %s, got %s", ip1, clients[0].ip) - } + dnsMonitor.checkHostnames() + if clients := client.GetClients(); len(clients) != 1 { + t.Errorf("Expected one client, got %+v", clients) + } else if clients[0].Target() != targetWithIp1 { + t.Errorf("Expected target %s, got %s", targetWithIp1, clients[0].Target()) + } else if !clients[0].ip.Equal(ip1) { + t.Errorf("Expected IP %s, got %s", ip1, clients[0].ip) + } - lookup.Set("testgrpc", []net.IP{ip1, ip2}) - drainWakeupChannel(ch) - dnsMonitor.checkHostnames() - waitForEvent(ctx, t, ch) + lookup.Set("testgrpc", []net.IP{ip1, ip2}) + drainWakeupChannel(ch) + dnsMonitor.checkHostnames() + waitForEvent(ctx, t, ch) - if clients := client.GetClients(); len(clients) != 2 { - t.Errorf("Expected two client, got %+v", clients) - } else if clients[0].Target() != targetWithIp1 { - t.Errorf("Expected target %s, got %s", targetWithIp1, clients[0].Target()) - } else if !clients[0].ip.Equal(ip1) { - t.Errorf("Expected IP %s, got %s", ip1, clients[0].ip) - } else if clients[1].Target() != targetWithIp2 { - t.Errorf("Expected target %s, got %s", targetWithIp2, clients[1].Target()) - } else if !clients[1].ip.Equal(ip2) { - t.Errorf("Expected IP %s, got %s", ip2, clients[1].ip) - } + if clients := client.GetClients(); len(clients) != 2 { + t.Errorf("Expected two client, got %+v", clients) + } else if clients[0].Target() != targetWithIp1 { + t.Errorf("Expected target %s, got %s", targetWithIp1, clients[0].Target()) + } else if !clients[0].ip.Equal(ip1) { + t.Errorf("Expected IP %s, got %s", ip1, clients[0].ip) + } else if clients[1].Target() != targetWithIp2 { + t.Errorf("Expected target %s, got %s", targetWithIp2, clients[1].Target()) + } else if !clients[1].ip.Equal(ip2) { + t.Errorf("Expected IP %s, got %s", ip2, clients[1].ip) + } - lookup.Set("testgrpc", []net.IP{ip2}) - drainWakeupChannel(ch) - dnsMonitor.checkHostnames() - waitForEvent(ctx, t, ch) + lookup.Set("testgrpc", []net.IP{ip2}) + drainWakeupChannel(ch) + dnsMonitor.checkHostnames() + waitForEvent(ctx, t, ch) - if clients := client.GetClients(); len(clients) != 1 { - t.Errorf("Expected one client, got %+v", clients) - } else if clients[0].Target() != targetWithIp2 { - t.Errorf("Expected target %s, got %s", targetWithIp2, clients[0].Target()) - } else if !clients[0].ip.Equal(ip2) { - t.Errorf("Expected IP %s, got %s", ip2, clients[0].ip) - } + if clients := client.GetClients(); len(clients) != 1 { + t.Errorf("Expected one client, got %+v", clients) + } else if clients[0].Target() != targetWithIp2 { + t.Errorf("Expected target %s, got %s", targetWithIp2, clients[0].Target()) + } else if !clients[0].ip.Equal(ip2) { + t.Errorf("Expected IP %s, got %s", ip2, clients[0].ip) + } + }) } func Test_GrpcClients_DnsDiscoveryInitialFailed(t *testing.T) { @@ -320,55 +324,57 @@ func Test_GrpcClients_DnsDiscoveryInitialFailed(t *testing.T) { } func Test_GrpcClients_Encryption(t *testing.T) { - serverKey, err := rsa.GenerateKey(rand.Reader, 1024) - if err != nil { - t.Fatal(err) - } - clientKey, err := rsa.GenerateKey(rand.Reader, 1024) - if err != nil { - t.Fatal(err) - } - - serverCert := GenerateSelfSignedCertificateForTesting(t, 1024, "Server cert", serverKey) - clientCert := GenerateSelfSignedCertificateForTesting(t, 1024, "Testing client", clientKey) - - dir := t.TempDir() - serverPrivkeyFile := path.Join(dir, "server-privkey.pem") - serverPubkeyFile := path.Join(dir, "server-pubkey.pem") - serverCertFile := path.Join(dir, "server-cert.pem") - WritePrivateKey(serverKey, serverPrivkeyFile) // nolint - WritePublicKey(&serverKey.PublicKey, serverPubkeyFile) // nolint - os.WriteFile(serverCertFile, serverCert, 0755) // nolint - clientPrivkeyFile := path.Join(dir, "client-privkey.pem") - clientPubkeyFile := path.Join(dir, "client-pubkey.pem") - clientCertFile := path.Join(dir, "client-cert.pem") - WritePrivateKey(clientKey, clientPrivkeyFile) // nolint - WritePublicKey(&clientKey.PublicKey, clientPubkeyFile) // nolint - os.WriteFile(clientCertFile, clientCert, 0755) // nolint - - serverConfig := goconf.NewConfigFile() - serverConfig.AddOption("grpc", "servercertificate", serverCertFile) - serverConfig.AddOption("grpc", "serverkey", serverPrivkeyFile) - serverConfig.AddOption("grpc", "clientca", clientCertFile) - _, addr := NewGrpcServerForTestWithConfig(t, serverConfig) - - clientConfig := goconf.NewConfigFile() - clientConfig.AddOption("grpc", "targets", addr) - clientConfig.AddOption("grpc", "clientcertificate", clientCertFile) - clientConfig.AddOption("grpc", "clientkey", clientPrivkeyFile) - clientConfig.AddOption("grpc", "serverca", serverCertFile) - clients, _ := NewGrpcClientsForTestWithConfig(t, clientConfig, nil) - - ctx, cancel1 := context.WithTimeout(context.Background(), time.Second) - defer cancel1() - - if err := clients.WaitForInitialized(ctx); err != nil { - t.Fatal(err) - } - - for _, client := range clients.GetClients() { - if _, err := client.GetServerId(ctx); err != nil { + ensureNoGoroutinesLeak(t, func(t *testing.T) { + serverKey, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { t.Fatal(err) } - } + clientKey, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + t.Fatal(err) + } + + serverCert := GenerateSelfSignedCertificateForTesting(t, 1024, "Server cert", serverKey) + clientCert := GenerateSelfSignedCertificateForTesting(t, 1024, "Testing client", clientKey) + + dir := t.TempDir() + serverPrivkeyFile := path.Join(dir, "server-privkey.pem") + serverPubkeyFile := path.Join(dir, "server-pubkey.pem") + serverCertFile := path.Join(dir, "server-cert.pem") + WritePrivateKey(serverKey, serverPrivkeyFile) // nolint + WritePublicKey(&serverKey.PublicKey, serverPubkeyFile) // nolint + os.WriteFile(serverCertFile, serverCert, 0755) // nolint + clientPrivkeyFile := path.Join(dir, "client-privkey.pem") + clientPubkeyFile := path.Join(dir, "client-pubkey.pem") + clientCertFile := path.Join(dir, "client-cert.pem") + WritePrivateKey(clientKey, clientPrivkeyFile) // nolint + WritePublicKey(&clientKey.PublicKey, clientPubkeyFile) // nolint + os.WriteFile(clientCertFile, clientCert, 0755) // nolint + + serverConfig := goconf.NewConfigFile() + serverConfig.AddOption("grpc", "servercertificate", serverCertFile) + serverConfig.AddOption("grpc", "serverkey", serverPrivkeyFile) + serverConfig.AddOption("grpc", "clientca", clientCertFile) + _, addr := NewGrpcServerForTestWithConfig(t, serverConfig) + + clientConfig := goconf.NewConfigFile() + clientConfig.AddOption("grpc", "targets", addr) + clientConfig.AddOption("grpc", "clientcertificate", clientCertFile) + clientConfig.AddOption("grpc", "clientkey", clientPrivkeyFile) + clientConfig.AddOption("grpc", "serverca", serverCertFile) + clients, _ := NewGrpcClientsForTestWithConfig(t, clientConfig, nil) + + ctx, cancel1 := context.WithTimeout(context.Background(), time.Second) + defer cancel1() + + if err := clients.WaitForInitialized(ctx); err != nil { + t.Fatal(err) + } + + for _, client := range clients.GetClients() { + if _, err := client.GetServerId(ctx); err != nil { + t.Fatal(err) + } + } + }) }