Merge pull request #716 from strukturag/leak-grpc-goroutines

Prevent goroutine leaks in GRPC tests.
This commit is contained in:
Joachim Bauch 2024-04-23 10:59:23 +02:00 committed by GitHub
commit 6960912681
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 204 additions and 149 deletions

View file

@ -66,6 +66,11 @@ func NewCertificateReloader(certFile string, keyFile string) (*CertificateReload
return reloader, nil return reloader, nil
} }
func (r *CertificateReloader) Close() {
r.keyWatcher.Close()
r.certWatcher.Close()
}
func (r *CertificateReloader) reload(filename string) { func (r *CertificateReloader) reload(filename string) {
log.Printf("reloading certificate from %s with %s", r.certFile, r.keyFile) log.Printf("reloading certificate from %s with %s", r.certFile, r.keyFile)
pair, err := tls.LoadX509KeyPair(r.certFile, r.keyFile) pair, err := tls.LoadX509KeyPair(r.certFile, r.keyFile)
@ -135,6 +140,10 @@ func NewCertPoolReloader(certFile string) (*CertPoolReloader, error) {
return reloader, nil return reloader, nil
} }
func (r *CertPoolReloader) Close() {
r.certWatcher.Close()
}
func (r *CertPoolReloader) reload(filename string) { func (r *CertPoolReloader) reload(filename string) {
log.Printf("reloading certificate pool from %s", r.certFile) log.Printf("reloading certificate pool from %s", r.certFile)
pool, err := loadCertPool(r.certFile) pool, err := loadCertPool(r.certFile)

View file

@ -22,6 +22,7 @@
package signaling package signaling
import ( import (
"context"
"errors" "errors"
"log" "log"
"os" "os"
@ -55,6 +56,8 @@ type FileWatcher struct {
callback FileWatcherCallback callback FileWatcherCallback
watcher *fsnotify.Watcher watcher *fsnotify.Watcher
closeCtx context.Context
closeFunc context.CancelFunc
} }
func NewFileWatcher(filename string, callback FileWatcherCallback) (*FileWatcher, error) { func NewFileWatcher(filename string, callback FileWatcherCallback) (*FileWatcher, error) {
@ -78,17 +81,23 @@ func NewFileWatcher(filename string, callback FileWatcherCallback) (*FileWatcher
return nil, err return nil, err
} }
closeCtx, closeFunc := context.WithCancel(context.Background())
w := &FileWatcher{ w := &FileWatcher{
filename: filename, filename: filename,
target: realFilename, target: realFilename,
callback: callback, callback: callback,
watcher: watcher, watcher: watcher,
closeCtx: closeCtx,
closeFunc: closeFunc,
} }
go w.run() go w.run()
return w, nil return w, nil
} }
func (f *FileWatcher) Close() error { func (f *FileWatcher) Close() error {
f.closeFunc()
return f.watcher.Close() return f.watcher.Close()
} }
@ -152,6 +161,8 @@ func (f *FileWatcher) run() {
} }
log.Printf("Error watching %s: %s", f.filename, err) log.Printf("Error watching %s: %s", f.filename, err)
case <-f.closeCtx.Done():
return
} }
} }
} }

View file

