Update tests to wait for certificate / pool reload to happen before continuing.

This commit is contained in:
Joachim Bauch 2024-04-03 09:41:38 +02:00
parent cc7625c544
commit 2ef9b39959
No known key found for this signature in database
GPG key ID: 77C1D22D53E15F02
7 changed files with 109 additions and 11 deletions

View file

@ -38,6 +38,8 @@ type CertificateReloader struct {
keyWatcher *FileWatcher keyWatcher *FileWatcher
certificate atomic.Pointer[tls.Certificate] certificate atomic.Pointer[tls.Certificate]
reloadCounter atomic.Uint64
} }
func NewCertificateReloader(certFile string, keyFile string) (*CertificateReloader, error) { func NewCertificateReloader(certFile string, keyFile string) (*CertificateReloader, error) {
@ -73,6 +75,7 @@ func (r *CertificateReloader) reload(filename string) {
} }
r.certificate.Store(&pair) r.certificate.Store(&pair)
r.reloadCounter.Add(1)
} }
func (r *CertificateReloader) getCertificate() (*tls.Certificate, error) { func (r *CertificateReloader) getCertificate() (*tls.Certificate, error) {
@ -87,11 +90,17 @@ func (r *CertificateReloader) GetClientCertificate(i *tls.CertificateRequestInfo
return r.getCertificate() return r.getCertificate()
} }
func (r *CertificateReloader) GetReloadCounter() uint64 {
return r.reloadCounter.Load()
}
type CertPoolReloader struct { type CertPoolReloader struct {
certFile string certFile string
certWatcher *FileWatcher certWatcher *FileWatcher
pool atomic.Pointer[x509.CertPool] pool atomic.Pointer[x509.CertPool]
reloadCounter atomic.Uint64
} }
func loadCertPool(filename string) (*x509.CertPool, error) { func loadCertPool(filename string) (*x509.CertPool, error) {
@ -135,8 +144,13 @@ func (r *CertPoolReloader) reload(filename string) {
} }
r.pool.Store(pool) r.pool.Store(pool)
r.reloadCounter.Add(1)
} }
func (r *CertPoolReloader) GetCertPool() *x509.CertPool { func (r *CertPoolReloader) GetCertPool() *x509.CertPool {
return r.pool.Load() return r.pool.Load()
} }
func (r *CertPoolReloader) GetReloadCounter() uint64 {
return r.reloadCounter.Load()
}

View file

@ -22,15 +22,38 @@
package signaling package signaling
import ( import (
"context"
"testing" "testing"
"time" "time"
) )
func UpdateCertificateCheckIntervalForTest(t *testing.T, interval time.Duration) { func UpdateCertificateCheckIntervalForTest(t *testing.T, interval time.Duration) {
old := deduplicateWatchEvents old := deduplicateWatchEvents.Load()
t.Cleanup(func() { t.Cleanup(func() {
deduplicateWatchEvents = old deduplicateWatchEvents.Store(old)
}) })
deduplicateWatchEvents = interval deduplicateWatchEvents.Store(int64(interval))
}
func (r *CertificateReloader) WaitForReload(ctx context.Context) error {
counter := r.GetReloadCounter()
for counter == r.GetReloadCounter() {
if err := ctx.Err(); err != nil {
return err
}
time.Sleep(time.Millisecond)
}
return nil
}
func (r *CertPoolReloader) WaitForReload(ctx context.Context) error {
counter := r.GetReloadCounter()
for counter == r.GetReloadCounter() {
if err := ctx.Err(); err != nil {
return err
}
time.Sleep(time.Millisecond)
}
return nil
} }

View file

