Update tests to wait for certificate / pool reload to happen before continuing.

This commit is contained in:
Joachim Bauch 2024-04-03 09:41:38 +02:00
parent cc7625c544
commit 2ef9b39959
No known key found for this signature in database
GPG key ID: 77C1D22D53E15F02
7 changed files with 109 additions and 11 deletions

View file

@ -38,6 +38,8 @@ type CertificateReloader struct {
keyWatcher *FileWatcher
certificate atomic.Pointer[tls.Certificate]
reloadCounter atomic.Uint64
}
func NewCertificateReloader(certFile string, keyFile string) (*CertificateReloader, error) {
@ -73,6 +75,7 @@ func (r *CertificateReloader) reload(filename string) {
}
r.certificate.Store(&pair)
r.reloadCounter.Add(1)
}
func (r *CertificateReloader) getCertificate() (*tls.Certificate, error) {
@ -87,11 +90,17 @@ func (r *CertificateReloader) GetClientCertificate(i *tls.CertificateRequestInfo
return r.getCertificate()
}
func (r *CertificateReloader) GetReloadCounter() uint64 {
return r.reloadCounter.Load()
}
type CertPoolReloader struct {
certFile string
certWatcher *FileWatcher
pool atomic.Pointer[x509.CertPool]
reloadCounter atomic.Uint64
}
func loadCertPool(filename string) (*x509.CertPool, error) {
@ -135,8 +144,13 @@ func (r *CertPoolReloader) reload(filename string) {
}
r.pool.Store(pool)
r.reloadCounter.Add(1)
}
func (r *CertPoolReloader) GetCertPool() *x509.CertPool {
return r.pool.Load()
}
func (r *CertPoolReloader) GetReloadCounter() uint64 {
return r.reloadCounter.Load()
}

View file

@ -22,15 +22,38 @@
package signaling
import (
"context"
"testing"
"time"
)
func UpdateCertificateCheckIntervalForTest(t *testing.T, interval time.Duration) {
old := deduplicateWatchEvents
old := deduplicateWatchEvents.Load()
t.Cleanup(func() {
deduplicateWatchEvents = old
deduplicateWatchEvents.Store(old)
})
deduplicateWatchEvents = interval
deduplicateWatchEvents.Store(int64(interval))
}
func (r *CertificateReloader) WaitForReload(ctx context.Context) error {
counter := r.GetReloadCounter()
for counter == r.GetReloadCounter() {
if err := ctx.Err(); err != nil {
return err
}
time.Sleep(time.Millisecond)
}
return nil
}
func (r *CertPoolReloader) WaitForReload(ctx context.Context) error {
counter := r.GetReloadCounter()
for counter == r.GetReloadCounter() {
if err := ctx.Err(); err != nil {
return err
}
time.Sleep(time.Millisecond)
}
return nil
}

View file

@ -29,15 +29,24 @@ import (
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/fsnotify/fsnotify"
)
var (
deduplicateWatchEvents = 100 * time.Millisecond
const (
defaultDeduplicateWatchEvents = 100 * time.Millisecond
)
var (
deduplicateWatchEvents atomic.Int64
)
func init() {
deduplicateWatchEvents.Store(int64(defaultDeduplicateWatchEvents))
}
type FileWatcherCallback func(filename string)
type FileWatcher struct {
@ -90,7 +99,8 @@ func (f *FileWatcher) run() {
timers := make(map[string]*time.Timer)
triggerEvent := func(event fsnotify.Event) {
if deduplicateWatchEvents <= 0 {
deduplicate := time.Duration(deduplicateWatchEvents.Load())
if deduplicate <= 0 {
f.callback(f.filename)
return
}
@ -100,7 +110,7 @@ func (f *FileWatcher) run() {
t, found := timers[event.Name]
mu.Unlock()
if !found {
t = time.AfterFunc(deduplicateWatchEvents, func() {
t = time.AfterFunc(deduplicate, func() {
f.callback(f.filename)
mu.Lock()
@ -111,7 +121,7 @@ func (f *FileWatcher) run() {
timers[event.Name] = t
mu.Unlock()
} else {
t.Reset(deduplicateWatchEvents)
t.Reset(deduplicate)
}
}

View file

@ -30,7 +30,7 @@ import (
)
var (
testWatcherNoEventTimeout = 2 * deduplicateWatchEvents
testWatcherNoEventTimeout = 2 * defaultDeduplicateWatchEvents
)
func TestFileWatcher_NotExist(t *testing.T) {

View file

@ -22,11 +22,13 @@
package signaling
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"io/fs"
"math/big"
"net"
@ -35,6 +37,22 @@ import (
"time"
)
func (c *reloadableCredentials) WaitForCertificateReload(ctx context.Context) error {
if c.loader == nil {
return errors.New("no certificate loaded")
}
return c.loader.WaitForReload(ctx)
}
func (c *reloadableCredentials) WaitForCertPoolReload(ctx context.Context) error {
if c.pool == nil {
return errors.New("no certificate pool loaded")
}
return c.pool.WaitForReload(ctx)
}
func GenerateSelfSignedCertificateForTesting(t *testing.T, bits int, organization string, key *rsa.PrivateKey) []byte {
template := x509.Certificate{
SerialNumber: big.NewInt(1),

View file

@ -35,6 +35,7 @@ import (
"github.com/dlintw/goconf"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
status "google.golang.org/grpc/status"
)
@ -60,6 +61,7 @@ type GrpcServer struct {
UnimplementedRpcMcuServer
UnimplementedRpcSessionsServer
creds credentials.TransportCredentials
conn *grpc.Server
listener net.Listener
serverId string // can be overwritten from tests
@ -84,6 +86,7 @@ func NewGrpcServer(config *goconf.ConfigFile) (*GrpcServer, error) {
conn := grpc.NewServer(grpc.Creds(creds))
result := &GrpcServer{
creds: creds,
conn: conn,
listener: listener,
serverId: GrpcServerId,

View file

@ -28,6 +28,7 @@ import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"net"
"os"
"path"
@ -40,6 +41,24 @@ import (
"google.golang.org/grpc/credentials"
)
func (s *GrpcServer) WaitForCertificateReload(ctx context.Context) error {
c, ok := s.creds.(*reloadableCredentials)
if !ok {
return errors.New("no reloadable credentials found")
}
return c.WaitForCertificateReload(ctx)
}
func (s *GrpcServer) WaitForCertPoolReload(ctx context.Context) error {
c, ok := s.creds.(*reloadableCredentials)
if !ok {
return errors.New("no reloadable credentials found")
}
return c.WaitForCertPoolReload(ctx)
}
func NewGrpcServerForTestWithConfig(t *testing.T, config *goconf.ConfigFile) (server *GrpcServer, addr string) {
for port := 50000; port < 50100; port++ {
addr = net.JoinHostPort("127.0.0.1", strconv.Itoa(port))
@ -100,7 +119,7 @@ func Test_GrpcServer_ReloadCerts(t *testing.T) {
config.AddOption("grpc", "serverkey", privkeyFile)
UpdateCertificateCheckIntervalForTest(t, 0)
_, addr := NewGrpcServerForTestWithConfig(t, config)
server, addr := NewGrpcServerForTestWithConfig(t, config)
cp1 := x509.NewCertPool()
if !cp1.AppendCertsFromPEM(cert1) {
@ -128,6 +147,13 @@ func Test_GrpcServer_ReloadCerts(t *testing.T) {
cert2 := GenerateSelfSignedCertificateForTesting(t, 1024, org2, key)
replaceFile(t, certFile, cert2, 0755)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := server.WaitForCertificateReload(ctx); err != nil {
t.Fatal(err)
}
cp2 := x509.NewCertPool()
if !cp2.AppendCertsFromPEM(cert2) {
t.Fatalf("could not add certificate")
@ -181,7 +207,7 @@ func Test_GrpcServer_ReloadCA(t *testing.T) {
config.AddOption("grpc", "clientca", caFile)
UpdateCertificateCheckIntervalForTest(t, 0)
_, addr := NewGrpcServerForTestWithConfig(t, config)
server, addr := NewGrpcServerForTestWithConfig(t, config)
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(serverCert) {
@ -217,6 +243,10 @@ func Test_GrpcServer_ReloadCA(t *testing.T) {
clientCert2 := GenerateSelfSignedCertificateForTesting(t, 1024, org2, clientKey)
replaceFile(t, caFile, clientCert2, 0755)
if err := server.WaitForCertPoolReload(ctx1); err != nil {
t.Fatal(err)
}
pair2, err := tls.X509KeyPair(clientCert2, pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(clientKey),