appservice/websocket: switch from gorilla to coder
Some checks failed
Go / Lint (latest) (push) Has been cancelled
Go / Build (old, libolm) (push) Has been cancelled
Go / Build (latest, libolm) (push) Has been cancelled
Go / Build (old, goolm) (push) Has been cancelled
Go / Build (latest, goolm) (push) Has been cancelled

This commit is contained in:
Tulir Asokan 2025-07-23 22:59:10 +03:00
commit 1dfc6280d5
6 changed files with 58 additions and 54 deletions

View file

@ -19,7 +19,7 @@ import (
"syscall"
"time"
"github.com/gorilla/websocket"
"github.com/coder/websocket"
"github.com/rs/zerolog"
"golang.org/x/net/publicsuffix"
"gopkg.in/yaml.v3"
@ -178,7 +178,6 @@ type AppService struct {
intentsLock sync.RWMutex
ws *websocket.Conn
wsWriteLock sync.Mutex
StopWebsocket func(error)
websocketHandlers map[string]WebsocketHandler
websocketHandlersLock sync.RWMutex

View file

@ -17,9 +17,8 @@ import (
"strings"
"sync"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
"github.com/coder/websocket"
"github.com/rs/zerolog"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
@ -28,11 +27,9 @@ import (
)
type WebsocketRequest struct {
ReqID int `json:"id,omitempty"`
Command string `json:"command"`
Data interface{} `json:"data"`
Deadline time.Duration `json:"-"`
ReqID int `json:"id,omitempty"`
Command string `json:"command"`
Data any `json:"data"`
}
type WebsocketCommand struct {
@ -43,7 +40,7 @@ type WebsocketCommand struct {
Ctx context.Context `json:"-"`
}
func (wsc *WebsocketCommand) MakeResponse(ok bool, data interface{}) *WebsocketRequest {
func (wsc *WebsocketCommand) MakeResponse(ok bool, data any) *WebsocketRequest {
if wsc.ReqID == 0 || wsc.Command == "response" || wsc.Command == "error" {
return nil
}
@ -100,8 +97,8 @@ type WebsocketMessage struct {
}
const (
WebsocketCloseConnReplaced = 4001
WebsocketCloseTxnNotAcknowledged = 4002
WebsocketCloseConnReplaced websocket.StatusCode = 4001
WebsocketCloseTxnNotAcknowledged websocket.StatusCode = 4002
)
type MeowWebsocketCloseCode string
@ -135,7 +132,7 @@ func (mwcc MeowWebsocketCloseCode) String() string {
}
type CloseCommand struct {
Code int `json:"-"`
Code websocket.StatusCode `json:"-"`
Command string `json:"command"`
Status MeowWebsocketCloseCode `json:"status"`
}
@ -145,15 +142,15 @@ func (cc CloseCommand) Error() string {
}
func parseCloseError(err error) error {
closeError := &websocket.CloseError{}
var closeError websocket.CloseError
if !errors.As(err, &closeError) {
return err
}
var closeCommand CloseCommand
closeCommand.Code = closeError.Code
closeCommand.Command = "disconnect"
if len(closeError.Text) > 0 {
jsonErr := json.Unmarshal([]byte(closeError.Text), &closeCommand)
if len(closeError.Reason) > 0 {
jsonErr := json.Unmarshal([]byte(closeError.Reason), &closeCommand)
if jsonErr != nil {
return err
}
@ -161,7 +158,7 @@ func parseCloseError(err error) error {
if len(closeCommand.Status) == 0 {
if closeCommand.Code == WebsocketCloseConnReplaced {
closeCommand.Status = MeowConnectionReplaced
} else if closeCommand.Code == websocket.CloseServiceRestart {
} else if closeCommand.Code == websocket.StatusServiceRestart {
closeCommand.Status = MeowServerShuttingDown
}
}
@ -172,20 +169,22 @@ func (as *AppService) HasWebsocket() bool {
return as.ws != nil
}
func (as *AppService) SendWebsocket(cmd *WebsocketRequest) error {
func (as *AppService) SendWebsocket(ctx context.Context, cmd *WebsocketRequest) error {
ws := as.ws
if cmd == nil {
return nil
} else if ws == nil {
return ErrWebsocketNotConnected
}
as.wsWriteLock.Lock()
defer as.wsWriteLock.Unlock()
if cmd.Deadline == 0 {
cmd.Deadline = 3 * time.Minute
wr, err := ws.Writer(ctx, websocket.MessageText)
if err != nil {
return err
}
_ = ws.SetWriteDeadline(time.Now().Add(cmd.Deadline))
return ws.WriteJSON(cmd)
err = json.NewEncoder(wr).Encode(cmd)
if err != nil {
return err
}
return nil
}
func (as *AppService) clearWebsocketResponseWaiters() {
@ -222,12 +221,12 @@ func (er *ErrorResponse) Error() string {
return fmt.Sprintf("%s: %s", er.Code, er.Message)
}
func (as *AppService) RequestWebsocket(ctx context.Context, cmd *WebsocketRequest, response interface{}) error {
func (as *AppService) RequestWebsocket(ctx context.Context, cmd *WebsocketRequest, response any) error {
cmd.ReqID = int(atomic.AddInt32(&as.websocketRequestID, 1))
respChan := make(chan *WebsocketCommand, 1)
as.addWebsocketResponseWaiter(cmd.ReqID, respChan)
defer as.removeWebsocketResponseWaiter(cmd.ReqID, respChan)
err := as.SendWebsocket(cmd)
err := as.SendWebsocket(ctx, cmd)
if err != nil {
return err
}
@ -256,7 +255,7 @@ func (as *AppService) RequestWebsocket(ctx context.Context, cmd *WebsocketReques
}
}
func (as *AppService) unknownCommandHandler(cmd WebsocketCommand) (bool, interface{}) {
func (as *AppService) unknownCommandHandler(cmd WebsocketCommand) (bool, any) {
zerolog.Ctx(cmd.Ctx).Warn().Msg("No handler for websocket command")
return false, fmt.Errorf("unknown request type")
}
@ -280,14 +279,22 @@ func (as *AppService) defaultHandleWebsocketTransaction(ctx context.Context, msg
return true, &WebsocketTransactionResponse{TxnID: msg.TxnID}
}
func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn) {
func (as *AppService) consumeWebsocket(ctx context.Context, stopFunc func(error), ws *websocket.Conn) {
defer stopFunc(ErrWebsocketUnknownError)
ctx := context.Background()
for {
var msg WebsocketMessage
err := ws.ReadJSON(&msg)
msgType, reader, err := ws.Reader(ctx)
if err != nil {
as.Log.Debug().Err(err).Msg("Error reading from websocket")
as.Log.Debug().Err(err).Msg("Error getting reader from websocket")
stopFunc(parseCloseError(err))
return
} else if msgType != websocket.MessageText {
as.Log.Debug().Msg("Ignoring non-text message from websocket")
continue
}
var msg WebsocketMessage
err = json.NewDecoder(reader).Decode(&msg)
if err != nil {
as.Log.Debug().Err(err).Msg("Error reading JSON from websocket")
stopFunc(parseCloseError(err))
return
}
@ -298,11 +305,11 @@ func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn)
with = with.Str("transaction_id", msg.TxnID)
}
log := with.Logger()
ctx = log.WithContext(ctx)
ctx := log.WithContext(ctx)
if msg.Command == "" || msg.Command == "transaction" {
ok, resp := as.WebsocketTransactionHandler(ctx, msg)
go func() {
err := as.SendWebsocket(msg.MakeResponse(ok, resp))
err := as.SendWebsocket(ctx, msg.MakeResponse(ok, resp))
if err != nil {
log.Warn().Err(err).Msg("Failed to send response to websocket transaction")
} else {
@ -334,7 +341,7 @@ func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn)
}
go func() {
okResp, data := handler(msg.WebsocketCommand)
err := as.SendWebsocket(msg.MakeResponse(okResp, data))
err := as.SendWebsocket(ctx, msg.MakeResponse(okResp, data))
if err != nil {
log.Error().Err(err).Msg("Failed to send response to websocket command")
} else if okResp {
@ -347,7 +354,7 @@ func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn)
}
}
func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error {
func (as *AppService) StartWebsocket(ctx context.Context, baseURL string, onConnect func()) error {
var parsed *url.URL
if baseURL != "" {
var err error
@ -365,12 +372,15 @@ func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error {
} else if parsed.Scheme == "https" {
parsed.Scheme = "wss"
}
ws, resp, err := websocket.DefaultDialer.Dial(parsed.String(), http.Header{
"Authorization": []string{fmt.Sprintf("Bearer %s", as.Registration.AppToken)},
"User-Agent": []string{as.BotClient().UserAgent},
ws, resp, err := websocket.Dial(ctx, parsed.String(), &websocket.DialOptions{
HTTPClient: as.HTTPClient,
HTTPHeader: http.Header{
"Authorization": []string{fmt.Sprintf("Bearer %s", as.Registration.AppToken)},
"User-Agent": []string{as.BotClient().UserAgent},
"X-Mautrix-Process-ID": []string{as.ProcessID},
"X-Mautrix-Websocket-Version": []string{"3"},
"X-Mautrix-Process-ID": []string{as.ProcessID},
"X-Mautrix-Websocket-Version": []string{"3"},
},
})
if resp != nil && resp.StatusCode >= 400 {
var errResp mautrix.RespError
@ -406,7 +416,7 @@ func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error {
as.PrepareWebsocket()
as.Log.Debug().Msg("Appservice transaction websocket opened")
go as.consumeWebsocket(stopFunc, ws)
go as.consumeWebsocket(ctx, stopFunc, ws)
var onConnectDone atomic.Bool
if onConnect != nil {
@ -428,12 +438,7 @@ func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error {
as.ws = nil
}
_ = ws.SetWriteDeadline(time.Now().Add(3 * time.Second))
err = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, ""))
if err != nil && !errors.Is(err, websocket.ErrCloseSent) {
as.Log.Warn().Err(err).Msg("Error writing close message to websocket")
}
err = ws.Close()
err = ws.Close(websocket.StatusGoingAway, "")
if err != nil {
as.Log.Warn().Err(err).Msg("Error closing websocket")
}

View file

@ -413,7 +413,7 @@ func (br *Connector) GhostIntent(userID networkid.UserID) bridgev2.MatrixAPI {
func (br *Connector) SendBridgeStatus(ctx context.Context, state *status.BridgeState) error {
if br.Websocket {
br.hasSentAnyStates = true
return br.AS.SendWebsocket(&appservice.WebsocketRequest{
return br.AS.SendWebsocket(ctx, &appservice.WebsocketRequest{
Command: "bridge_status",
Data: state,
})
@ -484,7 +484,7 @@ func (br *Connector) SendMessageCheckpoints(ctx context.Context, checkpoints []*
checkpointsJSON := status.CheckpointsJSON{Checkpoints: checkpoints}
if br.Websocket {
return br.AS.SendWebsocket(&appservice.WebsocketRequest{
return br.AS.SendWebsocket(ctx, &appservice.WebsocketRequest{
Command: "message_checkpoint",
Data: checkpointsJSON,
})

View file

@ -57,7 +57,7 @@ func (br *Connector) startWebsocket(wg *sync.WaitGroup) {
addr = br.Config.Homeserver.Address
}
for {
err := br.AS.StartWebsocket(addr, onConnect)
err := br.AS.StartWebsocket(br.Bridge.BackgroundCtx, addr, onConnect)
if errors.Is(err, appservice.ErrWebsocketManualStop) {
return
} else if closeCommand := (&appservice.CloseCommand{}); errors.As(err, &closeCommand) && closeCommand.Status == appservice.MeowConnectionReplaced {

2
go.mod
View file

@ -7,7 +7,7 @@ toolchain go1.24.5
require (
filippo.io/edwards25519 v1.1.0
github.com/chzyer/readline v1.5.1
github.com/gorilla/websocket v1.5.0
github.com/coder/websocket v1.8.13
github.com/lib/pq v1.10.9
github.com/mattn/go-sqlite3 v1.14.28
github.com/rs/xid v1.6.0

4
go.sum
View file

@ -8,13 +8,13 @@ github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI
github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk=
github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04=
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE=
github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=