Compare commits

...

5 Commits

Author SHA1 Message Date
Joachim Bauch b2da4002a4
grpc: Reload certificate if file has changed and support mutual authentication. 2022-07-04 11:05:21 +02:00
Joachim Bauch 06e9ae0644
Add certificate reloader class. 2022-07-04 10:50:44 +02:00
Joachim Bauch 44bf8b74c2
grpc: Make sure DNS discovery of clients continues if initial lookup failed. 2022-07-01 11:42:49 +02:00
Joachim Bauch 15dabeee1e
grpc: Check clients for own server id asychronously.
The external address of the (own) GRPC server might only be reachable after
some time, so performing the check only initially could fail but will
succeed later.
2022-07-01 10:22:16 +02:00
Joachim Bauch 715b2317df
Add helper to wait with exponential backoff. 2022-07-01 10:21:49 +02:00
12 changed files with 1118 additions and 110 deletions

76
backoff.go Normal file
View File

@ -0,0 +1,76 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2022 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package signaling
import (
"context"
"fmt"
"time"
)
type Backoff interface {
Reset()
NextWait() time.Duration
Wait(context.Context)
}
type exponentialBackoff struct {
initial time.Duration
maxWait time.Duration
nextWait time.Duration
}
func NewExponentialBackoff(initial time.Duration, maxWait time.Duration) (Backoff, error) {
if initial <= 0 {
return nil, fmt.Errorf("initial must be larger than 0")
}
if maxWait < initial {
return nil, fmt.Errorf("maxWait must be larger or equal to initial")
}
return &exponentialBackoff{
initial: initial,
maxWait: maxWait,
nextWait: initial,
}, nil
}
func (b *exponentialBackoff) Reset() {
b.nextWait = b.initial
}
func (b *exponentialBackoff) NextWait() time.Duration {
return b.nextWait
}
func (b *exponentialBackoff) Wait(ctx context.Context) {
waiter, cancel := context.WithTimeout(ctx, b.nextWait)
defer cancel()
b.nextWait = b.nextWait * 2
if b.nextWait > b.maxWait {
b.nextWait = b.maxWait
}
<-waiter.Done()
}

64
backoff_test.go Normal file
View File

@ -0,0 +1,64 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2022 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package signaling
import (
"context"
"testing"
"time"
)
func TestBackoff_Exponential(t *testing.T) {
backoff, err := NewExponentialBackoff(100*time.Millisecond, 500*time.Millisecond)
if err != nil {
t.Fatal(err)
}
waitTimes := []time.Duration{
100 * time.Millisecond,
200 * time.Millisecond,
400 * time.Millisecond,
500 * time.Millisecond,
500 * time.Millisecond,
}
for _, wait := range waitTimes {
if backoff.NextWait() != wait {
t.Errorf("Wait time should be %s, got %s", wait, backoff.NextWait())
}
a := time.Now()
backoff.Wait(context.Background())
b := time.Now()
if b.Sub(a) < wait {
t.Errorf("Should have waited %s, got %s", wait, b.Sub(a))
}
}
backoff.Reset()
a := time.Now()
backoff.Wait(context.Background())
b := time.Now()
if b.Sub(a) < 100*time.Millisecond {
t.Errorf("Should have waited %s, got %s", 100*time.Millisecond, b.Sub(a))
}
}

190
certificate_reloader.go Normal file
View File

