mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
appservice/websocket: switch from gorilla to coder
Some checks failed
Some checks failed
This commit is contained in:
parent
80befbf8e1
commit
1dfc6280d5
6 changed files with 58 additions and 54 deletions
|
|
@ -19,7 +19,7 @@ import (
|
|||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/coder/websocket"
|
||||
"github.com/rs/zerolog"
|
||||
"golang.org/x/net/publicsuffix"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
|
@ -178,7 +178,6 @@ type AppService struct {
|
|||
intentsLock sync.RWMutex
|
||||
|
||||
ws *websocket.Conn
|
||||
wsWriteLock sync.Mutex
|
||||
StopWebsocket func(error)
|
||||
websocketHandlers map[string]WebsocketHandler
|
||||
websocketHandlersLock sync.RWMutex
|
||||
|
|
|
|||
|
|
@ -17,9 +17,8 @@ import (
|
|||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/coder/websocket"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
|
|
@ -28,11 +27,9 @@ import (
|
|||
)
|
||||
|
||||
type WebsocketRequest struct {
|
||||
ReqID int `json:"id,omitempty"`
|
||||
Command string `json:"command"`
|
||||
Data interface{} `json:"data"`
|
||||
|
||||
Deadline time.Duration `json:"-"`
|
||||
ReqID int `json:"id,omitempty"`
|
||||
Command string `json:"command"`
|
||||
Data any `json:"data"`
|
||||
}
|
||||
|
||||
type WebsocketCommand struct {
|
||||
|
|
@ -43,7 +40,7 @@ type WebsocketCommand struct {
|
|||
Ctx context.Context `json:"-"`
|
||||
}
|
||||
|
||||
func (wsc *WebsocketCommand) MakeResponse(ok bool, data interface{}) *WebsocketRequest {
|
||||
func (wsc *WebsocketCommand) MakeResponse(ok bool, data any) *WebsocketRequest {
|
||||
if wsc.ReqID == 0 || wsc.Command == "response" || wsc.Command == "error" {
|
||||
return nil
|
||||
}
|
||||
|
|
@ -100,8 +97,8 @@ type WebsocketMessage struct {
|
|||
}
|
||||
|
||||
const (
|
||||
WebsocketCloseConnReplaced = 4001
|
||||
WebsocketCloseTxnNotAcknowledged = 4002
|
||||
WebsocketCloseConnReplaced websocket.StatusCode = 4001
|
||||
WebsocketCloseTxnNotAcknowledged websocket.StatusCode = 4002
|
||||
)
|
||||
|
||||
type MeowWebsocketCloseCode string
|
||||
|
|
@ -135,7 +132,7 @@ func (mwcc MeowWebsocketCloseCode) String() string {
|
|||
}
|
||||
|
||||
type CloseCommand struct {
|
||||
Code int `json:"-"`
|
||||
Code websocket.StatusCode `json:"-"`
|
||||
Command string `json:"command"`
|
||||
Status MeowWebsocketCloseCode `json:"status"`
|
||||
}
|
||||
|
|
@ -145,15 +142,15 @@ func (cc CloseCommand) Error() string {
|
|||
}
|
||||
|
||||
func parseCloseError(err error) error {
|
||||
closeError := &websocket.CloseError{}
|
||||
var closeError websocket.CloseError
|
||||
if !errors.As(err, &closeError) {
|
||||
return err
|
||||
}
|
||||
var closeCommand CloseCommand
|
||||
closeCommand.Code = closeError.Code
|
||||
closeCommand.Command = "disconnect"
|
||||
if len(closeError.Text) > 0 {
|
||||
jsonErr := json.Unmarshal([]byte(closeError.Text), &closeCommand)
|
||||
if len(closeError.Reason) > 0 {
|
||||
jsonErr := json.Unmarshal([]byte(closeError.Reason), &closeCommand)
|
||||
if jsonErr != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -161,7 +158,7 @@ func parseCloseError(err error) error {
|
|||
if len(closeCommand.Status) == 0 {
|
||||
if closeCommand.Code == WebsocketCloseConnReplaced {
|
||||
closeCommand.Status = MeowConnectionReplaced
|
||||
} else if closeCommand.Code == websocket.CloseServiceRestart {
|
||||
} else if closeCommand.Code == websocket.StatusServiceRestart {
|
||||
closeCommand.Status = MeowServerShuttingDown
|
||||
}
|
||||
}
|
||||
|
|
@ -172,20 +169,22 @@ func (as *AppService) HasWebsocket() bool {
|
|||
return as.ws != nil
|
||||
}
|
||||
|
||||
func (as *AppService) SendWebsocket(cmd *WebsocketRequest) error {
|
||||
func (as *AppService) SendWebsocket(ctx context.Context, cmd *WebsocketRequest) error {
|
||||
ws := as.ws
|
||||
if cmd == nil {
|
||||
return nil
|
||||
} else if ws == nil {
|
||||
return ErrWebsocketNotConnected
|
||||
}
|
||||
as.wsWriteLock.Lock()
|
||||
defer as.wsWriteLock.Unlock()
|
||||
if cmd.Deadline == 0 {
|
||||
cmd.Deadline = 3 * time.Minute
|
||||
wr, err := ws.Writer(ctx, websocket.MessageText)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_ = ws.SetWriteDeadline(time.Now().Add(cmd.Deadline))
|
||||
return ws.WriteJSON(cmd)
|
||||
err = json.NewEncoder(wr).Encode(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (as *AppService) clearWebsocketResponseWaiters() {
|
||||
|
|
@ -222,12 +221,12 @@ func (er *ErrorResponse) Error() string {
|
|||
return fmt.Sprintf("%s: %s", er.Code, er.Message)
|
||||
}
|
||||
|
||||
func (as *AppService) RequestWebsocket(ctx context.Context, cmd *WebsocketRequest, response interface{}) error {
|
||||
func (as *AppService) RequestWebsocket(ctx context.Context, cmd *WebsocketRequest, response any) error {
|
||||
cmd.ReqID = int(atomic.AddInt32(&as.websocketRequestID, 1))
|
||||
respChan := make(chan *WebsocketCommand, 1)
|
||||
as.addWebsocketResponseWaiter(cmd.ReqID, respChan)
|
||||
defer as.removeWebsocketResponseWaiter(cmd.ReqID, respChan)
|
||||
err := as.SendWebsocket(cmd)
|
||||
err := as.SendWebsocket(ctx, cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -256,7 +255,7 @@ func (as *AppService) RequestWebsocket(ctx context.Context, cmd *WebsocketReques
|
|||
}
|
||||
}
|
||||
|
||||
func (as *AppService) unknownCommandHandler(cmd WebsocketCommand) (bool, interface{}) {
|
||||
func (as *AppService) unknownCommandHandler(cmd WebsocketCommand) (bool, any) {
|
||||
zerolog.Ctx(cmd.Ctx).Warn().Msg("No handler for websocket command")
|
||||
return false, fmt.Errorf("unknown request type")
|
||||
}
|
||||
|
|
@ -280,14 +279,22 @@ func (as *AppService) defaultHandleWebsocketTransaction(ctx context.Context, msg
|
|||
return true, &WebsocketTransactionResponse{TxnID: msg.TxnID}
|
||||
}
|
||||
|
||||
func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn) {
|
||||
func (as *AppService) consumeWebsocket(ctx context.Context, stopFunc func(error), ws *websocket.Conn) {
|
||||
defer stopFunc(ErrWebsocketUnknownError)
|
||||
ctx := context.Background()
|
||||
for {
|
||||
var msg WebsocketMessage
|
||||
err := ws.ReadJSON(&msg)
|
||||
msgType, reader, err := ws.Reader(ctx)
|
||||
if err != nil {
|
||||
as.Log.Debug().Err(err).Msg("Error reading from websocket")
|
||||
as.Log.Debug().Err(err).Msg("Error getting reader from websocket")
|
||||
stopFunc(parseCloseError(err))
|
||||
return
|
||||
} else if msgType != websocket.MessageText {
|
||||
as.Log.Debug().Msg("Ignoring non-text message from websocket")
|
||||
continue
|
||||
}
|
||||
var msg WebsocketMessage
|
||||
err = json.NewDecoder(reader).Decode(&msg)
|
||||
if err != nil {
|
||||
as.Log.Debug().Err(err).Msg("Error reading JSON from websocket")
|
||||
stopFunc(parseCloseError(err))
|
||||
return
|
||||
}
|
||||
|
|
@ -298,11 +305,11 @@ func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn)
|
|||
with = with.Str("transaction_id", msg.TxnID)
|
||||
}
|
||||
log := with.Logger()
|
||||
ctx = log.WithContext(ctx)
|
||||
ctx := log.WithContext(ctx)
|
||||
if msg.Command == "" || msg.Command == "transaction" {
|
||||
ok, resp := as.WebsocketTransactionHandler(ctx, msg)
|
||||
go func() {
|
||||
err := as.SendWebsocket(msg.MakeResponse(ok, resp))
|
||||
err := as.SendWebsocket(ctx, msg.MakeResponse(ok, resp))
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to send response to websocket transaction")
|
||||
} else {
|
||||
|
|
@ -334,7 +341,7 @@ func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn)
|
|||
}
|
||||
go func() {
|
||||
okResp, data := handler(msg.WebsocketCommand)
|
||||
err := as.SendWebsocket(msg.MakeResponse(okResp, data))
|
||||
err := as.SendWebsocket(ctx, msg.MakeResponse(okResp, data))
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to send response to websocket command")
|
||||
} else if okResp {
|
||||
|
|
@ -347,7 +354,7 @@ func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn)
|
|||
}
|
||||
}
|
||||
|
||||
func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error {
|
||||
func (as *AppService) StartWebsocket(ctx context.Context, baseURL string, onConnect func()) error {
|
||||
var parsed *url.URL
|
||||
if baseURL != "" {
|
||||
var err error
|
||||
|
|
@ -365,12 +372,15 @@ func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error {
|
|||
} else if parsed.Scheme == "https" {
|
||||
parsed.Scheme = "wss"
|
||||
}
|
||||
ws, resp, err := websocket.DefaultDialer.Dial(parsed.String(), http.Header{
|
||||
"Authorization": []string{fmt.Sprintf("Bearer %s", as.Registration.AppToken)},
|
||||
"User-Agent": []string{as.BotClient().UserAgent},
|
||||
ws, resp, err := websocket.Dial(ctx, parsed.String(), &websocket.DialOptions{
|
||||
HTTPClient: as.HTTPClient,
|
||||
HTTPHeader: http.Header{
|
||||
"Authorization": []string{fmt.Sprintf("Bearer %s", as.Registration.AppToken)},
|
||||
"User-Agent": []string{as.BotClient().UserAgent},
|
||||
|
||||
"X-Mautrix-Process-ID": []string{as.ProcessID},
|
||||
"X-Mautrix-Websocket-Version": []string{"3"},
|
||||
"X-Mautrix-Process-ID": []string{as.ProcessID},
|
||||
"X-Mautrix-Websocket-Version": []string{"3"},
|
||||
},
|
||||
})
|
||||
if resp != nil && resp.StatusCode >= 400 {
|
||||
var errResp mautrix.RespError
|
||||
|
|
@ -406,7 +416,7 @@ func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error {
|
|||
as.PrepareWebsocket()
|
||||
as.Log.Debug().Msg("Appservice transaction websocket opened")
|
||||
|
||||
go as.consumeWebsocket(stopFunc, ws)
|
||||
go as.consumeWebsocket(ctx, stopFunc, ws)
|
||||
|
||||
var onConnectDone atomic.Bool
|
||||
if onConnect != nil {
|
||||
|
|
@ -428,12 +438,7 @@ func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error {
|
|||
as.ws = nil
|
||||
}
|
||||
|
||||
_ = ws.SetWriteDeadline(time.Now().Add(3 * time.Second))
|
||||
err = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, ""))
|
||||
if err != nil && !errors.Is(err, websocket.ErrCloseSent) {
|
||||
as.Log.Warn().Err(err).Msg("Error writing close message to websocket")
|
||||
}
|
||||
err = ws.Close()
|
||||
err = ws.Close(websocket.StatusGoingAway, "")
|
||||
if err != nil {
|
||||
as.Log.Warn().Err(err).Msg("Error closing websocket")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -413,7 +413,7 @@ func (br *Connector) GhostIntent(userID networkid.UserID) bridgev2.MatrixAPI {
|
|||
func (br *Connector) SendBridgeStatus(ctx context.Context, state *status.BridgeState) error {
|
||||
if br.Websocket {
|
||||
br.hasSentAnyStates = true
|
||||
return br.AS.SendWebsocket(&appservice.WebsocketRequest{
|
||||
return br.AS.SendWebsocket(ctx, &appservice.WebsocketRequest{
|
||||
Command: "bridge_status",
|
||||
Data: state,
|
||||
})
|
||||
|
|
@ -484,7 +484,7 @@ func (br *Connector) SendMessageCheckpoints(ctx context.Context, checkpoints []*
|
|||
checkpointsJSON := status.CheckpointsJSON{Checkpoints: checkpoints}
|
||||
|
||||
if br.Websocket {
|
||||
return br.AS.SendWebsocket(&appservice.WebsocketRequest{
|
||||
return br.AS.SendWebsocket(ctx, &appservice.WebsocketRequest{
|
||||
Command: "message_checkpoint",
|
||||
Data: checkpointsJSON,
|
||||
})
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ func (br *Connector) startWebsocket(wg *sync.WaitGroup) {
|
|||
addr = br.Config.Homeserver.Address
|
||||
}
|
||||
for {
|
||||
err := br.AS.StartWebsocket(addr, onConnect)
|
||||
err := br.AS.StartWebsocket(br.Bridge.BackgroundCtx, addr, onConnect)
|
||||
if errors.Is(err, appservice.ErrWebsocketManualStop) {
|
||||
return
|
||||
} else if closeCommand := (&appservice.CloseCommand{}); errors.As(err, &closeCommand) && closeCommand.Status == appservice.MeowConnectionReplaced {
|
||||
|
|
|
|||
2
go.mod
2
go.mod
|
|
@ -7,7 +7,7 @@ toolchain go1.24.5
|
|||
require (
|
||||
filippo.io/edwards25519 v1.1.0
|
||||
github.com/chzyer/readline v1.5.1
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/coder/websocket v1.8.13
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/mattn/go-sqlite3 v1.14.28
|
||||
github.com/rs/xid v1.6.0
|
||||
|
|
|
|||
4
go.sum
4
go.sum
|
|
@ -8,13 +8,13 @@ github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI
|
|||
github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk=
|
||||
github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04=
|
||||
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
|
||||
github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE=
|
||||
github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
|
||||
github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs=
|
||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue