From cc7625c544283d9f3ec142acb26185b2337559a1 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Thu, 7 Mar 2024 16:54:12 +0100 Subject: [PATCH] Use new file watcher to detect changed files. --- certificate_reloader.go | 148 ++++++++++++----------------------- certificate_reloader_test.go | 6 +- grpc_server_test.go | 4 +- 3 files changed, 55 insertions(+), 103 deletions(-) diff --git a/certificate_reloader.go b/certificate_reloader.go index d836be1..5bedc96 100644 --- a/certificate_reloader.go +++ b/certificate_reloader.go @@ -27,26 +27,17 @@ import ( "fmt" "log" "os" - "sync" - "time" -) - -var ( - // CertificateCheckInterval defines the interval in which certificate files - // are checked for modifications. - CertificateCheckInterval = time.Minute + "sync/atomic" ) type CertificateReloader struct { - mu sync.Mutex + certFile string + certWatcher *FileWatcher - certFile string - keyFile string + keyFile string + keyWatcher *FileWatcher - certificate *tls.Certificate - lastModified time.Time - - nextCheck time.Time + certificate atomic.Pointer[tls.Certificate] } func NewCertificateReloader(certFile string, keyFile string) (*CertificateReloader, error) { @@ -55,52 +46,37 @@ func NewCertificateReloader(certFile string, keyFile string) (*CertificateReload return nil, fmt.Errorf("could not load certificate / key: %w", err) } - stat, err := os.Stat(certFile) - if err != nil { - return nil, fmt.Errorf("could not stat %s: %w", certFile, err) - } - - return &CertificateReloader{ + reloader := &CertificateReloader{ certFile: certFile, keyFile: keyFile, + } + reloader.certificate.Store(&pair) + reloader.certWatcher, err = NewFileWatcher(certFile, reloader.reload) + if err != nil { + return nil, err + } + reloader.keyWatcher, err = NewFileWatcher(keyFile, reloader.reload) + if err != nil { + reloader.certWatcher.Close() // nolint + return nil, err + } - certificate: &pair, - lastModified: stat.ModTime(), + return reloader, nil +} - nextCheck: time.Now().Add(CertificateCheckInterval), - }, nil +func (r *CertificateReloader) reload(filename string) { + log.Printf("reloading certificate from %s with %s", r.certFile, r.keyFile) + pair, err := tls.LoadX509KeyPair(r.certFile, r.keyFile) + if err != nil { + log.Printf("could not load certificate / key: %s", err) + return + } + + r.certificate.Store(&pair) } func (r *CertificateReloader) getCertificate() (*tls.Certificate, error) { - r.mu.Lock() - defer r.mu.Unlock() - - now := time.Now() - if now.Before(r.nextCheck) { - return r.certificate, nil - } - - r.nextCheck = now.Add(CertificateCheckInterval) - - stat, err := os.Stat(r.certFile) - if err != nil { - log.Printf("could not stat %s: %s", r.certFile, err) - return r.certificate, nil - } - - if !stat.ModTime().Equal(r.lastModified) { - log.Printf("reloading certificate from %s with %s", r.certFile, r.keyFile) - pair, err := tls.LoadX509KeyPair(r.certFile, r.keyFile) - if err != nil { - log.Printf("could not load certificate / key: %s", err) - return r.certificate, nil - } - - r.certificate = &pair - r.lastModified = stat.ModTime() - } - - return r.certificate, nil + return r.certificate.Load(), nil } func (r *CertificateReloader) GetCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) { @@ -112,14 +88,10 @@ func (r *CertificateReloader) GetClientCertificate(i *tls.CertificateRequestInfo } type CertPoolReloader struct { - mu sync.Mutex + certFile string + certWatcher *FileWatcher - certFile string - - pool *x509.CertPool - lastModified time.Time - - nextCheck time.Time + pool atomic.Pointer[x509.CertPool] } func loadCertPool(filename string) (*x509.CertPool, error) { @@ -142,49 +114,29 @@ func NewCertPoolReloader(certFile string) (*CertPoolReloader, error) { return nil, err } - stat, err := os.Stat(certFile) + reloader := &CertPoolReloader{ + certFile: certFile, + } + reloader.pool.Store(pool) + reloader.certWatcher, err = NewFileWatcher(certFile, reloader.reload) if err != nil { - return nil, fmt.Errorf("could not stat %s: %w", certFile, err) + return nil, err } - return &CertPoolReloader{ - certFile: certFile, + return reloader, nil +} - pool: pool, - lastModified: stat.ModTime(), +func (r *CertPoolReloader) reload(filename string) { + log.Printf("reloading certificate pool from %s", r.certFile) + pool, err := loadCertPool(r.certFile) + if err != nil { + log.Printf("could not load certificate pool: %s", err) + return + } - nextCheck: time.Now().Add(CertificateCheckInterval), - }, nil + r.pool.Store(pool) } func (r *CertPoolReloader) GetCertPool() *x509.CertPool { - r.mu.Lock() - defer r.mu.Unlock() - - now := time.Now() - if now.Before(r.nextCheck) { - return r.pool - } - - r.nextCheck = now.Add(CertificateCheckInterval) - - stat, err := os.Stat(r.certFile) - if err != nil { - log.Printf("could not stat %s: %s", r.certFile, err) - return r.pool - } - - if !stat.ModTime().Equal(r.lastModified) { - log.Printf("reloading certificate pool from %s", r.certFile) - pool, err := loadCertPool(r.certFile) - if err != nil { - log.Printf("could not load certificate pool: %s", err) - return r.pool - } - - r.pool = pool - r.lastModified = stat.ModTime() - } - - return r.pool + return r.pool.Load() } diff --git a/certificate_reloader_test.go b/certificate_reloader_test.go index d282a61..1c3d8cb 100644 --- a/certificate_reloader_test.go +++ b/certificate_reloader_test.go @@ -27,10 +27,10 @@ import ( ) func UpdateCertificateCheckIntervalForTest(t *testing.T, interval time.Duration) { - old := CertificateCheckInterval + old := deduplicateWatchEvents t.Cleanup(func() { - CertificateCheckInterval = old + deduplicateWatchEvents = old }) - CertificateCheckInterval = interval + deduplicateWatchEvents = interval } diff --git a/grpc_server_test.go b/grpc_server_test.go index 4c4abed..8464ef3 100644 --- a/grpc_server_test.go +++ b/grpc_server_test.go @@ -99,7 +99,7 @@ func Test_GrpcServer_ReloadCerts(t *testing.T) { config.AddOption("grpc", "servercertificate", certFile) config.AddOption("grpc", "serverkey", privkeyFile) - UpdateCertificateCheckIntervalForTest(t, time.Millisecond) + UpdateCertificateCheckIntervalForTest(t, 0) _, addr := NewGrpcServerForTestWithConfig(t, config) cp1 := x509.NewCertPool() @@ -180,7 +180,7 @@ func Test_GrpcServer_ReloadCA(t *testing.T) { config.AddOption("grpc", "serverkey", privkeyFile) config.AddOption("grpc", "clientca", caFile) - UpdateCertificateCheckIntervalForTest(t, time.Millisecond) + UpdateCertificateCheckIntervalForTest(t, 0) _, addr := NewGrpcServerForTestWithConfig(t, config) pool := x509.NewCertPool()