Move certificate / pool reloader to "security" package.

This commit is contained in:
Joachim Bauch 2025-12-15 10:56:39 +01:00
commit 2275a5542e
No known key found for this signature in database
GPG key ID: 77C1D22D53E15F02
7 changed files with 49 additions and 77 deletions

View file

@ -74,6 +74,10 @@ component_management:
name: proxy
paths:
- proxy/**
- component_id: module_security
name: security
paths:
- security/**
- component_id: module_server
name: server
paths:

View file

@ -1,47 +0,0 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2022 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package signaling
import (
"context"
"time"
)
func (r *CertificateReloader) WaitForReload(ctx context.Context, counter uint64) error {
for counter == r.GetReloadCounter() {
if err := ctx.Err(); err != nil {
return err
}
time.Sleep(time.Millisecond)
}
return nil
}
func (r *CertPoolReloader) WaitForReload(ctx context.Context, counter uint64) error {
for counter == r.GetReloadCounter() {
if err := ctx.Err(); err != nil {
return err
}
time.Sleep(time.Millisecond)
}
return nil
}

View file

@ -32,13 +32,14 @@ import (
"google.golang.org/grpc/credentials/insecure"
"github.com/strukturag/nextcloud-spreed-signaling/log"
"github.com/strukturag/nextcloud-spreed-signaling/security"
)
type reloadableCredentials struct {
config *tls.Config
loader *CertificateReloader
pool *CertPoolReloader
loader *security.CertificateReloader
pool *security.CertPoolReloader
}
func (c *reloadableCredentials) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
@ -151,18 +152,18 @@ func NewReloadableCredentials(logger log.Logger, config *goconf.ConfigFile, serv
cfg := &tls.Config{
NextProtos: []string{"h2"},
}
var loader *CertificateReloader
var loader *security.CertificateReloader
var err error
if certificateFile != "" && keyFile != "" {
loader, err = NewCertificateReloader(logger, certificateFile, keyFile)
loader, err = security.NewCertificateReloader(logger, certificateFile, keyFile)
if err != nil {
return nil, fmt.Errorf("invalid GRPC %s certificate / key in %s / %s: %w", prefix, certificateFile, keyFile, err)
}
}
var pool *CertPoolReloader
var pool *security.CertPoolReloader
if caFile != "" {
pool, err = NewCertPoolReloader(logger, caFile)
pool, err = security.NewCertPoolReloader(logger, caFile)
if err != nil {
return nil, err
}

View file

@ -24,6 +24,7 @@ package signaling
import (
"context"
"errors"
"time"
)
func (c *reloadableCredentials) WaitForCertificateReload(ctx context.Context, counter uint64) error {
@ -31,7 +32,13 @@ func (c *reloadableCredentials) WaitForCertificateReload(ctx context.Context, co
return errors.New("no certificate loaded")
}
return c.loader.WaitForReload(ctx, counter)
for counter == c.loader.GetReloadCounter() {
if err := ctx.Err(); err != nil {
return err
}
time.Sleep(time.Millisecond)
}
return nil
}
func (c *reloadableCredentials) WaitForCertPoolReload(ctx context.Context, counter uint64) error {
@ -39,5 +46,11 @@ func (c *reloadableCredentials) WaitForCertPoolReload(ctx context.Context, count
return errors.New("no certificate pool loaded")
}
return c.pool.WaitForReload(ctx, counter)
for counter == c.pool.GetReloadCounter() {
if err := ctx.Err(); err != nil {
return err
}
time.Sleep(time.Millisecond)
}
return nil
}

View file

@ -19,7 +19,7 @@
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package signaling
package security
import (
"crypto/tls"
@ -30,16 +30,17 @@ import (
"testing"
"github.com/strukturag/nextcloud-spreed-signaling/log"
"github.com/strukturag/nextcloud-spreed-signaling/security/internal"
)
type CertificateReloader struct {
logger log.Logger
certFile string
certWatcher *FileWatcher
certWatcher *internal.FileWatcher
keyFile string
keyWatcher *FileWatcher
keyWatcher *internal.FileWatcher
certificate atomic.Pointer[tls.Certificate]
@ -52,7 +53,7 @@ func NewCertificateReloader(logger log.Logger, certFile string, keyFile string)
return nil, fmt.Errorf("could not load certificate / key: %w", err)
}
deduplicate := defaultDeduplicateWatchEvents
deduplicate := internal.DefaultDeduplicateWatchEvents
if testing.Testing() {
deduplicate = 0
}
@ -63,11 +64,11 @@ func NewCertificateReloader(logger log.Logger, certFile string, keyFile string)
keyFile: keyFile,
}
reloader.certificate.Store(&pair)
reloader.certWatcher, err = NewFileWatcher(reloader.logger, certFile, reloader.reload, deduplicate)
reloader.certWatcher, err = internal.NewFileWatcher(reloader.logger, certFile, reloader.reload, deduplicate)
if err != nil {
return nil, err
}
reloader.keyWatcher, err = NewFileWatcher(reloader.logger, keyFile, reloader.reload, deduplicate)
reloader.keyWatcher, err = internal.NewFileWatcher(reloader.logger, keyFile, reloader.reload, deduplicate)
if err != nil {
reloader.certWatcher.Close() // nolint
return nil, err
@ -113,7 +114,7 @@ type CertPoolReloader struct {
logger log.Logger
certFile string
certWatcher *FileWatcher
certWatcher *internal.FileWatcher
pool atomic.Pointer[x509.CertPool]
@ -140,7 +141,7 @@ func NewCertPoolReloader(logger log.Logger, certFile string) (*CertPoolReloader,
return nil, err
}
deduplicate := defaultDeduplicateWatchEvents
deduplicate := internal.DefaultDeduplicateWatchEvents
if testing.Testing() {
deduplicate = 0
}
@ -150,7 +151,7 @@ func NewCertPoolReloader(logger log.Logger, certFile string) (*CertPoolReloader,
certFile: certFile,
}
reloader.pool.Store(pool)
reloader.certWatcher, err = NewFileWatcher(reloader.logger, certFile, reloader.reload, deduplicate)
reloader.certWatcher, err = internal.NewFileWatcher(reloader.logger, certFile, reloader.reload, deduplicate)
if err != nil {
return nil, err
}

View file

@ -19,7 +19,7 @@
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package signaling
package internal
import (
"context"
@ -37,7 +37,7 @@ import (
)
const (
defaultDeduplicateWatchEvents = 100 * time.Millisecond
DefaultDeduplicateWatchEvents = 100 * time.Millisecond
)
type FileWatcherCallback func(filename string)

View file

@ -19,7 +19,7 @@
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package signaling
package internal
import (
"context"
@ -35,7 +35,7 @@ import (
)
var (
testWatcherNoEventTimeout = 2 * defaultDeduplicateWatchEvents
testWatcherNoEventTimeout = 2 * DefaultDeduplicateWatchEvents
)
func TestFileWatcher_NotExist(t *testing.T) {
@ -43,7 +43,7 @@ func TestFileWatcher_NotExist(t *testing.T) {
assert := assert.New(t)
tmpdir := t.TempDir()
logger := log.NewLoggerForTest(t)
if w, err := NewFileWatcher(logger, path.Join(tmpdir, "test.txt"), func(filename string) {}, defaultDeduplicateWatchEvents); !assert.ErrorIs(err, os.ErrNotExist) {
if w, err := NewFileWatcher(logger, path.Join(tmpdir, "test.txt"), func(filename string) {}, DefaultDeduplicateWatchEvents); !assert.ErrorIs(err, os.ErrNotExist) {
if w != nil {
assert.NoError(w.Close())
}
@ -62,7 +62,7 @@ func TestFileWatcher_File(t *testing.T) { // nolint:paralleltest
modified := make(chan struct{})
w, err := NewFileWatcher(logger, filename, func(filename string) {
modified <- struct{}{}
}, defaultDeduplicateWatchEvents)
}, DefaultDeduplicateWatchEvents)
require.NoError(err)
defer w.Close()
@ -105,7 +105,7 @@ func TestFileWatcher_CurrentDir(t *testing.T) { // nolint:paralleltest
modified := make(chan struct{})
w, err := NewFileWatcher(logger, "./"+path.Base(filename), func(filename string) {
modified <- struct{}{}
}, defaultDeduplicateWatchEvents)
}, DefaultDeduplicateWatchEvents)
require.NoError(err)
defer w.Close()
@ -147,7 +147,7 @@ func TestFileWatcher_Rename(t *testing.T) {
modified := make(chan struct{})
w, err := NewFileWatcher(logger, filename, func(filename string) {
modified <- struct{}{}
}, defaultDeduplicateWatchEvents)
}, DefaultDeduplicateWatchEvents)
require.NoError(err)
defer w.Close()
@ -191,7 +191,7 @@ func TestFileWatcher_Symlink(t *testing.T) {
modified := make(chan struct{})
w, err := NewFileWatcher(logger, filename, func(filename string) {
modified <- struct{}{}
}, defaultDeduplicateWatchEvents)
}, DefaultDeduplicateWatchEvents)
require.NoError(err)
defer w.Close()
@ -226,7 +226,7 @@ func TestFileWatcher_ChangeSymlinkTarget(t *testing.T) {
modified := make(chan struct{})
w, err := NewFileWatcher(logger, filename, func(filename string) {
modified <- struct{}{}
}, defaultDeduplicateWatchEvents)
}, DefaultDeduplicateWatchEvents)
require.NoError(err)
defer w.Close()
@ -263,7 +263,7 @@ func TestFileWatcher_OtherSymlink(t *testing.T) {
modified := make(chan struct{})
w, err := NewFileWatcher(logger, filename, func(filename string) {
modified <- struct{}{}
}, defaultDeduplicateWatchEvents)
}, DefaultDeduplicateWatchEvents)
require.NoError(err)
defer w.Close()
@ -294,7 +294,7 @@ func TestFileWatcher_RenameSymlinkTarget(t *testing.T) {
modified := make(chan struct{})
w, err := NewFileWatcher(logger, filename, func(filename string) {
modified <- struct{}{}
}, defaultDeduplicateWatchEvents)
}, DefaultDeduplicateWatchEvents)
require.NoError(err)
defer w.Close()
@ -348,7 +348,7 @@ func TestFileWatcher_UpdateSymlinkFolder(t *testing.T) {
modified := make(chan struct{})
w, err := NewFileWatcher(logger, filename, func(filename string) {
modified <- struct{}{}
}, defaultDeduplicateWatchEvents)
}, DefaultDeduplicateWatchEvents)
require.NoError(err)
defer w.Close()