Move etcd client code to "etcd" package.

This commit is contained in:
Joachim Bauch 2025-12-15 10:04:32 +01:00
commit 6756520447
No known key found for this signature in database
GPG key ID: 77C1D22D53E15F02
21 changed files with 1004 additions and 188 deletions

View file

@ -22,6 +22,7 @@
package etcd
import (
"context"
"errors"
"fmt"
"net/url"
@ -29,8 +30,32 @@ import (
"github.com/strukturag/nextcloud-spreed-signaling/api"
"github.com/strukturag/nextcloud-spreed-signaling/internal"
clientv3 "go.etcd.io/etcd/client/v3"
)
type ClientListener interface {
EtcdClientCreated(client Client)
}
type ClientWatcher interface {
EtcdWatchCreated(client Client, key string)
EtcdKeyUpdated(client Client, key string, value []byte, prevValue []byte)
EtcdKeyDeleted(client Client, key string, prevValue []byte)
}
type Client interface {
IsConfigured() bool
WaitForConnection(ctx context.Context) error
GetServerInfoEtcd() *BackendServerInfoEtcd
Close() error
AddListener(listener ClientListener)
RemoveListener(listener ClientListener)
Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error)
Watch(ctx context.Context, key string, nextRevision int64, watcher ClientWatcher, opts ...clientv3.OpOption) (int64, error)
}
// Information on a backend in the etcd cluster.
type BackendInformationEtcd struct {
@ -97,3 +122,10 @@ func (p *BackendInformationEtcd) CheckValid() (err error) {
return nil
}
type BackendServerInfoEtcd struct {
Endpoints []string `json:"endpoints"`
Active string `json:"active,omitempty"`
Connected *bool `json:"connected,omitempty"`
}

311
etcd/client.go Normal file
View file

@ -0,0 +1,311 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2022 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package etcd
import (
"context"
"errors"
"fmt"
"slices"
"sync"
"sync/atomic"
"time"
"github.com/dlintw/goconf"
"go.etcd.io/etcd/client/pkg/v3/srv"
"go.etcd.io/etcd/client/pkg/v3/transport"
clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"google.golang.org/grpc/connectivity"
"github.com/strukturag/nextcloud-spreed-signaling/async"
"github.com/strukturag/nextcloud-spreed-signaling/internal"
"github.com/strukturag/nextcloud-spreed-signaling/log"
)
var (
initialWaitDelay = time.Second
maxWaitDelay = 8 * time.Second
)
type etcdClient struct {
logger log.Logger
compatSection string
mu sync.Mutex
client atomic.Value
// +checklocks:mu
listeners map[ClientListener]bool
}
func NewClient(logger log.Logger, config *goconf.ConfigFile, compatSection string) (Client, error) {
result := &etcdClient{
logger: logger,
compatSection: compatSection,
}
if err := result.load(config, false); err != nil {
return nil, err
}
return result, nil
}
func (c *etcdClient) GetServerInfoEtcd() *BackendServerInfoEtcd {
client := c.getEtcdClient()
if client == nil {
return nil
}
result := &BackendServerInfoEtcd{
Endpoints: client.Endpoints(),
}
conn := client.ActiveConnection()
if conn != nil {
result.Active = conn.Target()
result.Connected = internal.MakePtr(conn.GetState() == connectivity.Ready)
}
return result
}
func (c *etcdClient) getConfigStringWithFallback(config *goconf.ConfigFile, option string) string {
value, _ := config.GetString("etcd", option)
if value == "" && c.compatSection != "" {
value, _ = config.GetString(c.compatSection, option)
if value != "" {
c.logger.Printf("WARNING: Configuring etcd option \"%s\" in section \"%s\" is deprecated, use section \"etcd\" instead", option, c.compatSection)
}
}
return value
}
func (c *etcdClient) load(config *goconf.ConfigFile, ignoreErrors bool) error {
var endpoints []string
if endpointsString := c.getConfigStringWithFallback(config, "endpoints"); endpointsString != "" {
endpoints = slices.Collect(internal.SplitEntries(endpointsString, ","))
} else if discoverySrv := c.getConfigStringWithFallback(config, "discoverysrv"); discoverySrv != "" {
discoveryService := c.getConfigStringWithFallback(config, "discoveryservice")
clients, err := srv.GetClient("etcd-client", discoverySrv, discoveryService)
if err != nil {
if !ignoreErrors {
return fmt.Errorf("could not discover etcd endpoints for %s: %w", discoverySrv, err)
}
} else {
endpoints = clients.Endpoints
}
}
if len(endpoints) == 0 {
if !ignoreErrors {
return nil
}
c.logger.Printf("No etcd endpoints configured, not changing client")
} else {
cfg := clientv3.Config{
Endpoints: endpoints,
// set timeout per request to fail fast when the target endpoint is unavailable
DialTimeout: time.Second,
}
if logLevel, _ := config.GetString("etcd", "loglevel"); logLevel != "" {
var l zapcore.Level
if err := l.Set(logLevel); err != nil {
return fmt.Errorf("unsupported etcd log level %s: %w", logLevel, err)
}
logConfig := zap.NewProductionConfig()
logConfig.Level = zap.NewAtomicLevelAt(l)
cfg.LogConfig = &logConfig
}
clientKey := c.getConfigStringWithFallback(config, "clientkey")
clientCert := c.getConfigStringWithFallback(config, "clientcert")
caCert := c.getConfigStringWithFallback(config, "cacert")
if clientKey != "" && clientCert != "" && caCert != "" {
tlsInfo := transport.TLSInfo{
CertFile: clientCert,
KeyFile: clientKey,
TrustedCAFile: caCert,
}
tlsConfig, err := tlsInfo.ClientConfig()
if err != nil {
if !ignoreErrors {
return fmt.Errorf("could not setup etcd TLS configuration: %w", err)
}
c.logger.Printf("Could not setup TLS configuration, will be disabled (%s)", err)
} else {
cfg.TLS = tlsConfig
}
}
client, err := clientv3.New(cfg)
if err != nil {
if !ignoreErrors {
return err
}
c.logger.Printf("Could not create new client from etd endpoints %+v: %s", endpoints, err)
} else {
prev := c.getEtcdClient()
if prev != nil {
prev.Close()
}
c.client.Store(client)
c.logger.Printf("Using etcd endpoints %+v", endpoints)
c.notifyListeners()
}
}
return nil
}
func (c *etcdClient) Close() error {
client := c.getEtcdClient()
if client != nil {
return client.Close()
}
return nil
}
func (c *etcdClient) IsConfigured() bool {
return c.getEtcdClient() != nil
}
func (c *etcdClient) getEtcdClient() *clientv3.Client {
client := c.client.Load()
if client == nil {
return nil
}
return client.(*clientv3.Client)
}
func (c *etcdClient) syncClient(ctx context.Context) error {
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
return c.getEtcdClient().Sync(ctx)
}
func (c *etcdClient) notifyListeners() {
c.mu.Lock()
defer c.mu.Unlock()
for listener := range c.listeners {
listener.EtcdClientCreated(c)
}
}
func (c *etcdClient) AddListener(listener ClientListener) {
c.mu.Lock()
defer c.mu.Unlock()
if c.listeners == nil {
c.listeners = make(map[ClientListener]bool)
}
c.listeners[listener] = true
if client := c.getEtcdClient(); client != nil {
go listener.EtcdClientCreated(c)
}
}
func (c *etcdClient) RemoveListener(listener ClientListener) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.listeners, listener)
}
func (c *etcdClient) WaitForConnection(ctx context.Context) error {
backoff, err := async.NewExponentialBackoff(initialWaitDelay, maxWaitDelay)
if err != nil {
return err
}
for {
if err := ctx.Err(); err != nil {
return err
}
if err := c.syncClient(ctx); err != nil {
if errors.Is(err, context.Canceled) {
return err
} else if errors.Is(err, context.DeadlineExceeded) {
c.logger.Printf("Timeout waiting for etcd client to connect to the cluster, retry in %s", backoff.NextWait())
} else {
c.logger.Printf("Could not sync etcd client with the cluster, retry in %s: %s", backoff.NextWait(), err)
}
backoff.Wait(ctx)
continue
}
c.logger.Printf("Client synced, using endpoints %+v", c.getEtcdClient().Endpoints())
return nil
}
}
func (c *etcdClient) Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) {
return c.getEtcdClient().Get(ctx, key, opts...)
}
func (c *etcdClient) Watch(ctx context.Context, key string, nextRevision int64, watcher ClientWatcher, opts ...clientv3.OpOption) (int64, error) {
c.logger.Printf("Wait for leader and start watching on %s (rev=%d)", key, nextRevision)
opts = append(opts, clientv3.WithRev(nextRevision), clientv3.WithPrevKV())
ch := c.getEtcdClient().Watch(clientv3.WithRequireLeader(ctx), key, opts...)
c.logger.Printf("Watch created for %s", key)
watcher.EtcdWatchCreated(c, key)
for response := range ch {
if err := response.Err(); err != nil {
return nextRevision, err
}
nextRevision = response.Header.Revision + 1
for _, ev := range response.Events {
switch ev.Type {
case clientv3.EventTypePut:
var prevValue []byte
if ev.PrevKv != nil {
prevValue = ev.PrevKv.Value
}
watcher.EtcdKeyUpdated(c, string(ev.Kv.Key), ev.Kv.Value, prevValue)
case clientv3.EventTypeDelete:
var prevValue []byte
if ev.PrevKv != nil {
prevValue = ev.PrevKv.Value
}
watcher.EtcdKeyDeleted(c, string(ev.Kv.Key), prevValue)
default:
c.logger.Printf("Unsupported watch event %s %q -> %q", ev.Type, ev.Kv.Key, ev.Kv.Value)
}
}
}
return nextRevision, nil
}

