From 1dfc6280d5f41f37295f2e1a13c699080fd3d2cc Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 23 Jul 2025 22:59:10 +0300 Subject: [PATCH] appservice/websocket: switch from gorilla to coder --- appservice/appservice.go | 3 +- appservice/websocket.go | 97 +++++++++++++++++++----------------- bridgev2/matrix/connector.go | 4 +- bridgev2/matrix/websocket.go | 2 +- go.mod | 2 +- go.sum | 4 +- 6 files changed, 58 insertions(+), 54 deletions(-) diff --git a/appservice/appservice.go b/appservice/appservice.go index 5dd067c0..b0af02cd 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -19,7 +19,7 @@ import ( "syscall" "time" - "github.com/gorilla/websocket" + "github.com/coder/websocket" "github.com/rs/zerolog" "golang.org/x/net/publicsuffix" "gopkg.in/yaml.v3" @@ -178,7 +178,6 @@ type AppService struct { intentsLock sync.RWMutex ws *websocket.Conn - wsWriteLock sync.Mutex StopWebsocket func(error) websocketHandlers map[string]WebsocketHandler websocketHandlersLock sync.RWMutex diff --git a/appservice/websocket.go b/appservice/websocket.go index 3d5bd232..62f4370c 100644 --- a/appservice/websocket.go +++ b/appservice/websocket.go @@ -17,9 +17,8 @@ import ( "strings" "sync" "sync/atomic" - "time" - "github.com/gorilla/websocket" + "github.com/coder/websocket" "github.com/rs/zerolog" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -28,11 +27,9 @@ import ( ) type WebsocketRequest struct { - ReqID int `json:"id,omitempty"` - Command string `json:"command"` - Data interface{} `json:"data"` - - Deadline time.Duration `json:"-"` + ReqID int `json:"id,omitempty"` + Command string `json:"command"` + Data any `json:"data"` } type WebsocketCommand struct { @@ -43,7 +40,7 @@ type WebsocketCommand struct { Ctx context.Context `json:"-"` } -func (wsc *WebsocketCommand) MakeResponse(ok bool, data interface{}) *WebsocketRequest { +func (wsc *WebsocketCommand) MakeResponse(ok bool, data any) *WebsocketRequest { if wsc.ReqID == 0 || wsc.Command == "response" || wsc.Command == "error" { return nil } @@ -100,8 +97,8 @@ type WebsocketMessage struct { } const ( - WebsocketCloseConnReplaced = 4001 - WebsocketCloseTxnNotAcknowledged = 4002 + WebsocketCloseConnReplaced websocket.StatusCode = 4001 + WebsocketCloseTxnNotAcknowledged websocket.StatusCode = 4002 ) type MeowWebsocketCloseCode string @@ -135,7 +132,7 @@ func (mwcc MeowWebsocketCloseCode) String() string { } type CloseCommand struct { - Code int `json:"-"` + Code websocket.StatusCode `json:"-"` Command string `json:"command"` Status MeowWebsocketCloseCode `json:"status"` } @@ -145,15 +142,15 @@ func (cc CloseCommand) Error() string { } func parseCloseError(err error) error { - closeError := &websocket.CloseError{} + var closeError websocket.CloseError if !errors.As(err, &closeError) { return err } var closeCommand CloseCommand closeCommand.Code = closeError.Code closeCommand.Command = "disconnect" - if len(closeError.Text) > 0 { - jsonErr := json.Unmarshal([]byte(closeError.Text), &closeCommand) + if len(closeError.Reason) > 0 { + jsonErr := json.Unmarshal([]byte(closeError.Reason), &closeCommand) if jsonErr != nil { return err } @@ -161,7 +158,7 @@ func parseCloseError(err error) error { if len(closeCommand.Status) == 0 { if closeCommand.Code == WebsocketCloseConnReplaced { closeCommand.Status = MeowConnectionReplaced - } else if closeCommand.Code == websocket.CloseServiceRestart { + } else if closeCommand.Code == websocket.StatusServiceRestart { closeCommand.Status = MeowServerShuttingDown } } @@ -172,20 +169,22 @@ func (as *AppService) HasWebsocket() bool { return as.ws != nil } -func (as *AppService) SendWebsocket(cmd *WebsocketRequest) error { +func (as *AppService) SendWebsocket(ctx context.Context, cmd *WebsocketRequest) error { ws := as.ws if cmd == nil { return nil } else if ws == nil { return ErrWebsocketNotConnected } - as.wsWriteLock.Lock() - defer as.wsWriteLock.Unlock() - if cmd.Deadline == 0 { - cmd.Deadline = 3 * time.Minute + wr, err := ws.Writer(ctx, websocket.MessageText) + if err != nil { + return err } - _ = ws.SetWriteDeadline(time.Now().Add(cmd.Deadline)) - return ws.WriteJSON(cmd) + err = json.NewEncoder(wr).Encode(cmd) + if err != nil { + return err + } + return nil } func (as *AppService) clearWebsocketResponseWaiters() { @@ -222,12 +221,12 @@ func (er *ErrorResponse) Error() string { return fmt.Sprintf("%s: %s", er.Code, er.Message) } -func (as *AppService) RequestWebsocket(ctx context.Context, cmd *WebsocketRequest, response interface{}) error { +func (as *AppService) RequestWebsocket(ctx context.Context, cmd *WebsocketRequest, response any) error { cmd.ReqID = int(atomic.AddInt32(&as.websocketRequestID, 1)) respChan := make(chan *WebsocketCommand, 1) as.addWebsocketResponseWaiter(cmd.ReqID, respChan) defer as.removeWebsocketResponseWaiter(cmd.ReqID, respChan) - err := as.SendWebsocket(cmd) + err := as.SendWebsocket(ctx, cmd) if err != nil { return err } @@ -256,7 +255,7 @@ func (as *AppService) RequestWebsocket(ctx context.Context, cmd *WebsocketReques } } -func (as *AppService) unknownCommandHandler(cmd WebsocketCommand) (bool, interface{}) { +func (as *AppService) unknownCommandHandler(cmd WebsocketCommand) (bool, any) { zerolog.Ctx(cmd.Ctx).Warn().Msg("No handler for websocket command") return false, fmt.Errorf("unknown request type") } @@ -280,14 +279,22 @@ func (as *AppService) defaultHandleWebsocketTransaction(ctx context.Context, msg return true, &WebsocketTransactionResponse{TxnID: msg.TxnID} } -func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn) { +func (as *AppService) consumeWebsocket(ctx context.Context, stopFunc func(error), ws *websocket.Conn) { defer stopFunc(ErrWebsocketUnknownError) - ctx := context.Background() for { - var msg WebsocketMessage - err := ws.ReadJSON(&msg) + msgType, reader, err := ws.Reader(ctx) if err != nil { - as.Log.Debug().Err(err).Msg("Error reading from websocket") + as.Log.Debug().Err(err).Msg("Error getting reader from websocket") + stopFunc(parseCloseError(err)) + return + } else if msgType != websocket.MessageText { + as.Log.Debug().Msg("Ignoring non-text message from websocket") + continue + } + var msg WebsocketMessage + err = json.NewDecoder(reader).Decode(&msg) + if err != nil { + as.Log.Debug().Err(err).Msg("Error reading JSON from websocket") stopFunc(parseCloseError(err)) return } @@ -298,11 +305,11 @@ func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn) with = with.Str("transaction_id", msg.TxnID) } log := with.Logger() - ctx = log.WithContext(ctx) + ctx := log.WithContext(ctx) if msg.Command == "" || msg.Command == "transaction" { ok, resp := as.WebsocketTransactionHandler(ctx, msg) go func() { - err := as.SendWebsocket(msg.MakeResponse(ok, resp)) + err := as.SendWebsocket(ctx, msg.MakeResponse(ok, resp)) if err != nil { log.Warn().Err(err).Msg("Failed to send response to websocket transaction") } else { @@ -334,7 +341,7 @@ func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn) } go func() { okResp, data := handler(msg.WebsocketCommand) - err := as.SendWebsocket(msg.MakeResponse(okResp, data)) + err := as.SendWebsocket(ctx, msg.MakeResponse(okResp, data)) if err != nil { log.Error().Err(err).Msg("Failed to send response to websocket command") } else if okResp { @@ -347,7 +354,7 @@ func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn) } } -func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error { +func (as *AppService) StartWebsocket(ctx context.Context, baseURL string, onConnect func()) error { var parsed *url.URL if baseURL != "" { var err error @@ -365,12 +372,15 @@ func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error { } else if parsed.Scheme == "https" { parsed.Scheme = "wss" } - ws, resp, err := websocket.DefaultDialer.Dial(parsed.String(), http.Header{ - "Authorization": []string{fmt.Sprintf("Bearer %s", as.Registration.AppToken)}, - "User-Agent": []string{as.BotClient().UserAgent}, + ws, resp, err := websocket.Dial(ctx, parsed.String(), &websocket.DialOptions{ + HTTPClient: as.HTTPClient, + HTTPHeader: http.Header{ + "Authorization": []string{fmt.Sprintf("Bearer %s", as.Registration.AppToken)}, + "User-Agent": []string{as.BotClient().UserAgent}, - "X-Mautrix-Process-ID": []string{as.ProcessID}, - "X-Mautrix-Websocket-Version": []string{"3"}, + "X-Mautrix-Process-ID": []string{as.ProcessID}, + "X-Mautrix-Websocket-Version": []string{"3"}, + }, }) if resp != nil && resp.StatusCode >= 400 { var errResp mautrix.RespError @@ -406,7 +416,7 @@ func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error { as.PrepareWebsocket() as.Log.Debug().Msg("Appservice transaction websocket opened") - go as.consumeWebsocket(stopFunc, ws) + go as.consumeWebsocket(ctx, stopFunc, ws) var onConnectDone atomic.Bool if onConnect != nil { @@ -428,12 +438,7 @@ func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error { as.ws = nil } - _ = ws.SetWriteDeadline(time.Now().Add(3 * time.Second)) - err = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, "")) - if err != nil && !errors.Is(err, websocket.ErrCloseSent) { - as.Log.Warn().Err(err).Msg("Error writing close message to websocket") - } - err = ws.Close() + err = ws.Close(websocket.StatusGoingAway, "") if err != nil { as.Log.Warn().Err(err).Msg("Error closing websocket") } diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 0a859e42..158148f3 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -413,7 +413,7 @@ func (br *Connector) GhostIntent(userID networkid.UserID) bridgev2.MatrixAPI { func (br *Connector) SendBridgeStatus(ctx context.Context, state *status.BridgeState) error { if br.Websocket { br.hasSentAnyStates = true - return br.AS.SendWebsocket(&appservice.WebsocketRequest{ + return br.AS.SendWebsocket(ctx, &appservice.WebsocketRequest{ Command: "bridge_status", Data: state, }) @@ -484,7 +484,7 @@ func (br *Connector) SendMessageCheckpoints(ctx context.Context, checkpoints []* checkpointsJSON := status.CheckpointsJSON{Checkpoints: checkpoints} if br.Websocket { - return br.AS.SendWebsocket(&appservice.WebsocketRequest{ + return br.AS.SendWebsocket(ctx, &appservice.WebsocketRequest{ Command: "message_checkpoint", Data: checkpointsJSON, }) diff --git a/bridgev2/matrix/websocket.go b/bridgev2/matrix/websocket.go index c679f960..b498cacd 100644 --- a/bridgev2/matrix/websocket.go +++ b/bridgev2/matrix/websocket.go @@ -57,7 +57,7 @@ func (br *Connector) startWebsocket(wg *sync.WaitGroup) { addr = br.Config.Homeserver.Address } for { - err := br.AS.StartWebsocket(addr, onConnect) + err := br.AS.StartWebsocket(br.Bridge.BackgroundCtx, addr, onConnect) if errors.Is(err, appservice.ErrWebsocketManualStop) { return } else if closeCommand := (&appservice.CloseCommand{}); errors.As(err, &closeCommand) && closeCommand.Status == appservice.MeowConnectionReplaced { diff --git a/go.mod b/go.mod index d71e86ab..1133313f 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ toolchain go1.24.5 require ( filippo.io/edwards25519 v1.1.0 github.com/chzyer/readline v1.5.1 - github.com/gorilla/websocket v1.5.0 + github.com/coder/websocket v1.8.13 github.com/lib/pq v1.10.9 github.com/mattn/go-sqlite3 v1.14.28 github.com/rs/xid v1.6.0 diff --git a/go.sum b/go.sum index eaa97cc8..461ee542 100644 --- a/go.sum +++ b/go.sum @@ -8,13 +8,13 @@ github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= +github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE= +github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= -github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=