diff --git a/grpc_client.go b/grpc_client.go index 871c5b3..48d162a 100644 --- a/grpc_client.go +++ b/grpc_client.go @@ -36,8 +36,6 @@ import ( clientv3 "go.etcd.io/etcd/client/v3" "google.golang.org/grpc" codes "google.golang.org/grpc/codes" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/resolver" status "google.golang.org/grpc/status" ) @@ -282,23 +280,12 @@ func NewGrpcClients(config *goconf.ConfigFile, etcdClient *EtcdClient) (*GrpcCli } func (c *GrpcClients) load(config *goconf.ConfigFile, fromReload bool) error { - var opts []grpc.DialOption - caFile, _ := config.GetString("grpc", "ca") - if caFile != "" { - creds, err := credentials.NewClientTLSFromFile(caFile, "") - if err != nil { - return fmt.Errorf("invalid GRPC CA in %s: %w", caFile, err) - } - - opts = append(opts, grpc.WithTransportCredentials(creds)) - } else { - log.Printf("WARNING: No GRPC CA configured, expecting unencrypted connections") - opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) + creds, err := NewReloadableCredentials(config, false) + if err != nil { + return err } - if opts == nil { - opts = make([]grpc.DialOption, 0) - } + opts := []grpc.DialOption{grpc.WithTransportCredentials(creds)} c.dialOptions.Store(opts) targetType, _ := config.GetString("grpc", "targettype") @@ -306,7 +293,6 @@ func (c *GrpcClients) load(config *goconf.ConfigFile, fromReload bool) error { targetType = DefaultGrpcTargetType } - var err error switch targetType { case GrpcTargetTypeStatic: err = c.loadTargetsStatic(config, fromReload, opts...) diff --git a/grpc_client_test.go b/grpc_client_test.go index 4821577..e3fdca6 100644 --- a/grpc_client_test.go +++ b/grpc_client_test.go @@ -23,8 +23,12 @@ package signaling import ( "context" + "crypto/rand" + "crypto/rsa" "fmt" "net" + "os" + "path" "testing" "time" @@ -32,12 +36,8 @@ import ( "go.etcd.io/etcd/server/v3/embed" ) -func NewGrpcClientsForTest(t *testing.T, addr string) *GrpcClients { - config := goconf.NewConfigFile() - config.AddOption("grpc", "targets", addr) - config.AddOption("grpc", "dnsdiscovery", "true") - - client, err := NewGrpcClients(config, nil) +func NewGrpcClientsForTestWithConfig(t *testing.T, config *goconf.ConfigFile, etcdClient *EtcdClient) *GrpcClients { + client, err := NewGrpcClients(config, etcdClient) if err != nil { t.Fatal(err) } @@ -48,6 +48,14 @@ func NewGrpcClientsForTest(t *testing.T, addr string) *GrpcClients { return client } +func NewGrpcClientsForTest(t *testing.T, addr string) *GrpcClients { + config := goconf.NewConfigFile() + config.AddOption("grpc", "targets", addr) + config.AddOption("grpc", "dnsdiscovery", "true") + + return NewGrpcClientsForTestWithConfig(t, config, nil) +} + func NewGrpcClientsWithEtcdForTest(t *testing.T, etcd *embed.Etcd) *GrpcClients { config := goconf.NewConfigFile() config.AddOption("etcd", "endpoints", etcd.Config().LCUrls[0].String()) @@ -65,15 +73,7 @@ func NewGrpcClientsWithEtcdForTest(t *testing.T, etcd *embed.Etcd) *GrpcClients } }) - client, err := NewGrpcClients(config, etcdClient) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { - client.Close() - }) - - return client + return NewGrpcClientsForTestWithConfig(t, config, etcdClient) } func drainWakeupChannel(ch chan bool) { @@ -302,3 +302,57 @@ func Test_GrpcClients_DnsDiscoveryInitialFailed(t *testing.T) { t.Errorf("Expected IP %s, got %s", ip1, clients[0].ip) } } + +func Test_GrpcClients_Encryption(t *testing.T) { + serverKey, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + t.Fatal(err) + } + clientKey, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + t.Fatal(err) + } + + serverCert := GenerateSelfSignedCertificateForTesting(t, 1024, "Server cert", serverKey) + clientCert := GenerateSelfSignedCertificateForTesting(t, 1024, "Testing client", clientKey) + + dir := t.TempDir() + serverPrivkeyFile := path.Join(dir, "server-privkey.pem") + serverPubkeyFile := path.Join(dir, "server-pubkey.pem") + serverCertFile := path.Join(dir, "server-cert.pem") + WritePrivateKey(serverKey, serverPrivkeyFile) // nolint + WritePublicKey(&serverKey.PublicKey, serverPubkeyFile) // nolint + os.WriteFile(serverCertFile, serverCert, 0755) // nolint + clientPrivkeyFile := path.Join(dir, "client-privkey.pem") + clientPubkeyFile := path.Join(dir, "client-pubkey.pem") + clientCertFile := path.Join(dir, "client-cert.pem") + WritePrivateKey(clientKey, clientPrivkeyFile) // nolint + WritePublicKey(&clientKey.PublicKey, clientPubkeyFile) // nolint + os.WriteFile(clientCertFile, clientCert, 0755) // nolint + + serverConfig := goconf.NewConfigFile() + serverConfig.AddOption("grpc", "servercertificate", serverCertFile) + serverConfig.AddOption("grpc", "serverkey", serverPrivkeyFile) + serverConfig.AddOption("grpc", "clientca", clientCertFile) + _, addr := NewGrpcServerForTestWithConfig(t, serverConfig) + + clientConfig := goconf.NewConfigFile() + clientConfig.AddOption("grpc", "targets", addr) + clientConfig.AddOption("grpc", "clientcertificate", clientCertFile) + clientConfig.AddOption("grpc", "clientkey", clientPrivkeyFile) + clientConfig.AddOption("grpc", "serverca", serverCertFile) + clients := NewGrpcClientsForTestWithConfig(t, clientConfig, nil) + + ctx, cancel1 := context.WithTimeout(context.Background(), time.Second) + defer cancel1() + + if err := clients.WaitForInitialized(ctx); err != nil { + t.Fatal(err) + } + + for _, client := range clients.GetClients() { + if _, err := client.GetServerId(ctx); err != nil { + t.Fatal(err) + } + } +} diff --git a/grpc_common.go b/grpc_common.go new file mode 100644 index 0000000..62bd437 --- /dev/null +++ b/grpc_common.go @@ -0,0 +1,172 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2022 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "context" + "crypto/tls" + "fmt" + "log" + "net" + + "github.com/dlintw/goconf" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" +) + +type reloadableCredentials struct { + config *tls.Config + + pool *CertPoolReloader +} + +func (c *reloadableCredentials) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + // use local cfg to avoid clobbering ServerName if using multiple endpoints + cfg := c.config.Clone() + cfg.RootCAs = c.pool.GetCertPool() + if cfg.ServerName == "" { + serverName, _, err := net.SplitHostPort(authority) + if err != nil { + // If the authority had no host port or if the authority cannot be parsed, use it as-is. + serverName = authority + } + cfg.ServerName = serverName + } + conn := tls.Client(rawConn, cfg) + errChannel := make(chan error, 1) + go func() { + errChannel <- conn.Handshake() + close(errChannel) + }() + select { + case err := <-errChannel: + if err != nil { + conn.Close() + return nil, nil, err + } + case <-ctx.Done(): + conn.Close() + return nil, nil, ctx.Err() + } + tlsInfo := credentials.TLSInfo{ + State: conn.ConnectionState(), + CommonAuthInfo: credentials.CommonAuthInfo{ + SecurityLevel: credentials.PrivacyAndIntegrity, + }, + } + return WrapSyscallConn(rawConn, conn), tlsInfo, nil +} + +func (c *reloadableCredentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + cfg := c.config.Clone() + cfg.ClientCAs = c.pool.GetCertPool() + + conn := tls.Server(rawConn, cfg) + if err := conn.Handshake(); err != nil { + conn.Close() + return nil, nil, err + } + tlsInfo := credentials.TLSInfo{ + State: conn.ConnectionState(), + CommonAuthInfo: credentials.CommonAuthInfo{ + SecurityLevel: credentials.PrivacyAndIntegrity, + }, + } + return WrapSyscallConn(rawConn, conn), tlsInfo, nil +} + +func (c *reloadableCredentials) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{ + SecurityProtocol: "tls", + SecurityVersion: "1.2", + ServerName: c.config.ServerName, + } +} + +func (c *reloadableCredentials) Clone() credentials.TransportCredentials { + return &reloadableCredentials{ + config: c.config.Clone(), + pool: c.pool, + } +} + +func (c *reloadableCredentials) OverrideServerName(serverName string) error { + c.config.ServerName = serverName + return nil +} + +func NewReloadableCredentials(config *goconf.ConfigFile, server bool) (credentials.TransportCredentials, error) { + var prefix string + var caPrefix string + if server { + prefix = "server" + caPrefix = "client" + } else { + prefix = "client" + caPrefix = "server" + } + certificateFile, _ := config.GetString("grpc", prefix+"certificate") + keyFile, _ := config.GetString("grpc", prefix+"key") + caFile, _ := config.GetString("grpc", caPrefix+"ca") + cfg := &tls.Config{ + NextProtos: []string{"h2"}, + } + if certificateFile != "" && keyFile != "" { + loader, err := NewCertificateReloader(certificateFile, keyFile) + if err != nil { + return nil, fmt.Errorf("invalid GRPC %s certificate / key in %s / %s: %w", prefix, certificateFile, keyFile, err) + } + + if server { + cfg.GetCertificate = loader.GetCertificate + } else { + cfg.GetClientCertificate = loader.GetClientCertificate + } + } + + if caFile != "" { + pool, err := NewCertPoolReloader(caFile) + if err != nil { + return nil, err + } + + if server { + cfg.ClientAuth = tls.RequireAndVerifyClientCert + } + creds := &reloadableCredentials{ + config: cfg, + pool: pool, + } + return creds, nil + } + + if cfg.GetCertificate == nil { + if server { + log.Printf("WARNING: No GRPC server certificate and/or key configured, running unencrypted") + } else { + log.Printf("WARNING: No GRPC CA configured, expecting unencrypted connections") + } + return insecure.NewCredentials(), nil + } + + return credentials.NewTLS(cfg), nil +} diff --git a/grpc_common_test.go b/grpc_common_test.go new file mode 100644 index 0000000..038859b --- /dev/null +++ b/grpc_common_test.go @@ -0,0 +1,88 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2022 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net" + "os" + "testing" + "time" +) + +func GenerateSelfSignedCertificateForTesting(t *testing.T, bits int, organization string, key *rsa.PrivateKey) []byte { + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{organization}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour * 24 * 180), + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageClientAuth, + x509.ExtKeyUsageServerAuth, + }, + BasicConstraintsValid: true, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + + data, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key) + if err != nil { + t.Fatal(err) + } + + data = pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: data, + }) + return data +} + +func WritePrivateKey(key *rsa.PrivateKey, filename string) error { + data := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + }) + + return os.WriteFile(filename, data, 0600) +} + +func WritePublicKey(key *rsa.PublicKey, filename string) error { + data, err := x509.MarshalPKIXPublicKey(key) + if err != nil { + return err + } + + data = pem.EncodeToMemory(&pem.Block{ + Type: "RSA PUBLIC KEY", + Bytes: data, + }) + + return os.WriteFile(filename, data, 0755) +} diff --git a/grpc_server.go b/grpc_server.go index 03c9a34..ac654f9 100644 --- a/grpc_server.go +++ b/grpc_server.go @@ -34,7 +34,6 @@ import ( "github.com/dlintw/goconf" "google.golang.org/grpc" "google.golang.org/grpc/codes" - "google.golang.org/grpc/credentials" status "google.golang.org/grpc/status" ) @@ -76,21 +75,12 @@ func NewGrpcServer(config *goconf.ConfigFile) (*GrpcServer, error) { } } - var opts []grpc.ServerOption - certificateFile, _ := config.GetString("grpc", "certificate") - keyFile, _ := config.GetString("grpc", "key") - if certificateFile != "" && keyFile != "" { - creds, err := credentials.NewServerTLSFromFile(certificateFile, keyFile) - if err != nil { - return nil, fmt.Errorf("invalid GRPC server certificate / key in %s / %s: %w", certificateFile, keyFile, err) - } - - opts = append(opts, grpc.Creds(creds)) - } else { - log.Printf("WARNING: No GRPC server certificate and/or key configured, running unencrypted") + creds, err := NewReloadableCredentials(config, true) + if err != nil { + return nil, err } - conn := grpc.NewServer(opts...) + conn := grpc.NewServer(grpc.Creds(creds)) result := &GrpcServer{ conn: conn, listener: listener, diff --git a/grpc_server_test.go b/grpc_server_test.go index 7292ffc..4ce17a4 100644 --- a/grpc_server_test.go +++ b/grpc_server_test.go @@ -22,15 +22,25 @@ package signaling import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" "net" + "os" + "path" "strconv" "testing" + "time" "github.com/dlintw/goconf" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" ) -func NewGrpcServerForTest(t *testing.T) (server *GrpcServer, addr string) { - config := goconf.NewConfigFile() +func NewGrpcServerForTestWithConfig(t *testing.T, config *goconf.ConfigFile) (server *GrpcServer, addr string) { for port := 50000; port < 50100; port++ { addr = net.JoinHostPort("127.0.0.1", strconv.Itoa(port)) config.AddOption("grpc", "listen", addr) @@ -62,3 +72,174 @@ func NewGrpcServerForTest(t *testing.T) (server *GrpcServer, addr string) { }) return server, addr } + +func NewGrpcServerForTest(t *testing.T) (server *GrpcServer, addr string) { + config := goconf.NewConfigFile() + return NewGrpcServerForTestWithConfig(t, config) +} + +func Test_GrpcServer_ReloadCerts(t *testing.T) { + key, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + t.Fatal(err) + } + + org1 := "Testing certificate" + cert1 := GenerateSelfSignedCertificateForTesting(t, 1024, org1, key) + + dir := t.TempDir() + privkeyFile := path.Join(dir, "privkey.pem") + pubkeyFile := path.Join(dir, "pubkey.pem") + certFile := path.Join(dir, "cert.pem") + WritePrivateKey(key, privkeyFile) // nolint + WritePublicKey(&key.PublicKey, pubkeyFile) // nolint + os.WriteFile(certFile, cert1, 0755) // nolint + + config := goconf.NewConfigFile() + config.AddOption("grpc", "servercertificate", certFile) + config.AddOption("grpc", "serverkey", privkeyFile) + + UpdateCertificateCheckIntervalForTest(t, time.Millisecond) + _, addr := NewGrpcServerForTestWithConfig(t, config) + + cp1 := x509.NewCertPool() + if !cp1.AppendCertsFromPEM(cert1) { + t.Fatalf("could not add certificate") + } + + cfg1 := &tls.Config{ + RootCAs: cp1, + } + conn1, err := tls.Dial("tcp", addr, cfg1) + if err != nil { + t.Fatal(err) + } + defer conn1.Close() // nolint + state1 := conn1.ConnectionState() + if certs := state1.PeerCertificates; len(certs) == 0 { + t.Errorf("expected certificates, got %+v", state1) + } else if len(certs[0].Subject.Organization) == 0 { + t.Errorf("expected organization, got %s", certs[0].Subject) + } else if certs[0].Subject.Organization[0] != org1 { + t.Errorf("expected organization %s, got %s", org1, certs[0].Subject) + } + + org2 := "Updated certificate" + cert2 := GenerateSelfSignedCertificateForTesting(t, 1024, org2, key) + os.WriteFile(certFile, cert2, 0755) // nolint + + cp2 := x509.NewCertPool() + if !cp2.AppendCertsFromPEM(cert2) { + t.Fatalf("could not add certificate") + } + + cfg2 := &tls.Config{ + RootCAs: cp2, + } + conn2, err := tls.Dial("tcp", addr, cfg2) + if err != nil { + t.Fatal(err) + } + defer conn2.Close() // nolint + state2 := conn2.ConnectionState() + if certs := state2.PeerCertificates; len(certs) == 0 { + t.Errorf("expected certificates, got %+v", state2) + } else if len(certs[0].Subject.Organization) == 0 { + t.Errorf("expected organization, got %s", certs[0].Subject) + } else if certs[0].Subject.Organization[0] != org2 { + t.Errorf("expected organization %s, got %s", org2, certs[0].Subject) + } +} + +func Test_GrpcServer_ReloadCA(t *testing.T) { + serverKey, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + t.Fatal(err) + } + clientKey, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + t.Fatal(err) + } + + serverCert := GenerateSelfSignedCertificateForTesting(t, 1024, "Server cert", serverKey) + org1 := "Testing client" + clientCert1 := GenerateSelfSignedCertificateForTesting(t, 1024, org1, clientKey) + + dir := t.TempDir() + privkeyFile := path.Join(dir, "privkey.pem") + pubkeyFile := path.Join(dir, "pubkey.pem") + certFile := path.Join(dir, "cert.pem") + caFile := path.Join(dir, "ca.pem") + WritePrivateKey(serverKey, privkeyFile) // nolint + WritePublicKey(&serverKey.PublicKey, pubkeyFile) // nolint + os.WriteFile(certFile, serverCert, 0755) // nolint + os.WriteFile(caFile, clientCert1, 0755) // nolint + + config := goconf.NewConfigFile() + config.AddOption("grpc", "servercertificate", certFile) + config.AddOption("grpc", "serverkey", privkeyFile) + config.AddOption("grpc", "clientca", caFile) + + UpdateCertificateCheckIntervalForTest(t, time.Millisecond) + _, addr := NewGrpcServerForTestWithConfig(t, config) + + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(serverCert) { + t.Fatalf("could not add certificate") + } + + pair1, err := tls.X509KeyPair(clientCert1, pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(clientKey), + })) + if err != nil { + t.Fatal(err) + } + + cfg1 := &tls.Config{ + RootCAs: pool, + Certificates: []tls.Certificate{pair1}, + } + client1, err := NewGrpcClient(addr, nil, grpc.WithTransportCredentials(credentials.NewTLS(cfg1))) + if err != nil { + t.Fatal(err) + } + defer client1.Close() // nolint + + ctx1, cancel1 := context.WithTimeout(context.Background(), time.Second) + defer cancel1() + + if _, err := client1.GetServerId(ctx1); err != nil { + t.Fatal(err) + } + + org2 := "Updated client" + clientCert2 := GenerateSelfSignedCertificateForTesting(t, 1024, org2, clientKey) + os.WriteFile(caFile, clientCert2, 0755) // nolint + + pair2, err := tls.X509KeyPair(clientCert2, pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(clientKey), + })) + if err != nil { + t.Fatal(err) + } + + cfg2 := &tls.Config{ + RootCAs: pool, + Certificates: []tls.Certificate{pair2}, + } + client2, err := NewGrpcClient(addr, nil, grpc.WithTransportCredentials(credentials.NewTLS(cfg2))) + if err != nil { + t.Fatal(err) + } + defer client2.Close() // nolint + + ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second) + defer cancel2() + + // This will fail if the CA certificate has not been reloaded by the server. + if _, err := client2.GetServerId(ctx2); err != nil { + t.Fatal(err) + } +} diff --git a/server.conf.in b/server.conf.in index 2f9123a..ba3221b 100644 --- a/server.conf.in +++ b/server.conf.in @@ -268,12 +268,21 @@ connectionsperhost = 8 # Certificate / private key to use for the GRPC server. # Omit to use unencrypted connections. -#certificate = /path/to/grpc-server.crt -#key = /path/to/grpc-server.key +#servercertificate = /path/to/grpc-server.crt +#serverkey = /path/to/grpc-server.key # CA certificate that is allowed to issue certificates of GRPC servers. # Omit to expect unencrypted connections. -#ca = /path/to/grpc-ca.crt +#serverca = /path/to/grpc-ca.crt + +# Certificate / private key to use for the GRPC client. +# Omit if clients don't need to authenticate on the server. +#clientcertificate = /path/to/grpc-client.crt +#clientkey = /path/to/grpc-client.key + +# CA certificate that is allowed to issue certificates of GRPC clients. +# Omit to allow any clients to connect. +#clientca = /path/to/grpc-ca.crt # Type of GRPC target configuration. # Defaults to "static". diff --git a/syscallconn.go b/syscallconn.go new file mode 100755 index 0000000..4bec68c --- /dev/null +++ b/syscallconn.go @@ -0,0 +1,58 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package signaling + +import ( + "net" + "syscall" +) + +type sysConn = syscall.Conn + +// syscallConn keeps reference of rawConn to support syscall.Conn for channelz. +// SyscallConn() (the method in interface syscall.Conn) is explicitly +// implemented on this type, +// +// Interface syscall.Conn is implemented by most net.Conn implementations (e.g. +// TCPConn, UnixConn), but is not part of net.Conn interface. So wrapper conns +// that embed net.Conn don't implement syscall.Conn. (Side note: tls.Conn +// doesn't embed net.Conn, so even if syscall.Conn is part of net.Conn, it won't +// help here). +type syscallConn struct { + net.Conn + // sysConn is a type alias of syscall.Conn. It's necessary because the name + // `Conn` collides with `net.Conn`. + sysConn +} + +// WrapSyscallConn tries to wrap rawConn and newConn into a net.Conn that +// implements syscall.Conn. rawConn will be used to support syscall, and newConn +// will be used for read/write. +// +// This function returns newConn if rawConn doesn't implement syscall.Conn. +func WrapSyscallConn(rawConn, newConn net.Conn) net.Conn { + sysConn, ok := rawConn.(syscall.Conn) + if !ok { + return newConn + } + return &syscallConn{ + Conn: newConn, + sysConn: sysConn, + } +}