@ -0,0 +1,190 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2022 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package signaling
import (
"crypto/tls"
"crypto/x509"
"fmt"
"log"
"os"
"sync"
"time"
)
var (
// CertificateCheckInterval defines the interval in which certificate files
// are checked for modifications.
CertificateCheckInterval = time.Minute
)
type CertificateReloader struct {
mu sync.Mutex
certFile string
keyFile string
certificate *tls.Certificate
lastModified time.Time
nextCheck time.Time
}
func NewCertificateReloader(certFile string, keyFile string) (*CertificateReloader, error) {
pair, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, fmt.Errorf("could not load certificate / key: %w", err)
}
stat, err := os.Stat(certFile)
if err != nil {
return nil, fmt.Errorf("could not stat %s: %w", certFile, err)
}
return &CertificateReloader{
certFile: certFile,
keyFile: keyFile,
certificate: &pair,
lastModified: stat.ModTime(),
nextCheck: time.Now().Add(CertificateCheckInterval),
}, nil
}
func (r *CertificateReloader) getCertificate() (*tls.Certificate, error) {
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now()
if now.Before(r.nextCheck) {
return r.certificate, nil
}
r.nextCheck = now.Add(CertificateCheckInterval)
stat, err := os.Stat(r.certFile)
if err != nil {
log.Printf("could not stat %s: %s", r.certFile, err)
return r.certificate, nil
}
if !stat.ModTime().Equal(r.lastModified) {
log.Printf("reloading certificate from %s with %s", r.certFile, r.keyFile)
pair, err := tls.LoadX509KeyPair(r.certFile, r.keyFile)
if err != nil {
log.Printf("could not load certificate / key: %s", err)
return r.certificate, nil
}
r.certificate = &pair
r.lastModified = stat.ModTime()
}
return r.certificate, nil
}
func (r *CertificateReloader) GetCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) {
return r.getCertificate()
}
func (r *CertificateReloader) GetClientCertificate(i *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return r.getCertificate()
}
type CertPoolReloader struct {
mu sync.Mutex
certFile string
pool *x509.CertPool
lastModified time.Time
nextCheck time.Time
}
func loadCertPool(filename string) (*x509.CertPool, error) {
cert, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(cert) {
return nil, fmt.Errorf("invalid CA in %s: %w", filename, err)
}
return pool, nil
}
func NewCertPoolReloader(certFile string) (*CertPoolReloader, error) {
pool, err := loadCertPool(certFile)
if err != nil {
return nil, err
}
stat, err := os.Stat(certFile)
if err != nil {
return nil, fmt.Errorf("could not stat %s: %w", certFile, err)
}
return &CertPoolReloader{
certFile: certFile,
pool: pool,
lastModified: stat.ModTime(),
nextCheck: time.Now().Add(CertificateCheckInterval),
}, nil
}
func (r *CertPoolReloader) GetCertPool() *x509.CertPool {
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now()
if now.Before(r.nextCheck) {
return r.pool
}
r.nextCheck = now.Add(CertificateCheckInterval)
stat, err := os.Stat(r.certFile)
if err != nil {
log.Printf("could not stat %s: %s", r.certFile, err)
return r.pool
}
if !stat.ModTime().Equal(r.lastModified) {
log.Printf("reloading certificate pool from %s", r.certFile)
pool, err := loadCertPool(r.certFile)
if err != nil {
log.Printf("could not load certificate pool: %s", err)
return r.pool
}
r.pool = pool
r.lastModified = stat.ModTime()
}
return r.pool
}

View File

@ -0,0 +1,36 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2022 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package signaling
import (
"testing"
"time"
)
func UpdateCertificateCheckIntervalForTest(t *testing.T, interval time.Duration) {
old := CertificateCheckInterval
t.Cleanup(func() {
CertificateCheckInterval = old
})
CertificateCheckInterval = interval
}

View File

