diff --git a/backend_client.go b/backend_client.go index 2e3fb3b..efd5ae1 100644 --- a/backend_client.go +++ b/backend_client.go @@ -39,6 +39,9 @@ import ( var ( ErrNotRedirecting = errors.New("not redirecting to different host") ErrUnsupportedContentType = errors.New("unsupported_content_type") + + ErrIncompleteResponse = errors.New("incomplete OCS response") + ErrThrottledResponse = errors.New("throttled OCS response") ) type BackendClient struct { @@ -193,8 +196,16 @@ func (b *BackendClient) PerformJSONRequest(ctx context.Context, u *url.URL, requ return err } else if ocs.Ocs == nil || ocs.Ocs.Data == nil { log.Printf("Incomplete OCS response %s from %s", string(body), req.URL) - return fmt.Errorf("incomplete OCS response") - } else if err := json.Unmarshal(*ocs.Ocs.Data, response); err != nil { + return ErrIncompleteResponse + } + + switch ocs.Ocs.Meta.StatusCode { + case http.StatusTooManyRequests: + log.Printf("Throttled OCS response %s from %s", string(body), req.URL) + return ErrThrottledResponse + } + + if err := json.Unmarshal(*ocs.Ocs.Data, response); err != nil { log.Printf("Could not decode OCS response body %s from %s: %s", string(*ocs.Ocs.Data), req.URL, err) return err } diff --git a/backend_client_test.go b/backend_client_test.go index 5bae7d5..379bcd0 100644 --- a/backend_client_test.go +++ b/backend_client_test.go @@ -30,6 +30,7 @@ import ( "net/http/httptest" "net/url" "reflect" + "strings" "testing" "github.com/dlintw/goconf" @@ -47,6 +48,14 @@ func returnOCS(t *testing.T, w http.ResponseWriter, body []byte) { Data: (*json.RawMessage)(&body), }, } + if strings.Contains(t.Name(), "Throttled") { + response.Ocs.Meta = OcsMeta{ + Status: "failure", + StatusCode: 429, + Message: "Reached maximum delay", + } + } + data, err := json.Marshal(response) if err != nil { t.Fatal(err) @@ -206,3 +215,40 @@ func TestPostOnRedirectStatusFound(t *testing.T) { t.Errorf("Expected empty response, got %+v", response) } } + +func TestHandleThrottled(t *testing.T) { + r := mux.NewRouter() + r.HandleFunc("/ocs/v2.php/one", func(w http.ResponseWriter, r *http.Request) { + returnOCS(t, w, []byte("[]")) + }) + server := httptest.NewServer(r) + defer server.Close() + + u, err := url.Parse(server.URL + "/ocs/v2.php/one") + if err != nil { + t.Fatal(err) + } + + config := goconf.NewConfigFile() + config.AddOption("backend", "allowed", u.Host) + config.AddOption("backend", "secret", string(testBackendSecret)) + if u.Scheme == "http" { + config.AddOption("backend", "allowhttp", "true") + } + client, err := NewBackendClient(config, 1, "0.0", nil) + if err != nil { + t.Fatal(err) + } + + ctx := context.Background() + request := map[string]string{ + "foo": "bar", + } + var response map[string]string + err = client.PerformJSONRequest(ctx, u, request, &response) + if err == nil { + t.Error("should have triggered an error") + } else if !errors.Is(err, ErrThrottledResponse) { + t.Error(err) + } +}