mirror of
https://github.com/strukturag/nextcloud-spreed-signaling
synced 2024-05-17 13:06:33 +02:00
Update tests to wait for certificate / pool reload to happen before continuing.
This commit is contained in:
parent
cc7625c544
commit
2ef9b39959
|
@ -38,6 +38,8 @@ type CertificateReloader struct {
|
||||||
keyWatcher *FileWatcher
|
keyWatcher *FileWatcher
|
||||||
|
|
||||||
certificate atomic.Pointer[tls.Certificate]
|
certificate atomic.Pointer[tls.Certificate]
|
||||||
|
|
||||||
|
reloadCounter atomic.Uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewCertificateReloader(certFile string, keyFile string) (*CertificateReloader, error) {
|
func NewCertificateReloader(certFile string, keyFile string) (*CertificateReloader, error) {
|
||||||
|
@ -73,6 +75,7 @@ func (r *CertificateReloader) reload(filename string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
r.certificate.Store(&pair)
|
r.certificate.Store(&pair)
|
||||||
|
r.reloadCounter.Add(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *CertificateReloader) getCertificate() (*tls.Certificate, error) {
|
func (r *CertificateReloader) getCertificate() (*tls.Certificate, error) {
|
||||||
|
@ -87,11 +90,17 @@ func (r *CertificateReloader) GetClientCertificate(i *tls.CertificateRequestInfo
|
||||||
return r.getCertificate()
|
return r.getCertificate()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *CertificateReloader) GetReloadCounter() uint64 {
|
||||||
|
return r.reloadCounter.Load()
|
||||||
|
}
|
||||||
|
|
||||||
type CertPoolReloader struct {
|
type CertPoolReloader struct {
|
||||||
certFile string
|
certFile string
|
||||||
certWatcher *FileWatcher
|
certWatcher *FileWatcher
|
||||||
|
|
||||||
pool atomic.Pointer[x509.CertPool]
|
pool atomic.Pointer[x509.CertPool]
|
||||||
|
|
||||||
|
reloadCounter atomic.Uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadCertPool(filename string) (*x509.CertPool, error) {
|
func loadCertPool(filename string) (*x509.CertPool, error) {
|
||||||
|
@ -135,8 +144,13 @@ func (r *CertPoolReloader) reload(filename string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
r.pool.Store(pool)
|
r.pool.Store(pool)
|
||||||
|
r.reloadCounter.Add(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *CertPoolReloader) GetCertPool() *x509.CertPool {
|
func (r *CertPoolReloader) GetCertPool() *x509.CertPool {
|
||||||
return r.pool.Load()
|
return r.pool.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *CertPoolReloader) GetReloadCounter() uint64 {
|
||||||
|
return r.reloadCounter.Load()
|
||||||
|
}
|
||||||
|
|
|
@ -22,15 +22,38 @@
|
||||||
package signaling
|
package signaling
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func UpdateCertificateCheckIntervalForTest(t *testing.T, interval time.Duration) {
|
func UpdateCertificateCheckIntervalForTest(t *testing.T, interval time.Duration) {
|
||||||
old := deduplicateWatchEvents
|
old := deduplicateWatchEvents.Load()
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
deduplicateWatchEvents = old
|
deduplicateWatchEvents.Store(old)
|
||||||
})
|
})
|
||||||
|
|
||||||
deduplicateWatchEvents = interval
|
deduplicateWatchEvents.Store(int64(interval))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *CertificateReloader) WaitForReload(ctx context.Context) error {
|
||||||
|
counter := r.GetReloadCounter()
|
||||||
|
for counter == r.GetReloadCounter() {
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *CertPoolReloader) WaitForReload(ctx context.Context) error {
|
||||||
|
counter := r.GetReloadCounter()
|
||||||
|
for counter == r.GetReloadCounter() {
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,15 +29,24 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fsnotify/fsnotify"
|
"github.com/fsnotify/fsnotify"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
const (
|
||||||
deduplicateWatchEvents = 100 * time.Millisecond
|
defaultDeduplicateWatchEvents = 100 * time.Millisecond
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
deduplicateWatchEvents atomic.Int64
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
deduplicateWatchEvents.Store(int64(defaultDeduplicateWatchEvents))
|
||||||
|
}
|
||||||
|
|
||||||
type FileWatcherCallback func(filename string)
|
type FileWatcherCallback func(filename string)
|
||||||
|
|
||||||
type FileWatcher struct {
|
type FileWatcher struct {
|
||||||
|
@ -90,7 +99,8 @@ func (f *FileWatcher) run() {
|
||||||
timers := make(map[string]*time.Timer)
|
timers := make(map[string]*time.Timer)
|
||||||
|
|
||||||
triggerEvent := func(event fsnotify.Event) {
|
triggerEvent := func(event fsnotify.Event) {
|
||||||
if deduplicateWatchEvents <= 0 {
|
deduplicate := time.Duration(deduplicateWatchEvents.Load())
|
||||||
|
if deduplicate <= 0 {
|
||||||
f.callback(f.filename)
|
f.callback(f.filename)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -100,7 +110,7 @@ func (f *FileWatcher) run() {
|
||||||
t, found := timers[event.Name]
|
t, found := timers[event.Name]
|
||||||
mu.Unlock()
|
mu.Unlock()
|
||||||
if !found {
|
if !found {
|
||||||
t = time.AfterFunc(deduplicateWatchEvents, func() {
|
t = time.AfterFunc(deduplicate, func() {
|
||||||
f.callback(f.filename)
|
f.callback(f.filename)
|
||||||
|
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
|
@ -111,7 +121,7 @@ func (f *FileWatcher) run() {
|
||||||
timers[event.Name] = t
|
timers[event.Name] = t
|
||||||
mu.Unlock()
|
mu.Unlock()
|
||||||
} else {
|
} else {
|
||||||
t.Reset(deduplicateWatchEvents)
|
t.Reset(deduplicate)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
testWatcherNoEventTimeout = 2 * deduplicateWatchEvents
|
testWatcherNoEventTimeout = 2 * defaultDeduplicateWatchEvents
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestFileWatcher_NotExist(t *testing.T) {
|
func TestFileWatcher_NotExist(t *testing.T) {
|
||||||
|
|
|
@ -22,11 +22,13 @@
|
||||||
package signaling
|
package signaling
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"crypto/x509/pkix"
|
"crypto/x509/pkix"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net"
|
"net"
|
||||||
|
@ -35,6 +37,22 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (c *reloadableCredentials) WaitForCertificateReload(ctx context.Context) error {
|
||||||
|
if c.loader == nil {
|
||||||
|
return errors.New("no certificate loaded")
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.loader.WaitForReload(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *reloadableCredentials) WaitForCertPoolReload(ctx context.Context) error {
|
||||||
|
if c.pool == nil {
|
||||||
|
return errors.New("no certificate pool loaded")
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.pool.WaitForReload(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
func GenerateSelfSignedCertificateForTesting(t *testing.T, bits int, organization string, key *rsa.PrivateKey) []byte {
|
func GenerateSelfSignedCertificateForTesting(t *testing.T, bits int, organization string, key *rsa.PrivateKey) []byte {
|
||||||
template := x509.Certificate{
|
template := x509.Certificate{
|
||||||
SerialNumber: big.NewInt(1),
|
SerialNumber: big.NewInt(1),
|
||||||
|
|
|
@ -35,6 +35,7 @@ import (
|
||||||
"github.com/dlintw/goconf"
|
"github.com/dlintw/goconf"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/credentials"
|
||||||
status "google.golang.org/grpc/status"
|
status "google.golang.org/grpc/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -60,6 +61,7 @@ type GrpcServer struct {
|
||||||
UnimplementedRpcMcuServer
|
UnimplementedRpcMcuServer
|
||||||
UnimplementedRpcSessionsServer
|
UnimplementedRpcSessionsServer
|
||||||
|
|
||||||
|
creds credentials.TransportCredentials
|
||||||
conn *grpc.Server
|
conn *grpc.Server
|
||||||
listener net.Listener
|
listener net.Listener
|
||||||
serverId string // can be overwritten from tests
|
serverId string // can be overwritten from tests
|
||||||
|
@ -84,6 +86,7 @@ func NewGrpcServer(config *goconf.ConfigFile) (*GrpcServer, error) {
|
||||||
|
|
||||||
conn := grpc.NewServer(grpc.Creds(creds))
|
conn := grpc.NewServer(grpc.Creds(creds))
|
||||||
result := &GrpcServer{
|
result := &GrpcServer{
|
||||||
|
creds: creds,
|
||||||
conn: conn,
|
conn: conn,
|
||||||
listener: listener,
|
listener: listener,
|
||||||
serverId: GrpcServerId,
|
serverId: GrpcServerId,
|
||||||
|
|
|
@ -28,6 +28,7 @@ import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
|
@ -40,6 +41,24 @@ import (
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (s *GrpcServer) WaitForCertificateReload(ctx context.Context) error {
|
||||||
|
c, ok := s.creds.(*reloadableCredentials)
|
||||||
|
if !ok {
|
||||||
|
return errors.New("no reloadable credentials found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.WaitForCertificateReload(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GrpcServer) WaitForCertPoolReload(ctx context.Context) error {
|
||||||
|
c, ok := s.creds.(*reloadableCredentials)
|
||||||
|
if !ok {
|
||||||
|
return errors.New("no reloadable credentials found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.WaitForCertPoolReload(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
func NewGrpcServerForTestWithConfig(t *testing.T, config *goconf.ConfigFile) (server *GrpcServer, addr string) {
|
func NewGrpcServerForTestWithConfig(t *testing.T, config *goconf.ConfigFile) (server *GrpcServer, addr string) {
|
||||||
for port := 50000; port < 50100; port++ {
|
for port := 50000; port < 50100; port++ {
|
||||||
addr = net.JoinHostPort("127.0.0.1", strconv.Itoa(port))
|
addr = net.JoinHostPort("127.0.0.1", strconv.Itoa(port))
|
||||||
|
@ -100,7 +119,7 @@ func Test_GrpcServer_ReloadCerts(t *testing.T) {
|
||||||
config.AddOption("grpc", "serverkey", privkeyFile)
|
config.AddOption("grpc", "serverkey", privkeyFile)
|
||||||
|
|
||||||
UpdateCertificateCheckIntervalForTest(t, 0)
|
UpdateCertificateCheckIntervalForTest(t, 0)
|
||||||
_, addr := NewGrpcServerForTestWithConfig(t, config)
|
server, addr := NewGrpcServerForTestWithConfig(t, config)
|
||||||
|
|
||||||
cp1 := x509.NewCertPool()
|
cp1 := x509.NewCertPool()
|
||||||
if !cp1.AppendCertsFromPEM(cert1) {
|
if !cp1.AppendCertsFromPEM(cert1) {
|
||||||
|
@ -128,6 +147,13 @@ func Test_GrpcServer_ReloadCerts(t *testing.T) {
|
||||||
cert2 := GenerateSelfSignedCertificateForTesting(t, 1024, org2, key)
|
cert2 := GenerateSelfSignedCertificateForTesting(t, 1024, org2, key)
|
||||||
replaceFile(t, certFile, cert2, 0755)
|
replaceFile(t, certFile, cert2, 0755)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := server.WaitForCertificateReload(ctx); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
cp2 := x509.NewCertPool()
|
cp2 := x509.NewCertPool()
|
||||||
if !cp2.AppendCertsFromPEM(cert2) {
|
if !cp2.AppendCertsFromPEM(cert2) {
|
||||||
t.Fatalf("could not add certificate")
|
t.Fatalf("could not add certificate")
|
||||||
|
@ -181,7 +207,7 @@ func Test_GrpcServer_ReloadCA(t *testing.T) {
|
||||||
config.AddOption("grpc", "clientca", caFile)
|
config.AddOption("grpc", "clientca", caFile)
|
||||||
|
|
||||||
UpdateCertificateCheckIntervalForTest(t, 0)
|
UpdateCertificateCheckIntervalForTest(t, 0)
|
||||||
_, addr := NewGrpcServerForTestWithConfig(t, config)
|
server, addr := NewGrpcServerForTestWithConfig(t, config)
|
||||||
|
|
||||||
pool := x509.NewCertPool()
|
pool := x509.NewCertPool()
|
||||||
if !pool.AppendCertsFromPEM(serverCert) {
|
if !pool.AppendCertsFromPEM(serverCert) {
|
||||||
|
@ -217,6 +243,10 @@ func Test_GrpcServer_ReloadCA(t *testing.T) {
|
||||||
clientCert2 := GenerateSelfSignedCertificateForTesting(t, 1024, org2, clientKey)
|
clientCert2 := GenerateSelfSignedCertificateForTesting(t, 1024, org2, clientKey)
|
||||||
replaceFile(t, caFile, clientCert2, 0755)
|
replaceFile(t, caFile, clientCert2, 0755)
|
||||||
|
|
||||||
|
if err := server.WaitForCertPoolReload(ctx1); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
pair2, err := tls.X509KeyPair(clientCert2, pem.EncodeToMemory(&pem.Block{
|
pair2, err := tls.X509KeyPair(clientCert2, pem.EncodeToMemory(&pem.Block{
|
||||||
Type: "RSA PRIVATE KEY",
|
Type: "RSA PRIVATE KEY",
|
||||||
Bytes: x509.MarshalPKCS1PrivateKey(clientKey),
|
Bytes: x509.MarshalPKCS1PrivateKey(clientKey),
|
||||||
|
|
Loading…
Reference in a new issue