421
etcd/client_test.go Normal file
View file

@ -0,0 +1,421 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2022 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package etcd
import (
"context"
"crypto/rand"
"crypto/rsa"
"net"
"net/url"
"os"
"path"
"strconv"
"testing"
"time"
"github.com/dlintw/goconf"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.etcd.io/etcd/api/v3/mvccpb"
clientv3 "go.etcd.io/etcd/client/v3"
"go.etcd.io/etcd/server/v3/embed"
"go.etcd.io/etcd/server/v3/lease"
"go.uber.org/zap"
"go.uber.org/zap/zaptest"
"github.com/strukturag/nextcloud-spreed-signaling/internal"
"github.com/strukturag/nextcloud-spreed-signaling/log"
"github.com/strukturag/nextcloud-spreed-signaling/test"
)
const (
testTimeout = 10 * time.Second
)
var (
etcdListenUrl = "http://localhost:8080"
)
func NewEtcdForTestWithTls(t *testing.T, withTLS bool) (*embed.Etcd, string, string) {
t.Helper()
require := require.New(t)
cfg := embed.NewConfig()
cfg.Dir = t.TempDir()
os.Chmod(cfg.Dir, 0700) // nolint
cfg.LogLevel = "warn"
cfg.Name = "signalingtest"
cfg.ZapLoggerBuilder = embed.NewZapLoggerBuilder(zaptest.NewLogger(t, zaptest.Level(zap.WarnLevel)))
u, err := url.Parse(etcdListenUrl)
require.NoError(err)
var keyfile string
var certfile string
if withTLS {
u.Scheme = "https"
tmpdir := t.TempDir()
key, err := rsa.GenerateKey(rand.Reader, 1024)
require.NoError(err)
keyfile = path.Join(tmpdir, "etcd.key")
require.NoError(internal.WritePrivateKey(key, keyfile))
cfg.ClientTLSInfo.KeyFile = keyfile
cfg.PeerTLSInfo.KeyFile = keyfile
cert := internal.GenerateSelfSignedCertificateForTesting(t, "etcd", key)
certfile = path.Join(tmpdir, "etcd.pem")
require.NoError(internal.WriteCertificate(cert, certfile))
cfg.ClientTLSInfo.CertFile = certfile
cfg.ClientTLSInfo.TrustedCAFile = certfile
cfg.PeerTLSInfo.CertFile = certfile
cfg.PeerTLSInfo.TrustedCAFile = certfile
}
// Find a free port to bind the server to.
var etcd *embed.Etcd
for port := 50000; port < 50100; port++ {
u.Host = net.JoinHostPort("localhost", strconv.Itoa(port))
cfg.ListenClientUrls = []url.URL{*u}
cfg.AdvertiseClientUrls = []url.URL{*u}
httpListener := u
httpListener.Host = net.JoinHostPort("localhost", strconv.Itoa(port+1))
cfg.ListenClientHttpUrls = []url.URL{*httpListener}
peerListener := u
peerListener.Host = net.JoinHostPort("localhost", strconv.Itoa(port+2))
cfg.ListenPeerUrls = []url.URL{*peerListener}
cfg.AdvertisePeerUrls = []url.URL{*peerListener}
cfg.InitialCluster = "signalingtest=" + peerListener.String()
etcd, err = embed.StartEtcd(cfg)
if test.IsErrorAddressAlreadyInUse(err) {
continue
}
require.NoError(err)
break
}
require.NotNil(etcd, "could not find free port")
t.Cleanup(func() {
etcd.Close()
<-etcd.Server.StopNotify()
})
// Wait for server to be ready.
<-etcd.Server.ReadyNotify()
return etcd, keyfile, certfile
}
func NewEtcdForTest(t *testing.T) *embed.Etcd {
t.Helper()
etcd, _, _ := NewEtcdForTestWithTls(t, false)
return etcd
}
func NewClientForTest(t *testing.T) (*embed.Etcd, Client) {
etcd := NewEtcdForTest(t)
config := goconf.NewConfigFile()
config.AddOption("etcd", "endpoints", etcd.Config().ListenClientUrls[0].String())
config.AddOption("etcd", "loglevel", "error")
logger := log.NewLoggerForTest(t)
client, err := NewClient(logger, config, "")
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, client.Close())
})
return etcd, client
}
func NewEtcdClientWithTLSForTest(t *testing.T) (*embed.Etcd, Client) {
etcd, keyfile, certfile := NewEtcdForTestWithTls(t, true)
config := goconf.NewConfigFile()
config.AddOption("etcd", "endpoints", etcd.Config().ListenClientUrls[0].String())
config.AddOption("etcd", "loglevel", "error")
config.AddOption("etcd", "clientkey", keyfile)
config.AddOption("etcd", "clientcert", certfile)
config.AddOption("etcd", "cacert", certfile)
logger := log.NewLoggerForTest(t)
client, err := NewClient(logger, config, "")
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, client.Close())
})
return etcd, client
}
func SetValue(etcd *embed.Etcd, key string, value []byte) {
if kv := etcd.Server.KV(); kv != nil {
kv.Put([]byte(key), value, lease.NoLease)
kv.Commit()
}
}
func DeleteValue(etcd *embed.Etcd, key string) {
if kv := etcd.Server.KV(); kv != nil {
kv.DeleteRange([]byte(key), nil)
kv.Commit()
}
}
func Test_EtcdClient_Get(t *testing.T) {
t.Parallel()
logger := log.NewLoggerForTest(t)
ctx := log.NewLoggerContext(t.Context(), logger)
assert := assert.New(t)
require := require.New(t)
etcd, client := NewClientForTest(t)
ctx, cancel := context.WithTimeout(ctx, testTimeout)
defer cancel()
if info := client.GetServerInfoEtcd(); assert.NotNil(info) {
assert.NotEmpty(info.Active)
assert.Equal([]string{
etcd.Config().ListenClientUrls[0].String(),
}, info.Endpoints)
assert.NotNil(info.Connected)
}
require.NoError(client.WaitForConnection(ctx))
if info := client.GetServerInfoEtcd(); assert.NotNil(info) {
assert.NotEmpty(info.Active)
assert.Equal([]string{
etcd.Config().ListenClientUrls[0].String(),
}, info.Endpoints)
if connected := info.Connected; assert.NotNil(connected) {
assert.True(*connected)
}
}
if response, err := client.Get(ctx, "foo"); assert.NoError(err) {
assert.EqualValues(0, response.Count)
}
SetValue(etcd, "foo", []byte("bar"))
if response, err := client.Get(ctx, "foo"); assert.NoError(err) {
if assert.EqualValues(1, response.Count) {
assert.Equal("foo", string(response.Kvs[0].Key))
assert.Equal("bar", string(response.Kvs[0].Value))
}
}
}
func Test_EtcdClientTLS_Get(t *testing.T) {
t.Parallel()
logger := log.NewLoggerForTest(t)
ctx := log.NewLoggerContext(t.Context(), logger)
assert := assert.New(t)
require := require.New(t)
etcd, client := NewEtcdClientWithTLSForTest(t)
ctx, cancel := context.WithTimeout(ctx, testTimeout)
defer cancel()
if info := client.GetServerInfoEtcd(); assert.NotNil(info) {
assert.NotEmpty(info.Active)
assert.Equal([]string{
etcd.Config().ListenClientUrls[0].String(),
}, info.Endpoints)
assert.NotNil(info.Connected)
}
require.NoError(client.WaitForConnection(ctx))
if info := client.GetServerInfoEtcd(); assert.NotNil(info) {
assert.NotEmpty(info.Active)
assert.Equal([]string{
etcd.Config().ListenClientUrls[0].String(),
}, info.Endpoints)
if connected := info.Connected; assert.NotNil(connected) {
assert.True(*connected)
}
}
if response, err := client.Get(ctx, "foo"); assert.NoError(err) {
assert.EqualValues(0, response.Count)
}
SetValue(etcd, "foo", []byte("bar"))
if response, err := client.Get(ctx, "foo"); assert.NoError(err) {
if assert.EqualValues(1, response.Count) {
assert.Equal("foo", string(response.Kvs[0].Key))
assert.Equal("bar", string(response.Kvs[0].Value))
}
}
}
func Test_EtcdClient_GetPrefix(t *testing.T) {
t.Parallel()
logger := log.NewLoggerForTest(t)
ctx := log.NewLoggerContext(t.Context(), logger)
assert := assert.New(t)
etcd, client := NewClientForTest(t)
if response, err := client.Get(ctx, "foo"); assert.NoError(err) {
assert.EqualValues(0, response.Count)
}
SetValue(etcd, "foo", []byte("1"))
SetValue(etcd, "foo/lala", []byte("2"))
SetValue(etcd, "lala/foo", []byte("3"))
if response, err := client.Get(ctx, "foo", clientv3.WithPrefix()); assert.NoError(err) {
if assert.EqualValues(2, response.Count) {
assert.Equal("foo", string(response.Kvs[0].Key))
assert.Equal("1", string(response.Kvs[0].Value))
assert.Equal("foo/lala", string(response.Kvs[1].Key))
assert.Equal("2", string(response.Kvs[1].Value))
}
}
}
type etcdEvent struct {
t mvccpb.Event_EventType
key string
value string
prevValue string
}
type EtcdClientTestListener struct {
t *testing.T
ctx context.Context
cancel context.CancelFunc
initial chan struct{}
events chan etcdEvent
}
func NewEtcdClientTestListener(ctx context.Context, t *testing.T) *EtcdClientTestListener {
ctx, cancel := context.WithCancel(ctx)
return &EtcdClientTestListener{
t: t,
ctx: ctx,
cancel: cancel,
initial: make(chan struct{}),
events: make(chan etcdEvent),
}
}
func (l *EtcdClientTestListener) Close() {
l.cancel()
}
func (l *EtcdClientTestListener) EtcdClientCreated(client Client) {
go func() {
assert := assert.New(l.t)
if err := client.WaitForConnection(l.ctx); !assert.NoError(err) {
return
}
ctx, cancel := context.WithTimeout(l.ctx, time.Second)
defer cancel()
response, err := client.Get(ctx, "foo", clientv3.WithPrefix())
if assert.NoError(err) && assert.EqualValues(1, response.Count) {
assert.Equal("foo/a", string(response.Kvs[0].Key))
assert.Equal("1", string(response.Kvs[0].Value))
}
close(l.initial)
nextRevision := response.Header.Revision + 1
for l.ctx.Err() == nil {
var err error
nextRevision, err = client.Watch(clientv3.WithRequireLeader(l.ctx), "foo", nextRevision, l, clientv3.WithPrefix())
assert.NoError(err)
}
}()
}
func (l *EtcdClientTestListener) EtcdWatchCreated(client Client, key string) {
}
func (l *EtcdClientTestListener) EtcdKeyUpdated(client Client, key string, value []byte, prevValue []byte) {
evt := etcdEvent{
t: clientv3.EventTypePut,
key: string(key),
value: string(value),
}
if len(prevValue) > 0 {
evt.prevValue = string(prevValue)
}
l.events <- evt
}
func (l *EtcdClientTestListener) EtcdKeyDeleted(client Client, key string, prevValue []byte) {
evt := etcdEvent{
t: clientv3.EventTypeDelete,
key: string(key),
}
if len(prevValue) > 0 {
evt.prevValue = string(prevValue)
}
l.events <- evt
}
func Test_EtcdClient_Watch(t *testing.T) {
t.Parallel()
logger := log.NewLoggerForTest(t)
ctx := log.NewLoggerContext(t.Context(), logger)
assert := assert.New(t)
etcd, client := NewClientForTest(t)
SetValue(etcd, "foo/a", []byte("1"))
listener := NewEtcdClientTestListener(ctx, t)
defer listener.Close()
client.AddListener(listener)
defer client.RemoveListener(listener)
<-listener.initial
SetValue(etcd, "foo/b", []byte("2"))
event := <-listener.events
assert.Equal(clientv3.EventTypePut, event.t)
assert.Equal("foo/b", event.key)
assert.Equal("2", event.value)
SetValue(etcd, "foo/a", []byte("3"))
event = <-listener.events
assert.Equal(clientv3.EventTypePut, event.t)
assert.Equal("foo/a", event.key)
assert.Equal("3", event.value)
DeleteValue(etcd, "foo/a")
event = <-listener.events
assert.Equal(clientv3.EventTypeDelete, event.t)
assert.Equal("foo/a", event.key)
assert.Equal("3", event.prevValue)
}

