diff --git a/certificate_reloader.go b/certificate_reloader.go index d836be1..e9ac82d 100644 --- a/certificate_reloader.go +++ b/certificate_reloader.go @@ -27,26 +27,19 @@ import ( "fmt" "log" "os" - "sync" - "time" -) - -var ( - // CertificateCheckInterval defines the interval in which certificate files - // are checked for modifications. - CertificateCheckInterval = time.Minute + "sync/atomic" ) type CertificateReloader struct { - mu sync.Mutex + certFile string + certWatcher *FileWatcher - certFile string - keyFile string + keyFile string + keyWatcher *FileWatcher - certificate *tls.Certificate - lastModified time.Time + certificate atomic.Pointer[tls.Certificate] - nextCheck time.Time + reloadCounter atomic.Uint64 } func NewCertificateReloader(certFile string, keyFile string) (*CertificateReloader, error) { @@ -55,52 +48,38 @@ func NewCertificateReloader(certFile string, keyFile string) (*CertificateReload return nil, fmt.Errorf("could not load certificate / key: %w", err) } - stat, err := os.Stat(certFile) - if err != nil { - return nil, fmt.Errorf("could not stat %s: %w", certFile, err) - } - - return &CertificateReloader{ + reloader := &CertificateReloader{ certFile: certFile, keyFile: keyFile, + } + reloader.certificate.Store(&pair) + reloader.certWatcher, err = NewFileWatcher(certFile, reloader.reload) + if err != nil { + return nil, err + } + reloader.keyWatcher, err = NewFileWatcher(keyFile, reloader.reload) + if err != nil { + reloader.certWatcher.Close() // nolint + return nil, err + } - certificate: &pair, - lastModified: stat.ModTime(), + return reloader, nil +} - nextCheck: time.Now().Add(CertificateCheckInterval), - }, nil +func (r *CertificateReloader) reload(filename string) { + log.Printf("reloading certificate from %s with %s", r.certFile, r.keyFile) + pair, err := tls.LoadX509KeyPair(r.certFile, r.keyFile) + if err != nil { + log.Printf("could not load certificate / key: %s", err) + return + } + + r.certificate.Store(&pair) + r.reloadCounter.Add(1) } func (r *CertificateReloader) getCertificate() (*tls.Certificate, error) { - r.mu.Lock() - defer r.mu.Unlock() - - now := time.Now() - if now.Before(r.nextCheck) { - return r.certificate, nil - } - - r.nextCheck = now.Add(CertificateCheckInterval) - - stat, err := os.Stat(r.certFile) - if err != nil { - log.Printf("could not stat %s: %s", r.certFile, err) - return r.certificate, nil - } - - if !stat.ModTime().Equal(r.lastModified) { - log.Printf("reloading certificate from %s with %s", r.certFile, r.keyFile) - pair, err := tls.LoadX509KeyPair(r.certFile, r.keyFile) - if err != nil { - log.Printf("could not load certificate / key: %s", err) - return r.certificate, nil - } - - r.certificate = &pair - r.lastModified = stat.ModTime() - } - - return r.certificate, nil + return r.certificate.Load(), nil } func (r *CertificateReloader) GetCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) { @@ -111,15 +90,17 @@ func (r *CertificateReloader) GetClientCertificate(i *tls.CertificateRequestInfo return r.getCertificate() } +func (r *CertificateReloader) GetReloadCounter() uint64 { + return r.reloadCounter.Load() +} + type CertPoolReloader struct { - mu sync.Mutex + certFile string + certWatcher *FileWatcher - certFile string + pool atomic.Pointer[x509.CertPool] - pool *x509.CertPool - lastModified time.Time - - nextCheck time.Time + reloadCounter atomic.Uint64 } func loadCertPool(filename string) (*x509.CertPool, error) { @@ -142,49 +123,34 @@ func NewCertPoolReloader(certFile string) (*CertPoolReloader, error) { return nil, err } - stat, err := os.Stat(certFile) + reloader := &CertPoolReloader{ + certFile: certFile, + } + reloader.pool.Store(pool) + reloader.certWatcher, err = NewFileWatcher(certFile, reloader.reload) if err != nil { - return nil, fmt.Errorf("could not stat %s: %w", certFile, err) + return nil, err } - return &CertPoolReloader{ - certFile: certFile, + return reloader, nil +} - pool: pool, - lastModified: stat.ModTime(), +func (r *CertPoolReloader) reload(filename string) { + log.Printf("reloading certificate pool from %s", r.certFile) + pool, err := loadCertPool(r.certFile) + if err != nil { + log.Printf("could not load certificate pool: %s", err) + return + } - nextCheck: time.Now().Add(CertificateCheckInterval), - }, nil + r.pool.Store(pool) + r.reloadCounter.Add(1) } func (r *CertPoolReloader) GetCertPool() *x509.CertPool { - r.mu.Lock() - defer r.mu.Unlock() - - now := time.Now() - if now.Before(r.nextCheck) { - return r.pool - } - - r.nextCheck = now.Add(CertificateCheckInterval) - - stat, err := os.Stat(r.certFile) - if err != nil { - log.Printf("could not stat %s: %s", r.certFile, err) - return r.pool - } - - if !stat.ModTime().Equal(r.lastModified) { - log.Printf("reloading certificate pool from %s", r.certFile) - pool, err := loadCertPool(r.certFile) - if err != nil { - log.Printf("could not load certificate pool: %s", err) - return r.pool - } - - r.pool = pool - r.lastModified = stat.ModTime() - } - - return r.pool + return r.pool.Load() +} + +func (r *CertPoolReloader) GetReloadCounter() uint64 { + return r.reloadCounter.Load() } diff --git a/certificate_reloader_test.go b/certificate_reloader_test.go index d282a61..95f0ffb 100644 --- a/certificate_reloader_test.go +++ b/certificate_reloader_test.go @@ -22,15 +22,38 @@ package signaling import ( + "context" "testing" "time" ) func UpdateCertificateCheckIntervalForTest(t *testing.T, interval time.Duration) { - old := CertificateCheckInterval + old := deduplicateWatchEvents.Load() t.Cleanup(func() { - CertificateCheckInterval = old + deduplicateWatchEvents.Store(old) }) - CertificateCheckInterval = interval + deduplicateWatchEvents.Store(int64(interval)) +} + +func (r *CertificateReloader) WaitForReload(ctx context.Context) error { + counter := r.GetReloadCounter() + for counter == r.GetReloadCounter() { + if err := ctx.Err(); err != nil { + return err + } + time.Sleep(time.Millisecond) + } + return nil +} + +func (r *CertPoolReloader) WaitForReload(ctx context.Context) error { + counter := r.GetReloadCounter() + for counter == r.GetReloadCounter() { + if err := ctx.Err(); err != nil { + return err + } + time.Sleep(time.Millisecond) + } + return nil } diff --git a/file_watcher.go b/file_watcher.go new file mode 100644 index 0000000..13d5c44 --- /dev/null +++ b/file_watcher.go @@ -0,0 +1,159 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2024 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "errors" + "log" + "os" + "path" + "path/filepath" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/fsnotify/fsnotify" +) + +const ( + defaultDeduplicateWatchEvents = 100 * time.Millisecond +) + +var ( + deduplicateWatchEvents atomic.Int64 +) + +func init() { + deduplicateWatchEvents.Store(int64(defaultDeduplicateWatchEvents)) +} + +type FileWatcherCallback func(filename string) + +type FileWatcher struct { + filename string + target string + callback FileWatcherCallback + + watcher *fsnotify.Watcher +} + +func NewFileWatcher(filename string, callback FileWatcherCallback) (*FileWatcher, error) { + realFilename, err := filepath.EvalSymlinks(filename) + if err != nil { + return nil, err + } + + watcher, err := fsnotify.NewWatcher() + if err != nil { + return nil, err + } + + if err := watcher.Add(realFilename); err != nil { + watcher.Close() // nolint + return nil, err + } + + if filename != realFilename { + if err := watcher.Add(path.Dir(filename)); err != nil { + watcher.Close() // nolint + return nil, err + } + } + + w := &FileWatcher{ + filename: filename, + target: realFilename, + callback: callback, + watcher: watcher, + } + go w.run() + return w, nil +} + +func (f *FileWatcher) Close() error { + return f.watcher.Close() +} + +func (f *FileWatcher) run() { + var mu sync.Mutex + timers := make(map[string]*time.Timer) + + triggerEvent := func(event fsnotify.Event) { + deduplicate := time.Duration(deduplicateWatchEvents.Load()) + if deduplicate <= 0 { + f.callback(f.filename) + return + } + + // Use timer to deduplicate multiple events for the same file. + mu.Lock() + t, found := timers[event.Name] + mu.Unlock() + if !found { + t = time.AfterFunc(deduplicate, func() { + f.callback(f.filename) + + mu.Lock() + delete(timers, event.Name) + mu.Unlock() + }) + mu.Lock() + timers[event.Name] = t + mu.Unlock() + } else { + t.Reset(deduplicate) + } + } + + for { + select { + case event := <-f.watcher.Events: + if !event.Has(fsnotify.Write) && !event.Has(fsnotify.Create) && !event.Has(fsnotify.Rename) { + continue + } + + if stat, err := os.Lstat(event.Name); err != nil { + if !errors.Is(err, os.ErrNotExist) { + log.Printf("Could not lstat %s: %s", event.Name, err) + } + } else if stat.Mode()&os.ModeSymlink != 0 { + target, err := filepath.EvalSymlinks(event.Name) + if err == nil && target != f.target && strings.HasSuffix(event.Name, f.filename) { + f.target = target + triggerEvent(event) + } + continue + } + + if strings.HasSuffix(event.Name, f.filename) || strings.HasSuffix(event.Name, f.target) { + triggerEvent(event) + } + case err := <-f.watcher.Errors: + if err == nil { + return + } + + log.Printf("Error watching %s: %s", f.filename, err) + } + } +} diff --git a/file_watcher_test.go b/file_watcher_test.go new file mode 100644 index 0000000..f64fdca --- /dev/null +++ b/file_watcher_test.go @@ -0,0 +1,213 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2024 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "context" + "errors" + "os" + "path" + "testing" +) + +var ( + testWatcherNoEventTimeout = 2 * defaultDeduplicateWatchEvents +) + +func TestFileWatcher_NotExist(t *testing.T) { + tmpdir := t.TempDir() + w, err := NewFileWatcher(path.Join(tmpdir, "test.txt"), func(filename string) {}) + if err == nil { + t.Error("should not be able to watch non-existing files") + if err := w.Close(); err != nil { + t.Error(err) + } + } else if !errors.Is(err, os.ErrNotExist) { + t.Error(err) + } +} + +func TestFileWatcher_File(t *testing.T) { + tmpdir := t.TempDir() + filename := path.Join(tmpdir, "test.txt") + if err := os.WriteFile(filename, []byte("Hello world!"), 0644); err != nil { + t.Fatal(err) + } + + modified := make(chan struct{}) + w, err := NewFileWatcher(filename, func(filename string) { + modified <- struct{}{} + }) + if err != nil { + t.Fatal(err) + } + defer w.Close() + + if err := os.WriteFile(filename, []byte("Updated"), 0644); err != nil { + t.Fatal(err) + } + <-modified + + ctxTimeout, cancel := context.WithTimeout(context.Background(), testWatcherNoEventTimeout) + defer cancel() + + select { + case <-modified: + t.Error("should not have received another event") + case <-ctxTimeout.Done(): + } + + if err := os.WriteFile(filename, []byte("Updated"), 0644); err != nil { + t.Fatal(err) + } + <-modified + + ctxTimeout, cancel = context.WithTimeout(context.Background(), testWatcherNoEventTimeout) + defer cancel() + + select { + case <-modified: + t.Error("should not have received another event") + case <-ctxTimeout.Done(): + } +} + +func TestFileWatcher_Symlink(t *testing.T) { + tmpdir := t.TempDir() + sourceFilename := path.Join(tmpdir, "test1.txt") + if err := os.WriteFile(sourceFilename, []byte("Hello world!"), 0644); err != nil { + t.Fatal(err) + } + + filename := path.Join(tmpdir, "symlink.txt") + if err := os.Symlink(sourceFilename, filename); err != nil { + t.Fatal(err) + } + + modified := make(chan struct{}) + w, err := NewFileWatcher(filename, func(filename string) { + modified <- struct{}{} + }) + if err != nil { + t.Fatal(err) + } + defer w.Close() + + if err := os.WriteFile(sourceFilename, []byte("Updated"), 0644); err != nil { + t.Fatal(err) + } + <-modified + + ctxTimeout, cancel := context.WithTimeout(context.Background(), testWatcherNoEventTimeout) + defer cancel() + + select { + case <-modified: + t.Error("should not have received another event") + case <-ctxTimeout.Done(): + } +} + +func TestFileWatcher_ChangeSymlinkTarget(t *testing.T) { + tmpdir := t.TempDir() + sourceFilename1 := path.Join(tmpdir, "test1.txt") + if err := os.WriteFile(sourceFilename1, []byte("Hello world!"), 0644); err != nil { + t.Fatal(err) + } + + sourceFilename2 := path.Join(tmpdir, "test2.txt") + if err := os.WriteFile(sourceFilename2, []byte("Updated"), 0644); err != nil { + t.Fatal(err) + } + + filename := path.Join(tmpdir, "symlink.txt") + if err := os.Symlink(sourceFilename1, filename); err != nil { + t.Fatal(err) + } + + modified := make(chan struct{}) + w, err := NewFileWatcher(filename, func(filename string) { + modified <- struct{}{} + }) + if err != nil { + t.Fatal(err) + } + defer w.Close() + + // Replace symlink by creating new one and rename it to the original target. + if err := os.Symlink(sourceFilename2, filename+".tmp"); err != nil { + t.Fatal(err) + } + if err := os.Rename(filename+".tmp", filename); err != nil { + t.Fatal(err) + } + <-modified + + ctxTimeout, cancel := context.WithTimeout(context.Background(), testWatcherNoEventTimeout) + defer cancel() + + select { + case <-modified: + t.Error("should not have received another event") + case <-ctxTimeout.Done(): + } +} + +func TestFileWatcher_OtherSymlink(t *testing.T) { + tmpdir := t.TempDir() + sourceFilename1 := path.Join(tmpdir, "test1.txt") + if err := os.WriteFile(sourceFilename1, []byte("Hello world!"), 0644); err != nil { + t.Fatal(err) + } + + sourceFilename2 := path.Join(tmpdir, "test2.txt") + if err := os.WriteFile(sourceFilename2, []byte("Updated"), 0644); err != nil { + t.Fatal(err) + } + + filename := path.Join(tmpdir, "symlink.txt") + if err := os.Symlink(sourceFilename1, filename); err != nil { + t.Fatal(err) + } + + modified := make(chan struct{}) + w, err := NewFileWatcher(filename, func(filename string) { + modified <- struct{}{} + }) + if err != nil { + t.Fatal(err) + } + defer w.Close() + + if err := os.Symlink(sourceFilename2, filename+".tmp"); err != nil { + t.Fatal(err) + } + + ctxTimeout, cancel := context.WithTimeout(context.Background(), testWatcherNoEventTimeout) + defer cancel() + + select { + case <-modified: + t.Error("should not have received event for other symlink") + case <-ctxTimeout.Done(): + } +} diff --git a/go.mod b/go.mod index c7ffed3..1dc1803 100644 --- a/go.mod +++ b/go.mod @@ -33,6 +33,7 @@ require ( github.com/coreos/go-systemd/v22 v22.3.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dustin/go-humanize v1.0.0 // indirect + github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/go-logr/logr v1.3.0 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect diff --git a/go.sum b/go.sum index 7b05a55..adb069c 100644 --- a/go.sum +++ b/go.sum @@ -33,6 +33,8 @@ github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.m github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/envoyproxy/protoc-gen-validate v1.0.4 h1:gVPz/FMfvh57HdSJQyvBtF00j8JU4zdyUgIUNhlgg0A= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= diff --git a/grpc_common_test.go b/grpc_common_test.go index a525b49..f9bed4a 100644 --- a/grpc_common_test.go +++ b/grpc_common_test.go @@ -22,11 +22,13 @@ package signaling import ( + "context" "crypto/rand" "crypto/rsa" "crypto/x509" "crypto/x509/pkix" "encoding/pem" + "errors" "io/fs" "math/big" "net" @@ -35,6 +37,22 @@ import ( "time" ) +func (c *reloadableCredentials) WaitForCertificateReload(ctx context.Context) error { + if c.loader == nil { + return errors.New("no certificate loaded") + } + + return c.loader.WaitForReload(ctx) +} + +func (c *reloadableCredentials) WaitForCertPoolReload(ctx context.Context) error { + if c.pool == nil { + return errors.New("no certificate pool loaded") + } + + return c.pool.WaitForReload(ctx) +} + func GenerateSelfSignedCertificateForTesting(t *testing.T, bits int, organization string, key *rsa.PrivateKey) []byte { template := x509.Certificate{ SerialNumber: big.NewInt(1), diff --git a/grpc_server.go b/grpc_server.go index 3108be9..6fd1069 100644 --- a/grpc_server.go +++ b/grpc_server.go @@ -35,6 +35,7 @@ import ( "github.com/dlintw/goconf" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" status "google.golang.org/grpc/status" ) @@ -60,6 +61,7 @@ type GrpcServer struct { UnimplementedRpcMcuServer UnimplementedRpcSessionsServer + creds credentials.TransportCredentials conn *grpc.Server listener net.Listener serverId string // can be overwritten from tests @@ -84,6 +86,7 @@ func NewGrpcServer(config *goconf.ConfigFile) (*GrpcServer, error) { conn := grpc.NewServer(grpc.Creds(creds)) result := &GrpcServer{ + creds: creds, conn: conn, listener: listener, serverId: GrpcServerId, diff --git a/grpc_server_test.go b/grpc_server_test.go index 4c4abed..232309e 100644 --- a/grpc_server_test.go +++ b/grpc_server_test.go @@ -28,6 +28,7 @@ import ( "crypto/tls" "crypto/x509" "encoding/pem" + "errors" "net" "os" "path" @@ -40,6 +41,24 @@ import ( "google.golang.org/grpc/credentials" ) +func (s *GrpcServer) WaitForCertificateReload(ctx context.Context) error { + c, ok := s.creds.(*reloadableCredentials) + if !ok { + return errors.New("no reloadable credentials found") + } + + return c.WaitForCertificateReload(ctx) +} + +func (s *GrpcServer) WaitForCertPoolReload(ctx context.Context) error { + c, ok := s.creds.(*reloadableCredentials) + if !ok { + return errors.New("no reloadable credentials found") + } + + return c.WaitForCertPoolReload(ctx) +} + func NewGrpcServerForTestWithConfig(t *testing.T, config *goconf.ConfigFile) (server *GrpcServer, addr string) { for port := 50000; port < 50100; port++ { addr = net.JoinHostPort("127.0.0.1", strconv.Itoa(port)) @@ -99,8 +118,8 @@ func Test_GrpcServer_ReloadCerts(t *testing.T) { config.AddOption("grpc", "servercertificate", certFile) config.AddOption("grpc", "serverkey", privkeyFile) - UpdateCertificateCheckIntervalForTest(t, time.Millisecond) - _, addr := NewGrpcServerForTestWithConfig(t, config) + UpdateCertificateCheckIntervalForTest(t, 0) + server, addr := NewGrpcServerForTestWithConfig(t, config) cp1 := x509.NewCertPool() if !cp1.AppendCertsFromPEM(cert1) { @@ -128,6 +147,13 @@ func Test_GrpcServer_ReloadCerts(t *testing.T) { cert2 := GenerateSelfSignedCertificateForTesting(t, 1024, org2, key) replaceFile(t, certFile, cert2, 0755) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + if err := server.WaitForCertificateReload(ctx); err != nil { + t.Fatal(err) + } + cp2 := x509.NewCertPool() if !cp2.AppendCertsFromPEM(cert2) { t.Fatalf("could not add certificate") @@ -180,8 +206,8 @@ func Test_GrpcServer_ReloadCA(t *testing.T) { config.AddOption("grpc", "serverkey", privkeyFile) config.AddOption("grpc", "clientca", caFile) - UpdateCertificateCheckIntervalForTest(t, time.Millisecond) - _, addr := NewGrpcServerForTestWithConfig(t, config) + UpdateCertificateCheckIntervalForTest(t, 0) + server, addr := NewGrpcServerForTestWithConfig(t, config) pool := x509.NewCertPool() if !pool.AppendCertsFromPEM(serverCert) { @@ -217,6 +243,10 @@ func Test_GrpcServer_ReloadCA(t *testing.T) { clientCert2 := GenerateSelfSignedCertificateForTesting(t, 1024, org2, clientKey) replaceFile(t, caFile, clientCert2, 0755) + if err := server.WaitForCertPoolReload(ctx1); err != nil { + t.Fatal(err) + } + pair2, err := tls.X509KeyPair(clientCert2, pem.EncodeToMemory(&pem.Block{ Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(clientKey),