client: add response size limits
Some checks are pending
Go / Lint (latest) (push) Waiting to run
Go / Build (old, libolm) (push) Waiting to run
Go / Build (latest, libolm) (push) Waiting to run
Go / Build (old, goolm) (push) Waiting to run
Go / Build (latest, goolm) (push) Waiting to run

This commit is contained in:
Tulir Asokan 2025-10-18 13:37:19 +02:00
commit c50460cd6e
4 changed files with 124 additions and 48 deletions

145
client.go
View file

@ -111,6 +111,8 @@ type Client struct {
// Set to true to disable automatically sleeping on 429 errors.
IgnoreRateLimit bool
ResponseSizeLimit int64
txnID int32
// Should the ?user_id= query parameter be set in requests?
@ -143,6 +145,8 @@ func DiscoverClientAPI(ctx context.Context, serverName string) (*ClientWellKnown
return DiscoverClientAPIWithClient(ctx, &http.Client{Timeout: 30 * time.Second}, serverName)
}
const WellKnownMaxSize = 64 * 1024
func DiscoverClientAPIWithClient(ctx context.Context, client *http.Client, serverName string) (*ClientWellKnown, error) {
wellKnownURL := url.URL{
Scheme: "https",
@ -168,11 +172,15 @@ func DiscoverClientAPIWithClient(ctx context.Context, client *http.Client, serve
if resp.StatusCode == http.StatusNotFound {
return nil, nil
} else if resp.ContentLength > WellKnownMaxSize {
return nil, errors.New(".well-known response too large")
}
data, err := io.ReadAll(resp.Body)
data, err := io.ReadAll(io.LimitReader(resp.Body, WellKnownMaxSize))
if err != nil {
return nil, err
} else if len(data) >= WellKnownMaxSize {
return nil, errors.New(".well-known response too large")
}
var wellKnown ClientWellKnown
@ -395,24 +403,25 @@ func (cli *Client) MakeRequest(ctx context.Context, method string, httpURL strin
return cli.MakeFullRequest(ctx, FullRequest{Method: method, URL: httpURL, RequestJSON: reqBody, ResponseJSON: resBody})
}
type ClientResponseHandler = func(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error)
type ClientResponseHandler = func(req *http.Request, res *http.Response, responseJSON any, sizeLimit int64) ([]byte, error)
type FullRequest struct {
Method string
URL string
Headers http.Header
RequestJSON interface{}
RequestBytes []byte
RequestBody io.Reader
RequestLength int64
ResponseJSON interface{}
MaxAttempts int
BackoffDuration time.Duration
SensitiveContent bool
Handler ClientResponseHandler
DontReadResponse bool
Logger *zerolog.Logger
Client *http.Client
Method string
URL string
Headers http.Header
RequestJSON interface{}
RequestBytes []byte
RequestBody io.Reader
RequestLength int64
ResponseJSON interface{}
MaxAttempts int
BackoffDuration time.Duration
SensitiveContent bool
Handler ClientResponseHandler
DontReadResponse bool
ResponseSizeLimit int64
Logger *zerolog.Logger
Client *http.Client
}
var requestID int32
@ -537,10 +546,25 @@ func (cli *Client) MakeFullRequestWithResp(ctx context.Context, params FullReque
if len(cli.AccessToken) > 0 {
req.Header.Set("Authorization", "Bearer "+cli.AccessToken)
}
if params.ResponseSizeLimit == 0 {
params.ResponseSizeLimit = cli.ResponseSizeLimit
}
if params.ResponseSizeLimit == 0 {
params.ResponseSizeLimit = DefaultResponseSizeLimit
}
if params.Client == nil {
params.Client = cli.Client
}
return cli.executeCompiledRequest(req, params.MaxAttempts-1, params.BackoffDuration, params.ResponseJSON, params.Handler, params.DontReadResponse, params.Client)
return cli.executeCompiledRequest(
req,
params.MaxAttempts-1,
params.BackoffDuration,
params.ResponseJSON,
params.Handler,
params.DontReadResponse,
params.ResponseSizeLimit,
params.Client,
)
}
func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger {
@ -551,7 +575,17 @@ func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger {
return log
}
func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff time.Duration, responseJSON any, handler ClientResponseHandler, dontReadResponse bool, client *http.Client) ([]byte, *http.Response, error) {
func (cli *Client) doRetry(
req *http.Request,
cause error,
retries int,
backoff time.Duration,
responseJSON any,
handler ClientResponseHandler,
dontReadResponse bool,
sizeLimit int64,
client *http.Client,
) ([]byte, *http.Response, error) {
log := zerolog.Ctx(req.Context())
if req.Body != nil {
var err error
@ -585,11 +619,23 @@ func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff
if cli.UpdateRequestOnRetry != nil {
req = cli.UpdateRequestOnRetry(req, cause)
}
return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler, dontReadResponse, client)
return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler, dontReadResponse, sizeLimit, client)
}
func readResponseBody(req *http.Request, res *http.Response) ([]byte, error) {
contents, err := io.ReadAll(res.Body)
func readResponseBody(req *http.Request, res *http.Response, limit int64) ([]byte, error) {
if res.ContentLength > limit {
return nil, HTTPError{
Request: req,
Response: res,
Message: "not reading response",
WrappedError: fmt.Errorf("%w (%.2f MiB)", ErrResponseTooLong, float64(res.ContentLength)/1024/1024),
}
}
contents, err := io.ReadAll(io.LimitReader(res.Body, limit+1))
if err == nil && len(contents) > int(limit) {
err = ErrBodyReadReachedLimit
}
if err != nil {
return nil, HTTPError{
Request: req,
@ -610,17 +656,20 @@ func closeTemp(log *zerolog.Logger, file *os.File) {
}
}
func streamResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) {
func streamResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) {
log := zerolog.Ctx(req.Context())
file, err := os.CreateTemp("", "mautrix-response-")
if err != nil {
log.Warn().Err(err).Msg("Failed to create temporary file for streaming response")
_, err = handleNormalResponse(req, res, responseJSON)
_, err = handleNormalResponse(req, res, responseJSON, limit)
return nil, err
}
defer closeTemp(log, file)
if _, err = io.Copy(file, res.Body); err != nil {
var n int64
if n, err = io.Copy(file, io.LimitReader(res.Body, limit+1)); err != nil {
return nil, fmt.Errorf("failed to copy response to file: %w", err)
} else if n > limit {
return nil, ErrBodyReadReachedLimit
} else if _, err = file.Seek(0, 0); err != nil {
return nil, fmt.Errorf("failed to seek to beginning of response file: %w", err)
} else if err = json.NewDecoder(file).Decode(responseJSON); err != nil {
@ -630,12 +679,12 @@ func streamResponse(req *http.Request, res *http.Response, responseJSON interfac
}
}
func noopHandleResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) {
func noopHandleResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) {
return nil, nil
}
func handleNormalResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) {
if contents, err := readResponseBody(req, res); err != nil {
func handleNormalResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) {
if contents, err := readResponseBody(req, res, limit); err != nil {
return nil, err
} else if responseJSON == nil {
return contents, nil
@ -653,8 +702,12 @@ func handleNormalResponse(req *http.Request, res *http.Response, responseJSON in
}
}
const ErrorResponseSizeLimit = 512 * 1024
var DefaultResponseSizeLimit int64 = 512 * 1024 * 1024
func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) {
contents, err := readResponseBody(req, res)
contents, err := readResponseBody(req, res, ErrorResponseSizeLimit)
if err != nil {
return contents, err
}
@ -673,7 +726,16 @@ func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) {
}
}
func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backoff time.Duration, responseJSON any, handler ClientResponseHandler, dontReadResponse bool, client *http.Client) ([]byte, *http.Response, error) {
func (cli *Client) executeCompiledRequest(
req *http.Request,
retries int,
backoff time.Duration,
responseJSON any,
handler ClientResponseHandler,
dontReadResponse bool,
sizeLimit int64,
client *http.Client,
) ([]byte, *http.Response, error) {
cli.RequestStart(req)
startTime := time.Now()
res, err := client.Do(req)
@ -683,7 +745,9 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof
}
if err != nil {
if retries > 0 && !errors.Is(err, context.Canceled) {
return cli.doRetry(req, err, retries, backoff, responseJSON, handler, dontReadResponse, client)
return cli.doRetry(
req, err, retries, backoff, responseJSON, handler, dontReadResponse, sizeLimit, client,
)
}
err = HTTPError{
Request: req,
@ -698,7 +762,9 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof
if retries > 0 && retryafter.Should(res.StatusCode, !cli.IgnoreRateLimit) {
backoff = retryafter.Parse(res.Header.Get("Retry-After"), backoff)
return cli.doRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler, dontReadResponse, client)
return cli.doRetry(
req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler, dontReadResponse, sizeLimit, client,
)
}
var body []byte
@ -706,7 +772,7 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof
body, err = ParseErrorResponse(req, res)
cli.LogRequestDone(req, res, nil, nil, len(body), duration)
} else {
body, err = handler(req, res, responseJSON)
body, err = handler(req, res, responseJSON, sizeLimit)
cli.LogRequestDone(req, res, nil, err, len(body), duration)
}
return body, res, err
@ -1628,11 +1694,20 @@ func (cli *Client) FullStateEvent(ctx context.Context, roomID id.RoomID, eventTy
}
// parseRoomStateArray parses a JSON array as a stream and stores the events inside it in a room state map.
func parseRoomStateArray(_ *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) {
func parseRoomStateArray(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) {
if res.ContentLength > limit {
return nil, HTTPError{
Request: req,
Response: res,
Message: "not reading response",
WrappedError: fmt.Errorf("%w (%.2f MiB)", ErrResponseTooLong, float64(res.ContentLength)/1024/1024),
}
}
response := make(RoomStateMap)
responsePtr := responseJSON.(*map[event.Type]map[string]*event.Event)
*responsePtr = response
dec := json.NewDecoder(res.Body)
dec := json.NewDecoder(io.LimitReader(res.Body, limit))
arrayStart, err := dec.Token()
if err != nil {

View file

@ -82,6 +82,9 @@ var (
var (
ErrClientIsNil = errors.New("client is nil")
ErrClientHasNoHomeserver = errors.New("client has no homeserver set")
ErrResponseTooLong = errors.New("response content length too long")
ErrBodyReadReachedLimit = errors.New("reached response size limit while reading body")
)
// HTTPError An HTTP Error response, which may wrap an underlying native Go Error.

View file

@ -48,7 +48,7 @@ func NewClient(serverName string, key *SigningKey, cache ResolutionCache) *Clien
ServerName: serverName,
Key: key,
ResponseSizeLimit: 128 * 1024 * 1024,
ResponseSizeLimit: mautrix.DefaultResponseSizeLimit,
}
}
@ -327,11 +327,14 @@ func (c *Client) MakeFullRequest(ctx context.Context, params RequestParams) ([]b
Request: req,
Response: resp,
Message: "response body too long",
WrappedError: fmt.Errorf("%.2f MiB", float64(resp.ContentLength)/1024/1024),
Message: "not reading response",
WrappedError: fmt.Errorf("%w (%.2f MiB)", mautrix.ErrResponseTooLong, float64(resp.ContentLength)/1024/1024),
}
}
body, err = io.ReadAll(io.LimitReader(resp.Body, c.ResponseSizeLimit+1))
if err == nil && len(body) > int(c.ResponseSizeLimit) {
err = mautrix.ErrBodyReadReachedLimit
}
if err != nil {
return body, resp, mautrix.HTTPError{
Request: req,
@ -341,15 +344,6 @@ func (c *Client) MakeFullRequest(ctx context.Context, params RequestParams) ([]b
WrappedError: err,
}
}
if len(body) > int(c.ResponseSizeLimit) {
return body, resp, mautrix.HTTPError{
Request: req,
Response: resp,
Message: "failed to read response body",
WrappedError: fmt.Errorf("exceeded read limit"),
}
}
if params.ResponseJSON != nil {
err = json.Unmarshal(body, params.ResponseJSON)
if err != nil {

View file

@ -20,6 +20,8 @@ import (
"time"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
)
type ResolvedServerName struct {
@ -171,9 +173,11 @@ func RequestWellKnown(ctx context.Context, cli *http.Client, hostname string) (*
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, time.Time{}, fmt.Errorf("unexpected status code %d", resp.StatusCode)
} else if resp.ContentLength > mautrix.WellKnownMaxSize {
return nil, time.Time{}, fmt.Errorf("response too large: %d bytes", resp.ContentLength)
}
var respData RespWellKnown
err = json.NewDecoder(io.LimitReader(resp.Body, 50*1024)).Decode(&respData)
err = json.NewDecoder(io.LimitReader(resp.Body, mautrix.WellKnownMaxSize)).Decode(&respData)
if err != nil {
return nil, time.Time{}, fmt.Errorf("failed to decode response: %w", err)
} else if respData.Server == "" {