Merge pull request #1149 from strukturag/parallelize-tests

Parallelize more tests.
This commit is contained in:
Joachim Bauch 2025-12-09 14:12:29 +01:00 committed by GitHub
commit 67b557349d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
29 changed files with 800 additions and 459 deletions

View file

@ -171,10 +171,19 @@ type BackendStorage interface {
GetBackends() []*Backend
}
type BackendStorageStats interface {
AddBackends(count int)
RemoveBackends(count int)
IncBackends()
DecBackends()
}
type backendStorageCommon struct {
mu sync.RWMutex
// +checklocks:mu
backends map[string][]*Backend
stats BackendStorageStats // +checklocksignore: Only written to from constructor
}
func (s *backendStorageCommon) GetBackends() []*Backend {
@ -224,21 +233,50 @@ type BackendConfiguration struct {
storage BackendStorage
}
type prometheusBackendStats struct{}
func (s *prometheusBackendStats) AddBackends(count int) {
statsBackendsCurrent.Add(float64(count))
}
func (s *prometheusBackendStats) RemoveBackends(count int) {
statsBackendsCurrent.Sub(float64(count))
}
func (s *prometheusBackendStats) IncBackends() {
statsBackendsCurrent.Inc()
}
func (s *prometheusBackendStats) DecBackends() {
statsBackendsCurrent.Dec()
}
var (
defaultBackendStats = &prometheusBackendStats{}
)
func NewBackendConfiguration(logger Logger, config *goconf.ConfigFile, etcdClient *EtcdClient) (*BackendConfiguration, error) {
return NewBackendConfigurationWithStats(logger, config, etcdClient, nil)
}
func NewBackendConfigurationWithStats(logger Logger, config *goconf.ConfigFile, etcdClient *EtcdClient, stats BackendStorageStats) (*BackendConfiguration, error) {
backendType, _ := config.GetString("backend", "backendtype")
if backendType == "" {
backendType = DefaultBackendType
}
RegisterBackendConfigurationStats()
if stats == nil {
RegisterBackendConfigurationStats()
stats = defaultBackendStats
}
var storage BackendStorage
var err error
switch backendType {
case BackendTypeStatic:
storage, err = NewBackendStorageStatic(logger, config)
storage, err = NewBackendStorageStatic(logger, config, stats)
case BackendTypeEtcd:
storage, err = NewBackendStorageEtcd(logger, config, etcdClient)
storage, err = NewBackendStorageEtcd(logger, config, etcdClient, stats)
default:
err = fmt.Errorf("unknown backend type: %s", backendType)
}

View file

@ -84,6 +84,26 @@ func testBackends(t *testing.T, config *BackendConfiguration, valid_urls [][]str
}
}
type mockBackendStats struct {
value int
}
func (s *mockBackendStats) AddBackends(count int) {
s.value += count
}
func (s *mockBackendStats) RemoveBackends(count int) {
s.value -= count
}
func (s *mockBackendStats) IncBackends() {
s.value++
}
func (s *mockBackendStats) DecBackends() {
s.value--
}
func TestIsUrlAllowed_Compat(t *testing.T) {
t.Parallel()
logger := NewLoggerForTest(t)
@ -233,11 +253,13 @@ func TestParseBackendIds(t *testing.T) {
}
}
func TestBackendReloadNoChange(t *testing.T) { // nolint:paralleltest
ResetStatsValue(t, statsBackendsCurrent)
func TestBackendReloadNoChange(t *testing.T) {
t.Parallel()
stats := &mockBackendStats{}
logger := NewLoggerForTest(t)
require := require.New(t)
assert := assert.New(t)
original_config := goconf.NewConfigFile()
original_config.AddOption("backend", "backends", "backend1, backend2")
original_config.AddOption("backend", "allowall", "false")
@ -245,9 +267,9 @@ func TestBackendReloadNoChange(t *testing.T) { // nolint:paralleltest
original_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
original_config.AddOption("backend2", "url", "http://domain2.invalid")
original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
o_cfg, err := NewBackendConfiguration(logger, original_config, nil)
o_cfg, err := NewBackendConfigurationWithStats(logger, original_config, nil, stats)
require.NoError(err)
checkStatsValue(t, statsBackendsCurrent, 2)
assert.Equal(2, stats.value)
new_config := goconf.NewConfigFile()
new_config.AddOption("backend", "backends", "backend1, backend2")
@ -256,22 +278,24 @@ func TestBackendReloadNoChange(t *testing.T) { // nolint:paralleltest
new_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
new_config.AddOption("backend2", "url", "http://domain2.invalid")
new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
n_cfg, err := NewBackendConfiguration(logger, new_config, nil)
n_cfg, err := NewBackendConfigurationWithStats(logger, new_config, nil, stats)
require.NoError(err)
checkStatsValue(t, statsBackendsCurrent, 4)
assert.Equal(4, stats.value)
o_cfg.Reload(original_config)
checkStatsValue(t, statsBackendsCurrent, 4)
assert.Equal(4, stats.value)
if !reflect.DeepEqual(n_cfg, o_cfg) {
assert.Fail(t, "BackendConfiguration should be equal after Reload")
assert.Fail("BackendConfiguration should be equal after Reload")
}
}
func TestBackendReloadChangeExistingURL(t *testing.T) { // nolint:paralleltest
ResetStatsValue(t, statsBackendsCurrent)
func TestBackendReloadChangeExistingURL(t *testing.T) {
t.Parallel()
stats := &mockBackendStats{}
logger := NewLoggerForTest(t)
require := require.New(t)
assert := assert.New(t)
original_config := goconf.NewConfigFile()
original_config.AddOption("backend", "backends", "backend1, backend2")
original_config.AddOption("backend", "allowall", "false")
@ -279,10 +303,10 @@ func TestBackendReloadChangeExistingURL(t *testing.T) { // nolint:paralleltest
original_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
original_config.AddOption("backend2", "url", "http://domain2.invalid")
original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
o_cfg, err := NewBackendConfiguration(logger, original_config, nil)
o_cfg, err := NewBackendConfigurationWithStats(logger, original_config, nil, stats)
require.NoError(err)
checkStatsValue(t, statsBackendsCurrent, 2)
assert.Equal(2, stats.value)
new_config := goconf.NewConfigFile()
new_config.AddOption("backend", "backends", "backend1, backend2")
new_config.AddOption("backend", "allowall", "false")
@ -291,26 +315,28 @@ func TestBackendReloadChangeExistingURL(t *testing.T) { // nolint:paralleltest
new_config.AddOption("backend1", "sessionlimit", "10")
new_config.AddOption("backend2", "url", "http://domain2.invalid")
new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
n_cfg, err := NewBackendConfiguration(logger, new_config, nil)
n_cfg, err := NewBackendConfigurationWithStats(logger, new_config, nil, stats)
require.NoError(err)
checkStatsValue(t, statsBackendsCurrent, 4)
assert.Equal(4, stats.value)
original_config.RemoveOption("backend1", "url")
original_config.AddOption("backend1", "url", "http://domain3.invalid")
original_config.AddOption("backend1", "sessionlimit", "10")
o_cfg.Reload(original_config)
checkStatsValue(t, statsBackendsCurrent, 4)
assert.Equal(4, stats.value)
if !reflect.DeepEqual(n_cfg, o_cfg) {
assert.Fail(t, "BackendConfiguration should be equal after Reload")
assert.Fail("BackendConfiguration should be equal after Reload")
}
}
func TestBackendReloadChangeSecret(t *testing.T) { // nolint:paralleltest
ResetStatsValue(t, statsBackendsCurrent)
func TestBackendReloadChangeSecret(t *testing.T) {
t.Parallel()
stats := &mockBackendStats{}
logger := NewLoggerForTest(t)
require := require.New(t)
assert := assert.New(t)
original_config := goconf.NewConfigFile()
original_config.AddOption("backend", "backends", "backend1, backend2")
original_config.AddOption("backend", "allowall", "false")
@ -318,10 +344,10 @@ func TestBackendReloadChangeSecret(t *testing.T) { // nolint:paralleltest
original_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
original_config.AddOption("backend2", "url", "http://domain2.invalid")
original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
o_cfg, err := NewBackendConfiguration(logger, original_config, nil)
o_cfg, err := NewBackendConfigurationWithStats(logger, original_config, nil, stats)
require.NoError(err)
checkStatsValue(t, statsBackendsCurrent, 2)
assert.Equal(2, stats.value)
new_config := goconf.NewConfigFile()
new_config.AddOption("backend", "backends", "backend1, backend2")
new_config.AddOption("backend", "allowall", "false")
@ -329,32 +355,34 @@ func TestBackendReloadChangeSecret(t *testing.T) { // nolint:paralleltest
new_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend3")
new_config.AddOption("backend2", "url", "http://domain2.invalid")
new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
n_cfg, err := NewBackendConfiguration(logger, new_config, nil)
n_cfg, err := NewBackendConfigurationWithStats(logger, new_config, nil, stats)
require.NoError(err)
checkStatsValue(t, statsBackendsCurrent, 4)
assert.Equal(4, stats.value)
original_config.RemoveOption("backend1", "secret")
original_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend3")
o_cfg.Reload(original_config)
checkStatsValue(t, statsBackendsCurrent, 4)
assert.Equal(t, n_cfg, o_cfg, "BackendConfiguration should be equal after Reload")
assert.Equal(4, stats.value)
assert.Equal(n_cfg, o_cfg, "BackendConfiguration should be equal after Reload")
}
func TestBackendReloadAddBackend(t *testing.T) { // nolint:paralleltest
ResetStatsValue(t, statsBackendsCurrent)
func TestBackendReloadAddBackend(t *testing.T) {
t.Parallel()
stats := &mockBackendStats{}
logger := NewLoggerForTest(t)
require := require.New(t)
assert := assert.New(t)
original_config := goconf.NewConfigFile()
original_config.AddOption("backend", "backends", "backend1")
original_config.AddOption("backend", "allowall", "false")
original_config.AddOption("backend1", "url", "http://domain1.invalid")
original_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
o_cfg, err := NewBackendConfiguration(logger, original_config, nil)
o_cfg, err := NewBackendConfigurationWithStats(logger, original_config, nil, stats)
require.NoError(err)
checkStatsValue(t, statsBackendsCurrent, 1)
assert.Equal(1, stats.value)
new_config := goconf.NewConfigFile()
new_config.AddOption("backend", "backends", "backend1, backend2")
new_config.AddOption("backend", "allowall", "false")
@ -363,10 +391,10 @@ func TestBackendReloadAddBackend(t *testing.T) { // nolint:paralleltest
new_config.AddOption("backend2", "url", "http://domain2.invalid")
new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
new_config.AddOption("backend2", "sessionlimit", "10")
n_cfg, err := NewBackendConfiguration(logger, new_config, nil)
n_cfg, err := NewBackendConfigurationWithStats(logger, new_config, nil, stats)
require.NoError(err)
checkStatsValue(t, statsBackendsCurrent, 3)
assert.Equal(3, stats.value)
original_config.RemoveOption("backend", "backends")
original_config.AddOption("backend", "backends", "backend1, backend2")
original_config.AddOption("backend2", "url", "http://domain2.invalid")
@ -374,17 +402,19 @@ func TestBackendReloadAddBackend(t *testing.T) { // nolint:paralleltest
original_config.AddOption("backend2", "sessionlimit", "10")
o_cfg.Reload(original_config)
checkStatsValue(t, statsBackendsCurrent, 4)
assert.Equal(4, stats.value)
if !reflect.DeepEqual(n_cfg, o_cfg) {
assert.Fail(t, "BackendConfiguration should be equal after Reload")
assert.Fail("BackendConfiguration should be equal after Reload")
}
}
func TestBackendReloadRemoveHost(t *testing.T) { // nolint:paralleltest
ResetStatsValue(t, statsBackendsCurrent)
func TestBackendReloadRemoveHost(t *testing.T) {
t.Parallel()
stats := &mockBackendStats{}
logger := NewLoggerForTest(t)
require := require.New(t)
assert := assert.New(t)
original_config := goconf.NewConfigFile()
original_config.AddOption("backend", "backends", "backend1, backend2")
original_config.AddOption("backend", "allowall", "false")
@ -392,35 +422,37 @@ func TestBackendReloadRemoveHost(t *testing.T) { // nolint:paralleltest
original_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
original_config.AddOption("backend2", "url", "http://domain2.invalid")
original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
o_cfg, err := NewBackendConfiguration(logger, original_config, nil)
o_cfg, err := NewBackendConfigurationWithStats(logger, original_config, nil, stats)
require.NoError(err)
checkStatsValue(t, statsBackendsCurrent, 2)
assert.Equal(2, stats.value)
new_config := goconf.NewConfigFile()
new_config.AddOption("backend", "backends", "backend1")
new_config.AddOption("backend", "allowall", "false")
new_config.AddOption("backend1", "url", "http://domain1.invalid")
new_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
n_cfg, err := NewBackendConfiguration(logger, new_config, nil)
n_cfg, err := NewBackendConfigurationWithStats(logger, new_config, nil, stats)
require.NoError(err)
checkStatsValue(t, statsBackendsCurrent, 3)
assert.Equal(3, stats.value)
original_config.RemoveOption("backend", "backends")
original_config.AddOption("backend", "backends", "backend1")
original_config.RemoveSection("backend2")
o_cfg.Reload(original_config)
checkStatsValue(t, statsBackendsCurrent, 2)
assert.Equal(2, stats.value)
if !reflect.DeepEqual(n_cfg, o_cfg) {
assert.Fail(t, "BackendConfiguration should be equal after Reload")
assert.Fail("BackendConfiguration should be equal after Reload")
}
}
func TestBackendReloadRemoveBackendFromSharedHost(t *testing.T) { // nolint:paralleltest
ResetStatsValue(t, statsBackendsCurrent)
func TestBackendReloadRemoveBackendFromSharedHost(t *testing.T) {
t.Parallel()
stats := &mockBackendStats{}
logger := NewLoggerForTest(t)
require := require.New(t)
assert := assert.New(t)
original_config := goconf.NewConfigFile()
original_config.AddOption("backend", "backends", "backend1, backend2")
original_config.AddOption("backend", "allowall", "false")
@ -428,27 +460,27 @@ func TestBackendReloadRemoveBackendFromSharedHost(t *testing.T) { // nolint:para
original_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
original_config.AddOption("backend2", "url", "http://domain1.invalid/bar/")
original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
o_cfg, err := NewBackendConfiguration(logger, original_config, nil)
o_cfg, err := NewBackendConfigurationWithStats(logger, original_config, nil, stats)
require.NoError(err)
checkStatsValue(t, statsBackendsCurrent, 2)
assert.Equal(2, stats.value)
new_config := goconf.NewConfigFile()
new_config.AddOption("backend", "backends", "backend1")
new_config.AddOption("backend", "allowall", "false")
new_config.AddOption("backend1", "url", "http://domain1.invalid/foo/")
new_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
n_cfg, err := NewBackendConfiguration(logger, new_config, nil)
n_cfg, err := NewBackendConfigurationWithStats(logger, new_config, nil, stats)
require.NoError(err)
checkStatsValue(t, statsBackendsCurrent, 3)
assert.Equal(3, stats.value)
original_config.RemoveOption("backend", "backends")
original_config.AddOption("backend", "backends", "backend1")
original_config.RemoveSection("backend2")
o_cfg.Reload(original_config)
checkStatsValue(t, statsBackendsCurrent, 2)
assert.Equal(2, stats.value)
if !reflect.DeepEqual(n_cfg, o_cfg) {
assert.Fail(t, "BackendConfiguration should be equal after Reload")
assert.Fail("BackendConfiguration should be equal after Reload")
}
}
@ -468,8 +500,9 @@ func mustParse(s string) *url.URL {
return p
}
func TestBackendConfiguration_EtcdCompat(t *testing.T) { // nolint:paralleltest
ResetStatsValue(t, statsBackendsCurrent)
func TestBackendConfiguration_EtcdCompat(t *testing.T) {
t.Parallel()
stats := &mockBackendStats{}
logger := NewLoggerForTest(t)
require := require.New(t)
@ -486,9 +519,7 @@ func TestBackendConfiguration_EtcdCompat(t *testing.T) { // nolint:paralleltest
config.AddOption("backend", "backendtype", "etcd")
config.AddOption("backend", "backendprefix", "/backends")
checkStatsValue(t, statsBackendsCurrent, 0)
cfg, err := NewBackendConfiguration(logger, config, client)
cfg, err := NewBackendConfigurationWithStats(logger, config, client, stats)
require.NoError(err)
defer cfg.Close()
@ -511,7 +542,7 @@ func TestBackendConfiguration_EtcdCompat(t *testing.T) { // nolint:paralleltest
drainWakeupChannel(ch)
SetEtcdValue(etcd, "/backends/1_one", []byte("{\"url\":\""+url1+"\",\"secret\":\""+secret1+"\"}"))
<-ch
checkStatsValue(t, statsBackendsCurrent, 1)
assert.Equal(1, stats.value)
if backends := sortBackends(cfg.GetBackends()); assert.Len(backends, 1) &&
assert.Equal([]string{url1}, backends[0].urls) &&
assert.Equal(secret1, string(backends[0].secret)) {
@ -526,7 +557,7 @@ func TestBackendConfiguration_EtcdCompat(t *testing.T) { // nolint:paralleltest
drainWakeupChannel(ch)
SetEtcdValue(etcd, "/backends/2_two", []byte("{\"url\":\""+url2+"\",\"secret\":\""+secret2+"\"}"))
<-ch
checkStatsValue(t, statsBackendsCurrent, 2)
assert.Equal(2, stats.value)
if backends := sortBackends(cfg.GetBackends()); assert.Len(backends, 2) &&
assert.Equal([]string{url1}, backends[0].urls) &&
assert.Equal(secret1, string(backends[0].secret)) &&
@ -545,7 +576,7 @@ func TestBackendConfiguration_EtcdCompat(t *testing.T) { // nolint:paralleltest
drainWakeupChannel(ch)
SetEtcdValue(etcd, "/backends/3_three", []byte("{\"url\":\""+url3+"\",\"secret\":\""+secret3+"\"}"))
<-ch
checkStatsValue(t, statsBackendsCurrent, 3)
assert.Equal(3, stats.value)
if backends := sortBackends(cfg.GetBackends()); assert.Len(backends, 3) &&
assert.Equal([]string{url1}, backends[0].urls) &&
assert.Equal(secret1, string(backends[0].secret)) &&
@ -565,7 +596,7 @@ func TestBackendConfiguration_EtcdCompat(t *testing.T) { // nolint:paralleltest
drainWakeupChannel(ch)
DeleteEtcdValue(etcd, "/backends/1_one")
<-ch
checkStatsValue(t, statsBackendsCurrent, 2)
assert.Equal(2, stats.value)
if backends := sortBackends(cfg.GetBackends()); assert.Len(backends, 2) {
assert.Equal([]string{url2}, backends[0].urls)
assert.Equal(secret2, string(backends[0].secret))
@ -576,7 +607,7 @@ func TestBackendConfiguration_EtcdCompat(t *testing.T) { // nolint:paralleltest
drainWakeupChannel(ch)
DeleteEtcdValue(etcd, "/backends/2_two")
<-ch
checkStatsValue(t, statsBackendsCurrent, 1)
assert.Equal(1, stats.value)
if backends := sortBackends(cfg.GetBackends()); assert.Len(backends, 1) {
assert.Equal([]string{url3}, backends[0].urls)
assert.Equal(secret3, string(backends[0].secret))
@ -629,8 +660,9 @@ func TestBackendCommonSecret(t *testing.T) {
}
}
func TestBackendChangeUrls(t *testing.T) { // nolint:paralleltest
ResetStatsValue(t, statsBackendsCurrent)
func TestBackendChangeUrls(t *testing.T) {
t.Parallel()
stats := &mockBackendStats{}
logger := NewLoggerForTest(t)
require := require.New(t)
@ -645,12 +677,10 @@ func TestBackendChangeUrls(t *testing.T) { // nolint:paralleltest
original_config.AddOption("backend1", "urls", u1.String())
original_config.AddOption("backend2", "urls", u2.String())
checkStatsValue(t, statsBackendsCurrent, 0)
cfg, err := NewBackendConfiguration(logger, original_config, nil)
cfg, err := NewBackendConfigurationWithStats(logger, original_config, nil, stats)
require.NoError(err)
checkStatsValue(t, statsBackendsCurrent, 2)
assert.Equal(2, stats.value)
if b1 := cfg.GetBackend(u1); assert.NotNil(b1) {
assert.Equal("backend1", b1.Id())
assert.Equal(string(testBackendSecret), string(b1.Secret()))
@ -670,7 +700,7 @@ func TestBackendChangeUrls(t *testing.T) { // nolint:paralleltest
updated_config.AddOption("backend1", "urls", strings.Join([]string{u1.String(), u2.String()}, ","))
cfg.Reload(updated_config)
checkStatsValue(t, statsBackendsCurrent, 1)
assert.Equal(1, stats.value)
if b1 := cfg.GetBackend(u1); assert.NotNil(b1) {
assert.Equal("backend1", b1.Id())
assert.Equal(string(testBackendSecret)+"-backend1", string(b1.Secret()))
@ -684,7 +714,7 @@ func TestBackendChangeUrls(t *testing.T) { // nolint:paralleltest
// No change reload.
cfg.Reload(updated_config)
checkStatsValue(t, statsBackendsCurrent, 1)
assert.Equal(1, stats.value)
if b1 := cfg.GetBackend(u1); assert.NotNil(b1) {
assert.Equal("backend1", b1.Id())
assert.Equal(string(testBackendSecret)+"-backend1", string(b1.Secret()))
@ -703,7 +733,7 @@ func TestBackendChangeUrls(t *testing.T) { // nolint:paralleltest
updated_config.AddOption("backend1", "urls", u2.String())
cfg.Reload(updated_config)
checkStatsValue(t, statsBackendsCurrent, 1)
assert.Equal(1, stats.value)
if b1 := cfg.GetBackend(u2); assert.NotNil(b1) {
assert.Equal("backend1", b1.Id())
assert.Equal(string(testBackendSecret), string(b1.Secret()))
@ -715,13 +745,14 @@ func TestBackendChangeUrls(t *testing.T) { // nolint:paralleltest
updated_config.AddOption("backend", "secret", string(testBackendSecret))
cfg.Reload(updated_config)
checkStatsValue(t, statsBackendsCurrent, 0)
assert.Equal(0, stats.value)
b1 := cfg.GetBackend(u2)
assert.Nil(b1)
}
func TestBackendConfiguration_EtcdChangeUrls(t *testing.T) { // nolint:paralleltest
ResetStatsValue(t, statsBackendsCurrent)
func TestBackendConfiguration_EtcdChangeUrls(t *testing.T) {
t.Parallel()
stats := &mockBackendStats{}
logger := NewLoggerForTest(t)
require := require.New(t)
@ -738,9 +769,7 @@ func TestBackendConfiguration_EtcdChangeUrls(t *testing.T) { // nolint:parallelt
config.AddOption("backend", "backendtype", "etcd")
config.AddOption("backend", "backendprefix", "/backends")
checkStatsValue(t, statsBackendsCurrent, 0)
cfg, err := NewBackendConfiguration(logger, config, client)
cfg, err := NewBackendConfigurationWithStats(logger, config, client, stats)
require.NoError(err)
defer cfg.Close()
@ -752,7 +781,7 @@ func TestBackendConfiguration_EtcdChangeUrls(t *testing.T) { // nolint:parallelt
require.NoError(storage.WaitForInitialized(ctx))
checkStatsValue(t, statsBackendsCurrent, 1)
assert.Equal(1, stats.value)
if backends := sortBackends(cfg.GetBackends()); assert.Len(backends, 1) &&
assert.Equal([]string{url1}, backends[0].Urls()) &&
assert.Equal(initialSecret1, string(backends[0].Secret())) {
@ -766,7 +795,7 @@ func TestBackendConfiguration_EtcdChangeUrls(t *testing.T) { // nolint:parallelt
drainWakeupChannel(ch)
SetEtcdValue(etcd, "/backends/1_one", []byte("{\"urls\":[\""+url1+"\",\""+url2+"\"],\"secret\":\""+secret1+"\"}"))
<-ch
checkStatsValue(t, statsBackendsCurrent, 1)
assert.Equal(1, stats.value)
if backends := sortBackends(cfg.GetBackends()); assert.Len(backends, 1) &&
assert.Equal([]string{url2, url1}, backends[0].Urls()) &&
assert.Equal(secret1, string(backends[0].Secret())) {
@ -786,7 +815,7 @@ func TestBackendConfiguration_EtcdChangeUrls(t *testing.T) { // nolint:parallelt
drainWakeupChannel(ch)
SetEtcdValue(etcd, "/backends/3_three", []byte("{\"urls\":[\""+url3+"\",\""+url4+"\"],\"secret\":\""+secret3+"\"}"))
<-ch
checkStatsValue(t, statsBackendsCurrent, 2)
assert.Equal(2, stats.value)
if backends := sortBackends(cfg.GetBackends()); assert.Len(backends, 2) &&
assert.Equal([]string{url2, url1}, backends[0].Urls()) &&
assert.Equal(secret1, string(backends[0].Secret())) &&
@ -806,7 +835,7 @@ func TestBackendConfiguration_EtcdChangeUrls(t *testing.T) { // nolint:parallelt
drainWakeupChannel(ch)
DeleteEtcdValue(etcd, "/backends/1_one")
<-ch
checkStatsValue(t, statsBackendsCurrent, 1)
assert.Equal(1, stats.value)
if backends := sortBackends(cfg.GetBackends()); assert.Len(backends, 1) {
assert.Equal([]string{url3, url4}, backends[0].Urls())
assert.Equal(secret3, string(backends[0].Secret()))
@ -816,7 +845,7 @@ func TestBackendConfiguration_EtcdChangeUrls(t *testing.T) { // nolint:parallelt
DeleteEtcdValue(etcd, "/backends/3_three")
<-ch
checkStatsValue(t, statsBackendsCurrent, 0)
assert.Equal(0, stats.value)
storage.mu.RLock()
_, found := storage.backends["domain1.invalid"]
storage.mu.RUnlock()

View file

@ -171,7 +171,7 @@ func CreateBackendServerWithClusteringForTestFromConfig(t *testing.T, config1 *g
defer cancel()
assert.NoError(events1.Close(ctx))
})
client1, _ := NewGrpcClientsForTest(t, addr2)
client1, _ := NewGrpcClientsForTest(t, addr2, nil)
hub1, err := NewHub(ctx, config1, events1, grpcServer1, client1, nil, r1, "no-version")
require.NoError(err)
@ -196,7 +196,7 @@ func CreateBackendServerWithClusteringForTestFromConfig(t *testing.T, config1 *g
defer cancel()
assert.NoError(events2.Close(ctx))
})
client2, _ := NewGrpcClientsForTest(t, addr1)
client2, _ := NewGrpcClientsForTest(t, addr1, nil)
hub2, err := NewHub(ctx, config2, events2, grpcServer2, client2, nil, r2, "no-version")
require.NoError(err)
@ -483,6 +483,7 @@ func RunTestBackendServer_RoomDisinvite(ctx context.Context, t *testing.T) {
defer cancel()
client, hello := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId)
defer client.CloseWithBye()
// Join room by id.
roomId := "test-room"
@ -550,7 +551,9 @@ func TestBackendServer_RoomDisinviteDifferentRooms(t *testing.T) {
defer cancel()
client1, hello1 := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId)
defer client1.CloseWithBye()
client2, hello2 := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId)
defer client2.CloseWithBye()
// Join room by id.
roomId1 := "test-room1"
@ -780,7 +783,9 @@ func TestBackendServer_ParticipantsUpdatePermissions(t *testing.T) {
defer cancel()
client1, hello1 := NewTestClientWithHello(ctx, t, server1, hub1, testDefaultUserId+"1")
defer client1.CloseWithBye()
client2, hello2 := NewTestClientWithHello(ctx, t, server2, hub2, testDefaultUserId+"2")
defer client2.CloseWithBye()
session1 := hub1.GetSessionByPublicId(hello1.Hello.SessionId)
require.NotNil(session1, "Session %s does not exist", hello1.Hello.SessionId)
@ -863,6 +868,7 @@ func TestBackendServer_ParticipantsUpdateEmptyPermissions(t *testing.T) {
defer cancel()
client, hello := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId)
defer client.CloseWithBye()
session := hub.GetSessionByPublicId(hello.Hello.SessionId)
assert.NotNil(session, "Session %s does not exist", hello.Hello.SessionId)
@ -927,7 +933,9 @@ func TestBackendServer_ParticipantsUpdateTimeout(t *testing.T) {
defer cancel()
client1, hello1 := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId+"1")
defer client1.CloseWithBye()
client2, hello2 := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId+"2")
defer client2.CloseWithBye()
// Join room by id.
roomId := "test-room"
@ -1100,7 +1108,9 @@ func TestBackendServer_InCallAll(t *testing.T) {
defer cancel()
client1, hello1 := NewTestClientWithHello(ctx, t, server1, hub1, testDefaultUserId+"1")
defer client1.CloseWithBye()
client2, hello2 := NewTestClientWithHello(ctx, t, server2, hub2, testDefaultUserId+"2")
defer client2.CloseWithBye()
session1 := hub1.GetSessionByPublicId(hello1.Hello.SessionId)
require.NotNil(session1, "Could not find session %s", hello1.Hello.SessionId)
@ -1258,6 +1268,7 @@ func TestBackendServer_RoomMessage(t *testing.T) {
defer cancel()
client, _ := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId+"1")
defer client.CloseWithBye()
// Join room by id.
roomId := "test-room"

View file

@ -49,7 +49,7 @@ type backendStorageEtcd struct {
closeFunc context.CancelFunc
}
func NewBackendStorageEtcd(logger Logger, config *goconf.ConfigFile, etcdClient *EtcdClient) (BackendStorage, error) {
func NewBackendStorageEtcd(logger Logger, config *goconf.ConfigFile, etcdClient *EtcdClient, stats BackendStorageStats) (BackendStorage, error) {
if etcdClient == nil || !etcdClient.IsConfigured() {
return nil, errors.New("no etcd endpoints configured")
}
@ -64,6 +64,7 @@ func NewBackendStorageEtcd(logger Logger, config *goconf.ConfigFile, etcdClient
result := &backendStorageEtcd{
backendStorageCommon: backendStorageCommon{
backends: make(map[string][]*Backend),
stats: stats,
},
logger: logger,
etcdClient: etcdClient,
@ -231,7 +232,7 @@ func (s *backendStorageEtcd) EtcdKeyUpdated(client *EtcdClient, key string, data
}
updateBackendStats(backend)
if added {
statsBackendsCurrent.Inc()
s.stats.IncBackends()
}
s.wakeupForTesting()
}
@ -275,7 +276,7 @@ func (s *backendStorageEtcd) EtcdKeyDeleted(client *EtcdClient, key string, prev
if !seen[entry.Id()] {
seen[entry.Id()] = true
updateBackendStats(entry)
statsBackendsCurrent.Dec()
s.stats.DecBackends()
}
continue
}

View file

@ -43,7 +43,7 @@ type backendStorageStatic struct {
compatBackend *Backend
}
func NewBackendStorageStatic(logger Logger, config *goconf.ConfigFile) (BackendStorage, error) {
func NewBackendStorageStatic(logger Logger, config *goconf.ConfigFile, stats BackendStorageStats) (BackendStorage, error) {
allowAll, _ := config.GetBool("backend", "allowall")
allowHttp, _ := config.GetBool("backend", "allowhttp")
commonSecret, _ := GetStringOptionWithEnv(config, "backend", "secret")
@ -157,10 +157,11 @@ func NewBackendStorageStatic(logger Logger, config *goconf.ConfigFile) (BackendS
logger.Printf("WARNING: No backends configured, client connections will not be possible.")
}
statsBackendsCurrent.Add(float64(numBackends))
stats.AddBackends(numBackends)
return &backendStorageStatic{
backendStorageCommon: backendStorageCommon{
backends: backends,
stats: stats,
},
logger: logger,
@ -196,7 +197,7 @@ func (s *backendStorageStatic) RemoveBackendsForHost(host string, seen map[strin
backend.counted = false
}
}
statsBackendsCurrent.Sub(float64(deleted))
s.stats.RemoveBackends(deleted)
}
delete(s.backends, host)
}
@ -247,7 +248,7 @@ func (s *backendStorageStatic) UpsertHost(host string, backends []*Backend, seen
if len(urls) == len(removed.urls) && removed.counted {
deleteBackendStats(removed)
delete(s.backendsById, removed.Id())
statsBackendsCurrent.Dec()
s.stats.DecBackends()
removed.counted = false
}
}
@ -276,7 +277,7 @@ func (s *backendStorageStatic) UpsertHost(host string, backends []*Backend, seen
added.counted = true
}
}
statsBackendsCurrent.Add(float64(addedBackends))
s.stats.AddBackends(addedBackends)
}
func getConfiguredBackendIDs(backendIds string) (ids []string) {

View file

@ -27,6 +27,7 @@ import (
"fmt"
"os"
"sync/atomic"
"testing"
)
type CertificateReloader struct {
@ -49,17 +50,22 @@ func NewCertificateReloader(logger Logger, certFile string, keyFile string) (*Ce
return nil, fmt.Errorf("could not load certificate / key: %w", err)
}
deduplicate := defaultDeduplicateWatchEvents
if testing.Testing() {
deduplicate = 0
}
reloader := &CertificateReloader{
logger: logger,
certFile: certFile,
keyFile: keyFile,
}
reloader.certificate.Store(&pair)
reloader.certWatcher, err = NewFileWatcher(reloader.logger, certFile, reloader.reload)
reloader.certWatcher, err = NewFileWatcher(reloader.logger, certFile, reloader.reload, deduplicate)
if err != nil {
return nil, err
}
reloader.keyWatcher, err = NewFileWatcher(reloader.logger, keyFile, reloader.reload)
reloader.keyWatcher, err = NewFileWatcher(reloader.logger, keyFile, reloader.reload, deduplicate)
if err != nil {
reloader.certWatcher.Close() // nolint
return nil, err
@ -132,12 +138,17 @@ func NewCertPoolReloader(logger Logger, certFile string) (*CertPoolReloader, err
return nil, err
}
deduplicate := defaultDeduplicateWatchEvents
if testing.Testing() {
deduplicate = 0
}
reloader := &CertPoolReloader{
logger: logger,
certFile: certFile,
}
reloader.pool.Store(pool)
reloader.certWatcher, err = NewFileWatcher(reloader.logger, certFile, reloader.reload)
reloader.certWatcher, err = NewFileWatcher(reloader.logger, certFile, reloader.reload, deduplicate)
if err != nil {
return nil, err
}

View file

@ -23,22 +23,9 @@ package signaling
import (
"context"
"testing"
"time"
)
func UpdateCertificateCheckIntervalForTest(t *testing.T, interval time.Duration) {
t.Helper()
// Make sure test is not executed with "t.Parallel()"
t.Setenv("PARALLEL_CHECK", "1")
old := deduplicateWatchEvents.Load()
t.Cleanup(func() {
deduplicateWatchEvents.Store(old)
})
deduplicateWatchEvents.Store(int64(interval))
}
func (r *CertificateReloader) WaitForReload(ctx context.Context, counter uint64) error {
for counter == r.GetReloadCounter() {
if err := ctx.Err(); err != nil {

View file

@ -51,6 +51,7 @@ func TestBandwidth_Client(t *testing.T) {
hub.SetMcu(mcu)
client, hello := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId)
defer client.CloseWithBye()
// Join room by id.
roomId := "test-room"
@ -476,6 +477,7 @@ func TestPermissionHideDisplayNames(t *testing.T) {
defer cancel()
client, hello := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId)
defer client.CloseWithBye()
roomId := "test-room"
roomMsg := MustSucceed2(t, client.JoinRoom, ctx, roomId)
@ -541,6 +543,7 @@ func TestPermissionHideDisplayNames(t *testing.T) {
}
client2, hello2 := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId+"2")
defer client2.CloseWithBye()
roomMsg2 := MustSucceed2(t, client2.JoinRoom, ctx, roomId)
require.Equal(roomId, roomMsg2.Room.RoomId)

View file

@ -32,10 +32,6 @@ import (
"time"
)
var (
lookupDnsMonitorIP = net.LookupIP
)
const (
defaultDnsMonitorInterval = time.Second
)
@ -157,9 +153,12 @@ func (e *dnsMonitorEntry) runCallbacks(all []net.IP, add []net.IP, keep []net.IP
}
}
type DnsMonitorLookupFunc func(hostname string) ([]net.IP, error)
type DnsMonitor struct {
logger Logger
interval time.Duration
logger Logger
interval time.Duration
lookupFunc DnsMonitorLookupFunc
stopCtx context.Context
stopFunc func()
@ -176,15 +175,19 @@ type DnsMonitor struct {
checkHostnames func()
}
func NewDnsMonitor(logger Logger, interval time.Duration) (*DnsMonitor, error) {
func NewDnsMonitor(logger Logger, interval time.Duration, lookupFunc DnsMonitorLookupFunc) (*DnsMonitor, error) {
if interval < 0 {
interval = defaultDnsMonitorInterval
}
if lookupFunc == nil {
lookupFunc = net.LookupIP
}
stopCtx, stopFunc := context.WithCancel(context.Background())
monitor := &DnsMonitor{
logger: logger,
interval: interval,
logger: logger,
interval: interval,
lookupFunc: lookupFunc,
stopCtx: stopCtx,
stopFunc: stopFunc,
@ -347,7 +350,7 @@ func (m *DnsMonitor) checkHostname(entry *dnsMonitorEntry) {
return
}
ips, err := lookupDnsMonitorIP(entry.hostname)
ips, err := m.lookupFunc(entry.hostname)
if err != nil {
m.logger.Printf("Could not lookup %s: %s", entry.hostname, err)
return

View file

@ -44,15 +44,9 @@ type mockDnsLookup struct {
func newMockDnsLookupForTest(t *testing.T) *mockDnsLookup {
t.Helper()
t.Setenv("PARALLEL_CHECK", "1")
mock := &mockDnsLookup{
ips: make(map[string][]net.IP),
}
prev := lookupDnsMonitorIP
t.Cleanup(func() {
lookupDnsMonitorIP = prev
})
lookupDnsMonitorIP = mock.lookup
return mock
}
@ -86,12 +80,16 @@ func (m *mockDnsLookup) lookup(host string) ([]net.IP, error) {
return append([]net.IP{}, ips...), nil
}
func newDnsMonitorForTest(t *testing.T, interval time.Duration) *DnsMonitor {
func newDnsMonitorForTest(t *testing.T, interval time.Duration, lookup *mockDnsLookup) *DnsMonitor {
t.Helper()
require := require.New(t)
logger := NewLoggerForTest(t)
monitor, err := NewDnsMonitor(logger, interval)
var lookupFunc DnsMonitorLookupFunc
if lookup != nil {
lookupFunc = lookup.lookup
}
monitor, err := NewDnsMonitor(logger, interval, lookupFunc)
require.NoError(err)
t.Cleanup(func() {
@ -223,13 +221,14 @@ func (r *dnsMonitorReceiver) ExpectNone() {
r.expected = expectNone
}
func TestDnsMonitor(t *testing.T) { // nolint:paralleltest
func TestDnsMonitor(t *testing.T) {
t.Parallel()
lookup := newMockDnsLookupForTest(t)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
interval := time.Millisecond
monitor := newDnsMonitorForTest(t, interval)
monitor := newDnsMonitorForTest(t, interval, lookup)
ip1 := net.ParseIP("192.168.0.1")
ip2 := net.ParseIP("192.168.1.1")
@ -297,7 +296,7 @@ func TestDnsMonitorIP(t *testing.T) {
defer cancel()
interval := time.Millisecond
monitor := newDnsMonitorForTest(t, interval)
monitor := newDnsMonitorForTest(t, interval, nil)
ip := "192.168.0.1"
ips := []net.IP{
@ -317,9 +316,10 @@ func TestDnsMonitorIP(t *testing.T) {
time.Sleep(5 * interval)
}
func TestDnsMonitorNoLookupIfEmpty(t *testing.T) { // nolint:paralleltest
func TestDnsMonitorNoLookupIfEmpty(t *testing.T) {
t.Parallel()
interval := time.Millisecond
monitor := newDnsMonitorForTest(t, interval)
monitor := newDnsMonitorForTest(t, interval, nil)
var checked atomic.Bool
monitor.checkHostnames = func() {
@ -402,14 +402,15 @@ func (r *deadlockMonitorReceiver) Close() {
r.wg.Wait()
}
func TestDnsMonitorDeadlock(t *testing.T) { // nolint:paralleltest
func TestDnsMonitorDeadlock(t *testing.T) {
t.Parallel()
lookup := newMockDnsLookupForTest(t)
ip1 := net.ParseIP("192.168.0.1")
ip2 := net.ParseIP("192.168.0.2")
lookup.Set("foo", []net.IP{ip1})
interval := time.Millisecond
monitor := newDnsMonitorForTest(t, interval)
monitor := newDnsMonitorForTest(t, interval, lookup)
r := newDeadlockMonitorReceiver(t, monitor)
r.Start()

View file

@ -29,7 +29,6 @@ import (
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/fsnotify/fsnotify"
@ -39,28 +38,21 @@ const (
defaultDeduplicateWatchEvents = 100 * time.Millisecond
)
var (
deduplicateWatchEvents atomic.Int64
)
func init() {
deduplicateWatchEvents.Store(int64(defaultDeduplicateWatchEvents))
}
type FileWatcherCallback func(filename string)
type FileWatcher struct {
logger Logger
filename string
target string
callback FileWatcherCallback
logger Logger
filename string
target string
callback FileWatcherCallback
deduplicate time.Duration
watcher *fsnotify.Watcher
closeCtx context.Context
closeFunc context.CancelFunc
}
func NewFileWatcher(logger Logger, filename string, callback FileWatcherCallback) (*FileWatcher, error) {
func NewFileWatcher(logger Logger, filename string, callback FileWatcherCallback, deduplicate time.Duration) (*FileWatcher, error) {
watcher, err := fsnotify.NewWatcher()
if err != nil {
return nil, err
@ -74,10 +66,11 @@ func NewFileWatcher(logger Logger, filename string, callback FileWatcherCallback
closeCtx, closeFunc := context.WithCancel(context.Background())
w := &FileWatcher{
logger: logger,
filename: filename,
callback: callback,
watcher: watcher,
logger: logger,
filename: filename,
callback: callback,
deduplicate: deduplicate,
watcher: watcher,
closeCtx: closeCtx,
closeFunc: closeFunc,
@ -115,8 +108,7 @@ func (f *FileWatcher) run() {
timers := make(map[string]*time.Timer)
triggerEvent := func(event fsnotify.Event) {
deduplicate := time.Duration(deduplicateWatchEvents.Load())
if deduplicate <= 0 {
if f.deduplicate <= 0 {
f.callback(f.filename)
return
}
@ -128,7 +120,7 @@ func (f *FileWatcher) run() {
t, found := timers[filename]
mu.Unlock()
if !found {
t = time.AfterFunc(deduplicate, func() {
t = time.AfterFunc(f.deduplicate, func() {
f.callback(f.filename)
mu.Lock()
@ -139,7 +131,7 @@ func (f *FileWatcher) run() {
timers[filename] = t
mu.Unlock()
} else {
t.Reset(deduplicate)
t.Reset(f.deduplicate)
}
}

View file

@ -40,7 +40,7 @@ func TestFileWatcher_NotExist(t *testing.T) {
assert := assert.New(t)
tmpdir := t.TempDir()
logger := NewLoggerForTest(t)
if w, err := NewFileWatcher(logger, path.Join(tmpdir, "test.txt"), func(filename string) {}); !assert.ErrorIs(err, os.ErrNotExist) {
if w, err := NewFileWatcher(logger, path.Join(tmpdir, "test.txt"), func(filename string) {}, defaultDeduplicateWatchEvents); !assert.ErrorIs(err, os.ErrNotExist) {
if w != nil {
assert.NoError(w.Close())
}
@ -59,7 +59,7 @@ func TestFileWatcher_File(t *testing.T) { // nolint:paralleltest
modified := make(chan struct{})
w, err := NewFileWatcher(logger, filename, func(filename string) {
modified <- struct{}{}
})
}, defaultDeduplicateWatchEvents)
require.NoError(err)
defer w.Close()
@ -94,7 +94,7 @@ func TestFileWatcher_CurrentDir(t *testing.T) { // nolint:paralleltest
require := require.New(t)
assert := assert.New(t)
tmpdir := t.TempDir()
require.NoError(os.Chdir(tmpdir))
t.Chdir(tmpdir)
filename := path.Join(tmpdir, "test.txt")
require.NoError(os.WriteFile(filename, []byte("Hello world!"), 0644))
@ -102,7 +102,7 @@ func TestFileWatcher_CurrentDir(t *testing.T) { // nolint:paralleltest
modified := make(chan struct{})
w, err := NewFileWatcher(logger, "./"+path.Base(filename), func(filename string) {
modified <- struct{}{}
})
}, defaultDeduplicateWatchEvents)
require.NoError(err)
defer w.Close()
@ -144,7 +144,7 @@ func TestFileWatcher_Rename(t *testing.T) {
modified := make(chan struct{})
w, err := NewFileWatcher(logger, filename, func(filename string) {
modified <- struct{}{}
})
}, defaultDeduplicateWatchEvents)
require.NoError(err)
defer w.Close()
@ -188,7 +188,7 @@ func TestFileWatcher_Symlink(t *testing.T) {
modified := make(chan struct{})
w, err := NewFileWatcher(logger, filename, func(filename string) {
modified <- struct{}{}
})
}, defaultDeduplicateWatchEvents)
require.NoError(err)
defer w.Close()
@ -223,7 +223,7 @@ func TestFileWatcher_ChangeSymlinkTarget(t *testing.T) {
modified := make(chan struct{})
w, err := NewFileWatcher(logger, filename, func(filename string) {
modified <- struct{}{}
})
}, defaultDeduplicateWatchEvents)
require.NoError(err)
defer w.Close()
@ -260,7 +260,7 @@ func TestFileWatcher_OtherSymlink(t *testing.T) {
modified := make(chan struct{})
w, err := NewFileWatcher(logger, filename, func(filename string) {
modified <- struct{}{}
})
}, defaultDeduplicateWatchEvents)
require.NoError(err)
defer w.Close()
@ -291,7 +291,7 @@ func TestFileWatcher_RenameSymlinkTarget(t *testing.T) {
modified := make(chan struct{})
w, err := NewFileWatcher(logger, filename, func(filename string) {
modified <- struct{}{}
})
}, defaultDeduplicateWatchEvents)
require.NoError(err)
defer w.Close()
@ -345,7 +345,7 @@ func TestFileWatcher_UpdateSymlinkFolder(t *testing.T) {
modified := make(chan struct{})
w, err := NewFileWatcher(logger, filename, func(filename string) {
modified <- struct{}{}
})
}, defaultDeduplicateWatchEvents)
require.NoError(err)
defer w.Close()

View file

@ -51,8 +51,8 @@ func (c *GrpcClients) getWakeupChannelForTesting() <-chan struct{} {
return ch
}
func NewGrpcClientsForTestWithConfig(t *testing.T, config *goconf.ConfigFile, etcdClient *EtcdClient) (*GrpcClients, *DnsMonitor) {
dnsMonitor := newDnsMonitorForTest(t, time.Hour) // will be updated manually
func NewGrpcClientsForTestWithConfig(t *testing.T, config *goconf.ConfigFile, etcdClient *EtcdClient, lookup *mockDnsLookup) (*GrpcClients, *DnsMonitor) {
dnsMonitor := newDnsMonitorForTest(t, time.Hour, lookup) // will be updated manually
logger := NewLoggerForTest(t)
ctx := NewLoggerContext(t.Context(), logger)
client, err := NewGrpcClients(ctx, config, etcdClient, dnsMonitor, "0.0.0")
@ -64,15 +64,15 @@ func NewGrpcClientsForTestWithConfig(t *testing.T, config *goconf.ConfigFile, et
return client, dnsMonitor
}
func NewGrpcClientsForTest(t *testing.T, addr string) (*GrpcClients, *DnsMonitor) {
func NewGrpcClientsForTest(t *testing.T, addr string, lookup *mockDnsLookup) (*GrpcClients, *DnsMonitor) {
config := goconf.NewConfigFile()
config.AddOption("grpc", "targets", addr)
config.AddOption("grpc", "dnsdiscovery", "true")
return NewGrpcClientsForTestWithConfig(t, config, nil)
return NewGrpcClientsForTestWithConfig(t, config, nil, lookup)
}
func NewGrpcClientsWithEtcdForTest(t *testing.T, etcd *embed.Etcd) (*GrpcClients, *DnsMonitor) {
func NewGrpcClientsWithEtcdForTest(t *testing.T, etcd *embed.Etcd, lookup *mockDnsLookup) (*GrpcClients, *DnsMonitor) {
config := goconf.NewConfigFile()
config.AddOption("etcd", "endpoints", etcd.Config().ListenClientUrls[0].String())
@ -86,7 +86,7 @@ func NewGrpcClientsWithEtcdForTest(t *testing.T, etcd *embed.Etcd) (*GrpcClients
assert.NoError(t, etcdClient.Close())
})
return NewGrpcClientsForTestWithConfig(t, config, etcdClient)
return NewGrpcClientsForTestWithConfig(t, config, etcdClient, lookup)
}
func drainWakeupChannel(ch <-chan struct{}) {
@ -122,7 +122,7 @@ func Test_GrpcClients_EtcdInitial(t *testing.T) { // nolint:paralleltest
SetEtcdValue(etcd, "/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}"))
SetEtcdValue(etcd, "/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}"))
client, _ := NewGrpcClientsWithEtcdForTest(t, etcd)
client, _ := NewGrpcClientsWithEtcdForTest(t, etcd, nil)
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
require.NoError(t, client.WaitForInitialized(ctx))
@ -138,7 +138,7 @@ func Test_GrpcClients_EtcdUpdate(t *testing.T) {
ctx := NewLoggerContext(t.Context(), logger)
assert := assert.New(t)
etcd := NewEtcdForTest(t)
client, _ := NewGrpcClientsWithEtcdForTest(t, etcd)
client, _ := NewGrpcClientsWithEtcdForTest(t, etcd, nil)
ch := client.getWakeupChannelForTesting()
ctx, cancel := context.WithTimeout(ctx, testTimeout)
@ -185,7 +185,7 @@ func Test_GrpcClients_EtcdIgnoreSelf(t *testing.T) {
ctx := NewLoggerContext(t.Context(), logger)
assert := assert.New(t)
etcd := NewEtcdForTest(t)
client, _ := NewGrpcClientsWithEtcdForTest(t, etcd)
client, _ := NewGrpcClientsWithEtcdForTest(t, etcd, nil)
ch := client.getWakeupChannelForTesting()
ctx, cancel := context.WithTimeout(ctx, testTimeout)
@ -232,7 +232,7 @@ func Test_GrpcClients_DnsDiscovery(t *testing.T) { // nolint:paralleltest
targetWithIp1 := fmt.Sprintf("%s (%s)", target, ip1)
targetWithIp2 := fmt.Sprintf("%s (%s)", target, ip2)
lookup.Set("testgrpc", []net.IP{ip1})
client, dnsMonitor := NewGrpcClientsForTest(t, target)
client, dnsMonitor := NewGrpcClientsForTest(t, target, lookup)
ch := client.getWakeupChannelForTesting()
ctx, cancel := context.WithTimeout(ctx, testTimeout)
@ -274,13 +274,14 @@ func Test_GrpcClients_DnsDiscovery(t *testing.T) { // nolint:paralleltest
})
}
func Test_GrpcClients_DnsDiscoveryInitialFailed(t *testing.T) { // nolint:paralleltest
func Test_GrpcClients_DnsDiscoveryInitialFailed(t *testing.T) {
t.Parallel()
assert := assert.New(t)
lookup := newMockDnsLookupForTest(t)
target := "testgrpc:12345"
ip1 := net.ParseIP("192.168.0.1")
targetWithIp1 := fmt.Sprintf("%s (%s)", target, ip1)
client, dnsMonitor := NewGrpcClientsForTest(t, target)
client, dnsMonitor := NewGrpcClientsForTest(t, target, lookup)
ch := client.getWakeupChannelForTesting()
testCtx, testCtxCancel := context.WithTimeout(context.Background(), testTimeout)
@ -339,7 +340,7 @@ func Test_GrpcClients_Encryption(t *testing.T) { // nolint:paralleltest
clientConfig.AddOption("grpc", "clientcertificate", clientCertFile)
clientConfig.AddOption("grpc", "clientkey", clientPrivkeyFile)
clientConfig.AddOption("grpc", "serverca", serverCertFile)
clients, _ := NewGrpcClientsForTestWithConfig(t, clientConfig, nil)
clients, _ := NewGrpcClientsForTestWithConfig(t, clientConfig, nil, nil)
ctx, cancel1 := context.WithTimeout(context.Background(), time.Second)
defer cancel1()

View file

@ -97,7 +97,8 @@ func NewGrpcServerForTest(t *testing.T) (server *GrpcServer, addr string) {
return NewGrpcServerForTestWithConfig(t, config)
}
func Test_GrpcServer_ReloadCerts(t *testing.T) { // nolint:paralleltest
func Test_GrpcServer_ReloadCerts(t *testing.T) {
t.Parallel()
require := require.New(t)
assert := assert.New(t)
key, err := rsa.GenerateKey(rand.Reader, 1024)
@ -118,7 +119,6 @@ func Test_GrpcServer_ReloadCerts(t *testing.T) { // nolint:paralleltest
config.AddOption("grpc", "servercertificate", certFile)
config.AddOption("grpc", "serverkey", privkeyFile)
UpdateCertificateCheckIntervalForTest(t, 0)
server, addr := NewGrpcServerForTestWithConfig(t, config)
cp1 := x509.NewCertPool()
@ -167,7 +167,8 @@ func Test_GrpcServer_ReloadCerts(t *testing.T) { // nolint:paralleltest
}
}
func Test_GrpcServer_ReloadCA(t *testing.T) { // nolint:paralleltest
func Test_GrpcServer_ReloadCA(t *testing.T) {
t.Parallel()
logger := NewLoggerForTest(t)
require := require.New(t)
serverKey, err := rsa.GenerateKey(rand.Reader, 1024)
@ -194,7 +195,6 @@ func Test_GrpcServer_ReloadCA(t *testing.T) { // nolint:paralleltest
config.AddOption("grpc", "serverkey", privkeyFile)
config.AddOption("grpc", "clientca", caFile)
UpdateCertificateCheckIntervalForTest(t, 0)
server, addr := NewGrpcServerForTestWithConfig(t, config)
pool := x509.NewCertPool()

View file

@ -246,7 +246,7 @@ func CreateClusteredHubsForTestWithConfig(t *testing.T, getConfigFunc func(*http
})
config1, err := getConfigFunc(server1)
require.NoError(err)
client1, _ := NewGrpcClientsForTest(t, addr2)
client1, _ := NewGrpcClientsForTest(t, addr2, nil)
h1, err := NewHub(ctx, config1, events1, grpcServer1, client1, nil, r1, "no-version")
require.NoError(err)
b1, err := NewBackendServer(ctx, config1, h1, "no-version")
@ -260,7 +260,7 @@ func CreateClusteredHubsForTestWithConfig(t *testing.T, getConfigFunc func(*http
})
config2, err := getConfigFunc(server2)
require.NoError(err)
client2, _ := NewGrpcClientsForTest(t, addr1)
client2, _ := NewGrpcClientsForTest(t, addr1, nil)
h2, err := NewHub(ctx, config2, events2, grpcServer2, client2, nil, r2, "no-version")
require.NoError(err)
b2, err := NewBackendServer(ctx, config2, h2, "no-version")
@ -809,16 +809,6 @@ func registerBackendHandlerUrl(t *testing.T, router *mux.Router, url string) {
}
}
func performHousekeeping(hub *Hub, now time.Time) *sync.WaitGroup {
var wg sync.WaitGroup
wg.Add(1)
go func() {
hub.performHousekeeping(now)
wg.Done()
}()
return &wg
}
func Benchmark_DecodePrivateSessionIdCached(b *testing.B) {
require := require.New(b)
decodeCaches := make([]*LruCache[*SessionIdData], 0, numDecodeCaches)
@ -940,7 +930,7 @@ func TestExpectClientHello(t *testing.T) {
// Perform housekeeping in the future, this will cause the connection to
// be terminated due to the missing "Hello" request.
performHousekeeping(hub, time.Now().Add(initialHelloTimeout+time.Second))
hub.performHousekeeping(time.Now().Add(initialHelloTimeout + time.Second))
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
@ -1486,7 +1476,7 @@ func TestClientHelloResumeThrottle(t *testing.T) {
// Perform housekeeping in the future, this will cause the session to be
// cleaned up after it is expired.
performHousekeeping(hub, time.Now().Add(sessionExpireDuration+time.Second)).Wait()
hub.performHousekeeping(time.Now().Add(sessionExpireDuration + time.Second))
client = NewTestClient(t, server, hub)
defer client.CloseWithBye()
@ -1523,7 +1513,7 @@ func TestClientHelloResumeExpired(t *testing.T) {
// Perform housekeeping in the future, this will cause the session to be
// cleaned up after it is expired.
performHousekeeping(hub, time.Now().Add(sessionExpireDuration+time.Second)).Wait()
hub.performHousekeeping(time.Now().Add(sessionExpireDuration + time.Second))
client = NewTestClient(t, server, hub)
defer client.CloseWithBye()
@ -2803,7 +2793,7 @@ func TestExpectAnonymousJoinRoom(t *testing.T) {
// Perform housekeeping in the future, this will cause the connection to
// be terminated because the anonymous client didn't join a room.
performHousekeeping(hub, time.Now().Add(anonmyousJoinRoomTimeout+time.Second))
hub.performHousekeeping(time.Now().Add(anonmyousJoinRoomTimeout + time.Second))
if message, ok := client.RunUntilMessage(ctx); ok {
if checkMessageType(t, message, "bye") {
@ -2840,7 +2830,7 @@ func TestExpectAnonymousJoinRoomAfterLeave(t *testing.T) {
// Perform housekeeping in the future, this will keep the connection as the
// session joined a room.
performHousekeeping(hub, time.Now().Add(anonmyousJoinRoomTimeout+time.Second))
hub.performHousekeeping(time.Now().Add(anonmyousJoinRoomTimeout + time.Second))
// No message about the closing is sent to the new connection.
ctx2, cancel2 := context.WithTimeout(context.Background(), 200*time.Millisecond)
@ -2854,7 +2844,7 @@ func TestExpectAnonymousJoinRoomAfterLeave(t *testing.T) {
// Perform housekeeping in the future, this will cause the connection to
// be terminated because the anonymous client didn't join a room.
performHousekeeping(hub, time.Now().Add(anonmyousJoinRoomTimeout+time.Second))
hub.performHousekeeping(time.Now().Add(anonmyousJoinRoomTimeout + time.Second))
if message, ok := client.RunUntilMessage(ctx); ok {
if checkMessageType(t, message, "bye") {
@ -5096,7 +5086,9 @@ func DoTestSwitchToMultiple(t *testing.T, details1 api.StringMap, details2 api.S
defer cancel()
client1, hello1 := NewTestClientWithHello(ctx, t, server1, hub1, testDefaultUserId+"1")
defer client1.CloseWithBye()
client2, hello2 := NewTestClientWithHello(ctx, t, server2, hub2, testDefaultUserId+"2")
defer client2.CloseWithBye()
roomSessionId1 := RoomSessionId("roomsession1")
roomId1 := "test-room"
@ -5421,7 +5413,7 @@ func TestGracefulShutdownOnExpiration(t *testing.T) {
case <-time.After(100 * time.Millisecond):
}
performHousekeeping(hub, time.Now().Add(sessionExpireDuration+time.Second))
hub.performHousekeeping(time.Now().Add(sessionExpireDuration + time.Second))
select {
case <-hub.ShutdownChannel():

View file

@ -213,6 +213,21 @@ func (s *mcuJanusSettings) Reload(config *goconf.ConfigFile) {
}
}
type mcuJanusStats interface {
IncSubscriber(streamType StreamType)
DecSubscriber(streamType StreamType)
}
type prometheusJanusStats struct{}
func (s *prometheusJanusStats) IncSubscriber(streamType StreamType) {
statsSubscribersCurrent.WithLabelValues(string(streamType)).Inc()
}
func (s *prometheusJanusStats) DecSubscriber(streamType StreamType) {
statsSubscribersCurrent.WithLabelValues(string(streamType)).Dec()
}
type mcuJanus struct {
logger Logger
@ -220,6 +235,7 @@ type mcuJanus struct {
mu sync.Mutex
settings *mcuJanusSettings
stats mcuJanusStats
createJanusGateway func(ctx context.Context, wsURL string, listener GatewayListener) (JanusGatewayInterface, error)
@ -265,6 +281,7 @@ func NewMcuJanus(ctx context.Context, url string, config *goconf.ConfigFile) (Mc
logger: LoggerFromContext(ctx),
url: url,
settings: settings,
stats: &prometheusJanusStats{},
closeChan: make(chan struct{}, 1),
clients: make(map[uint64]clientInterface),
@ -333,7 +350,22 @@ func (m *mcuJanus) Bandwidth() (result *McuClientBandwidthInfo) {
return
}
func (m *mcuJanus) updateBandwidthStats() {
type janusBandwidthStats interface {
SetBandwidth(incoming uint64, outgoing uint64)
}
type prometheusJanusBandwidthStats struct{}
func (s *prometheusJanusBandwidthStats) SetBandwidth(incoming uint64, outgoing uint64) {
statsJanusBandwidthCurrent.WithLabelValues("incoming").Set(float64(incoming))
statsJanusBandwidthCurrent.WithLabelValues("outgoing").Set(float64(outgoing))
}
var (
defaultJanusBandwidthStats = &prometheusJanusBandwidthStats{}
)
func (m *mcuJanus) updateBandwidthStats(stats janusBandwidthStats) {
if info := m.info.Load(); info != nil {
if !info.EventHandlers {
// Event handlers are disabled, no stats will be available.
@ -346,12 +378,14 @@ func (m *mcuJanus) updateBandwidthStats() {
}
}
if stats == nil {
stats = defaultJanusBandwidthStats
}
if bandwidth := m.Bandwidth(); bandwidth != nil {
statsJanusBandwidthCurrent.WithLabelValues("incoming").Set(float64(bandwidth.Received.Bytes()))
statsJanusBandwidthCurrent.WithLabelValues("outgoing").Set(float64(bandwidth.Sent.Bytes()))
stats.SetBandwidth(bandwidth.Received.Bytes(), bandwidth.Sent.Bytes())
} else {
statsJanusBandwidthCurrent.WithLabelValues("incoming").Set(0)
statsJanusBandwidthCurrent.WithLabelValues("outgoing").Set(0)
stats.SetBandwidth(0, 0)
}
}
@ -524,7 +558,7 @@ loop:
case <-ticker.C:
m.sendKeepalive(context.Background())
case <-bandwidthTicker.C:
m.updateBandwidthStats()
m.updateBandwidthStats(nil)
case <-m.closeChan:
break loop
}
@ -890,7 +924,7 @@ func (m *mcuJanus) NewSubscriber(ctx context.Context, listener McuListener, publ
client.mcuJanusClient.handleMedia = client.handleMedia
m.registerClient(client)
go client.run(handle, client.closeChan)
statsSubscribersCurrent.WithLabelValues(string(streamType)).Inc()
m.stats.IncSubscriber(streamType)
statsSubscribersTotal.WithLabelValues(string(streamType)).Inc()
return client, nil
}
@ -1060,7 +1094,7 @@ func (m *mcuJanus) NewRemoteSubscriber(ctx context.Context, listener McuListener
client.mcuJanusClient.handleMedia = client.handleMedia
m.registerClient(client)
go client.run(handle, client.closeChan)
statsSubscribersCurrent.WithLabelValues(string(publisher.StreamType())).Inc()
m.stats.IncSubscriber(publisher.StreamType())
statsSubscribersTotal.WithLabelValues(string(publisher.StreamType())).Inc()
return client, nil
}

View file

@ -47,10 +47,13 @@ type TestJanusEventsServerHandler struct {
mcu Mcu
addr string
wg sync.WaitGroup
}
func (h *TestJanusEventsServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.t.Helper()
h.wg.Add(1)
defer h.wg.Done()
assert := assert.New(h.t)
conn, err := h.upgrader.Upgrade(w, r, nil)
assert.NoError(err)
@ -90,6 +93,8 @@ func NewTestJanusEventsHandlerServer(t *testing.T) (*httptest.Server, string, *T
server := httptest.NewServer(handler)
t.Cleanup(func() {
server.Close()
server.CloseClientConnections()
handler.wg.Wait()
})
url := strings.ReplaceAll(server.URL, "http://", "ws://")
url = strings.ReplaceAll(url, "https://", "wss://")
@ -113,6 +118,9 @@ func TestJanusEventsHandlerNoMcu(t *testing.T) {
}
conn, response, err := dialer.DialContext(ctx, url, nil)
require.NoError(err)
defer func() {
assert.NoError(conn.Close())
}()
assert.Equal(JanusEventsSubprotocol, response.Header.Get("Sec-WebSocket-Protocol"))
@ -145,6 +153,9 @@ func TestJanusEventsHandlerInvalidMcu(t *testing.T) {
}
conn, response, err := dialer.DialContext(ctx, url, nil)
require.NoError(err)
defer func() {
assert.NoError(conn.Close())
}()
assert.Equal(JanusEventsSubprotocol, response.Header.Get("Sec-WebSocket-Protocol"))
@ -178,6 +189,9 @@ func TestJanusEventsHandlerPublicIP(t *testing.T) {
}
conn, response, err := dialer.DialContext(ctx, url, nil)
require.NoError(err)
defer func() {
assert.NoError(conn.Close())
}()
assert.Equal(JanusEventsSubprotocol, response.Header.Get("Sec-WebSocket-Protocol"))
@ -305,6 +319,9 @@ func TestJanusEventsHandlerDifferentTypes(t *testing.T) {
}
conn, response, err := dialer.DialContext(ctx, url, nil)
require.NoError(err)
defer func() {
assert.NoError(conn.Close())
}()
assert.Equal(JanusEventsSubprotocol, response.Header.Get("Sec-WebSocket-Protocol"))
@ -519,6 +536,9 @@ func TestJanusEventsHandlerNotGrouped(t *testing.T) {
}
conn, response, err := dialer.DialContext(ctx, url, nil)
require.NoError(err)
defer func() {
assert.NoError(conn.Close())
}()
assert.Equal(JanusEventsSubprotocol, response.Header.Get("Sec-WebSocket-Protocol"))
@ -595,6 +615,9 @@ func TestJanusEventsHandlerGrouped(t *testing.T) {
}
conn, response, err := dialer.DialContext(ctx, url, nil)
require.NoError(err)
defer func() {
assert.NoError(conn.Close())
}()
assert.Equal(JanusEventsSubprotocol, response.Header.Get("Sec-WebSocket-Protocol"))

View file

@ -140,7 +140,7 @@ func (p *mcuJanusSubscriber) closeClient(ctx context.Context) bool {
return false
}
statsSubscribersCurrent.WithLabelValues(string(p.streamType)).Dec()
p.mcu.stats.DecSubscriber(p.streamType)
return true
}
@ -226,7 +226,7 @@ retry:
p.sid = strconv.FormatUint(handle.Id, 10)
p.listener.SubscriberSidUpdated(p)
p.closeChan = make(chan struct{}, 1)
statsSubscribersCurrent.WithLabelValues(string(p.streamType)).Inc()
p.mcu.stats.IncSubscriber(p.streamType)
go p.run(handle, p.closeChan)
p.logger.Printf("Already connected subscriber %d for %s, leaving and re-joining on handle %d", p.id, p.streamType, p.handleId.Load())
goto retry

View file

@ -575,7 +575,10 @@ func (g *TestJanusGateway) send(msg api.StringMap, t *transaction) (uint64, erro
func (g *TestJanusGateway) removeTransaction(id uint64) {
g.mu.Lock()
defer g.mu.Unlock()
delete(g.transactions, id)
if t, found := g.transactions[id]; found {
delete(g.transactions, id)
t.quit()
}
}
func (g *TestJanusGateway) removeSession(session *JanusSession) {
@ -1081,10 +1084,20 @@ func Test_JanusPublisherGetStreamsAudioVideo(t *testing.T) {
}
}
func Test_JanusPublisherSubscriber(t *testing.T) { // nolint:paralleltest
ResetStatsValue(t, statsJanusBandwidthCurrent.WithLabelValues("incoming"))
ResetStatsValue(t, statsJanusBandwidthCurrent.WithLabelValues("outgoing"))
type mockBandwidthStats struct {
incoming uint64
outgoing uint64
}
func (s *mockBandwidthStats) SetBandwidth(incoming uint64, outgoing uint64) {
s.incoming = incoming
s.outgoing = outgoing
}
func Test_JanusPublisherSubscriber(t *testing.T) {
t.Parallel()
stats := &mockBandwidthStats{}
require := require.New(t)
assert := assert.New(t)
@ -1096,9 +1109,9 @@ func Test_JanusPublisherSubscriber(t *testing.T) { // nolint:paralleltest
// Bandwidth for unknown handles is ignored.
mcu.UpdateBandwidth(1234, "video", api.BandwidthFromBytes(100), api.BandwidthFromBytes(200))
mcu.updateBandwidthStats()
checkStatsValue(t, statsJanusBandwidthCurrent.WithLabelValues("incoming"), 0)
checkStatsValue(t, statsJanusBandwidthCurrent.WithLabelValues("outgoing"), 0)
mcu.updateBandwidthStats(stats)
assert.EqualValues(0, stats.incoming)
assert.EqualValues(0, stats.outgoing)
pubId := PublicSessionId("publisher-id")
listener1 := &TestMcuListener{
@ -1128,9 +1141,9 @@ func Test_JanusPublisherSubscriber(t *testing.T) { // nolint:paralleltest
assert.Equal(api.BandwidthFromBytes(1000), bw.Sent)
assert.Equal(api.BandwidthFromBytes(2000), bw.Received)
}
mcu.updateBandwidthStats()
checkStatsValue(t, statsJanusBandwidthCurrent.WithLabelValues("incoming"), 2000)
checkStatsValue(t, statsJanusBandwidthCurrent.WithLabelValues("outgoing"), 1000)
mcu.updateBandwidthStats(stats)
assert.EqualValues(2000, stats.incoming)
assert.EqualValues(1000, stats.outgoing)
listener2 := &TestMcuListener{
id: pubId,
@ -1156,11 +1169,11 @@ func Test_JanusPublisherSubscriber(t *testing.T) { // nolint:paralleltest
assert.Equal(api.BandwidthFromBytes(4000), bw.Sent)
assert.Equal(api.BandwidthFromBytes(6000), bw.Received)
}
checkStatsValue(t, statsJanusBandwidthCurrent.WithLabelValues("incoming"), 2000)
checkStatsValue(t, statsJanusBandwidthCurrent.WithLabelValues("outgoing"), 1000)
mcu.updateBandwidthStats()
checkStatsValue(t, statsJanusBandwidthCurrent.WithLabelValues("incoming"), 6000)
checkStatsValue(t, statsJanusBandwidthCurrent.WithLabelValues("outgoing"), 4000)
assert.EqualValues(2000, stats.incoming)
assert.EqualValues(1000, stats.outgoing)
mcu.updateBandwidthStats(stats)
assert.EqualValues(6000, stats.incoming)
assert.EqualValues(4000, stats.outgoing)
}
func Test_JanusSubscriberPublisher(t *testing.T) {
@ -1402,18 +1415,61 @@ func Test_JanusRemotePublisher(t *testing.T) {
assert.EqualValues(1, removed.Load())
}
func Test_JanusSubscriberNoSuchRoom(t *testing.T) { // nolint:paralleltest
ResetStatsValue(t, statsSubscribersCurrent.WithLabelValues("video"))
t.Cleanup(func() {
if !t.Failed() {
checkStatsValue(t, statsSubscribersCurrent.WithLabelValues("video"), 0)
}
})
type mockJanusStats struct {
called atomic.Bool
mu sync.Mutex
// +checklocks:mu
value map[StreamType]int
}
func (s *mockJanusStats) Value(streamType StreamType) int {
s.mu.Lock()
defer s.mu.Unlock()
return s.value[streamType]
}
func (s *mockJanusStats) IncSubscriber(streamType StreamType) {
s.called.Store(true)
s.mu.Lock()
defer s.mu.Unlock()
if s.value == nil {
s.value = make(map[StreamType]int)
}
s.value[streamType]++
}
func (s *mockJanusStats) DecSubscriber(streamType StreamType) {
s.called.Store(true)
s.mu.Lock()
defer s.mu.Unlock()
if s.value == nil {
s.value = make(map[StreamType]int)
}
s.value[streamType]--
}
func Test_JanusSubscriberNoSuchRoom(t *testing.T) {
t.Parallel()
require := require.New(t)
assert := assert.New(t)
stats := &mockJanusStats{}
t.Cleanup(func() {
if !t.Failed() {
assert.True(stats.called.Load(), "stats were not called")
assert.Equal(0, stats.Value("video"))
}
})
mcu, gateway := newMcuJanusForTesting(t)
mcu.stats = stats
gateway.registerHandlers(map[string]TestJanusHandler{
"configure": func(room *TestJanusRoom, body, jsep api.StringMap) (any, *janus.ErrorMsg) {
assert.EqualValues(1, room.id)
@ -1501,18 +1557,21 @@ func Test_JanusSubscriberNoSuchRoom(t *testing.T) { // nolint:paralleltest
client2.RunUntilOffer(ctx, MockSdpOfferAudioAndVideo)
}
func test_JanusSubscriberAlreadyJoined(t *testing.T) { // nolint:paralleltest
ResetStatsValue(t, statsSubscribersCurrent.WithLabelValues("video"))
t.Cleanup(func() {
if !t.Failed() {
checkStatsValue(t, statsSubscribersCurrent.WithLabelValues("video"), 0)
}
})
func test_JanusSubscriberAlreadyJoined(t *testing.T) {
require := require.New(t)
assert := assert.New(t)
stats := &mockJanusStats{}
t.Cleanup(func() {
if !t.Failed() {
assert.True(stats.called.Load(), "stats were not called")
assert.Equal(0, stats.Value("video"))
}
})
mcu, gateway := newMcuJanusForTesting(t)
mcu.stats = stats
gateway.registerHandlers(map[string]TestJanusHandler{
"configure": func(room *TestJanusRoom, body, jsep api.StringMap) (any, *janus.ErrorMsg) {
assert.EqualValues(1, room.id)
@ -1602,26 +1661,32 @@ func test_JanusSubscriberAlreadyJoined(t *testing.T) { // nolint:paralleltest
client2.RunUntilOffer(ctx, MockSdpOfferAudioAndVideo)
}
func Test_JanusSubscriberAlreadyJoined(t *testing.T) { // nolint:paralleltest
func Test_JanusSubscriberAlreadyJoined(t *testing.T) {
t.Parallel()
test_JanusSubscriberAlreadyJoined(t)
}
func Test_JanusSubscriberAlreadyJoinedAttachError(t *testing.T) { // nolint:paralleltest
func Test_JanusSubscriberAlreadyJoinedAttachError(t *testing.T) {
t.Parallel()
test_JanusSubscriberAlreadyJoined(t)
}
func Test_JanusSubscriberTimeout(t *testing.T) { // nolint:paralleltest
ResetStatsValue(t, statsSubscribersCurrent.WithLabelValues("video"))
t.Cleanup(func() {
if !t.Failed() {
checkStatsValue(t, statsSubscribersCurrent.WithLabelValues("video"), 0)
}
})
func Test_JanusSubscriberTimeout(t *testing.T) {
t.Parallel()
require := require.New(t)
assert := assert.New(t)
stats := &mockJanusStats{}
t.Cleanup(func() {
if !t.Failed() {
assert.True(stats.called.Load(), "stats were not called")
assert.Equal(0, stats.Value("video"))
}
})
mcu, gateway := newMcuJanusForTesting(t)
mcu.stats = stats
gateway.registerHandlers(map[string]TestJanusHandler{
"configure": func(room *TestJanusRoom, body, jsep api.StringMap) (any, *janus.ErrorMsg) {
assert.EqualValues(1, room.id)
@ -1713,18 +1778,22 @@ func Test_JanusSubscriberTimeout(t *testing.T) { // nolint:paralleltest
client2.RunUntilOffer(ctx, MockSdpOfferAudioAndVideo)
}
func Test_JanusSubscriberCloseEmptyStreams(t *testing.T) { // nolint:paralleltest
ResetStatsValue(t, statsSubscribersCurrent.WithLabelValues("video"))
t.Cleanup(func() {
if !t.Failed() {
checkStatsValue(t, statsSubscribersCurrent.WithLabelValues("video"), 0)
}
})
func Test_JanusSubscriberCloseEmptyStreams(t *testing.T) {
t.Parallel()
require := require.New(t)
assert := assert.New(t)
stats := &mockJanusStats{}
t.Cleanup(func() {
if !t.Failed() {
assert.True(stats.called.Load(), "stats were not called")
assert.Equal(0, stats.Value("video"))
}
})
mcu, gateway := newMcuJanusForTesting(t)
mcu.stats = stats
gateway.registerHandlers(map[string]TestJanusHandler{
"configure": func(room *TestJanusRoom, body, jsep api.StringMap) (any, *janus.ErrorMsg) {
assert.EqualValues(1, room.id)
@ -1823,18 +1892,22 @@ func Test_JanusSubscriberCloseEmptyStreams(t *testing.T) { // nolint:paralleltes
assert.Nil(handle, "subscriber should have been closed")
}
func Test_JanusSubscriberRoomDestroyed(t *testing.T) { // nolint:paralleltest
ResetStatsValue(t, statsSubscribersCurrent.WithLabelValues("video"))
t.Cleanup(func() {
if !t.Failed() {
checkStatsValue(t, statsSubscribersCurrent.WithLabelValues("video"), 0)
}
})
func Test_JanusSubscriberRoomDestroyed(t *testing.T) {
t.Parallel()
require := require.New(t)
assert := assert.New(t)
stats := &mockJanusStats{}
t.Cleanup(func() {
if !t.Failed() {
assert.True(stats.called.Load(), "stats were not called")
assert.Equal(0, stats.Value("video"))
}
})
mcu, gateway := newMcuJanusForTesting(t)
mcu.stats = stats
gateway.registerHandlers(map[string]TestJanusHandler{
"configure": func(room *TestJanusRoom, body, jsep api.StringMap) (any, *janus.ErrorMsg) {
assert.EqualValues(1, room.id)
@ -1933,18 +2006,22 @@ func Test_JanusSubscriberRoomDestroyed(t *testing.T) { // nolint:paralleltest
assert.Nil(handle, "subscriber should have been closed")
}
func Test_JanusSubscriberUpdateOffer(t *testing.T) { // nolint:paralleltest
ResetStatsValue(t, statsSubscribersCurrent.WithLabelValues("video"))
t.Cleanup(func() {
if !t.Failed() {
checkStatsValue(t, statsSubscribersCurrent.WithLabelValues("video"), 0)
}
})
func Test_JanusSubscriberUpdateOffer(t *testing.T) {
t.Parallel()
require := require.New(t)
assert := assert.New(t)
stats := &mockJanusStats{}
t.Cleanup(func() {
if !t.Failed() {
assert.True(stats.called.Load(), "stats were not called")
assert.Equal(0, stats.Value("video"))
}
})
mcu, gateway := newMcuJanusForTesting(t)
mcu.stats = stats
gateway.registerHandlers(map[string]TestJanusHandler{
"configure": func(room *TestJanusRoom, body, jsep api.StringMap) (any, *janus.ErrorMsg) {
assert.EqualValues(1, room.id)

View file

@ -845,13 +845,13 @@ type proxyTestOptions struct {
servers []*TestProxyServerHandler
}
func newMcuProxyForTestWithOptions(t *testing.T, options proxyTestOptions, idx int) (*mcuProxy, *goconf.ConfigFile) {
func newMcuProxyForTestWithOptions(t *testing.T, options proxyTestOptions, idx int, lookup *mockDnsLookup) (*mcuProxy, *goconf.ConfigFile) {
t.Helper()
require := require.New(t)
if options.etcd == nil {
options.etcd = NewEtcdForTest(t)
}
grpcClients, dnsMonitor := NewGrpcClientsWithEtcdForTest(t, options.etcd)
grpcClients, dnsMonitor := NewGrpcClientsWithEtcdForTest(t, options.etcd, lookup)
tokenKey, err := rsa.GenerateKey(rand.Reader, 1024)
require.NoError(err)
@ -933,20 +933,20 @@ func newMcuProxyForTestWithOptions(t *testing.T, options proxyTestOptions, idx i
return proxy, cfg
}
func newMcuProxyForTestWithServers(t *testing.T, servers []*TestProxyServerHandler, idx int) *mcuProxy {
func newMcuProxyForTestWithServers(t *testing.T, servers []*TestProxyServerHandler, idx int, lookup *mockDnsLookup) *mcuProxy {
t.Helper()
proxy, _ := newMcuProxyForTestWithOptions(t, proxyTestOptions{
servers: servers,
}, idx)
}, idx, lookup)
return proxy
}
func newMcuProxyForTest(t *testing.T, idx int) *mcuProxy {
func newMcuProxyForTest(t *testing.T, idx int, lookup *mockDnsLookup) *mcuProxy {
t.Helper()
server := NewProxyServerForTest(t, "DE")
return newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{server}, idx)
return newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{server}, idx, lookup)
}
func Test_ProxyAddRemoveConnections(t *testing.T) {
@ -961,7 +961,7 @@ func Test_ProxyAddRemoveConnections(t *testing.T) {
servers: []*TestProxyServerHandler{
server1,
},
}, 0)
}, 0, nil)
server2 := NewProxyServerForTest(t, "DE")
server1.servers = append(server1.servers, server2)
@ -1023,7 +1023,8 @@ func Test_ProxyAddRemoveConnections(t *testing.T) {
assert.NoError(waitCtx.Err(), "error while waiting for connection to be removed")
}
func Test_ProxyAddRemoveConnectionsDnsDiscovery(t *testing.T) { // nolint:paralleltest
func Test_ProxyAddRemoveConnectionsDnsDiscovery(t *testing.T) {
t.Parallel()
assert := assert.New(t)
require := require.New(t)
@ -1052,7 +1053,7 @@ func Test_ProxyAddRemoveConnectionsDnsDiscovery(t *testing.T) { // nolint:parall
servers: []*TestProxyServerHandler{
server1,
},
}, 0)
}, 0, lookup)
if connections := mcu.getConnections(); assert.Len(connections, 1) && assert.NotNil(connections[0].ip) {
assert.True(ip1.Equal(connections[0].ip), "ip addresses differ: expected %s, got %s", ip1.String(), connections[0].ip.String())
@ -1141,7 +1142,7 @@ func Test_ProxyAddRemoveConnectionsDnsDiscovery(t *testing.T) { // nolint:parall
func Test_ProxyPublisherSubscriber(t *testing.T) {
t.Parallel()
mcu := newMcuProxyForTest(t, 0)
mcu := newMcuProxyForTest(t, 0, nil)
ctx, cancel := context.WithTimeout(t.Context(), testTimeout)
defer cancel()
@ -1176,7 +1177,7 @@ func Test_ProxyPublisherSubscriber(t *testing.T) {
func Test_ProxyPublisherCodecs(t *testing.T) {
t.Parallel()
mcu := newMcuProxyForTest(t, 0)
mcu := newMcuProxyForTest(t, 0, nil)
ctx, cancel := context.WithTimeout(t.Context(), testTimeout)
defer cancel()
@ -1202,7 +1203,7 @@ func Test_ProxyPublisherCodecs(t *testing.T) {
func Test_ProxyWaitForPublisher(t *testing.T) {
t.Parallel()
mcu := newMcuProxyForTest(t, 0)
mcu := newMcuProxyForTest(t, 0, nil)
ctx, cancel := context.WithTimeout(t.Context(), testTimeout)
defer cancel()
@ -1256,7 +1257,7 @@ func Test_ProxyPublisherBandwidth(t *testing.T) {
mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{
server1,
server2,
}, 0)
}, 0, nil)
ctx, cancel := context.WithTimeout(t.Context(), testTimeout)
defer cancel()
@ -1325,7 +1326,7 @@ func Test_ProxyPublisherBandwidthOverload(t *testing.T) {
mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{
server1,
server2,
}, 0)
}, 0, nil)
ctx, cancel := context.WithTimeout(t.Context(), testTimeout)
defer cancel()
@ -1397,7 +1398,7 @@ func Test_ProxyPublisherLoad(t *testing.T) {
mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{
server1,
server2,
}, 0)
}, 0, nil)
ctx, cancel := context.WithTimeout(t.Context(), testTimeout)
defer cancel()
@ -1446,7 +1447,7 @@ func Test_ProxyPublisherCountry(t *testing.T) {
mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{
serverDE,
serverUS,
}, 0)
}, 0, nil)
ctx, cancel := context.WithTimeout(t.Context(), testTimeout)
defer cancel()
@ -1493,7 +1494,7 @@ func Test_ProxyPublisherContinent(t *testing.T) {
mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{
serverDE,
serverUS,
}, 0)
}, 0, nil)
ctx, cancel := context.WithTimeout(t.Context(), testTimeout)
defer cancel()
@ -1540,7 +1541,7 @@ func Test_ProxySubscriberCountry(t *testing.T) {
mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{
serverDE,
serverUS,
}, 0)
}, 0, nil)
ctx, cancel := context.WithTimeout(t.Context(), testTimeout)
defer cancel()
@ -1583,7 +1584,7 @@ func Test_ProxySubscriberContinent(t *testing.T) {
mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{
serverDE,
serverUS,
}, 0)
}, 0, nil)
ctx, cancel := context.WithTimeout(t.Context(), testTimeout)
defer cancel()
@ -1626,7 +1627,7 @@ func Test_ProxySubscriberBandwidth(t *testing.T) {
mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{
serverDE,
serverUS,
}, 0)
}, 0, nil)
ctx, cancel := context.WithTimeout(t.Context(), testTimeout)
defer cancel()
@ -1689,7 +1690,7 @@ func Test_ProxySubscriberBandwidthOverload(t *testing.T) {
mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{
serverDE,
serverUS,
}, 0)
}, 0, nil)
ctx, cancel := context.WithTimeout(t.Context(), testTimeout)
defer cancel()
@ -1824,7 +1825,7 @@ func Test_ProxyRemotePublisher(t *testing.T) {
server1,
server2,
},
}, 1)
}, 1, nil)
hub1.proxy.Store(mcu1)
mcu2, _ := newMcuProxyForTestWithOptions(t, proxyTestOptions{
etcd: etcd,
@ -1832,7 +1833,7 @@ func Test_ProxyRemotePublisher(t *testing.T) {
server1,
server2,
},
}, 2)
}, 2, nil)
hub2.proxy.Store(mcu2)
ctx, cancel := context.WithTimeout(t.Context(), testTimeout)
@ -1909,7 +1910,7 @@ func Test_ProxyMultipleRemotePublisher(t *testing.T) {
server2,
server3,
},
}, 1)
}, 1, nil)
hub1.proxy.Store(mcu1)
mcu2, _ := newMcuProxyForTestWithOptions(t, proxyTestOptions{
etcd: etcd,
@ -1918,7 +1919,7 @@ func Test_ProxyMultipleRemotePublisher(t *testing.T) {
server2,
server3,
},
}, 2)
}, 2, nil)
hub2.proxy.Store(mcu2)
mcu3, _ := newMcuProxyForTestWithOptions(t, proxyTestOptions{
etcd: etcd,
@ -1927,7 +1928,7 @@ func Test_ProxyMultipleRemotePublisher(t *testing.T) {
server2,
server3,
},
}, 3)
}, 3, nil)
hub3.proxy.Store(mcu3)
ctx, cancel := context.WithTimeout(t.Context(), testTimeout)
@ -2009,7 +2010,7 @@ func Test_ProxyRemotePublisherWait(t *testing.T) {
server1,
server2,
},
}, 1)
}, 1, nil)
hub1.proxy.Store(mcu1)
mcu2, _ := newMcuProxyForTestWithOptions(t, proxyTestOptions{
etcd: etcd,
@ -2017,7 +2018,7 @@ func Test_ProxyRemotePublisherWait(t *testing.T) {
server1,
server2,
},
}, 2)
}, 2, nil)
hub2.proxy.Store(mcu2)
ctx, cancel := context.WithTimeout(t.Context(), testTimeout)
@ -2103,14 +2104,14 @@ func Test_ProxyRemotePublisherTemporary(t *testing.T) {
servers: []*TestProxyServerHandler{
server1,
},
}, 1)
}, 1, nil)
hub1.proxy.Store(mcu1)
mcu2, _ := newMcuProxyForTestWithOptions(t, proxyTestOptions{
etcd: etcd,
servers: []*TestProxyServerHandler{
server2,
},
}, 2)
}, 2, nil)
hub2.proxy.Store(mcu2)
ctx, cancel := context.WithTimeout(t.Context(), testTimeout)
@ -2214,14 +2215,14 @@ func Test_ProxyConnectToken(t *testing.T) {
servers: []*TestProxyServerHandler{
server1,
},
}, 1)
}, 1, nil)
hub1.proxy.Store(mcu1)
mcu2, _ := newMcuProxyForTestWithOptions(t, proxyTestOptions{
etcd: etcd,
servers: []*TestProxyServerHandler{
server2,
},
}, 2)
}, 2, nil)
hub2.proxy.Store(mcu2)
ctx, cancel := context.WithTimeout(t.Context(), testTimeout)
@ -2296,14 +2297,14 @@ func Test_ProxyPublisherToken(t *testing.T) {
servers: []*TestProxyServerHandler{
server1,
},
}, 1)
}, 1, nil)
hub1.proxy.Store(mcu1)
mcu2, _ := newMcuProxyForTestWithOptions(t, proxyTestOptions{
etcd: etcd,
servers: []*TestProxyServerHandler{
server2,
},
}, 2)
}, 2, nil)
hub2.proxy.Store(mcu2)
// Support remote subscribers for the tests.
server1.servers = append(server1.servers, server2)
@ -2359,7 +2360,7 @@ func Test_ProxyPublisherTimeout(t *testing.T) {
server := NewProxyServerForTest(t, "DE")
mcu, _ := newMcuProxyForTestWithOptions(t, proxyTestOptions{
servers: []*TestProxyServerHandler{server},
}, 0)
}, 0, nil)
ctx, cancel := context.WithTimeout(t.Context(), testTimeout)
defer cancel()
@ -2399,7 +2400,7 @@ func Test_ProxySubscriberTimeout(t *testing.T) {
server := NewProxyServerForTest(t, "DE")
mcu, _ := newMcuProxyForTestWithOptions(t, proxyTestOptions{
servers: []*TestProxyServerHandler{server},
}, 0)
}, 0, nil)
ctx, cancel := context.WithTimeout(t.Context(), testTimeout)
defer cancel()
@ -2459,7 +2460,7 @@ func Test_ProxyReconnectAfter(t *testing.T) {
server := NewProxyServerForTest(t, "DE")
mcu, _ := newMcuProxyForTestWithOptions(t, proxyTestOptions{
servers: []*TestProxyServerHandler{server},
}, 0)
}, 0, nil)
connections := mcu.getSortedConnections(nil)
require.Len(connections, 1)
@ -2499,7 +2500,7 @@ func Test_ProxyReconnectAfterShutdown(t *testing.T) {
server := NewProxyServerForTest(t, "DE")
mcu, _ := newMcuProxyForTestWithOptions(t, proxyTestOptions{
servers: []*TestProxyServerHandler{server},
}, 0)
}, 0, nil)
connections := mcu.getSortedConnections(nil)
require.Len(connections, 1)
@ -2538,7 +2539,7 @@ func Test_ProxyResume(t *testing.T) {
server := NewProxyServerForTest(t, "DE")
mcu, _ := newMcuProxyForTestWithOptions(t, proxyTestOptions{
servers: []*TestProxyServerHandler{server},
}, 0)
}, 0, nil)
connections := mcu.getSortedConnections(nil)
require.Len(connections, 1)
@ -2570,7 +2571,7 @@ func Test_ProxyResumeFail(t *testing.T) {
server := NewProxyServerForTest(t, "DE")
mcu, _ := newMcuProxyForTestWithOptions(t, proxyTestOptions{
servers: []*TestProxyServerHandler{server},
}, 0)
}, 0, nil)
connections := mcu.getSortedConnections(nil)
require.Len(connections, 1)

View file

@ -63,34 +63,30 @@ func CreateLoopbackNatsClientForTest(t *testing.T) NatsClient {
return result
}
func TestLoopbackNatsClient_Subscribe(t *testing.T) { // nolint:paralleltest
ensureNoGoroutinesLeak(t, func(t *testing.T) {
client := CreateLoopbackNatsClientForTest(t)
func TestLoopbackNatsClient_Subscribe(t *testing.T) {
t.Parallel()
testNatsClient_Subscribe(t, client)
})
client := CreateLoopbackNatsClientForTest(t)
testNatsClient_Subscribe(t, client)
}
func TestLoopbackClient_PublishAfterClose(t *testing.T) { // nolint:paralleltest
ensureNoGoroutinesLeak(t, func(t *testing.T) {
client := CreateLoopbackNatsClientForTest(t)
func TestLoopbackClient_PublishAfterClose(t *testing.T) {
t.Parallel()
testNatsClient_PublishAfterClose(t, client)
})
client := CreateLoopbackNatsClientForTest(t)
testNatsClient_PublishAfterClose(t, client)
}
func TestLoopbackClient_SubscribeAfterClose(t *testing.T) { // nolint:paralleltest
ensureNoGoroutinesLeak(t, func(t *testing.T) {
client := CreateLoopbackNatsClientForTest(t)
func TestLoopbackClient_SubscribeAfterClose(t *testing.T) {
t.Parallel()
testNatsClient_SubscribeAfterClose(t, client)
})
client := CreateLoopbackNatsClientForTest(t)
testNatsClient_SubscribeAfterClose(t, client)
}
func TestLoopbackClient_BadSubjects(t *testing.T) { // nolint:paralleltest
ensureNoGoroutinesLeak(t, func(t *testing.T) {
client := CreateLoopbackNatsClientForTest(t)
func TestLoopbackClient_BadSubjects(t *testing.T) {
t.Parallel()
testNatsClient_BadSubjects(t, client)
})
client := CreateLoopbackNatsClientForTest(t)
testNatsClient_BadSubjects(t, client)
}

View file

@ -31,13 +31,13 @@ import (
"github.com/stretchr/testify/require"
)
func newProxyConfigStatic(t *testing.T, proxy McuProxy, dns bool, urls ...string) (ProxyConfig, *DnsMonitor) {
func newProxyConfigStatic(t *testing.T, proxy McuProxy, dns bool, lookup *mockDnsLookup, urls ...string) (ProxyConfig, *DnsMonitor) {
cfg := goconf.NewConfigFile()
cfg.AddOption("mcu", "url", strings.Join(urls, " "))
if dns {
cfg.AddOption("mcu", "dnsdiscovery", "true")
}
dnsMonitor := newDnsMonitorForTest(t, time.Hour) // will be updated manually
dnsMonitor := newDnsMonitorForTest(t, time.Hour, lookup) // will be updated manually
logger := NewLoggerForTest(t)
p, err := NewProxyConfigStatic(logger, cfg, proxy, dnsMonitor)
require.NoError(t, err)
@ -59,7 +59,7 @@ func updateProxyConfigStatic(t *testing.T, config ProxyConfig, dns bool, urls ..
func TestProxyConfigStaticSimple(t *testing.T) {
t.Parallel()
proxy := newMcuProxyForConfig(t)
config, _ := newProxyConfigStatic(t, proxy, false, "https://foo/")
config, _ := newProxyConfigStatic(t, proxy, false, nil, "https://foo/")
proxy.Expect("add", "https://foo/")
require.NoError(t, config.Start())
@ -73,10 +73,11 @@ func TestProxyConfigStaticSimple(t *testing.T) {
updateProxyConfigStatic(t, config, false, "https://bar/", "https://baz/")
}
func TestProxyConfigStaticDNS(t *testing.T) { // nolint:paralleltest
func TestProxyConfigStaticDNS(t *testing.T) {
t.Parallel()
lookup := newMockDnsLookupForTest(t)
proxy := newMcuProxyForConfig(t)
config, dnsMonitor := newProxyConfigStatic(t, proxy, true, "https://foo/")
config, dnsMonitor := newProxyConfigStatic(t, proxy, true, lookup, "https://foo/")
require.NoError(t, config.Start())
time.Sleep(time.Millisecond)

View file

@ -25,6 +25,45 @@ import (
"sync"
)
type publisherStatsCounterStats interface {
IncPublisherStream(streamType StreamType)
DecPublisherStream(streamType StreamType)
IncSubscriberStream(streamType StreamType)
DecSubscriberStream(streamType StreamType)
AddSubscriberStreams(streamType StreamType, count int)
SubSubscriberStreams(streamType StreamType, count int)
}
type prometheusPublisherStats struct{}
func (s *prometheusPublisherStats) IncPublisherStream(streamType StreamType) {
statsMcuPublisherStreamTypesCurrent.WithLabelValues(string(streamType)).Inc()
}
func (s *prometheusPublisherStats) DecPublisherStream(streamType StreamType) {
statsMcuPublisherStreamTypesCurrent.WithLabelValues(string(streamType)).Dec()
}
func (s *prometheusPublisherStats) IncSubscriberStream(streamType StreamType) {
statsMcuSubscriberStreamTypesCurrent.WithLabelValues(string(streamType)).Inc()
}
func (s *prometheusPublisherStats) DecSubscriberStream(streamType StreamType) {
statsMcuSubscriberStreamTypesCurrent.WithLabelValues(string(streamType)).Dec()
}
func (s *prometheusPublisherStats) AddSubscriberStreams(streamType StreamType, count int) {
statsMcuSubscriberStreamTypesCurrent.WithLabelValues(string(streamType)).Add(float64(count))
}
func (s *prometheusPublisherStats) SubSubscriberStreams(streamType StreamType, count int) {
statsMcuSubscriberStreamTypesCurrent.WithLabelValues(string(streamType)).Sub(float64(count))
}
var (
defaultPublisherStats = &prometheusPublisherStats{} // +checklocksignore: Global readonly variable.
)
type publisherStatsCounter struct {
mu sync.Mutex
@ -32,16 +71,23 @@ type publisherStatsCounter struct {
streamTypes map[StreamType]bool
// +checklocks:mu
subscribers map[string]bool
// +checklocks:mu
stats publisherStatsCounterStats
}
func (c *publisherStatsCounter) Reset() {
c.mu.Lock()
defer c.mu.Unlock()
stats := c.stats
if stats == nil {
stats = defaultPublisherStats
}
count := len(c.subscribers)
for streamType := range c.streamTypes {
statsMcuPublisherStreamTypesCurrent.WithLabelValues(string(streamType)).Dec()
statsMcuSubscriberStreamTypesCurrent.WithLabelValues(string(streamType)).Sub(float64(count))
stats.DecPublisherStream(streamType)
stats.SubSubscriberStreams(streamType, count)
}
c.streamTypes = nil
c.subscribers = nil
@ -55,17 +101,22 @@ func (c *publisherStatsCounter) EnableStream(streamType StreamType, enable bool)
return
}
stats := c.stats
if stats == nil {
stats = defaultPublisherStats
}
if enable {
if c.streamTypes == nil {
c.streamTypes = make(map[StreamType]bool)
}
c.streamTypes[streamType] = true
statsMcuPublisherStreamTypesCurrent.WithLabelValues(string(streamType)).Inc()
statsMcuSubscriberStreamTypesCurrent.WithLabelValues(string(streamType)).Add(float64(len(c.subscribers)))
stats.IncPublisherStream(streamType)
stats.AddSubscriberStreams(streamType, len(c.subscribers))
} else {
delete(c.streamTypes, streamType)
statsMcuPublisherStreamTypesCurrent.WithLabelValues(string(streamType)).Dec()
statsMcuSubscriberStreamTypesCurrent.WithLabelValues(string(streamType)).Sub(float64(len(c.subscribers)))
stats.DecPublisherStream(streamType)
stats.SubSubscriberStreams(streamType, len(c.subscribers))
}
}
@ -77,12 +128,17 @@ func (c *publisherStatsCounter) AddSubscriber(id string) {
return
}
stats := c.stats
if stats == nil {
stats = defaultPublisherStats
}
if c.subscribers == nil {
c.subscribers = make(map[string]bool)
}
c.subscribers[id] = true
for streamType := range c.streamTypes {
statsMcuSubscriberStreamTypesCurrent.WithLabelValues(string(streamType)).Inc()
stats.IncSubscriberStream(streamType)
}
}
@ -94,8 +150,13 @@ func (c *publisherStatsCounter) RemoveSubscriber(id string) {
return
}
stats := c.stats
if stats == nil {
stats = defaultPublisherStats
}
delete(c.subscribers, id)
for streamType := range c.streamTypes {
statsMcuSubscriberStreamTypesCurrent.WithLabelValues(string(streamType)).Dec()
stats.DecSubscriberStream(streamType)
}
}

View file

@ -23,89 +23,155 @@ package signaling
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestPublisherStatsCounter(t *testing.T) { // nolint:paralleltest
RegisterJanusMcuStats()
type mockPublisherStats struct {
publishers map[StreamType]int
subscribers map[StreamType]int
}
var c publisherStatsCounter
func (s *mockPublisherStats) IncPublisherStream(streamType StreamType) {
if s.publishers == nil {
s.publishers = make(map[StreamType]int)
}
s.publishers[streamType]++
}
func (s *mockPublisherStats) DecPublisherStream(streamType StreamType) {
if s.publishers == nil {
s.publishers = make(map[StreamType]int)
}
s.publishers[streamType]--
}
func (s *mockPublisherStats) IncSubscriberStream(streamType StreamType) {
if s.subscribers == nil {
s.subscribers = make(map[StreamType]int)
}
s.subscribers[streamType]++
}
func (s *mockPublisherStats) DecSubscriberStream(streamType StreamType) {
if s.subscribers == nil {
s.subscribers = make(map[StreamType]int)
}
s.subscribers[streamType]--
}
func (s *mockPublisherStats) AddSubscriberStreams(streamType StreamType, count int) {
if s.subscribers == nil {
s.subscribers = make(map[StreamType]int)
}
s.subscribers[streamType] += count
}
func (s *mockPublisherStats) SubSubscriberStreams(streamType StreamType, count int) {
if s.subscribers == nil {
s.subscribers = make(map[StreamType]int)
}
s.subscribers[streamType] -= count
}
func (s *mockPublisherStats) Publishers(streamType StreamType) int {
return s.publishers[streamType]
}
func (s *mockPublisherStats) Subscribers(streamType StreamType) int {
return s.subscribers[streamType]
}
func TestPublisherStatsPrometheus(t *testing.T) {
t.Parallel()
RegisterJanusMcuStats()
collectAndLint(t, commonMcuStats...)
}
func TestPublisherStatsCounter(t *testing.T) {
t.Parallel()
assert := assert.New(t)
stats := &mockPublisherStats{}
c := publisherStatsCounter{
stats: stats,
}
c.Reset()
checkStatsValue(t, statsMcuPublisherStreamTypesCurrent.WithLabelValues("audio"), 0)
assert.Equal(0, stats.Publishers("audio"))
c.EnableStream("audio", false)
checkStatsValue(t, statsMcuPublisherStreamTypesCurrent.WithLabelValues("audio"), 0)
assert.Equal(0, stats.Publishers("audio"))
c.EnableStream("audio", true)
checkStatsValue(t, statsMcuPublisherStreamTypesCurrent.WithLabelValues("audio"), 1)
assert.Equal(1, stats.Publishers("audio"))
c.EnableStream("audio", true)
checkStatsValue(t, statsMcuPublisherStreamTypesCurrent.WithLabelValues("audio"), 1)
assert.Equal(1, stats.Publishers("audio"))
c.EnableStream("video", true)
checkStatsValue(t, statsMcuPublisherStreamTypesCurrent.WithLabelValues("audio"), 1)
checkStatsValue(t, statsMcuPublisherStreamTypesCurrent.WithLabelValues("video"), 1)
assert.Equal(1, stats.Publishers("audio"))
assert.Equal(1, stats.Publishers("video"))
c.EnableStream("audio", false)
checkStatsValue(t, statsMcuPublisherStreamTypesCurrent.WithLabelValues("audio"), 0)
checkStatsValue(t, statsMcuPublisherStreamTypesCurrent.WithLabelValues("video"), 1)
assert.Equal(0, stats.Publishers("audio"))
assert.Equal(1, stats.Publishers("video"))
c.EnableStream("audio", false)
checkStatsValue(t, statsMcuPublisherStreamTypesCurrent.WithLabelValues("audio"), 0)
checkStatsValue(t, statsMcuPublisherStreamTypesCurrent.WithLabelValues("video"), 1)
assert.Equal(0, stats.Publishers("audio"))
assert.Equal(1, stats.Publishers("video"))
c.AddSubscriber("1")
checkStatsValue(t, statsMcuPublisherStreamTypesCurrent.WithLabelValues("audio"), 0)
checkStatsValue(t, statsMcuPublisherStreamTypesCurrent.WithLabelValues("video"), 1)
checkStatsValue(t, statsMcuSubscriberStreamTypesCurrent.WithLabelValues("audio"), 0)
checkStatsValue(t, statsMcuSubscriberStreamTypesCurrent.WithLabelValues("video"), 1)
assert.Equal(0, stats.Publishers("audio"))
assert.Equal(1, stats.Publishers("video"))
assert.Equal(0, stats.Subscribers("audio"))
assert.Equal(1, stats.Subscribers("video"))
c.EnableStream("audio", true)
checkStatsValue(t, statsMcuPublisherStreamTypesCurrent.WithLabelValues("audio"), 1)
checkStatsValue(t, statsMcuPublisherStreamTypesCurrent.WithLabelValues("video"), 1)
checkStatsValue(t, statsMcuSubscriberStreamTypesCurrent.WithLabelValues("audio"), 1)
checkStatsValue(t, statsMcuSubscriberStreamTypesCurrent.WithLabelValues("video"), 1)
assert.Equal(1, stats.Publishers("audio"))
assert.Equal(1, stats.Publishers("video"))
assert.Equal(1, stats.Subscribers("audio"))
assert.Equal(1, stats.Subscribers("video"))
c.AddSubscriber("1")
checkStatsValue(t, statsMcuPublisherStreamTypesCurrent.WithLabelValues("audio"), 1)
checkStatsValue(t, statsMcuPublisherStreamTypesCurrent.WithLabelValues("video"), 1)
checkStatsValue(t, statsMcuSubscriberStreamTypesCurrent.WithLabelValues("audio"), 1)
checkStatsValue(t, statsMcuSubscriberStreamTypesCurrent.WithLabelValues("video"), 1)
assert.Equal(1, stats.Publishers("audio"))
assert.Equal(1, stats.Publishers("video"))
assert.Equal(1, stats.Subscribers("audio"))
assert.Equal(1, stats.Subscribers("video"))
c.AddSubscriber("2")
checkStatsValue(t, statsMcuPublisherStreamTypesCurrent.WithLabelValues("audio"), 1)
checkStatsValue(t, statsMcuPublisherStreamTypesCurrent.WithLabelValues("video"), 1)
checkStatsValue(t, statsMcuSubscriberStreamTypesCurrent.WithLabelValues("audio"), 2)
checkStatsValue(t, statsMcuSubscriberStreamTypesCurrent.WithLabelValues("video"), 2)
assert.Equal(1, stats.Publishers("audio"))
assert.Equal(1, stats.Publishers("video"))
assert.Equal(2, stats.Subscribers("audio"))
assert.Equal(2, stats.Subscribers("video"))
c.RemoveSubscriber("3")
checkStatsValue(t, statsMcuPublisherStreamTypesCurrent.WithLabelValues("audio"), 1)
checkStatsValue(t, statsMcuPublisherStreamTypesCurrent.WithLabelValues("video"), 1)
checkStatsValue(t, statsMcuSubscriberStreamTypesCurrent.WithLabelValues("audio"), 2)
checkStatsValue(t, statsMcuSubscriberStreamTypesCurrent.WithLabelValues("video"), 2)
assert.Equal(1, stats.Publishers("audio"))
assert.Equal(1, stats.Publishers("video"))
assert.Equal(2, stats.Subscribers("audio"))
assert.Equal(2, stats.Subscribers("video"))
c.RemoveSubscriber("1")
checkStatsValue(t, statsMcuPublisherStreamTypesCurrent.WithLabelValues("audio"), 1)
checkStatsValue(t, statsMcuPublisherStreamTypesCurrent.WithLabelValues("video"), 1)
checkStatsValue(t, statsMcuSubscriberStreamTypesCurrent.WithLabelValues("audio"), 1)
checkStatsValue(t, statsMcuSubscriberStreamTypesCurrent.WithLabelValues("video"), 1)
assert.Equal(1, stats.Publishers("audio"))
assert.Equal(1, stats.Publishers("video"))
assert.Equal(1, stats.Subscribers("audio"))
assert.Equal(1, stats.Subscribers("video"))
c.AddSubscriber("1")
checkStatsValue(t, statsMcuPublisherStreamTypesCurrent.WithLabelValues("audio"), 1)
checkStatsValue(t, statsMcuPublisherStreamTypesCurrent.WithLabelValues("video"), 1)
checkStatsValue(t, statsMcuSubscriberStreamTypesCurrent.WithLabelValues("audio"), 2)
checkStatsValue(t, statsMcuSubscriberStreamTypesCurrent.WithLabelValues("video"), 2)
assert.Equal(1, stats.Publishers("audio"))
assert.Equal(1, stats.Publishers("video"))
assert.Equal(2, stats.Subscribers("audio"))
assert.Equal(2, stats.Subscribers("video"))
c.EnableStream("audio", false)
checkStatsValue(t, statsMcuPublisherStreamTypesCurrent.WithLabelValues("audio"), 0)
checkStatsValue(t, statsMcuPublisherStreamTypesCurrent.WithLabelValues("video"), 1)
checkStatsValue(t, statsMcuSubscriberStreamTypesCurrent.WithLabelValues("audio"), 0)
checkStatsValue(t, statsMcuSubscriberStreamTypesCurrent.WithLabelValues("video"), 2)
assert.Equal(0, stats.Publishers("audio"))
assert.Equal(1, stats.Publishers("video"))
assert.Equal(0, stats.Subscribers("audio"))
assert.Equal(2, stats.Subscribers("video"))
c.EnableStream("audio", true)
checkStatsValue(t, statsMcuPublisherStreamTypesCurrent.WithLabelValues("audio"), 1)
checkStatsValue(t, statsMcuPublisherStreamTypesCurrent.WithLabelValues("video"), 1)
checkStatsValue(t, statsMcuSubscriberStreamTypesCurrent.WithLabelValues("audio"), 2)
checkStatsValue(t, statsMcuSubscriberStreamTypesCurrent.WithLabelValues("video"), 2)
assert.Equal(1, stats.Publishers("audio"))
assert.Equal(1, stats.Publishers("video"))
assert.Equal(2, stats.Subscribers("audio"))
assert.Equal(2, stats.Subscribers("video"))
c.EnableStream("audio", false)
c.EnableStream("video", false)
checkStatsValue(t, statsMcuPublisherStreamTypesCurrent.WithLabelValues("audio"), 0)
checkStatsValue(t, statsMcuPublisherStreamTypesCurrent.WithLabelValues("video"), 0)
checkStatsValue(t, statsMcuSubscriberStreamTypesCurrent.WithLabelValues("audio"), 0)
checkStatsValue(t, statsMcuSubscriberStreamTypesCurrent.WithLabelValues("video"), 0)
collectAndLint(t, commonMcuStats...)
assert.Equal(0, stats.Publishers("audio"))
assert.Equal(0, stats.Publishers("video"))
assert.Equal(0, stats.Subscribers("audio"))
assert.Equal(0, stats.Subscribers("video"))
}

View file

@ -192,7 +192,7 @@ func main() {
}
}()
dnsMonitor, err := signaling.NewDnsMonitor(logger, dnsMonitorInterval)
dnsMonitor, err := signaling.NewDnsMonitor(logger, dnsMonitorInterval, nil)
if err != nil {
logger.Fatal("Could not create DNS monitor: ", err)
}

View file

@ -58,7 +58,7 @@ func assertCollectorChangeBy(t *testing.T, collector prometheus.Collector, delta
})
}
func checkStatsValue(t *testing.T, collector prometheus.Collector, value float64) {
func checkStatsValue(t *testing.T, collector prometheus.Collector, value float64) { // nolint:unused
// Make sure test is not executed with "t.Parallel()"
t.Setenv("PARALLEL_CHECK", "1")
@ -71,7 +71,7 @@ func checkStatsValue(t *testing.T, collector prometheus.Collector, value float64
pc := make([]uintptr, 10)
n := runtime.Callers(2, pc)
if n == 0 {
assert.InEpsilon(value, v, 0.0001, "failed for %s", desc)
assert.InDelta(value, v, 0.0001, "failed for %s", desc)
return
}
@ -88,7 +88,7 @@ func checkStatsValue(t *testing.T, collector prometheus.Collector, value float64
break
}
}
assert.InEpsilon(value, v, 0.0001, "Unexpected value for %s at\n%s", desc, stack.String())
assert.InDelta(value, v, 0.0001, "Unexpected value for %s at\n%s", desc, stack.String())
}
}

View file

@ -221,7 +221,10 @@ func NewTestClientContext(ctx context.Context, t *testing.T, server *httptest.Se
messageChan := make(chan []byte)
readErrorChan := make(chan error, 1)
closing := make(chan struct{})
closed := make(chan struct{})
go func() {
defer close(closed)
for {
messageType, data, err := conn.ReadMessage()
if err != nil {
@ -231,9 +234,17 @@ func NewTestClientContext(ctx context.Context, t *testing.T, server *httptest.Se
return
}
messageChan <- data
select {
case messageChan <- data:
case <-closing:
return
}
}
}()
t.Cleanup(func() {
close(closing)
<-closed
})
return &TestClient{
t: t,

View file

@ -95,8 +95,6 @@ type TransientData struct {
listeners map[TransientListener]bool
// +checklocks:mu
timers map[string]*time.Timer
// +checklocks:mu
ttlCh chan<- struct{}
}
// NewTransientData creates a new transient data container.
@ -181,7 +179,10 @@ func (t *TransientData) RemoveListener(listener TransientListener) {
// +checklocks:t.mu
func (t *TransientData) updateTTL(key string, value any, ttl time.Duration) {
if ttl <= 0 {
delete(t.timers, key)
if old, found := t.timers[key]; found {
old.Stop()
delete(t.timers, key)
}
} else {
t.removeAfterTTL(key, value, ttl)
}
@ -189,25 +190,20 @@ func (t *TransientData) updateTTL(key string, value any, ttl time.Duration) {
// +checklocks:t.mu
func (t *TransientData) removeAfterTTL(key string, value any, ttl time.Duration) {
if ttl <= 0 {
return
}
if old, found := t.timers[key]; found {
old.Stop()
}
if ttl <= 0 {
delete(t.timers, key)
return
}
timer := time.AfterFunc(ttl, func() {
t.mu.Lock()
defer t.mu.Unlock()
t.compareAndRemove(key, value)
if t.ttlCh != nil {
select {
case t.ttlCh <- struct{}{}:
default:
}
}
})
if t.timers == nil {
t.timers = make(map[string]*time.Timer)

View file

@ -26,6 +26,7 @@ import (
"net/http/httptest"
"sync"
"testing"
"testing/synctest"
"time"
"github.com/stretchr/testify/assert"
@ -34,59 +35,63 @@ import (
"github.com/strukturag/nextcloud-spreed-signaling/api"
)
func (t *TransientData) SetTTLChannel(ch chan<- struct{}) {
t.mu.Lock()
defer t.mu.Unlock()
t.ttlCh = ch
}
func Test_TransientData(t *testing.T) {
t.Parallel()
assert := assert.New(t)
data := NewTransientData()
assert.False(data.Set("foo", nil))
assert.True(data.Set("foo", "bar"))
assert.False(data.Set("foo", "bar"))
assert.True(data.Set("foo", "baz"))
assert.False(data.CompareAndSet("foo", "bar", "lala"))
assert.True(data.CompareAndSet("foo", "baz", "lala"))
assert.False(data.CompareAndSet("test", nil, nil))
assert.True(data.CompareAndSet("test", nil, "123"))
assert.False(data.CompareAndSet("test", nil, "456"))
assert.False(data.CompareAndRemove("test", "1234"))
assert.True(data.CompareAndRemove("test", "123"))
assert.False(data.Remove("lala"))
assert.True(data.Remove("foo"))
SynctestTest(t, func(t *testing.T) {
assert := assert.New(t)
data := NewTransientData()
assert.False(data.Set("foo", nil))
assert.True(data.Set("foo", "bar"))
assert.False(data.Set("foo", "bar"))
assert.True(data.Set("foo", "baz"))
assert.False(data.CompareAndSet("foo", "bar", "lala"))
assert.True(data.CompareAndSet("foo", "baz", "lala"))
assert.False(data.CompareAndSet("test", nil, nil))
assert.True(data.CompareAndSet("test", nil, "123"))
assert.False(data.CompareAndSet("test", nil, "456"))
assert.False(data.CompareAndRemove("test", "1234"))
assert.True(data.CompareAndRemove("test", "123"))
assert.False(data.Remove("lala"))
assert.True(data.Remove("foo"))
ttlCh := make(chan struct{})
data.SetTTLChannel(ttlCh)
assert.True(data.SetTTL("test", "1234", time.Millisecond))
assert.Equal("1234", data.GetData()["test"])
// Data is removed after the TTL
<-ttlCh
assert.Nil(data.GetData()["test"])
assert.True(data.SetTTL("test", "1234", time.Millisecond))
assert.Equal("1234", data.GetData()["test"])
// Data is removed after the TTL
start := time.Now()
time.Sleep(time.Millisecond)
synctest.Wait()
assert.Equal(time.Millisecond, time.Since(start))
assert.Nil(data.GetData()["test"])
assert.True(data.SetTTL("test", "1234", time.Millisecond))
assert.Equal("1234", data.GetData()["test"])
assert.True(data.SetTTL("test", "2345", 3*time.Millisecond))
assert.Equal("2345", data.GetData()["test"])
// Data is removed after the TTL only if the value still matches
time.Sleep(2 * time.Millisecond)
assert.Equal("2345", data.GetData()["test"])
// Data is removed after the (second) TTL
<-ttlCh
assert.Nil(data.GetData()["test"])
assert.True(data.SetTTL("test", "1234", time.Millisecond))
assert.Equal("1234", data.GetData()["test"])
assert.True(data.SetTTL("test", "2345", 3*time.Millisecond))
assert.Equal("2345", data.GetData()["test"])
start = time.Now()
// Data is removed after the TTL only if the value still matches
time.Sleep(2 * time.Millisecond)
synctest.Wait()
assert.Equal("2345", data.GetData()["test"])
// Data is removed after the (second) TTL
time.Sleep(time.Millisecond)
synctest.Wait()
assert.Equal(3*time.Millisecond, time.Since(start))
assert.Nil(data.GetData()["test"])
// Setting existing key will update the TTL
assert.True(data.SetTTL("test", "1234", time.Millisecond))
assert.False(data.SetTTL("test", "1234", 3*time.Millisecond))
// Data still exists after the first TTL
time.Sleep(2 * time.Millisecond)
assert.Equal("1234", data.GetData()["test"])
// Data is removed after the (updated) TTL
<-ttlCh
assert.Nil(data.GetData()["test"])
// Setting existing key will update the TTL
assert.True(data.SetTTL("test", "1234", time.Millisecond))
assert.False(data.SetTTL("test", "1234", 3*time.Millisecond))
start = time.Now()
// Data still exists after the first TTL
time.Sleep(2 * time.Millisecond)
synctest.Wait()
assert.Equal("1234", data.GetData()["test"])
// Data is removed after the (updated) TTL
time.Sleep(time.Millisecond)
synctest.Wait()
assert.Equal(3*time.Millisecond, time.Since(start))
assert.Nil(data.GetData()["test"])
})
}
type MockTransientListener struct {