@ -36,8 +36,6 @@ import (
clientv3 "go.etcd.io/etcd/client/v3"
"google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/resolver"
status "google.golang.org/grpc/status"
)
@ -74,6 +72,8 @@ func newGrpcClientImpl(conn grpc.ClientConnInterface) *grpcClientImpl {
}
type GrpcClient struct {
isSelf uint32
ip net.IP
target string
conn *grpc.ClientConn
@ -164,6 +164,18 @@ func (c *GrpcClient) Close() error {
return c.conn.Close()
}
func (c *GrpcClient) IsSelf() bool {
return atomic.LoadUint32(&c.isSelf) != 0
}
func (c *GrpcClient) SetSelf(self bool) {
if self {
atomic.StoreUint32(&c.isSelf, 1)
} else {
atomic.StoreUint32(&c.isSelf, 0)
}
}
func (c *GrpcClient) GetServerId(ctx context.Context) (string, error) {
statsGrpcClientCalls.WithLabelValues("GetServerId").Inc()
response, err := c.impl.GetServerId(ctx, &GetServerIdRequest{}, grpc.WaitForReady(true))
@ -248,6 +260,7 @@ type GrpcClients struct {
initializedCtx context.Context
initializedFunc context.CancelFunc
wakeupChanForTesting chan bool
selfCheckWaitGroup sync.WaitGroup
}
func NewGrpcClients(config *goconf.ConfigFile, etcdClient *EtcdClient) (*GrpcClients, error) {
@ -267,23 +280,12 @@ func NewGrpcClients(config *goconf.ConfigFile, etcdClient *EtcdClient) (*GrpcCli
}
func (c *GrpcClients) load(config *goconf.ConfigFile, fromReload bool) error {
var opts []grpc.DialOption
caFile, _ := config.GetString("grpc", "ca")
if caFile != "" {
creds, err := credentials.NewClientTLSFromFile(caFile, "")
if err != nil {
return fmt.Errorf("invalid GRPC CA in %s: %w", caFile, err)
}
opts = append(opts, grpc.WithTransportCredentials(creds))
} else {
log.Printf("WARNING: No GRPC CA configured, expecting unencrypted connections")
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
creds, err := NewReloadableCredentials(config, false)
if err != nil {
return err
}
if opts == nil {
opts = make([]grpc.DialOption, 0)
}
opts := []grpc.DialOption{grpc.WithTransportCredentials(creds)}
c.dialOptions.Store(opts)
targetType, _ := config.GetString("grpc", "targettype")
@ -291,7 +293,6 @@ func (c *GrpcClients) load(config *goconf.ConfigFile, fromReload bool) error {
targetType = DefaultGrpcTargetType
}
var err error
switch targetType {
case GrpcTargetTypeStatic:
err = c.loadTargetsStatic(config, fromReload, opts...)
@ -306,6 +307,79 @@ func (c *GrpcClients) load(config *goconf.ConfigFile, fromReload bool) error {
return err
}
func (c *GrpcClients) closeClient(client *GrpcClient) {
if client.IsSelf() {
// Already closed.
return
}
if err := client.Close(); err != nil {
log.Printf("Error closing client to %s: %s", client.Target(), err)
}
}
func (c *GrpcClients) isClientAvailable(target string, client *GrpcClient) bool {
c.mu.RLock()
defer c.mu.RUnlock()
entries, found := c.clientsMap[target]
if !found {
return false
}
for _, entry := range entries {
if entry == client {
return true
}
}
return false
}
func (c *GrpcClients) getServerIdWithTimeout(ctx context.Context, client *GrpcClient) (string, error) {
ctx2, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
id, err := client.GetServerId(ctx2)
return id, err
}
func (c *GrpcClients) checkIsSelf(ctx context.Context, target string, client *GrpcClient) {
backoff, _ := NewExponentialBackoff(initialWaitDelay, maxWaitDelay)
defer c.selfCheckWaitGroup.Done()
loop:
for {
select {
case <-ctx.Done():
// Cancelled
return
default:
if !c.isClientAvailable(target, client) {
return
}
id, err := c.getServerIdWithTimeout(ctx, client)
if err != nil {
if status.Code(err) != codes.Canceled {
log.Printf("Error checking GRPC server id of %s, retrying in %s: %s", client.Target(), backoff.NextWait(), err)
}
backoff.Wait(ctx)
continue
}
if id == GrpcServerId {
log.Printf("GRPC target %s is this server, removing", client.Target())
c.closeClient(client)
client.SetSelf(true)
} else {
log.Printf("Checked GRPC server id of %s", client.Target())
}
break loop
}
}
}
func (c *GrpcClients) loadTargetsStatic(config *goconf.ConfigFile, fromReload bool, opts ...grpc.DialOption) error {
c.mu.Lock()
defer c.mu.Unlock()
@ -343,6 +417,8 @@ func (c *GrpcClients) loadTargetsStatic(config *goconf.ConfigFile, fromReload bo
ips, err = lookupGrpcIp(host)
if err != nil {
log.Printf("Could not lookup %s: %s", host, err)
// Make sure updating continues even if initial lookup failed.
clientsMap[target] = nil
continue
}
} else {
@ -355,26 +431,14 @@ func (c *GrpcClients) loadTargetsStatic(config *goconf.ConfigFile, fromReload bo
if err != nil {
for _, clients := range clientsMap {
for _, client := range clients {
if closeerr := client.Close(); closeerr != nil {
log.Printf("Error closing client to %s: %s", client.Target(), closeerr)
}
c.closeClient(client)
}
}
return err
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if id, err := client.GetServerId(ctx); err != nil {
log.Printf("Error checking server id of %s: %s", client.Target(), err)
} else if id == GrpcServerId {
log.Printf("GRPC target %s is this server, ignoring", client.Target())
if err := client.Close(); err != nil {
log.Printf("Error closing client to %s: %s", client.Target(), err)
}
continue
}
c.selfCheckWaitGroup.Add(1)
go c.checkIsSelf(context.Background(), target, client)
log.Printf("Adding %s as GRPC target", client.Target())
clientsMap[target] = append(clientsMap[target], client)
@ -386,9 +450,7 @@ func (c *GrpcClients) loadTargetsStatic(config *goconf.ConfigFile, fromReload bo
if clients, found := clientsMap[target]; found {
for _, client := range clients {
log.Printf("Deleting GRPC target %s", client.Target())
if err := client.Close(); err != nil {
log.Printf("Error closing client to %s: %s", client.Target(), err)
}
c.closeClient(client)
}
delete(clientsMap, target)
}
@ -467,9 +529,7 @@ func (c *GrpcClients) updateGrpcIPs() {
if !found {
changed = true
log.Printf("Removing connection to %s", client.Target())
if err := client.Close(); err != nil {
log.Printf("Error closing client to %s: %s", client.Target(), err)
}
c.closeClient(client)
c.wakeupForTesting()
}
}
@ -481,18 +541,8 @@ func (c *GrpcClients) updateGrpcIPs() {
continue
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if id, err := client.GetServerId(ctx); err != nil {
log.Printf("Error checking server id of %s: %s", client.Target(), err)
} else if id == GrpcServerId {
//log.Printf("GRPC target %s is this server, ignoring", client.Target())
if err := client.Close(); err != nil {
log.Printf("Error closing client to %s: %s", client.Target(), err)
}
continue
}
c.selfCheckWaitGroup.Add(1)
go c.checkIsSelf(context.Background(), target, client)
log.Printf("Adding %s as GRPC target", client.Target())
newClients = append(newClients, client)
@ -543,21 +593,17 @@ func (c *GrpcClients) EtcdClientCreated(client *EtcdClient) {
go func() {
client.WaitForConnection()
waitDelay := initialWaitDelay
backoff, _ := NewExponentialBackoff(initialWaitDelay, maxWaitDelay)
for {
response, err := c.getGrpcTargets(client, c.targetPrefix)
if err != nil {
if err == context.DeadlineExceeded {
log.Printf("Timeout getting initial list of GRPC targets, retry in %s", waitDelay)
log.Printf("Timeout getting initial list of GRPC targets, retry in %s", backoff.NextWait())
} else {
log.Printf("Could not get initial list of GRPC targets, retry in %s: %s", waitDelay, err)
log.Printf("Could not get initial list of GRPC targets, retry in %s: %s", backoff.NextWait(), err)
}
time.Sleep(waitDelay)
waitDelay = waitDelay * 2
if waitDelay > maxWaitDelay {
waitDelay = maxWaitDelay
}
backoff.Wait(context.Background())
continue
}
@ -609,19 +655,8 @@ func (c *GrpcClients) EtcdKeyUpdated(client *EtcdClient, key string, data []byte
return
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if id, err := cl.GetServerId(ctx); err != nil {
log.Printf("Error checking server id of %s: %s", cl.Target(), err)
} else if id == GrpcServerId {
log.Printf("GRPC target %s is this server, ignoring %s", cl.Target(), key)
if err := cl.Close(); err != nil {
log.Printf("Error closing client to %s: %s", cl.Target(), err)
}
c.wakeupForTesting()
return
}
c.selfCheckWaitGroup.Add(1)
go c.checkIsSelf(context.Background(), info.Address, cl)
log.Printf("Adding %s as GRPC target", cl.Target())
@ -658,9 +693,7 @@ func (c *GrpcClients) removeEtcdClientLocked(key string) {
for _, client := range clients {
log.Printf("Removing connection to %s (from %s)", client.Target(), key)
if err := client.Close(); err != nil {
log.Printf("Error closing client to %s: %s", client.Target(), err)
}
c.closeClient(client)
}
delete(c.clientsMap, info.Address)
c.clients = make([]*GrpcClient, 0, len(c.clientsMap))
@ -726,5 +759,17 @@ func (c *GrpcClients) GetClients() []*GrpcClient {
c.mu.RLock()
defer c.mu.RUnlock()
return c.clients
if len(c.clients) == 0 {
return c.clients
}
result := make([]*GrpcClient, 0, len(c.clients)-1)
for _, client := range c.clients {
if client.IsSelf() {
continue
}
result = append(result, client)
}
return result
}

View File

@ -23,8 +23,12 @@ package signaling
import (
"context"
"crypto/rand"
"crypto/rsa"
"fmt"
"net"
"os"
"path"
"testing"
"time"
@ -32,12 +36,8 @@ import (
"go.etcd.io/etcd/server/v3/embed"
)
func NewGrpcClientsForTest(t *testing.T, addr string) *GrpcClients {
config := goconf.NewConfigFile()
config.AddOption("grpc", "targets", addr)
config.AddOption("grpc", "dnsdiscovery", "true")
client, err := NewGrpcClients(config, nil)
func NewGrpcClientsForTestWithConfig(t *testing.T, config *goconf.ConfigFile, etcdClient *EtcdClient) *GrpcClients {
client, err := NewGrpcClients(config, etcdClient)
if err != nil {
t.Fatal(err)
}
@ -48,6 +48,14 @@ func NewGrpcClientsForTest(t *testing.T, addr string) *GrpcClients {
return client
}
func NewGrpcClientsForTest(t *testing.T, addr string) *GrpcClients {
config := goconf.NewConfigFile()
config.AddOption("grpc", "targets", addr)
config.AddOption("grpc", "dnsdiscovery", "true")
return NewGrpcClientsForTestWithConfig(t, config, nil)
}
func NewGrpcClientsWithEtcdForTest(t *testing.T, etcd *embed.Etcd) *GrpcClients {
config := goconf.NewConfigFile()
config.AddOption("etcd", "endpoints", etcd.Config().LCUrls[0].String())
@ -65,15 +73,7 @@ func NewGrpcClientsWithEtcdForTest(t *testing.T, etcd *embed.Etcd) *GrpcClients
}
})
client, err := NewGrpcClients(config, etcdClient)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
client.Close()
})
return client
return NewGrpcClientsForTestWithConfig(t, config, etcdClient)
}
func drainWakeupChannel(ch chan bool) {
@ -184,6 +184,7 @@ func Test_GrpcClients_EtcdIgnoreSelf(t *testing.T) {
server2.serverId = GrpcServerId
SetEtcdValue(etcd, "/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}"))
<-ch
client.selfCheckWaitGroup.Wait()
if clients := client.GetClients(); len(clients) != 1 {
t.Errorf("Expected one client, got %+v", clients)
} else if clients[0].Target() != addr1 {
@ -257,3 +258,101 @@ func Test_GrpcClients_DnsDiscovery(t *testing.T) {
t.Errorf("Expected IP %s, got %s", ip2, clients[0].ip)
}
}
func Test_GrpcClients_DnsDiscoveryInitialFailed(t *testing.T) {
var ipsResult []net.IP
lookupGrpcIp = func(host string) ([]net.IP, error) {
if host == "testgrpc" && len(ipsResult) > 0 {
return ipsResult, nil
}
return nil, &net.DNSError{
Err: "no such host",
Name: host,
IsNotFound: true,
}
}
target := "testgrpc:12345"
ip1 := net.ParseIP("192.168.0.1")
targetWithIp1 := fmt.Sprintf("%s (%s)", target, ip1)
client := NewGrpcClientsForTest(t, target)
ch := make(chan bool, 1)
client.wakeupChanForTesting = ch
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := client.WaitForInitialized(ctx); err != nil {
t.Fatal(err)
}
if clients := client.GetClients(); len(clients) != 0 {
t.Errorf("Expected no client, got %+v", clients)
}
ipsResult = []net.IP{ip1}
drainWakeupChannel(ch)
client.updateGrpcIPs()
<-ch
if clients := client.GetClients(); len(clients) != 1 {
t.Errorf("Expected one client, got %+v", clients)
} else if clients[0].Target() != targetWithIp1 {
t.Errorf("Expected target %s, got %s", targetWithIp1, clients[0].Target())
} else if !clients[0].ip.Equal(ip1) {
t.Errorf("Expected IP %s, got %s", ip1, clients[0].ip)
}
}
func Test_GrpcClients_Encryption(t *testing.T) {
serverKey, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
t.Fatal(err)
}
clientKey, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
t.Fatal(err)
}
serverCert := GenerateSelfSignedCertificateForTesting(t, 1024, "Server cert", serverKey)
clientCert := GenerateSelfSignedCertificateForTesting(t, 1024, "Testing client", clientKey)
dir := t.TempDir()
serverPrivkeyFile := path.Join(dir, "server-privkey.pem")
serverPubkeyFile := path.Join(dir, "server-pubkey.pem")
serverCertFile := path.Join(dir, "server-cert.pem")
WritePrivateKey(serverKey, serverPrivkeyFile) // nolint
WritePublicKey(&serverKey.PublicKey, serverPubkeyFile) // nolint
os.WriteFile(serverCertFile, serverCert, 0755) // nolint
clientPrivkeyFile := path.Join(dir, "client-privkey.pem")
clientPubkeyFile := path.Join(dir, "client-pubkey.pem")
clientCertFile := path.Join(dir, "client-cert.pem")
WritePrivateKey(clientKey, clientPrivkeyFile) // nolint
WritePublicKey(&clientKey.PublicKey, clientPubkeyFile) // nolint
os.WriteFile(clientCertFile, clientCert, 0755) // nolint
serverConfig := goconf.NewConfigFile()
serverConfig.AddOption("grpc", "servercertificate", serverCertFile)
serverConfig.AddOption("grpc", "serverkey", serverPrivkeyFile)
serverConfig.AddOption("grpc", "clientca", clientCertFile)
_, addr := NewGrpcServerForTestWithConfig(t, serverConfig)
clientConfig := goconf.NewConfigFile()
clientConfig.AddOption("grpc", "targets", addr)
clientConfig.AddOption("grpc", "clientcertificate", clientCertFile)
clientConfig.AddOption("grpc", "clientkey", clientPrivkeyFile)
clientConfig.AddOption("grpc", "serverca", serverCertFile)
clients := NewGrpcClientsForTestWithConfig(t, clientConfig, nil)
ctx, cancel1 := context.WithTimeout(context.Background(), time.Second)
defer cancel1()
if err := clients.WaitForInitialized(ctx); err != nil {
t.Fatal(err)
}
for _, client := range clients.GetClients() {
if _, err := client.GetServerId(ctx); err != nil {
t.Fatal(err)
}
}
}

172
grpc_common.go Normal file
View File

@ -0,0 +1,172 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2022 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package signaling
import (
"context"
"crypto/tls"
"fmt"
"log"
"net"
"github.com/dlintw/goconf"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
)
type reloadableCredentials struct {
config *tls.Config
pool *CertPoolReloader
}
func (c *reloadableCredentials) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
// use local cfg to avoid clobbering ServerName if using multiple endpoints
cfg := c.config.Clone()
cfg.RootCAs = c.pool.GetCertPool()
if cfg.ServerName == "" {
serverName, _, err := net.SplitHostPort(authority)
if err != nil {
// If the authority had no host port or if the authority cannot be parsed, use it as-is.
serverName = authority
}
cfg.ServerName = serverName
}
conn := tls.Client(rawConn, cfg)
errChannel := make(chan error, 1)
go func() {
errChannel <- conn.Handshake()
close(errChannel)
}()
select {
case err := <-errChannel:
if err != nil {
conn.Close()
return nil, nil, err
}
case <-ctx.Done():
conn.Close()
return nil, nil, ctx.Err()
}
tlsInfo := credentials.TLSInfo{
State: conn.ConnectionState(),
CommonAuthInfo: credentials.CommonAuthInfo{
SecurityLevel: credentials.PrivacyAndIntegrity,
},
}
return WrapSyscallConn(rawConn, conn), tlsInfo, nil
}
func (c *reloadableCredentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
cfg := c.config.Clone()
cfg.ClientCAs = c.pool.GetCertPool()
conn := tls.Server(rawConn, cfg)
if err := conn.Handshake(); err != nil {
conn.Close()
return nil, nil, err
}
tlsInfo := credentials.TLSInfo{
State: conn.ConnectionState(),
CommonAuthInfo: credentials.CommonAuthInfo{
SecurityLevel: credentials.PrivacyAndIntegrity,
},
}
return WrapSyscallConn(rawConn, conn), tlsInfo, nil
}
func (c *reloadableCredentials) Info() credentials.ProtocolInfo {
return credentials.ProtocolInfo{
SecurityProtocol: "tls",
SecurityVersion: "1.2",
ServerName: c.config.ServerName,
}
}
func (c *reloadableCredentials) Clone() credentials.TransportCredentials {
return &reloadableCredentials{
config: c.config.Clone(),
pool: c.pool,
}
}
func (c *reloadableCredentials) OverrideServerName(serverName string) error {
c.config.ServerName = serverName
return nil
}
func NewReloadableCredentials(config *goconf.ConfigFile, server bool) (credentials.TransportCredentials, error) {
var prefix string
var caPrefix string
if server {
prefix = "server"
caPrefix = "client"
} else {
prefix = "client"
caPrefix = "server"
}
certificateFile, _ := config.GetString("grpc", prefix+"certificate")
keyFile, _ := config.GetString("grpc", prefix+"key")
caFile, _ := config.GetString("grpc", caPrefix+"ca")
cfg := &tls.Config{
NextProtos: []string{"h2"},
}
if certificateFile != "" && keyFile != "" {
loader, err := NewCertificateReloader(certificateFile, keyFile)
if err != nil {
return nil, fmt.Errorf("invalid GRPC %s certificate / key in %s / %s: %w", prefix, certificateFile, keyFile, err)
}
if server {
cfg.GetCertificate = loader.GetCertificate
} else {
cfg.GetClientCertificate = loader.GetClientCertificate
}
}
if caFile != "" {
pool, err := NewCertPoolReloader(caFile)
if err != nil {
return nil, err
}
if server {
cfg.ClientAuth = tls.RequireAndVerifyClientCert
}
creds := &reloadableCredentials{
config: cfg,
pool: pool,
}
return creds, nil
}
if cfg.GetCertificate == nil {
if server {
log.Printf("WARNING: No GRPC server certificate and/or key configured, running unencrypted")
} else {
log.Printf("WARNING: No GRPC CA configured, expecting unencrypted connections")
}
return insecure.NewCredentials(), nil
}
return credentials.NewTLS(cfg), nil
}

88
grpc_common_test.go Normal file
View File

@ -0,0 +1,88 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2022 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package signaling
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"net"
"os"
"testing"
"time"
)
func GenerateSelfSignedCertificateForTesting(t *testing.T, bits int, organization string, key *rsa.PrivateKey) []byte {
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{organization},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour * 24 * 180),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageClientAuth,
x509.ExtKeyUsageServerAuth,
},
BasicConstraintsValid: true,
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
data, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
if err != nil {
t.Fatal(err)
}
data = pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: data,
})
return data
}
func WritePrivateKey(key *rsa.PrivateKey, filename string) error {
data := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
})
return os.WriteFile(filename, data, 0600)
}
func WritePublicKey(key *rsa.PublicKey, filename string) error {
data, err := x509.MarshalPKIXPublicKey(key)
if err != nil {
return err
}
data = pem.EncodeToMemory(&pem.Block{
Type: "RSA PUBLIC KEY",
Bytes: data,
})
return os.WriteFile(filename, data, 0755)
}

View File

@ -34,7 +34,6 @@ import (
"github.com/dlintw/goconf"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
status "google.golang.org/grpc/status"
)
@ -76,21 +75,12 @@ func NewGrpcServer(config *goconf.ConfigFile) (*GrpcServer, error) {
}
}
var opts []grpc.ServerOption
certificateFile, _ := config.GetString("grpc", "certificate")
keyFile, _ := config.GetString("grpc", "key")
if certificateFile != "" && keyFile != "" {
creds, err := credentials.NewServerTLSFromFile(certificateFile, keyFile)
if err != nil {
return nil, fmt.Errorf("invalid GRPC server certificate / key in %s / %s: %w", certificateFile, keyFile, err)
}
opts = append(opts, grpc.Creds(creds))
} else {
log.Printf("WARNING: No GRPC server certificate and/or key configured, running unencrypted")
creds, err := NewReloadableCredentials(config, true)
if err != nil {
return nil, err
}
conn := grpc.NewServer(opts...)
conn := grpc.NewServer(grpc.Creds(creds))
result := &GrpcServer{
conn: conn,
listener: listener,

View File

@ -22,15 +22,25 @@
package signaling
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"net"
"os"
"path"
"strconv"
"testing"
"time"
"github.com/dlintw/goconf"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
func NewGrpcServerForTest(t *testing.T) (server *GrpcServer, addr string) {
config := goconf.NewConfigFile()
func NewGrpcServerForTestWithConfig(t *testing.T, config *goconf.ConfigFile) (server *GrpcServer, addr string) {
for port := 50000; port < 50100; port++ {
addr = net.JoinHostPort("127.0.0.1", strconv.Itoa(port))
config.AddOption("grpc", "listen", addr)
@ -62,3 +72,174 @@ func NewGrpcServerForTest(t *testing.T) (server *GrpcServer, addr string) {
})
return server, addr
}
func NewGrpcServerForTest(t *testing.T) (server *GrpcServer, addr string) {
config := goconf.NewConfigFile()
return NewGrpcServerForTestWithConfig(t, config)
}
func Test_GrpcServer_ReloadCerts(t *testing.T) {
key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
t.Fatal(err)
}
org1 := "Testing certificate"
cert1 := GenerateSelfSignedCertificateForTesting(t, 1024, org1, key)
dir := t.TempDir()
privkeyFile := path.Join(dir, "privkey.pem")
pubkeyFile := path.Join(dir, "pubkey.pem")
certFile := path.Join(dir, "cert.pem")
WritePrivateKey(key, privkeyFile) // nolint
WritePublicKey(&key.PublicKey, pubkeyFile) // nolint
os.WriteFile(certFile, cert1, 0755) // nolint
config := goconf.NewConfigFile()
config.AddOption("grpc", "servercertificate", certFile)
config.AddOption("grpc", "serverkey", privkeyFile)
UpdateCertificateCheckIntervalForTest(t, time.Millisecond)
_, addr := NewGrpcServerForTestWithConfig(t, config)
cp1 := x509.NewCertPool()
if !cp1.AppendCertsFromPEM(cert1) {
t.Fatalf("could not add certificate")
}
cfg1 := &tls.Config{
RootCAs: cp1,
}
conn1, err := tls.Dial("tcp", addr, cfg1)
if err != nil {
t.Fatal(err)
}
defer conn1.Close() // nolint
state1 := conn1.ConnectionState()
if certs := state1.PeerCertificates; len(certs) == 0 {
t.Errorf("expected certificates, got %+v", state1)
} else if len(certs[0].Subject.Organization) == 0 {
t.Errorf("expected organization, got %s", certs[0].Subject)
} else if certs[0].Subject.Organization[0] != org1 {
t.Errorf("expected organization %s, got %s", org1, certs[0].Subject)
}
org2 := "Updated certificate"
cert2 := GenerateSelfSignedCertificateForTesting(t, 1024, org2, key)
os.WriteFile(certFile, cert2, 0755) // nolint
cp2 := x509.NewCertPool()
if !cp2.AppendCertsFromPEM(cert2) {
t.Fatalf("could not add certificate")
}
cfg2 := &tls.Config{
RootCAs: cp2,
}
conn2, err := tls.Dial("tcp", addr, cfg2)
if err != nil {
t.Fatal(err)
}
defer conn2.Close() // nolint
state2 := conn2.ConnectionState()
if certs := state2.PeerCertificates; len(certs) == 0 {
t.Errorf("expected certificates, got %+v", state2)
} else if len(certs[0].Subject.Organization) == 0 {
t.Errorf("expected organization, got %s", certs[0].Subject)
} else if certs[0].Subject.Organization[0] != org2 {
t.Errorf("expected organization %s, got %s", org2, certs[0].Subject)
}
}
func Test_GrpcServer_ReloadCA(t *testing.T) {
serverKey, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
t.Fatal(err)
}
clientKey, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
t.Fatal(err)
}
serverCert := GenerateSelfSignedCertificateForTesting(t, 1024, "Server cert", serverKey)
org1 := "Testing client"
clientCert1 := GenerateSelfSignedCertificateForTesting(t, 1024, org1, clientKey)
dir := t.TempDir()
privkeyFile := path.Join(dir, "privkey.pem")
pubkeyFile := path.Join(dir, "pubkey.pem")
certFile := path.Join(dir, "cert.pem")
caFile := path.Join(dir, "ca.pem")
WritePrivateKey(serverKey, privkeyFile) // nolint
WritePublicKey(&serverKey.PublicKey, pubkeyFile) // nolint
os.WriteFile(certFile, serverCert, 0755) // nolint
os.WriteFile(caFile, clientCert1, 0755) // nolint
config := goconf.NewConfigFile()
config.AddOption("grpc", "servercertificate", certFile)
config.AddOption("grpc", "serverkey", privkeyFile)
config.AddOption("grpc", "clientca", caFile)
UpdateCertificateCheckIntervalForTest(t, time.Millisecond)
_, addr := NewGrpcServerForTestWithConfig(t, config)
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(serverCert) {
t.Fatalf("could not add certificate")
}
pair1, err := tls.X509KeyPair(clientCert1, pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(clientKey),
}))
if err != nil {
t.Fatal(err)
}
cfg1 := &tls.Config{
RootCAs: pool,
Certificates: []tls.Certificate{pair1},
}
client1, err := NewGrpcClient(addr, nil, grpc.WithTransportCredentials(credentials.NewTLS(cfg1)))
if err != nil {
t.Fatal(err)
}
defer client1.Close() // nolint
ctx1, cancel1 := context.WithTimeout(context.Background(), time.Second)
defer cancel1()
if _, err := client1.GetServerId(ctx1); err != nil {
t.Fatal(err)
}
org2 := "Updated client"
clientCert2 := GenerateSelfSignedCertificateForTesting(t, 1024, org2, clientKey)
os.WriteFile(caFile, clientCert2, 0755) // nolint
pair2, err := tls.X509KeyPair(clientCert2, pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(clientKey),
}))
if err != nil {
t.Fatal(err)
}
cfg2 := &tls.Config{
RootCAs: pool,
Certificates: []tls.Certificate{pair2},
}
client2, err := NewGrpcClient(addr, nil, grpc.WithTransportCredentials(credentials.NewTLS(cfg2)))
if err != nil {
t.Fatal(err)
}
defer client2.Close() // nolint
ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second)
defer cancel2()
// This will fail if the CA certificate has not been reloaded by the server.
if _, err := client2.GetServerId(ctx2); err != nil {
t.Fatal(err)
}
}

