diff --git a/certificate_reloader.go b/certificate_reloader.go index 5bedc96..e9ac82d 100644 --- a/certificate_reloader.go +++ b/certificate_reloader.go @@ -38,6 +38,8 @@ type CertificateReloader struct { keyWatcher *FileWatcher certificate atomic.Pointer[tls.Certificate] + + reloadCounter atomic.Uint64 } func NewCertificateReloader(certFile string, keyFile string) (*CertificateReloader, error) { @@ -73,6 +75,7 @@ func (r *CertificateReloader) reload(filename string) { } r.certificate.Store(&pair) + r.reloadCounter.Add(1) } func (r *CertificateReloader) getCertificate() (*tls.Certificate, error) { @@ -87,11 +90,17 @@ func (r *CertificateReloader) GetClientCertificate(i *tls.CertificateRequestInfo return r.getCertificate() } +func (r *CertificateReloader) GetReloadCounter() uint64 { + return r.reloadCounter.Load() +} + type CertPoolReloader struct { certFile string certWatcher *FileWatcher pool atomic.Pointer[x509.CertPool] + + reloadCounter atomic.Uint64 } func loadCertPool(filename string) (*x509.CertPool, error) { @@ -135,8 +144,13 @@ func (r *CertPoolReloader) reload(filename string) { } r.pool.Store(pool) + r.reloadCounter.Add(1) } func (r *CertPoolReloader) GetCertPool() *x509.CertPool { return r.pool.Load() } + +func (r *CertPoolReloader) GetReloadCounter() uint64 { + return r.reloadCounter.Load() +} diff --git a/certificate_reloader_test.go b/certificate_reloader_test.go index 1c3d8cb..95f0ffb 100644 --- a/certificate_reloader_test.go +++ b/certificate_reloader_test.go @@ -22,15 +22,38 @@ package signaling import ( + "context" "testing" "time" ) func UpdateCertificateCheckIntervalForTest(t *testing.T, interval time.Duration) { - old := deduplicateWatchEvents + old := deduplicateWatchEvents.Load() t.Cleanup(func() { - deduplicateWatchEvents = old + deduplicateWatchEvents.Store(old) }) - deduplicateWatchEvents = interval + deduplicateWatchEvents.Store(int64(interval)) +} + +func (r *CertificateReloader) WaitForReload(ctx context.Context) error { + counter := r.GetReloadCounter() + for counter == r.GetReloadCounter() { + if err := ctx.Err(); err != nil { + return err + } + time.Sleep(time.Millisecond) + } + return nil +} + +func (r *CertPoolReloader) WaitForReload(ctx context.Context) error { + counter := r.GetReloadCounter() + for counter == r.GetReloadCounter() { + if err := ctx.Err(); err != nil { + return err + } + time.Sleep(time.Millisecond) + } + return nil } diff --git a/file_watcher.go b/file_watcher.go index cce7bb8..13d5c44 100644 --- a/file_watcher.go +++ b/file_watcher.go @@ -29,15 +29,24 @@ import ( "path/filepath" "strings" "sync" + "sync/atomic" "time" "github.com/fsnotify/fsnotify" ) -var ( - deduplicateWatchEvents = 100 * time.Millisecond +const ( + defaultDeduplicateWatchEvents = 100 * time.Millisecond ) +var ( + deduplicateWatchEvents atomic.Int64 +) + +func init() { + deduplicateWatchEvents.Store(int64(defaultDeduplicateWatchEvents)) +} + type FileWatcherCallback func(filename string) type FileWatcher struct { @@ -90,7 +99,8 @@ func (f *FileWatcher) run() { timers := make(map[string]*time.Timer) triggerEvent := func(event fsnotify.Event) { - if deduplicateWatchEvents <= 0 { + deduplicate := time.Duration(deduplicateWatchEvents.Load()) + if deduplicate <= 0 { f.callback(f.filename) return } @@ -100,7 +110,7 @@ func (f *FileWatcher) run() { t, found := timers[event.Name] mu.Unlock() if !found { - t = time.AfterFunc(deduplicateWatchEvents, func() { + t = time.AfterFunc(deduplicate, func() { f.callback(f.filename) mu.Lock() @@ -111,7 +121,7 @@ func (f *FileWatcher) run() { timers[event.Name] = t mu.Unlock() } else { - t.Reset(deduplicateWatchEvents) + t.Reset(deduplicate) } } diff --git a/file_watcher_test.go b/file_watcher_test.go index f6cc90a..f64fdca 100644 --- a/file_watcher_test.go +++ b/file_watcher_test.go @@ -30,7 +30,7 @@ import ( ) var ( - testWatcherNoEventTimeout = 2 * deduplicateWatchEvents + testWatcherNoEventTimeout = 2 * defaultDeduplicateWatchEvents ) func TestFileWatcher_NotExist(t *testing.T) { diff --git a/grpc_common_test.go b/grpc_common_test.go index a525b49..f9bed4a 100644 --- a/grpc_common_test.go +++ b/grpc_common_test.go @@ -22,11 +22,13 @@ package signaling import ( + "context" "crypto/rand" "crypto/rsa" "crypto/x509" "crypto/x509/pkix" "encoding/pem" + "errors" "io/fs" "math/big" "net" @@ -35,6 +37,22 @@ import ( "time" ) +func (c *reloadableCredentials) WaitForCertificateReload(ctx context.Context) error { + if c.loader == nil { + return errors.New("no certificate loaded") + } + + return c.loader.WaitForReload(ctx) +} + +func (c *reloadableCredentials) WaitForCertPoolReload(ctx context.Context) error { + if c.pool == nil { + return errors.New("no certificate pool loaded") + } + + return c.pool.WaitForReload(ctx) +} + func GenerateSelfSignedCertificateForTesting(t *testing.T, bits int, organization string, key *rsa.PrivateKey) []byte { template := x509.Certificate{ SerialNumber: big.NewInt(1), diff --git a/grpc_server.go b/grpc_server.go index 3108be9..6fd1069 100644 --- a/grpc_server.go +++ b/grpc_server.go @@ -35,6 +35,7 @@ import ( "github.com/dlintw/goconf" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" status "google.golang.org/grpc/status" ) @@ -60,6 +61,7 @@ type GrpcServer struct { UnimplementedRpcMcuServer UnimplementedRpcSessionsServer + creds credentials.TransportCredentials conn *grpc.Server listener net.Listener serverId string // can be overwritten from tests @@ -84,6 +86,7 @@ func NewGrpcServer(config *goconf.ConfigFile) (*GrpcServer, error) { conn := grpc.NewServer(grpc.Creds(creds)) result := &GrpcServer{ + creds: creds, conn: conn, listener: listener, serverId: GrpcServerId, diff --git a/grpc_server_test.go b/grpc_server_test.go index 8464ef3..232309e 100644 --- a/grpc_server_test.go +++ b/grpc_server_test.go @@ -28,6 +28,7 @@ import ( "crypto/tls" "crypto/x509" "encoding/pem" + "errors" "net" "os" "path" @@ -40,6 +41,24 @@ import ( "google.golang.org/grpc/credentials" ) +func (s *GrpcServer) WaitForCertificateReload(ctx context.Context) error { + c, ok := s.creds.(*reloadableCredentials) + if !ok { + return errors.New("no reloadable credentials found") + } + + return c.WaitForCertificateReload(ctx) +} + +func (s *GrpcServer) WaitForCertPoolReload(ctx context.Context) error { + c, ok := s.creds.(*reloadableCredentials) + if !ok { + return errors.New("no reloadable credentials found") + } + + return c.WaitForCertPoolReload(ctx) +} + func NewGrpcServerForTestWithConfig(t *testing.T, config *goconf.ConfigFile) (server *GrpcServer, addr string) { for port := 50000; port < 50100; port++ { addr = net.JoinHostPort("127.0.0.1", strconv.Itoa(port)) @@ -100,7 +119,7 @@ func Test_GrpcServer_ReloadCerts(t *testing.T) { config.AddOption("grpc", "serverkey", privkeyFile) UpdateCertificateCheckIntervalForTest(t, 0) - _, addr := NewGrpcServerForTestWithConfig(t, config) + server, addr := NewGrpcServerForTestWithConfig(t, config) cp1 := x509.NewCertPool() if !cp1.AppendCertsFromPEM(cert1) { @@ -128,6 +147,13 @@ func Test_GrpcServer_ReloadCerts(t *testing.T) { cert2 := GenerateSelfSignedCertificateForTesting(t, 1024, org2, key) replaceFile(t, certFile, cert2, 0755) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + if err := server.WaitForCertificateReload(ctx); err != nil { + t.Fatal(err) + } + cp2 := x509.NewCertPool() if !cp2.AppendCertsFromPEM(cert2) { t.Fatalf("could not add certificate") @@ -181,7 +207,7 @@ func Test_GrpcServer_ReloadCA(t *testing.T) { config.AddOption("grpc", "clientca", caFile) UpdateCertificateCheckIntervalForTest(t, 0) - _, addr := NewGrpcServerForTestWithConfig(t, config) + server, addr := NewGrpcServerForTestWithConfig(t, config) pool := x509.NewCertPool() if !pool.AppendCertsFromPEM(serverCert) { @@ -217,6 +243,10 @@ func Test_GrpcServer_ReloadCA(t *testing.T) { clientCert2 := GenerateSelfSignedCertificateForTesting(t, 1024, org2, clientKey) replaceFile(t, caFile, clientCert2, 0755) + if err := server.WaitForCertPoolReload(ctx1); err != nil { + t.Fatal(err) + } + pair2, err := tls.X509KeyPair(clientCert2, pem.EncodeToMemory(&pem.Block{ Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(clientKey),