diff --git a/backend_client.go b/backend_client.go index de64e1d..f9cf3a9 100644 --- a/backend_client.go +++ b/backend_client.go @@ -24,7 +24,6 @@ package signaling import ( "bytes" "context" - "crypto/tls" "encoding/json" "errors" "fmt" @@ -33,8 +32,6 @@ import ( "net/http" "net/url" "strings" - "sync" - "time" "github.com/dlintw/goconf" ) @@ -44,31 +41,13 @@ var ( ErrUnsupportedContentType = errors.New("unsupported_content_type") ) -const ( - // Name of the "Talk" app in Nextcloud. - AppNameSpreed = "spreed" - - // Name of capability to enable the "v3" API for the signaling endpoint. - FeatureSignalingV3Api = "signaling-v3" - - // Cache received capabilities for one hour. - CapabilitiesCacheDuration = time.Hour -) - type BackendClient struct { - hub *Hub - transport *http.Transport - version string - backends *BackendConfiguration - clients map[string]*HttpClientPool + hub *Hub + version string + backends *BackendConfiguration - mu sync.Mutex - - maxConcurrentRequestsPerHost int - - capabilitiesLock sync.RWMutex - capabilities map[string]map[string]interface{} - nextCapabilities map[string]time.Time + pool *HttpClientPool + capabilities *Capabilities } func NewBackendClient(config *goconf.ConfigFile, maxConcurrentRequestsPerHost int, version string) (*BackendClient, error) { @@ -82,24 +61,22 @@ func NewBackendClient(config *goconf.ConfigFile, maxConcurrentRequestsPerHost in log.Println("WARNING: Backend verification is disabled!") } - tlsconfig := &tls.Config{ - InsecureSkipVerify: skipverify, + pool, err := NewHttpClientPool(maxConcurrentRequestsPerHost, skipverify) + if err != nil { + return nil, err } - transport := &http.Transport{ - MaxIdleConnsPerHost: maxConcurrentRequestsPerHost, - TLSClientConfig: tlsconfig, + + capabilities, err := NewCapabilities(version, pool) + if err != nil { + return nil, err } return &BackendClient{ - transport: transport, - version: version, - backends: backends, - clients: make(map[string]*HttpClientPool), + version: version, + backends: backends, - maxConcurrentRequestsPerHost: maxConcurrentRequestsPerHost, - - capabilities: make(map[string]map[string]interface{}), - nextCapabilities: make(map[string]time.Time), + pool: pool, + capabilities: capabilities, }, nil } @@ -107,38 +84,6 @@ func (b *BackendClient) Reload(config *goconf.ConfigFile) { b.backends.Reload(config) } -func (b *BackendClient) getPool(url *url.URL) (*HttpClientPool, error) { - b.mu.Lock() - defer b.mu.Unlock() - if pool, found := b.clients[url.Host]; found { - return pool, nil - } - - pool, err := NewHttpClientPool(func() *http.Client { - return &http.Client{ - Transport: b.transport, - // Only send body in redirect if going to same scheme / host. - CheckRedirect: func(req *http.Request, via []*http.Request) error { - if len(via) >= 10 { - return errors.New("stopped after 10 redirects") - } else if len(via) > 0 { - viaReq := via[len(via)-1] - if req.URL.Scheme != viaReq.URL.Scheme || req.URL.Host != viaReq.URL.Host { - return ErrNotRedirecting - } - } - return nil - }, - } - }, b.maxConcurrentRequestsPerHost) - if err != nil { - return nil, err - } - - b.clients[url.Host] = pool - return pool, nil -} - func (b *BackendClient) GetCompatBackend() *Backend { return b.backends.GetCompatBackend() } @@ -159,140 +104,6 @@ func isOcsRequest(u *url.URL) bool { return strings.Contains(u.Path, "/ocs/v2.php") || strings.Contains(u.Path, "/ocs/v1.php") } -type CapabilitiesVersion struct { - Major int `json:"major"` - Minor int `json:"minor"` - Micro int `json:"micro"` - String string `json:"string"` - Edition string `json:"edition"` - ExtendedSupport bool `json:"extendedSupport"` -} - -type CapabilitiesResponse struct { - Version CapabilitiesVersion `json:"version"` - Capabilities map[string]map[string]interface{} `json:"capabilities"` -} - -func (b *BackendClient) getCapabilities(ctx context.Context, u *url.URL) (map[string]interface{}, error) { - key := u.String() - now := time.Now() - - b.capabilitiesLock.RLock() - if caps, found := b.capabilities[key]; found { - if next, found := b.nextCapabilities[key]; found && next.After(now) { - b.capabilitiesLock.RUnlock() - return caps, nil - } - } - b.capabilitiesLock.RUnlock() - - capUrl := *u - if !strings.Contains(capUrl.Path, "ocs/v2.php") { - if !strings.HasSuffix(capUrl.Path, "/") { - capUrl.Path += "/" - } - capUrl.Path = capUrl.Path + "ocs/v2.php/cloud/capabilities" - } else if pos := strings.Index(capUrl.Path, "/ocs/v2.php/"); pos >= 0 { - capUrl.Path = capUrl.Path[:pos+11] + "/cloud/capabilities" - } - - log.Printf("Capabilities expired for %s, updating", capUrl.String()) - - pool, err := b.getPool(&capUrl) - if err != nil { - log.Printf("Could not get client pool for host %s: %s", capUrl.Host, err) - return nil, err - } - - c, err := pool.Get(ctx) - if err != nil { - log.Printf("Could not get client for host %s: %s", capUrl.Host, err) - return nil, err - } - defer pool.Put(c) - - req, err := http.NewRequestWithContext(ctx, "GET", capUrl.String(), nil) - if err != nil { - log.Printf("Could not create request to %s: %s", &capUrl, err) - return nil, err - } - req.Header.Set("Accept", "application/json") - req.Header.Set("OCS-APIRequest", "true") - req.Header.Set("User-Agent", "nextcloud-spreed-signaling/"+b.version) - - resp, err := c.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - ct := resp.Header.Get("Content-Type") - if !strings.HasPrefix(ct, "application/json") { - log.Printf("Received unsupported content-type from %s: %s (%s)", capUrl.String(), ct, resp.Status) - return nil, ErrUnsupportedContentType - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - log.Printf("Could not read response body from %s: %s", capUrl.String(), err) - return nil, err - } - - var ocs OcsResponse - if err := json.Unmarshal(body, &ocs); err != nil { - log.Printf("Could not decode OCS response %s from %s: %s", string(body), capUrl.String(), err) - return nil, err - } else if ocs.Ocs == nil || ocs.Ocs.Data == nil { - log.Printf("Incomplete OCS response %s from %s", string(body), u) - return nil, fmt.Errorf("incomplete OCS response") - } - - var response CapabilitiesResponse - if err := json.Unmarshal(*ocs.Ocs.Data, &response); err != nil { - log.Printf("Could not decode OCS response body %s from %s: %s", string(*ocs.Ocs.Data), capUrl.String(), err) - return nil, err - } - - capa, found := response.Capabilities[AppNameSpreed] - if !found { - log.Printf("No capabilities received for app spreed from %s: %+v", capUrl.String(), response) - return nil, nil - } - - log.Printf("Received capabilities %+v from %s", capa, capUrl.String()) - b.capabilitiesLock.Lock() - b.capabilities[key] = capa - b.nextCapabilities[key] = now.Add(CapabilitiesCacheDuration) - b.capabilitiesLock.Unlock() - return capa, nil -} - -func (b *BackendClient) HasCapabilityFeature(ctx context.Context, u *url.URL, feature string) bool { - caps, err := b.getCapabilities(ctx, u) - if err != nil { - log.Printf("Could not get capabilities for %s: %s", u, err) - return false - } - - featuresInterface := caps["features"] - if featuresInterface == nil { - return false - } - - features, ok := featuresInterface.([]interface{}) - if !ok { - log.Printf("Invalid features list received for %s: %+v", u, featuresInterface) - return false - } - - for _, entry := range features { - if entry == feature { - return true - } - } - return false -} - // PerformJSONRequest sends a JSON POST request to the given url and decodes // the result into "response". func (b *BackendClient) PerformJSONRequest(ctx context.Context, u *url.URL, request interface{}, response interface{}) error { @@ -306,7 +117,7 @@ func (b *BackendClient) PerformJSONRequest(ctx context.Context, u *url.URL, requ } var requestUrl *url.URL - if b.HasCapabilityFeature(ctx, u, FeatureSignalingV3Api) { + if b.capabilities.HasCapabilityFeature(ctx, u, FeatureSignalingV3Api) { newUrl := *u newUrl.Path = strings.Replace(newUrl.Path, "/spreed/api/v1/signaling/", "/spreed/api/v3/signaling/", -1) newUrl.Path = strings.Replace(newUrl.Path, "/spreed/api/v2/signaling/", "/spreed/api/v3/signaling/", -1) @@ -315,13 +126,7 @@ func (b *BackendClient) PerformJSONRequest(ctx context.Context, u *url.URL, requ requestUrl = u } - pool, err := b.getPool(u) - if err != nil { - log.Printf("Could not get client pool for host %s: %s", u.Host, err) - return err - } - - c, err := pool.Get(ctx) + c, pool, err := b.pool.Get(ctx, u) if err != nil { log.Printf("Could not get client for host %s: %s", u.Host, err) return err diff --git a/capabilities.go b/capabilities.go new file mode 100644 index 0000000..ef105fa --- /dev/null +++ b/capabilities.go @@ -0,0 +1,212 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2022 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "net/url" + "strings" + "sync" + "time" +) + +const ( + // Name of the "Talk" app in Nextcloud. + AppNameSpreed = "spreed" + + // Name of capability to enable the "v3" API for the signaling endpoint. + FeatureSignalingV3Api = "signaling-v3" + + // Cache received capabilities for one hour. + CapabilitiesCacheDuration = time.Hour +) + +type capabilitiesEntry struct { + nextUpdate time.Time + capabilities map[string]interface{} +} + +type Capabilities struct { + mu sync.RWMutex + + version string + pool *HttpClientPool + entries map[string]*capabilitiesEntry +} + +func NewCapabilities(version string, pool *HttpClientPool) (*Capabilities, error) { + result := &Capabilities{ + version: version, + pool: pool, + entries: make(map[string]*capabilitiesEntry), + } + + return result, nil +} + +type CapabilitiesVersion struct { + Major int `json:"major"` + Minor int `json:"minor"` + Micro int `json:"micro"` + String string `json:"string"` + Edition string `json:"edition"` + ExtendedSupport bool `json:"extendedSupport"` +} + +type CapabilitiesResponse struct { + Version CapabilitiesVersion `json:"version"` + Capabilities map[string]map[string]interface{} `json:"capabilities"` +} + +func (c *Capabilities) getCapabilities(key string) (map[string]interface{}, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + now := time.Now() + if entry, found := c.entries[key]; found && entry.nextUpdate.After(now) { + return entry.capabilities, true + } + + return nil, false +} + +func (c *Capabilities) setCapabilities(key string, capabilities map[string]interface{}) { + now := time.Now() + entry := &capabilitiesEntry{ + nextUpdate: now.Add(CapabilitiesCacheDuration), + capabilities: capabilities, + } + + c.mu.Lock() + defer c.mu.Unlock() + c.entries[key] = entry +} + +func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (map[string]interface{}, error) { + key := u.String() + + if caps, found := c.getCapabilities(key); found { + return caps, nil + } + + capUrl := *u + if !strings.Contains(capUrl.Path, "ocs/v2.php") { + if !strings.HasSuffix(capUrl.Path, "/") { + capUrl.Path += "/" + } + capUrl.Path = capUrl.Path + "ocs/v2.php/cloud/capabilities" + } else if pos := strings.Index(capUrl.Path, "/ocs/v2.php/"); pos >= 0 { + capUrl.Path = capUrl.Path[:pos+11] + "/cloud/capabilities" + } + + log.Printf("Capabilities expired for %s, updating", capUrl.String()) + + client, pool, err := c.pool.Get(ctx, &capUrl) + if err != nil { + log.Printf("Could not get client for host %s: %s", capUrl.Host, err) + return nil, err + } + defer pool.Put(client) + + req, err := http.NewRequestWithContext(ctx, "GET", capUrl.String(), nil) + if err != nil { + log.Printf("Could not create request to %s: %s", &capUrl, err) + return nil, err + } + req.Header.Set("Accept", "application/json") + req.Header.Set("OCS-APIRequest", "true") + req.Header.Set("User-Agent", "nextcloud-spreed-signaling/"+c.version) + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + ct := resp.Header.Get("Content-Type") + if !strings.HasPrefix(ct, "application/json") { + log.Printf("Received unsupported content-type from %s: %s (%s)", capUrl.String(), ct, resp.Status) + return nil, ErrUnsupportedContentType + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + log.Printf("Could not read response body from %s: %s", capUrl.String(), err) + return nil, err + } + + var ocs OcsResponse + if err := json.Unmarshal(body, &ocs); err != nil { + log.Printf("Could not decode OCS response %s from %s: %s", string(body), capUrl.String(), err) + return nil, err + } else if ocs.Ocs == nil || ocs.Ocs.Data == nil { + log.Printf("Incomplete OCS response %s from %s", string(body), u) + return nil, fmt.Errorf("incomplete OCS response") + } + + var response CapabilitiesResponse + if err := json.Unmarshal(*ocs.Ocs.Data, &response); err != nil { + log.Printf("Could not decode OCS response body %s from %s: %s", string(*ocs.Ocs.Data), capUrl.String(), err) + return nil, err + } + + capa, found := response.Capabilities[AppNameSpreed] + if !found { + log.Printf("No capabilities received for app spreed from %s: %+v", capUrl.String(), response) + return nil, nil + } + + log.Printf("Received capabilities %+v from %s", capa, capUrl.String()) + c.setCapabilities(key, capa) + return capa, nil +} + +func (c *Capabilities) HasCapabilityFeature(ctx context.Context, u *url.URL, feature string) bool { + caps, err := c.loadCapabilities(ctx, u) + if err != nil { + log.Printf("Could not get capabilities for %s: %s", u, err) + return false + } + + featuresInterface := caps["features"] + if featuresInterface == nil { + return false + } + + features, ok := featuresInterface.([]interface{}) + if !ok { + log.Printf("Invalid features list received for %s: %+v", u, featuresInterface) + return false + } + + for _, entry := range features { + if entry == feature { + return true + } + } + return false +} diff --git a/http_client_pool.go b/http_client_pool.go new file mode 100644 index 0000000..96a577a --- /dev/null +++ b/http_client_pool.go @@ -0,0 +1,142 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2017 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net/http" + "net/url" + "sync" +) + +type Pool struct { + pool chan *http.Client +} + +func (p *Pool) get(ctx context.Context) (client *http.Client, err error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case client := <-p.pool: + return client, nil + } +} + +func (p *Pool) Put(c *http.Client) { + p.pool <- c +} + +func newPool(constructor func() *http.Client, size int) (*Pool, error) { + if size <= 0 { + return nil, fmt.Errorf("can't create empty pool") + } + + p := &Pool{ + pool: make(chan *http.Client, size), + } + for size > 0 { + c := constructor() + p.pool <- c + size-- + } + return p, nil +} + +type HttpClientPool struct { + mu sync.Mutex + + transport *http.Transport + clients map[string]*Pool + + maxConcurrentRequestsPerHost int +} + +func NewHttpClientPool(maxConcurrentRequestsPerHost int, skipVerify bool) (*HttpClientPool, error) { + if maxConcurrentRequestsPerHost <= 0 { + return nil, fmt.Errorf("can't create empty pool") + } + + tlsconfig := &tls.Config{ + InsecureSkipVerify: skipVerify, + } + transport := &http.Transport{ + MaxIdleConnsPerHost: maxConcurrentRequestsPerHost, + TLSClientConfig: tlsconfig, + } + + result := &HttpClientPool{ + transport: transport, + clients: make(map[string]*Pool), + + maxConcurrentRequestsPerHost: maxConcurrentRequestsPerHost, + } + return result, nil +} + +func (p *HttpClientPool) getPool(url *url.URL) (*Pool, error) { + p.mu.Lock() + defer p.mu.Unlock() + if pool, found := p.clients[url.Host]; found { + return pool, nil + } + + pool, err := newPool(func() *http.Client { + return &http.Client{ + Transport: p.transport, + // Only send body in redirect if going to same scheme / host. + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if len(via) >= 10 { + return errors.New("stopped after 10 redirects") + } else if len(via) > 0 { + viaReq := via[len(via)-1] + if req.URL.Scheme != viaReq.URL.Scheme || req.URL.Host != viaReq.URL.Host { + return ErrNotRedirecting + } + } + return nil + }, + } + }, p.maxConcurrentRequestsPerHost) + if err != nil { + return nil, err + } + + p.clients[url.Host] = pool + return pool, nil +} + +func (p *HttpClientPool) Get(ctx context.Context, url *url.URL) (*http.Client, *Pool, error) { + pool, err := p.getPool(url) + if err != nil { + return nil, nil, err + } + + client, err := pool.get(ctx) + if err != nil { + return nil, nil, err + } + + return client, pool, err +} diff --git a/pool_test.go b/http_client_pool_test.go similarity index 62% rename from pool_test.go rename to http_client_pool_test.go index d3f073b..6b6214c 100644 --- a/pool_test.go +++ b/http_client_pool_test.go @@ -23,38 +23,52 @@ package signaling import ( "context" - "net/http" + "net/url" "testing" "time" ) func TestHttpClientPool(t *testing.T) { - transport := &http.Transport{} - if _, err := NewHttpClientPool(func() *http.Client { - return &http.Client{ - Transport: transport, - } - }, 0); err == nil { + if _, err := NewHttpClientPool(0, false); err == nil { t.Error("should not be possible to create empty pool") } - pool, err := NewHttpClientPool(func() *http.Client { - return &http.Client{ - Transport: transport, - } - }, 1) + pool, err := NewHttpClientPool(1, false) + if err != nil { + t.Fatal(err) + } + + u, err := url.Parse("http://localhost/foo/bar") if err != nil { t.Fatal(err) } ctx := context.Background() - if _, err := pool.Get(ctx); err != nil { + if _, _, err := pool.Get(ctx, u); err != nil { t.Fatal(err) } ctx2, cancel := context.WithTimeout(ctx, 10*time.Millisecond) defer cancel() - if _, err := pool.Get(ctx2); err == nil { + if _, _, err := pool.Get(ctx2, u); err == nil { + t.Error("fetching from empty pool should have timed out") + } else if err != context.DeadlineExceeded { + t.Errorf("fetching from empty pool should have timed out, got %s", err) + } + + // Pools are separated by hostname, so can get client for different host. + u2, err := url.Parse("http://local.host/foo/bar") + if err != nil { + t.Fatal(err) + } + + if _, _, err := pool.Get(ctx, u2); err != nil { + t.Fatal(err) + } + + ctx3, cancel2 := context.WithTimeout(ctx, 10*time.Millisecond) + defer cancel2() + if _, _, err := pool.Get(ctx3, u2); err == nil { t.Error("fetching from empty pool should have timed out") } else if err != context.DeadlineExceeded { t.Errorf("fetching from empty pool should have timed out, got %s", err) diff --git a/pool.go b/pool.go deleted file mode 100644 index 9cf8e3c..0000000 --- a/pool.go +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Standalone signaling server for the Nextcloud Spreed app. - * Copyright (C) 2017 struktur AG - * - * @author Joachim Bauch - * - * @license GNU AGPL version 3 or any later version - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ -package signaling - -import ( - "context" - "fmt" - "net/http" -) - -type HttpClientPool struct { - pool chan *http.Client -} - -func NewHttpClientPool(constructor func() *http.Client, size int) (*HttpClientPool, error) { - if size <= 0 { - return nil, fmt.Errorf("can't create empty pool") - } - - p := &HttpClientPool{ - pool: make(chan *http.Client, size), - } - for size > 0 { - c := constructor() - p.pool <- c - size-- - } - return p, nil -} - -func (p *HttpClientPool) Get(ctx context.Context) (client *http.Client, err error) { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case client := <-p.pool: - return client, nil - } -} - -func (p *HttpClientPool) Put(c *http.Client) { - p.pool <- c -}