diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index a12da24b..8eaf6a1e 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -120,3 +120,21 @@ func (br *Bridge) Start() error { br.Log.Info().Msg("Bridge started") return nil } + +func (br *Bridge) Stop() { + br.Log.Info().Msg("Shutting down bridge") + br.Matrix.Stop() + br.cacheLock.Lock() + var wg sync.WaitGroup + wg.Add(len(br.userLoginsByID)) + for _, login := range br.userLoginsByID { + go login.Disconnect(wg.Done) + } + wg.Wait() + br.cacheLock.Unlock() + err := br.DB.Close() + if err != nil { + br.Log.Warn().Err(err).Msg("Failed to close database") + } + br.Log.Info().Msg("Shutdown complete") +} diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 3493e7a3..5773a1bf 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -141,6 +141,14 @@ func (br *Connector) Start(ctx context.Context) error { return nil } +func (br *Connector) Stop() { + br.AS.Stop() + br.EventProcessor.Stop() + if br.Crypto != nil { + br.Crypto.Stop() + } +} + var MinSpecVersion = mautrix.SpecV14 func (br *Connector) ensureConnection(ctx context.Context) { diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index dd4c0328..3b2bc460 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -354,8 +354,7 @@ func (br *BridgeMain) WaitForInterrupt() { // Stop cleanly stops the bridge. This is called by [Run] and does not need to be called manually. func (br *BridgeMain) Stop() { - br.Log.Info().Msg("Shutting down bridge") - // TODO actually stop cleanly + br.Bridge.Stop() } // InitVersion formats the bridge version and build time nicely for things like diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index da8c92a8..56fbde38 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -20,6 +20,7 @@ import ( type MatrixConnector interface { Init(*Bridge) Start(ctx context.Context) error + Stop() ParseGhostMXID(userID id.UserID) (networkid.UserID, bool) FormatGhostMXID(userID networkid.UserID) id.UserID diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 48781b89..d82dd7ff 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -141,6 +141,7 @@ type MaxFileSizeingNetwork interface { // NetworkAPI is an interface representing a remote network client for a single user login. type NetworkAPI interface { Connect(ctx context.Context) error + Disconnect() IsLoggedIn() bool LogoutRemote(ctx context.Context) diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 3a980c01..beaa03c7 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -9,6 +9,7 @@ package bridgev2 import ( "context" "fmt" + "time" "github.com/rs/zerolog" @@ -161,3 +162,20 @@ func (ul *UserLogin) GetRemoteName() string { name, _ := ul.Metadata["remote_name"].(string) return name } + +func (ul *UserLogin) Disconnect(done func()) { + defer done() + if ul.Client != nil { + disconnected := make(chan struct{}) + go func() { + ul.Client.Disconnect() + ul.Client = nil + close(disconnected) + }() + select { + case <-disconnected: + case <-time.After(5 * time.Second): + ul.Log.Warn().Msg("Client disconnection timed out") + } + } +}