From 951532d3b34ddf44ab52b27b55f2aea3a65f7fdc Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Wed, 16 Apr 2025 13:57:13 +0200 Subject: [PATCH] Fix race condition in flaky certificate/CA reload tests. --- certificate_reloader_test.go | 6 ++---- grpc_common_test.go | 8 ++++---- grpc_server_test.go | 12 ++++++------ 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/certificate_reloader_test.go b/certificate_reloader_test.go index a180a9d..9e50fc6 100644 --- a/certificate_reloader_test.go +++ b/certificate_reloader_test.go @@ -39,8 +39,7 @@ func UpdateCertificateCheckIntervalForTest(t *testing.T, interval time.Duration) deduplicateWatchEvents.Store(int64(interval)) } -func (r *CertificateReloader) WaitForReload(ctx context.Context) error { - counter := r.GetReloadCounter() +func (r *CertificateReloader) WaitForReload(ctx context.Context, counter uint64) error { for counter == r.GetReloadCounter() { if err := ctx.Err(); err != nil { return err @@ -50,8 +49,7 @@ func (r *CertificateReloader) WaitForReload(ctx context.Context) error { return nil } -func (r *CertPoolReloader) WaitForReload(ctx context.Context) error { - counter := r.GetReloadCounter() +func (r *CertPoolReloader) WaitForReload(ctx context.Context, counter uint64) error { for counter == r.GetReloadCounter() { if err := ctx.Err(); err != nil { return err diff --git a/grpc_common_test.go b/grpc_common_test.go index 878efd0..a18deca 100644 --- a/grpc_common_test.go +++ b/grpc_common_test.go @@ -39,20 +39,20 @@ import ( "github.com/stretchr/testify/require" ) -func (c *reloadableCredentials) WaitForCertificateReload(ctx context.Context) error { +func (c *reloadableCredentials) WaitForCertificateReload(ctx context.Context, counter uint64) error { if c.loader == nil { return errors.New("no certificate loaded") } - return c.loader.WaitForReload(ctx) + return c.loader.WaitForReload(ctx, counter) } -func (c *reloadableCredentials) WaitForCertPoolReload(ctx context.Context) error { +func (c *reloadableCredentials) WaitForCertPoolReload(ctx context.Context, counter uint64) error { if c.pool == nil { return errors.New("no certificate pool loaded") } - return c.pool.WaitForReload(ctx) + return c.pool.WaitForReload(ctx, counter) } func GenerateSelfSignedCertificateForTesting(t *testing.T, bits int, organization string, key *rsa.PrivateKey) []byte { diff --git a/grpc_server_test.go b/grpc_server_test.go index e985fd2..8ebdc7c 100644 --- a/grpc_server_test.go +++ b/grpc_server_test.go @@ -43,22 +43,22 @@ import ( "google.golang.org/grpc/credentials" ) -func (s *GrpcServer) WaitForCertificateReload(ctx context.Context) error { +func (s *GrpcServer) WaitForCertificateReload(ctx context.Context, counter uint64) error { c, ok := s.creds.(*reloadableCredentials) if !ok { return errors.New("no reloadable credentials found") } - return c.WaitForCertificateReload(ctx) + return c.WaitForCertificateReload(ctx, counter) } -func (s *GrpcServer) WaitForCertPoolReload(ctx context.Context) error { +func (s *GrpcServer) WaitForCertPoolReload(ctx context.Context, counter uint64) error { c, ok := s.creds.(*reloadableCredentials) if !ok { return errors.New("no reloadable credentials found") } - return c.WaitForCertPoolReload(ctx) + return c.WaitForCertPoolReload(ctx, counter) } func NewGrpcServerForTestWithConfig(t *testing.T, config *goconf.ConfigFile) (server *GrpcServer, addr string) { @@ -145,7 +145,7 @@ func Test_GrpcServer_ReloadCerts(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - require.NoError(server.WaitForCertificateReload(ctx)) + require.NoError(server.WaitForCertificateReload(ctx, 0)) cp2 := x509.NewCertPool() if !cp2.AppendCertsFromPEM(cert2) { @@ -225,7 +225,7 @@ func Test_GrpcServer_ReloadCA(t *testing.T) { clientCert2 := GenerateSelfSignedCertificateForTesting(t, 1024, org2, clientKey) replaceFile(t, caFile, clientCert2, 0755) - require.NoError(server.WaitForCertPoolReload(ctx1)) + require.NoError(server.WaitForCertPoolReload(ctx1, 0)) pair2, err := tls.X509KeyPair(clientCert2, pem.EncodeToMemory(&pem.Block{ Type: "RSA PRIVATE KEY",