mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
client: add response size limits
This commit is contained in:
parent
827bb4c621
commit
c50460cd6e
4 changed files with 124 additions and 48 deletions
145
client.go
145
client.go
|
|
@ -111,6 +111,8 @@ type Client struct {
|
|||
// Set to true to disable automatically sleeping on 429 errors.
|
||||
IgnoreRateLimit bool
|
||||
|
||||
ResponseSizeLimit int64
|
||||
|
||||
txnID int32
|
||||
|
||||
// Should the ?user_id= query parameter be set in requests?
|
||||
|
|
@ -143,6 +145,8 @@ func DiscoverClientAPI(ctx context.Context, serverName string) (*ClientWellKnown
|
|||
return DiscoverClientAPIWithClient(ctx, &http.Client{Timeout: 30 * time.Second}, serverName)
|
||||
}
|
||||
|
||||
const WellKnownMaxSize = 64 * 1024
|
||||
|
||||
func DiscoverClientAPIWithClient(ctx context.Context, client *http.Client, serverName string) (*ClientWellKnown, error) {
|
||||
wellKnownURL := url.URL{
|
||||
Scheme: "https",
|
||||
|
|
@ -168,11 +172,15 @@ func DiscoverClientAPIWithClient(ctx context.Context, client *http.Client, serve
|
|||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return nil, nil
|
||||
} else if resp.ContentLength > WellKnownMaxSize {
|
||||
return nil, errors.New(".well-known response too large")
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
data, err := io.ReadAll(io.LimitReader(resp.Body, WellKnownMaxSize))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if len(data) >= WellKnownMaxSize {
|
||||
return nil, errors.New(".well-known response too large")
|
||||
}
|
||||
|
||||
var wellKnown ClientWellKnown
|
||||
|
|
@ -395,24 +403,25 @@ func (cli *Client) MakeRequest(ctx context.Context, method string, httpURL strin
|
|||
return cli.MakeFullRequest(ctx, FullRequest{Method: method, URL: httpURL, RequestJSON: reqBody, ResponseJSON: resBody})
|
||||
}
|
||||
|
||||
type ClientResponseHandler = func(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error)
|
||||
type ClientResponseHandler = func(req *http.Request, res *http.Response, responseJSON any, sizeLimit int64) ([]byte, error)
|
||||
|
||||
type FullRequest struct {
|
||||
Method string
|
||||
URL string
|
||||
Headers http.Header
|
||||
RequestJSON interface{}
|
||||
RequestBytes []byte
|
||||
RequestBody io.Reader
|
||||
RequestLength int64
|
||||
ResponseJSON interface{}
|
||||
MaxAttempts int
|
||||
BackoffDuration time.Duration
|
||||
SensitiveContent bool
|
||||
Handler ClientResponseHandler
|
||||
DontReadResponse bool
|
||||
Logger *zerolog.Logger
|
||||
Client *http.Client
|
||||
Method string
|
||||
URL string
|
||||
Headers http.Header
|
||||
RequestJSON interface{}
|
||||
RequestBytes []byte
|
||||
RequestBody io.Reader
|
||||
RequestLength int64
|
||||
ResponseJSON interface{}
|
||||
MaxAttempts int
|
||||
BackoffDuration time.Duration
|
||||
SensitiveContent bool
|
||||
Handler ClientResponseHandler
|
||||
DontReadResponse bool
|
||||
ResponseSizeLimit int64
|
||||
Logger *zerolog.Logger
|
||||
Client *http.Client
|
||||
}
|
||||
|
||||
var requestID int32
|
||||
|
|
@ -537,10 +546,25 @@ func (cli *Client) MakeFullRequestWithResp(ctx context.Context, params FullReque
|
|||
if len(cli.AccessToken) > 0 {
|
||||
req.Header.Set("Authorization", "Bearer "+cli.AccessToken)
|
||||
}
|
||||
if params.ResponseSizeLimit == 0 {
|
||||
params.ResponseSizeLimit = cli.ResponseSizeLimit
|
||||
}
|
||||
if params.ResponseSizeLimit == 0 {
|
||||
params.ResponseSizeLimit = DefaultResponseSizeLimit
|
||||
}
|
||||
if params.Client == nil {
|
||||
params.Client = cli.Client
|
||||
}
|
||||
return cli.executeCompiledRequest(req, params.MaxAttempts-1, params.BackoffDuration, params.ResponseJSON, params.Handler, params.DontReadResponse, params.Client)
|
||||
return cli.executeCompiledRequest(
|
||||
req,
|
||||
params.MaxAttempts-1,
|
||||
params.BackoffDuration,
|
||||
params.ResponseJSON,
|
||||
params.Handler,
|
||||
params.DontReadResponse,
|
||||
params.ResponseSizeLimit,
|
||||
params.Client,
|
||||
)
|
||||
}
|
||||
|
||||
func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger {
|
||||
|
|
@ -551,7 +575,17 @@ func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger {
|
|||
return log
|
||||
}
|
||||
|
||||
func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff time.Duration, responseJSON any, handler ClientResponseHandler, dontReadResponse bool, client *http.Client) ([]byte, *http.Response, error) {
|
||||
func (cli *Client) doRetry(
|
||||
req *http.Request,
|
||||
cause error,
|
||||
retries int,
|
||||
backoff time.Duration,
|
||||
responseJSON any,
|
||||
handler ClientResponseHandler,
|
||||
dontReadResponse bool,
|
||||
sizeLimit int64,
|
||||
client *http.Client,
|
||||
) ([]byte, *http.Response, error) {
|
||||
log := zerolog.Ctx(req.Context())
|
||||
if req.Body != nil {
|
||||
var err error
|
||||
|
|
@ -585,11 +619,23 @@ func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff
|
|||
if cli.UpdateRequestOnRetry != nil {
|
||||
req = cli.UpdateRequestOnRetry(req, cause)
|
||||
}
|
||||
return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler, dontReadResponse, client)
|
||||
return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler, dontReadResponse, sizeLimit, client)
|
||||
}
|
||||
|
||||
func readResponseBody(req *http.Request, res *http.Response) ([]byte, error) {
|
||||
contents, err := io.ReadAll(res.Body)
|
||||
func readResponseBody(req *http.Request, res *http.Response, limit int64) ([]byte, error) {
|
||||
if res.ContentLength > limit {
|
||||
return nil, HTTPError{
|
||||
Request: req,
|
||||
Response: res,
|
||||
|
||||
Message: "not reading response",
|
||||
WrappedError: fmt.Errorf("%w (%.2f MiB)", ErrResponseTooLong, float64(res.ContentLength)/1024/1024),
|
||||
}
|
||||
}
|
||||
contents, err := io.ReadAll(io.LimitReader(res.Body, limit+1))
|
||||
if err == nil && len(contents) > int(limit) {
|
||||
err = ErrBodyReadReachedLimit
|
||||
}
|
||||
if err != nil {
|
||||
return nil, HTTPError{
|
||||
Request: req,
|
||||
|
|
@ -610,17 +656,20 @@ func closeTemp(log *zerolog.Logger, file *os.File) {
|
|||
}
|
||||
}
|
||||
|
||||
func streamResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) {
|
||||
func streamResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) {
|
||||
log := zerolog.Ctx(req.Context())
|
||||
file, err := os.CreateTemp("", "mautrix-response-")
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to create temporary file for streaming response")
|
||||
_, err = handleNormalResponse(req, res, responseJSON)
|
||||
_, err = handleNormalResponse(req, res, responseJSON, limit)
|
||||
return nil, err
|
||||
}
|
||||
defer closeTemp(log, file)
|
||||
if _, err = io.Copy(file, res.Body); err != nil {
|
||||
var n int64
|
||||
if n, err = io.Copy(file, io.LimitReader(res.Body, limit+1)); err != nil {
|
||||
return nil, fmt.Errorf("failed to copy response to file: %w", err)
|
||||
} else if n > limit {
|
||||
return nil, ErrBodyReadReachedLimit
|
||||
} else if _, err = file.Seek(0, 0); err != nil {
|
||||
return nil, fmt.Errorf("failed to seek to beginning of response file: %w", err)
|
||||
} else if err = json.NewDecoder(file).Decode(responseJSON); err != nil {
|
||||
|
|
@ -630,12 +679,12 @@ func streamResponse(req *http.Request, res *http.Response, responseJSON interfac
|
|||
}
|
||||
}
|
||||
|
||||
func noopHandleResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) {
|
||||
func noopHandleResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func handleNormalResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) {
|
||||
if contents, err := readResponseBody(req, res); err != nil {
|
||||
func handleNormalResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) {
|
||||
if contents, err := readResponseBody(req, res, limit); err != nil {
|
||||
return nil, err
|
||||
} else if responseJSON == nil {
|
||||
return contents, nil
|
||||
|
|
@ -653,8 +702,12 @@ func handleNormalResponse(req *http.Request, res *http.Response, responseJSON in
|
|||
}
|
||||
}
|
||||
|
||||
const ErrorResponseSizeLimit = 512 * 1024
|
||||
|
||||
var DefaultResponseSizeLimit int64 = 512 * 1024 * 1024
|
||||
|
||||
func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) {
|
||||
contents, err := readResponseBody(req, res)
|
||||
contents, err := readResponseBody(req, res, ErrorResponseSizeLimit)
|
||||
if err != nil {
|
||||
return contents, err
|
||||
}
|
||||
|
|
@ -673,7 +726,16 @@ func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backoff time.Duration, responseJSON any, handler ClientResponseHandler, dontReadResponse bool, client *http.Client) ([]byte, *http.Response, error) {
|
||||
func (cli *Client) executeCompiledRequest(
|
||||
req *http.Request,
|
||||
retries int,
|
||||
backoff time.Duration,
|
||||
responseJSON any,
|
||||
handler ClientResponseHandler,
|
||||
dontReadResponse bool,
|
||||
sizeLimit int64,
|
||||
client *http.Client,
|
||||
) ([]byte, *http.Response, error) {
|
||||
cli.RequestStart(req)
|
||||
startTime := time.Now()
|
||||
res, err := client.Do(req)
|
||||
|
|
@ -683,7 +745,9 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof
|
|||
}
|
||||
if err != nil {
|
||||
if retries > 0 && !errors.Is(err, context.Canceled) {
|
||||
return cli.doRetry(req, err, retries, backoff, responseJSON, handler, dontReadResponse, client)
|
||||
return cli.doRetry(
|
||||
req, err, retries, backoff, responseJSON, handler, dontReadResponse, sizeLimit, client,
|
||||
)
|
||||
}
|
||||
err = HTTPError{
|
||||
Request: req,
|
||||
|
|
@ -698,7 +762,9 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof
|
|||
|
||||
if retries > 0 && retryafter.Should(res.StatusCode, !cli.IgnoreRateLimit) {
|
||||
backoff = retryafter.Parse(res.Header.Get("Retry-After"), backoff)
|
||||
return cli.doRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler, dontReadResponse, client)
|
||||
return cli.doRetry(
|
||||
req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler, dontReadResponse, sizeLimit, client,
|
||||
)
|
||||
}
|
||||
|
||||
var body []byte
|
||||
|
|
@ -706,7 +772,7 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof
|
|||
body, err = ParseErrorResponse(req, res)
|
||||
cli.LogRequestDone(req, res, nil, nil, len(body), duration)
|
||||
} else {
|
||||
body, err = handler(req, res, responseJSON)
|
||||
body, err = handler(req, res, responseJSON, sizeLimit)
|
||||
cli.LogRequestDone(req, res, nil, err, len(body), duration)
|
||||
}
|
||||
return body, res, err
|
||||
|
|
@ -1628,11 +1694,20 @@ func (cli *Client) FullStateEvent(ctx context.Context, roomID id.RoomID, eventTy
|
|||
}
|
||||
|
||||
// parseRoomStateArray parses a JSON array as a stream and stores the events inside it in a room state map.
|
||||
func parseRoomStateArray(_ *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) {
|
||||
func parseRoomStateArray(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) {
|
||||
if res.ContentLength > limit {
|
||||
return nil, HTTPError{
|
||||
Request: req,
|
||||
Response: res,
|
||||
|
||||
Message: "not reading response",
|
||||
WrappedError: fmt.Errorf("%w (%.2f MiB)", ErrResponseTooLong, float64(res.ContentLength)/1024/1024),
|
||||
}
|
||||
}
|
||||
response := make(RoomStateMap)
|
||||
responsePtr := responseJSON.(*map[event.Type]map[string]*event.Event)
|
||||
*responsePtr = response
|
||||
dec := json.NewDecoder(res.Body)
|
||||
dec := json.NewDecoder(io.LimitReader(res.Body, limit))
|
||||
|
||||
arrayStart, err := dec.Token()
|
||||
if err != nil {
|
||||
|
|
|
|||
3
error.go
3
error.go
|
|
@ -82,6 +82,9 @@ var (
|
|||
var (
|
||||
ErrClientIsNil = errors.New("client is nil")
|
||||
ErrClientHasNoHomeserver = errors.New("client has no homeserver set")
|
||||
|
||||
ErrResponseTooLong = errors.New("response content length too long")
|
||||
ErrBodyReadReachedLimit = errors.New("reached response size limit while reading body")
|
||||
)
|
||||
|
||||
// HTTPError An HTTP Error response, which may wrap an underlying native Go Error.
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ func NewClient(serverName string, key *SigningKey, cache ResolutionCache) *Clien
|
|||
ServerName: serverName,
|
||||
Key: key,
|
||||
|
||||
ResponseSizeLimit: 128 * 1024 * 1024,
|
||||
ResponseSizeLimit: mautrix.DefaultResponseSizeLimit,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -327,11 +327,14 @@ func (c *Client) MakeFullRequest(ctx context.Context, params RequestParams) ([]b
|
|||
Request: req,
|
||||
Response: resp,
|
||||
|
||||
Message: "response body too long",
|
||||
WrappedError: fmt.Errorf("%.2f MiB", float64(resp.ContentLength)/1024/1024),
|
||||
Message: "not reading response",
|
||||
WrappedError: fmt.Errorf("%w (%.2f MiB)", mautrix.ErrResponseTooLong, float64(resp.ContentLength)/1024/1024),
|
||||
}
|
||||
}
|
||||
body, err = io.ReadAll(io.LimitReader(resp.Body, c.ResponseSizeLimit+1))
|
||||
if err == nil && len(body) > int(c.ResponseSizeLimit) {
|
||||
err = mautrix.ErrBodyReadReachedLimit
|
||||
}
|
||||
if err != nil {
|
||||
return body, resp, mautrix.HTTPError{
|
||||
Request: req,
|
||||
|
|
@ -341,15 +344,6 @@ func (c *Client) MakeFullRequest(ctx context.Context, params RequestParams) ([]b
|
|||
WrappedError: err,
|
||||
}
|
||||
}
|
||||
if len(body) > int(c.ResponseSizeLimit) {
|
||||
return body, resp, mautrix.HTTPError{
|
||||
Request: req,
|
||||
Response: resp,
|
||||
|
||||
Message: "failed to read response body",
|
||||
WrappedError: fmt.Errorf("exceeded read limit"),
|
||||
}
|
||||
}
|
||||
if params.ResponseJSON != nil {
|
||||
err = json.Unmarshal(body, params.ResponseJSON)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
)
|
||||
|
||||
type ResolvedServerName struct {
|
||||
|
|
@ -171,9 +173,11 @@ func RequestWellKnown(ctx context.Context, cli *http.Client, hostname string) (*
|
|||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, time.Time{}, fmt.Errorf("unexpected status code %d", resp.StatusCode)
|
||||
} else if resp.ContentLength > mautrix.WellKnownMaxSize {
|
||||
return nil, time.Time{}, fmt.Errorf("response too large: %d bytes", resp.ContentLength)
|
||||
}
|
||||
var respData RespWellKnown
|
||||
err = json.NewDecoder(io.LimitReader(resp.Body, 50*1024)).Decode(&respData)
|
||||
err = json.NewDecoder(io.LimitReader(resp.Body, mautrix.WellKnownMaxSize)).Decode(&respData)
|
||||
if err != nil {
|
||||
return nil, time.Time{}, fmt.Errorf("failed to decode response: %w", err)
|
||||
} else if respData.Server == "" {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue