Only include body in 307/308 redirects if going to same host.
This commit is contained in:
parent
b422b4d379
commit
dc713ea8e8
|
@ -26,8 +26,8 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -40,7 +40,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrUseLastResponse = fmt.Errorf("use last response")
|
ErrNotRedirecting = errors.New("not redirecting to different host")
|
||||||
|
ErrUnsupportedContentType = errors.New("unsupported_content_type")
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -115,9 +116,17 @@ func (b *BackendClient) getPool(url *url.URL) (*HttpClientPool, error) {
|
||||||
pool, err := NewHttpClientPool(func() *http.Client {
|
pool, err := NewHttpClientPool(func() *http.Client {
|
||||||
return &http.Client{
|
return &http.Client{
|
||||||
Transport: b.transport,
|
Transport: b.transport,
|
||||||
|
// Only send body in redirect if going to same scheme / host.
|
||||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||||
// Should be http.ErrUseLastResponse with go 1.8
|
if len(via) >= 10 {
|
||||||
return ErrUseLastResponse
|
return errors.New("stopped after 10 redirects")
|
||||||
|
} else if len(via) > 0 {
|
||||||
|
viaReq := via[len(via)-1]
|
||||||
|
if req.URL.Scheme != viaReq.URL.Scheme || req.URL.Host != viaReq.URL.Host {
|
||||||
|
return ErrNotRedirecting
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}, b.maxConcurrentRequestsPerHost)
|
}, b.maxConcurrentRequestsPerHost)
|
||||||
|
@ -149,161 +158,6 @@ func isOcsRequest(u *url.URL) bool {
|
||||||
return strings.Contains(u.Path, "/ocs/v2.php") || strings.Contains(u.Path, "/ocs/v1.php")
|
return strings.Contains(u.Path, "/ocs/v2.php") || strings.Contains(u.Path, "/ocs/v1.php")
|
||||||
}
|
}
|
||||||
|
|
||||||
func closeBody(response *http.Response) {
|
|
||||||
if response.Body != nil {
|
|
||||||
response.Body.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// refererForURL returns a referer without any authentication info or
|
|
||||||
// an empty string if lastReq scheme is https and newReq scheme is http.
|
|
||||||
func refererForURL(lastReq, newReq *url.URL) string {
|
|
||||||
// https://tools.ietf.org/html/rfc7231#section-5.5.2
|
|
||||||
// "Clients SHOULD NOT include a Referer header field in a
|
|
||||||
// (non-secure) HTTP request if the referring page was
|
|
||||||
// transferred with a secure protocol."
|
|
||||||
if lastReq.Scheme == "https" && newReq.Scheme == "http" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
referer := lastReq.String()
|
|
||||||
if lastReq.User != nil {
|
|
||||||
// This is not very efficient, but is the best we can
|
|
||||||
// do without:
|
|
||||||
// - introducing a new method on URL
|
|
||||||
// - creating a race condition
|
|
||||||
// - copying the URL struct manually, which would cause
|
|
||||||
// maintenance problems down the line
|
|
||||||
auth := lastReq.User.String() + "@"
|
|
||||||
referer = strings.Replace(referer, auth, "", 1)
|
|
||||||
}
|
|
||||||
return referer
|
|
||||||
}
|
|
||||||
|
|
||||||
// urlErrorOp returns the (*url.Error).Op value to use for the
|
|
||||||
// provided (*Request).Method value.
|
|
||||||
func urlErrorOp(method string) string {
|
|
||||||
if method == "" {
|
|
||||||
return "Get"
|
|
||||||
}
|
|
||||||
return method[:1] + strings.ToLower(method[1:])
|
|
||||||
}
|
|
||||||
|
|
||||||
func performRequestWithRedirects(ctx context.Context, client *http.Client, req *http.Request, body []byte) (*http.Response, error) {
|
|
||||||
var reqs []*http.Request
|
|
||||||
var resp *http.Response
|
|
||||||
|
|
||||||
uerr := func(err error) error {
|
|
||||||
var urlStr string
|
|
||||||
if resp != nil && resp.Request != nil {
|
|
||||||
urlStr = resp.Request.URL.String()
|
|
||||||
} else {
|
|
||||||
urlStr = req.URL.String()
|
|
||||||
}
|
|
||||||
return &url.Error{
|
|
||||||
Op: urlErrorOp(reqs[0].Method),
|
|
||||||
URL: urlStr,
|
|
||||||
Err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
if len(reqs) >= 10 {
|
|
||||||
return nil, fmt.Errorf("stopped after 10 redirects")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(reqs) > 0 {
|
|
||||||
loc := resp.Header.Get("Location")
|
|
||||||
if loc == "" {
|
|
||||||
closeBody(resp)
|
|
||||||
return nil, uerr(fmt.Errorf("%d response missing Location header", resp.StatusCode))
|
|
||||||
}
|
|
||||||
u, err := req.URL.Parse(loc)
|
|
||||||
if err != nil {
|
|
||||||
closeBody(resp)
|
|
||||||
return nil, uerr(fmt.Errorf("failed to parse Location header %q: %v", loc, err))
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(reqs) == 1 {
|
|
||||||
log.Printf("Got a redirect from %s to %s, please check your configuration", req.URL, u)
|
|
||||||
}
|
|
||||||
|
|
||||||
host := ""
|
|
||||||
if req.Host != "" && req.Host != req.URL.Host {
|
|
||||||
// If the caller specified a custom Host header and the
|
|
||||||
// redirect location is relative, preserve the Host header
|
|
||||||
// through the redirect. See issue #22233.
|
|
||||||
if u, _ := url.Parse(loc); u != nil && !u.IsAbs() {
|
|
||||||
host = req.Host
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ireq := reqs[0]
|
|
||||||
req = &http.Request{
|
|
||||||
Method: ireq.Method,
|
|
||||||
URL: u,
|
|
||||||
Header: ireq.Header,
|
|
||||||
Host: host,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the Referer header from the most recent
|
|
||||||
// request URL to the new one, if it's not https->http:
|
|
||||||
if ref := refererForURL(reqs[len(reqs)-1].URL, req.URL); ref != "" {
|
|
||||||
req.Header.Set("Referer", ref)
|
|
||||||
}
|
|
||||||
// Close the previous response's body. But
|
|
||||||
// read at least some of the body so if it's
|
|
||||||
// small the underlying TCP connection will be
|
|
||||||
// re-used. No need to check for errors: if it
|
|
||||||
// fails, the Transport won't reuse it anyway.
|
|
||||||
const maxBodySlurpSize = 2 << 10
|
|
||||||
if resp.ContentLength == -1 || resp.ContentLength <= maxBodySlurpSize {
|
|
||||||
io.CopyN(ioutil.Discard, resp.Body, maxBodySlurpSize) // nolint
|
|
||||||
}
|
|
||||||
resp.Body.Close()
|
|
||||||
}
|
|
||||||
reqs = append(reqs, req)
|
|
||||||
var err error
|
|
||||||
|
|
||||||
if body != nil {
|
|
||||||
req.Body = ioutil.NopCloser(bytes.NewReader(body))
|
|
||||||
req.ContentLength = int64(len(body))
|
|
||||||
}
|
|
||||||
resp, err = client.Do(req.WithContext(ctx))
|
|
||||||
if err != nil {
|
|
||||||
// Prefer context error if it has been cancelled.
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
err = ctx.Err()
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
if e, ok := err.(*url.Error); !ok || resp == nil || e.Err != ErrUseLastResponse {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch resp.StatusCode {
|
|
||||||
case 301, 302, 303:
|
|
||||||
break
|
|
||||||
case 307, 308:
|
|
||||||
if resp.Header.Get("Location") == "" {
|
|
||||||
// 308s have been observed in the wild being served
|
|
||||||
// without Location headers. Since Go 1.7 and earlier
|
|
||||||
// didn't follow these codes, just stop here instead
|
|
||||||
// of returning an error.
|
|
||||||
// See Issue 17773.
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
if req.Body == nil {
|
|
||||||
// We had a request body, and 307/308 require
|
|
||||||
// re-sending it, but GetBody is not defined. So just
|
|
||||||
// return this response to the user instead of an
|
|
||||||
// error, like we did in Go 1.7 and earlier.
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type CapabilitiesVersion struct {
|
type CapabilitiesVersion struct {
|
||||||
Major int `json:"major"`
|
Major int `json:"major"`
|
||||||
Minor int `json:"minor"`
|
Minor int `json:"minor"`
|
||||||
|
@ -378,7 +232,7 @@ func (b *BackendClient) getCapabilities(ctx context.Context, u *url.URL) (map[st
|
||||||
ct := resp.Header.Get("Content-Type")
|
ct := resp.Header.Get("Content-Type")
|
||||||
if !strings.HasPrefix(ct, "application/json") {
|
if !strings.HasPrefix(ct, "application/json") {
|
||||||
log.Printf("Received unsupported content-type from %s: %s (%s)", capUrl.String(), ct, resp.Status)
|
log.Printf("Received unsupported content-type from %s: %s (%s)", capUrl.String(), ct, resp.Status)
|
||||||
return nil, fmt.Errorf("unsupported_content_type")
|
return nil, ErrUnsupportedContentType
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := ioutil.ReadAll(resp.Body)
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
|
@ -483,14 +337,10 @@ func (b *BackendClient) PerformJSONRequest(ctx context.Context, u *url.URL, requ
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
req := &http.Request{
|
req, err := http.NewRequestWithContext(ctx, "POST", requestUrl.String(), bytes.NewReader(data))
|
||||||
Method: "POST",
|
if err != nil {
|
||||||
URL: requestUrl,
|
log.Printf("Could not create request to %s: %s", requestUrl, err)
|
||||||
Proto: "HTTP/1.1",
|
return err
|
||||||
ProtoMajor: 1,
|
|
||||||
ProtoMinor: 1,
|
|
||||||
Header: make(http.Header),
|
|
||||||
Host: u.Host,
|
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("Accept", "application/json")
|
req.Header.Set("Accept", "application/json")
|
||||||
|
@ -500,7 +350,7 @@ func (b *BackendClient) PerformJSONRequest(ctx context.Context, u *url.URL, requ
|
||||||
// Add checksum so the backend can validate the request.
|
// Add checksum so the backend can validate the request.
|
||||||
AddBackendChecksum(req, data, secret)
|
AddBackendChecksum(req, data, secret)
|
||||||
|
|
||||||
resp, err := performRequestWithRedirects(ctx, c, req, data)
|
resp, err := c.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Could not send request %s to %s: %s", string(data), req.URL, err)
|
log.Printf("Could not send request %s to %s: %s", string(data), req.URL, err)
|
||||||
return err
|
return err
|
||||||
|
@ -510,7 +360,7 @@ func (b *BackendClient) PerformJSONRequest(ctx context.Context, u *url.URL, requ
|
||||||
ct := resp.Header.Get("Content-Type")
|
ct := resp.Header.Get("Content-Type")
|
||||||
if !strings.HasPrefix(ct, "application/json") {
|
if !strings.HasPrefix(ct, "application/json") {
|
||||||
log.Printf("Received unsupported content-type from %s: %s (%s)", req.URL, ct, resp.Status)
|
log.Printf("Received unsupported content-type from %s: %s (%s)", req.URL, ct, resp.Status)
|
||||||
return fmt.Errorf("unsupported_content_type")
|
return ErrUnsupportedContentType
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := ioutil.ReadAll(resp.Body)
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
|
|
|
@ -24,6 +24,7 @@ package signaling
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
@ -35,10 +36,34 @@ import (
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func returnOCS(t *testing.T, w http.ResponseWriter, body []byte) {
|
||||||
|
response := OcsResponse{
|
||||||
|
Ocs: &OcsBody{
|
||||||
|
Meta: OcsMeta{
|
||||||
|
Status: "OK",
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Message: "OK",
|
||||||
|
},
|
||||||
|
Data: (*json.RawMessage)(&body),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if _, err := w.Write(data); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestPostOnRedirect(t *testing.T) {
|
func TestPostOnRedirect(t *testing.T) {
|
||||||
r := mux.NewRouter()
|
r := mux.NewRouter()
|
||||||
r.HandleFunc("/ocs/v2.php/one", func(w http.ResponseWriter, r *http.Request) {
|
r.HandleFunc("/ocs/v2.php/one", func(w http.ResponseWriter, r *http.Request) {
|
||||||
http.Redirect(w, r, "/ocs/v2.php/two", http.StatusFound)
|
http.Redirect(w, r, "/ocs/v2.php/two", http.StatusTemporaryRedirect)
|
||||||
})
|
})
|
||||||
r.HandleFunc("/ocs/v2.php/two", func(w http.ResponseWriter, r *http.Request) {
|
r.HandleFunc("/ocs/v2.php/two", func(w http.ResponseWriter, r *http.Request) {
|
||||||
body, err := ioutil.ReadAll(r.Body)
|
body, err := ioutil.ReadAll(r.Body)
|
||||||
|
@ -53,27 +78,7 @@ func TestPostOnRedirect(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
returnOCS(t, w, body)
|
||||||
response := OcsResponse{
|
|
||||||
Ocs: &OcsBody{
|
|
||||||
Meta: OcsMeta{
|
|
||||||
Status: "OK",
|
|
||||||
StatusCode: http.StatusOK,
|
|
||||||
Message: "OK",
|
|
||||||
},
|
|
||||||
Data: (*json.RawMessage)(&body),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
data, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
if _, err := w.Write(data); err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
server := httptest.NewServer(r)
|
server := httptest.NewServer(r)
|
||||||
|
@ -109,3 +114,95 @@ func TestPostOnRedirect(t *testing.T) {
|
||||||
t.Errorf("Expected %+v, got %+v", request, response)
|
t.Errorf("Expected %+v, got %+v", request, response)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPostOnRedirectDifferentHost(t *testing.T) {
|
||||||
|
r := mux.NewRouter()
|
||||||
|
r.HandleFunc("/ocs/v2.php/one", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
http.Redirect(w, r, "http://domain.invalid/ocs/v2.php/two", http.StatusTemporaryRedirect)
|
||||||
|
})
|
||||||
|
server := httptest.NewServer(r)
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
u, err := url.Parse(server.URL + "/ocs/v2.php/one")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config := goconf.NewConfigFile()
|
||||||
|
config.AddOption("backend", "allowed", u.Host)
|
||||||
|
config.AddOption("backend", "secret", string(testBackendSecret))
|
||||||
|
if u.Scheme == "http" {
|
||||||
|
config.AddOption("backend", "allowhttp", "true")
|
||||||
|
}
|
||||||
|
client, err := NewBackendClient(config, 1, "0.0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
request := map[string]string{
|
||||||
|
"foo": "bar",
|
||||||
|
}
|
||||||
|
var response map[string]string
|
||||||
|
err = client.PerformJSONRequest(ctx, u, request, &response)
|
||||||
|
if err != nil {
|
||||||
|
// The redirect to a different host should have failed.
|
||||||
|
if !errors.Is(err, ErrNotRedirecting) {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Fatal("The redirect should have failed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostOnRedirectStatusFound(t *testing.T) {
|
||||||
|
r := mux.NewRouter()
|
||||||
|
r.HandleFunc("/ocs/v2.php/one", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
http.Redirect(w, r, "/ocs/v2.php/two", http.StatusFound)
|
||||||
|
})
|
||||||
|
r.HandleFunc("/ocs/v2.php/two", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, err := ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(body) > 0 {
|
||||||
|
t.Errorf("Should not have received any body, got %s", string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
returnOCS(t, w, []byte("{}"))
|
||||||
|
})
|
||||||
|
server := httptest.NewServer(r)
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
u, err := url.Parse(server.URL + "/ocs/v2.php/one")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config := goconf.NewConfigFile()
|
||||||
|
config.AddOption("backend", "allowed", u.Host)
|
||||||
|
config.AddOption("backend", "secret", string(testBackendSecret))
|
||||||
|
if u.Scheme == "http" {
|
||||||
|
config.AddOption("backend", "allowhttp", "true")
|
||||||
|
}
|
||||||
|
client, err := NewBackendClient(config, 1, "0.0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
request := map[string]string{
|
||||||
|
"foo": "bar",
|
||||||
|
}
|
||||||
|
var response map[string]string
|
||||||
|
err = client.PerformJSONRequest(ctx, u, request, &response)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(response) > 0 {
|
||||||
|
t.Errorf("Expected empty response, got %+v", response)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue