Merge pull request #870 from strukturag/improve-memory

Improve memory allocations
This commit is contained in:
Joachim Bauch 2025-04-16 13:24:13 +02:00 committed by GitHub
commit bfc153c2e6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 130 additions and 36 deletions

View file

@ -22,12 +22,10 @@
package signaling
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"net/url"
@ -51,6 +49,7 @@ type BackendClient struct {
pool *HttpClientPool
capabilities *Capabilities
buffers BufferPool
}
func NewBackendClient(config *goconf.ConfigFile, maxConcurrentRequestsPerHost int, version string, etcdClient *EtcdClient) (*BackendClient, error) {
@ -140,13 +139,14 @@ func (b *BackendClient) PerformJSONRequest(ctx context.Context, u *url.URL, requ
}
defer pool.Put(c)
data, err := json.Marshal(request)
data, err := b.buffers.MarshalAsJSON(request)
if err != nil {
log.Printf("Could not marshal request %+v: %s", request, err)
return err
}
req, err := http.NewRequestWithContext(ctx, "POST", requestUrl.String(), bytes.NewReader(data))
defer b.buffers.Put(data)
req, err := http.NewRequestWithContext(ctx, "POST", requestUrl.String(), data)
if err != nil {
log.Printf("Could not create request to %s: %s", requestUrl, err)
return err
@ -160,11 +160,11 @@ func (b *BackendClient) PerformJSONRequest(ctx context.Context, u *url.URL, requ
}
// Add checksum so the backend can validate the request.
AddBackendChecksum(req, data, secret)
AddBackendChecksum(req, data.Bytes(), secret)
resp, err := c.Do(req)
if err != nil {
log.Printf("Could not send request %s to %s: %s", string(data), req.URL, err)
log.Printf("Could not send request %s to %s: %s", data.String(), req.URL, err)
return err
}
defer resp.Body.Close()
@ -175,12 +175,14 @@ func (b *BackendClient) PerformJSONRequest(ctx context.Context, u *url.URL, requ
return ErrUnsupportedContentType
}
body, err := io.ReadAll(resp.Body)
body, err := b.buffers.ReadAll(resp.Body)
if err != nil {
log.Printf("Could not read response body from %s: %s", req.URL, err)
return err
}
defer b.buffers.Put(body)
if isOcsRequest(u) || req.Header.Get("OCS-APIRequest") != "" {
// OCS response are wrapped in an OCS container that needs to be parsed
// to get the actual contents:
@ -191,17 +193,17 @@ func (b *BackendClient) PerformJSONRequest(ctx context.Context, u *url.URL, requ
// }
// }
var ocs OcsResponse
if err := json.Unmarshal(body, &ocs); err != nil {
log.Printf("Could not decode OCS response %s from %s: %s", string(body), req.URL, err)
if err := json.Unmarshal(body.Bytes(), &ocs); err != nil {
log.Printf("Could not decode OCS response %s from %s: %s", body.String(), req.URL, err)
return err
} else if ocs.Ocs == nil || len(ocs.Ocs.Data) == 0 {
log.Printf("Incomplete OCS response %s from %s", string(body), req.URL)
log.Printf("Incomplete OCS response %s from %s", body.String(), req.URL)
return ErrIncompleteResponse
}
switch ocs.Ocs.Meta.StatusCode {
case http.StatusTooManyRequests:
log.Printf("Throttled OCS response %s from %s", string(body), req.URL)
log.Printf("Throttled OCS response %s from %s", body.String(), req.URL)
return ErrThrottledResponse
}
@ -209,8 +211,8 @@ func (b *BackendClient) PerformJSONRequest(ctx context.Context, u *url.URL, requ
log.Printf("Could not decode OCS response body %s from %s: %s", string(ocs.Ocs.Data), req.URL, err)
return err
}
} else if err := json.Unmarshal(body, response); err != nil {
log.Printf("Could not decode response body %s from %s: %s", string(body), req.URL, err)
} else if err := json.Unmarshal(body.Bytes(), response); err != nil {
log.Printf("Could not decode response body %s from %s: %s", body.String(), req.URL, err)
return err
}
return nil

View file

@ -70,6 +70,8 @@ type BackendServer struct {
statsAllowedIps atomic.Pointer[AllowedIps]
invalidSecret []byte
buffers BufferPool
}
func NewBackendServer(config *goconf.ConfigFile, hub *Hub, version string) (*BackendServer, error) {
@ -284,14 +286,15 @@ func (b *BackendServer) parseRequestBody(f func(http.ResponseWriter, *http.Reque
return
}
body, err := io.ReadAll(r.Body)
body, err := b.buffers.ReadAll(r.Body)
if err != nil {
log.Println("Error reading body: ", err)
http.Error(w, "Could not read body", http.StatusBadRequest)
return
}
defer b.buffers.Put(body)
f(w, r, body)
f(w, r, body.Bytes())
}
}

79
buffer_pool.go Normal file
View file

@ -0,0 +1,79 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2024 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package signaling
import (
"bytes"
"encoding/json"
"io"
"sync"
)
type BufferPool struct {
buffers sync.Pool
copyBuffers sync.Pool
}
func (p *BufferPool) Get() *bytes.Buffer {
b := p.buffers.Get()
if b == nil {
return bytes.NewBuffer(nil)
}
return b.(*bytes.Buffer)
}
func (p *BufferPool) Put(b *bytes.Buffer) {
if b == nil {
return
}
b.Reset()
p.buffers.Put(b)
}
func (p *BufferPool) ReadAll(r io.Reader) (*bytes.Buffer, error) {
buf := p.copyBuffers.Get()
if buf == nil {
buf = make([]byte, 1024)
}
defer p.copyBuffers.Put(buf)
b := p.Get()
if _, err := io.CopyBuffer(b, r, buf.([]byte)); err != nil {
p.Put(b)
return nil, err
}
return b, nil
}
func (p *BufferPool) MarshalAsJSON(v any) (*bytes.Buffer, error) {
b := p.Get()
encoder := json.NewEncoder(b)
if err := encoder.Encode(v); err != nil {
p.Put(b)
return nil, err
}
return b, nil
}

View file

@ -25,7 +25,6 @@ import (
"context"
"encoding/json"
"errors"
"io"
"log"
"net/http"
"net/url"
@ -182,18 +181,20 @@ func (e *capabilitiesEntry) update(ctx context.Context, u *url.URL, now time.Tim
return e.errorIfMustRevalidate(ErrUnsupportedContentType)
}
body, err := io.ReadAll(response.Body)
body, err := e.c.buffers.ReadAll(response.Body)
if err != nil {
log.Printf("Could not read response body from %s: %s", url, err)
return e.errorIfMustRevalidate(err)
}
defer e.c.buffers.Put(body)
var ocs OcsResponse
if err := json.Unmarshal(body, &ocs); err != nil {
log.Printf("Could not decode OCS response %s from %s: %s", string(body), url, err)
if err := json.Unmarshal(body.Bytes(), &ocs); err != nil {
log.Printf("Could not decode OCS response %s from %s: %s", body.String(), url, err)
return e.errorIfMustRevalidate(err)
} else if ocs.Ocs == nil || len(ocs.Ocs.Data) == 0 {
log.Printf("Incomplete OCS response %s from %s", string(body), url)
log.Printf("Incomplete OCS response %s from %s", body.String(), url)
return e.errorIfMustRevalidate(ErrIncompleteResponse)
}
@ -239,6 +240,8 @@ type Capabilities struct {
pool *HttpClientPool
entries map[string]*capabilitiesEntry
nextInvalidate map[string]time.Time
buffers BufferPool
}
func NewCapabilities(version string, pool *HttpClientPool) (*Capabilities, error) {

View file

@ -82,11 +82,7 @@ func IsValidCountry(country string) bool {
var (
InvalidFormat = NewError("invalid_format", "Invalid data format.")
bufferPool = sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
},
}
bufferPool BufferPool
)
type WritableClientMessage interface {
@ -391,10 +387,8 @@ func (c *Client) ReadPump() {
continue
}
decodeBuffer := bufferPool.Get().(*bytes.Buffer)
decodeBuffer.Reset()
if _, err := decodeBuffer.ReadFrom(reader); err != nil {
bufferPool.Put(decodeBuffer)
decodeBuffer, err := bufferPool.ReadAll(reader)
if err != nil {
if sessionId := c.GetSessionId(); sessionId != "" {
log.Printf("Error reading message from client %s: %v", sessionId, err)
} else {

View file

@ -82,8 +82,9 @@ func TestConcurrentStringStringMap(t *testing.T) {
defer wg.Done()
key := "key-" + strconv.Itoa(x)
rnd := newRandomString(32)
for y := 0; y < count; y = y + 1 {
value := newRandomString(32)
value := rnd + "-" + strconv.Itoa(y)
m.Set(key, value)
if v, found := m.Get(key); !assert.True(found, "Expected entry for key %s", key) ||
!assert.Equal(value, v, "Unexpected value for key %s", key) {

View file

@ -46,6 +46,8 @@ const (
var (
ErrFederationNotSupported = NewError("federation_unsupported", "The target server does not support federation.")
federationWriteBufferPool = &sync.Pool{}
)
func isClosedError(err error) bool {
@ -102,7 +104,9 @@ func NewFederationClient(ctx context.Context, hub *Hub, session *ClientSession,
return nil, fmt.Errorf("expected federation room message, got %+v", message)
}
var dialer websocket.Dialer
dialer := &websocket.Dialer{
WriteBufferPool: federationWriteBufferPool,
}
if hub.skipFederationVerify {
dialer.TLSClientConfig = &tls.Config{
InsecureSkipVerify: true,
@ -130,7 +134,7 @@ func NewFederationClient(ctx context.Context, hub *Hub, session *ClientSession,
reconnectDelay: initialFederationReconnectInterval,
dialer: &dialer,
dialer: dialer,
url: url,
closer: NewCloser(),
}

3
hub.go
View file

@ -105,6 +105,8 @@ var (
websocketReadBufferSize = 4096
websocketWriteBufferSize = 4096
websocketWriteBufferPool = &sync.Pool{}
// Delay after which a screen publisher should be cleaned up.
cleanupScreenPublisherDelay = time.Second
@ -322,6 +324,7 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer
upgrader: websocket.Upgrader{
ReadBufferSize: websocketReadBufferSize,
WriteBufferSize: websocketWriteBufferSize,
WriteBufferPool: websocketWriteBufferPool,
},
cookie: NewSessionIdCodec([]byte(hashKey), blockBytes),
info: NewWelcomeServerMessage(version, DefaultFeatures...),

View file

@ -767,7 +767,7 @@ func TestWebsocketFeatures(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
conn, response, err := websocket.DefaultDialer.DialContext(ctx, getWebsocketUrl(server.URL), nil)
conn, response, err := testClientDialer.DialContext(ctx, getWebsocketUrl(server.URL), nil)
require.NoError(err)
defer conn.Close() // nolint

View file

@ -119,8 +119,9 @@ const (
var (
janusDialer = websocket.Dialer{
Subprotocols: []string{"janus-protocol"},
Proxy: http.ProxyFromEnvironment,
Subprotocols: []string{"janus-protocol"},
Proxy: http.ProxyFromEnvironment,
WriteBufferPool: &sync.Pool{},
}
)

View file

@ -49,6 +49,10 @@ var (
testInternalSecret = []byte("internal-secret")
ErrNoMessageReceived = fmt.Errorf("no message was received by the server")
testClientDialer = websocket.Dialer{
WriteBufferPool: &sync.Pool{},
}
)
type TestBackendClientAuthParams struct {
@ -226,7 +230,7 @@ type TestClient struct {
func NewTestClientContext(ctx context.Context, t *testing.T, server *httptest.Server, hub *Hub) *TestClient {
// Reference "hub" to prevent compiler error.
conn, _, err := websocket.DefaultDialer.DialContext(ctx, getWebsocketUrl(server.URL), nil)
conn, _, err := testClientDialer.DialContext(ctx, getWebsocketUrl(server.URL), nil)
require.NoError(t, err)
messageChan := make(chan []byte)