From 9adb762ccf8f52c3925445688b2a95b9bbc26eed Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Tue, 23 Apr 2024 10:53:28 +0200 Subject: [PATCH] Close file watcher on shutdown to prevent goroutine leaks. --- certificate_reloader.go | 9 +++++++++ file_watcher.go | 13 ++++++++++++- grpc_client.go | 14 ++++++++++++++ grpc_common.go | 9 +++++++++ grpc_server.go | 3 +++ 5 files changed, 47 insertions(+), 1 deletion(-) diff --git a/certificate_reloader.go b/certificate_reloader.go index e9ac82d..3e23c96 100644 --- a/certificate_reloader.go +++ b/certificate_reloader.go @@ -66,6 +66,11 @@ func NewCertificateReloader(certFile string, keyFile string) (*CertificateReload return reloader, nil } +func (r *CertificateReloader) Close() { + r.keyWatcher.Close() + r.certWatcher.Close() +} + func (r *CertificateReloader) reload(filename string) { log.Printf("reloading certificate from %s with %s", r.certFile, r.keyFile) pair, err := tls.LoadX509KeyPair(r.certFile, r.keyFile) @@ -135,6 +140,10 @@ func NewCertPoolReloader(certFile string) (*CertPoolReloader, error) { return reloader, nil } +func (r *CertPoolReloader) Close() { + r.certWatcher.Close() +} + func (r *CertPoolReloader) reload(filename string) { log.Printf("reloading certificate pool from %s", r.certFile) pool, err := loadCertPool(r.certFile) diff --git a/file_watcher.go b/file_watcher.go index be4d375..6d3a923 100644 --- a/file_watcher.go +++ b/file_watcher.go @@ -22,6 +22,7 @@ package signaling import ( + "context" "errors" "log" "os" @@ -54,7 +55,9 @@ type FileWatcher struct { target string callback FileWatcherCallback - watcher *fsnotify.Watcher + watcher *fsnotify.Watcher + closeCtx context.Context + closeFunc context.CancelFunc } func NewFileWatcher(filename string, callback FileWatcherCallback) (*FileWatcher, error) { @@ -78,17 +81,23 @@ func NewFileWatcher(filename string, callback FileWatcherCallback) (*FileWatcher return nil, err } + closeCtx, closeFunc := context.WithCancel(context.Background()) + w := &FileWatcher{ filename: filename, target: realFilename, callback: callback, watcher: watcher, + + closeCtx: closeCtx, + closeFunc: closeFunc, } go w.run() return w, nil } func (f *FileWatcher) Close() error { + f.closeFunc() return f.watcher.Close() } @@ -152,6 +161,8 @@ func (f *FileWatcher) run() { } log.Printf("Error watching %s: %s", f.filename, err) + case <-f.closeCtx.Done(): + return } } } diff --git a/grpc_client.go b/grpc_client.go index 33d70ae..f2efa8d 100644 --- a/grpc_client.go +++ b/grpc_client.go @@ -38,6 +38,7 @@ import ( clientv3 "go.etcd.io/etcd/client/v3" "google.golang.org/grpc" codes "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/resolver" status "google.golang.org/grpc/status" ) @@ -275,6 +276,7 @@ type GrpcClients struct { targetPrefix string targetInformation map[string]*GrpcTargetInformationEtcd dialOptions atomic.Value // []grpc.DialOption + creds credentials.TransportCredentials initializedCtx context.Context initializedFunc context.CancelFunc @@ -308,6 +310,13 @@ func (c *GrpcClients) load(config *goconf.ConfigFile, fromReload bool) error { return err } + if c.creds != nil { + if cr, ok := c.creds.(*reloadableCredentials); ok { + cr.Close() + } + } + c.creds = creds + opts := []grpc.DialOption{grpc.WithTransportCredentials(creds)} c.dialOptions.Store(opts) @@ -795,6 +804,11 @@ func (c *GrpcClients) Close() { if c.etcdClient != nil { c.etcdClient.RemoveListener(c) } + if c.creds != nil { + if cr, ok := c.creds.(*reloadableCredentials); ok { + cr.Close() + } + } c.closeFunc() } diff --git a/grpc_common.go b/grpc_common.go index 4846179..b7df93e 100644 --- a/grpc_common.go +++ b/grpc_common.go @@ -125,6 +125,15 @@ func (c *reloadableCredentials) OverrideServerName(serverName string) error { return nil } +func (c *reloadableCredentials) Close() { + if c.loader != nil { + c.loader.Close() + } + if c.pool != nil { + c.pool.Close() + } +} + func NewReloadableCredentials(config *goconf.ConfigFile, server bool) (credentials.TransportCredentials, error) { var prefix string var caPrefix string diff --git a/grpc_server.go b/grpc_server.go index 6fd1069..6dd01e9 100644 --- a/grpc_server.go +++ b/grpc_server.go @@ -108,6 +108,9 @@ func (s *GrpcServer) Run() error { func (s *GrpcServer) Close() { s.conn.GracefulStop() + if cr, ok := s.creds.(*reloadableCredentials); ok { + cr.Close() + } } func (s *GrpcServer) LookupSessionId(ctx context.Context, request *LookupSessionIdRequest) (*LookupSessionIdReply, error) {