mirror of
https://github.com/strukturag/nextcloud-spreed-signaling
synced 2024-05-19 14:06:32 +02:00
Update tests to wait for certificate / pool reload to happen before continuing.
This commit is contained in:
parent
cc7625c544
commit
2ef9b39959
|
@ -38,6 +38,8 @@ type CertificateReloader struct {
|
|||
keyWatcher *FileWatcher
|
||||
|
||||
certificate atomic.Pointer[tls.Certificate]
|
||||
|
||||
reloadCounter atomic.Uint64
|
||||
}
|
||||
|
||||
func NewCertificateReloader(certFile string, keyFile string) (*CertificateReloader, error) {
|
||||
|
@ -73,6 +75,7 @@ func (r *CertificateReloader) reload(filename string) {
|
|||
}
|
||||
|
||||
r.certificate.Store(&pair)
|
||||
r.reloadCounter.Add(1)
|
||||
}
|
||||
|
||||
func (r *CertificateReloader) getCertificate() (*tls.Certificate, error) {
|
||||
|
@ -87,11 +90,17 @@ func (r *CertificateReloader) GetClientCertificate(i *tls.CertificateRequestInfo
|
|||
return r.getCertificate()
|
||||
}
|
||||
|
||||
func (r *CertificateReloader) GetReloadCounter() uint64 {
|
||||
return r.reloadCounter.Load()
|
||||
}
|
||||
|
||||
type CertPoolReloader struct {
|
||||
certFile string
|
||||
certWatcher *FileWatcher
|
||||
|
||||
pool atomic.Pointer[x509.CertPool]
|
||||
|
||||
reloadCounter atomic.Uint64
|
||||
}
|
||||
|
||||
func loadCertPool(filename string) (*x509.CertPool, error) {
|
||||
|
@ -135,8 +144,13 @@ func (r *CertPoolReloader) reload(filename string) {
|
|||
}
|
||||
|
||||
r.pool.Store(pool)
|
||||
r.reloadCounter.Add(1)
|
||||
}
|
||||
|
||||
func (r *CertPoolReloader) GetCertPool() *x509.CertPool {
|
||||
return r.pool.Load()
|
||||
}
|
||||
|
||||
func (r *CertPoolReloader) GetReloadCounter() uint64 {
|
||||
return r.reloadCounter.Load()
|
||||
}
|
||||
|
|
|
@ -22,15 +22,38 @@
|
|||
package signaling
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func UpdateCertificateCheckIntervalForTest(t *testing.T, interval time.Duration) {
|
||||
old := deduplicateWatchEvents
|
||||
old := deduplicateWatchEvents.Load()
|
||||
t.Cleanup(func() {
|
||||
deduplicateWatchEvents = old
|
||||
deduplicateWatchEvents.Store(old)
|
||||
})
|
||||
|
||||
deduplicateWatchEvents = interval
|
||||
deduplicateWatchEvents.Store(int64(interval))
|
||||
}
|
||||
|
||||
func (r *CertificateReloader) WaitForReload(ctx context.Context) error {
|
||||
counter := r.GetReloadCounter()
|
||||
for counter == r.GetReloadCounter() {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *CertPoolReloader) WaitForReload(ctx context.Context) error {
|
||||
counter := r.GetReloadCounter()
|
||||
for counter == r.GetReloadCounter() {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -29,15 +29,24 @@ import (
|
|||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
var (
|
||||
deduplicateWatchEvents = 100 * time.Millisecond
|
||||
const (
|
||||
defaultDeduplicateWatchEvents = 100 * time.Millisecond
|
||||
)
|
||||
|
||||
var (
|
||||
deduplicateWatchEvents atomic.Int64
|
||||
)
|
||||
|
||||
func init() {
|
||||
deduplicateWatchEvents.Store(int64(defaultDeduplicateWatchEvents))
|
||||
}
|
||||
|
||||
type FileWatcherCallback func(filename string)
|
||||
|
||||
type FileWatcher struct {
|
||||
|
@ -90,7 +99,8 @@ func (f *FileWatcher) run() {
|
|||
timers := make(map[string]*time.Timer)
|
||||
|
||||
triggerEvent := func(event fsnotify.Event) {
|
||||
if deduplicateWatchEvents <= 0 {
|
||||
deduplicate := time.Duration(deduplicateWatchEvents.Load())
|
||||
if deduplicate <= 0 {
|
||||
f.callback(f.filename)
|
||||
return
|
||||
}
|
||||
|
@ -100,7 +110,7 @@ func (f *FileWatcher) run() {
|
|||
t, found := timers[event.Name]
|
||||
mu.Unlock()
|
||||
if !found {
|
||||
t = time.AfterFunc(deduplicateWatchEvents, func() {
|
||||
t = time.AfterFunc(deduplicate, func() {
|
||||
f.callback(f.filename)
|
||||
|
||||
mu.Lock()
|
||||
|
@ -111,7 +121,7 @@ func (f *FileWatcher) run() {
|
|||
timers[event.Name] = t
|
||||
mu.Unlock()
|
||||
} else {
|
||||
t.Reset(deduplicateWatchEvents)
|
||||
t.Reset(deduplicate)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
testWatcherNoEventTimeout = 2 * deduplicateWatchEvents
|
||||
testWatcherNoEventTimeout = 2 * defaultDeduplicateWatchEvents
|
||||
)
|
||||
|
||||
func TestFileWatcher_NotExist(t *testing.T) {
|
||||
|
|
|
@ -22,11 +22,13 @@
|
|||
package signaling
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"io/fs"
|
||||
"math/big"
|
||||
"net"
|
||||
|
@ -35,6 +37,22 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
func (c *reloadableCredentials) WaitForCertificateReload(ctx context.Context) error {
|
||||
if c.loader == nil {
|
||||
return errors.New("no certificate loaded")
|
||||
}
|
||||
|
||||
return c.loader.WaitForReload(ctx)
|
||||
}
|
||||
|
||||
func (c *reloadableCredentials) WaitForCertPoolReload(ctx context.Context) error {
|
||||
if c.pool == nil {
|
||||
return errors.New("no certificate pool loaded")
|
||||
}
|
||||
|
||||
return c.pool.WaitForReload(ctx)
|
||||
}
|
||||
|
||||
func GenerateSelfSignedCertificateForTesting(t *testing.T, bits int, organization string, key *rsa.PrivateKey) []byte {
|
||||
template := x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
|
|
|
@ -35,6 +35,7 @@ import (
|
|||
"github.com/dlintw/goconf"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
status "google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
|
@ -60,6 +61,7 @@ type GrpcServer struct {
|
|||
UnimplementedRpcMcuServer
|
||||
UnimplementedRpcSessionsServer
|
||||
|
||||
creds credentials.TransportCredentials
|
||||
conn *grpc.Server
|
||||
listener net.Listener
|
||||
serverId string // can be overwritten from tests
|
||||
|
@ -84,6 +86,7 @@ func NewGrpcServer(config *goconf.ConfigFile) (*GrpcServer, error) {
|
|||
|
||||
conn := grpc.NewServer(grpc.Creds(creds))
|
||||
result := &GrpcServer{
|
||||
creds: creds,
|
||||
conn: conn,
|
||||
listener: listener,
|
||||
serverId: GrpcServerId,
|
||||
|
|
|
@ -28,6 +28,7 @@ import (
|
|||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
"path"
|
||||
|
@ -40,6 +41,24 @@ import (
|
|||
"google.golang.org/grpc/credentials"
|
||||
)
|
||||
|
||||
func (s *GrpcServer) WaitForCertificateReload(ctx context.Context) error {
|
||||
c, ok := s.creds.(*reloadableCredentials)
|
||||
if !ok {
|
||||
return errors.New("no reloadable credentials found")
|
||||
}
|
||||
|
||||
return c.WaitForCertificateReload(ctx)
|
||||
}
|
||||
|
||||
func (s *GrpcServer) WaitForCertPoolReload(ctx context.Context) error {
|
||||
c, ok := s.creds.(*reloadableCredentials)
|
||||
if !ok {
|
||||
return errors.New("no reloadable credentials found")
|
||||
}
|
||||
|
||||
return c.WaitForCertPoolReload(ctx)
|
||||
}
|
||||
|
||||
func NewGrpcServerForTestWithConfig(t *testing.T, config *goconf.ConfigFile) (server *GrpcServer, addr string) {
|
||||
for port := 50000; port < 50100; port++ {
|
||||
addr = net.JoinHostPort("127.0.0.1", strconv.Itoa(port))
|
||||
|
@ -100,7 +119,7 @@ func Test_GrpcServer_ReloadCerts(t *testing.T) {
|
|||
config.AddOption("grpc", "serverkey", privkeyFile)
|
||||
|
||||
UpdateCertificateCheckIntervalForTest(t, 0)
|
||||
_, addr := NewGrpcServerForTestWithConfig(t, config)
|
||||
server, addr := NewGrpcServerForTestWithConfig(t, config)
|
||||
|
||||
cp1 := x509.NewCertPool()
|
||||
if !cp1.AppendCertsFromPEM(cert1) {
|
||||
|
@ -128,6 +147,13 @@ func Test_GrpcServer_ReloadCerts(t *testing.T) {
|
|||
cert2 := GenerateSelfSignedCertificateForTesting(t, 1024, org2, key)
|
||||
replaceFile(t, certFile, cert2, 0755)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := server.WaitForCertificateReload(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cp2 := x509.NewCertPool()
|
||||
if !cp2.AppendCertsFromPEM(cert2) {
|
||||
t.Fatalf("could not add certificate")
|
||||
|
@ -181,7 +207,7 @@ func Test_GrpcServer_ReloadCA(t *testing.T) {
|
|||
config.AddOption("grpc", "clientca", caFile)
|
||||
|
||||
UpdateCertificateCheckIntervalForTest(t, 0)
|
||||
_, addr := NewGrpcServerForTestWithConfig(t, config)
|
||||
server, addr := NewGrpcServerForTestWithConfig(t, config)
|
||||
|
||||
pool := x509.NewCertPool()
|
||||
if !pool.AppendCertsFromPEM(serverCert) {
|
||||
|
@ -217,6 +243,10 @@ func Test_GrpcServer_ReloadCA(t *testing.T) {
|
|||
clientCert2 := GenerateSelfSignedCertificateForTesting(t, 1024, org2, clientKey)
|
||||
replaceFile(t, caFile, clientCert2, 0755)
|
||||
|
||||
if err := server.WaitForCertPoolReload(ctx1); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
pair2, err := tls.X509KeyPair(clientCert2, pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(clientKey),
|
||||
|
|
Loading…
Reference in a new issue