diff --git a/federation/client.go b/federation/client.go index 5c316e56..c84b437a 100644 --- a/federation/client.go +++ b/federation/client.go @@ -30,6 +30,8 @@ type Client struct { ServerName string UserAgent string Key *SigningKey + + ResponseSizeLimit int64 } func NewClient(serverName string, key *SigningKey, cache ResolutionCache) *Client { @@ -45,6 +47,8 @@ func NewClient(serverName string, key *SigningKey, cache ResolutionCache) *Clien UserAgent: mautrix.DefaultUserAgent, ServerName: serverName, Key: key, + + ResponseSizeLimit: 128 * 1024 * 1024, } } @@ -318,7 +322,16 @@ func (c *Client) MakeFullRequest(ctx context.Context, params RequestParams) ([]b body, err = mautrix.ParseErrorResponse(req, resp) return body, resp, err } else if params.ResponseJSON != nil || !params.DontReadBody { - body, err = io.ReadAll(resp.Body) + if resp.ContentLength > c.ResponseSizeLimit { + return body, resp, mautrix.HTTPError{ + Request: req, + Response: resp, + + Message: "response body too long", + WrappedError: fmt.Errorf("%.2f MiB", float64(resp.ContentLength)/1024/1024), + } + } + body, err = io.ReadAll(io.LimitReader(resp.Body, c.ResponseSizeLimit+1)) if err != nil { return body, resp, mautrix.HTTPError{ Request: req, @@ -328,6 +341,15 @@ func (c *Client) MakeFullRequest(ctx context.Context, params RequestParams) ([]b WrappedError: err, } } + if len(body) > int(c.ResponseSizeLimit) { + return body, resp, mautrix.HTTPError{ + Request: req, + Response: resp, + + Message: "failed to read response body", + WrappedError: fmt.Errorf("exceeded read limit"), + } + } if params.ResponseJSON != nil { err = json.Unmarshal(body, params.ResponseJSON) if err != nil {