From dc713ea8e87dea55e0f66767953d0ae00d46b0d0 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Wed, 7 Jul 2021 12:25:05 +0200 Subject: [PATCH] Only include body in 307/308 redirects if going to same host. --- backend_client.go | 190 +++++------------------------------------ backend_client_test.go | 141 +++++++++++++++++++++++++----- 2 files changed, 139 insertions(+), 192 deletions(-) diff --git a/backend_client.go b/backend_client.go index 363e52c..5b13c2f 100644 --- a/backend_client.go +++ b/backend_client.go @@ -26,8 +26,8 @@ import ( "context" "crypto/tls" "encoding/json" + "errors" "fmt" - "io" "io/ioutil" "log" "net/http" @@ -40,7 +40,8 @@ import ( ) var ( - ErrUseLastResponse = fmt.Errorf("use last response") + ErrNotRedirecting = errors.New("not redirecting to different host") + ErrUnsupportedContentType = errors.New("unsupported_content_type") ) const ( @@ -115,9 +116,17 @@ func (b *BackendClient) getPool(url *url.URL) (*HttpClientPool, error) { pool, err := NewHttpClientPool(func() *http.Client { return &http.Client{ Transport: b.transport, + // Only send body in redirect if going to same scheme / host. CheckRedirect: func(req *http.Request, via []*http.Request) error { - // Should be http.ErrUseLastResponse with go 1.8 - return ErrUseLastResponse + if len(via) >= 10 { + return errors.New("stopped after 10 redirects") + } else if len(via) > 0 { + viaReq := via[len(via)-1] + if req.URL.Scheme != viaReq.URL.Scheme || req.URL.Host != viaReq.URL.Host { + return ErrNotRedirecting + } + } + return nil }, } }, b.maxConcurrentRequestsPerHost) @@ -149,161 +158,6 @@ func isOcsRequest(u *url.URL) bool { return strings.Contains(u.Path, "/ocs/v2.php") || strings.Contains(u.Path, "/ocs/v1.php") } -func closeBody(response *http.Response) { - if response.Body != nil { - response.Body.Close() - } -} - -// refererForURL returns a referer without any authentication info or -// an empty string if lastReq scheme is https and newReq scheme is http. -func refererForURL(lastReq, newReq *url.URL) string { - // https://tools.ietf.org/html/rfc7231#section-5.5.2 - // "Clients SHOULD NOT include a Referer header field in a - // (non-secure) HTTP request if the referring page was - // transferred with a secure protocol." - if lastReq.Scheme == "https" && newReq.Scheme == "http" { - return "" - } - referer := lastReq.String() - if lastReq.User != nil { - // This is not very efficient, but is the best we can - // do without: - // - introducing a new method on URL - // - creating a race condition - // - copying the URL struct manually, which would cause - // maintenance problems down the line - auth := lastReq.User.String() + "@" - referer = strings.Replace(referer, auth, "", 1) - } - return referer -} - -// urlErrorOp returns the (*url.Error).Op value to use for the -// provided (*Request).Method value. -func urlErrorOp(method string) string { - if method == "" { - return "Get" - } - return method[:1] + strings.ToLower(method[1:]) -} - -func performRequestWithRedirects(ctx context.Context, client *http.Client, req *http.Request, body []byte) (*http.Response, error) { - var reqs []*http.Request - var resp *http.Response - - uerr := func(err error) error { - var urlStr string - if resp != nil && resp.Request != nil { - urlStr = resp.Request.URL.String() - } else { - urlStr = req.URL.String() - } - return &url.Error{ - Op: urlErrorOp(reqs[0].Method), - URL: urlStr, - Err: err, - } - } - for { - if len(reqs) >= 10 { - return nil, fmt.Errorf("stopped after 10 redirects") - } - - if len(reqs) > 0 { - loc := resp.Header.Get("Location") - if loc == "" { - closeBody(resp) - return nil, uerr(fmt.Errorf("%d response missing Location header", resp.StatusCode)) - } - u, err := req.URL.Parse(loc) - if err != nil { - closeBody(resp) - return nil, uerr(fmt.Errorf("failed to parse Location header %q: %v", loc, err)) - } - - if len(reqs) == 1 { - log.Printf("Got a redirect from %s to %s, please check your configuration", req.URL, u) - } - - host := "" - if req.Host != "" && req.Host != req.URL.Host { - // If the caller specified a custom Host header and the - // redirect location is relative, preserve the Host header - // through the redirect. See issue #22233. - if u, _ := url.Parse(loc); u != nil && !u.IsAbs() { - host = req.Host - } - } - ireq := reqs[0] - req = &http.Request{ - Method: ireq.Method, - URL: u, - Header: ireq.Header, - Host: host, - } - - // Add the Referer header from the most recent - // request URL to the new one, if it's not https->http: - if ref := refererForURL(reqs[len(reqs)-1].URL, req.URL); ref != "" { - req.Header.Set("Referer", ref) - } - // Close the previous response's body. But - // read at least some of the body so if it's - // small the underlying TCP connection will be - // re-used. No need to check for errors: if it - // fails, the Transport won't reuse it anyway. - const maxBodySlurpSize = 2 << 10 - if resp.ContentLength == -1 || resp.ContentLength <= maxBodySlurpSize { - io.CopyN(ioutil.Discard, resp.Body, maxBodySlurpSize) // nolint - } - resp.Body.Close() - } - reqs = append(reqs, req) - var err error - - if body != nil { - req.Body = ioutil.NopCloser(bytes.NewReader(body)) - req.ContentLength = int64(len(body)) - } - resp, err = client.Do(req.WithContext(ctx)) - if err != nil { - // Prefer context error if it has been cancelled. - select { - case <-ctx.Done(): - err = ctx.Err() - default: - } - if e, ok := err.(*url.Error); !ok || resp == nil || e.Err != ErrUseLastResponse { - return nil, err - } - } - - switch resp.StatusCode { - case 301, 302, 303: - break - case 307, 308: - if resp.Header.Get("Location") == "" { - // 308s have been observed in the wild being served - // without Location headers. Since Go 1.7 and earlier - // didn't follow these codes, just stop here instead - // of returning an error. - // See Issue 17773. - return resp, nil - } - if req.Body == nil { - // We had a request body, and 307/308 require - // re-sending it, but GetBody is not defined. So just - // return this response to the user instead of an - // error, like we did in Go 1.7 and earlier. - return resp, nil - } - default: - return resp, nil - } - } -} - type CapabilitiesVersion struct { Major int `json:"major"` Minor int `json:"minor"` @@ -378,7 +232,7 @@ func (b *BackendClient) getCapabilities(ctx context.Context, u *url.URL) (map[st ct := resp.Header.Get("Content-Type") if !strings.HasPrefix(ct, "application/json") { log.Printf("Received unsupported content-type from %s: %s (%s)", capUrl.String(), ct, resp.Status) - return nil, fmt.Errorf("unsupported_content_type") + return nil, ErrUnsupportedContentType } body, err := ioutil.ReadAll(resp.Body) @@ -483,14 +337,10 @@ func (b *BackendClient) PerformJSONRequest(ctx context.Context, u *url.URL, requ return err } - req := &http.Request{ - Method: "POST", - URL: requestUrl, - Proto: "HTTP/1.1", - ProtoMajor: 1, - ProtoMinor: 1, - Header: make(http.Header), - Host: u.Host, + req, err := http.NewRequestWithContext(ctx, "POST", requestUrl.String(), bytes.NewReader(data)) + if err != nil { + log.Printf("Could not create request to %s: %s", requestUrl, err) + return err } req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") @@ -500,7 +350,7 @@ func (b *BackendClient) PerformJSONRequest(ctx context.Context, u *url.URL, requ // Add checksum so the backend can validate the request. AddBackendChecksum(req, data, secret) - resp, err := performRequestWithRedirects(ctx, c, req, data) + resp, err := c.Do(req) if err != nil { log.Printf("Could not send request %s to %s: %s", string(data), req.URL, err) return err @@ -510,7 +360,7 @@ func (b *BackendClient) PerformJSONRequest(ctx context.Context, u *url.URL, requ ct := resp.Header.Get("Content-Type") if !strings.HasPrefix(ct, "application/json") { log.Printf("Received unsupported content-type from %s: %s (%s)", req.URL, ct, resp.Status) - return fmt.Errorf("unsupported_content_type") + return ErrUnsupportedContentType } body, err := ioutil.ReadAll(resp.Body) diff --git a/backend_client_test.go b/backend_client_test.go index 96c1279..88927fa 100644 --- a/backend_client_test.go +++ b/backend_client_test.go @@ -24,6 +24,7 @@ package signaling import ( "context" "encoding/json" + "errors" "io/ioutil" "net/http" "net/http/httptest" @@ -35,10 +36,34 @@ import ( "github.com/gorilla/mux" ) +func returnOCS(t *testing.T, w http.ResponseWriter, body []byte) { + response := OcsResponse{ + Ocs: &OcsBody{ + Meta: OcsMeta{ + Status: "OK", + StatusCode: http.StatusOK, + Message: "OK", + }, + Data: (*json.RawMessage)(&body), + }, + } + data, err := json.Marshal(response) + if err != nil { + t.Fatal(err) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if _, err := w.Write(data); err != nil { + t.Error(err) + } +} + func TestPostOnRedirect(t *testing.T) { r := mux.NewRouter() r.HandleFunc("/ocs/v2.php/one", func(w http.ResponseWriter, r *http.Request) { - http.Redirect(w, r, "/ocs/v2.php/two", http.StatusFound) + http.Redirect(w, r, "/ocs/v2.php/two", http.StatusTemporaryRedirect) }) r.HandleFunc("/ocs/v2.php/two", func(w http.ResponseWriter, r *http.Request) { body, err := ioutil.ReadAll(r.Body) @@ -53,27 +78,7 @@ func TestPostOnRedirect(t *testing.T) { return } - w.Header().Set("Content-Type", "application/json") - response := OcsResponse{ - Ocs: &OcsBody{ - Meta: OcsMeta{ - Status: "OK", - StatusCode: http.StatusOK, - Message: "OK", - }, - Data: (*json.RawMessage)(&body), - }, - } - data, err := json.Marshal(response) - if err != nil { - t.Fatal(err) - return - } - - w.WriteHeader(http.StatusOK) - if _, err := w.Write(data); err != nil { - t.Error(err) - } + returnOCS(t, w, body) }) server := httptest.NewServer(r) @@ -109,3 +114,95 @@ func TestPostOnRedirect(t *testing.T) { t.Errorf("Expected %+v, got %+v", request, response) } } + +func TestPostOnRedirectDifferentHost(t *testing.T) { + r := mux.NewRouter() + r.HandleFunc("/ocs/v2.php/one", func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "http://domain.invalid/ocs/v2.php/two", http.StatusTemporaryRedirect) + }) + server := httptest.NewServer(r) + defer server.Close() + + u, err := url.Parse(server.URL + "/ocs/v2.php/one") + if err != nil { + t.Fatal(err) + } + + config := goconf.NewConfigFile() + config.AddOption("backend", "allowed", u.Host) + config.AddOption("backend", "secret", string(testBackendSecret)) + if u.Scheme == "http" { + config.AddOption("backend", "allowhttp", "true") + } + client, err := NewBackendClient(config, 1, "0.0") + if err != nil { + t.Fatal(err) + } + + ctx := context.Background() + request := map[string]string{ + "foo": "bar", + } + var response map[string]string + err = client.PerformJSONRequest(ctx, u, request, &response) + if err != nil { + // The redirect to a different host should have failed. + if !errors.Is(err, ErrNotRedirecting) { + t.Fatal(err) + } + } else { + t.Fatal("The redirect should have failed") + } +} + +func TestPostOnRedirectStatusFound(t *testing.T) { + r := mux.NewRouter() + r.HandleFunc("/ocs/v2.php/one", func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "/ocs/v2.php/two", http.StatusFound) + }) + r.HandleFunc("/ocs/v2.php/two", func(w http.ResponseWriter, r *http.Request) { + body, err := ioutil.ReadAll(r.Body) + if err != nil { + t.Fatal(err) + return + } + + if len(body) > 0 { + t.Errorf("Should not have received any body, got %s", string(body)) + } + + returnOCS(t, w, []byte("{}")) + }) + server := httptest.NewServer(r) + defer server.Close() + + u, err := url.Parse(server.URL + "/ocs/v2.php/one") + if err != nil { + t.Fatal(err) + } + + config := goconf.NewConfigFile() + config.AddOption("backend", "allowed", u.Host) + config.AddOption("backend", "secret", string(testBackendSecret)) + if u.Scheme == "http" { + config.AddOption("backend", "allowhttp", "true") + } + client, err := NewBackendClient(config, 1, "0.0") + if err != nil { + t.Fatal(err) + } + + ctx := context.Background() + request := map[string]string{ + "foo": "bar", + } + var response map[string]string + err = client.PerformJSONRequest(ctx, u, request, &response) + if err != nil { + t.Error(err) + } + + if len(response) > 0 { + t.Errorf("Expected empty response, got %+v", response) + } +}