Move some signaling-specific functions of client to hub to allow reuse.

This commit is contained in:
Joachim Bauch 2020-08-07 10:22:27 +02:00
parent e14ba2f39f
commit 5a553fcc2d
Failed to extract signature
2 changed files with 133 additions and 108 deletions

View file

@ -25,7 +25,6 @@ import (
"bytes"
"encoding/json"
"log"
"net"
"strconv"
"strings"
"sync"
@ -52,14 +51,11 @@ const (
)
var (
_noCountry string = "no-country"
noCountry *string = &_noCountry
noCountry string = "no-country"
_loopback string = "loopback"
loopback *string = &_loopback
loopback string = "loopback"
_unknownCountry string = "unknown-country"
unknownCountry *string = &_unknownCountry
unknownCountry string = "unknown-country"
)
var (
@ -72,8 +68,13 @@ var (
}
)
type WritableClientMessage interface {
json.Marshaler
CloseAfterSend(session Session) bool
}
type Client struct {
hub *Hub
conn *websocket.Conn
addr string
agent string
@ -85,9 +86,13 @@ type Client struct {
mu sync.Mutex
closeChan chan bool
OnLookupCountry func(*Client) string
OnClosed func(*Client)
OnMessageReceived func(*Client, []byte)
}
func NewClient(hub *Hub, conn *websocket.Conn, remoteAddress string, agent string) (*Client, error) {
func NewClient(conn *websocket.Conn, remoteAddress string, agent string) (*Client, error) {
remoteAddress = strings.TrimSpace(remoteAddress)
if remoteAddress == "" {
remoteAddress = "unknown remote address"
@ -97,15 +102,27 @@ func NewClient(hub *Hub, conn *websocket.Conn, remoteAddress string, agent strin
agent = "unknown user agent"
}
client := &Client{
hub: hub,
conn: conn,
addr: remoteAddress,
agent: agent,
closeChan: make(chan bool, 1),
OnLookupCountry: func(client *Client) string { return unknownCountry },
OnClosed: func(client *Client) {},
OnMessageReceived: func(client *Client, data []byte) {},
}
return client, nil
}
func (c *Client) SetConn(conn *websocket.Conn, remoteAddress string) {
c.conn = conn
c.addr = remoteAddress
c.closeChan = make(chan bool, 1)
c.OnLookupCountry = func(client *Client) string { return unknownCountry }
c.OnClosed = func(client *Client) {}
c.OnMessageReceived = func(client *Client, data []byte) {}
}
func (c *Client) IsConnected() bool {
return atomic.LoadUint32(&c.closed) == 0
}
@ -132,25 +149,7 @@ func (c *Client) UserAgent() string {
func (c *Client) Country() string {
if c.country == nil {
if c.hub.geoip == nil {
c.country = unknownCountry
return *c.country
}
ip := net.ParseIP(c.RemoteAddr())
if ip == nil {
c.country = noCountry
return *c.country
} else if ip.IsLoopback() {
c.country = loopback
return *c.country
}
country, err := c.hub.geoip.LookupCountry(ip)
if err != nil {
log.Printf("Could not lookup country for %s", ip)
c.country = unknownCountry
return *c.country
}
country := c.OnLookupCountry(c)
c.country = &country
}
@ -164,7 +163,7 @@ func (c *Client) Close() {
c.closeChan <- true
c.hub.processUnregister(c)
c.OnClosed(c)
c.SetSession(nil)
c.mu.Lock()
@ -183,41 +182,6 @@ func (c *Client) SendError(e *Error) bool {
return c.SendMessage(message)
}
func (c *Client) SendRoom(message *ClientMessage, room *Room) bool {
response := &ServerMessage{
Type: "room",
}
if message != nil {
response.Id = message.Id
}
if room == nil {
response.Room = &RoomServerMessage{
RoomId: "",
}
} else {
response.Room = &RoomServerMessage{
RoomId: room.id,
Properties: room.properties,
}
}
return c.SendMessage(response)
}
func (c *Client) SendHelloResponse(message *ClientMessage, session *ClientSession) bool {
response := &ServerMessage{
Id: message.Id,
Type: "hello",
Hello: &HelloServerMessage{
Version: HelloVersion,
SessionId: session.PublicId(),
ResumeId: session.PrivateId(),
UserId: session.UserId(),
Server: c.hub.GetServerInfo(),
},
}
return c.SendMessage(response)
}
func (c *Client) SendByeResponse(message *ClientMessage) bool {
return c.SendByeResponseWithReason(message, "")
}
@ -236,11 +200,11 @@ func (c *Client) SendByeResponseWithReason(message *ClientMessage, reason string
return c.SendMessage(response)
}
func (c *Client) SendMessage(message *ServerMessage) bool {
func (c *Client) SendMessage(message WritableClientMessage) bool {
return c.writeMessage(message)
}
func (c *Client) readPump() {
func (c *Client) ReadPump() {
defer func() {
c.Close()
}()
@ -312,28 +276,7 @@ func (c *Client) readPump() {
break
}
var message ClientMessage
if err := message.UnmarshalJSON(decodeBuffer.Bytes()); err != nil {
if session := c.GetSession(); session != nil {
log.Printf("Error decoding message from client %s: %v", session.PublicId(), err)
} else {
log.Printf("Error decoding message from %s: %v", addr, err)
}
c.SendError(InvalidFormat)
continue
}
if err := message.CheckValid(); err != nil {
if session := c.GetSession(); session != nil {
log.Printf("Invalid message %+v from client %s: %v", message, session.PublicId(), err)
} else {
log.Printf("Invalid message %+v from %s: %v", message, addr, err)
}
c.SendMessage(message.NewErrorServerMessage(InvalidFormat))
continue
}
c.hub.processMessage(c, &message)
c.OnMessageReceived(c, decodeBuffer.Bytes())
}
}
@ -407,7 +350,7 @@ func (c *Client) writeError(e error) bool {
return false
}
func (c *Client) writeMessage(message *ServerMessage) bool {
func (c *Client) writeMessage(message WritableClientMessage) bool {
c.mu.Lock()
defer c.mu.Unlock()
if c.conn == nil {
@ -417,7 +360,7 @@ func (c *Client) writeMessage(message *ServerMessage) bool {
return c.writeMessageLocked(message)
}
func (c *Client) writeMessageLocked(message *ServerMessage) bool {
func (c *Client) writeMessageLocked(message WritableClientMessage) bool {
if !c.writeInternal(message) {
return false
}
@ -458,7 +401,7 @@ func (c *Client) sendPing() bool {
return true
}
func (c *Client) writePump() {
func (c *Client) WritePump() {
ticker := time.NewTicker(pingPeriod)
defer func() {
ticker.Stop()

View file

@ -30,6 +30,7 @@ import (
"fmt"
"hash/fnv"
"log"
"net"
"net/http"
"strings"
"sync"
@ -633,7 +634,7 @@ func (h *Hub) processRegister(client *Client, message *ClientMessage, backend *B
h.setDecodedSessionId(privateSessionId, privateSessionName, sessionIdData)
h.setDecodedSessionId(publicSessionId, publicSessionName, sessionIdData)
client.SendHelloResponse(message, session)
h.sendHelloResponse(client, message, session)
}
func (h *Hub) processUnregister(client *Client) *ClientSession {
@ -656,7 +657,28 @@ func (h *Hub) processUnregister(client *Client) *ClientSession {
return session
}
func (h *Hub) processMessage(client *Client, message *ClientMessage) {
func (h *Hub) processMessage(client *Client, data []byte) {
var message ClientMessage
if err := message.UnmarshalJSON(data); err != nil {
if session := client.GetSession(); session != nil {
log.Printf("Error decoding message from client %s: %v", session.PublicId(), err)
} else {
log.Printf("Error decoding message from %s: %v", client.RemoteAddr(), err)
}
client.SendError(InvalidFormat)
return
}
if err := message.CheckValid(); err != nil {
if session := client.GetSession(); session != nil {
log.Printf("Invalid message %+v from client %s: %v", message, session.PublicId(), err)
} else {
log.Printf("Invalid message %+v from %s: %v", message, client.RemoteAddr(), err)
}
client.SendMessage(message.NewErrorServerMessage(InvalidFormat))
return
}
session := client.GetSession()
if session == nil {
if message.Type != "hello" {
@ -664,19 +686,19 @@ func (h *Hub) processMessage(client *Client, message *ClientMessage) {
return
}
h.processHello(client, message)
h.processHello(client, &message)
return
}
switch message.Type {
case "room":
h.processRoom(client, message)
h.processRoom(client, &message)
case "message":
h.processMessageMsg(client, message)
h.processMessageMsg(client, &message)
case "control":
h.processControlMsg(client, message)
h.processControlMsg(client, &message)
case "bye":
h.processByeMsg(client, message)
h.processByeMsg(client, &message)
case "hello":
log.Printf("Ignore hello %+v for already authenticated connection %s", message.Hello, session.PublicId())
default:
@ -684,6 +706,21 @@ func (h *Hub) processMessage(client *Client, message *ClientMessage) {
}
}
func (h *Hub) sendHelloResponse(client *Client, message *ClientMessage, session *ClientSession) bool {
response := &ServerMessage{
Id: message.Id,
Type: "hello",
Hello: &HelloServerMessage{
Version: HelloVersion,
SessionId: session.PublicId(),
ResumeId: session.PrivateId(),
UserId: session.UserId(),
Server: h.GetServerInfo(),
},
}
return client.SendMessage(response)
}
func (h *Hub) processHello(client *Client, message *ClientMessage) {
resumeId := message.Hello.ResumeId
if resumeId != "" {
@ -728,7 +765,7 @@ func (h *Hub) processHello(client *Client, message *ClientMessage) {
log.Printf("Resume session from %s in %s (%s) %s (private=%s)", client.RemoteAddr(), client.Country(), client.UserAgent(), session.PublicId(), session.PrivateId())
client.SendHelloResponse(message, clientSession)
h.sendHelloResponse(client, message, clientSession)
clientSession.NotifySessionResumed(client)
return
}
@ -839,6 +876,26 @@ func (h *Hub) disconnectByRoomSessionId(roomSessionId string) {
session.Close()
}
func (h *Hub) sendRoom(client *Client, message *ClientMessage, room *Room) bool {
response := &ServerMessage{
Type: "room",
}
if message != nil {
response.Id = message.Id
}
if room == nil {
response.Room = &RoomServerMessage{
RoomId: "",
}
} else {
response.Room = &RoomServerMessage{
RoomId: room.id,
Properties: room.properties,
}
}
return client.SendMessage(response)
}
func (h *Hub) processRoom(client *Client, message *ClientMessage) {
session := client.GetSession()
roomId := message.Room.RoomId
@ -850,7 +907,7 @@ func (h *Hub) processRoom(client *Client, message *ClientMessage) {
// We can handle leaving a room directly.
if session.LeaveRoom(true) != nil {
// User was in a room before, so need to notify about leaving it.
client.SendRoom(message, nil)
h.sendRoom(client, message, nil)
}
if session.UserId() == "" && session.ClientType() != HelloClientTypeInternal {
h.startWaitAnonymousClientRoom(client)
@ -965,7 +1022,7 @@ func (h *Hub) processJoinRoom(client *Client, message *ClientMessage, room *Back
if err := session.SubscribeRoomNats(h.nats, roomId, message.Room.SessionId); err != nil {
client.SendMessage(message.NewWrappedErrorServerMessage(err))
// The client (implicitly) left the room due to an error.
client.SendRoom(nil, nil)
h.sendRoom(client, nil, nil)
return
}
@ -978,7 +1035,7 @@ func (h *Hub) processJoinRoom(client *Client, message *ClientMessage, room *Back
client.SendMessage(message.NewWrappedErrorServerMessage(err))
// The client (implicitly) left the room due to an error.
session.UnsubscribeRoomNats()
client.SendRoom(nil, nil)
h.sendRoom(client, nil, nil)
return
}
}
@ -992,7 +1049,7 @@ func (h *Hub) processJoinRoom(client *Client, message *ClientMessage, room *Back
if room.Room.Permissions != nil {
session.SetPermissions(*room.Room.Permissions)
}
client.SendRoom(message, r)
h.sendRoom(client, message, r)
h.notifyUserJoinedRoom(r, client, session, room.Room.Session)
}
@ -1427,7 +1484,7 @@ func (h *Hub) processRoomDeleted(message *BackendServerRoomRequest) {
switch sess := session.(type) {
case *ClientSession:
if client := sess.GetClient(); client != nil {
client.SendRoom(nil, nil)
h.sendRoom(client, nil, nil)
}
}
}
@ -1477,6 +1534,23 @@ func getRealUserIP(r *http.Request) string {
return r.RemoteAddr
}
func (h *Hub) lookupClientCountry(client *Client) string {
ip := net.ParseIP(client.RemoteAddr())
if ip == nil {
return noCountry
} else if ip.IsLoopback() {
return loopback
}
country, err := h.geoip.LookupCountry(ip)
if err != nil {
log.Printf("Could not lookup country for %s", ip)
return unknownCountry
}
return country
}
func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) {
addr := getRealUserIP(r)
agent := r.Header.Get("User-Agent")
@ -1487,13 +1561,21 @@ func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) {
return
}
client, err := NewClient(h, conn, addr, agent)
client, err := NewClient(conn, addr, agent)
if err != nil {
log.Printf("Could not create client for %s: %s", addr, err)
return
}
if h.geoip != nil {
client.OnLookupCountry = h.lookupClientCountry
}
client.OnMessageReceived = h.processMessage
client.OnClosed = func(client *Client) {
h.processUnregister(client)
}
h.processNewClient(client)
go client.writePump()
go client.readPump()
go client.WritePump()
go client.ReadPump()
}