466
etcd/etcdtest/etcdtest.go Normal file
View file

@ -0,0 +1,466 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2025 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package etcdtest
import (
"bytes"
"context"
"errors"
"net"
"net/url"
"os"
"slices"
"strconv"
"strings"
"sync"
"testing"
"github.com/dlintw/goconf"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.etcd.io/etcd/api/v3/etcdserverpb"
"go.etcd.io/etcd/api/v3/mvccpb"
clientv3 "go.etcd.io/etcd/client/v3"
"go.etcd.io/etcd/server/v3/embed"
"go.etcd.io/etcd/server/v3/lease"
"go.uber.org/zap"
"go.uber.org/zap/zaptest"
"github.com/strukturag/nextcloud-spreed-signaling/etcd"
"github.com/strukturag/nextcloud-spreed-signaling/log"
"github.com/strukturag/nextcloud-spreed-signaling/test"
)
var (
etcdListenUrl = "http://localhost:8080"
)
type Server struct {
embed *embed.Etcd
}
func (s *Server) URL() *url.URL {
return &s.embed.Config().ListenClientUrls[0]
}
func (s *Server) SetValue(key string, value []byte) {
if kv := s.embed.Server.KV(); kv != nil {
kv.Put([]byte(key), value, lease.NoLease)
kv.Commit()
}
}
func (s *Server) DeleteValue(key string) {
if kv := s.embed.Server.KV(); kv != nil {
kv.DeleteRange([]byte(key), nil)
kv.Commit()
}
}
func NewServerForTest(t *testing.T) *Server {
t.Helper()
require := require.New(t)
cfg := embed.NewConfig()
cfg.Dir = t.TempDir()
os.Chmod(cfg.Dir, 0700) // nolint
cfg.LogLevel = "warn"
cfg.Name = "signalingtest"
cfg.ZapLoggerBuilder = embed.NewZapLoggerBuilder(zaptest.NewLogger(t, zaptest.Level(zap.WarnLevel)))
u, err := url.Parse(etcdListenUrl)
require.NoError(err)
// Find a free port to bind the server to.
var etcd *embed.Etcd
for port := 50000; port < 50100; port++ {
u.Host = net.JoinHostPort("localhost", strconv.Itoa(port))
cfg.ListenClientUrls = []url.URL{*u}
cfg.AdvertiseClientUrls = []url.URL{*u}
httpListener := u
httpListener.Host = net.JoinHostPort("localhost", strconv.Itoa(port+1))
cfg.ListenClientHttpUrls = []url.URL{*httpListener}
peerListener := u
peerListener.Host = net.JoinHostPort("localhost", strconv.Itoa(port+2))
cfg.ListenPeerUrls = []url.URL{*peerListener}
cfg.AdvertisePeerUrls = []url.URL{*peerListener}
cfg.InitialCluster = "signalingtest=" + peerListener.String()
etcd, err = embed.StartEtcd(cfg)
if test.IsErrorAddressAlreadyInUse(err) {
continue
}
require.NoError(err)
break
}
require.NotNil(etcd, "could not find free port")
t.Cleanup(func() {
etcd.Close()
<-etcd.Server.StopNotify()
})
// Wait for server to be ready.
<-etcd.Server.ReadyNotify()
server := &Server{
embed: etcd,
}
return server
}
func NewEtcdClientForTest(t *testing.T, server *Server) etcd.Client {
t.Helper()
logger := log.NewLoggerForTest(t)
config := goconf.NewConfigFile()
config.AddOption("etcd", "endpoints", server.URL().String())
client, err := etcd.NewClient(logger, config, "")
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, client.Close())
})
return client
}
type testWatch struct {
key string
op clientv3.Op
rev int64
watcher etcd.ClientWatcher
}
type testClient struct {
mu sync.Mutex
server *TestServer
// +checklocks:mu
closed bool
closeCh chan struct{}
processCh chan func()
// +checklocks:mu
listeners []etcd.ClientListener
// +checklocks:mu
watchers []*testWatch
}
func newTestClient(server *TestServer) *testClient {
client := &testClient{
server: server,
closeCh: make(chan struct{}),
processCh: make(chan func(), 1),
}
go func() {
defer close(client.closeCh)
for {
f := <-client.processCh
if f == nil {
return
}
f()
}
}()
return client
}
func (c *testClient) IsConfigured() bool {
return true
}
func (c *testClient) WaitForConnection(ctx context.Context) error {
return nil
}
func (c *testClient) GetServerInfoEtcd() *etcd.BackendServerInfoEtcd {
return &etcd.BackendServerInfoEtcd{}
}
func (c *testClient) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return nil
}
c.closed = true
c.server.removeClient(c)
close(c.processCh)
<-c.closeCh
return nil
}
func (c *testClient) AddListener(listener etcd.ClientListener) {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return
}
c.listeners = append(c.listeners, listener)
c.processCh <- func() {
listener.EtcdClientCreated(c)
}
}
func (c *testClient) RemoveListener(listener etcd.ClientListener) {
c.mu.Lock()
defer c.mu.Unlock()
c.listeners = slices.DeleteFunc(c.listeners, func(l etcd.ClientListener) bool {
return l == listener
})
}
func (c *testClient) Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) {
keys, values, revision := c.server.getValues(key, 0, opts...)
response := &clientv3.GetResponse{
Count: int64(len(values)),
Header: &etcdserverpb.ResponseHeader{
Revision: revision,
},
}
for idx, key := range keys {
response.Kvs = append(response.Kvs, &mvccpb.KeyValue{
Key: []byte(key),
Value: values[idx],
})
}
return response, nil
}
func (c *testClient) notifyUpdated(key string, oldValue []byte, newValue []byte) {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return
}
for _, w := range c.watchers {
if withPrefix := w.op.IsOptsWithPrefix(); (withPrefix && strings.HasPrefix(key, w.key)) || (!withPrefix && key == w.key) {
c.processCh <- func() {
w.watcher.EtcdKeyUpdated(c, key, newValue, oldValue)
}
}
}
}
func (c *testClient) notifyDeleted(key string, oldValue []byte) {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return
}
for _, w := range c.watchers {
if withPrefix := w.op.IsOptsWithPrefix(); (withPrefix && strings.HasPrefix(key, w.key)) || (!withPrefix && key == w.key) {
c.processCh <- func() {
w.watcher.EtcdKeyDeleted(c, key, oldValue)
}
}
}
}
func (c *testClient) addWatcher(w *testWatch, opts ...clientv3.OpOption) error {
keys, values, _ := c.server.getValues(w.key, w.rev, opts...)
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return errors.New("closed")
}
c.watchers = append(c.watchers, w)
c.processCh <- func() {
w.watcher.EtcdWatchCreated(c, w.key)
}
for idx, key := range keys {
c.processCh <- func() {
w.watcher.EtcdKeyUpdated(c, key, values[idx], nil)
}
}
return nil
}
func (c *testClient) Watch(ctx context.Context, key string, nextRevision int64, watcher etcd.ClientWatcher, opts ...clientv3.OpOption) (int64, error) {
w := &testWatch{
key: key,
rev: nextRevision,
watcher: watcher,
}
for _, o := range opts {
o(&w.op)
}
if err := c.addWatcher(w, opts...); err != nil {
return 0, err
}
select {
case <-c.closeCh:
// Client is closed.
case <-ctx.Done():
// Watch context was cancelled / timed out.
}
return c.server.getRevision(), nil
}
type testServerValue struct {
value []byte
revision int64
}
type TestServer struct {
t *testing.T
mu sync.Mutex
// +checklocks:mu
clients []*testClient
// +checklocks:mu
values map[string]*testServerValue
// +checklocks:mu
revision int64
}
func (s *TestServer) newClient() *testClient {
client := newTestClient(s)
s.addClient(client)
return client
}
func (s *TestServer) addClient(client *testClient) {
s.mu.Lock()
defer s.mu.Unlock()
s.clients = append(s.clients, client)
}
func (s *TestServer) removeClient(client *testClient) {
s.mu.Lock()
defer s.mu.Unlock()
s.clients = slices.DeleteFunc(s.clients, func(c *testClient) bool {
return c == client
})
}
func (s *TestServer) getRevision() int64 {
s.mu.Lock()
defer s.mu.Unlock()
return s.revision
}
func (s *TestServer) getValues(key string, minRevision int64, opts ...clientv3.OpOption) (keys []string, values [][]byte, revision int64) {
s.mu.Lock()
defer s.mu.Unlock()
var op clientv3.Op
for _, o := range opts {
o(&op)
}
if op.IsOptsWithPrefix() {
for k, value := range s.values {
if minRevision > 0 && value.revision < minRevision {
continue
}
if strings.HasPrefix(k, key) {
keys = append(keys, k)
values = append(values, value.value)
}
}
} else {
if value, found := s.values[key]; found && (minRevision == 0 || value.revision >= minRevision) {
keys = append(keys, key)
values = append(values, value.value)
}
}
revision = s.revision
return
}
func (s *TestServer) SetValue(key string, value []byte) {
s.mu.Lock()
defer s.mu.Unlock()
prev, found := s.values[key]
if found && bytes.Equal(prev.value, value) {
return
}
if s.values == nil {
s.values = make(map[string]*testServerValue)
}
if prev == nil {
prev = &testServerValue{}
s.values[key] = prev
}
s.revision++
prevValue := prev.value
prev.value = value
prev.revision = s.revision
for _, c := range s.clients {
c.notifyUpdated(key, prevValue, value)
}
}
func (s *TestServer) DeleteValue(key string) {
s.mu.Lock()
defer s.mu.Unlock()
prev, found := s.values[key]
if !found {
return
}
delete(s.values, key)
s.revision++
for _, c := range s.clients {
c.notifyDeleted(key, prev.value)
}
}
func NewClientForTest(t *testing.T) (*TestServer, etcd.Client) {
t.Helper()
server := &TestServer{
t: t,
revision: 1,
}
client := server.newClient()
t.Cleanup(func() {
assert.NoError(t, client.Close())
})
return server, client
}