@ -47,6 +47,7 @@ func TestFileWatcher_NotExist(t *testing.T) {
} }
func TestFileWatcher_File(t *testing.T) { func TestFileWatcher_File(t *testing.T) {
ensureNoGoroutinesLeak(t, func(t *testing.T) {
tmpdir := t.TempDir() tmpdir := t.TempDir()
filename := path.Join(tmpdir, "test.txt") filename := path.Join(tmpdir, "test.txt")
if err := os.WriteFile(filename, []byte("Hello world!"), 0644); err != nil { if err := os.WriteFile(filename, []byte("Hello world!"), 0644); err != nil {
@ -89,6 +90,7 @@ func TestFileWatcher_File(t *testing.T) {
t.Error("should not have received another event") t.Error("should not have received another event")
case <-ctxTimeout.Done(): case <-ctxTimeout.Done():
} }
})
} }
func TestFileWatcher_Rename(t *testing.T) { func TestFileWatcher_Rename(t *testing.T) {

View file

@ -38,6 +38,7 @@ import (
clientv3 "go.etcd.io/etcd/client/v3" clientv3 "go.etcd.io/etcd/client/v3"
"google.golang.org/grpc" "google.golang.org/grpc"
codes "google.golang.org/grpc/codes" codes "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
status "google.golang.org/grpc/status" status "google.golang.org/grpc/status"
) )
@ -275,6 +276,7 @@ type GrpcClients struct {
targetPrefix string targetPrefix string
targetInformation map[string]*GrpcTargetInformationEtcd targetInformation map[string]*GrpcTargetInformationEtcd
dialOptions atomic.Value // []grpc.DialOption dialOptions atomic.Value // []grpc.DialOption
creds credentials.TransportCredentials
initializedCtx context.Context initializedCtx context.Context
initializedFunc context.CancelFunc initializedFunc context.CancelFunc
@ -308,6 +310,13 @@ func (c *GrpcClients) load(config *goconf.ConfigFile, fromReload bool) error {
return err return err
} }
if c.creds != nil {
if cr, ok := c.creds.(*reloadableCredentials); ok {
cr.Close()
}
}
c.creds = creds
opts := []grpc.DialOption{grpc.WithTransportCredentials(creds)} opts := []grpc.DialOption{grpc.WithTransportCredentials(creds)}
c.dialOptions.Store(opts) c.dialOptions.Store(opts)
@ -490,6 +499,7 @@ func (c *GrpcClients) loadTargetsStatic(config *goconf.ConfigFile, fromReload bo
entry, found := clientsMap[target] entry, found := clientsMap[target]
if !found { if !found {
entry = &grpcClientsList{} entry = &grpcClientsList{}
clientsMap[target] = entry
} }
entry.clients = append(entry.clients, client) entry.clients = append(entry.clients, client)
clients = append(clients, client) clients = append(clients, client)
@ -794,6 +804,11 @@ func (c *GrpcClients) Close() {
if c.etcdClient != nil { if c.etcdClient != nil {
c.etcdClient.RemoveListener(c) c.etcdClient.RemoveListener(c)
} }
if c.creds != nil {
if cr, ok := c.creds.(*reloadableCredentials); ok {
cr.Close()
}
}
c.closeFunc() c.closeFunc()
} }

View file

@ -112,6 +112,7 @@ func waitForEvent(ctx context.Context, t *testing.T, ch <-chan struct{}) {
} }
func Test_GrpcClients_EtcdInitial(t *testing.T) { func Test_GrpcClients_EtcdInitial(t *testing.T) {
ensureNoGoroutinesLeak(t, func(t *testing.T) {
_, addr1 := NewGrpcServerForTest(t) _, addr1 := NewGrpcServerForTest(t)
_, addr2 := NewGrpcServerForTest(t) _, addr2 := NewGrpcServerForTest(t)
@ -130,6 +131,7 @@ func Test_GrpcClients_EtcdInitial(t *testing.T) {
if clients := client.GetClients(); len(clients) != 2 { if clients := client.GetClients(); len(clients) != 2 {
t.Errorf("Expected two clients, got %+v", clients) t.Errorf("Expected two clients, got %+v", clients)
} }
})
} }
func Test_GrpcClients_EtcdUpdate(t *testing.T) { func Test_GrpcClients_EtcdUpdate(t *testing.T) {
@ -231,6 +233,7 @@ func Test_GrpcClients_EtcdIgnoreSelf(t *testing.T) {
} }
func Test_GrpcClients_DnsDiscovery(t *testing.T) { func Test_GrpcClients_DnsDiscovery(t *testing.T) {
ensureNoGoroutinesLeak(t, func(t *testing.T) {
lookup := newMockDnsLookupForTest(t) lookup := newMockDnsLookupForTest(t)
target := "testgrpc:12345" target := "testgrpc:12345"
ip1 := net.ParseIP("192.168.0.1") ip1 := net.ParseIP("192.168.0.1")
@ -282,6 +285,7 @@ func Test_GrpcClients_DnsDiscovery(t *testing.T) {
} else if !clients[0].ip.Equal(ip2) { } else if !clients[0].ip.Equal(ip2) {
t.Errorf("Expected IP %s, got %s", ip2, clients[0].ip) t.Errorf("Expected IP %s, got %s", ip2, clients[0].ip)
} }
})
} }
func Test_GrpcClients_DnsDiscoveryInitialFailed(t *testing.T) { func Test_GrpcClients_DnsDiscoveryInitialFailed(t *testing.T) {
@ -320,6 +324,7 @@ func Test_GrpcClients_DnsDiscoveryInitialFailed(t *testing.T) {
} }
func Test_GrpcClients_Encryption(t *testing.T) { func Test_GrpcClients_Encryption(t *testing.T) {
ensureNoGoroutinesLeak(t, func(t *testing.T) {
serverKey, err := rsa.GenerateKey(rand.Reader, 1024) serverKey, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -371,4 +376,5 @@ func Test_GrpcClients_Encryption(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
} }
})
} }

View file

@ -125,6 +125,15 @@ func (c *reloadableCredentials) OverrideServerName(serverName string) error {
return nil return nil
} }
func (c *reloadableCredentials) Close() {
if c.loader != nil {
c.loader.Close()
}
if c.pool != nil {
c.pool.Close()
}
}
func NewReloadableCredentials(config *goconf.ConfigFile, server bool) (credentials.TransportCredentials, error) { func NewReloadableCredentials(config *goconf.ConfigFile, server bool) (credentials.TransportCredentials, error) {
var prefix string var prefix string
var caPrefix string var caPrefix string

View file

@ -108,6 +108,9 @@ func (s *GrpcServer) Run() error {
func (s *GrpcServer) Close() { func (s *GrpcServer) Close() {
s.conn.GracefulStop() s.conn.GracefulStop()
if cr, ok := s.creds.(*reloadableCredentials); ok {
cr.Close()
}
} }
func (s *GrpcServer) LookupSessionId(ctx context.Context, request *LookupSessionIdRequest) (*LookupSessionIdReply, error) { func (s *GrpcServer) LookupSessionId(ctx context.Context, request *LookupSessionIdRequest) (*LookupSessionIdReply, error) {