diff --git a/api_backend.go b/api_backend.go index 901e035..ad62bc9 100644 --- a/api_backend.go +++ b/api_backend.go @@ -40,6 +40,11 @@ const ( HeaderBackendSignalingRandom = "Spreed-Signaling-Random" HeaderBackendSignalingChecksum = "Spreed-Signaling-Checksum" HeaderBackendServer = "Spreed-Signaling-Backend" + + ConfigGroupSignaling = "signaling" + + ConfigKeyHelloV2TokenKey = "hello-v2-token-key" + ConfigKeySessionPingLimit = "session-ping-limit" ) func newRandomString(length int) string { diff --git a/api_proxy.go b/api_proxy.go index 3e79c1c..3184be4 100644 --- a/api_proxy.go +++ b/api_proxy.go @@ -142,7 +142,7 @@ type HelloProxyClientMessage struct { } func (m *HelloProxyClientMessage) CheckValid() error { - if m.Version != HelloVersion { + if m.Version != HelloVersionV1 { return fmt.Errorf("unsupported hello version: %s", m.Version) } if m.ResumeId == "" { diff --git a/api_signaling.go b/api_signaling.go index 0373bb5..a736476 100644 --- a/api_signaling.go +++ b/api_signaling.go @@ -27,11 +27,16 @@ import ( "net/url" "sort" "strings" + + "github.com/golang-jwt/jwt/v4" ) const ( - // Version that must be sent in a "hello" message. - HelloVersion = "1.0" + // Version 1.0 validates auth params against the Nextcloud instance. + HelloVersionV1 = "1.0" + + // Version 2.0 validates auth params encoded as JWT. + HelloVersionV2 = "2.0" ) // ClientMessage is a message that is sent from a client to the server. @@ -325,6 +330,23 @@ func (p *ClientTypeInternalAuthParams) CheckValid() error { return nil } +type HelloV2AuthParams struct { + Token string `json:"token"` +} + +func (p *HelloV2AuthParams) CheckValid() error { + if p.Token == "" { + return fmt.Errorf("token missing") + } + return nil +} + +type HelloV2TokenClaims struct { + jwt.RegisteredClaims + + UserData *json.RawMessage `json:"userdata,omitempty"` +} + type HelloClientMessageAuth struct { // The client type that is connecting. Leave empty to use the default // "HelloClientTypeClient" @@ -336,6 +358,7 @@ type HelloClientMessageAuth struct { parsedUrl *url.URL internalParams ClientTypeInternalAuthParams + helloV2Params HelloV2AuthParams } // Type "hello" @@ -352,8 +375,8 @@ type HelloClientMessage struct { } func (m *HelloClientMessage) CheckValid() error { - if m.Version != HelloVersion { - return fmt.Errorf("unsupported hello version: %s", m.Version) + if m.Version != HelloVersionV1 && m.Version != HelloVersionV2 { + return InvalidHelloVersion } if m.ResumeId == "" { if m.Auth.Params == nil || len(*m.Auth.Params) == 0 { @@ -375,6 +398,17 @@ func (m *HelloClientMessage) CheckValid() error { m.Auth.parsedUrl = u } + + switch m.Version { + case HelloVersionV1: + // No additional validation necessary. + case HelloVersionV2: + if err := json.Unmarshal(*m.Auth.Params, &m.Auth.helloV2Params); err != nil { + return err + } else if err := m.Auth.helloV2Params.CheckValid(); err != nil { + return err + } + } case HelloClientTypeInternal: if err := json.Unmarshal(*m.Auth.Params, &m.Auth.internalParams); err != nil { return err @@ -397,6 +431,7 @@ const ( ServerFeatureTransientData = "transient-data" ServerFeatureInCallAll = "incall-all" ServerFeatureWelcome = "welcome" + ServerFeatureHelloV2 = "hello-v2" // Features for internal clients only. ServerFeatureInternalVirtualSessions = "virtual-sessions" @@ -408,12 +443,14 @@ var ( ServerFeatureTransientData, ServerFeatureInCallAll, ServerFeatureWelcome, + ServerFeatureHelloV2, } DefaultFeaturesInternal = []string{ ServerFeatureInternalVirtualSessions, ServerFeatureTransientData, ServerFeatureInCallAll, ServerFeatureWelcome, + ServerFeatureHelloV2, } DefaultWelcomeFeatures = []string{ ServerFeatureAudioVideoPermissions, @@ -421,6 +458,7 @@ var ( ServerFeatureTransientData, ServerFeatureInCallAll, ServerFeatureWelcome, + ServerFeatureHelloV2, } ) diff --git a/api_signaling_test.go b/api_signaling_test.go index 6e9bc7a..94e54c7 100644 --- a/api_signaling_test.go +++ b/api_signaling_test.go @@ -90,16 +90,18 @@ func TestClientMessage(t *testing.T) { func TestHelloClientMessage(t *testing.T) { internalAuthParams := []byte("{\"backend\":\"https://domain.invalid\"}") + tokenAuthParams := []byte("{\"token\":\"invalid-token\"}") valid_messages := []testCheckValid{ + // Hello version 1 &HelloClientMessage{ - Version: HelloVersion, + Version: HelloVersionV1, Auth: HelloClientMessageAuth{ Params: &json.RawMessage{'{', '}'}, Url: "https://domain.invalid", }, }, &HelloClientMessage{ - Version: HelloVersion, + Version: HelloVersionV1, Auth: HelloClientMessageAuth{ Type: "client", Params: &json.RawMessage{'{', '}'}, @@ -107,61 +109,116 @@ func TestHelloClientMessage(t *testing.T) { }, }, &HelloClientMessage{ - Version: HelloVersion, + Version: HelloVersionV1, Auth: HelloClientMessageAuth{ Type: "internal", Params: (*json.RawMessage)(&internalAuthParams), }, }, &HelloClientMessage{ - Version: HelloVersion, + Version: HelloVersionV1, + ResumeId: "the-resume-id", + }, + // Hello version 2 + &HelloClientMessage{ + Version: HelloVersionV2, + Auth: HelloClientMessageAuth{ + Params: (*json.RawMessage)(&tokenAuthParams), + Url: "https://domain.invalid", + }, + }, + &HelloClientMessage{ + Version: HelloVersionV2, + Auth: HelloClientMessageAuth{ + Type: "client", + Params: (*json.RawMessage)(&tokenAuthParams), + Url: "https://domain.invalid", + }, + }, + &HelloClientMessage{ + Version: HelloVersionV2, ResumeId: "the-resume-id", }, } invalid_messages := []testCheckValid{ + // Hello version 1 &HelloClientMessage{}, &HelloClientMessage{Version: "0.0"}, - &HelloClientMessage{Version: HelloVersion}, + &HelloClientMessage{Version: HelloVersionV1}, &HelloClientMessage{ - Version: HelloVersion, + Version: HelloVersionV1, Auth: HelloClientMessageAuth{ Params: &json.RawMessage{'{', '}'}, Type: "invalid-type", }, }, &HelloClientMessage{ - Version: HelloVersion, + Version: HelloVersionV1, Auth: HelloClientMessageAuth{ Url: "https://domain.invalid", }, }, &HelloClientMessage{ - Version: HelloVersion, + Version: HelloVersionV1, Auth: HelloClientMessageAuth{ Params: &json.RawMessage{'{', '}'}, }, }, &HelloClientMessage{ - Version: HelloVersion, + Version: HelloVersionV1, Auth: HelloClientMessageAuth{ Params: &json.RawMessage{'{', '}'}, Url: "invalid-url", }, }, &HelloClientMessage{ - Version: HelloVersion, + Version: HelloVersionV1, Auth: HelloClientMessageAuth{ Type: "internal", Params: &json.RawMessage{'{', '}'}, }, }, &HelloClientMessage{ - Version: HelloVersion, + Version: HelloVersionV1, Auth: HelloClientMessageAuth{ Type: "internal", Params: &json.RawMessage{'x', 'y', 'z'}, // Invalid JSON. }, }, + // Hello version 2 + &HelloClientMessage{ + Version: HelloVersionV2, + Auth: HelloClientMessageAuth{ + Url: "https://domain.invalid", + }, + }, + &HelloClientMessage{ + Version: HelloVersionV2, + Auth: HelloClientMessageAuth{ + Params: (*json.RawMessage)(&tokenAuthParams), + }, + }, + &HelloClientMessage{ + Version: HelloVersionV2, + Auth: HelloClientMessageAuth{ + Params: (*json.RawMessage)(&tokenAuthParams), + Url: "invalid-url", + }, + }, + &HelloClientMessage{ + Version: HelloVersionV2, + Auth: HelloClientMessageAuth{ + Params: (*json.RawMessage)(&internalAuthParams), + Url: "https://domain.invalid", + }, + }, + &HelloClientMessage{ + Version: HelloVersionV2, + Auth: HelloClientMessageAuth{ + Params: &json.RawMessage{'x', 'y', 'z'}, // Invalid JSON. + Url: "https://domain.invalid", + }, + }, } testMessages(t, "hello", valid_messages, invalid_messages) diff --git a/capabilities.go b/capabilities.go index 491bd7a..12914bc 100644 --- a/capabilities.go +++ b/capabilities.go @@ -43,8 +43,14 @@ const ( // Cache received capabilities for one hour. CapabilitiesCacheDuration = time.Hour + + // Don't invalidate more than once per minute. + maxInvalidateInterval = time.Minute ) +// Can be overwritten by tests. +var getCapabilitiesNow = time.Now + type capabilitiesEntry struct { nextUpdate time.Time capabilities map[string]interface{} @@ -53,16 +59,18 @@ type capabilitiesEntry struct { type Capabilities struct { mu sync.RWMutex - version string - pool *HttpClientPool - entries map[string]*capabilitiesEntry + version string + pool *HttpClientPool + entries map[string]*capabilitiesEntry + nextInvalidate map[string]time.Time } func NewCapabilities(version string, pool *HttpClientPool) (*Capabilities, error) { result := &Capabilities{ - version: version, - pool: pool, - entries: make(map[string]*capabilitiesEntry), + version: version, + pool: pool, + entries: make(map[string]*capabilitiesEntry), + nextInvalidate: make(map[string]time.Time), } return result, nil @@ -86,7 +94,7 @@ func (c *Capabilities) getCapabilities(key string) (map[string]interface{}, bool c.mu.RLock() defer c.mu.RUnlock() - now := time.Now() + now := getCapabilitiesNow() if entry, found := c.entries[key]; found && entry.nextUpdate.After(now) { return entry.capabilities, true } @@ -95,7 +103,7 @@ func (c *Capabilities) getCapabilities(key string) (map[string]interface{}, bool } func (c *Capabilities) setCapabilities(key string, capabilities map[string]interface{}) { - now := time.Now() + now := getCapabilitiesNow() entry := &capabilitiesEntry{ nextUpdate: now.Add(CapabilitiesCacheDuration), capabilities: capabilities, @@ -106,11 +114,28 @@ func (c *Capabilities) setCapabilities(key string, capabilities map[string]inter c.entries[key] = entry } -func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (map[string]interface{}, error) { - key := u.String() +func (c *Capabilities) invalidateCapabilities(key string) { + c.mu.Lock() + defer c.mu.Unlock() + now := getCapabilitiesNow() + if entry, found := c.nextInvalidate[key]; found && entry.After(now) { + return + } + + delete(c.entries, key) + c.nextInvalidate[key] = now.Add(maxInvalidateInterval) +} + +func (c *Capabilities) getKeyForUrl(u *url.URL) string { + key := u.String() + return key +} + +func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (map[string]interface{}, bool, error) { + key := c.getKeyForUrl(u) if caps, found := c.getCapabilities(key); found { - return caps, nil + return caps, true, nil } capUrl := *u @@ -128,14 +153,14 @@ func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (map[st client, pool, err := c.pool.Get(ctx, &capUrl) if err != nil { log.Printf("Could not get client for host %s: %s", capUrl.Host, err) - return nil, err + return nil, false, err } defer pool.Put(client) req, err := http.NewRequestWithContext(ctx, "GET", capUrl.String(), nil) if err != nil { log.Printf("Could not create request to %s: %s", &capUrl, err) - return nil, err + return nil, false, err } req.Header.Set("Accept", "application/json") req.Header.Set("OCS-APIRequest", "true") @@ -143,56 +168,56 @@ func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (map[st resp, err := client.Do(req) if err != nil { - return nil, err + return nil, false, err } defer resp.Body.Close() ct := resp.Header.Get("Content-Type") if !strings.HasPrefix(ct, "application/json") { log.Printf("Received unsupported content-type from %s: %s (%s)", capUrl.String(), ct, resp.Status) - return nil, ErrUnsupportedContentType + return nil, false, ErrUnsupportedContentType } body, err := io.ReadAll(resp.Body) if err != nil { log.Printf("Could not read response body from %s: %s", capUrl.String(), err) - return nil, err + return nil, false, err } var ocs OcsResponse if err := json.Unmarshal(body, &ocs); err != nil { log.Printf("Could not decode OCS response %s from %s: %s", string(body), capUrl.String(), err) - return nil, err + return nil, false, err } else if ocs.Ocs == nil || ocs.Ocs.Data == nil { log.Printf("Incomplete OCS response %s from %s", string(body), u) - return nil, fmt.Errorf("incomplete OCS response") + return nil, false, fmt.Errorf("incomplete OCS response") } var response CapabilitiesResponse if err := json.Unmarshal(*ocs.Ocs.Data, &response); err != nil { log.Printf("Could not decode OCS response body %s from %s: %s", string(*ocs.Ocs.Data), capUrl.String(), err) - return nil, err + return nil, false, err } capaObj, found := response.Capabilities[AppNameSpreed] if !found || capaObj == nil { log.Printf("No capabilities received for app spreed from %s: %+v", capUrl.String(), response) - return nil, nil + return nil, false, nil } var capa map[string]interface{} if err := json.Unmarshal(*capaObj, &capa); err != nil { log.Printf("Unsupported capabilities received for app spreed from %s: %+v", capUrl.String(), response) - return nil, nil + return nil, false, nil } log.Printf("Received capabilities %+v from %s", capa, capUrl.String()) c.setCapabilities(key, capa) - return capa, nil + return capa, false, nil } func (c *Capabilities) HasCapabilityFeature(ctx context.Context, u *url.URL, feature string) bool { - caps, err := c.loadCapabilities(ctx, u) + caps, _, err := c.loadCapabilities(ctx, u) if err != nil { log.Printf("Could not get capabilities for %s: %s", u, err) return false @@ -217,80 +242,86 @@ func (c *Capabilities) HasCapabilityFeature(ctx context.Context, u *url.URL, fea return false } -func (c *Capabilities) getConfigGroup(ctx context.Context, u *url.URL, group string) (map[string]interface{}, bool) { - caps, err := c.loadCapabilities(ctx, u) +func (c *Capabilities) getConfigGroup(ctx context.Context, u *url.URL, group string) (map[string]interface{}, bool, bool) { + caps, cached, err := c.loadCapabilities(ctx, u) if err != nil { log.Printf("Could not get capabilities for %s: %s", u, err) - return nil, false + return nil, cached, false } configInterface := caps["config"] if configInterface == nil { - return nil, false + return nil, cached, false } config, ok := configInterface.(map[string]interface{}) if !ok { log.Printf("Invalid config mapping received from %s: %+v", u, configInterface) - return nil, false + return nil, cached, false } groupInterface := config[group] if groupInterface == nil { - return nil, false + return nil, cached, false } groupConfig, ok := groupInterface.(map[string]interface{}) if !ok { log.Printf("Invalid group mapping \"%s\" received from %s: %+v", group, u, groupInterface) - return nil, false + return nil, cached, false } - return groupConfig, true + return groupConfig, cached, true } -func (c *Capabilities) GetIntegerConfig(ctx context.Context, u *url.URL, group, key string) (int, bool) { - groupConfig, found := c.getConfigGroup(ctx, u, group) +func (c *Capabilities) GetIntegerConfig(ctx context.Context, u *url.URL, group, key string) (int, bool, bool) { + groupConfig, cached, found := c.getConfigGroup(ctx, u, group) if !found { - return 0, false + return 0, cached, false } value, found := groupConfig[key] if !found { - return 0, false + return 0, cached, false } switch value := value.(type) { case int: - return value, true + return value, cached, true case float32: - return int(value), true + return int(value), cached, true case float64: - return int(value), true + return int(value), cached, true default: log.Printf("Invalid config value for \"%s\" received from %s: %+v", key, u, value) } - return 0, false + return 0, cached, false } -func (c *Capabilities) GetStringConfig(ctx context.Context, u *url.URL, group, key string) (string, bool) { - groupConfig, found := c.getConfigGroup(ctx, u, group) +func (c *Capabilities) GetStringConfig(ctx context.Context, u *url.URL, group, key string) (string, bool, bool) { + groupConfig, cached, found := c.getConfigGroup(ctx, u, group) if !found { - return "", false + return "", cached, false } value, found := groupConfig[key] if !found { - return "", false + return "", cached, false } switch value := value.(type) { case string: - return value, true + return value, cached, true default: log.Printf("Invalid config value for \"%s\" received from %s: %+v", key, u, value) } - return "", false + return "", cached, false +} + +func (c *Capabilities) InvalidateCapabilities(u *url.URL) { + key := c.getKeyForUrl(u) + + c.invalidateCapabilities(key) } diff --git a/capabilities_test.go b/capabilities_test.go index 19eb087..22f653b 100644 --- a/capabilities_test.go +++ b/capabilities_test.go @@ -28,12 +28,14 @@ import ( "net/http/httptest" "net/url" "strings" + "sync/atomic" "testing" + "time" "github.com/gorilla/mux" ) -func NewCapabilitiesForTest(t *testing.T) (*url.URL, *Capabilities) { +func NewCapabilitiesForTestWithCallback(t *testing.T, callback func(*CapabilitiesResponse)) (*url.URL, *Capabilities) { pool, err := NewHttpClientPool(1, false) if err != nil { t.Fatal(err) @@ -84,6 +86,10 @@ func NewCapabilitiesForTest(t *testing.T) (*url.URL, *Capabilities) { }, } + if callback != nil { + callback(response) + } + data, err := json.Marshal(response) if err != nil { t.Errorf("Could not marshal %+v: %s", response, err) @@ -110,6 +116,19 @@ func NewCapabilitiesForTest(t *testing.T) (*url.URL, *Capabilities) { return u, capabilities } +func NewCapabilitiesForTest(t *testing.T) (*url.URL, *Capabilities) { + return NewCapabilitiesForTestWithCallback(t, nil) +} + +func SetCapabilitiesGetNow(t *testing.T, f func() time.Time) { + old := getCapabilitiesNow + t.Cleanup(func() { + getCapabilitiesNow = old + }) + + getCapabilitiesNow = f +} + func TestCapabilities(t *testing.T) { url, capabilities := NewCapabilitiesForTest(t) @@ -124,34 +143,122 @@ func TestCapabilities(t *testing.T) { } expectedString := "bar" - if value, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { t.Error("could not find value for \"foo\"") } else if value != expectedString { t.Errorf("expected value %s, got %s", expectedString, value) + } else if !cached { + t.Errorf("expected cached response") } - if value, found := capabilities.GetStringConfig(ctx, url, "signaling", "baz"); found { + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "baz"); found { t.Errorf("should not have found value for \"baz\", got %s", value) + } else if !cached { + t.Errorf("expected cached response") } - if value, found := capabilities.GetStringConfig(ctx, url, "signaling", "invalid"); found { + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "invalid"); found { t.Errorf("should not have found value for \"invalid\", got %s", value) + } else if !cached { + t.Errorf("expected cached response") } - if value, found := capabilities.GetStringConfig(ctx, url, "invalid", "foo"); found { + if value, cached, found := capabilities.GetStringConfig(ctx, url, "invalid", "foo"); found { t.Errorf("should not have found value for \"baz\", got %s", value) + } else if !cached { + t.Errorf("expected cached response") } expectedInt := 42 - if value, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "baz"); !found { + if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "baz"); !found { t.Error("could not find value for \"baz\"") } else if value != expectedInt { t.Errorf("expected value %d, got %d", expectedInt, value) + } else if !cached { + t.Errorf("expected cached response") } - if value, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "foo"); found { + if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "foo"); found { t.Errorf("should not have found value for \"foo\", got %d", value) + } else if !cached { + t.Errorf("expected cached response") } - if value, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "invalid"); found { + if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "invalid"); found { t.Errorf("should not have found value for \"invalid\", got %d", value) + } else if !cached { + t.Errorf("expected cached response") } - if value, found := capabilities.GetIntegerConfig(ctx, url, "invalid", "baz"); found { + if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "invalid", "baz"); found { t.Errorf("should not have found value for \"baz\", got %d", value) + } else if !cached { + t.Errorf("expected cached response") + } +} + +func TestInvalidateCapabilities(t *testing.T) { + var called uint32 + url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse) { + atomic.AddUint32(&called, 1) + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + expectedString := "bar" + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { + t.Error("could not find value for \"foo\"") + } else if value != expectedString { + t.Errorf("expected value %s, got %s", expectedString, value) + } else if cached { + t.Errorf("expected direct response") + } + + if value := atomic.LoadUint32(&called); value != 1 { + t.Errorf("expected called %d, got %d", 1, value) + } + + // Invalidating will cause the capabilities to be reloaded. + capabilities.InvalidateCapabilities(url) + + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { + t.Error("could not find value for \"foo\"") + } else if value != expectedString { + t.Errorf("expected value %s, got %s", expectedString, value) + } else if cached { + t.Errorf("expected direct response") + } + + if value := atomic.LoadUint32(&called); value != 2 { + t.Errorf("expected called %d, got %d", 2, value) + } + + // Invalidating is throttled to about once per minute. + capabilities.InvalidateCapabilities(url) + + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { + t.Error("could not find value for \"foo\"") + } else if value != expectedString { + t.Errorf("expected value %s, got %s", expectedString, value) + } else if !cached { + t.Errorf("expected cached response") + } + + if value := atomic.LoadUint32(&called); value != 2 { + t.Errorf("expected called %d, got %d", 2, value) + } + + // At a later time, invalidating can be done again. + SetCapabilitiesGetNow(t, func() time.Time { + return time.Now().Add(2 * time.Minute) + }) + + capabilities.InvalidateCapabilities(url) + + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { + t.Error("could not find value for \"foo\"") + } else if value != expectedString { + t.Errorf("expected value %s, got %s", expectedString, value) + } else if cached { + t.Errorf("expected direct response") + } + + if value := atomic.LoadUint32(&called); value != 3 { + t.Errorf("expected called %d, got %d", 3, value) } } diff --git a/client/main.go b/client/main.go index 1338de9..a9b2229 100644 --- a/client/main.go +++ b/client/main.go @@ -603,7 +603,7 @@ func main() { request := &signaling.ClientMessage{ Type: "hello", Hello: &signaling.HelloClientMessage{ - Version: signaling.HelloVersion, + Version: signaling.HelloVersionV1, Auth: signaling.HelloClientMessageAuth{ Url: backendUrl + "/auth", Params: &json.RawMessage{'{', '}'}, diff --git a/clientsession_test.go b/clientsession_test.go index 427761f..eb152dd 100644 --- a/clientsession_test.go +++ b/clientsession_test.go @@ -238,7 +238,7 @@ func TestBandwidth_Backend(t *testing.T) { params := TestBackendClientAuthParams{ UserId: testDefaultUserId, } - if err := client.SendHelloParams(server.URL+"/one", "client", params); err != nil { + if err := client.SendHelloParams(server.URL+"/one", HelloVersionV1, "client", params); err != nil { t.Fatal(err) } diff --git a/docs/standalone-signaling-api-v1.md b/docs/standalone-signaling-api-v1.md index ed54800..ca41211 100644 --- a/docs/standalone-signaling-api-v1.md +++ b/docs/standalone-signaling-api-v1.md @@ -140,7 +140,7 @@ Message format (Client -> Server): "id": "unique-request-id", "type": "hello", "hello": { - "version": "the-protocol-version-must-be-1.0", + "version": "the-protocol-version", "auth": { "url": "the-url-to-the-auth-backend", "params": { @@ -159,7 +159,7 @@ Message format (Server -> Client): "sessionid": "the-unique-session-id", "resumeid": "the-unique-resume-id", "userid": "the-user-id-for-known-users", - "version": "the-protocol-version-must-be-1.0", + "version": "the-protocol-version", "server": { "features": ["optional", "list, "of", "feature", "ids"], ...additional information about the server... @@ -172,12 +172,82 @@ future version. Clients should use the data from the [`welcome` message](#welcome-message) instead. +### Protocol version "1.0" + +For protocol version `1.0` in the `hello` request, the `params` from the `auth` +field are sent to the Nextcloud backend for [validation](#backend-validation). + + +### Protocol version "2.0" + +For protocol version `2.0` in the `hello` request, the `params` from the `auth` +field must contain a `token` entry containing a [JWT](https://jwt.io/). + +The JWT must contain the following fields: +- `iss`: URL of the Nextcloud server that issued the token. +- `iat`: Timestamp when the token has been issued. +- `exp`: Timestamp of the token expiration. +- `sub`: User Id (if known). +- `userdata`: Optional JSON containing more user data. + +It must be signed with an RSA, ECDSA or Ed25519 key. + +Example token: +``` +eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiJ9.eyJpc3MiOiJodHRwczovL25leHRjbG91ZC1tYXN0ZXIubG9jYWwvIiwiaWF0IjoxNjU0ODQyMDgwLCJleHAiOjE2NTQ4NDIzODAsInN1YiI6ImFkbWluIiwidXNlcmRhdGEiOnsiZGlzcGxheW5hbWUiOiJBZG1pbmlzdHJhdG9yIn19.5rV0jh89_0fG2L-BUPtciu1q49PoYkLboj33EOdD0qQeYcvE7_di2r5WXM1WmKUCOGeX3hzn6qldDMrJBNuxvQ +``` + +Example public key: +``` +-----BEGIN PUBLIC KEY----- +MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEIoCsNSCXyxK25zvSKRio0uiBzwub +ONq3tiGTPZo3p2Ogn6wAhhsuSxbFuUQDWMX7Tsu9fDzVdwpRHPT4y3V9cA== +-----END PUBLIC KEY----- +``` + +Example payload: +``` +{ + "iss": "https://nextcloud-master.local/", + "iat": 1654842080, + "exp": 1654842380, + "sub": "admin", + "userdata": { + "displayname": "Administrator" + } +} +``` + +The public key is retrieved from the capabilities of the Nextcloud instance +in `config` key `hello-v2-token-key` inside `signaling`. + +``` + "spreed": { + "features": [ + "audio", + "video", + "chat-v2", + "conversation-v4", + ... + ], + "config": { + … + "signaling": { + "hello-v2-token-key": "-----BEGIN RSA PUBLIC KEY----- ..." + } + } + }, +``` + + ### Backend validation -The server validates the connection request against the passed auth backend -(needs to make sure the passed url / hostname is in a whitelist). It performs -a POST request and passes the provided `params` as JSON payload in the body -of the request. +For `hello` protocol version `1.0`, the server validates the connection request +against the passed auth backend (needs to make sure the passed url / hostname +is in a whitelist). + +It performs a POST request and passes the provided `params` as JSON payload in +the body of the request. Message format (Server -> Auth backend): @@ -236,7 +306,7 @@ Message format (Client -> Server): "id": "unique-request-id", "type": "hello", "hello": { - "version": "the-protocol-version-must-be-1.0", + "version": "the-protocol-version", "auth": { "type": "the-client-type", ...other attributes depending on the client type... @@ -294,7 +364,7 @@ Message format (Client -> Server): "id": "unique-request-id", "type": "hello", "hello": { - "version": "the-protocol-version-must-be-1.0", + "version": "the-protocol-version", "resumeid": "the-resume-id-from-the-original-hello-response" } } @@ -306,7 +376,7 @@ Message format (Server -> Client): "type": "hello", "hello": { "sessionid": "the-unique-session-id", - "version": "the-protocol-version-must-be-1.0" + "version": "the-protocol-version" } } diff --git a/hub.go b/hub.go index c5194a5..c7dc1af 100644 --- a/hub.go +++ b/hub.go @@ -22,12 +22,16 @@ package signaling import ( + "bytes" "context" + "crypto/ed25519" "crypto/hmac" "crypto/sha256" + "crypto/x509" "encoding/base64" "encoding/hex" "encoding/json" + "encoding/pem" "errors" "fmt" "hash/fnv" @@ -40,20 +44,24 @@ import ( "time" "github.com/dlintw/goconf" + "github.com/golang-jwt/jwt/v4" "github.com/gorilla/mux" "github.com/gorilla/securecookie" "github.com/gorilla/websocket" ) var ( - DuplicateClient = NewError("duplicate_client", "Client already registered.") - HelloExpected = NewError("hello_expected", "Expected Hello request.") - UserAuthFailed = NewError("auth_failed", "The user could not be authenticated.") - RoomJoinFailed = NewError("room_join_failed", "Could not join the room.") - InvalidClientType = NewError("invalid_client_type", "The client type is not supported.") - InvalidBackendUrl = NewError("invalid_backend", "The backend URL is not supported.") - InvalidToken = NewError("invalid_token", "The passed token is invalid.") - NoSuchSession = NewError("no_such_session", "The session to resume does not exist.") + DuplicateClient = NewError("duplicate_client", "Client already registered.") + HelloExpected = NewError("hello_expected", "Expected Hello request.") + InvalidHelloVersion = NewError("invalid_hello_version", "The hello version is not supported.") + UserAuthFailed = NewError("auth_failed", "The user could not be authenticated.") + RoomJoinFailed = NewError("room_join_failed", "Could not join the room.") + InvalidClientType = NewError("invalid_client_type", "The client type is not supported.") + InvalidBackendUrl = NewError("invalid_backend", "The backend URL is not supported.") + InvalidToken = NewError("invalid_token", "The passed token is invalid.") + NoSuchSession = NewError("no_such_session", "The session to resume does not exist.") + TokenNotValidYet = NewError("token_not_valid_yet", "The token is not valid yet.") + TokenExpired = NewError("token_expired", "The token is expired.") // Maximum number of concurrent requests to a backend. defaultMaxConcurrentRequestsPerHost = 8 @@ -850,10 +858,18 @@ func (h *Hub) processMessage(client *Client, data []byte) { if err := message.CheckValid(); err != nil { if session := client.GetSession(); session != nil { log.Printf("Invalid message %+v from client %s: %v", message, session.PublicId(), err) - session.SendMessage(message.NewErrorServerMessage(InvalidFormat)) + if err, ok := err.(*Error); ok { + session.SendMessage(message.NewErrorServerMessage(err)) + } else { + session.SendMessage(message.NewErrorServerMessage(InvalidFormat)) + } } else { log.Printf("Invalid message %+v from %s: %v", message, client.RemoteAddr(), err) - client.SendMessage(message.NewErrorServerMessage(InvalidFormat)) + if err, ok := err.(*Error); ok { + client.SendMessage(message.NewErrorServerMessage(err)) + } else { + client.SendMessage(message.NewErrorServerMessage(InvalidFormat)) + } } return } @@ -896,7 +912,7 @@ func (h *Hub) sendHelloResponse(session *ClientSession, message *ClientMessage) Id: message.Id, Type: "hello", Hello: &HelloServerMessage{ - Version: HelloVersion, + Version: message.Hello.Version, SessionId: session.PublicId(), ResumeId: session.PrivateId(), UserId: session.UserId(), @@ -975,31 +991,163 @@ func (h *Hub) processHello(client *Client, message *ClientMessage) { } } -func (h *Hub) processHelloClient(client *Client, message *ClientMessage) { - // Make sure the client must send another "hello" in case of errors. - defer h.startExpectHello(client) - +func (h *Hub) processHelloV1(client *Client, message *ClientMessage) (*Backend, *BackendClientResponse, error) { url := message.Hello.Auth.parsedUrl backend := h.backend.GetBackend(url) if backend == nil { - client.SendMessage(message.NewErrorServerMessage(InvalidBackendUrl)) - return + return nil, nil, InvalidBackendUrl } // Run in timeout context to prevent blocking too long. ctx, cancel := context.WithTimeout(context.Background(), h.backendTimeout) defer cancel() - request := NewBackendClientAuthRequest(message.Hello.Auth.Params) var auth BackendClientResponse + request := NewBackendClientAuthRequest(message.Hello.Auth.Params) if err := h.backend.PerformJSONRequest(ctx, url, request, &auth); err != nil { - client.SendMessage(message.NewWrappedErrorServerMessage(err)) - return + return nil, nil, err } // TODO(jojo): Validate response - h.processRegister(client, message, backend, &auth) + return backend, &auth, nil +} + +func (h *Hub) processHelloV2(client *Client, message *ClientMessage) (*Backend, *BackendClientResponse, error) { + url := message.Hello.Auth.parsedUrl + backend := h.backend.GetBackend(url) + if backend == nil { + return nil, nil, InvalidBackendUrl + } + + token, err := jwt.ParseWithClaims(message.Hello.Auth.helloV2Params.Token, &HelloV2TokenClaims{}, func(token *jwt.Token) (interface{}, error) { + // Only public-private-key algorithms are supported. + var loadKeyFunc func([]byte) (interface{}, error) + switch token.Method.(type) { + case *jwt.SigningMethodRSA: + loadKeyFunc = func(data []byte) (interface{}, error) { + return jwt.ParseRSAPublicKeyFromPEM(data) + } + case *jwt.SigningMethodECDSA: + loadKeyFunc = func(data []byte) (interface{}, error) { + return jwt.ParseECPublicKeyFromPEM(data) + } + case *jwt.SigningMethodEd25519: + loadKeyFunc = func(data []byte) (interface{}, error) { + if !bytes.HasPrefix(data, []byte("-----BEGIN ")) { + // Nextcloud sends the Ed25519 key as base64-encoded public key data. + decoded, err := base64.StdEncoding.DecodeString(string(data)) + if err != nil { + return nil, err + } + + key := ed25519.PublicKey(decoded) + data, err = x509.MarshalPKIXPublicKey(key) + if err != nil { + return nil, err + } + + data = pem.EncodeToMemory(&pem.Block{ + Type: "PUBLIC KEY", + Bytes: data, + }) + } + return jwt.ParseEdPublicKeyFromPEM(data) + } + default: + log.Printf("Unexpected signing method: %v", token.Header["alg"]) + return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) + } + + // Run in timeout context to prevent blocking too long. + ctx, cancel := context.WithTimeout(context.Background(), h.backendTimeout) + defer cancel() + + keyData, cached, found := h.backend.capabilities.GetStringConfig(ctx, url, ConfigGroupSignaling, ConfigKeyHelloV2TokenKey) + if !found { + if cached { + // The Nextcloud instance might just have enabled JWT but we probably use + // the cached capabilities without the public key. Make sure to re-fetch. + h.backend.capabilities.InvalidateCapabilities(url) + keyData, _, found = h.backend.capabilities.GetStringConfig(ctx, url, ConfigGroupSignaling, ConfigKeyHelloV2TokenKey) + } + if !found { + return nil, fmt.Errorf("No key found for issuer") + } + } + + key, err := loadKeyFunc([]byte(keyData)) + if err != nil { + return nil, fmt.Errorf("Could not parse token key: %w", err) + } + + return key, nil + }) + if err != nil { + if err, ok := err.(*jwt.ValidationError); ok { + if err.Errors&jwt.ValidationErrorIssuedAt == jwt.ValidationErrorIssuedAt { + return nil, nil, TokenNotValidYet + } + if err.Errors&jwt.ValidationErrorExpired == jwt.ValidationErrorExpired { + return nil, nil, TokenExpired + } + } + + return nil, nil, InvalidToken + } + + claims, ok := token.Claims.(*HelloV2TokenClaims) + if !ok || !token.Valid { + return nil, nil, InvalidToken + } + now := time.Now() + if !claims.VerifyIssuedAt(now, true) { + return nil, nil, TokenNotValidYet + } + if !claims.VerifyExpiresAt(now, true) { + return nil, nil, TokenExpired + } + + auth := &BackendClientResponse{ + Type: "auth", + Auth: &BackendClientAuthResponse{ + Version: message.Hello.Version, + UserId: claims.Subject, + User: claims.UserData, + }, + } + return backend, auth, nil +} + +func (h *Hub) processHelloClient(client *Client, message *ClientMessage) { + // Make sure the client must send another "hello" in case of errors. + defer h.startExpectHello(client) + + var authFunc func(*Client, *ClientMessage) (*Backend, *BackendClientResponse, error) + switch message.Hello.Version { + case HelloVersionV1: + // Auth information contains a ticket that must be validated against the + // Nextcloud instance. + authFunc = h.processHelloV1 + case HelloVersionV2: + // Auth information contains a JWT that contains all information of the user. + authFunc = h.processHelloV2 + default: + client.SendMessage(message.NewErrorServerMessage(InvalidHelloVersion)) + return + } + + backend, auth, err := authFunc(client, message) + if err != nil { + if e, ok := err.(*Error); ok { + client.SendMessage(message.NewErrorServerMessage(e)) + } else { + client.SendMessage(message.NewWrappedErrorServerMessage(err)) + } + return + } + + h.processRegister(client, message, backend, auth) } func (h *Hub) processHelloInternal(client *Client, message *ClientMessage) { diff --git a/hub_test.go b/hub_test.go index 535328a..abe6903 100644 --- a/hub_test.go +++ b/hub_test.go @@ -23,12 +23,21 @@ package signaling import ( "context" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/base64" "encoding/json" + "encoding/pem" "errors" "io" "net/http" "net/http/httptest" "net/url" + "os" "reflect" "strings" "sync" @@ -37,6 +46,7 @@ import ( "time" "github.com/dlintw/goconf" + "github.com/golang-jwt/jwt/v4" "github.com/gorilla/mux" "github.com/gorilla/websocket" ) @@ -53,6 +63,13 @@ var ( "local", "clustered", } + + testHelloV2Algorithms = []string{ + "RSA", + "ECDSA", + "Ed25519", + "Ed25519_Nextcloud", + } ) // Only used for testing. @@ -511,6 +528,131 @@ func processPingRequest(t *testing.T, w http.ResponseWriter, r *http.Request, re return response } +func ensureAuthTokens(t *testing.T) (string, string) { + if privateKey := os.Getenv("PRIVATE_AUTH_TOKEN_" + t.Name()); privateKey != "" { + publicKey := os.Getenv("PUBLIC_AUTH_TOKEN_" + t.Name()) + if publicKey == "" { + // should not happen, always both keys are created + t.Fatal("public key is empty") + } + return privateKey, publicKey + } + + var private []byte + var public []byte + + if strings.Contains(t.Name(), "ECDSA") { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + + private, err = x509.MarshalECPrivateKey(key) + if err != nil { + t.Fatal(err) + } + private = pem.EncodeToMemory(&pem.Block{ + Type: "ECDSA PRIVATE KEY", + Bytes: private, + }) + + public, err = x509.MarshalPKIXPublicKey(&key.PublicKey) + if err != nil { + t.Fatal(err) + } + public = pem.EncodeToMemory(&pem.Block{ + Type: "ECDSA PUBLIC KEY", + Bytes: public, + }) + } else if strings.Contains(t.Name(), "Ed25519") { + publicKey, privateKey, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + + private, err = x509.MarshalPKCS8PrivateKey(privateKey) + if err != nil { + t.Fatal(err) + } + private = pem.EncodeToMemory(&pem.Block{ + Type: "Ed25519 PRIVATE KEY", + Bytes: private, + }) + + public, err = x509.MarshalPKIXPublicKey(publicKey) + if err != nil { + t.Fatal(err) + } + public = pem.EncodeToMemory(&pem.Block{ + Type: "Ed25519 PUBLIC KEY", + Bytes: public, + }) + } else { + key, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + t.Fatal(err) + } + + private = pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + }) + + public, err = x509.MarshalPKIXPublicKey(&key.PublicKey) + if err != nil { + t.Fatal(err) + } + public = pem.EncodeToMemory(&pem.Block{ + Type: "RSA PUBLIC KEY", + Bytes: public, + }) + } + + privateKey := base64.StdEncoding.EncodeToString(private) + t.Setenv("PRIVATE_AUTH_TOKEN_"+t.Name(), privateKey) + publicKey := base64.StdEncoding.EncodeToString(public) + t.Setenv("PUBLIC_AUTH_TOKEN_"+t.Name(), publicKey) + return privateKey, publicKey +} + +func getPrivateAuthToken(t *testing.T) (key interface{}) { + private, _ := ensureAuthTokens(t) + data, err := base64.StdEncoding.DecodeString(private) + if err != nil { + t.Fatal(err) + } + if strings.Contains(t.Name(), "ECDSA") { + key, err = jwt.ParseECPrivateKeyFromPEM(data) + } else if strings.Contains(t.Name(), "Ed25519") { + key, err = jwt.ParseEdPrivateKeyFromPEM(data) + } else { + key, err = jwt.ParseRSAPrivateKeyFromPEM(data) + } + if err != nil { + t.Fatal(err) + } + return key +} + +func getPublicAuthToken(t *testing.T) (key interface{}) { + _, public := ensureAuthTokens(t) + data, err := base64.StdEncoding.DecodeString(public) + if err != nil { + t.Fatal(err) + } + if strings.Contains(t.Name(), "ECDSA") { + key, err = jwt.ParseECPublicKeyFromPEM(data) + } else if strings.Contains(t.Name(), "Ed25519") { + key, err = jwt.ParseEdPublicKeyFromPEM(data) + } else { + key, err = jwt.ParseRSAPublicKeyFromPEM(data) + } + if err != nil { + t.Fatal(err) + } + return key +} + func registerBackendHandler(t *testing.T, router *mux.Router) { registerBackendHandlerUrl(t, router, "/") } @@ -555,6 +697,37 @@ func registerBackendHandlerUrl(t *testing.T, router *mux.Router, url string) { if strings.Contains(t.Name(), "MultiRoom") { signaling[ConfigKeySessionPingLimit] = 2 } + useV2 := true + if os.Getenv("SKIP_V2_CAPABILITIES") != "" { + useV2 = false + } + if strings.Contains(t.Name(), "V2") && useV2 { + key := getPublicAuthToken(t) + public, err := x509.MarshalPKIXPublicKey(key) + if err != nil { + t.Fatal(err) + } + var pemType string + if strings.Contains(t.Name(), "ECDSA") { + pemType = "ECDSA PUBLIC KEY" + } else if strings.Contains(t.Name(), "Ed25519") { + pemType = "Ed25519 PUBLIC KEY" + } else { + pemType = "RSA PUBLIC KEY" + } + + public = pem.EncodeToMemory(&pem.Block{ + Type: pemType, + Bytes: public, + }) + if strings.Contains(t.Name(), "Ed25519_Nextcloud") { + // Simulate Nextcloud which returns the Ed25519 key as base64-encoded data. + encoded := base64.StdEncoding.EncodeToString(key.(ed25519.PublicKey)) + signaling[ConfigKeyHelloV2TokenKey] = encoded + } else { + signaling[ConfigKeyHelloV2TokenKey] = string(public) + } + } spreedCapa, _ := json.Marshal(map[string]interface{}{ "features": features, "config": config, @@ -665,7 +838,35 @@ func TestExpectClientHello(t *testing.T) { } } -func TestClientHello(t *testing.T) { +func TestExpectClientHelloUnsupportedVersion(t *testing.T) { + hub, _, _, server := CreateHubForTest(t) + + client := NewTestClient(t, server, hub) + defer client.CloseWithBye() + + params := TestBackendClientAuthParams{ + UserId: testDefaultUserId, + } + if err := client.SendHelloParams(server.URL, "0.0", "", params); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + message, err := client.RunUntilMessage(ctx) + if err := checkUnexpectedClose(err); err != nil { + t.Fatal(err) + } + + if err := checkMessageType(message, "error"); err != nil { + t.Error(err) + } else if message.Error.Code != "invalid_hello_version" { + t.Errorf("Expected \"invalid_hello_version\" reason, got %+v", message.Error) + } +} + +func TestClientHelloV1(t *testing.T) { hub, _, _, server := CreateHubForTest(t) client := NewTestClient(t, server, hub) @@ -690,6 +891,232 @@ func TestClientHello(t *testing.T) { } } +func TestClientHelloV2(t *testing.T) { + for _, algo := range testHelloV2Algorithms { + t.Run(algo, func(t *testing.T) { + hub, _, _, server := CreateHubForTest(t) + + client := NewTestClient(t, server, hub) + defer client.CloseWithBye() + + if err := client.SendHelloV2(testDefaultUserId); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + hello, err := client.RunUntilHello(ctx) + if err != nil { + t.Fatal(err) + } + if hello.Hello.UserId != testDefaultUserId { + t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello) + } + if hello.Hello.SessionId == "" { + t.Errorf("Expected session id, got %+v", hello.Hello) + } + + data := hub.decodeSessionId(hello.Hello.SessionId, publicSessionName) + if data == nil { + t.Fatalf("Could not decode session id: %s", hello.Hello.SessionId) + } + + hub.mu.RLock() + session := hub.sessions[data.Sid] + hub.mu.RUnlock() + if session == nil { + t.Fatalf("Could not get session for id %+v", data) + } + + var userdata map[string]string + if err := json.Unmarshal(*session.UserData(), &userdata); err != nil { + t.Fatal(err) + } + + if expected := "Displayname " + testDefaultUserId; userdata["displayname"] != expected { + t.Errorf("Expected displayname %s, got %s", expected, userdata["displayname"]) + } + }) + } +} + +func TestClientHelloV2_IssuedInFuture(t *testing.T) { + for _, algo := range testHelloV2Algorithms { + t.Run(algo, func(t *testing.T) { + hub, _, _, server := CreateHubForTest(t) + + client := NewTestClient(t, server, hub) + defer client.CloseWithBye() + + issuedAt := time.Now().Add(time.Minute) + expiresAt := issuedAt.Add(time.Second) + if err := client.SendHelloV2WithTimes(testDefaultUserId, issuedAt, expiresAt); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + message, err := client.RunUntilMessage(ctx) + if err := checkUnexpectedClose(err); err != nil { + t.Fatal(err) + } + + if err := checkMessageType(message, "error"); err != nil { + t.Error(err) + } else if message.Error.Code != "token_not_valid_yet" { + t.Errorf("Expected \"token_not_valid_yet\" reason, got %+v", message.Error) + } + }) + } +} + +func TestClientHelloV2_Expired(t *testing.T) { + for _, algo := range testHelloV2Algorithms { + t.Run(algo, func(t *testing.T) { + hub, _, _, server := CreateHubForTest(t) + + client := NewTestClient(t, server, hub) + defer client.CloseWithBye() + + issuedAt := time.Now().Add(-time.Minute) + if err := client.SendHelloV2WithTimes(testDefaultUserId, issuedAt, issuedAt.Add(time.Second)); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + message, err := client.RunUntilMessage(ctx) + if err := checkUnexpectedClose(err); err != nil { + t.Fatal(err) + } + + if err := checkMessageType(message, "error"); err != nil { + t.Error(err) + } else if message.Error.Code != "token_expired" { + t.Errorf("Expected \"token_expired\" reason, got %+v", message.Error) + } + }) + } +} + +func TestClientHelloV2_IssuedAtMissing(t *testing.T) { + for _, algo := range testHelloV2Algorithms { + t.Run(algo, func(t *testing.T) { + hub, _, _, server := CreateHubForTest(t) + + client := NewTestClient(t, server, hub) + defer client.CloseWithBye() + + var issuedAt time.Time + expiresAt := time.Now().Add(time.Minute) + if err := client.SendHelloV2WithTimes(testDefaultUserId, issuedAt, expiresAt); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + message, err := client.RunUntilMessage(ctx) + if err := checkUnexpectedClose(err); err != nil { + t.Fatal(err) + } + + if err := checkMessageType(message, "error"); err != nil { + t.Error(err) + } else if message.Error.Code != "token_not_valid_yet" { + t.Errorf("Expected \"token_not_valid_yet\" reason, got %+v", message.Error) + } + }) + } +} + +func TestClientHelloV2_ExpiresAtMissing(t *testing.T) { + for _, algo := range testHelloV2Algorithms { + t.Run(algo, func(t *testing.T) { + hub, _, _, server := CreateHubForTest(t) + + client := NewTestClient(t, server, hub) + defer client.CloseWithBye() + + issuedAt := time.Now().Add(-time.Minute) + var expiresAt time.Time + if err := client.SendHelloV2WithTimes(testDefaultUserId, issuedAt, expiresAt); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + message, err := client.RunUntilMessage(ctx) + if err := checkUnexpectedClose(err); err != nil { + t.Fatal(err) + } + + if err := checkMessageType(message, "error"); err != nil { + t.Error(err) + } else if message.Error.Code != "token_expired" { + t.Errorf("Expected \"token_expired\" reason, got %+v", message.Error) + } + }) + } +} + +func TestClientHelloV2_CachedCapabilities(t *testing.T) { + for _, algo := range testHelloV2Algorithms { + t.Run(algo, func(t *testing.T) { + hub, _, _, server := CreateHubForTest(t) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + // Simulate old-style Nextcloud without capabilities for Hello V2. + t.Setenv("SKIP_V2_CAPABILITIES", "1") + + client1 := NewTestClient(t, server, hub) + defer client1.CloseWithBye() + + if err := client1.SendHelloV1(testDefaultUserId + "1"); err != nil { + t.Fatal(err) + } + + hello1, err := client1.RunUntilHello(ctx) + if err != nil { + t.Fatal(err) + } + if hello1.Hello.UserId != testDefaultUserId+"1" { + t.Errorf("Expected \"%s\", got %+v", testDefaultUserId+"1", hello1.Hello) + } + if hello1.Hello.SessionId == "" { + t.Errorf("Expected session id, got %+v", hello1.Hello) + } + + // Simulate updated Nextcloud with capabilities for Hello V2. + t.Setenv("SKIP_V2_CAPABILITIES", "") + + client2 := NewTestClient(t, server, hub) + defer client2.CloseWithBye() + + if err := client2.SendHelloV2(testDefaultUserId + "2"); err != nil { + t.Fatal(err) + } + + hello2, err := client2.RunUntilHello(ctx) + if err != nil { + t.Fatal(err) + } + if hello2.Hello.UserId != testDefaultUserId+"2" { + t.Errorf("Expected \"%s\", got %+v", testDefaultUserId+"2", hello2.Hello) + } + if hello2.Hello.SessionId == "" { + t.Errorf("Expected session id, got %+v", hello2.Hello) + } + }) + } +} + func TestClientHelloWithSpaces(t *testing.T) { hub, _, _, server := CreateHubForTest(t) @@ -826,7 +1253,7 @@ func TestClientHelloSessionLimit(t *testing.T) { params1 := TestBackendClientAuthParams{ UserId: testDefaultUserId, } - if err := client.SendHelloParams(server1.URL+"/one", "client", params1); err != nil { + if err := client.SendHelloParams(server1.URL+"/one", HelloVersionV1, "client", params1); err != nil { t.Fatal(err) } @@ -851,7 +1278,7 @@ func TestClientHelloSessionLimit(t *testing.T) { params2 := TestBackendClientAuthParams{ UserId: testDefaultUserId + "2", } - if err := client2.SendHelloParams(server1.URL+"/one", "client", params2); err != nil { + if err := client2.SendHelloParams(server1.URL+"/one", HelloVersionV1, "client", params2); err != nil { t.Fatal(err) } @@ -867,7 +1294,7 @@ func TestClientHelloSessionLimit(t *testing.T) { } // The client can connect to a different backend. - if err := client2.SendHelloParams(server1.URL+"/two", "client", params2); err != nil { + if err := client2.SendHelloParams(server1.URL+"/two", HelloVersionV1, "client", params2); err != nil { t.Fatal(err) } @@ -894,7 +1321,7 @@ func TestClientHelloSessionLimit(t *testing.T) { params3 := TestBackendClientAuthParams{ UserId: testDefaultUserId + "3", } - if err := client3.SendHelloParams(server1.URL+"/one", "client", params3); err != nil { + if err := client3.SendHelloParams(server1.URL+"/one", HelloVersionV1, "client", params3); err != nil { t.Fatal(err) } @@ -1498,7 +1925,7 @@ func TestClientHelloClient_V3Api(t *testing.T) { } // The "/api/v1/signaling/" URL will be changed to use "v3" as the "signaling-v3" // feature is returned by the capabilities endpoint. - if err := client.SendHelloParams(server.URL+"/ocs/v2.php/apps/spreed/api/v1/signaling/backend", "client", params); err != nil { + if err := client.SendHelloParams(server.URL+"/ocs/v2.php/apps/spreed/api/v1/signaling/backend", HelloVersionV1, "client", params); err != nil { t.Fatal(err) } @@ -3841,7 +4268,7 @@ func TestNoSendBetweenSessionsOnDifferentBackends(t *testing.T) { params1 := TestBackendClientAuthParams{ UserId: "user1", } - if err := client1.SendHelloParams(server.URL+"/one", "client", params1); err != nil { + if err := client1.SendHelloParams(server.URL+"/one", HelloVersionV1, "client", params1); err != nil { t.Fatal(err) } hello1, err := client1.RunUntilHello(ctx) @@ -3855,7 +4282,7 @@ func TestNoSendBetweenSessionsOnDifferentBackends(t *testing.T) { params2 := TestBackendClientAuthParams{ UserId: "user2", } - if err := client2.SendHelloParams(server.URL+"/two", "client", params2); err != nil { + if err := client2.SendHelloParams(server.URL+"/two", HelloVersionV1, "client", params2); err != nil { t.Fatal(err) } hello2, err := client2.RunUntilHello(ctx) @@ -3911,7 +4338,7 @@ func TestNoSameRoomOnDifferentBackends(t *testing.T) { params1 := TestBackendClientAuthParams{ UserId: "user1", } - if err := client1.SendHelloParams(server.URL+"/one", "client", params1); err != nil { + if err := client1.SendHelloParams(server.URL+"/one", HelloVersionV1, "client", params1); err != nil { t.Fatal(err) } hello1, err := client1.RunUntilHello(ctx) @@ -3925,7 +4352,7 @@ func TestNoSameRoomOnDifferentBackends(t *testing.T) { params2 := TestBackendClientAuthParams{ UserId: "user2", } - if err := client2.SendHelloParams(server.URL+"/two", "client", params2); err != nil { + if err := client2.SendHelloParams(server.URL+"/two", HelloVersionV1, "client", params2); err != nil { t.Fatal(err) } hello2, err := client2.RunUntilHello(ctx) diff --git a/proxy/proxy_server.go b/proxy/proxy_server.go index d5de159..d1e2938 100644 --- a/proxy/proxy_server.go +++ b/proxy/proxy_server.go @@ -589,7 +589,7 @@ func (s *ProxyServer) processMessage(client *ProxyClient, data []byte) { Id: message.Id, Type: "hello", Hello: &signaling.HelloProxyServerMessage{ - Version: signaling.HelloVersion, + Version: signaling.HelloVersionV1, SessionId: session.PublicId(), Server: &signaling.WelcomeServerMessage{ Version: s.version, diff --git a/room_ping.go b/room_ping.go index d3cf231..48c301a 100644 --- a/room_ping.go +++ b/room_ping.go @@ -29,11 +29,6 @@ import ( "time" ) -const ( - ConfigGroupSignaling = "signaling" - ConfigKeySessionPingLimit = "session-ping-limit" -) - type pingEntries struct { url *url.URL @@ -124,7 +119,7 @@ func (p *RoomPing) publishEntries(entries *pingEntries, timeout time.Duration) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - limit, found := p.capabilities.GetIntegerConfig(ctx, entries.url, ConfigGroupSignaling, ConfigKeySessionPingLimit) + limit, _, found := p.capabilities.GetIntegerConfig(ctx, entries.url, ConfigGroupSignaling, ConfigKeySessionPingLimit) if !found || limit <= 0 { // Limit disabled while waiting for the next iteration, fallback to sending // one request per room. @@ -193,7 +188,7 @@ func (p *RoomPing) sendPingsCombined(url *url.URL, entries []BackendPingEntry, l } func (p *RoomPing) SendPings(ctx context.Context, room *Room, url *url.URL, entries []BackendPingEntry) error { - limit, found := p.capabilities.GetIntegerConfig(ctx, url, ConfigGroupSignaling, ConfigKeySessionPingLimit) + limit, _, found := p.capabilities.GetIntegerConfig(ctx, url, ConfigGroupSignaling, ConfigKeySessionPingLimit) if !found || limit <= 0 { // Old-style Nextcloud or session limit not configured. Perform one request // per room. Don't queue to avoid sending all ping requests to old-style diff --git a/testclient_test.go b/testclient_test.go index f3d33ff..d38a870 100644 --- a/testclient_test.go +++ b/testclient_test.go @@ -37,6 +37,7 @@ import ( "testing" "time" + "github.com/golang-jwt/jwt/v4" "github.com/gorilla/websocket" ) @@ -359,11 +360,14 @@ func (c *TestClient) WaitForSessionRemoved(ctx context.Context, sessionId string } func (c *TestClient) WriteJSON(data interface{}) error { - if msg, ok := data.(*ClientMessage); ok { - if err := msg.CheckValid(); err != nil { - return err + if !strings.Contains(c.t.Name(), "HelloUnsupportedVersion") { + if msg, ok := data.(*ClientMessage); ok { + if err := msg.CheckValid(); err != nil { + return err + } } } + return c.conn.WriteJSON(data) } @@ -374,10 +378,63 @@ func (c *TestClient) EnsuerWriteJSON(data interface{}) { } func (c *TestClient) SendHello(userid string) error { + return c.SendHelloV1(userid) +} + +func (c *TestClient) SendHelloV1(userid string) error { params := TestBackendClientAuthParams{ UserId: userid, } - return c.SendHelloParams(c.server.URL, "", params) + return c.SendHelloParams(c.server.URL, HelloVersionV1, "", params) +} + +func (c *TestClient) SendHelloV2(userid string) error { + now := time.Now() + return c.SendHelloV2WithTimes(userid, now, now.Add(time.Minute)) +} + +func (c *TestClient) SendHelloV2WithTimes(userid string, issuedAt time.Time, expiresAt time.Time) error { + userdata := map[string]string{ + "displayname": "Displayname " + userid, + } + + data, err := json.Marshal(userdata) + if err != nil { + c.t.Fatal(err) + } + + claims := &HelloV2TokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: c.server.URL, + Subject: userid, + }, + UserData: (*json.RawMessage)(&data), + } + if !issuedAt.IsZero() { + claims.IssuedAt = jwt.NewNumericDate(issuedAt) + } + if !expiresAt.IsZero() { + claims.ExpiresAt = jwt.NewNumericDate(expiresAt) + } + + var token *jwt.Token + if strings.Contains(c.t.Name(), "ECDSA") { + token = jwt.NewWithClaims(jwt.SigningMethodES256, claims) + } else if strings.Contains(c.t.Name(), "Ed25519") { + token = jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims) + } else { + token = jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + } + private := getPrivateAuthToken(c.t) + tokenString, err := token.SignedString(private) + if err != nil { + c.t.Fatal(err) + } + + params := HelloV2AuthParams{ + Token: tokenString, + } + return c.SendHelloParams(c.server.URL, HelloVersionV2, "", params) } func (c *TestClient) SendHelloResume(resumeId string) error { @@ -385,7 +442,7 @@ func (c *TestClient) SendHelloResume(resumeId string) error { Id: "1234", Type: "hello", Hello: &HelloClientMessage{ - Version: HelloVersion, + Version: HelloVersionV1, ResumeId: resumeId, }, } @@ -396,7 +453,7 @@ func (c *TestClient) SendHelloClient(userid string) error { params := TestBackendClientAuthParams{ UserId: userid, } - return c.SendHelloParams(c.server.URL, "client", params) + return c.SendHelloParams(c.server.URL, HelloVersionV1, "client", params) } func (c *TestClient) SendHelloInternal() error { @@ -411,10 +468,10 @@ func (c *TestClient) SendHelloInternal() error { Token: token, Backend: backend, } - return c.SendHelloParams("", "internal", params) + return c.SendHelloParams("", HelloVersionV1, "internal", params) } -func (c *TestClient) SendHelloParams(url string, clientType string, params interface{}) error { +func (c *TestClient) SendHelloParams(url string, version string, clientType string, params interface{}) error { data, err := json.Marshal(params) if err != nil { c.t.Fatal(err) @@ -424,7 +481,7 @@ func (c *TestClient) SendHelloParams(url string, clientType string, params inter Id: "1234", Type: "hello", Hello: &HelloClientMessage{ - Version: HelloVersion, + Version: version, Auth: HelloClientMessageAuth{ Type: clientType, Url: url,