Add Context to clients / sessions.

The Context will be closed when the client disconnects / the session is removed,
so any pending requests can be cancelled.
This commit is contained in:
Joachim Bauch 2024-05-14 17:02:51 +02:00
parent 94a8f0f02b
commit 0ee976d377
No known key found for this signature in database
GPG key ID: 77C1D22D53E15F02
8 changed files with 89 additions and 73 deletions

View file

@ -23,6 +23,7 @@ package signaling
import (
"bytes"
"context"
"encoding/json"
"log"
"strconv"
@ -93,6 +94,7 @@ type WritableClientMessage interface {
}
type HandlerClient interface {
Context() context.Context
RemoteAddr() string
Country() string
UserAgent() string
@ -121,6 +123,7 @@ type ClientGeoIpHandler interface {
}
type Client struct {
ctx context.Context
conn *websocket.Conn
addr string
agent string
@ -142,7 +145,7 @@ type Client struct {
messageChan chan *bytes.Buffer
}
func NewClient(conn *websocket.Conn, remoteAddress string, agent string, handler ClientHandler) (*Client, error) {
func NewClient(ctx context.Context, conn *websocket.Conn, remoteAddress string, agent string, handler ClientHandler) (*Client, error) {
remoteAddress = strings.TrimSpace(remoteAddress)
if remoteAddress == "" {
remoteAddress = "unknown remote address"
@ -153,6 +156,7 @@ func NewClient(conn *websocket.Conn, remoteAddress string, agent string, handler
}
client := &Client{
ctx: ctx,
agent: agent,
logRTT: true,
}
@ -181,6 +185,10 @@ func (c *Client) getHandler() ClientHandler {
return c.handler
}
func (c *Client) Context() context.Context {
return c.ctx
}
func (c *Client) IsConnected() bool {
return c.closed.Load() == 0
}

View file

@ -51,6 +51,8 @@ type ClientSession struct {
privateId string
publicId string
data *SessionIdData
ctx context.Context
closeFunc context.CancelFunc
clientType string
features []string
@ -91,12 +93,15 @@ type ClientSession struct {
}
func NewClientSession(hub *Hub, privateId string, publicId string, data *SessionIdData, backend *Backend, hello *HelloClientMessage, auth *BackendClientAuthResponse) (*ClientSession, error) {
ctx, closeFunc := context.WithCancel(context.Background())
s := &ClientSession{
hub: hub,
events: hub.events,
privateId: privateId,
publicId: publicId,
data: data,
ctx: ctx,
closeFunc: closeFunc,
clientType: hello.Auth.Type,
features: hello.Features,
@ -140,6 +145,10 @@ func NewClientSession(hub *Hub, privateId string, publicId string, data *Session
return s, nil
}
func (s *ClientSession) Context() context.Context {
return s.ctx
}
func (s *ClientSession) PrivateId() string {
return s.privateId
}
@ -337,7 +346,7 @@ func (s *ClientSession) getRoomJoinTime() time.Time {
func (s *ClientSession) releaseMcuObjects() {
if len(s.publishers) > 0 {
go func(publishers map[StreamType]McuPublisher) {
ctx := context.TODO()
ctx := context.Background()
for _, publisher := range publishers {
publisher.Close(ctx)
}
@ -346,7 +355,7 @@ func (s *ClientSession) releaseMcuObjects() {
}
if len(s.subscribers) > 0 {
go func(subscribers map[string]McuSubscriber) {
ctx := context.TODO()
ctx := context.Background()
for _, subscriber := range subscribers {
subscriber.Close(ctx)
}
@ -360,6 +369,7 @@ func (s *ClientSession) Close() {
}
func (s *ClientSession) closeAndWait(wait bool) {
s.closeFunc()
s.hub.removeSession(s)
s.mu.Lock()
@ -885,7 +895,7 @@ func (s *ClientSession) GetOrCreatePublisher(ctx context.Context, mcu Mcu, strea
if prev, found := s.publishers[streamType]; found {
// Another thread created the publisher while we were waiting.
go func(pub McuPublisher) {
closeCtx := context.TODO()
closeCtx := context.Background()
pub.Close(closeCtx)
}(publisher)
publisher = prev
@ -962,7 +972,7 @@ func (s *ClientSession) GetOrCreateSubscriber(ctx context.Context, mcu Mcu, id s
if prev, found := s.subscribers[getStreamId(id, streamType)]; found {
// Another thread created the subscriber while we were waiting.
go func(sub McuSubscriber) {
closeCtx := context.TODO()
closeCtx := context.Background()
sub.Close(closeCtx)
}(subscriber)
subscriber = prev
@ -1036,7 +1046,7 @@ func (s *ClientSession) processAsyncMessage(message *AsyncMessage) {
case "sendoffer":
// Process asynchronously to not block other messages received.
go func() {
ctx, cancel := context.WithTimeout(context.Background(), s.hub.mcuTimeout)
ctx, cancel := context.WithTimeout(s.Context(), s.hub.mcuTimeout)
defer cancel()
mc, err := s.GetOrCreateSubscriber(ctx, s.hub.mcu, message.SendOffer.SessionId, StreamType(message.SendOffer.Data.RoomType))
@ -1068,7 +1078,7 @@ func (s *ClientSession) processAsyncMessage(message *AsyncMessage) {
return
}
mc.SendMessage(context.TODO(), nil, message.SendOffer.Data, func(err error, response map[string]interface{}) {
mc.SendMessage(s.Context(), nil, message.SendOffer.Data, func(err error, response map[string]interface{}) {
if err != nil {
log.Printf("Could not send MCU message %+v for session %s to %s: %s", message.SendOffer.Data, message.SendOffer.SessionId, s.PublicId(), err)
if err := s.events.PublishSessionMessage(message.SendOffer.SessionId, s.backend, &AsyncMessage{

View file

@ -115,6 +115,10 @@ func (c *remoteGrpcClient) readPump() {
}
}
func (c *remoteGrpcClient) Context() context.Context {
return c.client.Context()
}
func (c *remoteGrpcClient) RemoteAddr() string {
return c.remoteAddr
}

113
hub.go
View file

@ -850,7 +850,7 @@ func (h *Hub) processRegister(c HandlerClient, message *ClientMessage, backend *
var totalCount atomic.Uint32
totalCount.Add(uint32(backend.Len()))
var wg sync.WaitGroup
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
ctx, cancel := context.WithTimeout(client.Context(), time.Second)
defer cancel()
for _, client := range h.rpcClients.GetClients() {
wg.Add(1)
@ -983,15 +983,15 @@ func (h *Hub) processMessage(client HandlerClient, data []byte) {
switch message.Type {
case "room":
h.processRoom(client, &message)
h.processRoom(session, &message)
case "message":
h.processMessageMsg(client, &message)
h.processMessageMsg(session, &message)
case "control":
h.processControlMsg(client, &message)
h.processControlMsg(session, &message)
case "internal":
h.processInternalMsg(client, &message)
h.processInternalMsg(session, &message)
case "transient":
h.processTransientMsg(client, &message)
h.processTransientMsg(session, &message)
case "bye":
h.processByeMsg(client, &message)
case "hello":
@ -1035,7 +1035,7 @@ func (h *Hub) tryProxyResume(c HandlerClient, resumeId string, message *ClientMe
return false
}
rpcCtx, rpcCancel := context.WithTimeout(context.Background(), 5*time.Second)
rpcCtx, rpcCancel := context.WithTimeout(c.Context(), 5*time.Second)
defer rpcCancel()
var wg sync.WaitGroup
@ -1174,7 +1174,7 @@ func (h *Hub) processHello(client HandlerClient, message *ClientMessage) {
}
}
func (h *Hub) processHelloV1(client HandlerClient, message *ClientMessage) (*Backend, *BackendClientResponse, error) {
func (h *Hub) processHelloV1(ctx context.Context, client HandlerClient, message *ClientMessage) (*Backend, *BackendClientResponse, error) {
url := message.Hello.Auth.parsedUrl
backend := h.backend.GetBackend(url)
if backend == nil {
@ -1182,7 +1182,7 @@ func (h *Hub) processHelloV1(client HandlerClient, message *ClientMessage) (*Bac
}
// Run in timeout context to prevent blocking too long.
ctx, cancel := context.WithTimeout(context.Background(), h.backendTimeout)
ctx, cancel := context.WithTimeout(ctx, h.backendTimeout)
defer cancel()
var auth BackendClientResponse
@ -1196,7 +1196,7 @@ func (h *Hub) processHelloV1(client HandlerClient, message *ClientMessage) (*Bac
return backend, &auth, nil
}
func (h *Hub) processHelloV2(client HandlerClient, message *ClientMessage) (*Backend, *BackendClientResponse, error) {
func (h *Hub) processHelloV2(ctx context.Context, client HandlerClient, message *ClientMessage) (*Backend, *BackendClientResponse, error) {
url := message.Hello.Auth.parsedUrl
backend := h.backend.GetBackend(url)
if backend == nil {
@ -1243,16 +1243,16 @@ func (h *Hub) processHelloV2(client HandlerClient, message *ClientMessage) (*Bac
}
// Run in timeout context to prevent blocking too long.
ctx, cancel := context.WithTimeout(context.Background(), h.backendTimeout)
backendCtx, cancel := context.WithTimeout(ctx, h.backendTimeout)
defer cancel()
keyData, cached, found := h.backend.capabilities.GetStringConfig(ctx, url, ConfigGroupSignaling, ConfigKeyHelloV2TokenKey)
keyData, cached, found := h.backend.capabilities.GetStringConfig(backendCtx, url, ConfigGroupSignaling, ConfigKeyHelloV2TokenKey)
if !found {
if cached {
// The Nextcloud instance might just have enabled JWT but we probably use
// the cached capabilities without the public key. Make sure to re-fetch.
h.backend.capabilities.InvalidateCapabilities(url)
keyData, _, found = h.backend.capabilities.GetStringConfig(ctx, url, ConfigGroupSignaling, ConfigKeyHelloV2TokenKey)
keyData, _, found = h.backend.capabilities.GetStringConfig(backendCtx, url, ConfigGroupSignaling, ConfigKeyHelloV2TokenKey)
}
if !found {
return nil, fmt.Errorf("No key found for issuer")
@ -1306,7 +1306,7 @@ func (h *Hub) processHelloClient(client HandlerClient, message *ClientMessage) {
// Make sure the client must send another "hello" in case of errors.
defer h.startExpectHello(client)
var authFunc func(HandlerClient, *ClientMessage) (*Backend, *BackendClientResponse, error)
var authFunc func(context.Context, HandlerClient, *ClientMessage) (*Backend, *BackendClientResponse, error)
switch message.Hello.Version {
case HelloVersionV1:
// Auth information contains a ticket that must be validated against the
@ -1320,7 +1320,7 @@ func (h *Hub) processHelloClient(client HandlerClient, message *ClientMessage) {
return
}
backend, auth, err := authFunc(client, message)
backend, auth, err := authFunc(client.Context(), client, message)
if err != nil {
if e, ok := err.(*Error); ok {
client.SendMessage(message.NewErrorServerMessage(e))
@ -1422,18 +1422,14 @@ func (h *Hub) sendRoom(session *ClientSession, message *ClientMessage, room *Roo
return session.SendMessage(response)
}
func (h *Hub) processRoom(client HandlerClient, message *ClientMessage) {
session, ok := client.GetSession().(*ClientSession)
func (h *Hub) processRoom(sess Session, message *ClientMessage) {
session, ok := sess.(*ClientSession)
if !ok {
return
}
roomId := message.Room.RoomId
if roomId == "" {
if session == nil {
return
}
// We can handle leaving a room directly.
if session.LeaveRoom(true) != nil {
// User was in a room before, so need to notify about leaving it.
@ -1446,13 +1442,6 @@ func (h *Hub) processRoom(client HandlerClient, message *ClientMessage) {
return
}
if session == nil {
session.SendMessage(message.NewErrorServerMessage(
NewError("not_authenticated", "Need to authenticate before joining rooms."),
))
return
}
if room := h.getRoomForBackend(roomId, session.Backend()); room != nil && room.HasSession(session) {
// Session already is in that room, no action needed.
roomSessionId := message.Room.SessionId
@ -1487,7 +1476,7 @@ func (h *Hub) processRoom(client HandlerClient, message *ClientMessage) {
}
} else {
// Run in timeout context to prevent blocking too long.
ctx, cancel := context.WithTimeout(context.Background(), h.backendTimeout)
ctx, cancel := context.WithTimeout(session.Context(), h.backendTimeout)
defer cancel()
sessionId := message.Room.SessionId
@ -1507,7 +1496,7 @@ func (h *Hub) processRoom(client HandlerClient, message *ClientMessage) {
if message.Room.SessionId != "" {
// There can only be one connection per Nextcloud Talk session,
// disconnect any other connections without sending a "leave" event.
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
ctx, cancel := context.WithTimeout(session.Context(), time.Second)
defer cancel()
h.disconnectByRoomSessionId(ctx, message.Room.SessionId, session.Backend())
@ -1600,9 +1589,9 @@ func (h *Hub) processJoinRoom(session *ClientSession, message *ClientMessage, ro
r.AddSession(session, room.Room.Session)
}
func (h *Hub) processMessageMsg(client HandlerClient, message *ClientMessage) {
session, ok := client.GetSession().(*ClientSession)
if session == nil || !ok {
func (h *Hub) processMessageMsg(sess Session, message *ClientMessage) {
session, ok := sess.(*ClientSession)
if !ok {
// Client is not connected yet.
return
}
@ -1654,10 +1643,13 @@ func (h *Hub) processMessageMsg(client HandlerClient, message *ClientMessage) {
// User is stopping to share his screen. Firefox doesn't properly clean
// up the peer connections in all cases, so make sure to stop publishing
// in the MCU.
go func(c HandlerClient) {
time.Sleep(cleanupScreenPublisherDelay)
session, ok := c.GetSession().(*ClientSession)
if session == nil || !ok {
go func(session *ClientSession) {
sleepCtx, cancel := context.WithTimeout(session.Context(), cleanupScreenPublisherDelay)
defer cancel()
<-sleepCtx.Done()
if session.Context().Err() != nil {
// Session was closed while waiting.
return
}
@ -1670,7 +1662,7 @@ func (h *Hub) processMessageMsg(client HandlerClient, message *ClientMessage) {
ctx, cancel := context.WithTimeout(context.Background(), h.mcuTimeout)
defer cancel()
publisher.Close(ctx)
}(client)
}(session)
}
}
}
@ -1778,7 +1770,7 @@ func (h *Hub) processMessageMsg(client HandlerClient, message *ClientMessage) {
// client) to start his stream, so we must not block the active
// goroutine.
go func() {
ctx, cancel := context.WithTimeout(context.Background(), h.mcuTimeout)
ctx, cancel := context.WithTimeout(session.Context(), h.mcuTimeout)
defer cancel()
mc, err := recipient.GetOrCreateSubscriber(ctx, h.mcu, session.PublicId(), StreamType(clientData.RoomType))
@ -1792,7 +1784,7 @@ func (h *Hub) processMessageMsg(client HandlerClient, message *ClientMessage) {
return
}
mc.SendMessage(context.TODO(), msg, clientData, func(err error, response map[string]interface{}) {
mc.SendMessage(session.Context(), msg, clientData, func(err error, response map[string]interface{}) {
if err != nil {
log.Printf("Could not send MCU message %+v for session %s to %s: %s", clientData, session.PublicId(), recipient.PublicId(), err)
sendMcuProcessingFailed(session, message)
@ -1870,13 +1862,9 @@ func isAllowedToControl(session Session) bool {
return false
}
func (h *Hub) processControlMsg(client HandlerClient, message *ClientMessage) {
func (h *Hub) processControlMsg(session Session, message *ClientMessage) {
msg := message.Control
session := client.GetSession()
if session == nil {
// Client is not connected yet.
return
} else if !isAllowedToControl(session) {
if !isAllowedToControl(session) {
log.Printf("Ignore control message %+v from %s", msg, session.PublicId())
return
}
@ -1983,10 +1971,10 @@ func (h *Hub) processControlMsg(client HandlerClient, message *ClientMessage) {
}
}
func (h *Hub) processInternalMsg(client HandlerClient, message *ClientMessage) {
func (h *Hub) processInternalMsg(sess Session, message *ClientMessage) {
msg := message.Internal
session, ok := client.GetSession().(*ClientSession)
if session == nil || !ok {
session, ok := sess.(*ClientSession)
if !ok {
// Client is not connected yet.
return
} else if session.ClientType() != HelloClientTypeInternal {
@ -2019,7 +2007,7 @@ func (h *Hub) processInternalMsg(client HandlerClient, message *ClientMessage) {
return
}
ctx, cancel := context.WithTimeout(context.Background(), h.backendTimeout)
ctx, cancel := context.WithTimeout(session.Context(), h.backendTimeout)
defer cancel()
virtualSessionId := GetVirtualSessionId(session, msg.SessionId)
@ -2200,14 +2188,7 @@ func isAllowedToUpdateTransientData(session Session) bool {
return false
}
func (h *Hub) processTransientMsg(client HandlerClient, message *ClientMessage) {
msg := message.TransientData
session := client.GetSession()
if session == nil {
// Client is not connected yet.
return
}
func (h *Hub) processTransientMsg(session Session, message *ClientMessage) {
room := session.GetRoom()
if room == nil {
response := message.NewErrorServerMessage(NewError("not_in_room", "No room joined yet."))
@ -2215,6 +2196,7 @@ func (h *Hub) processTransientMsg(client HandlerClient, message *ClientMessage)
return
}
msg := message.TransientData
switch msg.Type {
case "set":
if !isAllowedToUpdateTransientData(session) {
@ -2318,7 +2300,7 @@ func (h *Hub) isInSameCall(ctx context.Context, senderSession *ClientSession, re
}
func (h *Hub) processMcuMessage(session *ClientSession, client_message *ClientMessage, message *MessageClientMessage, data *MessageClientMessageData) {
ctx, cancel := context.WithTimeout(context.Background(), h.mcuTimeout)
ctx, cancel := context.WithTimeout(session.Context(), h.mcuTimeout)
defer cancel()
var mc McuClient
@ -2390,7 +2372,7 @@ func (h *Hub) processMcuMessage(session *ClientSession, client_message *ClientMe
return
}
mc.SendMessage(context.TODO(), message, data, func(err error, response map[string]interface{}) {
mc.SendMessage(session.Context(), message, data, func(err error, response map[string]interface{}) {
if err != nil {
log.Printf("Could not send MCU message %+v for session %s to %s: %s", data, session.PublicId(), message.Recipient.SessionId, err)
sendMcuProcessingFailed(session, client_message)
@ -2563,7 +2545,7 @@ func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) {
return
}
client, err := NewClient(conn, addr, agent, h)
client, err := NewClient(r.Context(), conn, addr, agent, h)
if err != nil {
log.Printf("Could not create client for %s: %s", addr, err)
return
@ -2575,11 +2557,10 @@ func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) {
defer h.writePumpActive.Add(-1)
client.WritePump()
}(h)
go func(h *Hub) {
h.readPumpActive.Add(1)
defer h.readPumpActive.Add(-1)
client.ReadPump()
}(h)
h.readPumpActive.Add(1)
defer h.readPumpActive.Add(-1)
client.ReadPump()
}
func (h *Hub) OnLookupCountry(client HandlerClient) string {

View file

@ -51,6 +51,8 @@ func NewRemoteSession(hub *Hub, client *Client, remoteClient *GrpcClient, sessio
client.SetSessionId(sessionId)
client.SetHandler(remoteSession)
// Don't use "client.Context()" here as it could close the proxy connection
// before any final messages are forwarded to the remote end.
proxy, err := remoteClient.ProxySession(context.Background(), sessionId, remoteSession)
if err != nil {
return nil, err

View file

@ -22,6 +22,7 @@
package signaling
import (
"context"
"encoding/json"
"errors"
"net/url"
@ -32,6 +33,10 @@ type DummySession struct {
publicId string
}
func (s *DummySession) Context() context.Context {
return context.Background()
}
func (s *DummySession) PrivateId() string {
return ""
}

View file

@ -22,6 +22,7 @@
package signaling
import (
"context"
"encoding/json"
"net/url"
"time"
@ -53,6 +54,7 @@ type SessionIdData struct {
}
type Session interface {
Context() context.Context
PrivateId() string
PublicId() string
ClientType() string

View file

@ -85,6 +85,10 @@ func NewVirtualSession(session *ClientSession, privateId string, publicId string
return result, nil
}
func (s *VirtualSession) Context() context.Context {
return s.session.Context()
}
func (s *VirtualSession) PrivateId() string {
return s.privateId
}