diff --git a/hub.go b/hub.go index f26bd23..bdc0370 100644 --- a/hub.go +++ b/hub.go @@ -126,6 +126,9 @@ type Hub struct { readPumpActive atomic.Int32 writePumpActive atomic.Int32 + shutdown *Closer + shutdownScheduled atomic.Bool + roomUpdated chan *BackendServerRoomRequest roomDeleted chan *BackendServerRoomRequest roomInCall chan *BackendServerRoomRequest @@ -318,7 +321,8 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer info: NewWelcomeServerMessage(version, DefaultFeatures...), infoInternal: NewWelcomeServerMessage(version, DefaultFeaturesInternal...), - closer: NewCloser(), + closer: NewCloser(), + shutdown: NewCloser(), roomUpdated: make(chan *BackendServerRoomRequest), roomDeleted: make(chan *BackendServerRoomRequest), @@ -685,10 +689,27 @@ func (h *Hub) removeSession(session Session) (removed bool) { delete(h.anonymousSessions, session) delete(h.dialoutSessions, session) } + if h.IsShutdownScheduled() && !h.hasSessionsLocked(false) { + go h.shutdown.Close() + } h.mu.Unlock() return } +func (h *Hub) hasSessionsLocked(withInternal bool) bool { + if withInternal { + return len(h.sessions) > 0 + } + + for _, s := range h.sessions { + if s.ClientType() != HelloClientTypeInternal { + return true + } + } + + return false +} + func (h *Hub) startWaitAnonymousSessionRoom(session *ClientSession) { h.mu.Lock() defer h.mu.Unlock() @@ -2604,3 +2625,23 @@ func (h *Hub) OnMessageReceived(client HandlerClient, data []byte) { func (h *Hub) OnRTTReceived(client HandlerClient, rtt time.Duration) { // Ignore } + +func (h *Hub) ShutdownChannel() <-chan struct{} { + return h.shutdown.C +} + +func (h *Hub) IsShutdownScheduled() bool { + return h.shutdownScheduled.Load() +} + +func (h *Hub) ScheduleShutdown() { + if !h.shutdownScheduled.CompareAndSwap(false, true) { + return + } + + h.mu.RLock() + defer h.mu.RUnlock() + if !h.hasSessionsLocked(false) { + go h.shutdown.Close() + } +} diff --git a/hub_test.go b/hub_test.go index 8419ac9..acc488a 100644 --- a/hub_test.go +++ b/hub_test.go @@ -5866,3 +5866,83 @@ func TestDialoutStatus(t *testing.T) { } } } + +func TestGracefulShutdownInitial(t *testing.T) { + hub, _, _, _ := CreateHubForTest(t) + + hub.ScheduleShutdown() + <-hub.ShutdownChannel() +} + +func TestGracefulShutdownOnBye(t *testing.T) { + hub, _, _, server := CreateHubForTest(t) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client := NewTestClient(t, server, hub) + defer client.CloseWithBye() + + if err := client.SendHello(testDefaultUserId); err != nil { + t.Fatal(err) + } + + if _, err := client.RunUntilHello(ctx); err != nil { + t.Error(err) + } + + hub.ScheduleShutdown() + select { + case <-hub.ShutdownChannel(): + t.Error("should not have shutdown") + case <-time.After(100 * time.Millisecond): + } + + client.CloseWithBye() + + select { + case <-hub.ShutdownChannel(): + case <-time.After(100 * time.Millisecond): + t.Error("should have shutdown") + } +} + +func TestGracefulShutdownOnExpiration(t *testing.T) { + hub, _, _, server := CreateHubForTest(t) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client := NewTestClient(t, server, hub) + defer client.CloseWithBye() + + if err := client.SendHello(testDefaultUserId); err != nil { + t.Fatal(err) + } + + if _, err := client.RunUntilHello(ctx); err != nil { + t.Error(err) + } + + hub.ScheduleShutdown() + select { + case <-hub.ShutdownChannel(): + t.Error("should not have shutdown") + case <-time.After(100 * time.Millisecond): + } + + client.Close() + select { + case <-hub.ShutdownChannel(): + t.Error("should not have shutdown") + case <-time.After(100 * time.Millisecond): + } + + performHousekeeping(hub, time.Now().Add(sessionExpireDuration+time.Second)) + + select { + case <-hub.ShutdownChannel(): + case <-time.After(100 * time.Millisecond): + t.Error("should have shutdown") + } +} diff --git a/server/main.go b/server/main.go index 5d058e3..05f32c3 100644 --- a/server/main.go +++ b/server/main.go @@ -23,6 +23,7 @@ package main import ( "crypto/tls" + "errors" "flag" "fmt" "log" @@ -34,6 +35,7 @@ import ( "runtime" runtimepprof "runtime/pprof" "strings" + "sync" "syscall" "time" @@ -91,6 +93,29 @@ func createTLSListener(addr string, certFile, keyFile string) (net.Listener, err return tls.Listen("tcp", addr, &config) } +type Listeners struct { + mu sync.Mutex + listeners []net.Listener +} + +func (l *Listeners) Add(listener net.Listener) { + l.mu.Lock() + defer l.mu.Unlock() + + l.listeners = append(l.listeners, listener) +} + +func (l *Listeners) Close() { + l.mu.Lock() + defer l.mu.Unlock() + + for _, listener := range l.listeners { + if err := listener.Close(); err != nil { + log.Printf("Error closing listener %s: %s", listener.Addr(), err) + } + } +} + func main() { log.SetFlags(log.Lshortfile) flag.Parse() @@ -103,6 +128,7 @@ func main() { sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, os.Interrupt) signal.Notify(sigChan, syscall.SIGHUP) + signal.Notify(sigChan, syscall.SIGUSR1) if *cpuprofile != "" { f, err := os.Create(*cpuprofile) @@ -300,6 +326,8 @@ func main() { } } + var listeners Listeners + if saddr, _ := config.GetString("https", "listen"); saddr != "" { cert, _ := config.GetString("https", "certificate") key, _ := config.GetString("https", "key") @@ -328,8 +356,11 @@ func main() { ReadTimeout: time.Duration(readTimeout) * time.Second, WriteTimeout: time.Duration(writeTimeout) * time.Second, } + listeners.Add(listener) if err := srv.Serve(listener); err != nil { - log.Fatal("Could not start server: ", err) + if !hub.IsShutdownScheduled() || !errors.Is(err, net.ErrClosed) { + log.Fatal("Could not start server: ", err) + } } }(address) } @@ -359,26 +390,39 @@ func main() { ReadTimeout: time.Duration(readTimeout) * time.Second, WriteTimeout: time.Duration(writeTimeout) * time.Second, } + listeners.Add(listener) if err := srv.Serve(listener); err != nil { - log.Fatal("Could not start server: ", err) + if !hub.IsShutdownScheduled() || !errors.Is(err, net.ErrClosed) { + log.Fatal("Could not start server: ", err) + } } }(address) } } loop: - for sig := range sigChan { - switch sig { - case os.Interrupt: - log.Println("Interrupted") - break loop - case syscall.SIGHUP: - log.Printf("Received SIGHUP, reloading %s", *configFlag) - if config, err := goconf.ReadConfigFile(*configFlag); err != nil { - log.Printf("Could not read configuration from %s: %s", *configFlag, err) - } else { - hub.Reload(config) + for { + select { + case sig := <-sigChan: + switch sig { + case os.Interrupt: + log.Println("Interrupted") + break loop + case syscall.SIGHUP: + log.Printf("Received SIGHUP, reloading %s", *configFlag) + if config, err := goconf.ReadConfigFile(*configFlag); err != nil { + log.Printf("Could not read configuration from %s: %s", *configFlag, err) + } else { + hub.Reload(config) + } + case syscall.SIGUSR1: + log.Printf("Received SIGUSR1, scheduling server to shutdown") + hub.ScheduleShutdown() + listeners.Close() } + case <-hub.ShutdownChannel(): + log.Printf("All clients disconnected, shutting down") + break loop } } }