View File

@ -268,12 +268,21 @@ connectionsperhost = 8
# Certificate / private key to use for the GRPC server.
# Omit to use unencrypted connections.
#certificate = /path/to/grpc-server.crt
#key = /path/to/grpc-server.key
#servercertificate = /path/to/grpc-server.crt
#serverkey = /path/to/grpc-server.key
# CA certificate that is allowed to issue certificates of GRPC servers.
# Omit to expect unencrypted connections.
#ca = /path/to/grpc-ca.crt
#serverca = /path/to/grpc-ca.crt
# Certificate / private key to use for the GRPC client.
# Omit if clients don't need to authenticate on the server.
#clientcertificate = /path/to/grpc-client.crt
#clientkey = /path/to/grpc-client.key
# CA certificate that is allowed to issue certificates of GRPC clients.
# Omit to allow any clients to connect.
#clientca = /path/to/grpc-ca.crt
# Type of GRPC target configuration.
# Defaults to "static".

58
syscallconn.go Executable file
View File

@ -0,0 +1,58 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package signaling
import (
"net"
"syscall"
)
type sysConn = syscall.Conn
// syscallConn keeps reference of rawConn to support syscall.Conn for channelz.
// SyscallConn() (the method in interface syscall.Conn) is explicitly
// implemented on this type,
//
// Interface syscall.Conn is implemented by most net.Conn implementations (e.g.
// TCPConn, UnixConn), but is not part of net.Conn interface. So wrapper conns
// that embed net.Conn don't implement syscall.Conn. (Side note: tls.Conn
// doesn't embed net.Conn, so even if syscall.Conn is part of net.Conn, it won't
// help here).
type syscallConn struct {
net.Conn
// sysConn is a type alias of syscall.Conn. It's necessary because the name
// `Conn` collides with `net.Conn`.
sysConn
}
// WrapSyscallConn tries to wrap rawConn and newConn into a net.Conn that
// implements syscall.Conn. rawConn will be used to support syscall, and newConn
// will be used for read/write.
//
// This function returns newConn if rawConn doesn't implement syscall.Conn.
func WrapSyscallConn(rawConn, newConn net.Conn) net.Conn {
sysConn, ok := rawConn.(syscall.Conn)
if !ok {
return newConn
}
return &syscallConn{
Conn: newConn,
sysConn: sysConn,
}
}