mirror of
https://github.com/strukturag/nextcloud-spreed-signaling
synced 2024-06-13 11:22:14 +02:00
Merge pull request #251 from strukturag/jwt-auth
Support hello auth version "2.0" with JWT
This commit is contained in:
commit
aaa9b2dde2
|
@ -40,6 +40,11 @@ const (
|
||||||
HeaderBackendSignalingRandom = "Spreed-Signaling-Random"
|
HeaderBackendSignalingRandom = "Spreed-Signaling-Random"
|
||||||
HeaderBackendSignalingChecksum = "Spreed-Signaling-Checksum"
|
HeaderBackendSignalingChecksum = "Spreed-Signaling-Checksum"
|
||||||
HeaderBackendServer = "Spreed-Signaling-Backend"
|
HeaderBackendServer = "Spreed-Signaling-Backend"
|
||||||
|
|
||||||
|
ConfigGroupSignaling = "signaling"
|
||||||
|
|
||||||
|
ConfigKeyHelloV2TokenKey = "hello-v2-token-key"
|
||||||
|
ConfigKeySessionPingLimit = "session-ping-limit"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newRandomString(length int) string {
|
func newRandomString(length int) string {
|
||||||
|
|
|
@ -142,7 +142,7 @@ type HelloProxyClientMessage struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *HelloProxyClientMessage) CheckValid() error {
|
func (m *HelloProxyClientMessage) CheckValid() error {
|
||||||
if m.Version != HelloVersion {
|
if m.Version != HelloVersionV1 {
|
||||||
return fmt.Errorf("unsupported hello version: %s", m.Version)
|
return fmt.Errorf("unsupported hello version: %s", m.Version)
|
||||||
}
|
}
|
||||||
if m.ResumeId == "" {
|
if m.ResumeId == "" {
|
||||||
|
|
|
@ -27,11 +27,16 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v4"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// Version that must be sent in a "hello" message.
|
// Version 1.0 validates auth params against the Nextcloud instance.
|
||||||
HelloVersion = "1.0"
|
HelloVersionV1 = "1.0"
|
||||||
|
|
||||||
|
// Version 2.0 validates auth params encoded as JWT.
|
||||||
|
HelloVersionV2 = "2.0"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ClientMessage is a message that is sent from a client to the server.
|
// ClientMessage is a message that is sent from a client to the server.
|
||||||
|
@ -325,6 +330,23 @@ func (p *ClientTypeInternalAuthParams) CheckValid() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type HelloV2AuthParams struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *HelloV2AuthParams) CheckValid() error {
|
||||||
|
if p.Token == "" {
|
||||||
|
return fmt.Errorf("token missing")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type HelloV2TokenClaims struct {
|
||||||
|
jwt.RegisteredClaims
|
||||||
|
|
||||||
|
UserData *json.RawMessage `json:"userdata,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
type HelloClientMessageAuth struct {
|
type HelloClientMessageAuth struct {
|
||||||
// The client type that is connecting. Leave empty to use the default
|
// The client type that is connecting. Leave empty to use the default
|
||||||
// "HelloClientTypeClient"
|
// "HelloClientTypeClient"
|
||||||
|
@ -336,6 +358,7 @@ type HelloClientMessageAuth struct {
|
||||||
parsedUrl *url.URL
|
parsedUrl *url.URL
|
||||||
|
|
||||||
internalParams ClientTypeInternalAuthParams
|
internalParams ClientTypeInternalAuthParams
|
||||||
|
helloV2Params HelloV2AuthParams
|
||||||
}
|
}
|
||||||
|
|
||||||
// Type "hello"
|
// Type "hello"
|
||||||
|
@ -352,8 +375,8 @@ type HelloClientMessage struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *HelloClientMessage) CheckValid() error {
|
func (m *HelloClientMessage) CheckValid() error {
|
||||||
if m.Version != HelloVersion {
|
if m.Version != HelloVersionV1 && m.Version != HelloVersionV2 {
|
||||||
return fmt.Errorf("unsupported hello version: %s", m.Version)
|
return InvalidHelloVersion
|
||||||
}
|
}
|
||||||
if m.ResumeId == "" {
|
if m.ResumeId == "" {
|
||||||
if m.Auth.Params == nil || len(*m.Auth.Params) == 0 {
|
if m.Auth.Params == nil || len(*m.Auth.Params) == 0 {
|
||||||
|
@ -375,6 +398,17 @@ func (m *HelloClientMessage) CheckValid() error {
|
||||||
|
|
||||||
m.Auth.parsedUrl = u
|
m.Auth.parsedUrl = u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
switch m.Version {
|
||||||
|
case HelloVersionV1:
|
||||||
|
// No additional validation necessary.
|
||||||
|
case HelloVersionV2:
|
||||||
|
if err := json.Unmarshal(*m.Auth.Params, &m.Auth.helloV2Params); err != nil {
|
||||||
|
return err
|
||||||
|
} else if err := m.Auth.helloV2Params.CheckValid(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
case HelloClientTypeInternal:
|
case HelloClientTypeInternal:
|
||||||
if err := json.Unmarshal(*m.Auth.Params, &m.Auth.internalParams); err != nil {
|
if err := json.Unmarshal(*m.Auth.Params, &m.Auth.internalParams); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -397,6 +431,7 @@ const (
|
||||||
ServerFeatureTransientData = "transient-data"
|
ServerFeatureTransientData = "transient-data"
|
||||||
ServerFeatureInCallAll = "incall-all"
|
ServerFeatureInCallAll = "incall-all"
|
||||||
ServerFeatureWelcome = "welcome"
|
ServerFeatureWelcome = "welcome"
|
||||||
|
ServerFeatureHelloV2 = "hello-v2"
|
||||||
|
|
||||||
// Features for internal clients only.
|
// Features for internal clients only.
|
||||||
ServerFeatureInternalVirtualSessions = "virtual-sessions"
|
ServerFeatureInternalVirtualSessions = "virtual-sessions"
|
||||||
|
@ -408,12 +443,14 @@ var (
|
||||||
ServerFeatureTransientData,
|
ServerFeatureTransientData,
|
||||||
ServerFeatureInCallAll,
|
ServerFeatureInCallAll,
|
||||||
ServerFeatureWelcome,
|
ServerFeatureWelcome,
|
||||||
|
ServerFeatureHelloV2,
|
||||||
}
|
}
|
||||||
DefaultFeaturesInternal = []string{
|
DefaultFeaturesInternal = []string{
|
||||||
ServerFeatureInternalVirtualSessions,
|
ServerFeatureInternalVirtualSessions,
|
||||||
ServerFeatureTransientData,
|
ServerFeatureTransientData,
|
||||||
ServerFeatureInCallAll,
|
ServerFeatureInCallAll,
|
||||||
ServerFeatureWelcome,
|
ServerFeatureWelcome,
|
||||||
|
ServerFeatureHelloV2,
|
||||||
}
|
}
|
||||||
DefaultWelcomeFeatures = []string{
|
DefaultWelcomeFeatures = []string{
|
||||||
ServerFeatureAudioVideoPermissions,
|
ServerFeatureAudioVideoPermissions,
|
||||||
|
@ -421,6 +458,7 @@ var (
|
||||||
ServerFeatureTransientData,
|
ServerFeatureTransientData,
|
||||||
ServerFeatureInCallAll,
|
ServerFeatureInCallAll,
|
||||||
ServerFeatureWelcome,
|
ServerFeatureWelcome,
|
||||||
|
ServerFeatureHelloV2,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -90,16 +90,18 @@ func TestClientMessage(t *testing.T) {
|
||||||
|
|
||||||
func TestHelloClientMessage(t *testing.T) {
|
func TestHelloClientMessage(t *testing.T) {
|
||||||
internalAuthParams := []byte("{\"backend\":\"https://domain.invalid\"}")
|
internalAuthParams := []byte("{\"backend\":\"https://domain.invalid\"}")
|
||||||
|
tokenAuthParams := []byte("{\"token\":\"invalid-token\"}")
|
||||||
valid_messages := []testCheckValid{
|
valid_messages := []testCheckValid{
|
||||||
|
// Hello version 1
|
||||||
&HelloClientMessage{
|
&HelloClientMessage{
|
||||||
Version: HelloVersion,
|
Version: HelloVersionV1,
|
||||||
Auth: HelloClientMessageAuth{
|
Auth: HelloClientMessageAuth{
|
||||||
Params: &json.RawMessage{'{', '}'},
|
Params: &json.RawMessage{'{', '}'},
|
||||||
Url: "https://domain.invalid",
|
Url: "https://domain.invalid",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
&HelloClientMessage{
|
&HelloClientMessage{
|
||||||
Version: HelloVersion,
|
Version: HelloVersionV1,
|
||||||
Auth: HelloClientMessageAuth{
|
Auth: HelloClientMessageAuth{
|
||||||
Type: "client",
|
Type: "client",
|
||||||
Params: &json.RawMessage{'{', '}'},
|
Params: &json.RawMessage{'{', '}'},
|
||||||
|
@ -107,61 +109,116 @@ func TestHelloClientMessage(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
&HelloClientMessage{
|
&HelloClientMessage{
|
||||||
Version: HelloVersion,
|
Version: HelloVersionV1,
|
||||||
Auth: HelloClientMessageAuth{
|
Auth: HelloClientMessageAuth{
|
||||||
Type: "internal",
|
Type: "internal",
|
||||||
Params: (*json.RawMessage)(&internalAuthParams),
|
Params: (*json.RawMessage)(&internalAuthParams),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
&HelloClientMessage{
|
&HelloClientMessage{
|
||||||
Version: HelloVersion,
|
Version: HelloVersionV1,
|
||||||
|
ResumeId: "the-resume-id",
|
||||||
|
},
|
||||||
|
// Hello version 2
|
||||||
|
&HelloClientMessage{
|
||||||
|
Version: HelloVersionV2,
|
||||||
|
Auth: HelloClientMessageAuth{
|
||||||
|
Params: (*json.RawMessage)(&tokenAuthParams),
|
||||||
|
Url: "https://domain.invalid",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
&HelloClientMessage{
|
||||||
|
Version: HelloVersionV2,
|
||||||
|
Auth: HelloClientMessageAuth{
|
||||||
|
Type: "client",
|
||||||
|
Params: (*json.RawMessage)(&tokenAuthParams),
|
||||||
|
Url: "https://domain.invalid",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
&HelloClientMessage{
|
||||||
|
Version: HelloVersionV2,
|
||||||
ResumeId: "the-resume-id",
|
ResumeId: "the-resume-id",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
invalid_messages := []testCheckValid{
|
invalid_messages := []testCheckValid{
|
||||||
|
// Hello version 1
|
||||||
&HelloClientMessage{},
|
&HelloClientMessage{},
|
||||||
&HelloClientMessage{Version: "0.0"},
|
&HelloClientMessage{Version: "0.0"},
|
||||||
&HelloClientMessage{Version: HelloVersion},
|
&HelloClientMessage{Version: HelloVersionV1},
|
||||||
&HelloClientMessage{
|
&HelloClientMessage{
|
||||||
Version: HelloVersion,
|
Version: HelloVersionV1,
|
||||||
Auth: HelloClientMessageAuth{
|
Auth: HelloClientMessageAuth{
|
||||||
Params: &json.RawMessage{'{', '}'},
|
Params: &json.RawMessage{'{', '}'},
|
||||||
Type: "invalid-type",
|
Type: "invalid-type",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
&HelloClientMessage{
|
&HelloClientMessage{
|
||||||
Version: HelloVersion,
|
Version: HelloVersionV1,
|
||||||
Auth: HelloClientMessageAuth{
|
Auth: HelloClientMessageAuth{
|
||||||
Url: "https://domain.invalid",
|
Url: "https://domain.invalid",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
&HelloClientMessage{
|
&HelloClientMessage{
|
||||||
Version: HelloVersion,
|
Version: HelloVersionV1,
|
||||||
Auth: HelloClientMessageAuth{
|
Auth: HelloClientMessageAuth{
|
||||||
Params: &json.RawMessage{'{', '}'},
|
Params: &json.RawMessage{'{', '}'},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
&HelloClientMessage{
|
&HelloClientMessage{
|
||||||
Version: HelloVersion,
|
Version: HelloVersionV1,
|
||||||
Auth: HelloClientMessageAuth{
|
Auth: HelloClientMessageAuth{
|
||||||
Params: &json.RawMessage{'{', '}'},
|
Params: &json.RawMessage{'{', '}'},
|
||||||
Url: "invalid-url",
|
Url: "invalid-url",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
&HelloClientMessage{
|
&HelloClientMessage{
|
||||||
Version: HelloVersion,
|
Version: HelloVersionV1,
|
||||||
Auth: HelloClientMessageAuth{
|
Auth: HelloClientMessageAuth{
|
||||||
Type: "internal",
|
Type: "internal",
|
||||||
Params: &json.RawMessage{'{', '}'},
|
Params: &json.RawMessage{'{', '}'},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
&HelloClientMessage{
|
&HelloClientMessage{
|
||||||
Version: HelloVersion,
|
Version: HelloVersionV1,
|
||||||
Auth: HelloClientMessageAuth{
|
Auth: HelloClientMessageAuth{
|
||||||
Type: "internal",
|
Type: "internal",
|
||||||
Params: &json.RawMessage{'x', 'y', 'z'}, // Invalid JSON.
|
Params: &json.RawMessage{'x', 'y', 'z'}, // Invalid JSON.
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
// Hello version 2
|
||||||
|
&HelloClientMessage{
|
||||||
|
Version: HelloVersionV2,
|
||||||
|
Auth: HelloClientMessageAuth{
|
||||||
|
Url: "https://domain.invalid",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
&HelloClientMessage{
|
||||||
|
Version: HelloVersionV2,
|
||||||
|
Auth: HelloClientMessageAuth{
|
||||||
|
Params: (*json.RawMessage)(&tokenAuthParams),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
&HelloClientMessage{
|
||||||
|
Version: HelloVersionV2,
|
||||||
|
Auth: HelloClientMessageAuth{
|
||||||
|
Params: (*json.RawMessage)(&tokenAuthParams),
|
||||||
|
Url: "invalid-url",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
&HelloClientMessage{
|
||||||
|
Version: HelloVersionV2,
|
||||||
|
Auth: HelloClientMessageAuth{
|
||||||
|
Params: (*json.RawMessage)(&internalAuthParams),
|
||||||
|
Url: "https://domain.invalid",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
&HelloClientMessage{
|
||||||
|
Version: HelloVersionV2,
|
||||||
|
Auth: HelloClientMessageAuth{
|
||||||
|
Params: &json.RawMessage{'x', 'y', 'z'}, // Invalid JSON.
|
||||||
|
Url: "https://domain.invalid",
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
testMessages(t, "hello", valid_messages, invalid_messages)
|
testMessages(t, "hello", valid_messages, invalid_messages)
|
||||||
|
|
109
capabilities.go
109
capabilities.go
|
@ -43,8 +43,14 @@ const (
|
||||||
|
|
||||||
// Cache received capabilities for one hour.
|
// Cache received capabilities for one hour.
|
||||||
CapabilitiesCacheDuration = time.Hour
|
CapabilitiesCacheDuration = time.Hour
|
||||||
|
|
||||||
|
// Don't invalidate more than once per minute.
|
||||||
|
maxInvalidateInterval = time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Can be overwritten by tests.
|
||||||
|
var getCapabilitiesNow = time.Now
|
||||||
|
|
||||||
type capabilitiesEntry struct {
|
type capabilitiesEntry struct {
|
||||||
nextUpdate time.Time
|
nextUpdate time.Time
|
||||||
capabilities map[string]interface{}
|
capabilities map[string]interface{}
|
||||||
|
@ -56,6 +62,7 @@ type Capabilities struct {
|
||||||
version string
|
version string
|
||||||
pool *HttpClientPool
|
pool *HttpClientPool
|
||||||
entries map[string]*capabilitiesEntry
|
entries map[string]*capabilitiesEntry
|
||||||
|
nextInvalidate map[string]time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewCapabilities(version string, pool *HttpClientPool) (*Capabilities, error) {
|
func NewCapabilities(version string, pool *HttpClientPool) (*Capabilities, error) {
|
||||||
|
@ -63,6 +70,7 @@ func NewCapabilities(version string, pool *HttpClientPool) (*Capabilities, error
|
||||||
version: version,
|
version: version,
|
||||||
pool: pool,
|
pool: pool,
|
||||||
entries: make(map[string]*capabilitiesEntry),
|
entries: make(map[string]*capabilitiesEntry),
|
||||||
|
nextInvalidate: make(map[string]time.Time),
|
||||||
}
|
}
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
|
@ -86,7 +94,7 @@ func (c *Capabilities) getCapabilities(key string) (map[string]interface{}, bool
|
||||||
c.mu.RLock()
|
c.mu.RLock()
|
||||||
defer c.mu.RUnlock()
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
now := time.Now()
|
now := getCapabilitiesNow()
|
||||||
if entry, found := c.entries[key]; found && entry.nextUpdate.After(now) {
|
if entry, found := c.entries[key]; found && entry.nextUpdate.After(now) {
|
||||||
return entry.capabilities, true
|
return entry.capabilities, true
|
||||||
}
|
}
|
||||||
|
@ -95,7 +103,7 @@ func (c *Capabilities) getCapabilities(key string) (map[string]interface{}, bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Capabilities) setCapabilities(key string, capabilities map[string]interface{}) {
|
func (c *Capabilities) setCapabilities(key string, capabilities map[string]interface{}) {
|
||||||
now := time.Now()
|
now := getCapabilitiesNow()
|
||||||
entry := &capabilitiesEntry{
|
entry := &capabilitiesEntry{
|
||||||
nextUpdate: now.Add(CapabilitiesCacheDuration),
|
nextUpdate: now.Add(CapabilitiesCacheDuration),
|
||||||
capabilities: capabilities,
|
capabilities: capabilities,
|
||||||
|
@ -106,11 +114,28 @@ func (c *Capabilities) setCapabilities(key string, capabilities map[string]inter
|
||||||
c.entries[key] = entry
|
c.entries[key] = entry
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (map[string]interface{}, error) {
|
func (c *Capabilities) invalidateCapabilities(key string) {
|
||||||
key := u.String()
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
now := getCapabilitiesNow()
|
||||||
|
if entry, found := c.nextInvalidate[key]; found && entry.After(now) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(c.entries, key)
|
||||||
|
c.nextInvalidate[key] = now.Add(maxInvalidateInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Capabilities) getKeyForUrl(u *url.URL) string {
|
||||||
|
key := u.String()
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (map[string]interface{}, bool, error) {
|
||||||
|
key := c.getKeyForUrl(u)
|
||||||
if caps, found := c.getCapabilities(key); found {
|
if caps, found := c.getCapabilities(key); found {
|
||||||
return caps, nil
|
return caps, true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
capUrl := *u
|
capUrl := *u
|
||||||
|
@ -128,14 +153,14 @@ func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (map[st
|
||||||
client, pool, err := c.pool.Get(ctx, &capUrl)
|
client, pool, err := c.pool.Get(ctx, &capUrl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Could not get client for host %s: %s", capUrl.Host, err)
|
log.Printf("Could not get client for host %s: %s", capUrl.Host, err)
|
||||||
return nil, err
|
return nil, false, err
|
||||||
}
|
}
|
||||||
defer pool.Put(client)
|
defer pool.Put(client)
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", capUrl.String(), nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", capUrl.String(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Could not create request to %s: %s", &capUrl, err)
|
log.Printf("Could not create request to %s: %s", &capUrl, err)
|
||||||
return nil, err
|
return nil, false, err
|
||||||
}
|
}
|
||||||
req.Header.Set("Accept", "application/json")
|
req.Header.Set("Accept", "application/json")
|
||||||
req.Header.Set("OCS-APIRequest", "true")
|
req.Header.Set("OCS-APIRequest", "true")
|
||||||
|
@ -143,56 +168,56 @@ func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (map[st
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, false, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
ct := resp.Header.Get("Content-Type")
|
ct := resp.Header.Get("Content-Type")
|
||||||
if !strings.HasPrefix(ct, "application/json") {
|
if !strings.HasPrefix(ct, "application/json") {
|
||||||
log.Printf("Received unsupported content-type from %s: %s (%s)", capUrl.String(), ct, resp.Status)
|
log.Printf("Received unsupported content-type from %s: %s (%s)", capUrl.String(), ct, resp.Status)
|
||||||
return nil, ErrUnsupportedContentType
|
return nil, false, ErrUnsupportedContentType
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Could not read response body from %s: %s", capUrl.String(), err)
|
log.Printf("Could not read response body from %s: %s", capUrl.String(), err)
|
||||||
return nil, err
|
return nil, false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var ocs OcsResponse
|
var ocs OcsResponse
|
||||||
if err := json.Unmarshal(body, &ocs); err != nil {
|
if err := json.Unmarshal(body, &ocs); err != nil {
|
||||||
log.Printf("Could not decode OCS response %s from %s: %s", string(body), capUrl.String(), err)
|
log.Printf("Could not decode OCS response %s from %s: %s", string(body), capUrl.String(), err)
|
||||||
return nil, err
|
return nil, false, err
|
||||||
} else if ocs.Ocs == nil || ocs.Ocs.Data == nil {
|
} else if ocs.Ocs == nil || ocs.Ocs.Data == nil {
|
||||||
log.Printf("Incomplete OCS response %s from %s", string(body), u)
|
log.Printf("Incomplete OCS response %s from %s", string(body), u)
|
||||||
return nil, fmt.Errorf("incomplete OCS response")
|
return nil, false, fmt.Errorf("incomplete OCS response")
|
||||||
}
|
}
|
||||||
|
|
||||||
var response CapabilitiesResponse
|
var response CapabilitiesResponse
|
||||||
if err := json.Unmarshal(*ocs.Ocs.Data, &response); err != nil {
|
if err := json.Unmarshal(*ocs.Ocs.Data, &response); err != nil {
|
||||||
log.Printf("Could not decode OCS response body %s from %s: %s", string(*ocs.Ocs.Data), capUrl.String(), err)
|
log.Printf("Could not decode OCS response body %s from %s: %s", string(*ocs.Ocs.Data), capUrl.String(), err)
|
||||||
return nil, err
|
return nil, false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
capaObj, found := response.Capabilities[AppNameSpreed]
|
capaObj, found := response.Capabilities[AppNameSpreed]
|
||||||
if !found || capaObj == nil {
|
if !found || capaObj == nil {
|
||||||
log.Printf("No capabilities received for app spreed from %s: %+v", capUrl.String(), response)
|
log.Printf("No capabilities received for app spreed from %s: %+v", capUrl.String(), response)
|
||||||
return nil, nil
|
return nil, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var capa map[string]interface{}
|
var capa map[string]interface{}
|
||||||
if err := json.Unmarshal(*capaObj, &capa); err != nil {
|
if err := json.Unmarshal(*capaObj, &capa); err != nil {
|
||||||
log.Printf("Unsupported capabilities received for app spreed from %s: %+v", capUrl.String(), response)
|
log.Printf("Unsupported capabilities received for app spreed from %s: %+v", capUrl.String(), response)
|
||||||
return nil, nil
|
return nil, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Received capabilities %+v from %s", capa, capUrl.String())
|
log.Printf("Received capabilities %+v from %s", capa, capUrl.String())
|
||||||
c.setCapabilities(key, capa)
|
c.setCapabilities(key, capa)
|
||||||
return capa, nil
|
return capa, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Capabilities) HasCapabilityFeature(ctx context.Context, u *url.URL, feature string) bool {
|
func (c *Capabilities) HasCapabilityFeature(ctx context.Context, u *url.URL, feature string) bool {
|
||||||
caps, err := c.loadCapabilities(ctx, u)
|
caps, _, err := c.loadCapabilities(ctx, u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Could not get capabilities for %s: %s", u, err)
|
log.Printf("Could not get capabilities for %s: %s", u, err)
|
||||||
return false
|
return false
|
||||||
|
@ -217,80 +242,86 @@ func (c *Capabilities) HasCapabilityFeature(ctx context.Context, u *url.URL, fea
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Capabilities) getConfigGroup(ctx context.Context, u *url.URL, group string) (map[string]interface{}, bool) {
|
func (c *Capabilities) getConfigGroup(ctx context.Context, u *url.URL, group string) (map[string]interface{}, bool, bool) {
|
||||||
caps, err := c.loadCapabilities(ctx, u)
|
caps, cached, err := c.loadCapabilities(ctx, u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Could not get capabilities for %s: %s", u, err)
|
log.Printf("Could not get capabilities for %s: %s", u, err)
|
||||||
return nil, false
|
return nil, cached, false
|
||||||
}
|
}
|
||||||
|
|
||||||
configInterface := caps["config"]
|
configInterface := caps["config"]
|
||||||
if configInterface == nil {
|
if configInterface == nil {
|
||||||
return nil, false
|
return nil, cached, false
|
||||||
}
|
}
|
||||||
|
|
||||||
config, ok := configInterface.(map[string]interface{})
|
config, ok := configInterface.(map[string]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Printf("Invalid config mapping received from %s: %+v", u, configInterface)
|
log.Printf("Invalid config mapping received from %s: %+v", u, configInterface)
|
||||||
return nil, false
|
return nil, cached, false
|
||||||
}
|
}
|
||||||
|
|
||||||
groupInterface := config[group]
|
groupInterface := config[group]
|
||||||
if groupInterface == nil {
|
if groupInterface == nil {
|
||||||
return nil, false
|
return nil, cached, false
|
||||||
}
|
}
|
||||||
|
|
||||||
groupConfig, ok := groupInterface.(map[string]interface{})
|
groupConfig, ok := groupInterface.(map[string]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Printf("Invalid group mapping \"%s\" received from %s: %+v", group, u, groupInterface)
|
log.Printf("Invalid group mapping \"%s\" received from %s: %+v", group, u, groupInterface)
|
||||||
return nil, false
|
return nil, cached, false
|
||||||
}
|
}
|
||||||
|
|
||||||
return groupConfig, true
|
return groupConfig, cached, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Capabilities) GetIntegerConfig(ctx context.Context, u *url.URL, group, key string) (int, bool) {
|
func (c *Capabilities) GetIntegerConfig(ctx context.Context, u *url.URL, group, key string) (int, bool, bool) {
|
||||||
groupConfig, found := c.getConfigGroup(ctx, u, group)
|
groupConfig, cached, found := c.getConfigGroup(ctx, u, group)
|
||||||
if !found {
|
if !found {
|
||||||
return 0, false
|
return 0, cached, false
|
||||||
}
|
}
|
||||||
|
|
||||||
value, found := groupConfig[key]
|
value, found := groupConfig[key]
|
||||||
if !found {
|
if !found {
|
||||||
return 0, false
|
return 0, cached, false
|
||||||
}
|
}
|
||||||
|
|
||||||
switch value := value.(type) {
|
switch value := value.(type) {
|
||||||
case int:
|
case int:
|
||||||
return value, true
|
return value, cached, true
|
||||||
case float32:
|
case float32:
|
||||||
return int(value), true
|
return int(value), cached, true
|
||||||
case float64:
|
case float64:
|
||||||
return int(value), true
|
return int(value), cached, true
|
||||||
default:
|
default:
|
||||||
log.Printf("Invalid config value for \"%s\" received from %s: %+v", key, u, value)
|
log.Printf("Invalid config value for \"%s\" received from %s: %+v", key, u, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0, false
|
return 0, cached, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Capabilities) GetStringConfig(ctx context.Context, u *url.URL, group, key string) (string, bool) {
|
func (c *Capabilities) GetStringConfig(ctx context.Context, u *url.URL, group, key string) (string, bool, bool) {
|
||||||
groupConfig, found := c.getConfigGroup(ctx, u, group)
|
groupConfig, cached, found := c.getConfigGroup(ctx, u, group)
|
||||||
if !found {
|
if !found {
|
||||||
return "", false
|
return "", cached, false
|
||||||
}
|
}
|
||||||
|
|
||||||
value, found := groupConfig[key]
|
value, found := groupConfig[key]
|
||||||
if !found {
|
if !found {
|
||||||
return "", false
|
return "", cached, false
|
||||||
}
|
}
|
||||||
|
|
||||||
switch value := value.(type) {
|
switch value := value.(type) {
|
||||||
case string:
|
case string:
|
||||||
return value, true
|
return value, cached, true
|
||||||
default:
|
default:
|
||||||
log.Printf("Invalid config value for \"%s\" received from %s: %+v", key, u, value)
|
log.Printf("Invalid config value for \"%s\" received from %s: %+v", key, u, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
return "", false
|
return "", cached, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Capabilities) InvalidateCapabilities(u *url.URL) {
|
||||||
|
key := c.getKeyForUrl(u)
|
||||||
|
|
||||||
|
c.invalidateCapabilities(key)
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,12 +28,14 @@ import (
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewCapabilitiesForTest(t *testing.T) (*url.URL, *Capabilities) {
|
func NewCapabilitiesForTestWithCallback(t *testing.T, callback func(*CapabilitiesResponse)) (*url.URL, *Capabilities) {
|
||||||
pool, err := NewHttpClientPool(1, false)
|
pool, err := NewHttpClientPool(1, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -84,6 +86,10 @@ func NewCapabilitiesForTest(t *testing.T) (*url.URL, *Capabilities) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if callback != nil {
|
||||||
|
callback(response)
|
||||||
|
}
|
||||||
|
|
||||||
data, err := json.Marshal(response)
|
data, err := json.Marshal(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Could not marshal %+v: %s", response, err)
|
t.Errorf("Could not marshal %+v: %s", response, err)
|
||||||
|
@ -110,6 +116,19 @@ func NewCapabilitiesForTest(t *testing.T) (*url.URL, *Capabilities) {
|
||||||
return u, capabilities
|
return u, capabilities
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewCapabilitiesForTest(t *testing.T) (*url.URL, *Capabilities) {
|
||||||
|
return NewCapabilitiesForTestWithCallback(t, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetCapabilitiesGetNow(t *testing.T, f func() time.Time) {
|
||||||
|
old := getCapabilitiesNow
|
||||||
|
t.Cleanup(func() {
|
||||||
|
getCapabilitiesNow = old
|
||||||
|
})
|
||||||
|
|
||||||
|
getCapabilitiesNow = f
|
||||||
|
}
|
||||||
|
|
||||||
func TestCapabilities(t *testing.T) {
|
func TestCapabilities(t *testing.T) {
|
||||||
url, capabilities := NewCapabilitiesForTest(t)
|
url, capabilities := NewCapabilitiesForTest(t)
|
||||||
|
|
||||||
|
@ -124,34 +143,122 @@ func TestCapabilities(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
expectedString := "bar"
|
expectedString := "bar"
|
||||||
if value, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
|
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
|
||||||
t.Error("could not find value for \"foo\"")
|
t.Error("could not find value for \"foo\"")
|
||||||
} else if value != expectedString {
|
} else if value != expectedString {
|
||||||
t.Errorf("expected value %s, got %s", expectedString, value)
|
t.Errorf("expected value %s, got %s", expectedString, value)
|
||||||
|
} else if !cached {
|
||||||
|
t.Errorf("expected cached response")
|
||||||
}
|
}
|
||||||
if value, found := capabilities.GetStringConfig(ctx, url, "signaling", "baz"); found {
|
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "baz"); found {
|
||||||
t.Errorf("should not have found value for \"baz\", got %s", value)
|
t.Errorf("should not have found value for \"baz\", got %s", value)
|
||||||
|
} else if !cached {
|
||||||
|
t.Errorf("expected cached response")
|
||||||
}
|
}
|
||||||
if value, found := capabilities.GetStringConfig(ctx, url, "signaling", "invalid"); found {
|
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "invalid"); found {
|
||||||
t.Errorf("should not have found value for \"invalid\", got %s", value)
|
t.Errorf("should not have found value for \"invalid\", got %s", value)
|
||||||
|
} else if !cached {
|
||||||
|
t.Errorf("expected cached response")
|
||||||
}
|
}
|
||||||
if value, found := capabilities.GetStringConfig(ctx, url, "invalid", "foo"); found {
|
if value, cached, found := capabilities.GetStringConfig(ctx, url, "invalid", "foo"); found {
|
||||||
t.Errorf("should not have found value for \"baz\", got %s", value)
|
t.Errorf("should not have found value for \"baz\", got %s", value)
|
||||||
|
} else if !cached {
|
||||||
|
t.Errorf("expected cached response")
|
||||||
}
|
}
|
||||||
|
|
||||||
expectedInt := 42
|
expectedInt := 42
|
||||||
if value, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "baz"); !found {
|
if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "baz"); !found {
|
||||||
t.Error("could not find value for \"baz\"")
|
t.Error("could not find value for \"baz\"")
|
||||||
} else if value != expectedInt {
|
} else if value != expectedInt {
|
||||||
t.Errorf("expected value %d, got %d", expectedInt, value)
|
t.Errorf("expected value %d, got %d", expectedInt, value)
|
||||||
|
} else if !cached {
|
||||||
|
t.Errorf("expected cached response")
|
||||||
}
|
}
|
||||||
if value, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "foo"); found {
|
if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "foo"); found {
|
||||||
t.Errorf("should not have found value for \"foo\", got %d", value)
|
t.Errorf("should not have found value for \"foo\", got %d", value)
|
||||||
|
} else if !cached {
|
||||||
|
t.Errorf("expected cached response")
|
||||||
}
|
}
|
||||||
if value, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "invalid"); found {
|
if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "invalid"); found {
|
||||||
t.Errorf("should not have found value for \"invalid\", got %d", value)
|
t.Errorf("should not have found value for \"invalid\", got %d", value)
|
||||||
|
} else if !cached {
|
||||||
|
t.Errorf("expected cached response")
|
||||||
}
|
}
|
||||||
if value, found := capabilities.GetIntegerConfig(ctx, url, "invalid", "baz"); found {
|
if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "invalid", "baz"); found {
|
||||||
t.Errorf("should not have found value for \"baz\", got %d", value)
|
t.Errorf("should not have found value for \"baz\", got %d", value)
|
||||||
|
} else if !cached {
|
||||||
|
t.Errorf("expected cached response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInvalidateCapabilities(t *testing.T) {
|
||||||
|
var called uint32
|
||||||
|
url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse) {
|
||||||
|
atomic.AddUint32(&called, 1)
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
expectedString := "bar"
|
||||||
|
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
|
||||||
|
t.Error("could not find value for \"foo\"")
|
||||||
|
} else if value != expectedString {
|
||||||
|
t.Errorf("expected value %s, got %s", expectedString, value)
|
||||||
|
} else if cached {
|
||||||
|
t.Errorf("expected direct response")
|
||||||
|
}
|
||||||
|
|
||||||
|
if value := atomic.LoadUint32(&called); value != 1 {
|
||||||
|
t.Errorf("expected called %d, got %d", 1, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalidating will cause the capabilities to be reloaded.
|
||||||
|
capabilities.InvalidateCapabilities(url)
|
||||||
|
|
||||||
|
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
|
||||||
|
t.Error("could not find value for \"foo\"")
|
||||||
|
} else if value != expectedString {
|
||||||
|
t.Errorf("expected value %s, got %s", expectedString, value)
|
||||||
|
} else if cached {
|
||||||
|
t.Errorf("expected direct response")
|
||||||
|
}
|
||||||
|
|
||||||
|
if value := atomic.LoadUint32(&called); value != 2 {
|
||||||
|
t.Errorf("expected called %d, got %d", 2, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalidating is throttled to about once per minute.
|
||||||
|
capabilities.InvalidateCapabilities(url)
|
||||||
|
|
||||||
|
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
|
||||||
|
t.Error("could not find value for \"foo\"")
|
||||||
|
} else if value != expectedString {
|
||||||
|
t.Errorf("expected value %s, got %s", expectedString, value)
|
||||||
|
} else if !cached {
|
||||||
|
t.Errorf("expected cached response")
|
||||||
|
}
|
||||||
|
|
||||||
|
if value := atomic.LoadUint32(&called); value != 2 {
|
||||||
|
t.Errorf("expected called %d, got %d", 2, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// At a later time, invalidating can be done again.
|
||||||
|
SetCapabilitiesGetNow(t, func() time.Time {
|
||||||
|
return time.Now().Add(2 * time.Minute)
|
||||||
|
})
|
||||||
|
|
||||||
|
capabilities.InvalidateCapabilities(url)
|
||||||
|
|
||||||
|
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
|
||||||
|
t.Error("could not find value for \"foo\"")
|
||||||
|
} else if value != expectedString {
|
||||||
|
t.Errorf("expected value %s, got %s", expectedString, value)
|
||||||
|
} else if cached {
|
||||||
|
t.Errorf("expected direct response")
|
||||||
|
}
|
||||||
|
|
||||||
|
if value := atomic.LoadUint32(&called); value != 3 {
|
||||||
|
t.Errorf("expected called %d, got %d", 3, value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -603,7 +603,7 @@ func main() {
|
||||||
request := &signaling.ClientMessage{
|
request := &signaling.ClientMessage{
|
||||||
Type: "hello",
|
Type: "hello",
|
||||||
Hello: &signaling.HelloClientMessage{
|
Hello: &signaling.HelloClientMessage{
|
||||||
Version: signaling.HelloVersion,
|
Version: signaling.HelloVersionV1,
|
||||||
Auth: signaling.HelloClientMessageAuth{
|
Auth: signaling.HelloClientMessageAuth{
|
||||||
Url: backendUrl + "/auth",
|
Url: backendUrl + "/auth",
|
||||||
Params: &json.RawMessage{'{', '}'},
|
Params: &json.RawMessage{'{', '}'},
|
||||||
|
|
|
@ -238,7 +238,7 @@ func TestBandwidth_Backend(t *testing.T) {
|
||||||
params := TestBackendClientAuthParams{
|
params := TestBackendClientAuthParams{
|
||||||
UserId: testDefaultUserId,
|
UserId: testDefaultUserId,
|
||||||
}
|
}
|
||||||
if err := client.SendHelloParams(server.URL+"/one", "client", params); err != nil {
|
if err := client.SendHelloParams(server.URL+"/one", HelloVersionV1, "client", params); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -140,7 +140,7 @@ Message format (Client -> Server):
|
||||||
"id": "unique-request-id",
|
"id": "unique-request-id",
|
||||||
"type": "hello",
|
"type": "hello",
|
||||||
"hello": {
|
"hello": {
|
||||||
"version": "the-protocol-version-must-be-1.0",
|
"version": "the-protocol-version",
|
||||||
"auth": {
|
"auth": {
|
||||||
"url": "the-url-to-the-auth-backend",
|
"url": "the-url-to-the-auth-backend",
|
||||||
"params": {
|
"params": {
|
||||||
|
@ -159,7 +159,7 @@ Message format (Server -> Client):
|
||||||
"sessionid": "the-unique-session-id",
|
"sessionid": "the-unique-session-id",
|
||||||
"resumeid": "the-unique-resume-id",
|
"resumeid": "the-unique-resume-id",
|
||||||
"userid": "the-user-id-for-known-users",
|
"userid": "the-user-id-for-known-users",
|
||||||
"version": "the-protocol-version-must-be-1.0",
|
"version": "the-protocol-version",
|
||||||
"server": {
|
"server": {
|
||||||
"features": ["optional", "list, "of", "feature", "ids"],
|
"features": ["optional", "list, "of", "feature", "ids"],
|
||||||
...additional information about the server...
|
...additional information about the server...
|
||||||
|
@ -172,12 +172,82 @@ future version. Clients should use the data from the
|
||||||
[`welcome` message](#welcome-message) instead.
|
[`welcome` message](#welcome-message) instead.
|
||||||
|
|
||||||
|
|
||||||
|
### Protocol version "1.0"
|
||||||
|
|
||||||
|
For protocol version `1.0` in the `hello` request, the `params` from the `auth`
|
||||||
|
field are sent to the Nextcloud backend for [validation](#backend-validation).
|
||||||
|
|
||||||
|
|
||||||
|
### Protocol version "2.0"
|
||||||
|
|
||||||
|
For protocol version `2.0` in the `hello` request, the `params` from the `auth`
|
||||||
|
field must contain a `token` entry containing a [JWT](https://jwt.io/).
|
||||||
|
|
||||||
|
The JWT must contain the following fields:
|
||||||
|
- `iss`: URL of the Nextcloud server that issued the token.
|
||||||
|
- `iat`: Timestamp when the token has been issued.
|
||||||
|
- `exp`: Timestamp of the token expiration.
|
||||||
|
- `sub`: User Id (if known).
|
||||||
|
- `userdata`: Optional JSON containing more user data.
|
||||||
|
|
||||||
|
It must be signed with an RSA, ECDSA or Ed25519 key.
|
||||||
|
|
||||||
|
Example token:
|
||||||
|
```
|
||||||
|
eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiJ9.eyJpc3MiOiJodHRwczovL25leHRjbG91ZC1tYXN0ZXIubG9jYWwvIiwiaWF0IjoxNjU0ODQyMDgwLCJleHAiOjE2NTQ4NDIzODAsInN1YiI6ImFkbWluIiwidXNlcmRhdGEiOnsiZGlzcGxheW5hbWUiOiJBZG1pbmlzdHJhdG9yIn19.5rV0jh89_0fG2L-BUPtciu1q49PoYkLboj33EOdD0qQeYcvE7_di2r5WXM1WmKUCOGeX3hzn6qldDMrJBNuxvQ
|
||||||
|
```
|
||||||
|
|
||||||
|
Example public key:
|
||||||
|
```
|
||||||
|
-----BEGIN PUBLIC KEY-----
|
||||||
|
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEIoCsNSCXyxK25zvSKRio0uiBzwub
|
||||||
|
ONq3tiGTPZo3p2Ogn6wAhhsuSxbFuUQDWMX7Tsu9fDzVdwpRHPT4y3V9cA==
|
||||||
|
-----END PUBLIC KEY-----
|
||||||
|
```
|
||||||
|
|
||||||
|
Example payload:
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"iss": "https://nextcloud-master.local/",
|
||||||
|
"iat": 1654842080,
|
||||||
|
"exp": 1654842380,
|
||||||
|
"sub": "admin",
|
||||||
|
"userdata": {
|
||||||
|
"displayname": "Administrator"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The public key is retrieved from the capabilities of the Nextcloud instance
|
||||||
|
in `config` key `hello-v2-token-key` inside `signaling`.
|
||||||
|
|
||||||
|
```
|
||||||
|
"spreed": {
|
||||||
|
"features": [
|
||||||
|
"audio",
|
||||||
|
"video",
|
||||||
|
"chat-v2",
|
||||||
|
"conversation-v4",
|
||||||
|
...
|
||||||
|
],
|
||||||
|
"config": {
|
||||||
|
…
|
||||||
|
"signaling": {
|
||||||
|
"hello-v2-token-key": "-----BEGIN RSA PUBLIC KEY----- ..."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
### Backend validation
|
### Backend validation
|
||||||
|
|
||||||
The server validates the connection request against the passed auth backend
|
For `hello` protocol version `1.0`, the server validates the connection request
|
||||||
(needs to make sure the passed url / hostname is in a whitelist). It performs
|
against the passed auth backend (needs to make sure the passed url / hostname
|
||||||
a POST request and passes the provided `params` as JSON payload in the body
|
is in a whitelist).
|
||||||
of the request.
|
|
||||||
|
It performs a POST request and passes the provided `params` as JSON payload in
|
||||||
|
the body of the request.
|
||||||
|
|
||||||
Message format (Server -> Auth backend):
|
Message format (Server -> Auth backend):
|
||||||
|
|
||||||
|
@ -236,7 +306,7 @@ Message format (Client -> Server):
|
||||||
"id": "unique-request-id",
|
"id": "unique-request-id",
|
||||||
"type": "hello",
|
"type": "hello",
|
||||||
"hello": {
|
"hello": {
|
||||||
"version": "the-protocol-version-must-be-1.0",
|
"version": "the-protocol-version",
|
||||||
"auth": {
|
"auth": {
|
||||||
"type": "the-client-type",
|
"type": "the-client-type",
|
||||||
...other attributes depending on the client type...
|
...other attributes depending on the client type...
|
||||||
|
@ -294,7 +364,7 @@ Message format (Client -> Server):
|
||||||
"id": "unique-request-id",
|
"id": "unique-request-id",
|
||||||
"type": "hello",
|
"type": "hello",
|
||||||
"hello": {
|
"hello": {
|
||||||
"version": "the-protocol-version-must-be-1.0",
|
"version": "the-protocol-version",
|
||||||
"resumeid": "the-resume-id-from-the-original-hello-response"
|
"resumeid": "the-resume-id-from-the-original-hello-response"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -306,7 +376,7 @@ Message format (Server -> Client):
|
||||||
"type": "hello",
|
"type": "hello",
|
||||||
"hello": {
|
"hello": {
|
||||||
"sessionid": "the-unique-session-id",
|
"sessionid": "the-unique-session-id",
|
||||||
"version": "the-protocol-version-must-be-1.0"
|
"version": "the-protocol-version"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
170
hub.go
170
hub.go
|
@ -22,12 +22,16 @@
|
||||||
package signaling
|
package signaling
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/ed25519"
|
||||||
"crypto/hmac"
|
"crypto/hmac"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
|
"crypto/x509"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"encoding/pem"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"hash/fnv"
|
"hash/fnv"
|
||||||
|
@ -40,6 +44,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/dlintw/goconf"
|
"github.com/dlintw/goconf"
|
||||||
|
"github.com/golang-jwt/jwt/v4"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/gorilla/securecookie"
|
"github.com/gorilla/securecookie"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
|
@ -48,12 +53,15 @@ import (
|
||||||
var (
|
var (
|
||||||
DuplicateClient = NewError("duplicate_client", "Client already registered.")
|
DuplicateClient = NewError("duplicate_client", "Client already registered.")
|
||||||
HelloExpected = NewError("hello_expected", "Expected Hello request.")
|
HelloExpected = NewError("hello_expected", "Expected Hello request.")
|
||||||
|
InvalidHelloVersion = NewError("invalid_hello_version", "The hello version is not supported.")
|
||||||
UserAuthFailed = NewError("auth_failed", "The user could not be authenticated.")
|
UserAuthFailed = NewError("auth_failed", "The user could not be authenticated.")
|
||||||
RoomJoinFailed = NewError("room_join_failed", "Could not join the room.")
|
RoomJoinFailed = NewError("room_join_failed", "Could not join the room.")
|
||||||
InvalidClientType = NewError("invalid_client_type", "The client type is not supported.")
|
InvalidClientType = NewError("invalid_client_type", "The client type is not supported.")
|
||||||
InvalidBackendUrl = NewError("invalid_backend", "The backend URL is not supported.")
|
InvalidBackendUrl = NewError("invalid_backend", "The backend URL is not supported.")
|
||||||
InvalidToken = NewError("invalid_token", "The passed token is invalid.")
|
InvalidToken = NewError("invalid_token", "The passed token is invalid.")
|
||||||
NoSuchSession = NewError("no_such_session", "The session to resume does not exist.")
|
NoSuchSession = NewError("no_such_session", "The session to resume does not exist.")
|
||||||
|
TokenNotValidYet = NewError("token_not_valid_yet", "The token is not valid yet.")
|
||||||
|
TokenExpired = NewError("token_expired", "The token is expired.")
|
||||||
|
|
||||||
// Maximum number of concurrent requests to a backend.
|
// Maximum number of concurrent requests to a backend.
|
||||||
defaultMaxConcurrentRequestsPerHost = 8
|
defaultMaxConcurrentRequestsPerHost = 8
|
||||||
|
@ -850,11 +858,19 @@ func (h *Hub) processMessage(client *Client, data []byte) {
|
||||||
if err := message.CheckValid(); err != nil {
|
if err := message.CheckValid(); err != nil {
|
||||||
if session := client.GetSession(); session != nil {
|
if session := client.GetSession(); session != nil {
|
||||||
log.Printf("Invalid message %+v from client %s: %v", message, session.PublicId(), err)
|
log.Printf("Invalid message %+v from client %s: %v", message, session.PublicId(), err)
|
||||||
|
if err, ok := err.(*Error); ok {
|
||||||
|
session.SendMessage(message.NewErrorServerMessage(err))
|
||||||
|
} else {
|
||||||
session.SendMessage(message.NewErrorServerMessage(InvalidFormat))
|
session.SendMessage(message.NewErrorServerMessage(InvalidFormat))
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Printf("Invalid message %+v from %s: %v", message, client.RemoteAddr(), err)
|
log.Printf("Invalid message %+v from %s: %v", message, client.RemoteAddr(), err)
|
||||||
|
if err, ok := err.(*Error); ok {
|
||||||
|
client.SendMessage(message.NewErrorServerMessage(err))
|
||||||
|
} else {
|
||||||
client.SendMessage(message.NewErrorServerMessage(InvalidFormat))
|
client.SendMessage(message.NewErrorServerMessage(InvalidFormat))
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -896,7 +912,7 @@ func (h *Hub) sendHelloResponse(session *ClientSession, message *ClientMessage)
|
||||||
Id: message.Id,
|
Id: message.Id,
|
||||||
Type: "hello",
|
Type: "hello",
|
||||||
Hello: &HelloServerMessage{
|
Hello: &HelloServerMessage{
|
||||||
Version: HelloVersion,
|
Version: message.Hello.Version,
|
||||||
SessionId: session.PublicId(),
|
SessionId: session.PublicId(),
|
||||||
ResumeId: session.PrivateId(),
|
ResumeId: session.PrivateId(),
|
||||||
UserId: session.UserId(),
|
UserId: session.UserId(),
|
||||||
|
@ -975,31 +991,163 @@ func (h *Hub) processHello(client *Client, message *ClientMessage) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Hub) processHelloClient(client *Client, message *ClientMessage) {
|
func (h *Hub) processHelloV1(client *Client, message *ClientMessage) (*Backend, *BackendClientResponse, error) {
|
||||||
// Make sure the client must send another "hello" in case of errors.
|
|
||||||
defer h.startExpectHello(client)
|
|
||||||
|
|
||||||
url := message.Hello.Auth.parsedUrl
|
url := message.Hello.Auth.parsedUrl
|
||||||
backend := h.backend.GetBackend(url)
|
backend := h.backend.GetBackend(url)
|
||||||
if backend == nil {
|
if backend == nil {
|
||||||
client.SendMessage(message.NewErrorServerMessage(InvalidBackendUrl))
|
return nil, nil, InvalidBackendUrl
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run in timeout context to prevent blocking too long.
|
// Run in timeout context to prevent blocking too long.
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), h.backendTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), h.backendTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
request := NewBackendClientAuthRequest(message.Hello.Auth.Params)
|
|
||||||
var auth BackendClientResponse
|
var auth BackendClientResponse
|
||||||
|
request := NewBackendClientAuthRequest(message.Hello.Auth.Params)
|
||||||
if err := h.backend.PerformJSONRequest(ctx, url, request, &auth); err != nil {
|
if err := h.backend.PerformJSONRequest(ctx, url, request, &auth); err != nil {
|
||||||
client.SendMessage(message.NewWrappedErrorServerMessage(err))
|
return nil, nil, err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(jojo): Validate response
|
// TODO(jojo): Validate response
|
||||||
|
|
||||||
h.processRegister(client, message, backend, &auth)
|
return backend, &auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hub) processHelloV2(client *Client, message *ClientMessage) (*Backend, *BackendClientResponse, error) {
|
||||||
|
url := message.Hello.Auth.parsedUrl
|
||||||
|
backend := h.backend.GetBackend(url)
|
||||||
|
if backend == nil {
|
||||||
|
return nil, nil, InvalidBackendUrl
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := jwt.ParseWithClaims(message.Hello.Auth.helloV2Params.Token, &HelloV2TokenClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||||
|
// Only public-private-key algorithms are supported.
|
||||||
|
var loadKeyFunc func([]byte) (interface{}, error)
|
||||||
|
switch token.Method.(type) {
|
||||||
|
case *jwt.SigningMethodRSA:
|
||||||
|
loadKeyFunc = func(data []byte) (interface{}, error) {
|
||||||
|
return jwt.ParseRSAPublicKeyFromPEM(data)
|
||||||
|
}
|
||||||
|
case *jwt.SigningMethodECDSA:
|
||||||
|
loadKeyFunc = func(data []byte) (interface{}, error) {
|
||||||
|
return jwt.ParseECPublicKeyFromPEM(data)
|
||||||
|
}
|
||||||
|
case *jwt.SigningMethodEd25519:
|
||||||
|
loadKeyFunc = func(data []byte) (interface{}, error) {
|
||||||
|
if !bytes.HasPrefix(data, []byte("-----BEGIN ")) {
|
||||||
|
// Nextcloud sends the Ed25519 key as base64-encoded public key data.
|
||||||
|
decoded, err := base64.StdEncoding.DecodeString(string(data))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
key := ed25519.PublicKey(decoded)
|
||||||
|
data, err = x509.MarshalPKIXPublicKey(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
data = pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "PUBLIC KEY",
|
||||||
|
Bytes: data,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return jwt.ParseEdPublicKeyFromPEM(data)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
log.Printf("Unexpected signing method: %v", token.Header["alg"])
|
||||||
|
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run in timeout context to prevent blocking too long.
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), h.backendTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
keyData, cached, found := h.backend.capabilities.GetStringConfig(ctx, url, ConfigGroupSignaling, ConfigKeyHelloV2TokenKey)
|
||||||
|
if !found {
|
||||||
|
if cached {
|
||||||
|
// The Nextcloud instance might just have enabled JWT but we probably use
|
||||||
|
// the cached capabilities without the public key. Make sure to re-fetch.
|
||||||
|
h.backend.capabilities.InvalidateCapabilities(url)
|
||||||
|
keyData, _, found = h.backend.capabilities.GetStringConfig(ctx, url, ConfigGroupSignaling, ConfigKeyHelloV2TokenKey)
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
return nil, fmt.Errorf("No key found for issuer")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
key, err := loadKeyFunc([]byte(keyData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Could not parse token key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
if err, ok := err.(*jwt.ValidationError); ok {
|
||||||
|
if err.Errors&jwt.ValidationErrorIssuedAt == jwt.ValidationErrorIssuedAt {
|
||||||
|
return nil, nil, TokenNotValidYet
|
||||||
|
}
|
||||||
|
if err.Errors&jwt.ValidationErrorExpired == jwt.ValidationErrorExpired {
|
||||||
|
return nil, nil, TokenExpired
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil, InvalidToken
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, ok := token.Claims.(*HelloV2TokenClaims)
|
||||||
|
if !ok || !token.Valid {
|
||||||
|
return nil, nil, InvalidToken
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
if !claims.VerifyIssuedAt(now, true) {
|
||||||
|
return nil, nil, TokenNotValidYet
|
||||||
|
}
|
||||||
|
if !claims.VerifyExpiresAt(now, true) {
|
||||||
|
return nil, nil, TokenExpired
|
||||||
|
}
|
||||||
|
|
||||||
|
auth := &BackendClientResponse{
|
||||||
|
Type: "auth",
|
||||||
|
Auth: &BackendClientAuthResponse{
|
||||||
|
Version: message.Hello.Version,
|
||||||
|
UserId: claims.Subject,
|
||||||
|
User: claims.UserData,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return backend, auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hub) processHelloClient(client *Client, message *ClientMessage) {
|
||||||
|
// Make sure the client must send another "hello" in case of errors.
|
||||||
|
defer h.startExpectHello(client)
|
||||||
|
|
||||||
|
var authFunc func(*Client, *ClientMessage) (*Backend, *BackendClientResponse, error)
|
||||||
|
switch message.Hello.Version {
|
||||||
|
case HelloVersionV1:
|
||||||
|
// Auth information contains a ticket that must be validated against the
|
||||||
|
// Nextcloud instance.
|
||||||
|
authFunc = h.processHelloV1
|
||||||
|
case HelloVersionV2:
|
||||||
|
// Auth information contains a JWT that contains all information of the user.
|
||||||
|
authFunc = h.processHelloV2
|
||||||
|
default:
|
||||||
|
client.SendMessage(message.NewErrorServerMessage(InvalidHelloVersion))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
backend, auth, err := authFunc(client, message)
|
||||||
|
if err != nil {
|
||||||
|
if e, ok := err.(*Error); ok {
|
||||||
|
client.SendMessage(message.NewErrorServerMessage(e))
|
||||||
|
} else {
|
||||||
|
client.SendMessage(message.NewWrappedErrorServerMessage(err))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.processRegister(client, message, backend, auth)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Hub) processHelloInternal(client *Client, message *ClientMessage) {
|
func (h *Hub) processHelloInternal(client *Client, message *ClientMessage) {
|
||||||
|
|
447
hub_test.go
447
hub_test.go
|
@ -23,12 +23,21 @@ package signaling
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"encoding/pem"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -37,6 +46,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/dlintw/goconf"
|
"github.com/dlintw/goconf"
|
||||||
|
"github.com/golang-jwt/jwt/v4"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
@ -53,6 +63,13 @@ var (
|
||||||
"local",
|
"local",
|
||||||
"clustered",
|
"clustered",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
testHelloV2Algorithms = []string{
|
||||||
|
"RSA",
|
||||||
|
"ECDSA",
|
||||||
|
"Ed25519",
|
||||||
|
"Ed25519_Nextcloud",
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
// Only used for testing.
|
// Only used for testing.
|
||||||
|
@ -511,6 +528,131 @@ func processPingRequest(t *testing.T, w http.ResponseWriter, r *http.Request, re
|
||||||
return response
|
return response
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ensureAuthTokens(t *testing.T) (string, string) {
|
||||||
|
if privateKey := os.Getenv("PRIVATE_AUTH_TOKEN_" + t.Name()); privateKey != "" {
|
||||||
|
publicKey := os.Getenv("PUBLIC_AUTH_TOKEN_" + t.Name())
|
||||||
|
if publicKey == "" {
|
||||||
|
// should not happen, always both keys are created
|
||||||
|
t.Fatal("public key is empty")
|
||||||
|
}
|
||||||
|
return privateKey, publicKey
|
||||||
|
}
|
||||||
|
|
||||||
|
var private []byte
|
||||||
|
var public []byte
|
||||||
|
|
||||||
|
if strings.Contains(t.Name(), "ECDSA") {
|
||||||
|
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
private, err = x509.MarshalECPrivateKey(key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
private = pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "ECDSA PRIVATE KEY",
|
||||||
|
Bytes: private,
|
||||||
|
})
|
||||||
|
|
||||||
|
public, err = x509.MarshalPKIXPublicKey(&key.PublicKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
public = pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "ECDSA PUBLIC KEY",
|
||||||
|
Bytes: public,
|
||||||
|
})
|
||||||
|
} else if strings.Contains(t.Name(), "Ed25519") {
|
||||||
|
publicKey, privateKey, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
private, err = x509.MarshalPKCS8PrivateKey(privateKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
private = pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "Ed25519 PRIVATE KEY",
|
||||||
|
Bytes: private,
|
||||||
|
})
|
||||||
|
|
||||||
|
public, err = x509.MarshalPKIXPublicKey(publicKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
public = pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "Ed25519 PUBLIC KEY",
|
||||||
|
Bytes: public,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
private = pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "RSA PRIVATE KEY",
|
||||||
|
Bytes: x509.MarshalPKCS1PrivateKey(key),
|
||||||
|
})
|
||||||
|
|
||||||
|
public, err = x509.MarshalPKIXPublicKey(&key.PublicKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
public = pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "RSA PUBLIC KEY",
|
||||||
|
Bytes: public,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
privateKey := base64.StdEncoding.EncodeToString(private)
|
||||||
|
t.Setenv("PRIVATE_AUTH_TOKEN_"+t.Name(), privateKey)
|
||||||
|
publicKey := base64.StdEncoding.EncodeToString(public)
|
||||||
|
t.Setenv("PUBLIC_AUTH_TOKEN_"+t.Name(), publicKey)
|
||||||
|
return privateKey, publicKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPrivateAuthToken(t *testing.T) (key interface{}) {
|
||||||
|
private, _ := ensureAuthTokens(t)
|
||||||
|
data, err := base64.StdEncoding.DecodeString(private)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if strings.Contains(t.Name(), "ECDSA") {
|
||||||
|
key, err = jwt.ParseECPrivateKeyFromPEM(data)
|
||||||
|
} else if strings.Contains(t.Name(), "Ed25519") {
|
||||||
|
key, err = jwt.ParseEdPrivateKeyFromPEM(data)
|
||||||
|
} else {
|
||||||
|
key, err = jwt.ParseRSAPrivateKeyFromPEM(data)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPublicAuthToken(t *testing.T) (key interface{}) {
|
||||||
|
_, public := ensureAuthTokens(t)
|
||||||
|
data, err := base64.StdEncoding.DecodeString(public)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if strings.Contains(t.Name(), "ECDSA") {
|
||||||
|
key, err = jwt.ParseECPublicKeyFromPEM(data)
|
||||||
|
} else if strings.Contains(t.Name(), "Ed25519") {
|
||||||
|
key, err = jwt.ParseEdPublicKeyFromPEM(data)
|
||||||
|
} else {
|
||||||
|
key, err = jwt.ParseRSAPublicKeyFromPEM(data)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
|
||||||
func registerBackendHandler(t *testing.T, router *mux.Router) {
|
func registerBackendHandler(t *testing.T, router *mux.Router) {
|
||||||
registerBackendHandlerUrl(t, router, "/")
|
registerBackendHandlerUrl(t, router, "/")
|
||||||
}
|
}
|
||||||
|
@ -555,6 +697,37 @@ func registerBackendHandlerUrl(t *testing.T, router *mux.Router, url string) {
|
||||||
if strings.Contains(t.Name(), "MultiRoom") {
|
if strings.Contains(t.Name(), "MultiRoom") {
|
||||||
signaling[ConfigKeySessionPingLimit] = 2
|
signaling[ConfigKeySessionPingLimit] = 2
|
||||||
}
|
}
|
||||||
|
useV2 := true
|
||||||
|
if os.Getenv("SKIP_V2_CAPABILITIES") != "" {
|
||||||
|
useV2 = false
|
||||||
|
}
|
||||||
|
if strings.Contains(t.Name(), "V2") && useV2 {
|
||||||
|
key := getPublicAuthToken(t)
|
||||||
|
public, err := x509.MarshalPKIXPublicKey(key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var pemType string
|
||||||
|
if strings.Contains(t.Name(), "ECDSA") {
|
||||||
|
pemType = "ECDSA PUBLIC KEY"
|
||||||
|
} else if strings.Contains(t.Name(), "Ed25519") {
|
||||||
|
pemType = "Ed25519 PUBLIC KEY"
|
||||||
|
} else {
|
||||||
|
pemType = "RSA PUBLIC KEY"
|
||||||
|
}
|
||||||
|
|
||||||
|
public = pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: pemType,
|
||||||
|
Bytes: public,
|
||||||
|
})
|
||||||
|
if strings.Contains(t.Name(), "Ed25519_Nextcloud") {
|
||||||
|
// Simulate Nextcloud which returns the Ed25519 key as base64-encoded data.
|
||||||
|
encoded := base64.StdEncoding.EncodeToString(key.(ed25519.PublicKey))
|
||||||
|
signaling[ConfigKeyHelloV2TokenKey] = encoded
|
||||||
|
} else {
|
||||||
|
signaling[ConfigKeyHelloV2TokenKey] = string(public)
|
||||||
|
}
|
||||||
|
}
|
||||||
spreedCapa, _ := json.Marshal(map[string]interface{}{
|
spreedCapa, _ := json.Marshal(map[string]interface{}{
|
||||||
"features": features,
|
"features": features,
|
||||||
"config": config,
|
"config": config,
|
||||||
|
@ -665,7 +838,35 @@ func TestExpectClientHello(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClientHello(t *testing.T) {
|
func TestExpectClientHelloUnsupportedVersion(t *testing.T) {
|
||||||
|
hub, _, _, server := CreateHubForTest(t)
|
||||||
|
|
||||||
|
client := NewTestClient(t, server, hub)
|
||||||
|
defer client.CloseWithBye()
|
||||||
|
|
||||||
|
params := TestBackendClientAuthParams{
|
||||||
|
UserId: testDefaultUserId,
|
||||||
|
}
|
||||||
|
if err := client.SendHelloParams(server.URL, "0.0", "", params); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
message, err := client.RunUntilMessage(ctx)
|
||||||
|
if err := checkUnexpectedClose(err); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := checkMessageType(message, "error"); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
} else if message.Error.Code != "invalid_hello_version" {
|
||||||
|
t.Errorf("Expected \"invalid_hello_version\" reason, got %+v", message.Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientHelloV1(t *testing.T) {
|
||||||
hub, _, _, server := CreateHubForTest(t)
|
hub, _, _, server := CreateHubForTest(t)
|
||||||
|
|
||||||
client := NewTestClient(t, server, hub)
|
client := NewTestClient(t, server, hub)
|
||||||
|
@ -690,6 +891,232 @@ func TestClientHello(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClientHelloV2(t *testing.T) {
|
||||||
|
for _, algo := range testHelloV2Algorithms {
|
||||||
|
t.Run(algo, func(t *testing.T) {
|
||||||
|
hub, _, _, server := CreateHubForTest(t)
|
||||||
|
|
||||||
|
client := NewTestClient(t, server, hub)
|
||||||
|
defer client.CloseWithBye()
|
||||||
|
|
||||||
|
if err := client.SendHelloV2(testDefaultUserId); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
hello, err := client.RunUntilHello(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if hello.Hello.UserId != testDefaultUserId {
|
||||||
|
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello)
|
||||||
|
}
|
||||||
|
if hello.Hello.SessionId == "" {
|
||||||
|
t.Errorf("Expected session id, got %+v", hello.Hello)
|
||||||
|
}
|
||||||
|
|
||||||
|
data := hub.decodeSessionId(hello.Hello.SessionId, publicSessionName)
|
||||||
|
if data == nil {
|
||||||
|
t.Fatalf("Could not decode session id: %s", hello.Hello.SessionId)
|
||||||
|
}
|
||||||
|
|
||||||
|
hub.mu.RLock()
|
||||||
|
session := hub.sessions[data.Sid]
|
||||||
|
hub.mu.RUnlock()
|
||||||
|
if session == nil {
|
||||||
|
t.Fatalf("Could not get session for id %+v", data)
|
||||||
|
}
|
||||||
|
|
||||||
|
var userdata map[string]string
|
||||||
|
if err := json.Unmarshal(*session.UserData(), &userdata); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if expected := "Displayname " + testDefaultUserId; userdata["displayname"] != expected {
|
||||||
|
t.Errorf("Expected displayname %s, got %s", expected, userdata["displayname"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientHelloV2_IssuedInFuture(t *testing.T) {
|
||||||
|
for _, algo := range testHelloV2Algorithms {
|
||||||
|
t.Run(algo, func(t *testing.T) {
|
||||||
|
hub, _, _, server := CreateHubForTest(t)
|
||||||
|
|
||||||
|
client := NewTestClient(t, server, hub)
|
||||||
|
defer client.CloseWithBye()
|
||||||
|
|
||||||
|
issuedAt := time.Now().Add(time.Minute)
|
||||||
|
expiresAt := issuedAt.Add(time.Second)
|
||||||
|
if err := client.SendHelloV2WithTimes(testDefaultUserId, issuedAt, expiresAt); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
message, err := client.RunUntilMessage(ctx)
|
||||||
|
if err := checkUnexpectedClose(err); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := checkMessageType(message, "error"); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
} else if message.Error.Code != "token_not_valid_yet" {
|
||||||
|
t.Errorf("Expected \"token_not_valid_yet\" reason, got %+v", message.Error)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientHelloV2_Expired(t *testing.T) {
|
||||||
|
for _, algo := range testHelloV2Algorithms {
|
||||||
|
t.Run(algo, func(t *testing.T) {
|
||||||
|
hub, _, _, server := CreateHubForTest(t)
|
||||||
|
|
||||||
|
client := NewTestClient(t, server, hub)
|
||||||
|
defer client.CloseWithBye()
|
||||||
|
|
||||||
|
issuedAt := time.Now().Add(-time.Minute)
|
||||||
|
if err := client.SendHelloV2WithTimes(testDefaultUserId, issuedAt, issuedAt.Add(time.Second)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
message, err := client.RunUntilMessage(ctx)
|
||||||
|
if err := checkUnexpectedClose(err); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := checkMessageType(message, "error"); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
} else if message.Error.Code != "token_expired" {
|
||||||
|
t.Errorf("Expected \"token_expired\" reason, got %+v", message.Error)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientHelloV2_IssuedAtMissing(t *testing.T) {
|
||||||
|
for _, algo := range testHelloV2Algorithms {
|
||||||
|
t.Run(algo, func(t *testing.T) {
|
||||||
|
hub, _, _, server := CreateHubForTest(t)
|
||||||
|
|
||||||
|
client := NewTestClient(t, server, hub)
|
||||||
|
defer client.CloseWithBye()
|
||||||
|
|
||||||
|
var issuedAt time.Time
|
||||||
|
expiresAt := time.Now().Add(time.Minute)
|
||||||
|
if err := client.SendHelloV2WithTimes(testDefaultUserId, issuedAt, expiresAt); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
message, err := client.RunUntilMessage(ctx)
|
||||||
|
if err := checkUnexpectedClose(err); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := checkMessageType(message, "error"); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
} else if message.Error.Code != "token_not_valid_yet" {
|
||||||
|
t.Errorf("Expected \"token_not_valid_yet\" reason, got %+v", message.Error)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientHelloV2_ExpiresAtMissing(t *testing.T) {
|
||||||
|
for _, algo := range testHelloV2Algorithms {
|
||||||
|
t.Run(algo, func(t *testing.T) {
|
||||||
|
hub, _, _, server := CreateHubForTest(t)
|
||||||
|
|
||||||
|
client := NewTestClient(t, server, hub)
|
||||||
|
defer client.CloseWithBye()
|
||||||
|
|
||||||
|
issuedAt := time.Now().Add(-time.Minute)
|
||||||
|
var expiresAt time.Time
|
||||||
|
if err := client.SendHelloV2WithTimes(testDefaultUserId, issuedAt, expiresAt); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
message, err := client.RunUntilMessage(ctx)
|
||||||
|
if err := checkUnexpectedClose(err); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := checkMessageType(message, "error"); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
} else if message.Error.Code != "token_expired" {
|
||||||
|
t.Errorf("Expected \"token_expired\" reason, got %+v", message.Error)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientHelloV2_CachedCapabilities(t *testing.T) {
|
||||||
|
for _, algo := range testHelloV2Algorithms {
|
||||||
|
t.Run(algo, func(t *testing.T) {
|
||||||
|
hub, _, _, server := CreateHubForTest(t)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Simulate old-style Nextcloud without capabilities for Hello V2.
|
||||||
|
t.Setenv("SKIP_V2_CAPABILITIES", "1")
|
||||||
|
|
||||||
|
client1 := NewTestClient(t, server, hub)
|
||||||
|
defer client1.CloseWithBye()
|
||||||
|
|
||||||
|
if err := client1.SendHelloV1(testDefaultUserId + "1"); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
hello1, err := client1.RunUntilHello(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if hello1.Hello.UserId != testDefaultUserId+"1" {
|
||||||
|
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId+"1", hello1.Hello)
|
||||||
|
}
|
||||||
|
if hello1.Hello.SessionId == "" {
|
||||||
|
t.Errorf("Expected session id, got %+v", hello1.Hello)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate updated Nextcloud with capabilities for Hello V2.
|
||||||
|
t.Setenv("SKIP_V2_CAPABILITIES", "")
|
||||||
|
|
||||||
|
client2 := NewTestClient(t, server, hub)
|
||||||
|
defer client2.CloseWithBye()
|
||||||
|
|
||||||
|
if err := client2.SendHelloV2(testDefaultUserId + "2"); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
hello2, err := client2.RunUntilHello(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if hello2.Hello.UserId != testDefaultUserId+"2" {
|
||||||
|
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId+"2", hello2.Hello)
|
||||||
|
}
|
||||||
|
if hello2.Hello.SessionId == "" {
|
||||||
|
t.Errorf("Expected session id, got %+v", hello2.Hello)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestClientHelloWithSpaces(t *testing.T) {
|
func TestClientHelloWithSpaces(t *testing.T) {
|
||||||
hub, _, _, server := CreateHubForTest(t)
|
hub, _, _, server := CreateHubForTest(t)
|
||||||
|
|
||||||
|
@ -826,7 +1253,7 @@ func TestClientHelloSessionLimit(t *testing.T) {
|
||||||
params1 := TestBackendClientAuthParams{
|
params1 := TestBackendClientAuthParams{
|
||||||
UserId: testDefaultUserId,
|
UserId: testDefaultUserId,
|
||||||
}
|
}
|
||||||
if err := client.SendHelloParams(server1.URL+"/one", "client", params1); err != nil {
|
if err := client.SendHelloParams(server1.URL+"/one", HelloVersionV1, "client", params1); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -851,7 +1278,7 @@ func TestClientHelloSessionLimit(t *testing.T) {
|
||||||
params2 := TestBackendClientAuthParams{
|
params2 := TestBackendClientAuthParams{
|
||||||
UserId: testDefaultUserId + "2",
|
UserId: testDefaultUserId + "2",
|
||||||
}
|
}
|
||||||
if err := client2.SendHelloParams(server1.URL+"/one", "client", params2); err != nil {
|
if err := client2.SendHelloParams(server1.URL+"/one", HelloVersionV1, "client", params2); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -867,7 +1294,7 @@ func TestClientHelloSessionLimit(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// The client can connect to a different backend.
|
// The client can connect to a different backend.
|
||||||
if err := client2.SendHelloParams(server1.URL+"/two", "client", params2); err != nil {
|
if err := client2.SendHelloParams(server1.URL+"/two", HelloVersionV1, "client", params2); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -894,7 +1321,7 @@ func TestClientHelloSessionLimit(t *testing.T) {
|
||||||
params3 := TestBackendClientAuthParams{
|
params3 := TestBackendClientAuthParams{
|
||||||
UserId: testDefaultUserId + "3",
|
UserId: testDefaultUserId + "3",
|
||||||
}
|
}
|
||||||
if err := client3.SendHelloParams(server1.URL+"/one", "client", params3); err != nil {
|
if err := client3.SendHelloParams(server1.URL+"/one", HelloVersionV1, "client", params3); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1498,7 +1925,7 @@ func TestClientHelloClient_V3Api(t *testing.T) {
|
||||||
}
|
}
|
||||||
// The "/api/v1/signaling/" URL will be changed to use "v3" as the "signaling-v3"
|
// The "/api/v1/signaling/" URL will be changed to use "v3" as the "signaling-v3"
|
||||||
// feature is returned by the capabilities endpoint.
|
// feature is returned by the capabilities endpoint.
|
||||||
if err := client.SendHelloParams(server.URL+"/ocs/v2.php/apps/spreed/api/v1/signaling/backend", "client", params); err != nil {
|
if err := client.SendHelloParams(server.URL+"/ocs/v2.php/apps/spreed/api/v1/signaling/backend", HelloVersionV1, "client", params); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3841,7 +4268,7 @@ func TestNoSendBetweenSessionsOnDifferentBackends(t *testing.T) {
|
||||||
params1 := TestBackendClientAuthParams{
|
params1 := TestBackendClientAuthParams{
|
||||||
UserId: "user1",
|
UserId: "user1",
|
||||||
}
|
}
|
||||||
if err := client1.SendHelloParams(server.URL+"/one", "client", params1); err != nil {
|
if err := client1.SendHelloParams(server.URL+"/one", HelloVersionV1, "client", params1); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
hello1, err := client1.RunUntilHello(ctx)
|
hello1, err := client1.RunUntilHello(ctx)
|
||||||
|
@ -3855,7 +4282,7 @@ func TestNoSendBetweenSessionsOnDifferentBackends(t *testing.T) {
|
||||||
params2 := TestBackendClientAuthParams{
|
params2 := TestBackendClientAuthParams{
|
||||||
UserId: "user2",
|
UserId: "user2",
|
||||||
}
|
}
|
||||||
if err := client2.SendHelloParams(server.URL+"/two", "client", params2); err != nil {
|
if err := client2.SendHelloParams(server.URL+"/two", HelloVersionV1, "client", params2); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
hello2, err := client2.RunUntilHello(ctx)
|
hello2, err := client2.RunUntilHello(ctx)
|
||||||
|
@ -3911,7 +4338,7 @@ func TestNoSameRoomOnDifferentBackends(t *testing.T) {
|
||||||
params1 := TestBackendClientAuthParams{
|
params1 := TestBackendClientAuthParams{
|
||||||
UserId: "user1",
|
UserId: "user1",
|
||||||
}
|
}
|
||||||
if err := client1.SendHelloParams(server.URL+"/one", "client", params1); err != nil {
|
if err := client1.SendHelloParams(server.URL+"/one", HelloVersionV1, "client", params1); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
hello1, err := client1.RunUntilHello(ctx)
|
hello1, err := client1.RunUntilHello(ctx)
|
||||||
|
@ -3925,7 +4352,7 @@ func TestNoSameRoomOnDifferentBackends(t *testing.T) {
|
||||||
params2 := TestBackendClientAuthParams{
|
params2 := TestBackendClientAuthParams{
|
||||||
UserId: "user2",
|
UserId: "user2",
|
||||||
}
|
}
|
||||||
if err := client2.SendHelloParams(server.URL+"/two", "client", params2); err != nil {
|
if err := client2.SendHelloParams(server.URL+"/two", HelloVersionV1, "client", params2); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
hello2, err := client2.RunUntilHello(ctx)
|
hello2, err := client2.RunUntilHello(ctx)
|
||||||
|
|
|
@ -589,7 +589,7 @@ func (s *ProxyServer) processMessage(client *ProxyClient, data []byte) {
|
||||||
Id: message.Id,
|
Id: message.Id,
|
||||||
Type: "hello",
|
Type: "hello",
|
||||||
Hello: &signaling.HelloProxyServerMessage{
|
Hello: &signaling.HelloProxyServerMessage{
|
||||||
Version: signaling.HelloVersion,
|
Version: signaling.HelloVersionV1,
|
||||||
SessionId: session.PublicId(),
|
SessionId: session.PublicId(),
|
||||||
Server: &signaling.WelcomeServerMessage{
|
Server: &signaling.WelcomeServerMessage{
|
||||||
Version: s.version,
|
Version: s.version,
|
||||||
|
|
|
@ -29,11 +29,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
ConfigGroupSignaling = "signaling"
|
|
||||||
ConfigKeySessionPingLimit = "session-ping-limit"
|
|
||||||
)
|
|
||||||
|
|
||||||
type pingEntries struct {
|
type pingEntries struct {
|
||||||
url *url.URL
|
url *url.URL
|
||||||
|
|
||||||
|
@ -124,7 +119,7 @@ func (p *RoomPing) publishEntries(entries *pingEntries, timeout time.Duration) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
limit, found := p.capabilities.GetIntegerConfig(ctx, entries.url, ConfigGroupSignaling, ConfigKeySessionPingLimit)
|
limit, _, found := p.capabilities.GetIntegerConfig(ctx, entries.url, ConfigGroupSignaling, ConfigKeySessionPingLimit)
|
||||||
if !found || limit <= 0 {
|
if !found || limit <= 0 {
|
||||||
// Limit disabled while waiting for the next iteration, fallback to sending
|
// Limit disabled while waiting for the next iteration, fallback to sending
|
||||||
// one request per room.
|
// one request per room.
|
||||||
|
@ -193,7 +188,7 @@ func (p *RoomPing) sendPingsCombined(url *url.URL, entries []BackendPingEntry, l
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *RoomPing) SendPings(ctx context.Context, room *Room, url *url.URL, entries []BackendPingEntry) error {
|
func (p *RoomPing) SendPings(ctx context.Context, room *Room, url *url.URL, entries []BackendPingEntry) error {
|
||||||
limit, found := p.capabilities.GetIntegerConfig(ctx, url, ConfigGroupSignaling, ConfigKeySessionPingLimit)
|
limit, _, found := p.capabilities.GetIntegerConfig(ctx, url, ConfigGroupSignaling, ConfigKeySessionPingLimit)
|
||||||
if !found || limit <= 0 {
|
if !found || limit <= 0 {
|
||||||
// Old-style Nextcloud or session limit not configured. Perform one request
|
// Old-style Nextcloud or session limit not configured. Perform one request
|
||||||
// per room. Don't queue to avoid sending all ping requests to old-style
|
// per room. Don't queue to avoid sending all ping requests to old-style
|
||||||
|
|
|
@ -37,6 +37,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v4"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -359,11 +360,14 @@ func (c *TestClient) WaitForSessionRemoved(ctx context.Context, sessionId string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *TestClient) WriteJSON(data interface{}) error {
|
func (c *TestClient) WriteJSON(data interface{}) error {
|
||||||
|
if !strings.Contains(c.t.Name(), "HelloUnsupportedVersion") {
|
||||||
if msg, ok := data.(*ClientMessage); ok {
|
if msg, ok := data.(*ClientMessage); ok {
|
||||||
if err := msg.CheckValid(); err != nil {
|
if err := msg.CheckValid(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return c.conn.WriteJSON(data)
|
return c.conn.WriteJSON(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -374,10 +378,63 @@ func (c *TestClient) EnsuerWriteJSON(data interface{}) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *TestClient) SendHello(userid string) error {
|
func (c *TestClient) SendHello(userid string) error {
|
||||||
|
return c.SendHelloV1(userid)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TestClient) SendHelloV1(userid string) error {
|
||||||
params := TestBackendClientAuthParams{
|
params := TestBackendClientAuthParams{
|
||||||
UserId: userid,
|
UserId: userid,
|
||||||
}
|
}
|
||||||
return c.SendHelloParams(c.server.URL, "", params)
|
return c.SendHelloParams(c.server.URL, HelloVersionV1, "", params)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TestClient) SendHelloV2(userid string) error {
|
||||||
|
now := time.Now()
|
||||||
|
return c.SendHelloV2WithTimes(userid, now, now.Add(time.Minute))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TestClient) SendHelloV2WithTimes(userid string, issuedAt time.Time, expiresAt time.Time) error {
|
||||||
|
userdata := map[string]string{
|
||||||
|
"displayname": "Displayname " + userid,
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(userdata)
|
||||||
|
if err != nil {
|
||||||
|
c.t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
claims := &HelloV2TokenClaims{
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Issuer: c.server.URL,
|
||||||
|
Subject: userid,
|
||||||
|
},
|
||||||
|
UserData: (*json.RawMessage)(&data),
|
||||||
|
}
|
||||||
|
if !issuedAt.IsZero() {
|
||||||
|
claims.IssuedAt = jwt.NewNumericDate(issuedAt)
|
||||||
|
}
|
||||||
|
if !expiresAt.IsZero() {
|
||||||
|
claims.ExpiresAt = jwt.NewNumericDate(expiresAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
var token *jwt.Token
|
||||||
|
if strings.Contains(c.t.Name(), "ECDSA") {
|
||||||
|
token = jwt.NewWithClaims(jwt.SigningMethodES256, claims)
|
||||||
|
} else if strings.Contains(c.t.Name(), "Ed25519") {
|
||||||
|
token = jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)
|
||||||
|
} else {
|
||||||
|
token = jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||||
|
}
|
||||||
|
private := getPrivateAuthToken(c.t)
|
||||||
|
tokenString, err := token.SignedString(private)
|
||||||
|
if err != nil {
|
||||||
|
c.t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
params := HelloV2AuthParams{
|
||||||
|
Token: tokenString,
|
||||||
|
}
|
||||||
|
return c.SendHelloParams(c.server.URL, HelloVersionV2, "", params)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *TestClient) SendHelloResume(resumeId string) error {
|
func (c *TestClient) SendHelloResume(resumeId string) error {
|
||||||
|
@ -385,7 +442,7 @@ func (c *TestClient) SendHelloResume(resumeId string) error {
|
||||||
Id: "1234",
|
Id: "1234",
|
||||||
Type: "hello",
|
Type: "hello",
|
||||||
Hello: &HelloClientMessage{
|
Hello: &HelloClientMessage{
|
||||||
Version: HelloVersion,
|
Version: HelloVersionV1,
|
||||||
ResumeId: resumeId,
|
ResumeId: resumeId,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -396,7 +453,7 @@ func (c *TestClient) SendHelloClient(userid string) error {
|
||||||
params := TestBackendClientAuthParams{
|
params := TestBackendClientAuthParams{
|
||||||
UserId: userid,
|
UserId: userid,
|
||||||
}
|
}
|
||||||
return c.SendHelloParams(c.server.URL, "client", params)
|
return c.SendHelloParams(c.server.URL, HelloVersionV1, "client", params)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *TestClient) SendHelloInternal() error {
|
func (c *TestClient) SendHelloInternal() error {
|
||||||
|
@ -411,10 +468,10 @@ func (c *TestClient) SendHelloInternal() error {
|
||||||
Token: token,
|
Token: token,
|
||||||
Backend: backend,
|
Backend: backend,
|
||||||
}
|
}
|
||||||
return c.SendHelloParams("", "internal", params)
|
return c.SendHelloParams("", HelloVersionV1, "internal", params)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *TestClient) SendHelloParams(url string, clientType string, params interface{}) error {
|
func (c *TestClient) SendHelloParams(url string, version string, clientType string, params interface{}) error {
|
||||||
data, err := json.Marshal(params)
|
data, err := json.Marshal(params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.t.Fatal(err)
|
c.t.Fatal(err)
|
||||||
|
@ -424,7 +481,7 @@ func (c *TestClient) SendHelloParams(url string, clientType string, params inter
|
||||||
Id: "1234",
|
Id: "1234",
|
||||||
Type: "hello",
|
Type: "hello",
|
||||||
Hello: &HelloClientMessage{
|
Hello: &HelloClientMessage{
|
||||||
Version: HelloVersion,
|
Version: version,
|
||||||
Auth: HelloClientMessageAuth{
|
Auth: HelloClientMessageAuth{
|
||||||
Type: clientType,
|
Type: clientType,
|
||||||
Url: url,
|
Url: url,
|
||||||
|
|
Loading…
Reference in a new issue