@ -29,15 +29,24 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/fsnotify/fsnotify" "github.com/fsnotify/fsnotify"
) )
var ( const (
deduplicateWatchEvents = 100 * time.Millisecond defaultDeduplicateWatchEvents = 100 * time.Millisecond
) )
var (
deduplicateWatchEvents atomic.Int64
)
func init() {
deduplicateWatchEvents.Store(int64(defaultDeduplicateWatchEvents))
}
type FileWatcherCallback func(filename string) type FileWatcherCallback func(filename string)
type FileWatcher struct { type FileWatcher struct {
@ -90,7 +99,8 @@ func (f *FileWatcher) run() {
timers := make(map[string]*time.Timer) timers := make(map[string]*time.Timer)
triggerEvent := func(event fsnotify.Event) { triggerEvent := func(event fsnotify.Event) {
if deduplicateWatchEvents <= 0 { deduplicate := time.Duration(deduplicateWatchEvents.Load())
if deduplicate <= 0 {
f.callback(f.filename) f.callback(f.filename)
return return
} }
@ -100,7 +110,7 @@ func (f *FileWatcher) run() {
t, found := timers[event.Name] t, found := timers[event.Name]
mu.Unlock() mu.Unlock()
if !found { if !found {
t = time.AfterFunc(deduplicateWatchEvents, func() { t = time.AfterFunc(deduplicate, func() {
f.callback(f.filename) f.callback(f.filename)
mu.Lock() mu.Lock()
@ -111,7 +121,7 @@ func (f *FileWatcher) run() {
timers[event.Name] = t timers[event.Name] = t
mu.Unlock() mu.Unlock()
} else { } else {
t.Reset(deduplicateWatchEvents) t.Reset(deduplicate)
} }
} }

View file

@ -30,7 +30,7 @@ import (
) )
var ( var (
testWatcherNoEventTimeout = 2 * deduplicateWatchEvents testWatcherNoEventTimeout = 2 * defaultDeduplicateWatchEvents
) )
func TestFileWatcher_NotExist(t *testing.T) { func TestFileWatcher_NotExist(t *testing.T) {

View file

@ -22,11 +22,13 @@
package signaling package signaling
import ( import (
"context"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/x509" "crypto/x509"
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/pem" "encoding/pem"
"errors"
"io/fs" "io/fs"
"math/big" "math/big"
"net" "net"
@ -35,6 +37,22 @@ import (
"time" "time"
) )
func (c *reloadableCredentials) WaitForCertificateReload(ctx context.Context) error {
if c.loader == nil {
return errors.New("no certificate loaded")
}
return c.loader.WaitForReload(ctx)
}
func (c *reloadableCredentials) WaitForCertPoolReload(ctx context.Context) error {
if c.pool == nil {
return errors.New("no certificate pool loaded")
}
return c.pool.WaitForReload(ctx)
}
func GenerateSelfSignedCertificateForTesting(t *testing.T, bits int, organization string, key *rsa.PrivateKey) []byte { func GenerateSelfSignedCertificateForTesting(t *testing.T, bits int, organization string, key *rsa.PrivateKey) []byte {
template := x509.Certificate{ template := x509.Certificate{
SerialNumber: big.NewInt(1), SerialNumber: big.NewInt(1),

View file

@ -35,6 +35,7 @@ import (
"github.com/dlintw/goconf" "github.com/dlintw/goconf"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
status "google.golang.org/grpc/status" status "google.golang.org/grpc/status"
) )
@ -60,6 +61,7 @@ type GrpcServer struct {
UnimplementedRpcMcuServer UnimplementedRpcMcuServer
UnimplementedRpcSessionsServer UnimplementedRpcSessionsServer
creds credentials.TransportCredentials
conn *grpc.Server conn *grpc.Server
listener net.Listener listener net.Listener
serverId string // can be overwritten from tests serverId string // can be overwritten from tests
@ -84,6 +86,7 @@ func NewGrpcServer(config *goconf.ConfigFile) (*GrpcServer, error) {
conn := grpc.NewServer(grpc.Creds(creds)) conn := grpc.NewServer(grpc.Creds(creds))
result := &GrpcServer{ result := &GrpcServer{
creds: creds,
conn: conn, conn: conn,
listener: listener, listener: listener,
serverId: GrpcServerId, serverId: GrpcServerId,

View file

@ -28,6 +28,7 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"errors"
"net" "net"
"os" "os"
"path" "path"
@ -40,6 +41,24 @@ import (
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
) )
func (s *GrpcServer) WaitForCertificateReload(ctx context.Context) error {
c, ok := s.creds.(*reloadableCredentials)
if !ok {
return errors.New("no reloadable credentials found")
}
return c.WaitForCertificateReload(ctx)
}
func (s *GrpcServer) WaitForCertPoolReload(ctx context.Context) error {
c, ok := s.creds.(*reloadableCredentials)
if !ok {
return errors.New("no reloadable credentials found")
}
return c.WaitForCertPoolReload(ctx)
}
func NewGrpcServerForTestWithConfig(t *testing.T, config *goconf.ConfigFile) (server *GrpcServer, addr string) { func NewGrpcServerForTestWithConfig(t *testing.T, config *goconf.ConfigFile) (server *GrpcServer, addr string) {
for port := 50000; port < 50100; port++ { for port := 50000; port < 50100; port++ {
addr = net.JoinHostPort("127.0.0.1", strconv.Itoa(port)) addr = net.JoinHostPort("127.0.0.1", strconv.Itoa(port))
@ -100,7 +119,7 @@ func Test_GrpcServer_ReloadCerts(t *testing.T) {
config.AddOption("grpc", "serverkey", privkeyFile) config.AddOption("grpc", "serverkey", privkeyFile)
UpdateCertificateCheckIntervalForTest(t, 0) UpdateCertificateCheckIntervalForTest(t, 0)
_, addr := NewGrpcServerForTestWithConfig(t, config) server, addr := NewGrpcServerForTestWithConfig(t, config)
cp1 := x509.NewCertPool() cp1 := x509.NewCertPool()
if !cp1.AppendCertsFromPEM(cert1) { if !cp1.AppendCertsFromPEM(cert1) {
@ -128,6 +147,13 @@ func Test_GrpcServer_ReloadCerts(t *testing.T) {
cert2 := GenerateSelfSignedCertificateForTesting(t, 1024, org2, key) cert2 := GenerateSelfSignedCertificateForTesting(t, 1024, org2, key)
replaceFile(t, certFile, cert2, 0755) replaceFile(t, certFile, cert2, 0755)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := server.WaitForCertificateReload(ctx); err != nil {
t.Fatal(err)
}
cp2 := x509.NewCertPool() cp2 := x509.NewCertPool()
if !cp2.AppendCertsFromPEM(cert2) { if !cp2.AppendCertsFromPEM(cert2) {
t.Fatalf("could not add certificate") t.Fatalf("could not add certificate")
@ -181,7 +207,7 @@ func Test_GrpcServer_ReloadCA(t *testing.T) {
config.AddOption("grpc", "clientca", caFile) config.AddOption("grpc", "clientca", caFile)
UpdateCertificateCheckIntervalForTest(t, 0) UpdateCertificateCheckIntervalForTest(t, 0)
_, addr := NewGrpcServerForTestWithConfig(t, config) server, addr := NewGrpcServerForTestWithConfig(t, config)
pool := x509.NewCertPool() pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(serverCert) { if !pool.AppendCertsFromPEM(serverCert) {
@ -217,6 +243,10 @@ func Test_GrpcServer_ReloadCA(t *testing.T) {
clientCert2 := GenerateSelfSignedCertificateForTesting(t, 1024, org2, clientKey) clientCert2 := GenerateSelfSignedCertificateForTesting(t, 1024, org2, clientKey)
replaceFile(t, caFile, clientCert2, 0755) replaceFile(t, caFile, clientCert2, 0755)
if err := server.WaitForCertPoolReload(ctx1); err != nil {
t.Fatal(err)
}
pair2, err := tls.X509KeyPair(clientCert2, pem.EncodeToMemory(&pem.Block{ pair2, err := tls.X509KeyPair(clientCert2, pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY", Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(clientKey), Bytes: x509.MarshalPKCS1PrivateKey(clientKey),