diff --git a/src/signaling/client.go b/src/signaling/client.go index e1535c4..3523792 100644 --- a/src/signaling/client.go +++ b/src/signaling/client.go @@ -25,7 +25,6 @@ import ( "bytes" "encoding/json" "log" - "net" "strconv" "strings" "sync" @@ -52,14 +51,11 @@ const ( ) var ( - _noCountry string = "no-country" - noCountry *string = &_noCountry + noCountry string = "no-country" - _loopback string = "loopback" - loopback *string = &_loopback + loopback string = "loopback" - _unknownCountry string = "unknown-country" - unknownCountry *string = &_unknownCountry + unknownCountry string = "unknown-country" ) var ( @@ -72,8 +68,13 @@ var ( } ) +type WritableClientMessage interface { + json.Marshaler + + CloseAfterSend(session Session) bool +} + type Client struct { - hub *Hub conn *websocket.Conn addr string agent string @@ -85,9 +86,13 @@ type Client struct { mu sync.Mutex closeChan chan bool + + OnLookupCountry func(*Client) string + OnClosed func(*Client) + OnMessageReceived func(*Client, []byte) } -func NewClient(hub *Hub, conn *websocket.Conn, remoteAddress string, agent string) (*Client, error) { +func NewClient(conn *websocket.Conn, remoteAddress string, agent string) (*Client, error) { remoteAddress = strings.TrimSpace(remoteAddress) if remoteAddress == "" { remoteAddress = "unknown remote address" @@ -97,15 +102,27 @@ func NewClient(hub *Hub, conn *websocket.Conn, remoteAddress string, agent strin agent = "unknown user agent" } client := &Client{ - hub: hub, conn: conn, addr: remoteAddress, agent: agent, closeChan: make(chan bool, 1), + + OnLookupCountry: func(client *Client) string { return unknownCountry }, + OnClosed: func(client *Client) {}, + OnMessageReceived: func(client *Client, data []byte) {}, } return client, nil } +func (c *Client) SetConn(conn *websocket.Conn, remoteAddress string) { + c.conn = conn + c.addr = remoteAddress + c.closeChan = make(chan bool, 1) + c.OnLookupCountry = func(client *Client) string { return unknownCountry } + c.OnClosed = func(client *Client) {} + c.OnMessageReceived = func(client *Client, data []byte) {} +} + func (c *Client) IsConnected() bool { return atomic.LoadUint32(&c.closed) == 0 } @@ -132,25 +149,7 @@ func (c *Client) UserAgent() string { func (c *Client) Country() string { if c.country == nil { - if c.hub.geoip == nil { - c.country = unknownCountry - return *c.country - } - ip := net.ParseIP(c.RemoteAddr()) - if ip == nil { - c.country = noCountry - return *c.country - } else if ip.IsLoopback() { - c.country = loopback - return *c.country - } - - country, err := c.hub.geoip.LookupCountry(ip) - if err != nil { - log.Printf("Could not lookup country for %s", ip) - c.country = unknownCountry - return *c.country - } + country := c.OnLookupCountry(c) c.country = &country } @@ -164,7 +163,7 @@ func (c *Client) Close() { c.closeChan <- true - c.hub.processUnregister(c) + c.OnClosed(c) c.SetSession(nil) c.mu.Lock() @@ -183,41 +182,6 @@ func (c *Client) SendError(e *Error) bool { return c.SendMessage(message) } -func (c *Client) SendRoom(message *ClientMessage, room *Room) bool { - response := &ServerMessage{ - Type: "room", - } - if message != nil { - response.Id = message.Id - } - if room == nil { - response.Room = &RoomServerMessage{ - RoomId: "", - } - } else { - response.Room = &RoomServerMessage{ - RoomId: room.id, - Properties: room.properties, - } - } - return c.SendMessage(response) -} - -func (c *Client) SendHelloResponse(message *ClientMessage, session *ClientSession) bool { - response := &ServerMessage{ - Id: message.Id, - Type: "hello", - Hello: &HelloServerMessage{ - Version: HelloVersion, - SessionId: session.PublicId(), - ResumeId: session.PrivateId(), - UserId: session.UserId(), - Server: c.hub.GetServerInfo(), - }, - } - return c.SendMessage(response) -} - func (c *Client) SendByeResponse(message *ClientMessage) bool { return c.SendByeResponseWithReason(message, "") } @@ -236,11 +200,11 @@ func (c *Client) SendByeResponseWithReason(message *ClientMessage, reason string return c.SendMessage(response) } -func (c *Client) SendMessage(message *ServerMessage) bool { +func (c *Client) SendMessage(message WritableClientMessage) bool { return c.writeMessage(message) } -func (c *Client) readPump() { +func (c *Client) ReadPump() { defer func() { c.Close() }() @@ -312,28 +276,7 @@ func (c *Client) readPump() { break } - var message ClientMessage - if err := message.UnmarshalJSON(decodeBuffer.Bytes()); err != nil { - if session := c.GetSession(); session != nil { - log.Printf("Error decoding message from client %s: %v", session.PublicId(), err) - } else { - log.Printf("Error decoding message from %s: %v", addr, err) - } - c.SendError(InvalidFormat) - continue - } - - if err := message.CheckValid(); err != nil { - if session := c.GetSession(); session != nil { - log.Printf("Invalid message %+v from client %s: %v", message, session.PublicId(), err) - } else { - log.Printf("Invalid message %+v from %s: %v", message, addr, err) - } - c.SendMessage(message.NewErrorServerMessage(InvalidFormat)) - continue - } - - c.hub.processMessage(c, &message) + c.OnMessageReceived(c, decodeBuffer.Bytes()) } } @@ -407,7 +350,7 @@ func (c *Client) writeError(e error) bool { return false } -func (c *Client) writeMessage(message *ServerMessage) bool { +func (c *Client) writeMessage(message WritableClientMessage) bool { c.mu.Lock() defer c.mu.Unlock() if c.conn == nil { @@ -417,7 +360,7 @@ func (c *Client) writeMessage(message *ServerMessage) bool { return c.writeMessageLocked(message) } -func (c *Client) writeMessageLocked(message *ServerMessage) bool { +func (c *Client) writeMessageLocked(message WritableClientMessage) bool { if !c.writeInternal(message) { return false } @@ -458,7 +401,7 @@ func (c *Client) sendPing() bool { return true } -func (c *Client) writePump() { +func (c *Client) WritePump() { ticker := time.NewTicker(pingPeriod) defer func() { ticker.Stop() diff --git a/src/signaling/hub.go b/src/signaling/hub.go index 1747081..0bbd173 100644 --- a/src/signaling/hub.go +++ b/src/signaling/hub.go @@ -30,6 +30,7 @@ import ( "fmt" "hash/fnv" "log" + "net" "net/http" "strings" "sync" @@ -633,7 +634,7 @@ func (h *Hub) processRegister(client *Client, message *ClientMessage, backend *B h.setDecodedSessionId(privateSessionId, privateSessionName, sessionIdData) h.setDecodedSessionId(publicSessionId, publicSessionName, sessionIdData) - client.SendHelloResponse(message, session) + h.sendHelloResponse(client, message, session) } func (h *Hub) processUnregister(client *Client) *ClientSession { @@ -656,7 +657,28 @@ func (h *Hub) processUnregister(client *Client) *ClientSession { return session } -func (h *Hub) processMessage(client *Client, message *ClientMessage) { +func (h *Hub) processMessage(client *Client, data []byte) { + var message ClientMessage + if err := message.UnmarshalJSON(data); err != nil { + if session := client.GetSession(); session != nil { + log.Printf("Error decoding message from client %s: %v", session.PublicId(), err) + } else { + log.Printf("Error decoding message from %s: %v", client.RemoteAddr(), err) + } + client.SendError(InvalidFormat) + return + } + + if err := message.CheckValid(); err != nil { + if session := client.GetSession(); session != nil { + log.Printf("Invalid message %+v from client %s: %v", message, session.PublicId(), err) + } else { + log.Printf("Invalid message %+v from %s: %v", message, client.RemoteAddr(), err) + } + client.SendMessage(message.NewErrorServerMessage(InvalidFormat)) + return + } + session := client.GetSession() if session == nil { if message.Type != "hello" { @@ -664,19 +686,19 @@ func (h *Hub) processMessage(client *Client, message *ClientMessage) { return } - h.processHello(client, message) + h.processHello(client, &message) return } switch message.Type { case "room": - h.processRoom(client, message) + h.processRoom(client, &message) case "message": - h.processMessageMsg(client, message) + h.processMessageMsg(client, &message) case "control": - h.processControlMsg(client, message) + h.processControlMsg(client, &message) case "bye": - h.processByeMsg(client, message) + h.processByeMsg(client, &message) case "hello": log.Printf("Ignore hello %+v for already authenticated connection %s", message.Hello, session.PublicId()) default: @@ -684,6 +706,21 @@ func (h *Hub) processMessage(client *Client, message *ClientMessage) { } } +func (h *Hub) sendHelloResponse(client *Client, message *ClientMessage, session *ClientSession) bool { + response := &ServerMessage{ + Id: message.Id, + Type: "hello", + Hello: &HelloServerMessage{ + Version: HelloVersion, + SessionId: session.PublicId(), + ResumeId: session.PrivateId(), + UserId: session.UserId(), + Server: h.GetServerInfo(), + }, + } + return client.SendMessage(response) +} + func (h *Hub) processHello(client *Client, message *ClientMessage) { resumeId := message.Hello.ResumeId if resumeId != "" { @@ -728,7 +765,7 @@ func (h *Hub) processHello(client *Client, message *ClientMessage) { log.Printf("Resume session from %s in %s (%s) %s (private=%s)", client.RemoteAddr(), client.Country(), client.UserAgent(), session.PublicId(), session.PrivateId()) - client.SendHelloResponse(message, clientSession) + h.sendHelloResponse(client, message, clientSession) clientSession.NotifySessionResumed(client) return } @@ -839,6 +876,26 @@ func (h *Hub) disconnectByRoomSessionId(roomSessionId string) { session.Close() } +func (h *Hub) sendRoom(client *Client, message *ClientMessage, room *Room) bool { + response := &ServerMessage{ + Type: "room", + } + if message != nil { + response.Id = message.Id + } + if room == nil { + response.Room = &RoomServerMessage{ + RoomId: "", + } + } else { + response.Room = &RoomServerMessage{ + RoomId: room.id, + Properties: room.properties, + } + } + return client.SendMessage(response) +} + func (h *Hub) processRoom(client *Client, message *ClientMessage) { session := client.GetSession() roomId := message.Room.RoomId @@ -850,7 +907,7 @@ func (h *Hub) processRoom(client *Client, message *ClientMessage) { // We can handle leaving a room directly. if session.LeaveRoom(true) != nil { // User was in a room before, so need to notify about leaving it. - client.SendRoom(message, nil) + h.sendRoom(client, message, nil) } if session.UserId() == "" && session.ClientType() != HelloClientTypeInternal { h.startWaitAnonymousClientRoom(client) @@ -965,7 +1022,7 @@ func (h *Hub) processJoinRoom(client *Client, message *ClientMessage, room *Back if err := session.SubscribeRoomNats(h.nats, roomId, message.Room.SessionId); err != nil { client.SendMessage(message.NewWrappedErrorServerMessage(err)) // The client (implicitly) left the room due to an error. - client.SendRoom(nil, nil) + h.sendRoom(client, nil, nil) return } @@ -978,7 +1035,7 @@ func (h *Hub) processJoinRoom(client *Client, message *ClientMessage, room *Back client.SendMessage(message.NewWrappedErrorServerMessage(err)) // The client (implicitly) left the room due to an error. session.UnsubscribeRoomNats() - client.SendRoom(nil, nil) + h.sendRoom(client, nil, nil) return } } @@ -992,7 +1049,7 @@ func (h *Hub) processJoinRoom(client *Client, message *ClientMessage, room *Back if room.Room.Permissions != nil { session.SetPermissions(*room.Room.Permissions) } - client.SendRoom(message, r) + h.sendRoom(client, message, r) h.notifyUserJoinedRoom(r, client, session, room.Room.Session) } @@ -1427,7 +1484,7 @@ func (h *Hub) processRoomDeleted(message *BackendServerRoomRequest) { switch sess := session.(type) { case *ClientSession: if client := sess.GetClient(); client != nil { - client.SendRoom(nil, nil) + h.sendRoom(client, nil, nil) } } } @@ -1477,6 +1534,23 @@ func getRealUserIP(r *http.Request) string { return r.RemoteAddr } +func (h *Hub) lookupClientCountry(client *Client) string { + ip := net.ParseIP(client.RemoteAddr()) + if ip == nil { + return noCountry + } else if ip.IsLoopback() { + return loopback + } + + country, err := h.geoip.LookupCountry(ip) + if err != nil { + log.Printf("Could not lookup country for %s", ip) + return unknownCountry + } + + return country +} + func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) { addr := getRealUserIP(r) agent := r.Header.Get("User-Agent") @@ -1487,13 +1561,21 @@ func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) { return } - client, err := NewClient(h, conn, addr, agent) + client, err := NewClient(conn, addr, agent) if err != nil { log.Printf("Could not create client for %s: %s", addr, err) return } + if h.geoip != nil { + client.OnLookupCountry = h.lookupClientCountry + } + client.OnMessageReceived = h.processMessage + client.OnClosed = func(client *Client) { + h.processUnregister(client) + } + h.processNewClient(client) - go client.writePump() - go client.readPump() + go client.WritePump() + go client.ReadPump() }