From 3c41dcdbce0486ac69a5d44a44751a1e59534fb9 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Thu, 8 Jan 2026 14:00:41 +0100 Subject: [PATCH] Move common client code to separate package. --- .codecov.yml | 4 + client.go => client/client.go | 133 ++++---- client/client_test.go | 315 ++++++++++++++++++ client/ip.go | 93 ++++++ client/ip_test.go | 278 ++++++++++++++++ .../stats_prometheus.go | 9 +- clientsession.go | 14 +- cmd/proxy/proxy_client.go | 28 +- cmd/proxy/proxy_server.go | 21 +- federation.go | 12 + grpc_remote_client.go | 19 +- hub.go | 154 +++------ hub_client.go | 128 +++++++ ..._test.go => hub_client_stats_prometheus.go | 34 +- hub_test.go | 237 +------------ remotesession.go | 26 +- testclient_test.go | 8 +- 17 files changed, 1024 insertions(+), 489 deletions(-) rename client.go => client/client.go (86%) create mode 100644 client/client_test.go create mode 100644 client/ip.go create mode 100644 client/ip_test.go rename client_stats_prometheus.go => client/stats_prometheus.go (87%) create mode 100644 hub_client.go rename client_test.go => hub_client_stats_prometheus.go (62%) diff --git a/.codecov.yml b/.codecov.yml index e3aba2f..7960edb 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -28,6 +28,10 @@ component_management: name: async paths: - async/** + - component_id: module_client + name: client + paths: + - client/** - component_id: module_cmd_client name: cmd/client paths: diff --git a/client.go b/client/client.go similarity index 86% rename from client.go rename to client/client.go index 0c024b2..9679526 100644 --- a/client.go +++ b/client/client.go @@ -19,7 +19,7 @@ * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see . */ -package signaling +package client import ( "bytes" @@ -80,10 +80,6 @@ type HandlerClient interface { Country() geoip.Country UserAgent() string IsConnected() bool - IsAuthenticated() bool - - GetSession() Session - SetSession(session Session) SendError(e *api.Error) bool SendByeResponse(message *api.ClientMessage) bool @@ -93,19 +89,30 @@ type HandlerClient interface { Close() } -type ClientHandler interface { - OnClosed(HandlerClient) - OnMessageReceived(HandlerClient, []byte) - OnRTTReceived(HandlerClient, time.Duration) +type Handler interface { + GetSessionId() api.PublicSessionId + + OnClosed() + OnMessageReceived([]byte) + OnRTTReceived(time.Duration) } -type ClientGeoIpHandler interface { - OnLookupCountry(HandlerClient) geoip.Country +type GeoIpHandler interface { + OnLookupCountry(addr string) geoip.Country +} + +type InRoomHandler interface { + IsInRoom(string) bool +} + +type SessionCloserHandler interface { + CloseSession() } type Client struct { - logger log.Logger - ctx context.Context + logger log.Logger + ctx context.Context + // +checklocks:mu conn *websocket.Conn addr string agent string @@ -115,9 +122,8 @@ type Client struct { handlerMu sync.RWMutex // +checklocks:handlerMu - handler ClientHandler + handler Handler - session atomic.Pointer[Session] sessionId atomic.Pointer[api.PublicSessionId] mu sync.Mutex @@ -128,42 +134,36 @@ type Client struct { messageChan chan *bytes.Buffer } -func NewClient(ctx context.Context, conn *websocket.Conn, remoteAddress string, agent string, handler ClientHandler) (*Client, error) { - remoteAddress = strings.TrimSpace(remoteAddress) - if remoteAddress == "" { - remoteAddress = "unknown remote address" - } - agent = strings.TrimSpace(agent) - if agent == "" { - agent = "unknown user agent" - } +func (c *Client) SetConn(ctx context.Context, conn *websocket.Conn, remoteAddress string, agent string, logRTT bool, handler Handler) { + c.mu.Lock() + defer c.mu.Unlock() - client := &Client{ - agent: agent, - logRTT: true, - } - client.SetConn(ctx, conn, remoteAddress, handler) - return client, nil -} - -func (c *Client) SetConn(ctx context.Context, conn *websocket.Conn, remoteAddress string, handler ClientHandler) { c.logger = log.LoggerFromContext(ctx) c.ctx = ctx c.conn = conn c.addr = remoteAddress + c.agent = agent + c.logRTT = logRTT c.SetHandler(handler) c.closer = internal.NewCloser() c.messageChan = make(chan *bytes.Buffer, 16) c.messagesDone = make(chan struct{}) } -func (c *Client) SetHandler(handler ClientHandler) { +func (c *Client) GetConn() *websocket.Conn { + c.mu.Lock() + defer c.mu.Unlock() + + return c.conn +} + +func (c *Client) SetHandler(handler Handler) { c.handlerMu.Lock() defer c.handlerMu.Unlock() c.handler = handler } -func (c *Client) getHandler() ClientHandler { +func (c *Client) getHandler() Handler { c.handlerMu.RLock() defer c.handlerMu.RUnlock() return c.handler @@ -177,27 +177,6 @@ func (c *Client) IsConnected() bool { return c.closed.Load() == 0 } -func (c *Client) IsAuthenticated() bool { - return c.GetSession() != nil -} - -func (c *Client) GetSession() Session { - session := c.session.Load() - if session == nil { - return nil - } - - return *session -} - -func (c *Client) SetSession(session Session) { - if session == nil { - c.session.Store(nil) - } else { - c.session.Store(&session) - } -} - func (c *Client) SetSessionId(sessionId api.PublicSessionId) { c.sessionId.Store(&sessionId) } @@ -205,12 +184,12 @@ func (c *Client) SetSessionId(sessionId api.PublicSessionId) { func (c *Client) GetSessionId() api.PublicSessionId { sessionId := c.sessionId.Load() if sessionId == nil { - session := c.GetSession() - if session == nil { + sessionId := c.getHandler().GetSessionId() + if sessionId == "" { return "" } - return session.PublicId() + return sessionId } return *sessionId @@ -227,8 +206,8 @@ func (c *Client) UserAgent() string { func (c *Client) Country() geoip.Country { if c.country == nil { var country geoip.Country - if handler, ok := c.getHandler().(ClientGeoIpHandler); ok { - country = handler.OnLookupCountry(c) + if handler, ok := c.getHandler().(GeoIpHandler); ok { + country = handler.OnLookupCountry(c.addr) } else { country = geoip.UnknownCountry } @@ -238,6 +217,14 @@ func (c *Client) Country() geoip.Country { return *c.country } +func (c *Client) IsInRoom(id string) bool { + if handler, ok := c.getHandler().(InRoomHandler); ok { + return handler.IsInRoom(id) + } + + return false +} + func (c *Client) Close() { if c.closed.Load() >= 2 { // Prevent reentrant call in case this was the second closing @@ -267,8 +254,7 @@ func (c *Client) doClose() { c.closer.Close() <-c.messagesDone - c.getHandler().OnClosed(c) - c.SetSession(nil) + c.getHandler().OnClosed() } } @@ -340,7 +326,7 @@ func (c *Client) ReadPump() { } } statsClientRTT.Observe(float64(rtt.Milliseconds())) - c.getHandler().OnRTTReceived(c, rtt) + c.getHandler().OnRTTReceived(rtt) } return nil }) @@ -404,7 +390,7 @@ func (c *Client) processMessages() { break } - c.getHandler().OnMessageReceived(c, buffer.Bytes()) + c.getHandler().OnMessageReceived(buffer.Bytes()) bufferPool.Put(buffer) } @@ -425,6 +411,7 @@ func (w *counterWriter) Write(p []byte) (int, error) { return written, err } +// +checklocks:c.mu func (c *Client) writeInternal(message json.Marshaler) bool { var closeData []byte @@ -512,19 +499,19 @@ func (c *Client) writeMessage(message WritableClientMessage) bool { return c.writeMessageLocked(message) } +// +checklocks:c.mu func (c *Client) writeMessageLocked(message WritableClientMessage) bool { if !c.writeInternal(message) { return false } - session := c.GetSession() - if message.CloseAfterSend(session) { - c.conn.SetWriteDeadline(time.Now().Add(writeWait)) // nolint - c.conn.WriteMessage(websocket.CloseMessage, []byte{}) // nolint - if session != nil { - go session.Close() - } - go c.Close() + if message.CloseAfterSend(c) { + go func() { + if sc, ok := c.getHandler().(SessionCloserHandler); ok { + sc.CloseSession() + } + c.Close() + }() } return true diff --git a/client/client_test.go b/client/client_test.go new file mode 100644 index 0000000..31e8c46 --- /dev/null +++ b/client/client_test.go @@ -0,0 +1,315 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2025 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package client + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "net" + "net/http" + "net/http/httptest" + "slices" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/strukturag/nextcloud-spreed-signaling/api" + "github.com/strukturag/nextcloud-spreed-signaling/geoip" + "github.com/strukturag/nextcloud-spreed-signaling/log" +) + +func TestCounterWriter(t *testing.T) { + t.Parallel() + assert := assert.New(t) + + var b bytes.Buffer + var written int + w := &counterWriter{ + w: &b, + counter: &written, + } + if count, err := w.Write(nil); assert.NoError(err) && assert.Equal(0, count) { + assert.Equal(0, written) + } + if count, err := w.Write([]byte("foo")); assert.NoError(err) && assert.Equal(3, count) { + assert.Equal(3, written) + } +} + +type serverClient struct { + Client + + t *testing.T + handler *testHandler + + id string + received atomic.Uint32 + sessionClosed atomic.Bool +} + +func newTestClient(h *testHandler, r *http.Request, conn *websocket.Conn, id uint64) *serverClient { + result := &serverClient{ + t: h.t, + handler: h, + id: fmt.Sprintf("session-%d", id), + } + + addr := r.RemoteAddr + if host, _, err := net.SplitHostPort(addr); err == nil { + addr = host + } + + logger := log.NewLoggerForTest(h.t) + ctx := log.NewLoggerContext(r.Context(), logger) + result.SetConn(ctx, conn, addr, r.Header.Get("User-Agent"), false, result) + return result +} + +func (c *serverClient) WaitReceived(ctx context.Context, count uint32) error { + for { + if err := ctx.Err(); err != nil { + return err + } else if c.received.Load() >= count { + return nil + } + + time.Sleep(time.Millisecond) + } +} + +func (c *serverClient) GetSessionId() api.PublicSessionId { + return api.PublicSessionId(c.id) +} + +func (c *serverClient) OnClosed() { + c.handler.removeClient(c) +} + +func (c *serverClient) OnMessageReceived(message []byte) { + switch c.received.Add(1) { + case 1: + var s string + if err := json.Unmarshal(message, &s); assert.NoError(c.t, err) { + assert.Equal(c.t, "Hello world!", s) + c.SendMessage(&api.ServerMessage{ + Type: "welcome", + Welcome: &api.WelcomeServerMessage{ + Version: "1.0", + }, + }) + } + case 2: + var s string + if err := json.Unmarshal(message, &s); assert.NoError(c.t, err) { + assert.Equal(c.t, "Send error", s) + c.SendError(api.NewError("test_error", "This is a test error.")) + } + case 3: + var s string + if err := json.Unmarshal(message, &s); assert.NoError(c.t, err) { + assert.Equal(c.t, "Send bye", s) + c.SendByeResponseWithReason(nil, "Go away!") + } + } +} + +func (c *serverClient) OnRTTReceived(rtt time.Duration) { + +} + +func (c *serverClient) OnLookupCountry(addr string) geoip.Country { + return "DE" +} + +func (c *serverClient) IsInRoom(roomId string) bool { + return false +} + +func (c *serverClient) CloseSession() { + if c.sessionClosed.Swap(true) { + assert.Fail(c.t, "session closed more than once") + } +} + +type testHandler struct { + mu sync.Mutex + + t *testing.T + + upgrader websocket.Upgrader + + id atomic.Uint64 + // +checklocks:mu + activeClients map[string]*serverClient + // +checklocks:mu + allClients []*serverClient +} + +func newTestHandler(t *testing.T) *testHandler { + return &testHandler{ + t: t, + activeClients: make(map[string]*serverClient), + } +} + +func (h *testHandler) addClient(client *serverClient) { + h.mu.Lock() + defer h.mu.Unlock() + + h.activeClients[client.id] = client + h.allClients = append(h.allClients, client) +} + +func (h *testHandler) removeClient(client *serverClient) { + h.mu.Lock() + defer h.mu.Unlock() + + delete(h.activeClients, client.id) +} + +func (h *testHandler) getClients() []*serverClient { + h.mu.Lock() + defer h.mu.Unlock() + + return slices.Clone(h.allClients) +} + +func (h *testHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + conn, err := h.upgrader.Upgrade(w, r, nil) + if !assert.NoError(h.t, err) { + return + } + + id := h.id.Add(1) + client := newTestClient(h, r, conn, id) + h.addClient(client) + go client.WritePump() + client.ReadPump() +} + +type localClient struct { + t *testing.T + + conn *websocket.Conn +} + +func newLocalClient(t *testing.T, url string) *localClient { + t.Helper() + + conn, _, err := websocket.DefaultDialer.DialContext(t.Context(), url, nil) + require.NoError(t, err) + return &localClient{ + t: t, + + conn: conn, + } +} + +func (c *localClient) Close() error { + err := c.conn.Close() + if errors.Is(err, net.ErrClosed) { + err = nil + } + return err +} + +func (c *localClient) WriteJSON(v any) error { + return c.conn.WriteJSON(v) +} + +func (c *localClient) ReadJSON(v any) error { + return c.conn.ReadJSON(v) +} + +func TestClient(t *testing.T) { + t.Parallel() + + require := require.New(t) + assert := assert.New(t) + + serverHandler := newTestHandler(t) + + server := httptest.NewServer(serverHandler) + t.Cleanup(func() { + server.Close() + }) + + client := newLocalClient(t, strings.ReplaceAll(server.URL, "http://", "ws://")) + t.Cleanup(func() { + assert.NoError(client.Close()) + }) + + var msg api.ServerMessage + + require.NoError(client.WriteJSON("Hello world!")) + if assert.NoError(client.ReadJSON(&msg)) && + assert.Equal("welcome", msg.Type) && + assert.NotNil(msg.Welcome) { + assert.Equal("1.0", msg.Welcome.Version) + } + if clients := serverHandler.getClients(); assert.Len(clients, 1) { + assert.False(clients[0].sessionClosed.Load()) + assert.EqualValues(1, clients[0].received.Load()) + } + + require.NoError(client.WriteJSON("Send error")) + if assert.NoError(client.ReadJSON(&msg)) && + assert.Equal("error", msg.Type) && + assert.NotNil(msg.Error) { + assert.Equal("test_error", msg.Error.Code) + assert.Equal("This is a test error.", msg.Error.Message) + } + if clients := serverHandler.getClients(); assert.Len(clients, 1) { + assert.False(clients[0].sessionClosed.Load()) + assert.EqualValues(2, clients[0].received.Load()) + } + + require.NoError(client.WriteJSON("Send bye")) + if assert.NoError(client.ReadJSON(&msg)) && + assert.Equal("bye", msg.Type) && + assert.NotNil(msg.Bye) { + assert.Equal("Go away!", msg.Bye.Reason) + } + if clients := serverHandler.getClients(); assert.Len(clients, 1) { + assert.EqualValues(3, clients[0].received.Load()) + } + + // Sending a "bye" will close the connection. + var we *websocket.CloseError + if err := client.ReadJSON(&msg); assert.ErrorAs(err, &we) { + assert.Equal(websocket.CloseNormalClosure, we.Code) + assert.Empty(we.Text) + } + if clients := serverHandler.getClients(); assert.Len(clients, 1) { + assert.True(clients[0].sessionClosed.Load()) + assert.EqualValues(3, clients[0].received.Load()) + } +} diff --git a/client/ip.go b/client/ip.go new file mode 100644 index 0000000..d897278 --- /dev/null +++ b/client/ip.go @@ -0,0 +1,93 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2026 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package client + +import ( + "net" + "net/http" + "slices" + "strings" + + "github.com/strukturag/nextcloud-spreed-signaling/container" +) + +var ( + DefaultTrustedProxies = container.DefaultPrivateIPs() +) + +func GetRealUserIP(r *http.Request, trusted *container.IPList) string { + addr := r.RemoteAddr + if host, _, err := net.SplitHostPort(addr); err == nil { + addr = host + } + + ip := net.ParseIP(addr) + if len(ip) == 0 { + return addr + } + + // Don't check any headers if the server can be reached by untrusted clients directly. + if trusted == nil || !trusted.Contains(ip) { + return addr + } + + if realIP := r.Header.Get("X-Real-IP"); realIP != "" { + if ip := net.ParseIP(realIP); len(ip) > 0 { + return realIP + } + } + + // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For#selecting_an_ip_address + forwarded := strings.Split(strings.Join(r.Header.Values("X-Forwarded-For"), ","), ",") + if len(forwarded) > 0 { + slices.Reverse(forwarded) + var lastTrusted string + for _, hop := range forwarded { + hop = strings.TrimSpace(hop) + // Make sure to remove any port. + if host, _, err := net.SplitHostPort(hop); err == nil { + hop = host + } + + ip := net.ParseIP(hop) + if len(ip) == 0 { + continue + } + + if trusted.Contains(ip) { + lastTrusted = hop + continue + } + + return hop + } + + // If all entries in the "X-Forwarded-For" list are trusted, the left-most + // will be the client IP. This can happen if a subnet is trusted and the + // client also has an IP from this subnet. + if lastTrusted != "" { + return lastTrusted + } + } + + return addr +} diff --git a/client/ip_test.go b/client/ip_test.go new file mode 100644 index 0000000..f12b62e --- /dev/null +++ b/client/ip_test.go @@ -0,0 +1,278 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2026 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package client + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/strukturag/nextcloud-spreed-signaling/container" +) + +func TestGetRealUserIP(t *testing.T) { + t.Parallel() + testcases := []struct { + expected string + headers http.Header + trusted string + addr string + }{ + { + "192.168.1.2", + nil, + "192.168.0.0/16", + "192.168.1.2:23456", + }, + { + "invalid-ip", + nil, + "192.168.0.0/16", + "invalid-ip", + }, + { + "invalid-ip", + nil, + "192.168.0.0/16", + "invalid-ip:12345", + }, + { + "10.11.12.13", + nil, + "192.168.0.0/16", + "10.11.12.13:23456", + }, + { + "10.11.12.13", + http.Header{ + http.CanonicalHeaderKey("x-real-ip"): []string{"10.11.12.13"}, + }, + "192.168.0.0/16", + "192.168.1.2:23456", + }, + { + "2002:db8::1", + http.Header{ + http.CanonicalHeaderKey("x-real-ip"): []string{"2002:db8::1"}, + }, + "192.168.0.0/16", + "192.168.1.2:23456", + }, + { + "11.12.13.14", + http.Header{ + http.CanonicalHeaderKey("x-forwarded-for"): []string{"11.12.13.14, 192.168.30.32"}, + }, + "192.168.0.0/16", + "192.168.1.2:23456", + }, + { + "11.12.13.14", + http.Header{ + http.CanonicalHeaderKey("x-forwarded-for"): []string{"11.12.13.14:1234, 192.168.30.32:2345"}, + }, + "192.168.0.0/16", + "192.168.1.2:23456", + }, + { + "10.11.12.13", + http.Header{ + http.CanonicalHeaderKey("x-real-ip"): []string{"10.11.12.13"}, + }, + "2001:db8::/48", + "[2001:db8::1]:23456", + }, + { + "2002:db8::1", + http.Header{ + http.CanonicalHeaderKey("x-real-ip"): []string{"2002:db8::1"}, + }, + "2001:db8::/48", + "[2001:db8::1]:23456", + }, + { + "2002:db8::1", + http.Header{ + http.CanonicalHeaderKey("x-forwarded-for"): []string{"2002:db8::1, 192.168.30.32"}, + }, + "192.168.0.0/16", + "192.168.1.2:23456", + }, + { + "2002:db8::1", + http.Header{ + http.CanonicalHeaderKey("x-forwarded-for"): []string{"2002:db8::1, 2001:db8::1"}, + }, + "192.168.0.0/16, 2001:db8::/48", + "192.168.1.2:23456", + }, + { + "2002:db8::1", + http.Header{ + http.CanonicalHeaderKey("x-forwarded-for"): []string{"2002:db8::1, 192.168.30.32"}, + }, + "192.168.0.0/16, 2001:db8::/48", + "[2001:db8::1]:23456", + }, + { + "2002:db8::1", + http.Header{ + http.CanonicalHeaderKey("x-forwarded-for"): []string{"2002:db8::1, 2001:db8::2"}, + }, + "2001:db8::/48", + "[2001:db8::1]:23456", + }, + // "X-Real-IP" has preference before "X-Forwarded-For" + { + "10.11.12.13", + http.Header{ + http.CanonicalHeaderKey("x-real-ip"): []string{"10.11.12.13"}, + http.CanonicalHeaderKey("x-forwarded-for"): []string{"11.12.13.14, 192.168.30.32"}, + }, + "192.168.0.0/16", + "192.168.1.2:23456", + }, + // Multiple "X-Forwarded-For" headers are merged. + { + "11.12.13.14", + http.Header{ + http.CanonicalHeaderKey("x-forwarded-for"): []string{"11.12.13.14", "192.168.30.32"}, + }, + "192.168.0.0/16", + "192.168.1.2:23456", + }, + { + "11.12.13.14", + http.Header{ + http.CanonicalHeaderKey("x-forwarded-for"): []string{"1.2.3.4", "11.12.13.14", "192.168.30.32"}, + }, + "192.168.0.0/16", + "192.168.1.2:23456", + }, + { + "11.12.13.14", + http.Header{ + http.CanonicalHeaderKey("x-forwarded-for"): []string{"1.2.3.4", "2.3.4.5", "11.12.13.14", "192.168.31.32", "192.168.30.32"}, + }, + "192.168.0.0/16", + "192.168.1.2:23456", + }, + // Headers are ignored if coming from untrusted clients. + { + "10.11.12.13", + http.Header{ + http.CanonicalHeaderKey("x-real-ip"): []string{"11.12.13.14"}, + }, + "192.168.0.0/16", + "10.11.12.13:23456", + }, + { + "10.11.12.13", + http.Header{ + http.CanonicalHeaderKey("x-forwarded-for"): []string{"11.12.13.14, 192.168.30.32"}, + }, + "192.168.0.0/16", + "10.11.12.13:23456", + }, + // X-Forwarded-For is filtered for trusted proxies. + { + "1.2.3.4", + http.Header{ + http.CanonicalHeaderKey("x-forwarded-for"): []string{"11.12.13.14, 1.2.3.4"}, + }, + "192.168.0.0/16", + "192.168.1.2:23456", + }, + { + "1.2.3.4", + http.Header{ + http.CanonicalHeaderKey("x-forwarded-for"): []string{"11.12.13.14, 1.2.3.4, 192.168.2.3"}, + }, + "192.168.0.0/16", + "192.168.1.2:23456", + }, + { + "10.11.12.13", + http.Header{ + http.CanonicalHeaderKey("x-forwarded-for"): []string{"11.12.13.14, 1.2.3.4"}, + }, + "192.168.0.0/16", + "10.11.12.13:23456", + }, + // Invalid IPs are ignored. + { + "192.168.1.2", + http.Header{ + http.CanonicalHeaderKey("x-real-ip"): []string{"this-is-not-an-ip"}, + }, + "192.168.0.0/16", + "192.168.1.2:23456", + }, + { + "11.12.13.14", + http.Header{ + http.CanonicalHeaderKey("x-real-ip"): []string{"this-is-not-an-ip"}, + http.CanonicalHeaderKey("x-forwarded-for"): []string{"11.12.13.14, 192.168.30.32"}, + }, + "192.168.0.0/16", + "192.168.1.2:23456", + }, + { + "11.12.13.14", + http.Header{ + http.CanonicalHeaderKey("x-real-ip"): []string{"this-is-not-an-ip"}, + http.CanonicalHeaderKey("x-forwarded-for"): []string{"11.12.13.14, 192.168.30.32, proxy1"}, + }, + "192.168.0.0/16", + "192.168.1.2:23456", + }, + { + "192.168.1.2", + http.Header{ + http.CanonicalHeaderKey("x-forwarded-for"): []string{"this-is-not-an-ip"}, + }, + "192.168.0.0/16", + "192.168.1.2:23456", + }, + { + "192.168.2.3", + http.Header{ + http.CanonicalHeaderKey("x-forwarded-for"): []string{"this-is-not-an-ip, 192.168.2.3"}, + }, + "192.168.0.0/16", + "192.168.1.2:23456", + }, + } + + for _, tc := range testcases { + trustedProxies, err := container.ParseIPList(tc.trusted) + if !assert.NoError(t, err, "invalid trusted proxies in %+v", tc) { + continue + } + request := &http.Request{ + RemoteAddr: tc.addr, + Header: tc.headers, + } + assert.Equal(t, tc.expected, GetRealUserIP(request, trustedProxies), "failed for %+v", tc) + } +} diff --git a/client_stats_prometheus.go b/client/stats_prometheus.go similarity index 87% rename from client_stats_prometheus.go rename to client/stats_prometheus.go index f27248f..86ae928 100644 --- a/client_stats_prometheus.go +++ b/client/stats_prometheus.go @@ -19,7 +19,7 @@ * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see . */ -package signaling +package client import ( "github.com/prometheus/client_golang/prometheus" @@ -28,12 +28,6 @@ import ( ) var ( - statsClientCountries = prometheus.NewCounterVec(prometheus.CounterOpts{ - Namespace: "signaling", - Subsystem: "client", - Name: "countries_total", - Help: "The total number of connections by country", - }, []string{"country"}) statsClientRTT = prometheus.NewHistogram(prometheus.HistogramOpts{ Namespace: "signaling", Subsystem: "client", @@ -55,7 +49,6 @@ var ( }, []string{"direction"}) clientStats = []prometheus.Collector{ - statsClientCountries, statsClientRTT, statsClientBytesTotal, statsClientMessagesTotal, diff --git a/clientsession.go b/clientsession.go index e19b4b5..b601f73 100644 --- a/clientsession.go +++ b/clientsession.go @@ -88,7 +88,7 @@ type ClientSession struct { asyncCh events.AsyncChannel // +checklocks:mu - client HandlerClient + client ClientWithSession room atomic.Pointer[Room] roomJoinTime atomic.Int64 federation atomic.Pointer[FederationClient] @@ -604,7 +604,7 @@ func (s *ClientSession) doUnsubscribeRoomEvents(notify bool) { s.roomSessionId = "" } -func (s *ClientSession) ClearClient(client HandlerClient) { +func (s *ClientSession) ClearClient(client ClientWithSession) { s.mu.Lock() defer s.mu.Unlock() @@ -612,7 +612,7 @@ func (s *ClientSession) ClearClient(client HandlerClient) { } // +checklocks:s.mu -func (s *ClientSession) clearClientLocked(client HandlerClient) { +func (s *ClientSession) clearClientLocked(client ClientWithSession) { if s.client == nil { return } else if client != nil && s.client != client { @@ -625,7 +625,7 @@ func (s *ClientSession) clearClientLocked(client HandlerClient) { prevClient.SetSession(nil) } -func (s *ClientSession) GetClient() HandlerClient { +func (s *ClientSession) GetClient() ClientWithSession { s.mu.Lock() defer s.mu.Unlock() @@ -633,11 +633,11 @@ func (s *ClientSession) GetClient() HandlerClient { } // +checklocks:s.mu -func (s *ClientSession) getClientUnlocked() HandlerClient { +func (s *ClientSession) getClientUnlocked() ClientWithSession { return s.client } -func (s *ClientSession) SetClient(client HandlerClient) HandlerClient { +func (s *ClientSession) SetClient(client ClientWithSession) ClientWithSession { if client == nil { panic("Use ClearClient to set the client to nil") } @@ -1551,7 +1551,7 @@ func (s *ClientSession) filterAsyncMessage(msg *events.AsyncMessage) *api.Server } } -func (s *ClientSession) NotifySessionResumed(client HandlerClient) { +func (s *ClientSession) NotifySessionResumed(client ClientWithSession) { s.mu.Lock() if len(s.pendingClientMessages) == 0 { s.mu.Unlock() diff --git a/cmd/proxy/proxy_client.go b/cmd/proxy/proxy_client.go index 935a2b9..aec1f03 100644 --- a/cmd/proxy/proxy_client.go +++ b/cmd/proxy/proxy_client.go @@ -27,25 +27,35 @@ import ( "time" "github.com/gorilla/websocket" - signaling "github.com/strukturag/nextcloud-spreed-signaling" + + "github.com/strukturag/nextcloud-spreed-signaling/api" + "github.com/strukturag/nextcloud-spreed-signaling/client" ) type ProxyClient struct { - signaling.Client + client.Client proxy *ProxyServer session atomic.Pointer[ProxySession] } -func NewProxyClient(ctx context.Context, proxy *ProxyServer, conn *websocket.Conn, addr string) (*ProxyClient, error) { +func NewProxyClient(ctx context.Context, proxy *ProxyServer, conn *websocket.Conn, addr string, agent string) (*ProxyClient, error) { client := &ProxyClient{ proxy: proxy, } - client.SetConn(ctx, conn, addr, client) + client.SetConn(ctx, conn, addr, agent, false, client) return client, nil } +func (c *ProxyClient) GetSessionId() api.PublicSessionId { + if session := c.GetSession(); session != nil { + return session.PublicId() + } + + return "" +} + func (c *ProxyClient) GetSession() *ProxySession { return c.session.Load() } @@ -54,18 +64,18 @@ func (c *ProxyClient) SetSession(session *ProxySession) { c.session.Store(session) } -func (c *ProxyClient) OnClosed(client signaling.HandlerClient) { - if session := c.GetSession(); session != nil { +func (c *ProxyClient) OnClosed() { + if session := c.session.Swap(nil); session != nil { session.MarkUsed() } - c.proxy.clientClosed(&c.Client) + c.proxy.clientClosed(c) } -func (c *ProxyClient) OnMessageReceived(client signaling.HandlerClient, data []byte) { +func (c *ProxyClient) OnMessageReceived(data []byte) { c.proxy.processMessage(c, data) } -func (c *ProxyClient) OnRTTReceived(client signaling.HandlerClient, rtt time.Duration) { +func (c *ProxyClient) OnRTTReceived(rtt time.Duration) { if session := c.GetSession(); session != nil { session.MarkUsed() } diff --git a/cmd/proxy/proxy_server.go b/cmd/proxy/proxy_server.go index 6817cf7..8d4de6c 100644 --- a/cmd/proxy/proxy_server.go +++ b/cmd/proxy/proxy_server.go @@ -51,6 +51,7 @@ import ( signaling "github.com/strukturag/nextcloud-spreed-signaling" "github.com/strukturag/nextcloud-spreed-signaling/api" "github.com/strukturag/nextcloud-spreed-signaling/async" + "github.com/strukturag/nextcloud-spreed-signaling/client" "github.com/strukturag/nextcloud-spreed-signaling/config" "github.com/strukturag/nextcloud-spreed-signaling/container" "github.com/strukturag/nextcloud-spreed-signaling/geoip" @@ -87,6 +88,8 @@ const ( ) var ( + InvalidFormat = client.InvalidFormat + defaultProxyFeatures = []string{ ProxyFeatureRemoteStreams, } @@ -279,7 +282,7 @@ func NewProxyServer(ctx context.Context, r *mux.Router, version string, config * if !trustedProxiesIps.Empty() { logger.Printf("Trusted proxies: %s", trustedProxiesIps) } else { - trustedProxiesIps = signaling.DefaultTrustedProxies + trustedProxiesIps = client.DefaultTrustedProxies logger.Printf("No trusted proxies configured, only allowing for %s", trustedProxiesIps) } @@ -640,7 +643,7 @@ func (s *ProxyServer) Reload(config *goconf.ConfigFile) { if !trustedProxiesIps.Empty() { s.logger.Printf("Trusted proxies: %s", trustedProxiesIps) } else { - trustedProxiesIps = signaling.DefaultTrustedProxies + trustedProxiesIps = client.DefaultTrustedProxies s.logger.Printf("No trusted proxies configured, only allowing for %s", trustedProxiesIps) } s.trustedProxies.Store(trustedProxiesIps) @@ -675,7 +678,7 @@ func (s *ProxyServer) welcomeHandler(w http.ResponseWriter, r *http.Request) { } func (s *ProxyServer) proxyHandler(w http.ResponseWriter, r *http.Request) { - addr := signaling.GetRealUserIP(r, s.trustedProxies.Load()) + addr := client.GetRealUserIP(r, s.trustedProxies.Load()) header := http.Header{} header.Set("Server", "nextcloud-spreed-signaling-proxy/"+s.version) header.Set("X-Spreed-Signaling-Features", strings.Join(s.welcomeMsg.Features, ", ")) @@ -685,14 +688,14 @@ func (s *ProxyServer) proxyHandler(w http.ResponseWriter, r *http.Request) { return } + agent := r.Header.Get("User-Agent") ctx := log.NewLoggerContext(r.Context(), s.logger) if conn.Subprotocol() == janus.EventsSubprotocol { - agent := r.Header.Get("User-Agent") janus.RunEventsHandler(ctx, s.mcu, conn, addr, agent) return } - client, err := NewProxyClient(ctx, s, conn, addr) + client, err := NewProxyClient(ctx, s, conn, addr, agent) if err != nil { s.logger.Printf("Could not create client for %s: %s", addr, err) return @@ -702,7 +705,7 @@ func (s *ProxyServer) proxyHandler(w http.ResponseWriter, r *http.Request) { client.ReadPump() } -func (s *ProxyServer) clientClosed(client *signaling.Client) { +func (s *ProxyServer) clientClosed(client *ProxyClient) { s.logger.Printf("Connection from %s closed", client.RemoteAddr()) } @@ -766,7 +769,7 @@ func (s *ProxyServer) processMessage(client *ProxyClient, data []byte) { } else { s.logger.Printf("Error decoding message from %s: %v", client.RemoteAddr(), err) } - client.SendError(signaling.InvalidFormat) + client.SendError(InvalidFormat) return } @@ -776,7 +779,7 @@ func (s *ProxyServer) processMessage(client *ProxyClient, data []byte) { } else { s.logger.Printf("Invalid message %+v from %s: %v", message, client.RemoteAddr(), err) } - client.SendMessage(message.NewErrorServerMessage(signaling.InvalidFormat)) + client.SendMessage(message.NewErrorServerMessage(InvalidFormat)) return } @@ -1654,7 +1657,7 @@ func (s *ProxyServer) getStats() api.StringMap { } func (s *ProxyServer) allowStatsAccess(r *http.Request) bool { - addr := signaling.GetRealUserIP(r, s.trustedProxies.Load()) + addr := client.GetRealUserIP(r, s.trustedProxies.Load()) ip := net.ParseIP(addr) if len(ip) == 0 { return false diff --git a/federation.go b/federation.go index a0f6296..1e04de5 100644 --- a/federation.go +++ b/federation.go @@ -43,6 +43,18 @@ import ( ) const ( + // Time allowed to write a message to the peer. + writeWait = 10 * time.Second + + // Time allowed to read the next pong message from the peer. + pongWait = 60 * time.Second + + // Send pings to peer with this period. Must be less than pongWait. + pingPeriod = (pongWait * 9) / 10 + + // Maximum message size allowed from peer. + maxMessageSize = 64 * 1024 + initialFederationReconnectInterval = 100 * time.Millisecond maxFederationReconnectInterval = 8 * time.Second ) diff --git a/grpc_remote_client.go b/grpc_remote_client.go index 4c34137..0456d1a 100644 --- a/grpc_remote_client.go +++ b/grpc_remote_client.go @@ -34,6 +34,7 @@ import ( "google.golang.org/grpc/status" "github.com/strukturag/nextcloud-spreed-signaling/api" + "github.com/strukturag/nextcloud-spreed-signaling/client" "github.com/strukturag/nextcloud-spreed-signaling/geoip" "github.com/strukturag/nextcloud-spreed-signaling/grpc" "github.com/strukturag/nextcloud-spreed-signaling/log" @@ -57,7 +58,7 @@ type remoteGrpcClient struct { hub *Hub client grpc.RpcSessions_ProxySessionServer - sessionId string + sessionId api.PublicSessionId remoteAddr string country geoip.Country userAgent string @@ -66,7 +67,7 @@ type remoteGrpcClient struct { closeFunc context.CancelCauseFunc session atomic.Pointer[Session] - messages chan WritableClientMessage + messages chan client.WritableClientMessage } func newRemoteGrpcClient(hub *Hub, request grpc.RpcSessions_ProxySessionServer) (*remoteGrpcClient, error) { @@ -82,7 +83,7 @@ func newRemoteGrpcClient(hub *Hub, request grpc.RpcSessions_ProxySessionServer) hub: hub, client: request, - sessionId: getMD(md, "sessionId"), + sessionId: api.PublicSessionId(getMD(md, "sessionId")), remoteAddr: getMD(md, "remoteAddr"), country: geoip.Country(getMD(md, "country")), userAgent: getMD(md, "userAgent"), @@ -90,7 +91,7 @@ func newRemoteGrpcClient(hub *Hub, request grpc.RpcSessions_ProxySessionServer) closeCtx: closeCtx, closeFunc: closeFunc, - messages: make(chan WritableClientMessage, grpcRemoteClientMessageQueue), + messages: make(chan client.WritableClientMessage, grpcRemoteClientMessageQueue), } return result, nil } @@ -99,7 +100,7 @@ func (c *remoteGrpcClient) readPump() { var closeError error defer func() { c.closeFunc(closeError) - c.hub.OnClosed(c) + c.hub.processUnregister(c) }() for { @@ -117,7 +118,7 @@ func (c *remoteGrpcClient) readPump() { break } - c.hub.OnMessageReceived(c, msg.Message) + c.hub.processMessage(c, msg.Message) } } @@ -145,6 +146,10 @@ func (c *remoteGrpcClient) IsAuthenticated() bool { return c.GetSession() != nil } +func (c *remoteGrpcClient) GetSessionId() api.PublicSessionId { + return c.sessionId +} + func (c *remoteGrpcClient) GetSession() Session { session := c.session.Load() if session == nil { @@ -190,7 +195,7 @@ func (c *remoteGrpcClient) SendByeResponseWithReason(message *api.ClientMessage, return c.SendMessage(response) } -func (c *remoteGrpcClient) SendMessage(message WritableClientMessage) bool { +func (c *remoteGrpcClient) SendMessage(message client.WritableClientMessage) bool { if c.closeCtx.Err() != nil { return false } diff --git a/hub.go b/hub.go index 622197c..61d43e8 100644 --- a/hub.go +++ b/hub.go @@ -54,6 +54,7 @@ import ( "github.com/strukturag/nextcloud-spreed-signaling/api" "github.com/strukturag/nextcloud-spreed-signaling/async" "github.com/strukturag/nextcloud-spreed-signaling/async/events" + "github.com/strukturag/nextcloud-spreed-signaling/client" "github.com/strukturag/nextcloud-spreed-signaling/config" "github.com/strukturag/nextcloud-spreed-signaling/container" "github.com/strukturag/nextcloud-spreed-signaling/etcd" @@ -139,14 +140,20 @@ var ( // Allow time differences of up to one minute between server and proxy. tokenLeeway = time.Minute - - DefaultTrustedProxies = container.DefaultPrivateIPs() ) func init() { RegisterHubStats() } +type ClientWithSession interface { + client.HandlerClient + + IsAuthenticated() bool + GetSession() Session + SetSession(session Session) +} + type Hub struct { version string logger log.Logger @@ -174,7 +181,7 @@ type Hub struct { sid atomic.Uint64 // +checklocks:mu - clients map[uint64]HandlerClient + clients map[uint64]ClientWithSession // +checklocks:mu sessions map[uint64]Session // +checklocks:ru @@ -198,7 +205,7 @@ type Hub struct { // +checklocks:mu anonymousSessions map[*ClientSession]time.Time // +checklocks:mu - expectHelloClients map[HandlerClient]time.Time + expectHelloClients map[ClientWithSession]time.Time // +checklocks:mu dialoutSessions map[*ClientSession]bool // +checklocks:mu @@ -309,7 +316,7 @@ func NewHub(ctx context.Context, cfg *goconf.ConfigFile, events events.AsyncEven if !trustedProxiesIps.Empty() { logger.Printf("Trusted proxies: %s", trustedProxiesIps) } else { - trustedProxiesIps = DefaultTrustedProxies + trustedProxiesIps = client.DefaultTrustedProxies logger.Printf("No trusted proxies configured, only allowing for %s", trustedProxiesIps) } @@ -388,7 +395,7 @@ func NewHub(ctx context.Context, cfg *goconf.ConfigFile, events events.AsyncEven roomInCall: make(chan *talk.BackendServerRoomRequest), roomParticipants: make(chan *talk.BackendServerRoomRequest), - clients: make(map[uint64]HandlerClient), + clients: make(map[uint64]ClientWithSession), sessions: make(map[uint64]Session), rooms: make(map[string]*Room), @@ -405,7 +412,7 @@ func NewHub(ctx context.Context, cfg *goconf.ConfigFile, events events.AsyncEven expiredSessions: make(map[Session]time.Time), anonymousSessions: make(map[*ClientSession]time.Time), - expectHelloClients: make(map[HandlerClient]time.Time), + expectHelloClients: make(map[ClientWithSession]time.Time), dialoutSessions: make(map[*ClientSession]bool), remoteSessions: make(map[*RemoteSession]bool), federatedSessions: make(map[*ClientSession]bool), @@ -584,7 +591,7 @@ func (h *Hub) Reload(ctx context.Context, config *goconf.ConfigFile) { if !trustedProxiesIps.Empty() { h.logger.Printf("Trusted proxies: %s", trustedProxiesIps) } else { - trustedProxiesIps = DefaultTrustedProxies + trustedProxiesIps = client.DefaultTrustedProxies h.logger.Printf("No trusted proxies configured, only allowing for %s", trustedProxiesIps) } h.trustedProxies.Store(trustedProxiesIps) @@ -894,7 +901,7 @@ func (h *Hub) startWaitAnonymousSessionRoomLocked(session *ClientSession) { h.anonymousSessions[session] = now.Add(anonmyousJoinRoomTimeout) } -func (h *Hub) startExpectHello(client HandlerClient) { +func (h *Hub) startExpectHello(client ClientWithSession) { h.mu.Lock() defer h.mu.Unlock() if !client.IsConnected() { @@ -910,16 +917,16 @@ func (h *Hub) startExpectHello(client HandlerClient) { h.expectHelloClients[client] = now.Add(initialHelloTimeout) } -func (h *Hub) processNewClient(client HandlerClient) { +func (h *Hub) processNewClient(client ClientWithSession) { h.startExpectHello(client) h.sendWelcome(client) } -func (h *Hub) sendWelcome(client HandlerClient) { +func (h *Hub) sendWelcome(client ClientWithSession) { client.SendMessage(h.getWelcomeMessage()) } -func (h *Hub) registerClient(client HandlerClient) uint64 { +func (h *Hub) registerClient(client ClientWithSession) uint64 { sid := h.sid.Add(1) for sid == 0 { sid = h.sid.Add(1) @@ -956,24 +963,17 @@ func (h *Hub) newSessionIdData(backend *talk.Backend) *SessionIdData { return sessionIdData } -func (h *Hub) processRegister(c HandlerClient, message *api.ClientMessage, backend *talk.Backend, auth *talk.BackendClientResponse) { - if !c.IsConnected() { +func (h *Hub) processRegister(client ClientWithSession, message *api.ClientMessage, backend *talk.Backend, auth *talk.BackendClientResponse) { + if !client.IsConnected() { // Client disconnected while waiting for "hello" response. return } if auth.Type == "error" { - c.SendMessage(message.NewErrorServerMessage(auth.Error)) + client.SendMessage(message.NewErrorServerMessage(auth.Error)) return } else if auth.Type != "auth" { - c.SendMessage(message.NewErrorServerMessage(UserAuthFailed)) - return - } - - client, ok := c.(*Client) - if !ok { - h.logger.Printf("Can't register non-client %T", c) - client.SendMessage(message.NewWrappedErrorServerMessage(errors.New("can't register non-client"))) + client.SendMessage(message.NewErrorServerMessage(UserAuthFailed)) return } @@ -1077,7 +1077,7 @@ func (h *Hub) processRegister(c HandlerClient, message *api.ClientMessage, backe h.sendHelloResponse(session, message) } -func (h *Hub) processUnregister(client HandlerClient) Session { +func (h *Hub) processUnregister(client ClientWithSession) Session { session := client.GetSession() h.mu.Lock() @@ -1090,10 +1090,8 @@ func (h *Hub) processUnregister(client HandlerClient) Session { h.mu.Unlock() if session != nil { h.logger.Printf("Unregister %s (private=%s)", session.PublicId(), session.PrivateId()) - if c, ok := client.(*Client); ok { - if cs, ok := session.(*ClientSession); ok { - cs.ClearClient(c) - } + if cs, ok := session.(*ClientSession); ok { + cs.ClearClient(client) } } @@ -1101,7 +1099,7 @@ func (h *Hub) processUnregister(client HandlerClient) Session { return session } -func (h *Hub) processMessage(client HandlerClient, data []byte) { +func (h *Hub) processMessage(client ClientWithSession, data []byte) { var message api.ClientMessage if err := message.UnmarshalJSON(data); err != nil { if session := client.GetSession(); session != nil { @@ -1198,8 +1196,8 @@ type remoteClientInfo struct { response *grpc.LookupResumeIdReply } -func (h *Hub) tryProxyResume(c HandlerClient, resumeId api.PrivateSessionId, message *api.ClientMessage) bool { - client, ok := c.(*Client) +func (h *Hub) tryProxyResume(c ClientWithSession, resumeId api.PrivateSessionId, message *api.ClientMessage) bool { + client, ok := c.(*HubClient) if !ok { return false } @@ -1212,7 +1210,7 @@ func (h *Hub) tryProxyResume(c HandlerClient, resumeId api.PrivateSessionId, mes return false } - rpcCtx, rpcCancel := context.WithTimeout(c.Context(), 5*time.Second) + rpcCtx, rpcCancel := context.WithTimeout(client.Context(), 5*time.Second) defer rpcCancel() var wg sync.WaitGroup @@ -1274,7 +1272,7 @@ func (h *Hub) tryProxyResume(c HandlerClient, resumeId api.PrivateSessionId, mes return true } -func (h *Hub) processHello(client HandlerClient, message *api.ClientMessage) { +func (h *Hub) processHello(client ClientWithSession, message *api.ClientMessage) { ctx := log.NewLoggerContext(client.Context(), h.logger) resumeId := message.Hello.ResumeId if resumeId != "" { @@ -1366,7 +1364,7 @@ func (h *Hub) processHello(client HandlerClient, message *api.ClientMessage) { } } -func (h *Hub) processHelloV1(ctx context.Context, client HandlerClient, message *api.ClientMessage) (*talk.Backend, *talk.BackendClientResponse, error) { +func (h *Hub) processHelloV1(ctx context.Context, client ClientWithSession, message *api.ClientMessage) (*talk.Backend, *talk.BackendClientResponse, error) { url := message.Hello.Auth.ParsedUrl backend := h.backend.GetBackend(url) if backend == nil { @@ -1390,7 +1388,7 @@ func (h *Hub) processHelloV1(ctx context.Context, client HandlerClient, message return backend, &auth, nil } -func (h *Hub) processHelloV2(ctx context.Context, client HandlerClient, message *api.ClientMessage) (*talk.Backend, *talk.BackendClientResponse, error) { +func (h *Hub) processHelloV2(ctx context.Context, client ClientWithSession, message *api.ClientMessage) (*talk.Backend, *talk.BackendClientResponse, error) { url := message.Hello.Auth.ParsedUrl backend := h.backend.GetBackend(url) if backend == nil { @@ -1543,11 +1541,11 @@ func (h *Hub) processHelloV2(ctx context.Context, client HandlerClient, message return backend, auth, nil } -func (h *Hub) processHelloClient(client HandlerClient, message *api.ClientMessage) { +func (h *Hub) processHelloClient(client ClientWithSession, message *api.ClientMessage) { // Make sure the client must send another "hello" in case of errors. defer h.startExpectHello(client) - var authFunc func(context.Context, HandlerClient, *api.ClientMessage) (*talk.Backend, *talk.BackendClientResponse, error) + var authFunc func(context.Context, ClientWithSession, *api.ClientMessage) (*talk.Backend, *talk.BackendClientResponse, error) switch message.Hello.Version { case api.HelloVersionV1: // Auth information contains a ticket that must be validated against the @@ -1574,7 +1572,7 @@ func (h *Hub) processHelloClient(client HandlerClient, message *api.ClientMessag h.processRegister(client, message, backend, auth) } -func (h *Hub) processHelloInternal(client HandlerClient, message *api.ClientMessage) { +func (h *Hub) processHelloInternal(client ClientWithSession, message *api.ClientMessage) { defer h.startExpectHello(client) if len(h.internalClientsSecret) == 0 { client.SendMessage(message.NewErrorServerMessage(InvalidClientType)) @@ -2957,7 +2955,7 @@ func (h *Hub) sendMcuMessageResponse(session *ClientSession, mcuClient sfu.Clien session.SendMessage(response_message) } -func (h *Hub) processByeMsg(client HandlerClient, message *api.ClientMessage) { +func (h *Hub) processByeMsg(client ClientWithSession, message *api.ClientMessage) { client.SendByeResponse(message) if session := h.processUnregister(client); session != nil { session.Close() @@ -3076,66 +3074,8 @@ func (h *Hub) GetServerInfoDialout() (result []talk.BackendServerInfoDialout) { return } -func GetRealUserIP(r *http.Request, trusted *container.IPList) string { - addr := r.RemoteAddr - if host, _, err := net.SplitHostPort(addr); err == nil { - addr = host - } - - ip := net.ParseIP(addr) - if len(ip) == 0 { - return addr - } - - // Don't check any headers if the server can be reached by untrusted clients directly. - if trusted == nil || !trusted.Contains(ip) { - return addr - } - - if realIP := r.Header.Get("X-Real-IP"); realIP != "" { - if ip := net.ParseIP(realIP); len(ip) > 0 { - return realIP - } - } - - // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For#selecting_an_ip_address - forwarded := strings.Split(strings.Join(r.Header.Values("X-Forwarded-For"), ","), ",") - if len(forwarded) > 0 { - slices.Reverse(forwarded) - var lastTrusted string - for _, hop := range forwarded { - hop = strings.TrimSpace(hop) - // Make sure to remove any port. - if host, _, err := net.SplitHostPort(hop); err == nil { - hop = host - } - - ip := net.ParseIP(hop) - if len(ip) == 0 { - continue - } - - if trusted.Contains(ip) { - lastTrusted = hop - continue - } - - return hop - } - - // If all entries in the "X-Forwarded-For" list are trusted, the left-most - // will be the client IP. This can happen if a subnet is trusted and the - // client also has an IP from this subnet. - if lastTrusted != "" { - return lastTrusted - } - } - - return addr -} - func (h *Hub) getRealUserIP(r *http.Request) string { - return GetRealUserIP(r, h.trustedProxies.Load()) + return client.GetRealUserIP(r, h.trustedProxies.Load()) } func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) { @@ -3158,7 +3098,7 @@ func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) { return } - client, err := NewClient(ctx, conn, addr, agent, h) + client, err := NewHubClient(ctx, conn, addr, agent, h) if err != nil { h.logger.Printf("Could not create client for %s: %s", addr, err) return @@ -3176,8 +3116,8 @@ func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) { client.ReadPump() } -func (h *Hub) OnLookupCountry(client HandlerClient) geoip.Country { - ip := net.ParseIP(client.RemoteAddr()) +func (h *Hub) LookupCountry(addr string) geoip.Country { + ip := net.ParseIP(addr) if ip == nil { return geoip.NoCountry } @@ -3206,18 +3146,6 @@ func (h *Hub) OnLookupCountry(client HandlerClient) geoip.Country { return country } -func (h *Hub) OnClosed(client HandlerClient) { - h.processUnregister(client) -} - -func (h *Hub) OnMessageReceived(client HandlerClient, data []byte) { - h.processMessage(client, data) -} - -func (h *Hub) OnRTTReceived(client HandlerClient, rtt time.Duration) { - // Ignore -} - func (h *Hub) ShutdownChannel() <-chan struct{} { return h.shutdown.C } diff --git a/hub_client.go b/hub_client.go new file mode 100644 index 0000000..ad32ac9 --- /dev/null +++ b/hub_client.go @@ -0,0 +1,128 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2017 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "context" + "strings" + "sync/atomic" + "time" + + "github.com/gorilla/websocket" + + "github.com/strukturag/nextcloud-spreed-signaling/api" + "github.com/strukturag/nextcloud-spreed-signaling/client" + "github.com/strukturag/nextcloud-spreed-signaling/geoip" +) + +var ( + InvalidFormat = client.InvalidFormat +) + +func init() { + RegisterClientStats() +} + +type HubClient struct { + client.Client + + hub *Hub + session atomic.Pointer[Session] +} + +func NewHubClient(ctx context.Context, conn *websocket.Conn, remoteAddress string, agent string, hub *Hub) (*HubClient, error) { + remoteAddress = strings.TrimSpace(remoteAddress) + if remoteAddress == "" { + remoteAddress = "unknown remote address" + } + agent = strings.TrimSpace(agent) + if agent == "" { + agent = "unknown user agent" + } + + client := &HubClient{ + hub: hub, + } + client.SetConn(ctx, conn, remoteAddress, agent, true, client) + return client, nil +} + +func (c *HubClient) OnLookupCountry(addr string) geoip.Country { + return c.hub.LookupCountry(addr) +} + +func (c *HubClient) OnClosed() { + c.hub.processUnregister(c) +} + +func (c *HubClient) OnMessageReceived(data []byte) { + c.hub.processMessage(c, data) +} + +func (c *HubClient) OnRTTReceived(rtt time.Duration) { + // Ignore +} + +func (c *HubClient) CloseSession() { + if session := c.GetSession(); session != nil { + session.Close() + } +} + +func (c *HubClient) IsAuthenticated() bool { + return c.GetSession() != nil +} + +func (c *HubClient) GetSession() Session { + session := c.session.Load() + if session == nil { + return nil + } + + return *session +} + +func (c *HubClient) SetSession(session Session) { + if session == nil { + c.session.Store(nil) + } else { + c.session.Store(&session) + } +} + +func (c *HubClient) GetSessionId() api.PublicSessionId { + session := c.GetSession() + if session == nil { + return "" + } + + return session.PublicId() +} + +func (c *HubClient) IsInRoom(id string) bool { + session := c.GetSession() + if session == nil { + return false + } + + return session.IsInRoom(id) +} diff --git a/client_test.go b/hub_client_stats_prometheus.go similarity index 62% rename from client_test.go rename to hub_client_stats_prometheus.go index 9b90e54..034d9a4 100644 --- a/client_test.go +++ b/hub_client_stats_prometheus.go @@ -1,6 +1,6 @@ /** * Standalone signaling server for the Nextcloud Spreed app. - * Copyright (C) 2025 struktur AG + * Copyright (C) 2021 struktur AG * * @author Joachim Bauch * @@ -22,26 +22,24 @@ package signaling import ( - "bytes" - "testing" + "github.com/prometheus/client_golang/prometheus" - "github.com/stretchr/testify/assert" + "github.com/strukturag/nextcloud-spreed-signaling/metrics" ) -func TestCounterWriter(t *testing.T) { - t.Parallel() - assert := assert.New(t) +var ( + statsClientCountries = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: "signaling", + Subsystem: "client", + Name: "countries_total", + Help: "The total number of connections by country", + }, []string{"country"}) - var b bytes.Buffer - var written int - w := &counterWriter{ - w: &b, - counter: &written, - } - if count, err := w.Write(nil); assert.NoError(err) && assert.Equal(0, count) { - assert.Equal(0, written) - } - if count, err := w.Write([]byte("foo")); assert.NoError(err) && assert.Equal(3, count) { - assert.Equal(3, written) + clientStats = []prometheus.Collector{ + statsClientCountries, } +) + +func RegisterClientStats() { + metrics.RegisterAll(clientStats...) } diff --git a/hub_test.go b/hub_test.go index 1b1c7d1..8fce006 100644 --- a/hub_test.go +++ b/hub_test.go @@ -3084,233 +3084,6 @@ func TestJoinRoomSwitchClient(t *testing.T) { require.Empty(roomMsg.Room.RoomId) } -func TestGetRealUserIP(t *testing.T) { - t.Parallel() - testcases := []struct { - expected string - headers http.Header - trusted string - addr string - }{ - { - "192.168.1.2", - nil, - "192.168.0.0/16", - "192.168.1.2:23456", - }, - { - "10.11.12.13", - nil, - "192.168.0.0/16", - "10.11.12.13:23456", - }, - { - "10.11.12.13", - http.Header{ - http.CanonicalHeaderKey("x-real-ip"): []string{"10.11.12.13"}, - }, - "192.168.0.0/16", - "192.168.1.2:23456", - }, - { - "2002:db8::1", - http.Header{ - http.CanonicalHeaderKey("x-real-ip"): []string{"2002:db8::1"}, - }, - "192.168.0.0/16", - "192.168.1.2:23456", - }, - { - "11.12.13.14", - http.Header{ - http.CanonicalHeaderKey("x-forwarded-for"): []string{"11.12.13.14, 192.168.30.32"}, - }, - "192.168.0.0/16", - "192.168.1.2:23456", - }, - { - "10.11.12.13", - http.Header{ - http.CanonicalHeaderKey("x-real-ip"): []string{"10.11.12.13"}, - }, - "2001:db8::/48", - "[2001:db8::1]:23456", - }, - { - "2002:db8::1", - http.Header{ - http.CanonicalHeaderKey("x-real-ip"): []string{"2002:db8::1"}, - }, - "2001:db8::/48", - "[2001:db8::1]:23456", - }, - { - "2002:db8::1", - http.Header{ - http.CanonicalHeaderKey("x-forwarded-for"): []string{"2002:db8::1, 192.168.30.32"}, - }, - "192.168.0.0/16", - "192.168.1.2:23456", - }, - { - "2002:db8::1", - http.Header{ - http.CanonicalHeaderKey("x-forwarded-for"): []string{"2002:db8::1, 2001:db8::1"}, - }, - "192.168.0.0/16, 2001:db8::/48", - "192.168.1.2:23456", - }, - { - "2002:db8::1", - http.Header{ - http.CanonicalHeaderKey("x-forwarded-for"): []string{"2002:db8::1, 192.168.30.32"}, - }, - "192.168.0.0/16, 2001:db8::/48", - "[2001:db8::1]:23456", - }, - { - "2002:db8::1", - http.Header{ - http.CanonicalHeaderKey("x-forwarded-for"): []string{"2002:db8::1, 2001:db8::2"}, - }, - "2001:db8::/48", - "[2001:db8::1]:23456", - }, - // "X-Real-IP" has preference before "X-Forwarded-For" - { - "10.11.12.13", - http.Header{ - http.CanonicalHeaderKey("x-real-ip"): []string{"10.11.12.13"}, - http.CanonicalHeaderKey("x-forwarded-for"): []string{"11.12.13.14, 192.168.30.32"}, - }, - "192.168.0.0/16", - "192.168.1.2:23456", - }, - // Multiple "X-Forwarded-For" headers are merged. - { - "11.12.13.14", - http.Header{ - http.CanonicalHeaderKey("x-forwarded-for"): []string{"11.12.13.14", "192.168.30.32"}, - }, - "192.168.0.0/16", - "192.168.1.2:23456", - }, - { - "11.12.13.14", - http.Header{ - http.CanonicalHeaderKey("x-forwarded-for"): []string{"1.2.3.4", "11.12.13.14", "192.168.30.32"}, - }, - "192.168.0.0/16", - "192.168.1.2:23456", - }, - { - "11.12.13.14", - http.Header{ - http.CanonicalHeaderKey("x-forwarded-for"): []string{"1.2.3.4", "2.3.4.5", "11.12.13.14", "192.168.31.32", "192.168.30.32"}, - }, - "192.168.0.0/16", - "192.168.1.2:23456", - }, - // Headers are ignored if coming from untrusted clients. - { - "10.11.12.13", - http.Header{ - http.CanonicalHeaderKey("x-real-ip"): []string{"11.12.13.14"}, - }, - "192.168.0.0/16", - "10.11.12.13:23456", - }, - { - "10.11.12.13", - http.Header{ - http.CanonicalHeaderKey("x-forwarded-for"): []string{"11.12.13.14, 192.168.30.32"}, - }, - "192.168.0.0/16", - "10.11.12.13:23456", - }, - // X-Forwarded-For is filtered for trusted proxies. - { - "1.2.3.4", - http.Header{ - http.CanonicalHeaderKey("x-forwarded-for"): []string{"11.12.13.14, 1.2.3.4"}, - }, - "192.168.0.0/16", - "192.168.1.2:23456", - }, - { - "1.2.3.4", - http.Header{ - http.CanonicalHeaderKey("x-forwarded-for"): []string{"11.12.13.14, 1.2.3.4, 192.168.2.3"}, - }, - "192.168.0.0/16", - "192.168.1.2:23456", - }, - { - "10.11.12.13", - http.Header{ - http.CanonicalHeaderKey("x-forwarded-for"): []string{"11.12.13.14, 1.2.3.4"}, - }, - "192.168.0.0/16", - "10.11.12.13:23456", - }, - // Invalid IPs are ignored. - { - "192.168.1.2", - http.Header{ - http.CanonicalHeaderKey("x-real-ip"): []string{"this-is-not-an-ip"}, - }, - "192.168.0.0/16", - "192.168.1.2:23456", - }, - { - "11.12.13.14", - http.Header{ - http.CanonicalHeaderKey("x-real-ip"): []string{"this-is-not-an-ip"}, - http.CanonicalHeaderKey("x-forwarded-for"): []string{"11.12.13.14, 192.168.30.32"}, - }, - "192.168.0.0/16", - "192.168.1.2:23456", - }, - { - "11.12.13.14", - http.Header{ - http.CanonicalHeaderKey("x-real-ip"): []string{"this-is-not-an-ip"}, - http.CanonicalHeaderKey("x-forwarded-for"): []string{"11.12.13.14, 192.168.30.32, proxy1"}, - }, - "192.168.0.0/16", - "192.168.1.2:23456", - }, - { - "192.168.1.2", - http.Header{ - http.CanonicalHeaderKey("x-forwarded-for"): []string{"this-is-not-an-ip"}, - }, - "192.168.0.0/16", - "192.168.1.2:23456", - }, - { - "192.168.2.3", - http.Header{ - http.CanonicalHeaderKey("x-forwarded-for"): []string{"this-is-not-an-ip, 192.168.2.3"}, - }, - "192.168.0.0/16", - "192.168.1.2:23456", - }, - } - - for _, tc := range testcases { - trustedProxies, err := container.ParseIPList(tc.trusted) - if !assert.NoError(t, err, "invalid trusted proxies in %+v", tc) { - continue - } - request := &http.Request{ - RemoteAddr: tc.addr, - Header: tc.headers, - } - assert.Equal(t, tc.expected, GetRealUserIP(request, trustedProxies), "failed for %+v", tc) - } -} - func TestClientMessageToSessionIdWhileDisconnected(t *testing.T) { t.Parallel() require := require.New(t) @@ -5254,11 +5027,11 @@ func TestGeoipOverrides(t *testing.T) { return conf, err }) - assert.Equal(geoip.Loopback, hub.OnLookupCountry(&Client{addr: "127.0.0.1"})) - assert.Equal(geoip.UnknownCountry, hub.OnLookupCountry(&Client{addr: "8.8.8.8"})) - assert.EqualValues(country1, hub.OnLookupCountry(&Client{addr: "10.1.1.2"})) - assert.EqualValues(country2, hub.OnLookupCountry(&Client{addr: "10.2.1.2"})) - assert.EqualValues(strings.ToUpper(country3), hub.OnLookupCountry(&Client{addr: "192.168.10.20"})) + assert.Equal(geoip.Loopback, hub.LookupCountry("127.0.0.1")) + assert.Equal(geoip.UnknownCountry, hub.LookupCountry("8.8.8.8")) + assert.EqualValues(country1, hub.LookupCountry("10.1.1.2")) + assert.EqualValues(country2, hub.LookupCountry("10.2.1.2")) + assert.EqualValues(strings.ToUpper(country3), hub.LookupCountry("192.168.10.20")) } func TestDialoutStatus(t *testing.T) { diff --git a/remotesession.go b/remotesession.go index 2b56333..e1a40ac 100644 --- a/remotesession.go +++ b/remotesession.go @@ -29,6 +29,7 @@ import ( "time" "github.com/strukturag/nextcloud-spreed-signaling/api" + "github.com/strukturag/nextcloud-spreed-signaling/client" "github.com/strukturag/nextcloud-spreed-signaling/geoip" "github.com/strukturag/nextcloud-spreed-signaling/grpc" "github.com/strukturag/nextcloud-spreed-signaling/log" @@ -37,14 +38,14 @@ import ( type RemoteSession struct { logger log.Logger hub *Hub - client *Client + client *HubClient remoteClient *grpc.Client sessionId api.PublicSessionId proxy atomic.Pointer[grpc.SessionProxy] } -func NewRemoteSession(hub *Hub, client *Client, remoteClient *grpc.Client, sessionId api.PublicSessionId) (*RemoteSession, error) { +func NewRemoteSession(hub *Hub, client *HubClient, remoteClient *grpc.Client, sessionId api.PublicSessionId) (*RemoteSession, error) { remoteSession := &RemoteSession{ logger: hub.logger, hub: hub, @@ -67,6 +68,10 @@ func NewRemoteSession(hub *Hub, client *Client, remoteClient *grpc.Client, sessi return remoteSession, nil } +func (s *RemoteSession) GetSessionId() api.PublicSessionId { + return s.sessionId +} + func (s *RemoteSession) Country() geoip.Country { return s.client.Country() } @@ -107,7 +112,7 @@ func (s *RemoteSession) OnProxyClose(err error) { s.Close() } -func (s *RemoteSession) SendMessage(message WritableClientMessage) bool { +func (s *RemoteSession) SendMessage(message client.WritableClientMessage) bool { return s.sendMessage(message) == nil } @@ -140,20 +145,25 @@ func (s *RemoteSession) Close() { s.client.Close() } -func (s *RemoteSession) OnLookupCountry(client HandlerClient) geoip.Country { - return s.hub.OnLookupCountry(client) +func (s *RemoteSession) OnLookupCountry(addr string) geoip.Country { + return s.hub.LookupCountry(addr) } -func (s *RemoteSession) OnClosed(client HandlerClient) { +func (s *RemoteSession) OnClosed() { s.Close() } -func (s *RemoteSession) OnMessageReceived(client HandlerClient, message []byte) { +func (s *RemoteSession) OnMessageReceived(message []byte) { if err := s.sendProxyMessage(message); err != nil { s.logger.Printf("Error sending %s to the proxy for session %s: %s", string(message), s.sessionId, err) s.Close() } } -func (s *RemoteSession) OnRTTReceived(client HandlerClient, rtt time.Duration) { +func (s *RemoteSession) OnRTTReceived(rtt time.Duration) { + // Ignore +} + +func (s *RemoteSession) IsInRoom(id string) bool { + return s.client.IsInRoom(id) } diff --git a/testclient_test.go b/testclient_test.go index c6fffd2..7fe62ef 100644 --- a/testclient_test.go +++ b/testclient_test.go @@ -319,10 +319,8 @@ func (c *TestClient) WaitForClientRemoved(ctx context.Context) error { for { found := false for _, client := range c.hub.clients { - if cc, ok := client.(*Client); ok { - cc.mu.Lock() - conn := cc.conn - cc.mu.Unlock() + if cc, ok := client.(*HubClient); ok { + conn := cc.GetConn() if conn != nil && conn.RemoteAddr().String() == c.localAddr.String() { found = true break @@ -736,7 +734,7 @@ func (c *TestClient) RunUntilMessage(ctx context.Context) (*api.ServerMessage, b func (c *TestClient) RunUntilMessageOrClosed(ctx context.Context) (*api.ServerMessage, bool) { select { case err := <-c.readErrorChan: - if c.assert.Error(err) && websocket.IsCloseError(err, websocket.CloseNoStatusReceived) { + if c.assert.Error(err) && websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) { return nil, true }