View file

@ -0,0 +1,319 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2025 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package etcdtest
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
clientv3 "go.etcd.io/etcd/client/v3"
"github.com/strukturag/nextcloud-spreed-signaling/etcd"
)
var (
testTimeout = 10 * time.Second
)
type updateEvent struct {
key string
value string
prev []byte
}
type deleteEvent struct {
key string
prev []byte
}
type testWatcher struct {
created chan struct{}
updated chan updateEvent
deleted chan deleteEvent
}
func newTestWatcher() *testWatcher {
return &testWatcher{
created: make(chan struct{}),
updated: make(chan updateEvent),
deleted: make(chan deleteEvent),
}
}
func (w *testWatcher) EtcdWatchCreated(client etcd.Client, key string) {
close(w.created)
}
func (w *testWatcher) EtcdKeyUpdated(client etcd.Client, key string, value []byte, prevValue []byte) {
w.updated <- updateEvent{
key: key,
value: string(value),
prev: prevValue,
}
}
func (w *testWatcher) EtcdKeyDeleted(client etcd.Client, key string, prevValue []byte) {
w.deleted <- deleteEvent{
key: key,
prev: prevValue,
}
}
type serverInterface interface {
SetValue(key string, value []byte)
DeleteValue(key string)
}
type testClientListener struct {
called chan struct{}
}
func (l *testClientListener) EtcdClientCreated(c etcd.Client) {
close(l.called)
}
func testServerWatch(t *testing.T, server serverInterface, client etcd.Client) {
require := require.New(t)
assert := assert.New(t)
ctx, cancel := context.WithTimeout(t.Context(), testTimeout)
defer cancel()
cancelCtx, cancel := context.WithCancel(ctx)
defer cancel()
assert.True(client.IsConfigured(), "should be configured")
require.NoError(client.WaitForConnection(ctx))
listener := &testClientListener{
called: make(chan struct{}),
}
client.AddListener(listener)
defer client.RemoveListener(listener)
select {
case <-listener.called:
case <-ctx.Done():
require.NoError(ctx.Err())
}
watcher := newTestWatcher()
go func() {
if _, err := client.Watch(cancelCtx, "foo", 0, watcher); err != nil {
assert.ErrorIs(err, context.Canceled)
}
}()
select {
case <-watcher.created:
case <-ctx.Done():
require.NoError(ctx.Err())
}
key := "foo"
value := "bar"
server.SetValue("foo", []byte(value))
select {
case evt := <-watcher.updated:
assert.Equal(key, evt.key)
assert.Equal(value, evt.value)
assert.Empty(evt.prev)
case <-ctx.Done():
require.NoError(ctx.Err())
}
if response, err := client.Get(ctx, "foo"); assert.NoError(err) {
assert.EqualValues(1, response.Count)
if assert.Len(response.Kvs, 1) {
assert.Equal(key, string(response.Kvs[0].Key))
assert.Equal(value, string(response.Kvs[0].Value))
}
}
if response, err := client.Get(ctx, "f"); assert.NoError(err) {
assert.EqualValues(0, response.Count)
assert.Empty(response.Kvs)
}
if response, err := client.Get(ctx, "f", clientv3.WithPrefix()); assert.NoError(err) {
assert.EqualValues(1, response.Count)
if assert.Len(response.Kvs, 1) {
assert.Equal(key, string(response.Kvs[0].Key))
assert.Equal(value, string(response.Kvs[0].Value))
}
}
server.DeleteValue("foo")
select {
case evt := <-watcher.deleted:
assert.Equal(key, evt.key)
assert.Equal(value, string(evt.prev))
case <-ctx.Done():
require.NoError(ctx.Err())
}
select {
case evt := <-watcher.updated:
assert.Fail("unexpected update event", "got %+v", evt)
case evt := <-watcher.deleted:
assert.Fail("unexpected deleted event", "got %+v", evt)
default:
}
}
func TestServerWatch_Mock(t *testing.T) {
t.Parallel()
server, client := NewClientForTest(t)
testServerWatch(t, server, client)
}
func TestServerWatch_Real(t *testing.T) {
t.Parallel()
server := NewServerForTest(t)
client := NewEtcdClientForTest(t, server)
testServerWatch(t, server, client)
}
func testServerWatchInitialData(t *testing.T, server serverInterface, client etcd.Client) {
require := require.New(t)
assert := assert.New(t)
ctx, cancel := context.WithTimeout(t.Context(), testTimeout)
defer cancel()
key := "foo"
value := "bar"
server.SetValue("foo", []byte(value))
cancelCtx, cancel := context.WithCancel(ctx)
defer cancel()
watcher := newTestWatcher()
go func() {
if _, err := client.Watch(cancelCtx, "foo", 1, watcher); err != nil {
assert.ErrorIs(err, context.Canceled)
}
}()
select {
case <-watcher.created:
case <-ctx.Done():
require.NoError(ctx.Err())
}
select {
case evt := <-watcher.updated:
assert.Equal(key, evt.key)
assert.Equal(value, evt.value)
assert.Empty(evt.prev)
case <-ctx.Done():
require.NoError(ctx.Err())
}
select {
case evt := <-watcher.updated:
assert.Fail("unexpected update event", "got %+v", evt)
case evt := <-watcher.deleted:
assert.Fail("unexpected deleted event", "got %+v", evt)
default:
}
}
func TestServerWatchInitialData_Mock(t *testing.T) {
t.Parallel()
server, client := NewClientForTest(t)
testServerWatchInitialData(t, server, client)
}
func TestServerWatchInitialData_Real(t *testing.T) {
t.Parallel()
server := NewServerForTest(t)
client := NewEtcdClientForTest(t, server)
testServerWatchInitialData(t, server, client)
}
func testServerWatchInitialOldData(t *testing.T, server serverInterface, client etcd.Client) {
require := require.New(t)
assert := assert.New(t)
ctx, cancel := context.WithTimeout(t.Context(), testTimeout)
defer cancel()
key := "foo"
value := "bar"
server.SetValue("foo", []byte(value))
response, err := client.Get(ctx, key)
require.NoError(err)
if assert.EqualValues(1, response.Count) && assert.Len(response.Kvs, 1) {
assert.Equal(key, string(response.Kvs[0].Key))
assert.Equal(value, string(response.Kvs[0].Value))
}
cancelCtx, cancel := context.WithCancel(ctx)
defer cancel()
watcher := newTestWatcher()
go func() {
if _, err := client.Watch(cancelCtx, "foo", response.Header.GetRevision()+1, watcher); err != nil {
assert.ErrorIs(err, context.Canceled)
}
}()
select {
case <-watcher.created:
case <-ctx.Done():
require.NoError(ctx.Err())
}
select {
case evt := <-watcher.updated:
assert.Fail("unexpected update event", "got %+v", evt)
case evt := <-watcher.deleted:
assert.Fail("unexpected deleted event", "got %+v", evt)
default:
}
}
func TestServerWatchInitialOldData_Mock(t *testing.T) {
t.Parallel()
server, client := NewClientForTest(t)
testServerWatchInitialOldData(t, server, client)
}
func TestServerWatchInitialOldData_Real(t *testing.T) {
t.Parallel()
server := NewServerForTest(t)
client := NewEtcdClientForTest(t, server)
testServerWatchInitialOldData(t, server, client)
}