diff --git a/backend_client.go b/backend_client.go index a9a3967..363e52c 100644 --- a/backend_client.go +++ b/backend_client.go @@ -34,6 +34,7 @@ import ( "net/url" "strings" "sync" + "time" "github.com/dlintw/goconf" ) @@ -42,6 +43,17 @@ var ( ErrUseLastResponse = fmt.Errorf("use last response") ) +const ( + // Name of the "Talk" app in Nextcloud. + AppNameSpreed = "spreed" + + // Name of capability to enable the "v3" API for the signaling endpoint. + FeatureSignalingV3Api = "signaling-v3" + + // Cache received capabilities for one hour. + CapabilitiesCacheDuration = time.Hour +) + type BackendClient struct { transport *http.Transport version string @@ -51,6 +63,10 @@ type BackendClient struct { mu sync.Mutex maxConcurrentRequestsPerHost int + + capabilitiesLock sync.RWMutex + capabilities map[string]map[string]interface{} + nextCapabilities map[string]time.Time } func NewBackendClient(config *goconf.ConfigFile, maxConcurrentRequestsPerHost int, version string) (*BackendClient, error) { @@ -79,6 +95,9 @@ func NewBackendClient(config *goconf.ConfigFile, maxConcurrentRequestsPerHost in clients: make(map[string]*HttpClientPool), maxConcurrentRequestsPerHost: maxConcurrentRequestsPerHost, + + capabilities: make(map[string]map[string]interface{}), + nextCapabilities: make(map[string]time.Time), }, nil } @@ -285,6 +304,144 @@ func performRequestWithRedirects(ctx context.Context, client *http.Client, req * } } +type CapabilitiesVersion struct { + Major int `json:"major"` + Minor int `json:"minor"` + Micro int `json:"micro"` + String string `json:"string"` + Edition string `json:"edition"` + ExtendedSupport bool `json:"extendedSupport"` +} + +type CapabilitiesResponse struct { + Version CapabilitiesVersion `json:"version"` + Capabilities map[string]map[string]interface{} `json:"capabilities"` +} + +func (b *BackendClient) getCapabilities(ctx context.Context, u *url.URL) (map[string]interface{}, error) { + key := u.String() + now := time.Now() + + b.capabilitiesLock.RLock() + if caps, found := b.capabilities[key]; found { + if next, found := b.nextCapabilities[key]; found && next.After(now) { + b.capabilitiesLock.RUnlock() + return caps, nil + } + } + b.capabilitiesLock.RUnlock() + + capUrl := *u + if !strings.Contains(capUrl.Path, "ocs/v2.php") { + if !strings.HasSuffix(capUrl.Path, "/") { + capUrl.Path += "/" + } + capUrl.Path = capUrl.Path + "ocs/v2.php/cloud/capabilities" + } else if pos := strings.Index(capUrl.Path, "/ocs/v2.php/"); pos >= 0 { + capUrl.Path = capUrl.Path[:pos+11] + "/cloud/capabilities" + } + + log.Printf("Capabilities expired for %s, updating", capUrl.String()) + + pool, err := b.getPool(&capUrl) + if err != nil { + log.Printf("Could not get client pool for host %s: %s", capUrl.Host, err) + return nil, err + } + + c, err := pool.Get(ctx) + if err != nil { + log.Printf("Could not get client for host %s: %s", capUrl.Host, err) + return nil, err + } + defer pool.Put(c) + + req := &http.Request{ + Method: "GET", + URL: &capUrl, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + Host: u.Host, + } + req.Header.Set("Accept", "application/json") + req.Header.Set("OCS-APIRequest", "true") + req.Header.Set("User-Agent", "nextcloud-spreed-signaling/"+b.version) + + resp, err := c.Do(req.WithContext(ctx)) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + ct := resp.Header.Get("Content-Type") + if !strings.HasPrefix(ct, "application/json") { + log.Printf("Received unsupported content-type from %s: %s (%s)", capUrl.String(), ct, resp.Status) + return nil, fmt.Errorf("unsupported_content_type") + } + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + log.Printf("Could not read response body from %s: %s", capUrl.String(), err) + return nil, err + } + + var ocs OcsResponse + if err := json.Unmarshal(body, &ocs); err != nil { + log.Printf("Could not decode OCS response %s from %s: %s", string(body), capUrl.String(), err) + return nil, err + } else if ocs.Ocs == nil || ocs.Ocs.Data == nil { + log.Printf("Incomplete OCS response %s from %s", string(body), u) + return nil, fmt.Errorf("incomplete OCS response") + } + + var response CapabilitiesResponse + if err := json.Unmarshal(*ocs.Ocs.Data, &response); err != nil { + log.Printf("Could not decode OCS response body %s from %s: %s", string(*ocs.Ocs.Data), capUrl.String(), err) + return nil, err + } + + capa, found := response.Capabilities[AppNameSpreed] + if !found { + log.Printf("No capabilities received for app spreed from %s: %+v", capUrl.String(), response) + return nil, nil + } + + log.Printf("Received capabilities %+v from %s", capa, capUrl.String()) + b.capabilitiesLock.Lock() + b.capabilities[key] = capa + b.nextCapabilities[key] = now.Add(CapabilitiesCacheDuration) + b.capabilitiesLock.Unlock() + return capa, nil +} + +func (b *BackendClient) HasCapabilityFeature(ctx context.Context, u *url.URL, feature string) bool { + caps, err := b.getCapabilities(ctx, u) + if err != nil { + log.Printf("Could not get capabilities for %s: %s", u, err) + return false + } + + featuresInterface := caps["features"] + if featuresInterface == nil { + return false + } + + features, ok := featuresInterface.([]interface{}) + if !ok { + log.Printf("Invalid features list received for %s: %+v", u, featuresInterface) + return false + } + + for _, entry := range features { + if entry == feature { + return true + } + } + return false +} + // PerformJSONRequest sends a JSON POST request to the given url and decodes // the result into "response". func (b *BackendClient) PerformJSONRequest(ctx context.Context, u *url.URL, request interface{}, response interface{}) error { @@ -297,6 +454,16 @@ func (b *BackendClient) PerformJSONRequest(ctx context.Context, u *url.URL, requ return fmt.Errorf("no backend secret configured for for %s", u) } + var requestUrl *url.URL + if b.HasCapabilityFeature(ctx, u, FeatureSignalingV3Api) { + newUrl := *u + newUrl.Path = strings.Replace(newUrl.Path, "/spreed/api/v1/signaling/", "/spreed/api/v3/signaling/", -1) + newUrl.Path = strings.Replace(newUrl.Path, "/spreed/api/v2/signaling/", "/spreed/api/v3/signaling/", -1) + requestUrl = &newUrl + } else { + requestUrl = u + } + pool, err := b.getPool(u) if err != nil { log.Printf("Could not get client pool for host %s: %s", u.Host, err) @@ -318,7 +485,7 @@ func (b *BackendClient) PerformJSONRequest(ctx context.Context, u *url.URL, requ req := &http.Request{ Method: "POST", - URL: u, + URL: requestUrl, Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1, @@ -335,20 +502,20 @@ func (b *BackendClient) PerformJSONRequest(ctx context.Context, u *url.URL, requ resp, err := performRequestWithRedirects(ctx, c, req, data) if err != nil { - log.Printf("Could not send request %s to %s: %s", string(data), u.String(), err) + log.Printf("Could not send request %s to %s: %s", string(data), req.URL, err) return err } defer resp.Body.Close() ct := resp.Header.Get("Content-Type") if !strings.HasPrefix(ct, "application/json") { - log.Printf("Received unsupported content-type from %s: %s (%s)", u.String(), ct, resp.Status) + log.Printf("Received unsupported content-type from %s: %s (%s)", req.URL, ct, resp.Status) return fmt.Errorf("unsupported_content_type") } body, err := ioutil.ReadAll(resp.Body) if err != nil { - log.Printf("Could not read response body from %s: %s", u.String(), err) + log.Printf("Could not read response body from %s: %s", req.URL, err) return err } @@ -363,17 +530,17 @@ func (b *BackendClient) PerformJSONRequest(ctx context.Context, u *url.URL, requ // } var ocs OcsResponse if err := json.Unmarshal(body, &ocs); err != nil { - log.Printf("Could not decode OCS response %s from %s: %s", string(body), u, err) + log.Printf("Could not decode OCS response %s from %s: %s", string(body), req.URL, err) return err } else if ocs.Ocs == nil || ocs.Ocs.Data == nil { - log.Printf("Incomplete OCS response %s from %s", string(body), u) + log.Printf("Incomplete OCS response %s from %s", string(body), req.URL) return fmt.Errorf("incomplete OCS response") } else if err := json.Unmarshal(*ocs.Ocs.Data, response); err != nil { - log.Printf("Could not decode OCS response body %s from %s: %s", string(*ocs.Ocs.Data), u, err) + log.Printf("Could not decode OCS response body %s from %s: %s", string(*ocs.Ocs.Data), req.URL, err) return err } } else if err := json.Unmarshal(body, response); err != nil { - log.Printf("Could not decode response body %s from %s: %s", string(body), u, err) + log.Printf("Could not decode response body %s from %s: %s", string(body), req.URL, err) return err } return nil diff --git a/hub_test.go b/hub_test.go index 7904c1e..e56cf9f 100644 --- a/hub_test.go +++ b/hub_test.go @@ -181,11 +181,11 @@ func validateBackendChecksum(t *testing.T, f func(http.ResponseWriter, *http.Req rnd := r.Header.Get(HeaderBackendSignalingRandom) checksum := r.Header.Get(HeaderBackendSignalingChecksum) if rnd == "" || checksum == "" { - t.Fatal("No checksum headers found") + t.Fatalf("No checksum headers found in request to %s", r.URL) } if verify := CalculateBackendChecksum(rnd, body, testBackendSecret); verify != checksum { - t.Fatal("Backend checksum verification failed") + t.Fatalf("Backend checksum verification failed for request to %s", r.URL) } var request BackendClientRequest @@ -355,7 +355,54 @@ func registerBackendHandlerUrl(t *testing.T, router *mux.Router, url string) { if !strings.HasSuffix(url, "/") { url += "/" } - router.HandleFunc(url+"ocs/v2.php/apps/spreed/api/v1/signaling/backend", handleFunc) + + handleCapabilitiesFunc := func(w http.ResponseWriter, r *http.Request) { + features := []string{ + "foo", + "bar", + } + if strings.Contains(t.Name(), "V3Api") { + features = append(features, "signaling-v3") + } + response := &CapabilitiesResponse{ + Version: CapabilitiesVersion{ + Major: 20, + }, + Capabilities: map[string]map[string]interface{}{ + "spreed": { + "features": features, + }, + }, + } + + data, err := json.Marshal(response) + if err != nil { + t.Errorf("Could not marshal %+v: %s", response, err) + } + + var ocs OcsResponse + ocs.Ocs = &OcsBody{ + Meta: OcsMeta{ + Status: "ok", + StatusCode: http.StatusOK, + Message: http.StatusText(http.StatusOK), + }, + Data: (*json.RawMessage)(&data), + } + if data, err = json.Marshal(ocs); err != nil { + t.Fatal(err) + } + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(data) // nolint + } + router.HandleFunc(url+"ocs/v2.php/cloud/capabilities", handleCapabilitiesFunc) + + if strings.Contains(t.Name(), "V3Api") { + router.HandleFunc(url+"ocs/v2.php/apps/spreed/api/v3/signaling/backend", handleFunc) + } else { + router.HandleFunc(url+"ocs/v2.php/apps/spreed/api/v1/signaling/backend", handleFunc) + } } func performHousekeeping(hub *Hub, now time.Time) *sync.WaitGroup { @@ -1189,6 +1236,40 @@ func TestClientHelloClient(t *testing.T) { } } +func TestClientHelloClient_V3Api(t *testing.T) { + hub, _, _, server, shutdown := CreateHubForTest(t) + defer shutdown() + + client := NewTestClient(t, server, hub) + defer client.CloseWithBye() + + params := TestBackendClientAuthParams{ + UserId: testDefaultUserId, + } + // The "/api/v1/signaling/" URL will be changed to use "v3" as the "signaling-v3" + // feature is returned by the capabilities endpoint. + if err := client.SendHelloParams(server.URL+"/ocs/v2.php/apps/spreed/api/v1/signaling/backend", "client", params); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + if hello, err := client.RunUntilHello(ctx); err != nil { + t.Error(err) + } else { + if hello.Hello.UserId != testDefaultUserId { + t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello) + } + if hello.Hello.SessionId == "" { + t.Errorf("Expected session id, got %+v", hello.Hello) + } + if hello.Hello.ResumeId == "" { + t.Errorf("Expected resume id, got %+v", hello.Hello) + } + } +} + func TestClientHelloInternal(t *testing.T) { hub, _, _, server, shutdown := CreateHubForTest(t) defer shutdown()