Switch to "github.com/stretchr/testify" for tests.

This commit is contained in:
Joachim Bauch 2024-08-29 17:12:33 +02:00
commit 03cad99b8d
No known key found for this signature in database
GPG key ID: 77C1D22D53E15F02
50 changed files with 3082 additions and 6234 deletions

View file

@ -24,19 +24,17 @@ package signaling
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAllowedIps(t *testing.T) {
require := require.New(t)
a, err := ParseAllowedIps("127.0.0.1, 192.168.0.1, 192.168.1.1/24")
if err != nil {
t.Fatal(err)
}
if a.Empty() {
t.Fatal("should not be empty")
}
if expected := `[127.0.0.1/32, 192.168.0.1/32, 192.168.1.0/24]`; a.String() != expected {
t.Errorf("expected %s, got %s", expected, a.String())
}
require.NoError(err)
require.False(a.Empty())
require.Equal(`[127.0.0.1/32, 192.168.0.1/32, 192.168.1.0/24]`, a.String())
allowed := []string{
"127.0.0.1",
@ -51,22 +49,18 @@ func TestAllowedIps(t *testing.T) {
for _, addr := range allowed {
t.Run(addr, func(t *testing.T) {
ip := net.ParseIP(addr)
if ip == nil {
t.Errorf("error parsing %s", addr)
} else if !a.Allowed(ip) {
t.Errorf("should allow %s", addr)
assert := assert.New(t)
if ip := net.ParseIP(addr); assert.NotNil(ip, "error parsing %s", addr) {
assert.True(a.Allowed(ip), "should allow %s", addr)
}
})
}
for _, addr := range notAllowed {
t.Run(addr, func(t *testing.T) {
ip := net.ParseIP(addr)
if ip == nil {
t.Errorf("error parsing %s", addr)
} else if a.Allowed(ip) {
t.Errorf("should not allow %s", addr)
assert := assert.New(t)
if ip := net.ParseIP(addr); assert.NotNil(ip, "error parsing %s", addr) {
assert.False(a.Allowed(ip), "should not allow %s", addr)
}
})
}

View file

@ -24,42 +24,36 @@ package signaling
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
)
func TestBackendChecksum(t *testing.T) {
t.Parallel()
assert := assert.New(t)
rnd := newRandomString(32)
body := []byte{1, 2, 3, 4, 5}
secret := []byte("shared-secret")
check1 := CalculateBackendChecksum(rnd, body, secret)
check2 := CalculateBackendChecksum(rnd, body, secret)
if check1 != check2 {
t.Errorf("Expected equal checksums, got %s and %s", check1, check2)
}
assert.Equal(check1, check2, "Expected equal checksums")
if !ValidateBackendChecksumValue(check1, rnd, body, secret) {
t.Errorf("Checksum %s could not be validated", check1)
}
if ValidateBackendChecksumValue(check1[1:], rnd, body, secret) {
t.Errorf("Checksum %s should not be valid", check1[1:])
}
if ValidateBackendChecksumValue(check1[:len(check1)-1], rnd, body, secret) {
t.Errorf("Checksum %s should not be valid", check1[:len(check1)-1])
}
assert.True(ValidateBackendChecksumValue(check1, rnd, body, secret), "Checksum should be valid")
assert.False(ValidateBackendChecksumValue(check1[1:], rnd, body, secret), "Checksum should not be valid")
assert.False(ValidateBackendChecksumValue(check1[:len(check1)-1], rnd, body, secret), "Checksum should not be valid")
request := &http.Request{
Header: make(http.Header),
}
request.Header.Set("Spreed-Signaling-Random", rnd)
request.Header.Set("Spreed-Signaling-Checksum", check1)
if !ValidateBackendChecksum(request, body, secret) {
t.Errorf("Checksum %s could not be validated from request", check1)
}
assert.True(ValidateBackendChecksum(request, body, secret), "Checksum could not be validated from request")
}
func TestValidNumbers(t *testing.T) {
t.Parallel()
assert := assert.New(t)
valid := []string{
"+12",
"+12345",
@ -72,13 +66,9 @@ func TestValidNumbers(t *testing.T) {
"+123-45",
}
for _, number := range valid {
if !isValidNumber(number) {
t.Errorf("number %s should be valid", number)
}
assert.True(isValidNumber(number), "number %s should be valid", number)
}
for _, number := range invalid {
if isValidNumber(number) {
t.Errorf("number %s should not be valid", number)
}
assert.False(isValidNumber(number), "number %s should not be valid", number)
}
}

View file

@ -24,9 +24,10 @@ package signaling
import (
"encoding/json"
"fmt"
"reflect"
"sort"
"testing"
"github.com/stretchr/testify/assert"
)
type testCheckValid interface {
@ -53,40 +54,34 @@ func wrapMessage(messageType string, msg testCheckValid) *ClientMessage {
}
func testMessages(t *testing.T, messageType string, valid_messages []testCheckValid, invalid_messages []testCheckValid) {
t.Helper()
assert := assert.New(t)
for _, msg := range valid_messages {
if err := msg.CheckValid(); err != nil {
t.Errorf("Message %+v should be valid, got %s", msg, err)
}
assert.NoError(msg.CheckValid(), "Message %+v should be valid", msg)
// If the inner message is valid, it should also be valid in a wrapped
// ClientMessage.
if wrapped := wrapMessage(messageType, msg); wrapped == nil {
t.Errorf("Unknown message type: %s", messageType)
} else if err := wrapped.CheckValid(); err != nil {
t.Errorf("Message %+v should be valid, got %s", wrapped, err)
if wrapped := wrapMessage(messageType, msg); assert.NotNil(wrapped, "Unknown message type: %s", messageType) {
assert.NoError(wrapped.CheckValid(), "Message %+v should be valid", wrapped)
}
}
for _, msg := range invalid_messages {
if err := msg.CheckValid(); err == nil {
t.Errorf("Message %+v should not be valid", msg)
}
assert.Error(msg.CheckValid(), "Message %+v should not be valid", msg)
// If the inner message is invalid, it should also be invalid in a
// wrapped ClientMessage.
if wrapped := wrapMessage(messageType, msg); wrapped == nil {
t.Errorf("Unknown message type: %s", messageType)
} else if err := wrapped.CheckValid(); err == nil {
t.Errorf("Message %+v should not be valid", wrapped)
if wrapped := wrapMessage(messageType, msg); assert.NotNil(wrapped, "Unknown message type: %s", messageType) {
assert.Error(wrapped.CheckValid(), "Message %+v should not be valid", wrapped)
}
}
}
func TestClientMessage(t *testing.T) {
t.Parallel()
assert := assert.New(t)
// The message needs a type.
msg := ClientMessage{}
if err := msg.CheckValid(); err == nil {
t.Errorf("Message %+v should not be valid", msg)
}
assert.Error(msg.CheckValid())
}
func TestHelloClientMessage(t *testing.T) {
@ -229,9 +224,8 @@ func TestHelloClientMessage(t *testing.T) {
msg := ClientMessage{
Type: "hello",
}
if err := msg.CheckValid(); err == nil {
t.Errorf("Message %+v should not be valid", msg)
}
assert := assert.New(t)
assert.Error(msg.CheckValid())
}
func TestMessageClientMessage(t *testing.T) {
@ -311,9 +305,8 @@ func TestMessageClientMessage(t *testing.T) {
msg := ClientMessage{
Type: "message",
}
if err := msg.CheckValid(); err == nil {
t.Errorf("Message %+v should not be valid", msg)
}
assert := assert.New(t)
assert.Error(msg.CheckValid())
}
func TestByeClientMessage(t *testing.T) {
@ -330,9 +323,8 @@ func TestByeClientMessage(t *testing.T) {
msg := ClientMessage{
Type: "bye",
}
if err := msg.CheckValid(); err != nil {
t.Errorf("Message %+v should be valid, got %s", msg, err)
}
assert := assert.New(t)
assert.NoError(msg.CheckValid())
}
func TestRoomClientMessage(t *testing.T) {
@ -349,42 +341,31 @@ func TestRoomClientMessage(t *testing.T) {
msg := ClientMessage{
Type: "room",
}
if err := msg.CheckValid(); err == nil {
t.Errorf("Message %+v should not be valid", msg)
}
assert := assert.New(t)
assert.Error(msg.CheckValid())
}
func TestErrorMessages(t *testing.T) {
t.Parallel()
assert := assert.New(t)
id := "request-id"
msg := ClientMessage{
Id: id,
}
err1 := msg.NewErrorServerMessage(&Error{})
if err1.Id != id {
t.Errorf("Expected id %s, got %+v", id, err1)
}
if err1.Type != "error" || err1.Error == nil {
t.Errorf("Expected type \"error\", got %+v", err1)
}
assert.Equal(id, err1.Id, "%+v", err1)
assert.Equal("error", err1.Type, "%+v", err1)
assert.NotNil(err1.Error, "%+v", err1)
err2 := msg.NewWrappedErrorServerMessage(fmt.Errorf("test-error"))
if err2.Id != id {
t.Errorf("Expected id %s, got %+v", id, err2)
}
if err2.Type != "error" || err2.Error == nil {
t.Errorf("Expected type \"error\", got %+v", err2)
}
if err2.Error.Code != "internal_error" {
t.Errorf("Expected code \"internal_error\", got %+v", err2)
}
if err2.Error.Message != "test-error" {
t.Errorf("Expected message \"test-error\", got %+v", err2)
assert.Equal(id, err2.Id, "%+v", err2)
assert.Equal("error", err2.Type, "%+v", err2)
if assert.NotNil(err2.Error, "%+v", err2) {
assert.Equal("internal_error", err2.Error.Code, "%+v", err2)
assert.Equal("test-error", err2.Error.Message, "%+v", err2)
}
// Test "error" interface
if err2.Error.Error() != "test-error" {
t.Errorf("Expected error string \"test-error\", got %+v", err2)
}
assert.Equal("test-error", err2.Error.Error(), "%+v", err2)
}
func TestIsChatRefresh(t *testing.T) {
@ -397,9 +378,7 @@ func TestIsChatRefresh(t *testing.T) {
Data: data_true,
},
}
if !msg.IsChatRefresh() {
t.Error("message should be detected as chat refresh")
}
assert.True(t, msg.IsChatRefresh())
data_false := []byte("{\"type\":\"chat\",\"chat\":{\"refresh\":false}}")
msg = ServerMessage{
@ -408,9 +387,7 @@ func TestIsChatRefresh(t *testing.T) {
Data: data_false,
},
}
if msg.IsChatRefresh() {
t.Error("message should not be detected as chat refresh")
}
assert.False(t, msg.IsChatRefresh())
}
func assertEqualStrings(t *testing.T, expected, result []string) {
@ -427,27 +404,22 @@ func assertEqualStrings(t *testing.T, expected, result []string) {
sort.Strings(result)
}
if !reflect.DeepEqual(expected, result) {
t.Errorf("Expected %+v, got %+v", expected, result)
}
assert.Equal(t, expected, result)
}
func Test_Welcome_AddRemoveFeature(t *testing.T) {
t.Parallel()
assert := assert.New(t)
var msg WelcomeServerMessage
assertEqualStrings(t, []string{}, msg.Features)
msg.AddFeature("one", "two", "one")
assertEqualStrings(t, []string{"one", "two"}, msg.Features)
if !sort.StringsAreSorted(msg.Features) {
t.Errorf("features should be sorted, got %+v", msg.Features)
}
assert.True(sort.StringsAreSorted(msg.Features), "features should be sorted, got %+v", msg.Features)
msg.AddFeature("three")
assertEqualStrings(t, []string{"one", "two", "three"}, msg.Features)
if !sort.StringsAreSorted(msg.Features) {
t.Errorf("features should be sorted, got %+v", msg.Features)
}
assert.True(sort.StringsAreSorted(msg.Features), "features should be sorted, got %+v", msg.Features)
msg.RemoveFeature("three", "one")
assertEqualStrings(t, []string{"two"}, msg.Features)

View file

@ -25,6 +25,8 @@ import (
"context"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
var (
@ -51,7 +53,7 @@ func getRealAsyncEventsForTest(t *testing.T) AsyncEvents {
url := startLocalNatsServer(t)
events, err := NewAsyncEvents(url)
if err != nil {
t.Fatal(err)
require.NoError(t, err)
}
return events
}
@ -59,7 +61,7 @@ func getRealAsyncEventsForTest(t *testing.T) AsyncEvents {
func getLoopbackAsyncEventsForTest(t *testing.T) AsyncEvents {
events, err := NewAsyncEvents(NatsLoopbackUrl)
if err != nil {
t.Fatal(err)
require.NoError(t, err)
}
t.Cleanup(func() {

View file

@ -24,17 +24,17 @@ package signaling
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"strings"
"testing"
"github.com/dlintw/goconf"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func returnOCS(t *testing.T, w http.ResponseWriter, body []byte) {
@ -57,37 +57,29 @@ func returnOCS(t *testing.T, w http.ResponseWriter, body []byte) {
}
data, err := json.Marshal(response)
if err != nil {
t.Fatal(err)
return
}
require.NoError(t, err)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if _, err := w.Write(data); err != nil {
t.Error(err)
}
_, err = w.Write(data)
assert.NoError(t, err)
}
func TestPostOnRedirect(t *testing.T) {
t.Parallel()
CatchLogForTest(t)
require := require.New(t)
r := mux.NewRouter()
r.HandleFunc("/ocs/v2.php/one", func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/ocs/v2.php/two", http.StatusTemporaryRedirect)
})
r.HandleFunc("/ocs/v2.php/two", func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
if err != nil {
t.Fatal(err)
return
}
require.NoError(err)
var request map[string]string
if err := json.Unmarshal(body, &request); err != nil {
t.Fatal(err)
return
}
err = json.Unmarshal(body, &request)
require.NoError(err)
returnOCS(t, w, body)
})
@ -96,9 +88,7 @@ func TestPostOnRedirect(t *testing.T) {
defer server.Close()
u, err := url.Parse(server.URL + "/ocs/v2.php/one")
if err != nil {
t.Fatal(err)
}
require.NoError(err)
config := goconf.NewConfigFile()
config.AddOption("backend", "allowed", u.Host)
@ -107,9 +97,7 @@ func TestPostOnRedirect(t *testing.T) {
config.AddOption("backend", "allowhttp", "true")
}
client, err := NewBackendClient(config, 1, "0.0", nil)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
ctx := context.Background()
request := map[string]string{
@ -117,18 +105,17 @@ func TestPostOnRedirect(t *testing.T) {
}
var response map[string]string
err = client.PerformJSONRequest(ctx, u, request, &response)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
if response == nil || !reflect.DeepEqual(request, response) {
t.Errorf("Expected %+v, got %+v", request, response)
if assert.NotNil(t, response) {
assert.Equal(t, request, response)
}
}
func TestPostOnRedirectDifferentHost(t *testing.T) {
t.Parallel()
CatchLogForTest(t)
require := require.New(t)
r := mux.NewRouter()
r.HandleFunc("/ocs/v2.php/one", func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "http://domain.invalid/ocs/v2.php/two", http.StatusTemporaryRedirect)
@ -137,9 +124,7 @@ func TestPostOnRedirectDifferentHost(t *testing.T) {
defer server.Close()
u, err := url.Parse(server.URL + "/ocs/v2.php/one")
if err != nil {
t.Fatal(err)
}
require.NoError(err)
config := goconf.NewConfigFile()
config.AddOption("backend", "allowed", u.Host)
@ -148,9 +133,7 @@ func TestPostOnRedirectDifferentHost(t *testing.T) {
config.AddOption("backend", "allowhttp", "true")
}
client, err := NewBackendClient(config, 1, "0.0", nil)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
ctx := context.Background()
request := map[string]string{
@ -160,41 +143,33 @@ func TestPostOnRedirectDifferentHost(t *testing.T) {
err = client.PerformJSONRequest(ctx, u, request, &response)
if err != nil {
// The redirect to a different host should have failed.
if !errors.Is(err, ErrNotRedirecting) {
t.Fatal(err)
}
require.ErrorIs(err, ErrNotRedirecting)
} else {
t.Fatal("The redirect should have failed")
require.Fail("The redirect should have failed")
}
}
func TestPostOnRedirectStatusFound(t *testing.T) {
t.Parallel()
CatchLogForTest(t)
require := require.New(t)
assert := assert.New(t)
r := mux.NewRouter()
r.HandleFunc("/ocs/v2.php/one", func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/ocs/v2.php/two", http.StatusFound)
})
r.HandleFunc("/ocs/v2.php/two", func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
if err != nil {
t.Fatal(err)
return
}
if len(body) > 0 {
t.Errorf("Should not have received any body, got %s", string(body))
}
require.NoError(err)
assert.Empty(string(body), "Should not have received any body, got %s", string(body))
returnOCS(t, w, []byte("{}"))
})
server := httptest.NewServer(r)
defer server.Close()
u, err := url.Parse(server.URL + "/ocs/v2.php/one")
if err != nil {
t.Fatal(err)
}
require.NoError(err)
config := goconf.NewConfigFile()
config.AddOption("backend", "allowed", u.Host)
@ -203,9 +178,7 @@ func TestPostOnRedirectStatusFound(t *testing.T) {
config.AddOption("backend", "allowhttp", "true")
}
client, err := NewBackendClient(config, 1, "0.0", nil)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
ctx := context.Background()
request := map[string]string{
@ -213,18 +186,16 @@ func TestPostOnRedirectStatusFound(t *testing.T) {
}
var response map[string]string
err = client.PerformJSONRequest(ctx, u, request, &response)
if err != nil {
t.Error(err)
}
if len(response) > 0 {
t.Errorf("Expected empty response, got %+v", response)
if assert.NoError(err) {
assert.Empty(response, "Expected empty response, got %+v", response)
}
}
func TestHandleThrottled(t *testing.T) {
t.Parallel()
CatchLogForTest(t)
require := require.New(t)
assert := assert.New(t)
r := mux.NewRouter()
r.HandleFunc("/ocs/v2.php/one", func(w http.ResponseWriter, r *http.Request) {
returnOCS(t, w, []byte("[]"))
@ -233,9 +204,7 @@ func TestHandleThrottled(t *testing.T) {
defer server.Close()
u, err := url.Parse(server.URL + "/ocs/v2.php/one")
if err != nil {
t.Fatal(err)
}
require.NoError(err)
config := goconf.NewConfigFile()
config.AddOption("backend", "allowed", u.Host)
@ -244,9 +213,7 @@ func TestHandleThrottled(t *testing.T) {
config.AddOption("backend", "allowhttp", "true")
}
client, err := NewBackendClient(config, 1, "0.0", nil)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
ctx := context.Background()
request := map[string]string{
@ -254,9 +221,7 @@ func TestHandleThrottled(t *testing.T) {
}
var response map[string]string
err = client.PerformJSONRequest(ctx, u, request, &response)
if err == nil {
t.Error("should have triggered an error")
} else if !errors.Is(err, ErrThrottledResponse) {
t.Error(err)
if assert.Error(err) {
assert.ErrorIs(err, ErrThrottledResponse)
}
}

View file

@ -22,7 +22,6 @@
package signaling
import (
"bytes"
"context"
"net/url"
"reflect"
@ -31,32 +30,30 @@ import (
"github.com/dlintw/goconf"
"github.com/prometheus/client_golang/prometheus/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func testUrls(t *testing.T, config *BackendConfiguration, valid_urls []string, invalid_urls []string) {
for _, u := range valid_urls {
u := u
t.Run(u, func(t *testing.T) {
assert := assert.New(t)
parsed, err := url.ParseRequestURI(u)
if err != nil {
t.Errorf("The url %s should be valid, got %s", u, err)
if !assert.NoError(err, "The url %s should be valid", u) {
return
}
if !config.IsUrlAllowed(parsed) {
t.Errorf("The url %s should be allowed", u)
}
if secret := config.GetSecret(parsed); !bytes.Equal(secret, testBackendSecret) {
t.Errorf("Expected secret %s for url %s, got %s", string(testBackendSecret), u, string(secret))
}
assert.True(config.IsUrlAllowed(parsed), "The url %s should be allowed", u)
secret := config.GetSecret(parsed)
assert.Equal(string(testBackendSecret), string(secret), "Expected secret %s for url %s, got %s", string(testBackendSecret), u, string(secret))
})
}
for _, u := range invalid_urls {
u := u
t.Run(u, func(t *testing.T) {
assert := assert.New(t)
parsed, _ := url.ParseRequestURI(u)
if config.IsUrlAllowed(parsed) {
t.Errorf("The url %s should not be allowed", u)
}
assert.False(config.IsUrlAllowed(parsed), "The url %s should not be allowed", u)
})
}
}
@ -65,28 +62,24 @@ func testBackends(t *testing.T, config *BackendConfiguration, valid_urls [][]str
for _, entry := range valid_urls {
entry := entry
t.Run(entry[0], func(t *testing.T) {
assert := assert.New(t)
u := entry[0]
parsed, err := url.ParseRequestURI(u)
if err != nil {
t.Errorf("The url %s should be valid, got %s", u, err)
if !assert.NoError(err, "The url %s should be valid", u) {
return
}
if !config.IsUrlAllowed(parsed) {
t.Errorf("The url %s should be allowed", u)
}
assert.True(config.IsUrlAllowed(parsed), "The url %s should be allowed", u)
s := entry[1]
if secret := config.GetSecret(parsed); !bytes.Equal(secret, []byte(s)) {
t.Errorf("Expected secret %s for url %s, got %s", string(s), u, string(secret))
}
secret := config.GetSecret(parsed)
assert.Equal(s, string(secret), "Expected secret %s for url %s, got %s", s, u, string(secret))
})
}
for _, u := range invalid_urls {
u := u
t.Run(u, func(t *testing.T) {
assert := assert.New(t)
parsed, _ := url.ParseRequestURI(u)
if config.IsUrlAllowed(parsed) {
t.Errorf("The url %s should not be allowed", u)
}
assert.False(config.IsUrlAllowed(parsed), "The url %s should not be allowed", u)
})
}
}
@ -108,9 +101,7 @@ func TestIsUrlAllowed_Compat(t *testing.T) {
config.AddOption("backend", "allowhttp", "true")
config.AddOption("backend", "secret", string(testBackendSecret))
cfg, err := NewBackendConfiguration(config, nil)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
testUrls(t, cfg, valid_urls, invalid_urls)
}
@ -130,9 +121,7 @@ func TestIsUrlAllowed_CompatForceHttps(t *testing.T) {
config.AddOption("backend", "allowed", "domain.invalid")
config.AddOption("backend", "secret", string(testBackendSecret))
cfg, err := NewBackendConfiguration(config, nil)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
testUrls(t, cfg, valid_urls, invalid_urls)
}
@ -176,9 +165,7 @@ func TestIsUrlAllowed(t *testing.T) {
config.AddOption("lala", "url", "https://otherdomain.invalid/")
config.AddOption("lala", "secret", string(testBackendSecret)+"-lala")
cfg, err := NewBackendConfiguration(config, nil)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
testBackends(t, cfg, valid_urls, invalid_urls)
}
@ -194,9 +181,7 @@ func TestIsUrlAllowed_EmptyAllowlist(t *testing.T) {
config.AddOption("backend", "allowed", "")
config.AddOption("backend", "secret", string(testBackendSecret))
cfg, err := NewBackendConfiguration(config, nil)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
testUrls(t, cfg, valid_urls, invalid_urls)
}
@ -215,9 +200,7 @@ func TestIsUrlAllowed_AllowAll(t *testing.T) {
config.AddOption("backend", "allowed", "")
config.AddOption("backend", "secret", string(testBackendSecret))
cfg, err := NewBackendConfiguration(config, nil)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
testUrls(t, cfg, valid_urls, invalid_urls)
}
@ -238,16 +221,16 @@ func TestParseBackendIds(t *testing.T) {
{"backend1,backend2, backend1", []string{"backend1", "backend2"}},
}
assert := assert.New(t)
for _, test := range testcases {
ids := getConfiguredBackendIDs(test.s)
if !reflect.DeepEqual(ids, test.ids) {
t.Errorf("List of ids differs, expected %+v, got %+v", test.ids, ids)
}
assert.Equal(test.ids, ids, "List of ids differs for \"%s\"", test.s)
}
}
func TestBackendReloadNoChange(t *testing.T) {
CatchLogForTest(t)
require := require.New(t)
current := testutil.ToFloat64(statsBackendsCurrent)
original_config := goconf.NewConfigFile()
original_config.AddOption("backend", "backends", "backend1, backend2")
@ -257,9 +240,7 @@ func TestBackendReloadNoChange(t *testing.T) {
original_config.AddOption("backend2", "url", "http://domain2.invalid")
original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
o_cfg, err := NewBackendConfiguration(original_config, nil)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
checkStatsValue(t, statsBackendsCurrent, current+2)
new_config := goconf.NewConfigFile()
@ -270,20 +251,19 @@ func TestBackendReloadNoChange(t *testing.T) {
new_config.AddOption("backend2", "url", "http://domain2.invalid")
new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
n_cfg, err := NewBackendConfiguration(new_config, nil)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
checkStatsValue(t, statsBackendsCurrent, current+4)
o_cfg.Reload(original_config)
checkStatsValue(t, statsBackendsCurrent, current+4)
if !reflect.DeepEqual(n_cfg, o_cfg) {
t.Error("BackendConfiguration should be equal after Reload")
assert.Fail(t, "BackendConfiguration should be equal after Reload")
}
}
func TestBackendReloadChangeExistingURL(t *testing.T) {
CatchLogForTest(t)
require := require.New(t)
current := testutil.ToFloat64(statsBackendsCurrent)
original_config := goconf.NewConfigFile()
original_config.AddOption("backend", "backends", "backend1, backend2")
@ -293,9 +273,7 @@ func TestBackendReloadChangeExistingURL(t *testing.T) {
original_config.AddOption("backend2", "url", "http://domain2.invalid")
original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
o_cfg, err := NewBackendConfiguration(original_config, nil)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
checkStatsValue(t, statsBackendsCurrent, current+2)
new_config := goconf.NewConfigFile()
@ -307,9 +285,7 @@ func TestBackendReloadChangeExistingURL(t *testing.T) {
new_config.AddOption("backend2", "url", "http://domain2.invalid")
new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
n_cfg, err := NewBackendConfiguration(new_config, nil)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
checkStatsValue(t, statsBackendsCurrent, current+4)
original_config.RemoveOption("backend1", "url")
@ -319,12 +295,13 @@ func TestBackendReloadChangeExistingURL(t *testing.T) {
o_cfg.Reload(original_config)
checkStatsValue(t, statsBackendsCurrent, current+4)
if !reflect.DeepEqual(n_cfg, o_cfg) {
t.Error("BackendConfiguration should be equal after Reload")
assert.Fail(t, "BackendConfiguration should be equal after Reload")
}
}
func TestBackendReloadChangeSecret(t *testing.T) {
CatchLogForTest(t)
require := require.New(t)
current := testutil.ToFloat64(statsBackendsCurrent)
original_config := goconf.NewConfigFile()
original_config.AddOption("backend", "backends", "backend1, backend2")
@ -334,9 +311,7 @@ func TestBackendReloadChangeSecret(t *testing.T) {
original_config.AddOption("backend2", "url", "http://domain2.invalid")
original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
o_cfg, err := NewBackendConfiguration(original_config, nil)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
checkStatsValue(t, statsBackendsCurrent, current+2)
new_config := goconf.NewConfigFile()
@ -347,9 +322,7 @@ func TestBackendReloadChangeSecret(t *testing.T) {
new_config.AddOption("backend2", "url", "http://domain2.invalid")
new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
n_cfg, err := NewBackendConfiguration(new_config, nil)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
checkStatsValue(t, statsBackendsCurrent, current+4)
original_config.RemoveOption("backend1", "secret")
@ -358,12 +331,13 @@ func TestBackendReloadChangeSecret(t *testing.T) {
o_cfg.Reload(original_config)
checkStatsValue(t, statsBackendsCurrent, current+4)
if !reflect.DeepEqual(n_cfg, o_cfg) {
t.Error("BackendConfiguration should be equal after Reload")
assert.Fail(t, "BackendConfiguration should be equal after Reload")
}
}
func TestBackendReloadAddBackend(t *testing.T) {
CatchLogForTest(t)
require := require.New(t)
current := testutil.ToFloat64(statsBackendsCurrent)
original_config := goconf.NewConfigFile()
original_config.AddOption("backend", "backends", "backend1")
@ -371,9 +345,7 @@ func TestBackendReloadAddBackend(t *testing.T) {
original_config.AddOption("backend1", "url", "http://domain1.invalid")
original_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
o_cfg, err := NewBackendConfiguration(original_config, nil)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
checkStatsValue(t, statsBackendsCurrent, current+1)
new_config := goconf.NewConfigFile()
@ -385,9 +357,7 @@ func TestBackendReloadAddBackend(t *testing.T) {
new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
new_config.AddOption("backend2", "sessionlimit", "10")
n_cfg, err := NewBackendConfiguration(new_config, nil)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
checkStatsValue(t, statsBackendsCurrent, current+3)
original_config.RemoveOption("backend", "backends")
@ -399,12 +369,13 @@ func TestBackendReloadAddBackend(t *testing.T) {
o_cfg.Reload(original_config)
checkStatsValue(t, statsBackendsCurrent, current+4)
if !reflect.DeepEqual(n_cfg, o_cfg) {
t.Error("BackendConfiguration should be equal after Reload")
assert.Fail(t, "BackendConfiguration should be equal after Reload")
}
}
func TestBackendReloadRemoveHost(t *testing.T) {
CatchLogForTest(t)
require := require.New(t)
current := testutil.ToFloat64(statsBackendsCurrent)
original_config := goconf.NewConfigFile()
original_config.AddOption("backend", "backends", "backend1, backend2")
@ -414,9 +385,7 @@ func TestBackendReloadRemoveHost(t *testing.T) {
original_config.AddOption("backend2", "url", "http://domain2.invalid")
original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
o_cfg, err := NewBackendConfiguration(original_config, nil)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
checkStatsValue(t, statsBackendsCurrent, current+2)
new_config := goconf.NewConfigFile()
@ -425,9 +394,7 @@ func TestBackendReloadRemoveHost(t *testing.T) {
new_config.AddOption("backend1", "url", "http://domain1.invalid")
new_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
n_cfg, err := NewBackendConfiguration(new_config, nil)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
checkStatsValue(t, statsBackendsCurrent, current+3)
original_config.RemoveOption("backend", "backends")
@ -437,12 +404,13 @@ func TestBackendReloadRemoveHost(t *testing.T) {
o_cfg.Reload(original_config)
checkStatsValue(t, statsBackendsCurrent, current+2)
if !reflect.DeepEqual(n_cfg, o_cfg) {
t.Error("BackendConfiguration should be equal after Reload")
assert.Fail(t, "BackendConfiguration should be equal after Reload")
}
}
func TestBackendReloadRemoveBackendFromSharedHost(t *testing.T) {
CatchLogForTest(t)
require := require.New(t)
current := testutil.ToFloat64(statsBackendsCurrent)
original_config := goconf.NewConfigFile()
original_config.AddOption("backend", "backends", "backend1, backend2")
@ -452,9 +420,7 @@ func TestBackendReloadRemoveBackendFromSharedHost(t *testing.T) {
original_config.AddOption("backend2", "url", "http://domain1.invalid/bar/")
original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
o_cfg, err := NewBackendConfiguration(original_config, nil)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
checkStatsValue(t, statsBackendsCurrent, current+2)
new_config := goconf.NewConfigFile()
@ -463,9 +429,7 @@ func TestBackendReloadRemoveBackendFromSharedHost(t *testing.T) {
new_config.AddOption("backend1", "url", "http://domain1.invalid/foo/")
new_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
n_cfg, err := NewBackendConfiguration(new_config, nil)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
checkStatsValue(t, statsBackendsCurrent, current+3)
original_config.RemoveOption("backend", "backends")
@ -475,7 +439,7 @@ func TestBackendReloadRemoveBackendFromSharedHost(t *testing.T) {
o_cfg.Reload(original_config)
checkStatsValue(t, statsBackendsCurrent, current+2)
if !reflect.DeepEqual(n_cfg, o_cfg) {
t.Error("BackendConfiguration should be equal after Reload")
assert.Fail(t, "BackendConfiguration should be equal after Reload")
}
}
@ -500,6 +464,8 @@ func mustParse(s string) *url.URL {
func TestBackendConfiguration_Etcd(t *testing.T) {
t.Parallel()
CatchLogForTest(t)
require := require.New(t)
assert := assert.New(t)
etcd, client := NewEtcdClientForTest(t)
url1 := "https://domain1.invalid/foo"
@ -513,9 +479,7 @@ func TestBackendConfiguration_Etcd(t *testing.T) {
config.AddOption("backend", "backendprefix", "/backends")
cfg, err := NewBackendConfiguration(config, client)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
defer cfg.Close()
storage := cfg.storage.(*backendStorageEtcd)
@ -524,31 +488,25 @@ func TestBackendConfiguration_Etcd(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
if err := storage.WaitForInitialized(ctx); err != nil {
t.Fatal(err)
}
require.NoError(storage.WaitForInitialized(ctx))
if backends := sortBackends(cfg.GetBackends()); len(backends) != 1 {
t.Errorf("Expected one backend, got %+v", backends)
} else if backends[0].url != url1 {
t.Errorf("Expected backend url %s, got %s", url1, backends[0].url)
} else if string(backends[0].secret) != initialSecret1 {
t.Errorf("Expected backend secret %s, got %s", initialSecret1, string(backends[0].secret))
} else if backend := cfg.GetBackend(mustParse(url1)); backend != backends[0] {
t.Errorf("Expected backend %+v, got %+v", backends[0], backend)
if backends := sortBackends(cfg.GetBackends()); assert.Len(backends, 1) &&
assert.Equal(url1, backends[0].url) &&
assert.Equal(initialSecret1, string(backends[0].secret)) {
if backend := cfg.GetBackend(mustParse(url1)); backend != backends[0] {
assert.Fail("Expected backend %+v, got %+v", backends[0], backend)
}
}
drainWakeupChannel(ch)
SetEtcdValue(etcd, "/backends/1_one", []byte("{\"url\":\""+url1+"\",\"secret\":\""+secret1+"\"}"))
<-ch
if backends := sortBackends(cfg.GetBackends()); len(backends) != 1 {
t.Errorf("Expected one backend, got %+v", backends)
} else if backends[0].url != url1 {
t.Errorf("Expected backend url %s, got %s", url1, backends[0].url)
} else if string(backends[0].secret) != secret1 {
t.Errorf("Expected backend secret %s, got %s", secret1, string(backends[0].secret))
} else if backend := cfg.GetBackend(mustParse(url1)); backend != backends[0] {
t.Errorf("Expected backend %+v, got %+v", backends[0], backend)
if backends := sortBackends(cfg.GetBackends()); assert.Len(backends, 1) &&
assert.Equal(url1, backends[0].url) &&
assert.Equal(secret1, string(backends[0].secret)) {
if backend := cfg.GetBackend(mustParse(url1)); backend != backends[0] {
assert.Fail("Expected backend %+v, got %+v", backends[0], backend)
}
}
url2 := "https://domain1.invalid/bar"
@ -557,20 +515,16 @@ func TestBackendConfiguration_Etcd(t *testing.T) {
drainWakeupChannel(ch)
SetEtcdValue(etcd, "/backends/2_two", []byte("{\"url\":\""+url2+"\",\"secret\":\""+secret2+"\"}"))
<-ch
if backends := sortBackends(cfg.GetBackends()); len(backends) != 2 {
t.Errorf("Expected two backends, got %+v", backends)
} else if backends[0].url != url1 {
t.Errorf("Expected backend url %s, got %s", url1, backends[0].url)
} else if string(backends[0].secret) != secret1 {
t.Errorf("Expected backend secret %s, got %s", secret1, string(backends[0].secret))
} else if backends[1].url != url2 {
t.Errorf("Expected backend url %s, got %s", url2, backends[1].url)
} else if string(backends[1].secret) != secret2 {
t.Errorf("Expected backend secret %s, got %s", secret2, string(backends[1].secret))
} else if backend := cfg.GetBackend(mustParse(url1)); backend != backends[0] {
t.Errorf("Expected backend %+v, got %+v", backends[0], backend)
} else if backend := cfg.GetBackend(mustParse(url2)); backend != backends[1] {
t.Errorf("Expected backend %+v, got %+v", backends[1], backend)
if backends := sortBackends(cfg.GetBackends()); assert.Len(backends, 2) &&
assert.Equal(url1, backends[0].url) &&
assert.Equal(secret1, string(backends[0].secret)) &&
assert.Equal(url2, backends[1].url) &&
assert.Equal(secret2, string(backends[1].secret)) {
if backend := cfg.GetBackend(mustParse(url1)); backend != backends[0] {
assert.Fail("Expected backend %+v, got %+v", backends[0], backend)
} else if backend := cfg.GetBackend(mustParse(url2)); backend != backends[1] {
assert.Fail("Expected backend %+v, got %+v", backends[1], backend)
}
}
url3 := "https://domain2.invalid/foo"
@ -579,70 +533,54 @@ func TestBackendConfiguration_Etcd(t *testing.T) {
drainWakeupChannel(ch)
SetEtcdValue(etcd, "/backends/3_three", []byte("{\"url\":\""+url3+"\",\"secret\":\""+secret3+"\"}"))
<-ch
if backends := sortBackends(cfg.GetBackends()); len(backends) != 3 {
t.Errorf("Expected three backends, got %+v", backends)
} else if backends[0].url != url1 {
t.Errorf("Expected backend url %s, got %s", url1, backends[0].url)
} else if string(backends[0].secret) != secret1 {
t.Errorf("Expected backend secret %s, got %s", secret1, string(backends[0].secret))
} else if backends[1].url != url2 {
t.Errorf("Expected backend url %s, got %s", url2, backends[1].url)
} else if string(backends[1].secret) != secret2 {
t.Errorf("Expected backend secret %s, got %s", secret2, string(backends[1].secret))
} else if backends[2].url != url3 {
t.Errorf("Expected backend url %s, got %s", url3, backends[2].url)
} else if string(backends[2].secret) != secret3 {
t.Errorf("Expected backend secret %s, got %s", secret3, string(backends[2].secret))
} else if backend := cfg.GetBackend(mustParse(url1)); backend != backends[0] {
t.Errorf("Expected backend %+v, got %+v", backends[0], backend)
} else if backend := cfg.GetBackend(mustParse(url2)); backend != backends[1] {
t.Errorf("Expected backend %+v, got %+v", backends[1], backend)
} else if backend := cfg.GetBackend(mustParse(url3)); backend != backends[2] {
t.Errorf("Expected backend %+v, got %+v", backends[2], backend)
if backends := sortBackends(cfg.GetBackends()); assert.Len(backends, 3) &&
assert.Equal(url1, backends[0].url) &&
assert.Equal(secret1, string(backends[0].secret)) &&
assert.Equal(url2, backends[1].url) &&
assert.Equal(secret2, string(backends[1].secret)) &&
assert.Equal(url3, backends[2].url) &&
assert.Equal(secret3, string(backends[2].secret)) {
if backend := cfg.GetBackend(mustParse(url1)); backend != backends[0] {
assert.Fail("Expected backend %+v, got %+v", backends[0], backend)
} else if backend := cfg.GetBackend(mustParse(url2)); backend != backends[1] {
assert.Fail("Expected backend %+v, got %+v", backends[1], backend)
} else if backend := cfg.GetBackend(mustParse(url3)); backend != backends[2] {
assert.Fail("Expected backend %+v, got %+v", backends[2], backend)
}
}
drainWakeupChannel(ch)
DeleteEtcdValue(etcd, "/backends/1_one")
<-ch
if backends := sortBackends(cfg.GetBackends()); len(backends) != 2 {
t.Errorf("Expected two backends, got %+v", backends)
} else if backends[0].url != url2 {
t.Errorf("Expected backend url %s, got %s", url2, backends[0].url)
} else if string(backends[0].secret) != secret2 {
t.Errorf("Expected backend secret %s, got %s", secret2, string(backends[0].secret))
} else if backends[1].url != url3 {
t.Errorf("Expected backend url %s, got %s", url3, backends[1].url)
} else if string(backends[1].secret) != secret3 {
t.Errorf("Expected backend secret %s, got %s", secret3, string(backends[1].secret))
if backends := sortBackends(cfg.GetBackends()); assert.Len(backends, 2) {
assert.Equal(url2, backends[0].url)
assert.Equal(secret2, string(backends[0].secret))
assert.Equal(url3, backends[1].url)
assert.Equal(secret3, string(backends[1].secret))
}
drainWakeupChannel(ch)
DeleteEtcdValue(etcd, "/backends/2_two")
<-ch
if backends := sortBackends(cfg.GetBackends()); len(backends) != 1 {
t.Errorf("Expected one backend, got %+v", backends)
} else if backends[0].url != url3 {
t.Errorf("Expected backend url %s, got %s", url3, backends[0].url)
} else if string(backends[0].secret) != secret3 {
t.Errorf("Expected backend secret %s, got %s", secret3, string(backends[0].secret))
if backends := sortBackends(cfg.GetBackends()); assert.Len(backends, 1) {
assert.Equal(url3, backends[0].url)
assert.Equal(secret3, string(backends[0].secret))
}
if _, found := storage.backends["domain1.invalid"]; found {
t.Errorf("Should have removed host information for %s", "domain1.invalid")
assert.Fail("Should have removed host information for %s", "domain1.invalid")
}
}
func TestBackendCommonSecret(t *testing.T) {
t.Parallel()
CatchLogForTest(t)
require := require.New(t)
assert := assert.New(t)
u1, err := url.Parse("http://domain1.invalid")
if err != nil {
t.Fatal(err)
}
require.NoError(err)
u2, err := url.Parse("http://domain2.invalid")
if err != nil {
t.Fatal(err)
}
require.NoError(err)
original_config := goconf.NewConfigFile()
original_config.AddOption("backend", "backends", "backend1, backend2")
original_config.AddOption("backend", "secret", string(testBackendSecret))
@ -650,19 +588,13 @@ func TestBackendCommonSecret(t *testing.T) {
original_config.AddOption("backend2", "url", u2.String())
original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
cfg, err := NewBackendConfiguration(original_config, nil)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
if b1 := cfg.GetBackend(u1); b1 == nil {
t.Error("didn't get backend")
} else if !bytes.Equal(b1.Secret(), testBackendSecret) {
t.Errorf("expected secret %s, got %s", string(testBackendSecret), string(b1.Secret()))
if b1 := cfg.GetBackend(u1); assert.NotNil(b1) {
assert.Equal(string(testBackendSecret), string(b1.Secret()))
}
if b2 := cfg.GetBackend(u2); b2 == nil {
t.Error("didn't get backend")
} else if !bytes.Equal(b2.Secret(), []byte(string(testBackendSecret)+"-backend2")) {
t.Errorf("expected secret %s, got %s", string(testBackendSecret)+"-backend2", string(b2.Secret()))
if b2 := cfg.GetBackend(u2); assert.NotNil(b2) {
assert.Equal(string(testBackendSecret)+"-backend2", string(b2.Secret()))
}
updated_config := goconf.NewConfigFile()
@ -673,14 +605,10 @@ func TestBackendCommonSecret(t *testing.T) {
updated_config.AddOption("backend2", "url", u2.String())
cfg.Reload(updated_config)
if b1 := cfg.GetBackend(u1); b1 == nil {
t.Error("didn't get backend")
} else if !bytes.Equal(b1.Secret(), []byte(string(testBackendSecret)+"-backend1")) {
t.Errorf("expected secret %s, got %s", string(testBackendSecret)+"-backend1", string(b1.Secret()))
if b1 := cfg.GetBackend(u1); assert.NotNil(b1) {
assert.Equal(string(testBackendSecret)+"-backend1", string(b1.Secret()))
}
if b2 := cfg.GetBackend(u2); b2 == nil {
t.Error("didn't get backend")
} else if !bytes.Equal(b2.Secret(), testBackendSecret) {
t.Errorf("expected secret %s, got %s", string(testBackendSecret), string(b2.Secret()))
if b2 := cfg.GetBackend(u2); assert.NotNil(b2) {
assert.Equal(string(testBackendSecret), string(b2.Secret()))
}
}

File diff suppressed because it is too large Load diff

View file

@ -25,6 +25,7 @@ import (
"testing"
"github.com/dlintw/goconf"
"github.com/stretchr/testify/require"
"go.etcd.io/etcd/server/v3/embed"
)
@ -67,9 +68,7 @@ func Test_BackendStorageEtcdNoLeak(t *testing.T) {
config.AddOption("backend", "backendprefix", "/backends")
cfg, err := NewBackendConfiguration(config, client)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
<-tl.closed
cfg.Close()

View file

@ -25,17 +25,20 @@ import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestBackoff_Exponential(t *testing.T) {
t.Parallel()
backoff, err := NewExponentialBackoff(100*time.Millisecond, 500*time.Millisecond)
if err != nil {
t.Fatal(err)
}
assert := assert.New(t)
minWait := 100 * time.Millisecond
backoff, err := NewExponentialBackoff(minWait, 500*time.Millisecond)
require.NoError(t, err)
waitTimes := []time.Duration{
100 * time.Millisecond,
minWait,
200 * time.Millisecond,
400 * time.Millisecond,
500 * time.Millisecond,
@ -43,23 +46,17 @@ func TestBackoff_Exponential(t *testing.T) {
}
for _, wait := range waitTimes {
if backoff.NextWait() != wait {
t.Errorf("Wait time should be %s, got %s", wait, backoff.NextWait())
}
assert.Equal(wait, backoff.NextWait())
a := time.Now()
backoff.Wait(context.Background())
b := time.Now()
if b.Sub(a) < wait {
t.Errorf("Should have waited %s, got %s", wait, b.Sub(a))
}
assert.GreaterOrEqual(b.Sub(a), wait)
}
backoff.Reset()
a := time.Now()
backoff.Wait(context.Background())
b := time.Now()
if b.Sub(a) < 100*time.Millisecond {
t.Errorf("Should have waited %s, got %s", 100*time.Millisecond, b.Sub(a))
}
assert.GreaterOrEqual(b.Sub(a), minWait)
}

View file

@ -37,17 +37,16 @@ import (
"time"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func NewCapabilitiesForTestWithCallback(t *testing.T, callback func(*CapabilitiesResponse, http.ResponseWriter) error) (*url.URL, *Capabilities) {
require := require.New(t)
pool, err := NewHttpClientPool(1, false)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
capabilities, err := NewCapabilities("0.0", pool)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
r := mux.NewRouter()
server := httptest.NewServer(r)
@ -56,9 +55,7 @@ func NewCapabilitiesForTestWithCallback(t *testing.T, callback func(*Capabilitie
})
u, err := url.Parse(server.URL)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
handleCapabilitiesFunc := func(w http.ResponseWriter, r *http.Request) {
features := []string{
@ -91,9 +88,7 @@ func NewCapabilitiesForTestWithCallback(t *testing.T, callback func(*Capabilitie
}
data, err := json.Marshal(response)
if err != nil {
t.Errorf("Could not marshal %+v: %s", response, err)
}
assert.NoError(t, err, "Could not marshal %+v", response)
var ocs OcsResponse
ocs.Ocs = &OcsBody{
@ -104,9 +99,9 @@ func NewCapabilitiesForTestWithCallback(t *testing.T, callback func(*Capabilitie
},
Data: data,
}
if data, err = json.Marshal(ocs); err != nil {
t.Fatal(err)
}
data, err = json.Marshal(ocs)
require.NoError(err)
var cc []string
if !strings.Contains(t.Name(), "NoCache") {
if strings.Contains(t.Name(), "ShortCache") {
@ -177,70 +172,50 @@ func SetCapabilitiesGetNow(t *testing.T, capabilities *Capabilities, f func() ti
func TestCapabilities(t *testing.T) {
t.Parallel()
CatchLogForTest(t)
assert := assert.New(t)
url, capabilities := NewCapabilitiesForTest(t)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
if !capabilities.HasCapabilityFeature(ctx, url, "foo") {
t.Error("should have capability \"foo\"")
}
if capabilities.HasCapabilityFeature(ctx, url, "lala") {
t.Error("should not have capability \"lala\"")
}
assert.True(capabilities.HasCapabilityFeature(ctx, url, "foo"))
assert.False(capabilities.HasCapabilityFeature(ctx, url, "lala"))
expectedString := "bar"
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
t.Error("could not find value for \"foo\"")
} else if value != expectedString {
t.Errorf("expected value %s, got %s", expectedString, value)
} else if !cached {
t.Errorf("expected cached response")
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
assert.Equal(expectedString, value)
assert.True(cached)
}
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "baz"); found {
t.Errorf("should not have found value for \"baz\", got %s", value)
} else if !cached {
t.Errorf("expected cached response")
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "baz"); assert.False(found, "should not have found value for \"baz\", got %s", value) {
assert.True(cached)
}
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "invalid"); found {
t.Errorf("should not have found value for \"invalid\", got %s", value)
} else if !cached {
t.Errorf("expected cached response")
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "invalid"); assert.False(found, "should not have found value for \"invalid\", got %s", value) {
assert.True(cached)
}
if value, cached, found := capabilities.GetStringConfig(ctx, url, "invalid", "foo"); found {
t.Errorf("should not have found value for \"baz\", got %s", value)
} else if !cached {
t.Errorf("expected cached response")
if value, cached, found := capabilities.GetStringConfig(ctx, url, "invalid", "foo"); assert.False(found, "should not have found value for \"baz\", got %s", value) {
assert.True(cached)
}
expectedInt := 42
if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "baz"); !found {
t.Error("could not find value for \"baz\"")
} else if value != expectedInt {
t.Errorf("expected value %d, got %d", expectedInt, value)
} else if !cached {
t.Errorf("expected cached response")
if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "baz"); assert.True(found) {
assert.Equal(expectedInt, value)
assert.True(cached)
}
if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "foo"); found {
t.Errorf("should not have found value for \"foo\", got %d", value)
} else if !cached {
t.Errorf("expected cached response")
if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "foo"); assert.False(found, "should not have found value for \"foo\", got %d", value) {
assert.True(cached)
}
if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "invalid"); found {
t.Errorf("should not have found value for \"invalid\", got %d", value)
} else if !cached {
t.Errorf("expected cached response")
if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "invalid"); assert.False(found, "should not have found value for \"invalid\", got %d", value) {
assert.True(cached)
}
if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "invalid", "baz"); found {
t.Errorf("should not have found value for \"baz\", got %d", value)
} else if !cached {
t.Errorf("expected cached response")
if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "invalid", "baz"); assert.False(found, "should not have found value for \"baz\", got %d", value) {
assert.True(cached)
}
}
func TestInvalidateCapabilities(t *testing.T) {
t.Parallel()
CatchLogForTest(t)
assert := assert.New(t)
var called atomic.Uint32
url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error {
called.Add(1)
@ -251,47 +226,35 @@ func TestInvalidateCapabilities(t *testing.T) {
defer cancel()
expectedString := "bar"
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
t.Error("could not find value for \"foo\"")
} else if value != expectedString {
t.Errorf("expected value %s, got %s", expectedString, value)
} else if cached {
t.Errorf("expected direct response")
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
assert.Equal(expectedString, value)
assert.False(cached)
}
if value := called.Load(); value != 1 {
t.Errorf("expected called %d, got %d", 1, value)
}
value := called.Load()
assert.EqualValues(1, value)
// Invalidating will cause the capabilities to be reloaded.
capabilities.InvalidateCapabilities(url)
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
t.Error("could not find value for \"foo\"")
} else if value != expectedString {
t.Errorf("expected value %s, got %s", expectedString, value)
} else if cached {
t.Errorf("expected direct response")
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
assert.Equal(expectedString, value)
assert.False(cached)
}
if value := called.Load(); value != 2 {
t.Errorf("expected called %d, got %d", 2, value)
}
value = called.Load()
assert.EqualValues(2, value)
// Invalidating is throttled to about once per minute.
capabilities.InvalidateCapabilities(url)
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
t.Error("could not find value for \"foo\"")
} else if value != expectedString {
t.Errorf("expected value %s, got %s", expectedString, value)
} else if !cached {
t.Errorf("expected cached response")
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
assert.Equal(expectedString, value)
assert.True(cached)
}
if value := called.Load(); value != 2 {
t.Errorf("expected called %d, got %d", 2, value)
}
value = called.Load()
assert.EqualValues(2, value)
// At a later time, invalidating can be done again.
SetCapabilitiesGetNow(t, capabilities, func() time.Time {
@ -300,22 +263,19 @@ func TestInvalidateCapabilities(t *testing.T) {
capabilities.InvalidateCapabilities(url)
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
t.Error("could not find value for \"foo\"")
} else if value != expectedString {
t.Errorf("expected value %s, got %s", expectedString, value)
} else if cached {
t.Errorf("expected direct response")
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
assert.Equal(expectedString, value)
assert.False(cached)
}
if value := called.Load(); value != 3 {
t.Errorf("expected called %d, got %d", 3, value)
}
value = called.Load()
assert.EqualValues(3, value)
}
func TestCapabilitiesNoCache(t *testing.T) {
t.Parallel()
CatchLogForTest(t)
assert := assert.New(t)
var called atomic.Uint32
url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error {
called.Add(1)
@ -326,51 +286,40 @@ func TestCapabilitiesNoCache(t *testing.T) {
defer cancel()
expectedString := "bar"
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
t.Error("could not find value for \"foo\"")
} else if value != expectedString {
t.Errorf("expected value %s, got %s", expectedString, value)
} else if cached {
t.Errorf("expected direct response")
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
assert.Equal(expectedString, value)
assert.False(cached)
}
if value := called.Load(); value != 1 {
t.Errorf("expected called %d, got %d", 1, value)
}
value := called.Load()
assert.EqualValues(1, value)
// Capabilities are cached for some time if no "Cache-Control" header is set.
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
t.Error("could not find value for \"foo\"")
} else if value != expectedString {
t.Errorf("expected value %s, got %s", expectedString, value)
} else if !cached {
t.Errorf("expected cached response")
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
assert.Equal(expectedString, value)
assert.True(cached)
}
if value := called.Load(); value != 1 {
t.Errorf("expected called %d, got %d", 1, value)
}
value = called.Load()
assert.EqualValues(1, value)
SetCapabilitiesGetNow(t, capabilities, func() time.Time {
return time.Now().Add(minCapabilitiesCacheDuration)
})
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
t.Error("could not find value for \"foo\"")
} else if value != expectedString {
t.Errorf("expected value %s, got %s", expectedString, value)
} else if cached {
t.Errorf("expected direct response")
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
assert.Equal(expectedString, value)
assert.False(cached)
}
if value := called.Load(); value != 2 {
t.Errorf("expected called %d, got %d", 2, value)
}
value = called.Load()
assert.EqualValues(2, value)
}
func TestCapabilitiesShortCache(t *testing.T) {
t.Parallel()
CatchLogForTest(t)
assert := assert.New(t)
var called atomic.Uint32
url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error {
called.Add(1)
@ -381,76 +330,58 @@ func TestCapabilitiesShortCache(t *testing.T) {
defer cancel()
expectedString := "bar"
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
t.Error("could not find value for \"foo\"")
} else if value != expectedString {
t.Errorf("expected value %s, got %s", expectedString, value)
} else if cached {
t.Errorf("expected direct response")
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
assert.Equal(expectedString, value)
assert.False(cached)
}
if value := called.Load(); value != 1 {
t.Errorf("expected called %d, got %d", 1, value)
}
value := called.Load()
assert.EqualValues(1, value)
// Capabilities are cached for some time if no "Cache-Control" header is set.
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
t.Error("could not find value for \"foo\"")
} else if value != expectedString {
t.Errorf("expected value %s, got %s", expectedString, value)
} else if !cached {
t.Errorf("expected cached response")
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
assert.Equal(expectedString, value)
assert.True(cached)
}
if value := called.Load(); value != 1 {
t.Errorf("expected called %d, got %d", 1, value)
}
value = called.Load()
assert.EqualValues(1, value)
// The capabilities are cached for a minumum duration.
SetCapabilitiesGetNow(t, capabilities, func() time.Time {
return time.Now().Add(minCapabilitiesCacheDuration / 2)
})
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
t.Error("could not find value for \"foo\"")
} else if value != expectedString {
t.Errorf("expected value %s, got %s", expectedString, value)
} else if !cached {
t.Errorf("expected cached response")
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
assert.Equal(expectedString, value)
assert.True(cached)
}
SetCapabilitiesGetNow(t, capabilities, func() time.Time {
return time.Now().Add(minCapabilitiesCacheDuration)
})
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
t.Error("could not find value for \"foo\"")
} else if value != expectedString {
t.Errorf("expected value %s, got %s", expectedString, value)
} else if cached {
t.Errorf("expected direct response")
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
assert.Equal(expectedString, value)
assert.False(cached)
}
if value := called.Load(); value != 2 {
t.Errorf("expected called %d, got %d", 2, value)
}
value = called.Load()
assert.EqualValues(2, value)
}
func TestCapabilitiesNoCacheETag(t *testing.T) {
t.Parallel()
CatchLogForTest(t)
assert := assert.New(t)
var called atomic.Uint32
url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error {
ct := w.Header().Get("Content-Type")
switch called.Add(1) {
case 1:
if ct == "" {
t.Error("expected content-type on first request")
}
assert.NotEmpty(ct, "expected content-type on first request")
case 2:
if ct != "" {
t.Errorf("expected no content-type on second request, got %s", ct)
}
assert.Empty(ct, "expected no content-type on second request")
}
return nil
})
@ -459,38 +390,31 @@ func TestCapabilitiesNoCacheETag(t *testing.T) {
defer cancel()
expectedString := "bar"
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
t.Error("could not find value for \"foo\"")
} else if value != expectedString {
t.Errorf("expected value %s, got %s", expectedString, value)
} else if cached {
t.Errorf("expected direct response")
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
assert.Equal(expectedString, value)
assert.False(cached)
}
if value := called.Load(); value != 1 {
t.Errorf("expected called %d, got %d", 1, value)
}
value := called.Load()
assert.EqualValues(1, value)
SetCapabilitiesGetNow(t, capabilities, func() time.Time {
return time.Now().Add(minCapabilitiesCacheDuration)
})
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
t.Error("could not find value for \"foo\"")
} else if value != expectedString {
t.Errorf("expected value %s, got %s", expectedString, value)
} else if cached {
t.Errorf("expected direct response")
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
assert.Equal(expectedString, value)
assert.False(cached)
}
if value := called.Load(); value != 2 {
t.Errorf("expected called %d, got %d", 2, value)
}
value = called.Load()
assert.EqualValues(2, value)
}
func TestCapabilitiesCacheNoMustRevalidate(t *testing.T) {
t.Parallel()
CatchLogForTest(t)
assert := assert.New(t)
var called atomic.Uint32
url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error {
if called.Add(1) == 2 {
@ -504,17 +428,13 @@ func TestCapabilitiesCacheNoMustRevalidate(t *testing.T) {
defer cancel()
expectedString := "bar"
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
t.Error("could not find value for \"foo\"")
} else if value != expectedString {
t.Errorf("expected value %s, got %s", expectedString, value)
} else if cached {
t.Errorf("expected direct response")
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
assert.Equal(expectedString, value)
assert.False(cached)
}
if value := called.Load(); value != 1 {
t.Errorf("expected called %d, got %d", 1, value)
}
value := called.Load()
assert.EqualValues(1, value)
SetCapabilitiesGetNow(t, capabilities, func() time.Time {
return time.Now().Add(time.Minute)
@ -522,22 +442,19 @@ func TestCapabilitiesCacheNoMustRevalidate(t *testing.T) {
// Expired capabilities can still be used even in case of update errors if
// "must-revalidate" is not set.
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
t.Error("could not find value for \"foo\"")
} else if value != expectedString {
t.Errorf("expected value %s, got %s", expectedString, value)
} else if cached {
t.Errorf("expected direct response")
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
assert.Equal(expectedString, value)
assert.False(cached)
}
if value := called.Load(); value != 2 {
t.Errorf("expected called %d, got %d", 2, value)
}
value = called.Load()
assert.EqualValues(2, value)
}
func TestCapabilitiesNoCacheNoMustRevalidate(t *testing.T) {
t.Parallel()
CatchLogForTest(t)
assert := assert.New(t)
var called atomic.Uint32
url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error {
if called.Add(1) == 2 {
@ -551,17 +468,13 @@ func TestCapabilitiesNoCacheNoMustRevalidate(t *testing.T) {
defer cancel()
expectedString := "bar"
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
t.Error("could not find value for \"foo\"")
} else if value != expectedString {
t.Errorf("expected value %s, got %s", expectedString, value)
} else if cached {
t.Errorf("expected direct response")
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
assert.Equal(expectedString, value)
assert.False(cached)
}
if value := called.Load(); value != 1 {
t.Errorf("expected called %d, got %d", 1, value)
}
value := called.Load()
assert.EqualValues(1, value)
SetCapabilitiesGetNow(t, capabilities, func() time.Time {
return time.Now().Add(minCapabilitiesCacheDuration)
@ -569,22 +482,19 @@ func TestCapabilitiesNoCacheNoMustRevalidate(t *testing.T) {
// Expired capabilities can still be used even in case of update errors if
// "must-revalidate" is not set.
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
t.Error("could not find value for \"foo\"")
} else if value != expectedString {
t.Errorf("expected value %s, got %s", expectedString, value)
} else if cached {
t.Errorf("expected direct response")
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
assert.Equal(expectedString, value)
assert.False(cached)
}
if value := called.Load(); value != 2 {
t.Errorf("expected called %d, got %d", 2, value)
}
value = called.Load()
assert.EqualValues(2, value)
}
func TestCapabilitiesNoCacheMustRevalidate(t *testing.T) {
t.Parallel()
CatchLogForTest(t)
assert := assert.New(t)
var called atomic.Uint32
url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error {
if called.Add(1) == 2 {
@ -598,17 +508,13 @@ func TestCapabilitiesNoCacheMustRevalidate(t *testing.T) {
defer cancel()
expectedString := "bar"
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
t.Error("could not find value for \"foo\"")
} else if value != expectedString {
t.Errorf("expected value %s, got %s", expectedString, value)
} else if cached {
t.Errorf("expected direct response")
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
assert.Equal(expectedString, value)
assert.False(cached)
}
if value := called.Load(); value != 1 {
t.Errorf("expected called %d, got %d", 1, value)
}
value := called.Load()
assert.EqualValues(1, value)
SetCapabilitiesGetNow(t, capabilities, func() time.Time {
return time.Now().Add(minCapabilitiesCacheDuration)
@ -616,11 +522,9 @@ func TestCapabilitiesNoCacheMustRevalidate(t *testing.T) {
// Capabilities will be cleared if "must-revalidate" is set and an error
// occurs while fetching the updated data.
if value, _, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); found {
t.Errorf("should not have found value for \"foo\", got %s", value)
}
capaValue, _, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo")
assert.False(found, "should not have found value for \"foo\", got %s", capaValue)
if value := called.Load(); value != 2 {
t.Errorf("expected called %d, got %d", 2, value)
}
value = called.Load()
assert.EqualValues(2, value)
}

View file

@ -23,6 +23,8 @@ package signaling
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestChannelWaiters(t *testing.T) {
@ -42,9 +44,9 @@ func TestChannelWaiters(t *testing.T) {
select {
case <-ch1:
t.Error("should have not received another event")
assert.Fail(t, "should have not received another event")
case <-ch2:
t.Error("should have not received another event")
assert.Fail(t, "should have not received another event")
default:
}
@ -60,7 +62,7 @@ func TestChannelWaiters(t *testing.T) {
<-ch2
select {
case <-ch3:
t.Error("should have not received another event")
assert.Fail(t, "should have not received another event")
default:
}
}

View file

@ -26,6 +26,9 @@ import (
"net/url"
"strconv"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var (
@ -119,9 +122,7 @@ func Test_permissionsEqual(t *testing.T) {
t.Run(strconv.Itoa(idx), func(t *testing.T) {
t.Parallel()
equal := permissionsEqual(test.a, test.b)
if equal != test.equal {
t.Errorf("Expected %+v to be %s to %+v but was %s", test.a, equalStrings[test.equal], test.b, equalStrings[equal])
}
assert.Equal(t, test.equal, equal, "Expected %+v to be %s to %+v but was %s", test.a, equalStrings[test.equal], test.b, equalStrings[equal])
})
}
}
@ -129,17 +130,16 @@ func Test_permissionsEqual(t *testing.T) {
func TestBandwidth_Client(t *testing.T) {
t.Parallel()
CatchLogForTest(t)
require := require.New(t)
assert := assert.New(t)
hub, _, _, server := CreateHubForTest(t)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
mcu, err := NewTestMCU()
if err != nil {
t.Fatal(err)
} else if err := mcu.Start(ctx); err != nil {
t.Fatal(err)
}
require.NoError(err)
require.NoError(mcu.Start(ctx))
defer mcu.Stop()
hub.SetMcu(mcu)
@ -147,31 +147,23 @@ func TestBandwidth_Client(t *testing.T) {
client := NewTestClient(t, server, hub)
defer client.CloseWithBye()
if err := client.SendHello(testDefaultUserId); err != nil {
t.Fatal(err)
}
require.NoError(client.SendHello(testDefaultUserId))
hello, err := client.RunUntilHello(ctx)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
// Join room by id.
roomId := "test-room"
if room, err := client.JoinRoom(ctx, roomId); err != nil {
t.Fatal(err)
} else if room.Room.RoomId != roomId {
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
}
roomMsg, err := client.JoinRoom(ctx, roomId)
require.NoError(err)
require.Equal(roomId, roomMsg.Room.RoomId)
// We will receive a "joined" event.
if err := client.RunUntilJoined(ctx, hello.Hello); err != nil {
t.Error(err)
}
assert.NoError(client.RunUntilJoined(ctx, hello.Hello))
// Client may not send an offer with audio and video.
bitrate := 10000
if err := client.SendMessage(MessageClientMessageRecipient{
require.NoError(client.SendMessage(MessageClientMessageRecipient{
Type: "session",
SessionId: hello.Hello.SessionId,
}, MessageClientMessageData{
@ -182,22 +174,13 @@ func TestBandwidth_Client(t *testing.T) {
Payload: map[string]interface{}{
"sdp": MockSdpOfferAudioAndVideo,
},
}); err != nil {
t.Fatal(err)
}
}))
if err := client.RunUntilAnswer(ctx, MockSdpAnswerAudioAndVideo); err != nil {
t.Fatal(err)
}
require.NoError(client.RunUntilAnswer(ctx, MockSdpAnswerAudioAndVideo))
pub := mcu.GetPublisher(hello.Hello.SessionId)
if pub == nil {
t.Fatal("Could not find publisher")
}
if pub.bitrate != bitrate {
t.Errorf("Expected bitrate %d, got %d", bitrate, pub.bitrate)
}
require.NotNil(pub)
assert.Equal(bitrate, pub.bitrate)
}
func TestBandwidth_Backend(t *testing.T) {
@ -206,13 +189,9 @@ func TestBandwidth_Backend(t *testing.T) {
hub, _, _, server := CreateHubWithMultipleBackendsForTest(t)
u, err := url.Parse(server.URL + "/one")
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
backend := hub.backend.GetBackend(u)
if backend == nil {
t.Fatal("Could not get backend")
}
require.NotNil(t, backend, "Could not get backend")
backend.maxScreenBitrate = 1000
backend.maxStreamBitrate = 2000
@ -221,11 +200,8 @@ func TestBandwidth_Backend(t *testing.T) {
defer cancel()
mcu, err := NewTestMCU()
if err != nil {
t.Fatal(err)
} else if err := mcu.Start(ctx); err != nil {
t.Fatal(err)
}
require.NoError(t, err)
require.NoError(t, mcu.Start(ctx))
defer mcu.Stop()
hub.SetMcu(mcu)
@ -237,37 +213,31 @@ func TestBandwidth_Backend(t *testing.T) {
for _, streamType := range streamTypes {
t.Run(string(streamType), func(t *testing.T) {
require := require.New(t)
assert := assert.New(t)
client := NewTestClient(t, server, hub)
defer client.CloseWithBye()
params := TestBackendClientAuthParams{
UserId: testDefaultUserId,
}
if err := client.SendHelloParams(server.URL+"/one", HelloVersionV1, "client", nil, params); err != nil {
t.Fatal(err)
}
require.NoError(client.SendHelloParams(server.URL+"/one", HelloVersionV1, "client", nil, params))
hello, err := client.RunUntilHello(ctx)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
// Join room by id.
roomId := "test-room"
if room, err := client.JoinRoom(ctx, roomId); err != nil {
t.Fatal(err)
} else if room.Room.RoomId != roomId {
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
}
roomMsg, err := client.JoinRoom(ctx, roomId)
require.NoError(err)
require.Equal(roomId, roomMsg.Room.RoomId)
// We will receive a "joined" event.
if err := client.RunUntilJoined(ctx, hello.Hello); err != nil {
t.Error(err)
}
require.NoError(client.RunUntilJoined(ctx, hello.Hello))
// Client may not send an offer with audio and video.
bitrate := 10000
if err := client.SendMessage(MessageClientMessageRecipient{
require.NoError(client.SendMessage(MessageClientMessageRecipient{
Type: "session",
SessionId: hello.Hello.SessionId,
}, MessageClientMessageData{
@ -278,18 +248,12 @@ func TestBandwidth_Backend(t *testing.T) {
Payload: map[string]interface{}{
"sdp": MockSdpOfferAudioAndVideo,
},
}); err != nil {
t.Fatal(err)
}
}))
if err := client.RunUntilAnswer(ctx, MockSdpAnswerAudioAndVideo); err != nil {
t.Fatal(err)
}
require.NoError(client.RunUntilAnswer(ctx, MockSdpAnswerAudioAndVideo))
pub := mcu.GetPublisher(hello.Hello.SessionId)
if pub == nil {
t.Fatal("Could not find publisher")
}
require.NotNil(pub, "Could not find publisher")
var expectBitrate int
if streamType == StreamTypeVideo {
@ -297,9 +261,7 @@ func TestBandwidth_Backend(t *testing.T) {
} else {
expectBitrate = backend.maxScreenBitrate
}
if pub.bitrate != expectBitrate {
t.Errorf("Expected bitrate %d, got %d", expectBitrate, pub.bitrate)
}
assert.Equal(expectBitrate, pub.bitrate)
})
}
}

View file

@ -24,6 +24,8 @@ package signaling
import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
)
func TestCloserMulti(t *testing.T) {
@ -39,24 +41,16 @@ func TestCloserMulti(t *testing.T) {
}()
}
if closer.IsClosed() {
t.Error("should not be closed")
}
assert.False(t, closer.IsClosed())
closer.Close()
if !closer.IsClosed() {
t.Error("should be closed")
}
assert.True(t, closer.IsClosed())
wg.Wait()
}
func TestCloserCloseBeforeWait(t *testing.T) {
closer := NewCloser()
closer.Close()
if !closer.IsClosed() {
t.Error("should be closed")
}
assert.True(t, closer.IsClosed())
<-closer.C
if !closer.IsClosed() {
t.Error("should be closed")
}
assert.True(t, closer.IsClosed())
}

View file

@ -25,75 +25,52 @@ import (
"strconv"
"sync"
"testing"
"github.com/stretchr/testify/assert"
)
func TestConcurrentStringStringMap(t *testing.T) {
assert := assert.New(t)
var m ConcurrentStringStringMap
if m.Len() != 0 {
t.Errorf("Expected %d entries, got %d", 0, m.Len())
}
if v, found := m.Get("foo"); found {
t.Errorf("Expected missing entry, got %s", v)
}
assert.Equal(0, m.Len())
v, found := m.Get("foo")
assert.False(found, "Expected missing entry, got %s", v)
m.Set("foo", "bar")
if m.Len() != 1 {
t.Errorf("Expected %d entries, got %d", 1, m.Len())
}
if v, found := m.Get("foo"); !found {
t.Errorf("Expected entry")
} else if v != "bar" {
t.Errorf("Expected bar, got %s", v)
assert.Equal(1, m.Len())
if v, found := m.Get("foo"); assert.True(found) {
assert.Equal("bar", v)
}
m.Set("foo", "baz")
if m.Len() != 1 {
t.Errorf("Expected %d entries, got %d", 1, m.Len())
}
if v, found := m.Get("foo"); !found {
t.Errorf("Expected entry")
} else if v != "baz" {
t.Errorf("Expected baz, got %s", v)
assert.Equal(1, m.Len())
if v, found := m.Get("foo"); assert.True(found) {
assert.Equal("baz", v)
}
m.Set("lala", "lolo")
if m.Len() != 2 {
t.Errorf("Expected %d entries, got %d", 2, m.Len())
}
if v, found := m.Get("lala"); !found {
t.Errorf("Expected entry")
} else if v != "lolo" {
t.Errorf("Expected lolo, got %s", v)
assert.Equal(2, m.Len())
if v, found := m.Get("lala"); assert.True(found) {
assert.Equal("lolo", v)
}
// Deleting missing entries doesn't do anything.
m.Del("xyz")
if m.Len() != 2 {
t.Errorf("Expected %d entries, got %d", 2, m.Len())
assert.Equal(2, m.Len())
if v, found := m.Get("foo"); assert.True(found) {
assert.Equal("baz", v)
}
if v, found := m.Get("foo"); !found {
t.Errorf("Expected entry")
} else if v != "baz" {
t.Errorf("Expected baz, got %s", v)
}
if v, found := m.Get("lala"); !found {
t.Errorf("Expected entry")
} else if v != "lolo" {
t.Errorf("Expected lolo, got %s", v)
if v, found := m.Get("lala"); assert.True(found) {
assert.Equal("lolo", v)
}
m.Del("lala")
if m.Len() != 1 {
t.Errorf("Expected %d entries, got %d", 2, m.Len())
}
if v, found := m.Get("foo"); !found {
t.Errorf("Expected entry")
} else if v != "baz" {
t.Errorf("Expected baz, got %s", v)
}
if v, found := m.Get("lala"); found {
t.Errorf("Expected missing entry, got %s", v)
assert.Equal(1, m.Len())
if v, found := m.Get("foo"); assert.True(found) {
assert.Equal("baz", v)
}
v, found = m.Get("lala")
assert.False(found, "Expected missing entry, got %s", v)
m.Clear()
var wg sync.WaitGroup
@ -108,18 +85,13 @@ func TestConcurrentStringStringMap(t *testing.T) {
for y := 0; y < count; y = y + 1 {
value := newRandomString(32)
m.Set(key, value)
if v, found := m.Get(key); !found {
t.Errorf("Expected entry for key %s", key)
return
} else if v != value {
t.Errorf("Expected value %s for key %s, got %s", value, key, v)
if v, found := m.Get(key); !assert.True(found, "Expected entry for key %s", key) ||
!assert.Equal(value, v, "Unexpected value for key %s", key) {
return
}
}
}(x)
}
wg.Wait()
if m.Len() != concurrency {
t.Errorf("Expected %d entries, got %d", concurrency, m.Len())
}
assert.Equal(concurrency, m.Len())
}

View file

@ -22,10 +22,11 @@
package signaling
import (
"reflect"
"testing"
"github.com/dlintw/goconf"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestStringOptions(t *testing.T) {
@ -46,13 +47,8 @@ func TestStringOptions(t *testing.T) {
config.AddOption("default", "three", "3")
options, err := GetStringOptions(config, "foo", false)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(expected, options) {
t.Errorf("expected %+v, got %+v", expected, options)
}
require.NoError(t, err)
assert.Equal(t, expected, options)
}
func TestStringOptionWithEnv(t *testing.T) {
@ -82,10 +78,8 @@ func TestStringOptionWithEnv(t *testing.T) {
}
for k, v := range expected {
value, err := GetStringOptionWithEnv(config, "test", k)
if err != nil {
t.Errorf("expected value for %s, got %s", k, err)
} else if value != v {
t.Errorf("expected value %s for %s, got %s", v, k, value)
if assert.NoError(t, err, "expected value for %s", k) {
assert.Equal(t, v, value, "unexpected value for %s", k)
}
}

View file

@ -24,6 +24,8 @@ package signaling
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestDeferredExecutor_MultiClose(t *testing.T) {
@ -53,9 +55,7 @@ func TestDeferredExecutor_QueueSize(t *testing.T) {
b := time.Now()
delta := b.Sub(a)
// Allow one millisecond less delay to account for time variance on CI runners.
if delta+time.Millisecond < delay {
t.Errorf("Expected a delay of %s, got %s", delay, delta)
}
assert.GreaterOrEqual(t, delta+time.Millisecond, delay)
}
func TestDeferredExecutor_Order(t *testing.T) {
@ -81,9 +81,7 @@ func TestDeferredExecutor_Order(t *testing.T) {
<-done
for x := 0; x < 10; x++ {
if entries[x] != x {
t.Errorf("Expected %d at position %d, got %d", x, x, entries[x])
}
assert.Equal(t, entries[x], x, "Unexpected at position %d", x)
}
}
@ -108,7 +106,7 @@ func TestDeferredExecutor_DeferAfterClose(t *testing.T) {
e.Close()
e.Execute(func() {
t.Error("method should not have been called")
assert.Fail(t, "method should not have been called")
})
}

View file

@ -30,6 +30,9 @@ import (
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type mockDnsLookup struct {
@ -82,20 +85,16 @@ func (m *mockDnsLookup) lookup(host string) ([]net.IP, error) {
func newDnsMonitorForTest(t *testing.T, interval time.Duration) *DnsMonitor {
t.Helper()
require := require.New(t)
monitor, err := NewDnsMonitor(interval)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
t.Cleanup(func() {
monitor.Stop()
})
if err := monitor.Start(); err != nil {
t.Fatal(err)
}
require.NoError(monitor.Start())
return monitor
}
@ -148,20 +147,18 @@ func (r *dnsMonitorReceiver) OnLookup(entry *DnsMonitorEntry, all, add, keep, re
expected := r.expected
r.expected = nil
if expected == expectNone {
r.t.Errorf("expected no event, got %v", received)
assert.Fail(r.t, "expected no event, got %v", received)
return
}
if expected == nil {
if r.received != nil && !r.received.Equal(received) {
r.t.Errorf("already received %v, got %v", r.received, received)
assert.Fail(r.t, "already received %v, got %v", r.received, received)
}
return
}
if !expected.Equal(received) {
r.t.Errorf("expected %v, got %v", expected, received)
}
assert.True(r.t, expected.Equal(received), "expected %v, got %v", expected, received)
r.received = nil
r.expected = nil
}
@ -178,7 +175,7 @@ func (r *dnsMonitorReceiver) WaitForExpected(ctx context.Context) {
select {
case <-ticker.C:
case <-ctx.Done():
r.t.Error(ctx.Err())
assert.NoError(r.t, ctx.Err())
abort = true
}
r.Lock()
@ -191,7 +188,7 @@ func (r *dnsMonitorReceiver) Expect(all, add, keep, remove []net.IP) {
defer r.Unlock()
if r.expected != nil && r.expected != expectNone {
r.t.Errorf("didn't get previously expected %v", r.expected)
assert.Fail(r.t, "didn't get previously expected %v", r.expected)
}
expected := &dnsMonitorReceiverRecord{
@ -214,7 +211,7 @@ func (r *dnsMonitorReceiver) ExpectNone() {
defer r.Unlock()
if r.expected != nil && r.expected != expectNone {
r.t.Errorf("didn't get previously expected %v", r.expected)
assert.Fail(r.t, "didn't get previously expected %v", r.expected)
}
r.expected = expectNone
@ -241,9 +238,7 @@ func TestDnsMonitor(t *testing.T) {
rec1.Expect(ips1, ips1, nil, nil)
entry1, err := monitor.Add("https://foo:12345", rec1.OnLookup)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer monitor.Remove(entry1)
rec1.WaitForExpected(ctx)
@ -306,9 +301,7 @@ func TestDnsMonitorIP(t *testing.T) {
rec1.Expect(ips, ips, nil, nil)
entry, err := monitor.Add(ip+":12345", rec1.OnLookup)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer monitor.Remove(entry)
rec1.WaitForExpected(ctx)
@ -328,9 +321,7 @@ func TestDnsMonitorNoLookupIfEmpty(t *testing.T) {
}
time.Sleep(10 * interval)
if checked.Load() {
t.Error("should not have checked hostnames")
}
assert.False(t, checked.Load(), "should not have checked hostnames")
}
type deadlockMonitorReceiver struct {
@ -355,8 +346,7 @@ func newDeadlockMonitorReceiver(t *testing.T, monitor *DnsMonitor) *deadlockMoni
}
func (r *deadlockMonitorReceiver) OnLookup(entry *DnsMonitorEntry, all []net.IP, add []net.IP, keep []net.IP, remove []net.IP) {
if r.closed.Load() {
r.t.Error("received lookup after closed")
if !assert.False(r.t, r.closed.Load(), "received lookup after closed") {
return
}
@ -385,8 +375,7 @@ func (r *deadlockMonitorReceiver) Start() {
defer r.mu.Unlock()
entry, err := r.monitor.Add("foo", r.OnLookup)
if err != nil {
r.t.Errorf("error adding listener: %s", err)
if !assert.NoError(r.t, err) {
return
}
@ -422,7 +411,5 @@ func TestDnsMonitorDeadlock(t *testing.T) {
time.Sleep(10 * interval)
monitor.mu.Lock()
defer monitor.mu.Unlock()
if len(monitor.hostnames) > 0 {
t.Errorf("should have cleared hostnames, got %+v", monitor.hostnames)
}
assert.Empty(t, monitor.hostnames)
}

View file

@ -34,6 +34,8 @@ import (
"time"
"github.com/dlintw/goconf"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.etcd.io/etcd/api/v3/mvccpb"
clientv3 "go.etcd.io/etcd/client/v3"
"go.etcd.io/etcd/server/v3/embed"
@ -66,15 +68,14 @@ func isErrorAddressAlreadyInUse(err error) bool {
}
func NewEtcdForTest(t *testing.T) *embed.Etcd {
require := require.New(t)
cfg := embed.NewConfig()
cfg.Dir = t.TempDir()
os.Chmod(cfg.Dir, 0700) // nolint
cfg.LogLevel = "warn"
u, err := url.Parse(etcdListenUrl)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
// Find a free port to bind the server to.
var etcd *embed.Etcd
@ -94,14 +95,12 @@ func NewEtcdForTest(t *testing.T) *embed.Etcd {
etcd, err = embed.StartEtcd(cfg)
if isErrorAddressAlreadyInUse(err) {
continue
} else if err != nil {
t.Fatal(err)
}
require.NoError(err)
break
}
if etcd == nil {
t.Fatal("could not find free port")
}
require.NotNil(etcd, "could not find free port")
t.Cleanup(func() {
etcd.Close()
@ -121,13 +120,9 @@ func NewEtcdClientForTest(t *testing.T) (*embed.Etcd, *EtcdClient) {
config.AddOption("etcd", "loglevel", "error")
client, err := NewEtcdClient(config, "")
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
t.Cleanup(func() {
if err := client.Close(); err != nil {
t.Error(err)
}
assert.NoError(t, client.Close())
})
return etcd, client
}
@ -149,54 +144,44 @@ func DeleteEtcdValue(etcd *embed.Etcd, key string) {
func Test_EtcdClient_Get(t *testing.T) {
t.Parallel()
CatchLogForTest(t)
assert := assert.New(t)
etcd, client := NewEtcdClientForTest(t)
if response, err := client.Get(context.Background(), "foo"); err != nil {
t.Error(err)
} else if response.Count != 0 {
t.Errorf("expected 0 response, got %d", response.Count)
if response, err := client.Get(context.Background(), "foo"); assert.NoError(err) {
assert.EqualValues(0, response.Count)
}
SetEtcdValue(etcd, "foo", []byte("bar"))
if response, err := client.Get(context.Background(), "foo"); err != nil {
t.Error(err)
} else if response.Count != 1 {
t.Errorf("expected 1 responses, got %d", response.Count)
} else if string(response.Kvs[0].Key) != "foo" {
t.Errorf("expected key \"foo\", got \"%s\"", string(response.Kvs[0].Key))
} else if string(response.Kvs[0].Value) != "bar" {
t.Errorf("expected value \"bar\", got \"%s\"", string(response.Kvs[0].Value))
if response, err := client.Get(context.Background(), "foo"); assert.NoError(err) {
if assert.EqualValues(1, response.Count) {
assert.Equal("foo", string(response.Kvs[0].Key))
assert.Equal("bar", string(response.Kvs[0].Value))
}
}
}
func Test_EtcdClient_GetPrefix(t *testing.T) {
t.Parallel()
CatchLogForTest(t)
assert := assert.New(t)
etcd, client := NewEtcdClientForTest(t)
if response, err := client.Get(context.Background(), "foo"); err != nil {
t.Error(err)
} else if response.Count != 0 {
t.Errorf("expected 0 response, got %d", response.Count)
if response, err := client.Get(context.Background(), "foo"); assert.NoError(err) {
assert.EqualValues(0, response.Count)
}
SetEtcdValue(etcd, "foo", []byte("1"))
SetEtcdValue(etcd, "foo/lala", []byte("2"))
SetEtcdValue(etcd, "lala/foo", []byte("3"))
if response, err := client.Get(context.Background(), "foo", clientv3.WithPrefix()); err != nil {
t.Error(err)
} else if response.Count != 2 {
t.Errorf("expected 2 responses, got %d", response.Count)
} else if string(response.Kvs[0].Key) != "foo" {
t.Errorf("expected key \"foo\", got \"%s\"", string(response.Kvs[0].Key))
} else if string(response.Kvs[0].Value) != "1" {
t.Errorf("expected value \"1\", got \"%s\"", string(response.Kvs[0].Value))
} else if string(response.Kvs[1].Key) != "foo/lala" {
t.Errorf("expected key \"foo/lala\", got \"%s\"", string(response.Kvs[1].Key))
} else if string(response.Kvs[1].Value) != "2" {
t.Errorf("expected value \"2\", got \"%s\"", string(response.Kvs[1].Value))
if response, err := client.Get(context.Background(), "foo", clientv3.WithPrefix()); assert.NoError(err) {
if assert.EqualValues(2, response.Count) {
assert.Equal("foo", string(response.Kvs[0].Key))
assert.Equal("1", string(response.Kvs[0].Value))
assert.Equal("foo/lala", string(response.Kvs[1].Key))
assert.Equal("2", string(response.Kvs[1].Value))
}
}
}
@ -237,8 +222,8 @@ func (l *EtcdClientTestListener) Close() {
func (l *EtcdClientTestListener) EtcdClientCreated(client *EtcdClient) {
go func() {
if err := client.WaitForConnection(l.ctx); err != nil {
l.t.Errorf("error waiting for connection: %s", err)
assert := assert.New(l.t)
if err := client.WaitForConnection(l.ctx); !assert.NoError(err) {
return
}
@ -246,23 +231,17 @@ func (l *EtcdClientTestListener) EtcdClientCreated(client *EtcdClient) {
defer cancel()
response, err := client.Get(ctx, "foo", clientv3.WithPrefix())
if err != nil {
l.t.Error(err)
} else if response.Count != 1 {
l.t.Errorf("expected 1 responses, got %d", response.Count)
} else if string(response.Kvs[0].Key) != "foo/a" {
l.t.Errorf("expected key \"foo/a\", got \"%s\"", string(response.Kvs[0].Key))
} else if string(response.Kvs[0].Value) != "1" {
l.t.Errorf("expected value \"1\", got \"%s\"", string(response.Kvs[0].Value))
if assert.NoError(err) && assert.EqualValues(1, response.Count) {
assert.Equal("foo/a", string(response.Kvs[0].Key))
assert.Equal("1", string(response.Kvs[0].Value))
}
close(l.initial)
nextRevision := response.Header.Revision + 1
for l.ctx.Err() == nil {
var err error
if nextRevision, err = client.Watch(clientv3.WithRequireLeader(l.ctx), "foo", nextRevision, l, clientv3.WithPrefix()); err != nil {
l.t.Error(err)
}
nextRevision, err = client.Watch(clientv3.WithRequireLeader(l.ctx), "foo", nextRevision, l, clientv3.WithPrefix())
assert.NoError(err)
}
}()
}
@ -296,6 +275,7 @@ func (l *EtcdClientTestListener) EtcdKeyDeleted(client *EtcdClient, key string,
func Test_EtcdClient_Watch(t *testing.T) {
t.Parallel()
CatchLogForTest(t)
assert := assert.New(t)
etcd, client := NewEtcdClientForTest(t)
SetEtcdValue(etcd, "foo/a", []byte("1"))
@ -310,31 +290,19 @@ func Test_EtcdClient_Watch(t *testing.T) {
SetEtcdValue(etcd, "foo/b", []byte("2"))
event := <-listener.events
if event.t != clientv3.EventTypePut {
t.Errorf("expected type %d, got %d", clientv3.EventTypePut, event.t)
} else if event.key != "foo/b" {
t.Errorf("expected key %s, got %s", "foo/b", event.key)
} else if event.value != "2" {
t.Errorf("expected value %s, got %s", "2", event.value)
}
assert.Equal(clientv3.EventTypePut, event.t)
assert.Equal("foo/b", event.key)
assert.Equal("2", event.value)
SetEtcdValue(etcd, "foo/a", []byte("3"))
event = <-listener.events
if event.t != clientv3.EventTypePut {
t.Errorf("expected type %d, got %d", clientv3.EventTypePut, event.t)
} else if event.key != "foo/a" {
t.Errorf("expected key %s, got %s", "foo/a", event.key)
} else if event.value != "3" {
t.Errorf("expected value %s, got %s", "3", event.value)
}
assert.Equal(clientv3.EventTypePut, event.t)
assert.Equal("foo/a", event.key)
assert.Equal("3", event.value)
DeleteEtcdValue(etcd, "foo/a")
event = <-listener.events
if event.t != clientv3.EventTypeDelete {
t.Errorf("expected type %d, got %d", clientv3.EventTypeDelete, event.t)
} else if event.key != "foo/a" {
t.Errorf("expected key %s, got %s", "foo/a", event.key)
} else if event.prevValue != "3" {
t.Errorf("expected previous value %s, got %s", "3", event.prevValue)
}
assert.Equal(clientv3.EventTypeDelete, event.t)
assert.Equal("foo/a", event.key)
assert.Equal("3", event.prevValue)
}

View file

@ -330,10 +330,10 @@ func Test_Federation(t *testing.T) {
ctx2, cancel2 := context.WithTimeout(ctx, 200*time.Millisecond)
defer cancel2()
if message, err := client2.RunUntilMessage(ctx2); err != nil && err != ErrNoMessageReceived && err != context.DeadlineExceeded {
t.Error(err)
} else {
assert.Nil(message)
if message, err := client2.RunUntilMessage(ctx2); err == nil {
assert.Fail("expected no message, got %+v", message)
} else if err != ErrNoMessageReceived && err != context.DeadlineExceeded {
assert.NoError(err)
}
}
@ -345,10 +345,10 @@ func Test_Federation(t *testing.T) {
ctx2, cancel2 := context.WithTimeout(ctx, 200*time.Millisecond)
defer cancel2()
if message, err := client2.RunUntilMessage(ctx2); err != nil && err != ErrNoMessageReceived && err != context.DeadlineExceeded {
t.Error(err)
} else {
assert.Nil(message)
if message, err := client2.RunUntilMessage(ctx2); err == nil {
assert.Fail("expected no message, got %+v", message)
} else if err != ErrNoMessageReceived && err != context.DeadlineExceeded {
assert.NoError(err)
}
}

View file

@ -23,10 +23,12 @@ package signaling
import (
"context"
"errors"
"os"
"path"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var (
@ -34,38 +36,31 @@ var (
)
func TestFileWatcher_NotExist(t *testing.T) {
assert := assert.New(t)
tmpdir := t.TempDir()
w, err := NewFileWatcher(path.Join(tmpdir, "test.txt"), func(filename string) {})
if err == nil {
t.Error("should not be able to watch non-existing files")
if err := w.Close(); err != nil {
t.Error(err)
if w, err := NewFileWatcher(path.Join(tmpdir, "test.txt"), func(filename string) {}); !assert.ErrorIs(err, os.ErrNotExist) {
if w != nil {
assert.NoError(w.Close())
}
} else if !errors.Is(err, os.ErrNotExist) {
t.Error(err)
}
}
func TestFileWatcher_File(t *testing.T) {
ensureNoGoroutinesLeak(t, func(t *testing.T) {
require := require.New(t)
assert := assert.New(t)
tmpdir := t.TempDir()
filename := path.Join(tmpdir, "test.txt")
if err := os.WriteFile(filename, []byte("Hello world!"), 0644); err != nil {
t.Fatal(err)
}
require.NoError(os.WriteFile(filename, []byte("Hello world!"), 0644))
modified := make(chan struct{})
w, err := NewFileWatcher(filename, func(filename string) {
modified <- struct{}{}
})
if err != nil {
t.Fatal(err)
}
require.NoError(err)
defer w.Close()
if err := os.WriteFile(filename, []byte("Updated"), 0644); err != nil {
t.Fatal(err)
}
require.NoError(os.WriteFile(filename, []byte("Updated"), 0644))
<-modified
ctxTimeout, cancel := context.WithTimeout(context.Background(), testWatcherNoEventTimeout)
@ -73,13 +68,11 @@ func TestFileWatcher_File(t *testing.T) {
select {
case <-modified:
t.Error("should not have received another event")
assert.Fail("should not have received another event")
case <-ctxTimeout.Done():
}
if err := os.WriteFile(filename, []byte("Updated"), 0644); err != nil {
t.Fatal(err)
}
require.NoError(os.WriteFile(filename, []byte("Updated"), 0644))
<-modified
ctxTimeout, cancel = context.WithTimeout(context.Background(), testWatcherNoEventTimeout)
@ -87,45 +80,39 @@ func TestFileWatcher_File(t *testing.T) {
select {
case <-modified:
t.Error("should not have received another event")
assert.Fail("should not have received another event")
case <-ctxTimeout.Done():
}
})
}
func TestFileWatcher_Rename(t *testing.T) {
require := require.New(t)
assert := assert.New(t)
tmpdir := t.TempDir()
filename := path.Join(tmpdir, "test.txt")
if err := os.WriteFile(filename, []byte("Hello world!"), 0644); err != nil {
t.Fatal(err)
}
require.NoError(os.WriteFile(filename, []byte("Hello world!"), 0644))
modified := make(chan struct{})
w, err := NewFileWatcher(filename, func(filename string) {
modified <- struct{}{}
})
if err != nil {
t.Fatal(err)
}
require.NoError(err)
defer w.Close()
filename2 := path.Join(tmpdir, "test.txt.tmp")
if err := os.WriteFile(filename2, []byte("Updated"), 0644); err != nil {
t.Fatal(err)
}
require.NoError(os.WriteFile(filename2, []byte("Updated"), 0644))
ctxTimeout, cancel := context.WithTimeout(context.Background(), testWatcherNoEventTimeout)
defer cancel()
select {
case <-modified:
t.Error("should not have received another event")
assert.Fail("should not have received another event")
case <-ctxTimeout.Done():
}
if err := os.Rename(filename2, filename); err != nil {
t.Fatal(err)
}
require.NoError(os.Rename(filename2, filename))
<-modified
ctxTimeout, cancel = context.WithTimeout(context.Background(), testWatcherNoEventTimeout)
@ -133,35 +120,29 @@ func TestFileWatcher_Rename(t *testing.T) {
select {
case <-modified:
t.Error("should not have received another event")
assert.Fail("should not have received another event")
case <-ctxTimeout.Done():
}
}
func TestFileWatcher_Symlink(t *testing.T) {
require := require.New(t)
assert := assert.New(t)
tmpdir := t.TempDir()
sourceFilename := path.Join(tmpdir, "test1.txt")
if err := os.WriteFile(sourceFilename, []byte("Hello world!"), 0644); err != nil {
t.Fatal(err)
}
require.NoError(os.WriteFile(sourceFilename, []byte("Hello world!"), 0644))
filename := path.Join(tmpdir, "symlink.txt")
if err := os.Symlink(sourceFilename, filename); err != nil {
t.Fatal(err)
}
require.NoError(os.Symlink(sourceFilename, filename))
modified := make(chan struct{})
w, err := NewFileWatcher(filename, func(filename string) {
modified <- struct{}{}
})
if err != nil {
t.Fatal(err)
}
require.NoError(err)
defer w.Close()
if err := os.WriteFile(sourceFilename, []byte("Updated"), 0644); err != nil {
t.Fatal(err)
}
require.NoError(os.WriteFile(sourceFilename, []byte("Updated"), 0644))
<-modified
ctxTimeout, cancel := context.WithTimeout(context.Background(), testWatcherNoEventTimeout)
@ -169,44 +150,34 @@ func TestFileWatcher_Symlink(t *testing.T) {
select {
case <-modified:
t.Error("should not have received another event")
assert.Fail("should not have received another event")
case <-ctxTimeout.Done():
}
}
func TestFileWatcher_ChangeSymlinkTarget(t *testing.T) {
require := require.New(t)
assert := assert.New(t)
tmpdir := t.TempDir()
sourceFilename1 := path.Join(tmpdir, "test1.txt")
if err := os.WriteFile(sourceFilename1, []byte("Hello world!"), 0644); err != nil {
t.Fatal(err)
}
require.NoError(os.WriteFile(sourceFilename1, []byte("Hello world!"), 0644))
sourceFilename2 := path.Join(tmpdir, "test2.txt")
if err := os.WriteFile(sourceFilename2, []byte("Updated"), 0644); err != nil {
t.Fatal(err)
}
require.NoError(os.WriteFile(sourceFilename2, []byte("Updated"), 0644))
filename := path.Join(tmpdir, "symlink.txt")
if err := os.Symlink(sourceFilename1, filename); err != nil {
t.Fatal(err)
}
require.NoError(os.Symlink(sourceFilename1, filename))
modified := make(chan struct{})
w, err := NewFileWatcher(filename, func(filename string) {
modified <- struct{}{}
})
if err != nil {
t.Fatal(err)
}
require.NoError(err)
defer w.Close()
// Replace symlink by creating new one and rename it to the original target.
if err := os.Symlink(sourceFilename2, filename+".tmp"); err != nil {
t.Fatal(err)
}
if err := os.Rename(filename+".tmp", filename); err != nil {
t.Fatal(err)
}
require.NoError(os.Symlink(sourceFilename2, filename+".tmp"))
require.NoError(os.Rename(filename+".tmp", filename))
<-modified
ctxTimeout, cancel := context.WithTimeout(context.Background(), testWatcherNoEventTimeout)
@ -214,89 +185,73 @@ func TestFileWatcher_ChangeSymlinkTarget(t *testing.T) {
select {
case <-modified:
t.Error("should not have received another event")
assert.Fail("should not have received another event")
case <-ctxTimeout.Done():
}
}
func TestFileWatcher_OtherSymlink(t *testing.T) {
require := require.New(t)
assert := assert.New(t)
tmpdir := t.TempDir()
sourceFilename1 := path.Join(tmpdir, "test1.txt")
if err := os.WriteFile(sourceFilename1, []byte("Hello world!"), 0644); err != nil {
t.Fatal(err)
}
require.NoError(os.WriteFile(sourceFilename1, []byte("Hello world!"), 0644))
sourceFilename2 := path.Join(tmpdir, "test2.txt")
if err := os.WriteFile(sourceFilename2, []byte("Updated"), 0644); err != nil {
t.Fatal(err)
}
require.NoError(os.WriteFile(sourceFilename2, []byte("Updated"), 0644))
filename := path.Join(tmpdir, "symlink.txt")
if err := os.Symlink(sourceFilename1, filename); err != nil {
t.Fatal(err)
}
require.NoError(os.Symlink(sourceFilename1, filename))
modified := make(chan struct{})
w, err := NewFileWatcher(filename, func(filename string) {
modified <- struct{}{}
})
if err != nil {
t.Fatal(err)
}
require.NoError(err)
defer w.Close()
if err := os.Symlink(sourceFilename2, filename+".tmp"); err != nil {
t.Fatal(err)
}
require.NoError(os.Symlink(sourceFilename2, filename+".tmp"))
ctxTimeout, cancel := context.WithTimeout(context.Background(), testWatcherNoEventTimeout)
defer cancel()
select {
case <-modified:
t.Error("should not have received event for other symlink")
assert.Fail("should not have received event for other symlink")
case <-ctxTimeout.Done():
}
}
func TestFileWatcher_RenameSymlinkTarget(t *testing.T) {
require := require.New(t)
assert := assert.New(t)
tmpdir := t.TempDir()
sourceFilename1 := path.Join(tmpdir, "test1.txt")
if err := os.WriteFile(sourceFilename1, []byte("Hello world!"), 0644); err != nil {
t.Fatal(err)
}
require.NoError(os.WriteFile(sourceFilename1, []byte("Hello world!"), 0644))
filename := path.Join(tmpdir, "test.txt")
if err := os.Symlink(sourceFilename1, filename); err != nil {
t.Fatal(err)
}
require.NoError(os.Symlink(sourceFilename1, filename))
modified := make(chan struct{})
w, err := NewFileWatcher(filename, func(filename string) {
modified <- struct{}{}
})
if err != nil {
t.Fatal(err)
}
require.NoError(err)
defer w.Close()
sourceFilename2 := path.Join(tmpdir, "test1.txt.tmp")
if err := os.WriteFile(sourceFilename2, []byte("Updated"), 0644); err != nil {
t.Fatal(err)
}
require.NoError(os.WriteFile(sourceFilename2, []byte("Updated"), 0644))
ctxTimeout, cancel := context.WithTimeout(context.Background(), testWatcherNoEventTimeout)
defer cancel()
select {
case <-modified:
t.Error("should not have received another event")
assert.Fail("should not have received another event")
case <-ctxTimeout.Done():
}
if err := os.Rename(sourceFilename2, sourceFilename1); err != nil {
t.Fatal(err)
}
require.NoError(os.Rename(sourceFilename2, sourceFilename1))
<-modified
ctxTimeout, cancel = context.WithTimeout(context.Background(), testWatcherNoEventTimeout)
@ -304,7 +259,7 @@ func TestFileWatcher_RenameSymlinkTarget(t *testing.T) {
select {
case <-modified:
t.Error("should not have received another event")
assert.Fail("should not have received another event")
case <-ctxTimeout.Done():
}
}

View file

@ -25,55 +25,28 @@ import (
"sync"
"sync/atomic"
"testing"
"github.com/stretchr/testify/assert"
)
func TestFlags(t *testing.T) {
assert := assert.New(t)
var f Flags
if f.Get() != 0 {
t.Fatalf("Expected flags 0, got %d", f.Get())
}
if !f.Add(1) {
t.Error("expected true")
}
if f.Get() != 1 {
t.Fatalf("Expected flags 1, got %d", f.Get())
}
if f.Add(1) {
t.Error("expected false")
}
if f.Get() != 1 {
t.Fatalf("Expected flags 1, got %d", f.Get())
}
if !f.Add(2) {
t.Error("expected true")
}
if f.Get() != 3 {
t.Fatalf("Expected flags 3, got %d", f.Get())
}
if !f.Remove(1) {
t.Error("expected true")
}
if f.Get() != 2 {
t.Fatalf("Expected flags 2, got %d", f.Get())
}
if f.Remove(1) {
t.Error("expected false")
}
if f.Get() != 2 {
t.Fatalf("Expected flags 2, got %d", f.Get())
}
if !f.Add(3) {
t.Error("expected true")
}
if f.Get() != 3 {
t.Fatalf("Expected flags 3, got %d", f.Get())
}
if !f.Remove(1) {
t.Error("expected true")
}
if f.Get() != 2 {
t.Fatalf("Expected flags 2, got %d", f.Get())
}
assert.EqualValues(0, f.Get())
assert.True(f.Add(1))
assert.EqualValues(1, f.Get())
assert.False(f.Add(1))
assert.EqualValues(1, f.Get())
assert.True(f.Add(2))
assert.EqualValues(3, f.Get())
assert.True(f.Remove(1))
assert.EqualValues(2, f.Get())
assert.False(f.Remove(1))
assert.EqualValues(2, f.Get())
assert.True(f.Add(3))
assert.EqualValues(3, f.Get())
assert.True(f.Remove(1))
assert.EqualValues(2, f.Get())
}
func runConcurrentFlags(t *testing.T, count int, f func()) {
@ -106,9 +79,7 @@ func TestFlagsConcurrentAdd(t *testing.T) {
added.Add(1)
}
})
if added.Load() != 1 {
t.Errorf("expected only one successfull attempt, got %d", added.Load())
}
assert.EqualValues(t, 1, added.Load(), "expected only one successfull attempt")
}
func TestFlagsConcurrentRemove(t *testing.T) {
@ -122,9 +93,7 @@ func TestFlagsConcurrentRemove(t *testing.T) {
removed.Add(1)
}
})
if removed.Load() != 1 {
t.Errorf("expected only one successfull attempt, got %d", removed.Load())
}
assert.EqualValues(t, 1, removed.Load(), "expected only one successfull attempt")
}
func TestFlagsConcurrentSet(t *testing.T) {
@ -137,7 +106,5 @@ func TestFlagsConcurrentSet(t *testing.T) {
set.Add(1)
}
})
if set.Load() != 1 {
t.Errorf("expected only one successfull attempt, got %d", set.Load())
}
assert.EqualValues(t, 1, set.Load(), "expected only one successfull attempt")
}

View file

@ -32,6 +32,9 @@ import (
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func testGeoLookupReader(t *testing.T, reader *GeoLookup) {
@ -47,14 +50,11 @@ func testGeoLookupReader(t *testing.T, reader *GeoLookup) {
expected := expected
t.Run(ip, func(t *testing.T) {
country, err := reader.LookupCountry(net.ParseIP(ip))
if err != nil {
t.Errorf("Could not lookup %s: %s", ip, err)
if !assert.NoError(t, err, "Could not lookup %s", ip) {
return
}
if country != expected {
t.Errorf("Expected %s for %s, got %s", expected, ip, country)
}
assert.Equal(t, expected, country, "Unexpected country for %s", ip)
})
}
}
@ -79,36 +79,28 @@ func GetGeoIpUrlForTest(t *testing.T) string {
func TestGeoLookup(t *testing.T) {
CatchLogForTest(t)
require := require.New(t)
reader, err := NewGeoLookupFromUrl(GetGeoIpUrlForTest(t))
if err != nil {
t.Fatal(err)
}
require.NoError(err)
defer reader.Close()
if err := reader.Update(); err != nil {
t.Fatal(err)
}
require.NoError(reader.Update())
testGeoLookupReader(t, reader)
}
func TestGeoLookupCaching(t *testing.T) {
CatchLogForTest(t)
require := require.New(t)
reader, err := NewGeoLookupFromUrl(GetGeoIpUrlForTest(t))
if err != nil {
t.Fatal(err)
}
require.NoError(err)
defer reader.Close()
if err := reader.Update(); err != nil {
t.Fatal(err)
}
require.NoError(reader.Update())
// Updating the second time will most likely return a "304 Not Modified".
// Make sure this doesn't trigger an error.
if err := reader.Update(); err != nil {
t.Fatal(err)
}
require.NoError(reader.Update())
}
func TestGeoLookupContinent(t *testing.T) {
@ -125,13 +117,11 @@ func TestGeoLookupContinent(t *testing.T) {
expected := expected
t.Run(country, func(t *testing.T) {
continents := LookupContinents(country)
if len(continents) != len(expected) {
t.Errorf("Continents didn't match for %s: got %s, expected %s", country, continents, expected)
if !assert.Equal(t, len(expected), len(continents), "Continents didn't match for %s: got %s, expected %s", country, continents, expected) {
return
}
for idx, c := range expected {
if continents[idx] != c {
t.Errorf("Continents didn't match for %s: got %s, expected %s", country, continents, expected)
if !assert.Equal(t, c, continents[idx], "Continents didn't match for %s: got %s, expected %s", country, continents, expected) {
break
}
}
@ -142,36 +132,29 @@ func TestGeoLookupContinent(t *testing.T) {
func TestGeoLookupCloseEmpty(t *testing.T) {
CatchLogForTest(t)
reader, err := NewGeoLookupFromUrl("ignore-url")
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
reader.Close()
}
func TestGeoLookupFromFile(t *testing.T) {
CatchLogForTest(t)
require := require.New(t)
geoIpUrl := GetGeoIpUrlForTest(t)
resp, err := http.Get(geoIpUrl)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
defer resp.Body.Close()
body := resp.Body
url := geoIpUrl
if strings.HasSuffix(geoIpUrl, ".gz") {
body, err = gzip.NewReader(body)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
url = strings.TrimSuffix(url, ".gz")
}
tmpfile, err := os.CreateTemp("", "geoipdb")
if err != nil {
t.Fatal(err)
}
require.NoError(err)
t.Cleanup(func() {
os.Remove(tmpfile.Name())
})
@ -183,9 +166,8 @@ func TestGeoLookupFromFile(t *testing.T) {
header, err := tarfile.Next()
if err == io.EOF {
break
} else if err != nil {
t.Fatal(err)
}
require.NoError(err)
if !strings.HasSuffix(header.Name, ".mmdb") {
continue
@ -193,33 +175,25 @@ func TestGeoLookupFromFile(t *testing.T) {
if _, err := io.Copy(tmpfile, tarfile); err != nil {
tmpfile.Close()
t.Fatal(err)
}
if err := tmpfile.Close(); err != nil {
t.Fatal(err)
require.NoError(err)
}
require.NoError(tmpfile.Close())
foundDatabase = true
break
}
} else {
if _, err := io.Copy(tmpfile, body); err != nil {
tmpfile.Close()
t.Fatal(err)
}
if err := tmpfile.Close(); err != nil {
t.Fatal(err)
require.NoError(err)
}
require.NoError(tmpfile.Close())
foundDatabase = true
}
if !foundDatabase {
t.Fatalf("Did not find GeoIP database in download from %s", geoIpUrl)
}
require.True(foundDatabase, "Did not find GeoIP database in download from %s", geoIpUrl)
reader, err := NewGeoLookupFromFile(tmpfile.Name())
if err != nil {
t.Fatal(err)
}
require.NoError(err)
defer reader.Close()
testGeoLookupReader(t, reader)
@ -228,9 +202,7 @@ func TestGeoLookupFromFile(t *testing.T) {
func TestIsValidContinent(t *testing.T) {
for country, continents := range ContinentMap {
for _, continent := range continents {
if !IsValidContinent(continent) {
t.Errorf("Continent %s of country %s is not valid", continent, country)
}
assert.True(t, IsValidContinent(continent), "Continent %s of country %s is not valid", continent, country)
}
}
}

View file

@ -33,6 +33,8 @@ import (
"time"
"github.com/dlintw/goconf"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.etcd.io/etcd/server/v3/embed"
)
@ -52,9 +54,7 @@ func (c *GrpcClients) getWakeupChannelForTesting() <-chan struct{} {
func NewGrpcClientsForTestWithConfig(t *testing.T, config *goconf.ConfigFile, etcdClient *EtcdClient) (*GrpcClients, *DnsMonitor) {
dnsMonitor := newDnsMonitorForTest(t, time.Hour) // will be updated manually
client, err := NewGrpcClients(config, etcdClient, dnsMonitor)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
t.Cleanup(func() {
client.Close()
})
@ -78,13 +78,9 @@ func NewGrpcClientsWithEtcdForTest(t *testing.T, etcd *embed.Etcd) (*GrpcClients
config.AddOption("grpc", "targetprefix", "/grpctargets")
etcdClient, err := NewEtcdClient(config, "")
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
t.Cleanup(func() {
if err := etcdClient.Close(); err != nil {
t.Error(err)
}
assert.NoError(t, etcdClient.Close())
})
return NewGrpcClientsForTestWithConfig(t, config, etcdClient)
@ -107,7 +103,7 @@ func waitForEvent(ctx context.Context, t *testing.T, ch <-chan struct{}) {
case <-ch:
return
case <-ctx.Done():
t.Error("timeout waiting for event")
assert.Fail(t, "timeout waiting for event")
}
}
@ -125,19 +121,17 @@ func Test_GrpcClients_EtcdInitial(t *testing.T) {
client, _ := NewGrpcClientsWithEtcdForTest(t, etcd)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := client.WaitForInitialized(ctx); err != nil {
t.Fatal(err)
}
require.NoError(t, client.WaitForInitialized(ctx))
if clients := client.GetClients(); len(clients) != 2 {
t.Errorf("Expected two clients, got %+v", clients)
}
clients := client.GetClients()
assert.Len(t, clients, 2, "Expected two clients, got %+v", clients)
})
}
func Test_GrpcClients_EtcdUpdate(t *testing.T) {
t.Parallel()
CatchLogForTest(t)
assert := assert.New(t)
etcd := NewEtcdForTest(t)
client, _ := NewGrpcClientsWithEtcdForTest(t, etcd)
ch := client.getWakeupChannelForTesting()
@ -145,55 +139,45 @@ func Test_GrpcClients_EtcdUpdate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
if clients := client.GetClients(); len(clients) != 0 {
t.Errorf("Expected no clients, got %+v", clients)
}
assert.Empty(client.GetClients())
drainWakeupChannel(ch)
_, addr1 := NewGrpcServerForTest(t)
SetEtcdValue(etcd, "/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}"))
waitForEvent(ctx, t, ch)
if clients := client.GetClients(); len(clients) != 1 {
t.Errorf("Expected one client, got %+v", clients)
} else if clients[0].Target() != addr1 {
t.Errorf("Expected target %s, got %s", addr1, clients[0].Target())
if clients := client.GetClients(); assert.Len(clients, 1) {
assert.Equal(addr1, clients[0].Target())
}
drainWakeupChannel(ch)
_, addr2 := NewGrpcServerForTest(t)
SetEtcdValue(etcd, "/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}"))
waitForEvent(ctx, t, ch)
if clients := client.GetClients(); len(clients) != 2 {
t.Errorf("Expected two clients, got %+v", clients)
} else if clients[0].Target() != addr1 {
t.Errorf("Expected target %s, got %s", addr1, clients[0].Target())
} else if clients[1].Target() != addr2 {
t.Errorf("Expected target %s, got %s", addr2, clients[1].Target())
if clients := client.GetClients(); assert.Len(clients, 2) {
assert.Equal(addr1, clients[0].Target())
assert.Equal(addr2, clients[1].Target())
}
drainWakeupChannel(ch)
DeleteEtcdValue(etcd, "/grpctargets/one")
waitForEvent(ctx, t, ch)
if clients := client.GetClients(); len(clients) != 1 {
t.Errorf("Expected one client, got %+v", clients)
} else if clients[0].Target() != addr2 {
t.Errorf("Expected target %s, got %s", addr2, clients[0].Target())
if clients := client.GetClients(); assert.Len(clients, 1) {
assert.Equal(addr2, clients[0].Target())
}
drainWakeupChannel(ch)
_, addr3 := NewGrpcServerForTest(t)
SetEtcdValue(etcd, "/grpctargets/two", []byte("{\"address\":\""+addr3+"\"}"))
waitForEvent(ctx, t, ch)
if clients := client.GetClients(); len(clients) != 1 {
t.Errorf("Expected one client, got %+v", clients)
} else if clients[0].Target() != addr3 {
t.Errorf("Expected target %s, got %s", addr3, clients[0].Target())
if clients := client.GetClients(); assert.Len(clients, 1) {
assert.Equal(addr3, clients[0].Target())
}
}
func Test_GrpcClients_EtcdIgnoreSelf(t *testing.T) {
t.Parallel()
CatchLogForTest(t)
assert := assert.New(t)
etcd := NewEtcdForTest(t)
client, _ := NewGrpcClientsWithEtcdForTest(t, etcd)
ch := client.getWakeupChannelForTesting()
@ -201,18 +185,14 @@ func Test_GrpcClients_EtcdIgnoreSelf(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
if clients := client.GetClients(); len(clients) != 0 {
t.Errorf("Expected no clients, got %+v", clients)
}
assert.Empty(client.GetClients())
drainWakeupChannel(ch)
_, addr1 := NewGrpcServerForTest(t)
SetEtcdValue(etcd, "/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}"))
waitForEvent(ctx, t, ch)
if clients := client.GetClients(); len(clients) != 1 {
t.Errorf("Expected one client, got %+v", clients)
} else if clients[0].Target() != addr1 {
t.Errorf("Expected target %s, got %s", addr1, clients[0].Target())
if clients := client.GetClients(); assert.Len(clients, 1) {
assert.Equal(addr1, clients[0].Target())
}
drainWakeupChannel(ch)
@ -221,25 +201,22 @@ func Test_GrpcClients_EtcdIgnoreSelf(t *testing.T) {
SetEtcdValue(etcd, "/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}"))
waitForEvent(ctx, t, ch)
client.selfCheckWaitGroup.Wait()
if clients := client.GetClients(); len(clients) != 1 {
t.Errorf("Expected one client, got %+v", clients)
} else if clients[0].Target() != addr1 {
t.Errorf("Expected target %s, got %s", addr1, clients[0].Target())
if clients := client.GetClients(); assert.Len(clients, 1) {
assert.Equal(addr1, clients[0].Target())
}
drainWakeupChannel(ch)
DeleteEtcdValue(etcd, "/grpctargets/two")
waitForEvent(ctx, t, ch)
if clients := client.GetClients(); len(clients) != 1 {
t.Errorf("Expected one client, got %+v", clients)
} else if clients[0].Target() != addr1 {
t.Errorf("Expected target %s, got %s", addr1, clients[0].Target())
if clients := client.GetClients(); assert.Len(clients, 1) {
assert.Equal(addr1, clients[0].Target())
}
}
func Test_GrpcClients_DnsDiscovery(t *testing.T) {
CatchLogForTest(t)
ensureNoGoroutinesLeak(t, func(t *testing.T) {
assert := assert.New(t)
lookup := newMockDnsLookupForTest(t)
target := "testgrpc:12345"
ip1 := net.ParseIP("192.168.0.1")
@ -254,12 +231,9 @@ func Test_GrpcClients_DnsDiscovery(t *testing.T) {
defer cancel()
dnsMonitor.checkHostnames()
if clients := client.GetClients(); len(clients) != 1 {
t.Errorf("Expected one client, got %+v", clients)
} else if clients[0].Target() != targetWithIp1 {
t.Errorf("Expected target %s, got %s", targetWithIp1, clients[0].Target())
} else if !clients[0].ip.Equal(ip1) {
t.Errorf("Expected IP %s, got %s", ip1, clients[0].ip)
if clients := client.GetClients(); assert.Len(clients, 1) {
assert.Equal(targetWithIp1, clients[0].Target())
assert.True(clients[0].ip.Equal(ip1), "Expected IP %s, got %s", ip1, clients[0].ip)
}
lookup.Set("testgrpc", []net.IP{ip1, ip2})
@ -267,16 +241,11 @@ func Test_GrpcClients_DnsDiscovery(t *testing.T) {
dnsMonitor.checkHostnames()
waitForEvent(ctx, t, ch)
if clients := client.GetClients(); len(clients) != 2 {
t.Errorf("Expected two client, got %+v", clients)
} else if clients[0].Target() != targetWithIp1 {
t.Errorf("Expected target %s, got %s", targetWithIp1, clients[0].Target())
} else if !clients[0].ip.Equal(ip1) {
t.Errorf("Expected IP %s, got %s", ip1, clients[0].ip)
} else if clients[1].Target() != targetWithIp2 {
t.Errorf("Expected target %s, got %s", targetWithIp2, clients[1].Target())
} else if !clients[1].ip.Equal(ip2) {
t.Errorf("Expected IP %s, got %s", ip2, clients[1].ip)
if clients := client.GetClients(); assert.Len(clients, 2) {
assert.Equal(targetWithIp1, clients[0].Target())
assert.True(clients[0].ip.Equal(ip1), "Expected IP %s, got %s", ip1, clients[0].ip)
assert.Equal(targetWithIp2, clients[1].Target())
assert.True(clients[1].ip.Equal(ip2), "Expected IP %s, got %s", ip2, clients[1].ip)
}
lookup.Set("testgrpc", []net.IP{ip2})
@ -284,12 +253,9 @@ func Test_GrpcClients_DnsDiscovery(t *testing.T) {
dnsMonitor.checkHostnames()
waitForEvent(ctx, t, ch)
if clients := client.GetClients(); len(clients) != 1 {
t.Errorf("Expected one client, got %+v", clients)
} else if clients[0].Target() != targetWithIp2 {
t.Errorf("Expected target %s, got %s", targetWithIp2, clients[0].Target())
} else if !clients[0].ip.Equal(ip2) {
t.Errorf("Expected IP %s, got %s", ip2, clients[0].ip)
if clients := client.GetClients(); assert.Len(clients, 1) {
assert.Equal(targetWithIp2, clients[0].Target())
assert.True(clients[0].ip.Equal(ip2), "Expected IP %s, got %s", ip2, clients[0].ip)
}
})
}
@ -297,6 +263,7 @@ func Test_GrpcClients_DnsDiscovery(t *testing.T) {
func Test_GrpcClients_DnsDiscoveryInitialFailed(t *testing.T) {
t.Parallel()
CatchLogForTest(t)
assert := assert.New(t)
lookup := newMockDnsLookupForTest(t)
target := "testgrpc:12345"
ip1 := net.ParseIP("192.168.0.1")
@ -309,39 +276,29 @@ func Test_GrpcClients_DnsDiscoveryInitialFailed(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := client.WaitForInitialized(ctx); err != nil {
t.Fatal(err)
}
require.NoError(t, client.WaitForInitialized(ctx))
if clients := client.GetClients(); len(clients) != 0 {
t.Errorf("Expected no client, got %+v", clients)
}
assert.Empty(client.GetClients())
lookup.Set("testgrpc", []net.IP{ip1})
drainWakeupChannel(ch)
dnsMonitor.checkHostnames()
waitForEvent(testCtx, t, ch)
if clients := client.GetClients(); len(clients) != 1 {
t.Errorf("Expected one client, got %+v", clients)
} else if clients[0].Target() != targetWithIp1 {
t.Errorf("Expected target %s, got %s", targetWithIp1, clients[0].Target())
} else if !clients[0].ip.Equal(ip1) {
t.Errorf("Expected IP %s, got %s", ip1, clients[0].ip)
if clients := client.GetClients(); assert.Len(clients, 1) {
assert.Equal(targetWithIp1, clients[0].Target())
assert.True(clients[0].ip.Equal(ip1), "Expected IP %s, got %s", ip1, clients[0].ip)
}
}
func Test_GrpcClients_Encryption(t *testing.T) {
CatchLogForTest(t)
ensureNoGoroutinesLeak(t, func(t *testing.T) {
require := require.New(t)
serverKey, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
clientKey, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
serverCert := GenerateSelfSignedCertificateForTesting(t, 1024, "Server cert", serverKey)
clientCert := GenerateSelfSignedCertificateForTesting(t, 1024, "Testing client", clientKey)
@ -376,14 +333,11 @@ func Test_GrpcClients_Encryption(t *testing.T) {
ctx, cancel1 := context.WithTimeout(context.Background(), time.Second)
defer cancel1()
if err := clients.WaitForInitialized(ctx); err != nil {
t.Fatal(err)
}
require.NoError(clients.WaitForInitialized(ctx))
for _, client := range clients.GetClients() {
if _, err := client.GetServerId(ctx); err != nil {
t.Fatal(err)
}
_, err := client.GetServerId(ctx)
require.NoError(err)
}
})
}

View file

@ -35,6 +35,8 @@ import (
"os"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func (c *reloadableCredentials) WaitForCertificateReload(ctx context.Context) error {
@ -72,9 +74,7 @@ func GenerateSelfSignedCertificateForTesting(t *testing.T, bits int, organizatio
}
data, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
data = pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
@ -108,23 +108,15 @@ func WritePublicKey(key *rsa.PublicKey, filename string) error {
func replaceFile(t *testing.T, filename string, data []byte, perm fs.FileMode) {
t.Helper()
require := require.New(t)
oldStat, err := os.Stat(filename)
if err != nil {
t.Fatalf("can't stat old file %s: %s", filename, err)
return
}
require.NoError(err, "can't stat old file %s", filename)
for {
if err := os.WriteFile(filename, data, perm); err != nil {
t.Fatalf("can't write file %s: %s", filename, err)
return
}
require.NoError(os.WriteFile(filename, data, perm), "can't write file %s", filename)
newStat, err := os.Stat(filename)
if err != nil {
t.Fatalf("can't stat new file %s: %s", filename, err)
return
}
require.NoError(err, "can't stat new file %s", filename)
// We need different modification times.
if !newStat.ModTime().Equal(oldStat.ModTime()) {

View file

@ -37,6 +37,8 @@ import (
"time"
"github.com/dlintw/goconf"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
@ -67,23 +69,19 @@ func NewGrpcServerForTestWithConfig(t *testing.T, config *goconf.ConfigFile) (se
server, err = NewGrpcServer(config)
if isErrorAddressAlreadyInUse(err) {
continue
} else if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
break
}
if server == nil {
t.Fatal("could not find free port")
}
require.NotNil(t, server, "could not find free port")
// Don't match with own server id by default.
server.serverId = "dont-match"
go func() {
if err := server.Run(); err != nil {
t.Errorf("could not start GRPC server: %s", err)
}
assert.NoError(t, server.Run(), "could not start GRPC server")
}()
t.Cleanup(func() {
@ -99,10 +97,10 @@ func NewGrpcServerForTest(t *testing.T) (server *GrpcServer, addr string) {
func Test_GrpcServer_ReloadCerts(t *testing.T) {
CatchLogForTest(t)
require := require.New(t)
assert := assert.New(t)
key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
org1 := "Testing certificate"
cert1 := GenerateSelfSignedCertificateForTesting(t, 1024, org1, key)
@ -124,24 +122,20 @@ func Test_GrpcServer_ReloadCerts(t *testing.T) {
cp1 := x509.NewCertPool()
if !cp1.AppendCertsFromPEM(cert1) {
t.Fatalf("could not add certificate")
require.Fail("could not add certificate")
}
cfg1 := &tls.Config{
RootCAs: cp1,
}
conn1, err := tls.Dial("tcp", addr, cfg1)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
defer conn1.Close() // nolint
state1 := conn1.ConnectionState()
if certs := state1.PeerCertificates; len(certs) == 0 {
t.Errorf("expected certificates, got %+v", state1)
} else if len(certs[0].Subject.Organization) == 0 {
t.Errorf("expected organization, got %s", certs[0].Subject)
} else if certs[0].Subject.Organization[0] != org1 {
t.Errorf("expected organization %s, got %s", org1, certs[0].Subject)
if certs := state1.PeerCertificates; assert.NotEmpty(certs) {
if assert.NotEmpty(certs[0].Subject.Organization) {
assert.Equal(org1, certs[0].Subject.Organization[0])
}
}
org2 := "Updated certificate"
@ -151,43 +145,34 @@ func Test_GrpcServer_ReloadCerts(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := server.WaitForCertificateReload(ctx); err != nil {
t.Fatal(err)
}
require.NoError(server.WaitForCertificateReload(ctx))
cp2 := x509.NewCertPool()
if !cp2.AppendCertsFromPEM(cert2) {
t.Fatalf("could not add certificate")
require.Fail("could not add certificate")
}
cfg2 := &tls.Config{
RootCAs: cp2,
}
conn2, err := tls.Dial("tcp", addr, cfg2)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
defer conn2.Close() // nolint
state2 := conn2.ConnectionState()
if certs := state2.PeerCertificates; len(certs) == 0 {
t.Errorf("expected certificates, got %+v", state2)
} else if len(certs[0].Subject.Organization) == 0 {
t.Errorf("expected organization, got %s", certs[0].Subject)
} else if certs[0].Subject.Organization[0] != org2 {
t.Errorf("expected organization %s, got %s", org2, certs[0].Subject)
if certs := state2.PeerCertificates; assert.NotEmpty(certs) {
if assert.NotEmpty(certs[0].Subject.Organization) {
assert.Equal(org2, certs[0].Subject.Organization[0])
}
}
}
func Test_GrpcServer_ReloadCA(t *testing.T) {
CatchLogForTest(t)
require := require.New(t)
serverKey, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
clientKey, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
serverCert := GenerateSelfSignedCertificateForTesting(t, 1024, "Server cert", serverKey)
org1 := "Testing client"
@ -213,65 +198,53 @@ func Test_GrpcServer_ReloadCA(t *testing.T) {
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(serverCert) {
t.Fatalf("could not add certificate")
require.Fail("could not add certificate")
}
pair1, err := tls.X509KeyPair(clientCert1, pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(clientKey),
}))
if err != nil {
t.Fatal(err)
}
require.NoError(err)
cfg1 := &tls.Config{
RootCAs: pool,
Certificates: []tls.Certificate{pair1},
}
client1, err := NewGrpcClient(addr, nil, grpc.WithTransportCredentials(credentials.NewTLS(cfg1)))
if err != nil {
t.Fatal(err)
}
require.NoError(err)
defer client1.Close() // nolint
ctx1, cancel1 := context.WithTimeout(context.Background(), time.Second)
defer cancel1()
if _, err := client1.GetServerId(ctx1); err != nil {
t.Fatal(err)
}
_, err = client1.GetServerId(ctx1)
require.NoError(err)
org2 := "Updated client"
clientCert2 := GenerateSelfSignedCertificateForTesting(t, 1024, org2, clientKey)
replaceFile(t, caFile, clientCert2, 0755)
if err := server.WaitForCertPoolReload(ctx1); err != nil {
t.Fatal(err)
}
require.NoError(server.WaitForCertPoolReload(ctx1))
pair2, err := tls.X509KeyPair(clientCert2, pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(clientKey),
}))
if err != nil {
t.Fatal(err)
}
require.NoError(err)
cfg2 := &tls.Config{
RootCAs: pool,
Certificates: []tls.Certificate{pair2},
}
client2, err := NewGrpcClient(addr, nil, grpc.WithTransportCredentials(credentials.NewTLS(cfg2)))
if err != nil {
t.Fatal(err)
}
require.NoError(err)
defer client2.Close() // nolint
ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second)
defer cancel2()
// This will fail if the CA certificate has not been reloaded by the server.
if _, err := client2.GetServerId(ctx2); err != nil {
t.Fatal(err)
}
_, err = client2.GetServerId(ctx2)
require.NoError(err)
}

View file

@ -26,52 +26,42 @@ import (
"net/url"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestHttpClientPool(t *testing.T) {
t.Parallel()
if _, err := NewHttpClientPool(0, false); err == nil {
t.Error("should not be possible to create empty pool")
}
require := require.New(t)
assert := assert.New(t)
_, err := NewHttpClientPool(0, false)
assert.Error(err)
pool, err := NewHttpClientPool(1, false)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
u, err := url.Parse("http://localhost/foo/bar")
if err != nil {
t.Fatal(err)
}
require.NoError(err)
ctx := context.Background()
if _, _, err := pool.Get(ctx, u); err != nil {
t.Fatal(err)
}
_, _, err = pool.Get(ctx, u)
require.NoError(err)
ctx2, cancel := context.WithTimeout(ctx, 10*time.Millisecond)
defer cancel()
if _, _, err := pool.Get(ctx2, u); err == nil {
t.Error("fetching from empty pool should have timed out")
} else if err != context.DeadlineExceeded {
t.Errorf("fetching from empty pool should have timed out, got %s", err)
}
_, _, err = pool.Get(ctx2, u)
assert.ErrorIs(err, context.DeadlineExceeded)
// Pools are separated by hostname, so can get client for different host.
u2, err := url.Parse("http://local.host/foo/bar")
if err != nil {
t.Fatal(err)
}
require.NoError(err)
if _, _, err := pool.Get(ctx, u2); err != nil {
t.Fatal(err)
}
_, _, err = pool.Get(ctx, u2)
require.NoError(err)
ctx3, cancel2 := context.WithTimeout(ctx, 10*time.Millisecond)
defer cancel2()
if _, _, err := pool.Get(ctx3, u2); err == nil {
t.Error("fetching from empty pool should have timed out")
} else if err != context.DeadlineExceeded {
t.Errorf("fetching from empty pool should have timed out, got %s", err)
}
_, _, err = pool.Get(ctx3, u2)
assert.ErrorIs(err, context.DeadlineExceeded)
}

File diff suppressed because it is too large Load diff

View file

@ -24,46 +24,36 @@ package signaling
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestLruUnbound(t *testing.T) {
assert := assert.New(t)
lru := NewLruCache(0)
count := 10
for i := 0; i < count; i++ {
key := fmt.Sprintf("%d", i)
lru.Set(key, i)
}
if lru.Len() != count {
t.Errorf("Expected %d entries, got %d", count, lru.Len())
}
assert.Equal(count, lru.Len())
for i := 0; i < count; i++ {
key := fmt.Sprintf("%d", i)
value := lru.Get(key)
if value == nil {
t.Errorf("No value found for %s", key)
continue
} else if value.(int) != i {
t.Errorf("Expected value to be %d, got %d", value.(int), i)
if value := lru.Get(key); assert.NotNil(value, "No value found for %s", key) {
assert.EqualValues(i, value)
}
}
// The first key ("0") is now the oldest.
lru.RemoveOldest()
if lru.Len() != count-1 {
t.Errorf("Expected %d entries after RemoveOldest, got %d", count-1, lru.Len())
}
assert.Equal(count-1, lru.Len())
for i := 0; i < count; i++ {
key := fmt.Sprintf("%d", i)
value := lru.Get(key)
if i == 0 {
if value != nil {
t.Errorf("The value for key %s should have been removed", key)
}
assert.Nil(value, "The value for key %s should have been removed", key)
continue
} else if value == nil {
t.Errorf("No value found for %s", key)
continue
} else if value.(int) != i {
t.Errorf("Expected value to be %d, got %d", value.(int), i)
} else if assert.NotNil(value, "No value found for %s", key) {
assert.EqualValues(i, value)
}
}
@ -74,66 +64,47 @@ func TestLruUnbound(t *testing.T) {
key := fmt.Sprintf("%d", i)
lru.Set(key, i)
}
if lru.Len() != count-1 {
t.Errorf("Expected %d entries, got %d", count-1, lru.Len())
}
assert.Equal(count-1, lru.Len())
// NOTE: The same ordering as the Set calls above.
for i := count - 1; i >= 1; i-- {
key := fmt.Sprintf("%d", i)
value := lru.Get(key)
if value == nil {
t.Errorf("No value found for %s", key)
continue
} else if value.(int) != i {
t.Errorf("Expected value to be %d, got %d", value.(int), i)
if value := lru.Get(key); assert.NotNil(value, "No value found for %s", key) {
assert.EqualValues(i, value)
}
}
// The last key ("9") is now the oldest.
lru.RemoveOldest()
if lru.Len() != count-2 {
t.Errorf("Expected %d entries after RemoveOldest, got %d", count-2, lru.Len())
}
assert.Equal(count-2, lru.Len())
for i := 0; i < count; i++ {
key := fmt.Sprintf("%d", i)
value := lru.Get(key)
if i == 0 || i == count-1 {
if value != nil {
t.Errorf("The value for key %s should have been removed", key)
}
assert.Nil(value, "The value for key %s should have been removed", key)
continue
} else if value == nil {
t.Errorf("No value found for %s", key)
continue
} else if value.(int) != i {
t.Errorf("Expected value to be %d, got %d", value.(int), i)
} else if assert.NotNil(value, "No value found for %s", key) {
assert.EqualValues(i, value)
}
}
// Remove an arbitrary key from the cache
key := fmt.Sprintf("%d", count/2)
lru.Remove(key)
if lru.Len() != count-3 {
t.Errorf("Expected %d entries after RemoveOldest, got %d", count-3, lru.Len())
}
assert.Equal(count-3, lru.Len())
for i := 0; i < count; i++ {
key := fmt.Sprintf("%d", i)
value := lru.Get(key)
if i == 0 || i == count-1 || i == count/2 {
if value != nil {
t.Errorf("The value for key %s should have been removed", key)
}
assert.Nil(value, "The value for key %s should have been removed", key)
continue
} else if value == nil {
t.Errorf("No value found for %s", key)
continue
} else if value.(int) != i {
t.Errorf("Expected value to be %d, got %d", value.(int), i)
} else if assert.NotNil(value, "No value found for %s", key) {
assert.EqualValues(i, value)
}
}
}
func TestLruBound(t *testing.T) {
assert := assert.New(t)
size := 2
lru := NewLruCache(size)
count := 10
@ -141,23 +112,16 @@ func TestLruBound(t *testing.T) {
key := fmt.Sprintf("%d", i)
lru.Set(key, i)
}
if lru.Len() != size {
t.Errorf("Expected %d entries, got %d", size, lru.Len())
}
assert.Equal(size, lru.Len())
// Only the last "size" entries have been stored.
for i := 0; i < count; i++ {
key := fmt.Sprintf("%d", i)
value := lru.Get(key)
if i < count-size {
if value != nil {
t.Errorf("The value for key %s should have been removed", key)
}
assert.Nil(value, "The value for key %s should have been removed", key)
continue
} else if value == nil {
t.Errorf("No value found for %s", key)
continue
} else if value.(int) != i {
t.Errorf("Expected value to be %d, got %d", value.(int), i)
} else if assert.NotNil(value, "No value found for %s", key) {
assert.EqualValues(i, value)
}
}
}

View file

@ -23,9 +23,12 @@ package signaling
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestGetFmtpValueH264(t *testing.T) {
assert := assert.New(t)
testcases := []struct {
fmtp string
profile string
@ -51,16 +54,17 @@ func TestGetFmtpValueH264(t *testing.T) {
for _, tc := range testcases {
value, found := getFmtpValue(tc.fmtp, "profile-level-id")
if !found && tc.profile != "" {
t.Errorf("did not find profile \"%s\" in \"%s\"", tc.profile, tc.fmtp)
assert.Fail("did not find profile \"%s\" in \"%s\"", tc.profile, tc.fmtp)
} else if found && tc.profile == "" {
t.Errorf("did not expect profile in \"%s\" but got \"%s\"", tc.fmtp, value)
assert.Fail("did not expect profile in \"%s\" but got \"%s\"", tc.fmtp, value)
} else if found && tc.profile != value {
t.Errorf("expected profile \"%s\" in \"%s\" but got \"%s\"", tc.profile, tc.fmtp, value)
assert.Fail("expected profile \"%s\" in \"%s\" but got \"%s\"", tc.profile, tc.fmtp, value)
}
}
}
func TestGetFmtpValueVP9(t *testing.T) {
assert := assert.New(t)
testcases := []struct {
fmtp string
profile string
@ -82,11 +86,11 @@ func TestGetFmtpValueVP9(t *testing.T) {
for _, tc := range testcases {
value, found := getFmtpValue(tc.fmtp, "profile-id")
if !found && tc.profile != "" {
t.Errorf("did not find profile \"%s\" in \"%s\"", tc.profile, tc.fmtp)
assert.Fail("did not find profile \"%s\" in \"%s\"", tc.profile, tc.fmtp)
} else if found && tc.profile == "" {
t.Errorf("did not expect profile in \"%s\" but got \"%s\"", tc.fmtp, value)
assert.Fail("did not expect profile in \"%s\" but got \"%s\"", tc.fmtp, value)
} else if found && tc.profile != value {
t.Errorf("expected profile \"%s\" in \"%s\" but got \"%s\"", tc.profile, tc.fmtp, value)
assert.Fail("expected profile \"%s\" in \"%s\" but got \"%s\"", tc.profile, tc.fmtp, value)
}
}
}

View file

@ -41,6 +41,8 @@ import (
"github.com/dlintw/goconf"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.etcd.io/etcd/server/v3/embed"
)
@ -105,7 +107,7 @@ func Test_sortConnectionsForCountry(t *testing.T) {
sorted := sortConnectionsForCountry(test[0], country, nil)
for idx, conn := range sorted {
if test[1][idx] != conn {
t.Errorf("Index %d for %s: expected %s, got %s", idx, country, test[1][idx].Country(), conn.Country())
assert.Fail(t, "Index %d for %s: expected %s, got %s", idx, country, test[1][idx].Country(), conn.Country())
}
}
})
@ -179,7 +181,7 @@ func Test_sortConnectionsForCountryWithOverride(t *testing.T) {
sorted := sortConnectionsForCountry(test[0], country, continentMap)
for idx, conn := range sorted {
if test[1][idx] != conn {
t.Errorf("Index %d for %s: expected %s, got %s", idx, country, test[1][idx].Country(), conn.Country())
assert.Fail(t, "Index %d for %s: expected %s, got %s", idx, country, test[1][idx].Country(), conn.Country())
}
}
})
@ -342,7 +344,7 @@ func (c *testProxyServerClient) handleSendMessageError(fmt string, msg *ProxySer
c.t.Helper()
if !errors.Is(err, websocket.ErrCloseSent) || msg.Type != "event" || msg.Event.Type != "update-load" {
c.t.Errorf(fmt, msg, err)
assert.Fail(c.t, fmt, msg, err)
}
}
@ -385,6 +387,7 @@ func (c *testProxyServerClient) run() {
c.ws = nil
}()
c.processMessage = c.processHello
assert := assert.New(c.t)
for {
c.mu.Lock()
ws := c.ws
@ -396,36 +399,31 @@ func (c *testProxyServerClient) run() {
msgType, reader, err := ws.NextReader()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) {
c.t.Error(err)
assert.NoError(err)
}
return
}
body, err := io.ReadAll(reader)
if err != nil {
c.t.Error(err)
if !assert.NoError(err) {
continue
}
if msgType != websocket.TextMessage {
c.t.Errorf("unexpected message type %q (%s)", msgType, string(body))
if !assert.Equal(websocket.TextMessage, msgType, "unexpected message type for %s", string(body)) {
continue
}
var msg ProxyClientMessage
if err := json.Unmarshal(body, &msg); err != nil {
c.t.Errorf("could not decode message %s: %s", string(body), err)
if err := json.Unmarshal(body, &msg); !assert.NoError(err, "could not decode message %s", string(body)) {
continue
}
if err := msg.CheckValid(); err != nil {
c.t.Errorf("invalid message %s: %s", string(body), err)
if err := msg.CheckValid(); !assert.NoError(err, "invalid message %s", string(body)) {
continue
}
response, err := c.processMessage(&msg)
if err != nil {
c.t.Error(err)
if !assert.NoError(err) {
continue
}
@ -605,8 +603,7 @@ func (h *TestProxyServerHandler) removeClient(client *testProxyServerClient) {
func (h *TestProxyServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ws, err := h.upgrader.Upgrade(w, r, nil)
if err != nil {
h.t.Error(err)
if !assert.NoError(h.t, err) {
return
}
@ -658,15 +655,14 @@ type proxyTestOptions struct {
func newMcuProxyForTestWithOptions(t *testing.T, options proxyTestOptions) *mcuProxy {
t.Helper()
require := require.New(t)
if options.etcd == nil {
options.etcd = NewEtcdForTest(t)
}
grpcClients, dnsMonitor := NewGrpcClientsWithEtcdForTest(t, options.etcd)
tokenKey, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
dir := t.TempDir()
privkeyFile := path.Join(dir, "privkey.pem")
pubkeyFile := path.Join(dir, "pubkey.pem")
@ -696,19 +692,13 @@ func newMcuProxyForTestWithOptions(t *testing.T, options proxyTestOptions) *mcuP
etcdConfig.AddOption("etcd", "loglevel", "error")
etcdClient, err := NewEtcdClient(etcdConfig, "")
if err != nil {
t.Fatal(err)
}
require.NoError(err)
t.Cleanup(func() {
if err := etcdClient.Close(); err != nil {
t.Error(err)
}
assert.NoError(t, etcdClient.Close())
})
mcu, err := NewMcuProxy(cfg, etcdClient, grpcClients, dnsMonitor)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
t.Cleanup(func() {
mcu.Stop()
})
@ -716,20 +706,14 @@ func newMcuProxyForTestWithOptions(t *testing.T, options proxyTestOptions) *mcuP
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
if err := mcu.Start(ctx); err != nil {
t.Fatal(err)
}
require.NoError(mcu.Start(ctx))
proxy := mcu.(*mcuProxy)
if err := proxy.WaitForConnections(ctx); err != nil {
t.Fatal(err)
}
require.NoError(proxy.WaitForConnections(ctx))
for len(waitingMap) > 0 {
if err := ctx.Err(); err != nil {
t.Fatal(err)
}
require.NoError(ctx.Err())
for u := range waitingMap {
proxy.connectionsMu.RLock()
@ -782,9 +766,7 @@ func Test_ProxyPublisherSubscriber(t *testing.T) {
}
pub, err := mcu.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer pub.Close(context.Background())
@ -795,9 +777,7 @@ func Test_ProxyPublisherSubscriber(t *testing.T) {
country: "DE",
}
sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer sub.Close(context.Background())
}
@ -829,8 +809,7 @@ func Test_ProxyWaitForPublisher(t *testing.T) {
go func() {
defer close(done)
sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator)
if err != nil {
t.Error(err)
if !assert.NoError(t, err) {
return
}
@ -841,14 +820,12 @@ func Test_ProxyWaitForPublisher(t *testing.T) {
time.Sleep(100 * time.Millisecond)
pub, err := mcu.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
select {
case <-done:
case <-ctx.Done():
t.Error(ctx.Err())
assert.NoError(t, ctx.Err())
}
defer pub.Close(context.Background())
}
@ -875,9 +852,7 @@ func Test_ProxyPublisherBandwidth(t *testing.T) {
country: "DE",
}
pub1, err := mcu.NewPublisher(ctx, pub1Listener, pub1Id, pub1Sid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pub1Initiator)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer pub1.Close(context.Background())
@ -914,15 +889,11 @@ func Test_ProxyPublisherBandwidth(t *testing.T) {
country: "DE",
}
pub2, err := mcu.NewPublisher(ctx, pub2Listener, pub2Id, pub2id, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pub2Initiator)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer pub2.Close(context.Background())
if pub1.(*mcuProxyPublisher).conn.rawUrl == pub2.(*mcuProxyPublisher).conn.rawUrl {
t.Errorf("servers should be different, got %s", pub1.(*mcuProxyPublisher).conn.rawUrl)
}
assert.NotEqual(t, pub1.(*mcuProxyPublisher).conn.rawUrl, pub2.(*mcuProxyPublisher).conn.rawUrl)
}
func Test_ProxyPublisherBandwidthOverload(t *testing.T) {
@ -947,9 +918,7 @@ func Test_ProxyPublisherBandwidthOverload(t *testing.T) {
country: "DE",
}
pub1, err := mcu.NewPublisher(ctx, pub1Listener, pub1Id, pub1Sid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pub1Initiator)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer pub1.Close(context.Background())
@ -989,15 +958,11 @@ func Test_ProxyPublisherBandwidthOverload(t *testing.T) {
country: "DE",
}
pub2, err := mcu.NewPublisher(ctx, pub2Listener, pub2Id, pub2id, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pub2Initiator)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer pub2.Close(context.Background())
if pub1.(*mcuProxyPublisher).conn.rawUrl != pub2.(*mcuProxyPublisher).conn.rawUrl {
t.Errorf("servers should be the same, got %s / %s", pub1.(*mcuProxyPublisher).conn.rawUrl, pub2.(*mcuProxyPublisher).conn.rawUrl)
}
assert.Equal(t, pub1.(*mcuProxyPublisher).conn.rawUrl, pub2.(*mcuProxyPublisher).conn.rawUrl)
}
func Test_ProxyPublisherLoad(t *testing.T) {
@ -1022,9 +987,7 @@ func Test_ProxyPublisherLoad(t *testing.T) {
country: "DE",
}
pub1, err := mcu.NewPublisher(ctx, pub1Listener, pub1Id, pub1Sid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pub1Initiator)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer pub1.Close(context.Background())
@ -1041,15 +1004,11 @@ func Test_ProxyPublisherLoad(t *testing.T) {
country: "DE",
}
pub2, err := mcu.NewPublisher(ctx, pub2Listener, pub2Id, pub2id, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pub2Initiator)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer pub2.Close(context.Background())
if pub1.(*mcuProxyPublisher).conn.rawUrl == pub2.(*mcuProxyPublisher).conn.rawUrl {
t.Errorf("servers should be different, got %s", pub1.(*mcuProxyPublisher).conn.rawUrl)
}
assert.NotEqual(t, pub1.(*mcuProxyPublisher).conn.rawUrl, pub2.(*mcuProxyPublisher).conn.rawUrl)
}
func Test_ProxyPublisherCountry(t *testing.T) {
@ -1074,15 +1033,11 @@ func Test_ProxyPublisherCountry(t *testing.T) {
country: "DE",
}
pubDE, err := mcu.NewPublisher(ctx, pubDEListener, pubDEId, pubDESid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubDEInitiator)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer pubDE.Close(context.Background())
if pubDE.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL {
t.Errorf("expected server %s, go %s", serverDE.URL, pubDE.(*mcuProxyPublisher).conn.rawUrl)
}
assert.Equal(t, serverDE.URL, pubDE.(*mcuProxyPublisher).conn.rawUrl)
pubUSId := "the-publisher-us"
pubUSSid := "1234567890"
@ -1093,15 +1048,11 @@ func Test_ProxyPublisherCountry(t *testing.T) {
country: "US",
}
pubUS, err := mcu.NewPublisher(ctx, pubUSListener, pubUSId, pubUSSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubUSInitiator)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer pubUS.Close(context.Background())
if pubUS.(*mcuProxyPublisher).conn.rawUrl != serverUS.URL {
t.Errorf("expected server %s, go %s", serverUS.URL, pubUS.(*mcuProxyPublisher).conn.rawUrl)
}
assert.Equal(t, serverUS.URL, pubUS.(*mcuProxyPublisher).conn.rawUrl)
}
func Test_ProxyPublisherContinent(t *testing.T) {
@ -1126,15 +1077,11 @@ func Test_ProxyPublisherContinent(t *testing.T) {
country: "DE",
}
pubDE, err := mcu.NewPublisher(ctx, pubDEListener, pubDEId, pubDESid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubDEInitiator)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer pubDE.Close(context.Background())
if pubDE.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL {
t.Errorf("expected server %s, go %s", serverDE.URL, pubDE.(*mcuProxyPublisher).conn.rawUrl)
}
assert.Equal(t, serverDE.URL, pubDE.(*mcuProxyPublisher).conn.rawUrl)
pubFRId := "the-publisher-fr"
pubFRSid := "1234567890"
@ -1145,15 +1092,11 @@ func Test_ProxyPublisherContinent(t *testing.T) {
country: "FR",
}
pubFR, err := mcu.NewPublisher(ctx, pubFRListener, pubFRId, pubFRSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubFRInitiator)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer pubFR.Close(context.Background())
if pubFR.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL {
t.Errorf("expected server %s, go %s", serverDE.URL, pubFR.(*mcuProxyPublisher).conn.rawUrl)
}
assert.Equal(t, serverDE.URL, pubFR.(*mcuProxyPublisher).conn.rawUrl)
}
func Test_ProxySubscriberCountry(t *testing.T) {
@ -1178,15 +1121,11 @@ func Test_ProxySubscriberCountry(t *testing.T) {
country: "DE",
}
pub, err := mcu.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer pub.Close(context.Background())
if pub.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL {
t.Errorf("expected server %s, go %s", serverDE.URL, pub.(*mcuProxyPublisher).conn.rawUrl)
}
assert.Equal(t, serverDE.URL, pub.(*mcuProxyPublisher).conn.rawUrl)
subListener := &MockMcuListener{
publicId: "subscriber-public",
@ -1195,15 +1134,11 @@ func Test_ProxySubscriberCountry(t *testing.T) {
country: "US",
}
sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer sub.Close(context.Background())
if sub.(*mcuProxySubscriber).conn.rawUrl != serverUS.URL {
t.Errorf("expected server %s, go %s", serverUS.URL, sub.(*mcuProxySubscriber).conn.rawUrl)
}
assert.Equal(t, serverUS.URL, sub.(*mcuProxySubscriber).conn.rawUrl)
}
func Test_ProxySubscriberContinent(t *testing.T) {
@ -1228,15 +1163,11 @@ func Test_ProxySubscriberContinent(t *testing.T) {
country: "DE",
}
pub, err := mcu.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer pub.Close(context.Background())
if pub.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL {
t.Errorf("expected server %s, go %s", serverDE.URL, pub.(*mcuProxyPublisher).conn.rawUrl)
}
assert.Equal(t, serverDE.URL, pub.(*mcuProxyPublisher).conn.rawUrl)
subListener := &MockMcuListener{
publicId: "subscriber-public",
@ -1245,15 +1176,11 @@ func Test_ProxySubscriberContinent(t *testing.T) {
country: "FR",
}
sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer sub.Close(context.Background())
if sub.(*mcuProxySubscriber).conn.rawUrl != serverDE.URL {
t.Errorf("expected server %s, go %s", serverDE.URL, sub.(*mcuProxySubscriber).conn.rawUrl)
}
assert.Equal(t, serverDE.URL, sub.(*mcuProxySubscriber).conn.rawUrl)
}
func Test_ProxySubscriberBandwidth(t *testing.T) {
@ -1278,15 +1205,11 @@ func Test_ProxySubscriberBandwidth(t *testing.T) {
country: "DE",
}
pub, err := mcu.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer pub.Close(context.Background())
if pub.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL {
t.Errorf("expected server %s, go %s", serverDE.URL, pub.(*mcuProxyPublisher).conn.rawUrl)
}
assert.Equal(t, serverDE.URL, pub.(*mcuProxyPublisher).conn.rawUrl)
serverDE.UpdateBandwidth(0, 100)
@ -1315,15 +1238,11 @@ func Test_ProxySubscriberBandwidth(t *testing.T) {
country: "US",
}
sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer sub.Close(context.Background())
if sub.(*mcuProxySubscriber).conn.rawUrl != serverUS.URL {
t.Errorf("expected server %s, go %s", serverUS.URL, sub.(*mcuProxySubscriber).conn.rawUrl)
}
assert.Equal(t, serverUS.URL, sub.(*mcuProxySubscriber).conn.rawUrl)
}
func Test_ProxySubscriberBandwidthOverload(t *testing.T) {
@ -1348,15 +1267,11 @@ func Test_ProxySubscriberBandwidthOverload(t *testing.T) {
country: "DE",
}
pub, err := mcu.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer pub.Close(context.Background())
if pub.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL {
t.Errorf("expected server %s, go %s", serverDE.URL, pub.(*mcuProxyPublisher).conn.rawUrl)
}
assert.Equal(t, serverDE.URL, pub.(*mcuProxyPublisher).conn.rawUrl)
serverDE.UpdateBandwidth(0, 100)
serverUS.UpdateBandwidth(0, 102)
@ -1386,15 +1301,11 @@ func Test_ProxySubscriberBandwidthOverload(t *testing.T) {
country: "US",
}
sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer sub.Close(context.Background())
if sub.(*mcuProxySubscriber).conn.rawUrl != serverDE.URL {
t.Errorf("expected server %s, go %s", serverDE.URL, sub.(*mcuProxySubscriber).conn.rawUrl)
}
assert.Equal(t, serverDE.URL, sub.(*mcuProxySubscriber).conn.rawUrl)
}
type mockGrpcServerHub struct {
@ -1490,9 +1401,7 @@ func Test_ProxyRemotePublisher(t *testing.T) {
defer hub1.removeSession(session1)
pub, err := mcu1.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer pub.Close(context.Background())
@ -1508,9 +1417,7 @@ func Test_ProxyRemotePublisher(t *testing.T) {
country: "DE",
}
sub, err := mcu2.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer sub.Close(context.Background())
}
@ -1580,8 +1487,7 @@ func Test_ProxyRemotePublisherWait(t *testing.T) {
go func() {
defer close(done)
sub, err := mcu2.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator)
if err != nil {
t.Error(err)
if !assert.NoError(t, err) {
return
}
@ -1592,9 +1498,7 @@ func Test_ProxyRemotePublisherWait(t *testing.T) {
time.Sleep(100 * time.Millisecond)
pub, err := mcu1.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer pub.Close(context.Background())
@ -1606,7 +1510,7 @@ func Test_ProxyRemotePublisherWait(t *testing.T) {
select {
case <-done:
case <-ctx.Done():
t.Error(ctx.Err())
assert.NoError(t, ctx.Err())
}
}
@ -1663,9 +1567,7 @@ func Test_ProxyRemotePublisherTemporary(t *testing.T) {
defer hub1.removeSession(session1)
pub, err := mcu1.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer pub.Close(context.Background())
@ -1677,9 +1579,7 @@ func Test_ProxyRemotePublisherTemporary(t *testing.T) {
mcu2.connectionsMu.RLock()
count := len(mcu2.connections)
mcu2.connectionsMu.RUnlock()
if expected := 1; count != expected {
t.Errorf("expected %d connections, got %+v", expected, count)
}
assert.Equal(t, 1, count)
subListener := &MockMcuListener{
publicId: "subscriber-public",
@ -1688,23 +1588,17 @@ func Test_ProxyRemotePublisherTemporary(t *testing.T) {
country: "DE",
}
sub, err := mcu2.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer sub.Close(context.Background())
if sub.(*mcuProxySubscriber).conn.rawUrl != server1.URL {
t.Errorf("expected server %s, go %s", server1.URL, sub.(*mcuProxySubscriber).conn.rawUrl)
}
assert.Equal(t, server1.URL, sub.(*mcuProxySubscriber).conn.rawUrl)
// The temporary connection has been added
mcu2.connectionsMu.RLock()
count = len(mcu2.connections)
mcu2.connectionsMu.RUnlock()
if expected := 2; count != expected {
t.Errorf("expected %d connections, got %+v", expected, count)
}
assert.Equal(t, 2, count)
sub.Close(context.Background())
@ -1713,7 +1607,7 @@ loop:
for {
select {
case <-ctx.Done():
t.Error(ctx.Err())
assert.NoError(t, ctx.Err())
default:
mcu2.connectionsMu.RLock()
count = len(mcu2.connections)

View file

@ -25,6 +25,9 @@ import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func (c *LoopbackNatsClient) waitForSubscriptionsEmpty(ctx context.Context, t *testing.T) {
@ -39,7 +42,7 @@ func (c *LoopbackNatsClient) waitForSubscriptionsEmpty(ctx context.Context, t *t
select {
case <-ctx.Done():
c.mu.Lock()
t.Errorf("Error waiting for subscriptions %+v to terminate: %s", c.subscriptions, ctx.Err())
assert.NoError(t, ctx.Err(), "Error waiting for subscriptions %+v to terminate", c.subscriptions)
c.mu.Unlock()
return
default:
@ -50,9 +53,7 @@ func (c *LoopbackNatsClient) waitForSubscriptionsEmpty(ctx context.Context, t *t
func CreateLoopbackNatsClientForTest(t *testing.T) NatsClient {
result, err := NewLoopbackNatsClient()
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
t.Cleanup(func() {
result.Close()
})

View file

@ -27,6 +27,8 @@ import (
"time"
"github.com/nats-io/nats.go"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
natsserver "github.com/nats-io/nats-server/v2/test"
)
@ -46,9 +48,7 @@ func startLocalNatsServer(t *testing.T) string {
func CreateLocalNatsClientForTest(t *testing.T) NatsClient {
url := startLocalNatsServer(t)
result, err := NewNatsClient(url)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
t.Cleanup(func() {
result.Close()
})
@ -56,11 +56,11 @@ func CreateLocalNatsClientForTest(t *testing.T) NatsClient {
}
func testNatsClient_Subscribe(t *testing.T, client NatsClient) {
require := require.New(t)
assert := assert.New(t)
dest := make(chan *nats.Msg)
sub, err := client.Subscribe("foo", dest)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
ch := make(chan struct{})
var received atomic.Int32
@ -75,9 +75,7 @@ func testNatsClient_Subscribe(t *testing.T, client NatsClient) {
case <-dest:
total := received.Add(1)
if total == max {
err := sub.Unsubscribe()
if err != nil {
t.Errorf("Unsubscribe failed with err: %s", err)
if err := sub.Unsubscribe(); !assert.NoError(err) {
return
}
close(ch)
@ -89,18 +87,14 @@ func testNatsClient_Subscribe(t *testing.T, client NatsClient) {
}()
<-ready
for i := int32(0); i < max; i++ {
if err := client.Publish("foo", []byte("hello")); err != nil {
t.Error(err)
}
assert.NoError(client.Publish("foo", []byte("hello")))
// Allow NATS goroutines to process messages.
time.Sleep(10 * time.Millisecond)
}
<-ch
if r := received.Load(); r != max {
t.Fatalf("Received wrong # of messages: %d vs %d", r, max)
}
require.EqualValues(max, received.Load(), "Received wrong # of messages")
}
func TestNatsClient_Subscribe(t *testing.T) {
@ -115,9 +109,7 @@ func TestNatsClient_Subscribe(t *testing.T) {
func testNatsClient_PublishAfterClose(t *testing.T, client NatsClient) {
client.Close()
if err := client.Publish("foo", "bar"); err != nats.ErrConnectionClosed {
t.Errorf("Expected %v, got %v", nats.ErrConnectionClosed, err)
}
assert.ErrorIs(t, client.Publish("foo", "bar"), nats.ErrConnectionClosed)
}
func TestNatsClient_PublishAfterClose(t *testing.T) {
@ -133,9 +125,8 @@ func testNatsClient_SubscribeAfterClose(t *testing.T, client NatsClient) {
client.Close()
ch := make(chan *nats.Msg)
if _, err := client.Subscribe("foo", ch); err != nats.ErrConnectionClosed {
t.Errorf("Expected %v, got %v", nats.ErrConnectionClosed, err)
}
_, err := client.Subscribe("foo", ch)
assert.ErrorIs(t, err, nats.ErrConnectionClosed)
}
func TestNatsClient_SubscribeAfterClose(t *testing.T) {
@ -148,6 +139,7 @@ func TestNatsClient_SubscribeAfterClose(t *testing.T) {
}
func testNatsClient_BadSubjects(t *testing.T, client NatsClient) {
assert := assert.New(t)
subjects := []string{
"foo bar",
"foo.",
@ -155,9 +147,8 @@ func testNatsClient_BadSubjects(t *testing.T, client NatsClient) {
ch := make(chan *nats.Msg)
for _, s := range subjects {
if _, err := client.Subscribe(s, ch); err != nats.ErrBadSubject {
t.Errorf("Expected %v for subject %s, got %v", nats.ErrBadSubject, s, err)
}
_, err := client.Subscribe(s, ch)
assert.ErrorIs(err, nats.ErrBadSubject, "Expected error for subject %s", s)
}
}

View file

@ -26,6 +26,8 @@ import (
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestNotifierNoWaiter(t *testing.T) {
@ -48,9 +50,7 @@ func TestNotifierSimple(t *testing.T) {
defer wg.Done()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := waiter.Wait(ctx); err != nil {
t.Error(err)
}
assert.NoError(t, waiter.Wait(ctx))
}()
notifier.Notify("foo")
@ -74,9 +74,7 @@ func TestNotifierWaitClosed(t *testing.T) {
waiter := notifier.NewWaiter("foo")
notifier.Release(waiter)
if err := waiter.Wait(context.Background()); err != nil {
t.Error(err)
}
assert.NoError(t, waiter.Wait(context.Background()))
}
func TestNotifierWaitClosedMulti(t *testing.T) {
@ -87,12 +85,8 @@ func TestNotifierWaitClosedMulti(t *testing.T) {
notifier.Release(waiter1)
notifier.Release(waiter2)
if err := waiter1.Wait(context.Background()); err != nil {
t.Error(err)
}
if err := waiter2.Wait(context.Background()); err != nil {
t.Error(err)
}
assert.NoError(t, waiter1.Wait(context.Background()))
assert.NoError(t, waiter2.Wait(context.Background()))
}
func TestNotifierResetWillNotify(t *testing.T) {
@ -108,9 +102,7 @@ func TestNotifierResetWillNotify(t *testing.T) {
defer wg.Done()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := waiter.Wait(ctx); err != nil {
t.Error(err)
}
assert.NoError(t, waiter.Wait(ctx))
}()
notifier.Reset()
@ -137,9 +129,7 @@ func TestNotifierDuplicate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := waiter.Wait(ctx); err != nil {
t.Error(err)
}
assert.NoError(t, waiter.Wait(ctx))
}()
}

View file

@ -38,6 +38,8 @@ import (
"github.com/golang-jwt/jwt/v4"
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
signaling "github.com/strukturag/nextcloud-spreed-signaling"
)
@ -57,6 +59,7 @@ func getWebsocketUrl(url string) string {
}
func newProxyServerForTest(t *testing.T) (*ProxyServer, *rsa.PrivateKey, *httptest.Server) {
require := require.New(t)
tempdir := t.TempDir()
var proxy *ProxyServer
t.Cleanup(func() {
@ -67,43 +70,32 @@ func newProxyServerForTest(t *testing.T) (*ProxyServer, *rsa.PrivateKey, *httpte
r := mux.NewRouter()
key, err := rsa.GenerateKey(rand.Reader, KeypairSizeForTest)
if err != nil {
t.Fatalf("could not generate key: %s", err)
}
require.NoError(err)
priv := &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
}
privkey, err := os.CreateTemp(tempdir, "privkey*.pem")
if err != nil {
t.Fatalf("could not create temporary file for private key: %s", err)
}
if err := pem.Encode(privkey, priv); err != nil {
t.Fatalf("could not encode private key: %s", err)
}
require.NoError(err)
require.NoError(pem.Encode(privkey, priv))
require.NoError(privkey.Close())
pubData, err := x509.MarshalPKIXPublicKey(&key.PublicKey)
if err != nil {
t.Fatalf("could not marshal public key: %s", err)
}
require.NoError(err)
pub := &pem.Block{
Type: "RSA PUBLIC KEY",
Bytes: pubData,
}
pubkey, err := os.CreateTemp(tempdir, "pubkey*.pem")
if err != nil {
t.Fatalf("could not create temporary file for public key: %s", err)
}
if err := pem.Encode(pubkey, pub); err != nil {
t.Fatalf("could not encode public key: %s", err)
}
require.NoError(err)
require.NoError(pem.Encode(pubkey, pub))
require.NoError(pubkey.Close())
config := goconf.NewConfigFile()
config.AddOption("tokens", TokenIdForTest, pubkey.Name())
if proxy, err = NewProxyServer(r, "0.0", config); err != nil {
t.Fatalf("could not create proxy server: %s", err)
}
proxy, err = NewProxyServer(r, "0.0", config)
require.NoError(err)
server := httptest.NewServer(r)
t.Cleanup(func() {
@ -125,19 +117,14 @@ func TestTokenValid(t *testing.T) {
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
tokenString, err := token.SignedString(key)
if err != nil {
t.Fatalf("could not create token: %s", err)
}
require.NoError(t, err)
hello := &signaling.HelloProxyClientMessage{
Version: "1.0",
Token: tokenString,
}
session, err := proxy.NewSession(hello)
if session != nil {
if session, err := proxy.NewSession(hello); assert.NoError(t, err) {
defer session.Close()
} else if err != nil {
t.Error(err)
}
}
@ -153,20 +140,16 @@ func TestTokenNotSigned(t *testing.T) {
}
token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
tokenString, err := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
if err != nil {
t.Fatalf("could not create token: %s", err)
}
require.NoError(t, err)
hello := &signaling.HelloProxyClientMessage{
Version: "1.0",
Token: tokenString,
}
session, err := proxy.NewSession(hello)
if session != nil {
defer session.Close()
t.Errorf("should not have created session")
} else if err != TokenAuthFailed {
t.Errorf("could have failed with TokenAuthFailed, got %s", err)
if session, err := proxy.NewSession(hello); !assert.ErrorIs(t, err, TokenAuthFailed) {
if session != nil {
defer session.Close()
}
}
}
@ -182,20 +165,16 @@ func TestTokenUnknown(t *testing.T) {
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
tokenString, err := token.SignedString(key)
if err != nil {
t.Fatalf("could not create token: %s", err)
}
require.NoError(t, err)
hello := &signaling.HelloProxyClientMessage{
Version: "1.0",
Token: tokenString,
}
session, err := proxy.NewSession(hello)
if session != nil {
defer session.Close()
t.Errorf("should not have created session")
} else if err != TokenAuthFailed {
t.Errorf("could have failed with TokenAuthFailed, got %s", err)
if session, err := proxy.NewSession(hello); !assert.ErrorIs(t, err, TokenAuthFailed) {
if session != nil {
defer session.Close()
}
}
}
@ -211,20 +190,16 @@ func TestTokenInFuture(t *testing.T) {
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
tokenString, err := token.SignedString(key)
if err != nil {
t.Fatalf("could not create token: %s", err)
}
require.NoError(t, err)
hello := &signaling.HelloProxyClientMessage{
Version: "1.0",
Token: tokenString,
}
session, err := proxy.NewSession(hello)
if session != nil {
defer session.Close()
t.Errorf("should not have created session")
} else if err != TokenNotValidYet {
t.Errorf("could have failed with TokenNotValidYet, got %s", err)
if session, err := proxy.NewSession(hello); !assert.ErrorIs(t, err, TokenNotValidYet) {
if session != nil {
defer session.Close()
}
}
}
@ -240,24 +215,21 @@ func TestTokenExpired(t *testing.T) {
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
tokenString, err := token.SignedString(key)
if err != nil {
t.Fatalf("could not create token: %s", err)
}
require.NoError(t, err)
hello := &signaling.HelloProxyClientMessage{
Version: "1.0",
Token: tokenString,
}
session, err := proxy.NewSession(hello)
if session != nil {
defer session.Close()
t.Errorf("should not have created session")
} else if err != TokenExpired {
t.Errorf("could have failed with TokenExpired, got %s", err)
if session, err := proxy.NewSession(hello); !assert.ErrorIs(t, err, TokenExpired) {
if session != nil {
defer session.Close()
}
}
}
func TestPublicIPs(t *testing.T) {
assert := assert.New(t)
public := []string{
"8.8.8.8",
"172.15.1.2",
@ -275,35 +247,30 @@ func TestPublicIPs(t *testing.T) {
}
for _, s := range public {
ip := net.ParseIP(s)
if len(ip) == 0 {
t.Errorf("invalid IP: %s", s)
} else if !IsPublicIP(ip) {
t.Errorf("should be public IP: %s", s)
if assert.NotEmpty(ip, "invalid IP: %s", s) {
assert.True(IsPublicIP(ip), "should be public IP: %s", s)
}
}
for _, s := range private {
ip := net.ParseIP(s)
if len(ip) == 0 {
t.Errorf("invalid IP: %s", s)
} else if IsPublicIP(ip) {
t.Errorf("should be private IP: %s", s)
if assert.NotEmpty(ip, "invalid IP: %s", s) {
assert.False(IsPublicIP(ip), "should be private IP: %s", s)
}
}
}
func TestWebsocketFeatures(t *testing.T) {
signaling.CatchLogForTest(t)
assert := assert.New(t)
_, _, server := newProxyServerForTest(t)
conn, response, err := websocket.DefaultDialer.DialContext(context.Background(), getWebsocketUrl(server.URL), nil)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer conn.Close() // nolint
if server := response.Header.Get("Server"); !strings.HasPrefix(server, "nextcloud-spreed-signaling-proxy/") {
t.Errorf("expected valid server header, got \"%s\"", server)
assert.Fail("expected valid server header, got \"%s\"", server)
}
features := response.Header.Get("X-Spreed-Signaling-Features")
featuresList := make(map[string]bool)
@ -311,19 +278,15 @@ func TestWebsocketFeatures(t *testing.T) {
f = strings.TrimSpace(f)
if f != "" {
if _, found := featuresList[f]; found {
t.Errorf("duplicate feature id \"%s\" in \"%s\"", f, features)
assert.Fail("duplicate feature id \"%s\" in \"%s\"", f, features)
}
featuresList[f] = true
}
}
if len(featuresList) == 0 {
t.Errorf("expected valid features header, got \"%s\"", features)
}
assert.NotEmpty(featuresList, "expected valid features header, got \"%s\"", features)
if _, found := featuresList["remote-streams"]; !found {
t.Errorf("expected feature \"remote-streams\", got \"%s\"", features)
assert.Fail("expected feature \"remote-streams\", got \"%s\"", features)
}
if err := conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Time{}); err != nil {
t.Errorf("could not write close message: %s", err)
}
assert.NoError(conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Time{}))
}

View file

@ -37,6 +37,8 @@ import (
"testing"
"github.com/dlintw/goconf"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.etcd.io/etcd/server/v3/embed"
"go.etcd.io/etcd/server/v3/lease"
@ -73,9 +75,7 @@ func newEtcdForTesting(t *testing.T) *embed.Etcd {
cfg.LogLevel = "warn"
u, err := url.Parse(etcdListenUrl)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
// Find a free port to bind the server to.
var etcd *embed.Etcd
@ -91,14 +91,12 @@ func newEtcdForTesting(t *testing.T) *embed.Etcd {
etcd, err = embed.StartEtcd(cfg)
if isErrorAddressAlreadyInUse(err) {
continue
} else if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
break
}
if etcd == nil {
t.Fatal("could not find free port")
}
require.NotNil(t, etcd, "could not find free port")
t.Cleanup(func() {
etcd.Close()
@ -118,9 +116,7 @@ func newTokensEtcdForTesting(t *testing.T) (*tokensEtcd, *embed.Etcd) {
cfg.AddOption("tokens", "keyformat", "/%s, /testing/%s/key")
tokens, err := NewProxyTokensEtcd(cfg)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
t.Cleanup(func() {
tokens.Close()
})
@ -134,11 +130,9 @@ func storeKey(t *testing.T, etcd *embed.Etcd, key string, pubkey crypto.PublicKe
switch pubkey := pubkey.(type) {
case rsa.PublicKey:
data, err = x509.MarshalPKIXPublicKey(&pubkey)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
default:
t.Fatalf("unknown key type %T in %+v", pubkey, pubkey)
require.Fail(t, "unknown key type %T in %+v", pubkey, pubkey)
}
data = pem.EncodeToMemory(&pem.Block{
@ -154,9 +148,7 @@ func storeKey(t *testing.T, etcd *embed.Etcd, key string, pubkey crypto.PublicKe
func generateAndSaveKey(t *testing.T, etcd *embed.Etcd, name string) *rsa.PrivateKey {
key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
storeKey(t, etcd, name, key.PublicKey)
return key
@ -164,24 +156,17 @@ func generateAndSaveKey(t *testing.T, etcd *embed.Etcd, name string) *rsa.Privat
func TestProxyTokensEtcd(t *testing.T) {
signaling.CatchLogForTest(t)
assert := assert.New(t)
tokens, etcd := newTokensEtcdForTesting(t)
key1 := generateAndSaveKey(t, etcd, "/foo")
key2 := generateAndSaveKey(t, etcd, "/testing/bar/key")
if token, err := tokens.Get("foo"); err != nil {
t.Error(err)
} else if token == nil {
t.Error("could not get token")
} else if !key1.PublicKey.Equal(token.key) {
t.Error("token keys mismatch")
if token, err := tokens.Get("foo"); assert.NoError(err) && assert.NotNil(token) {
assert.True(key1.PublicKey.Equal(token.key))
}
if token, err := tokens.Get("bar"); err != nil {
t.Error(err)
} else if token == nil {
t.Error("could not get token")
} else if !key2.PublicKey.Equal(token.key) {
t.Error("token keys mismatch")
if token, err := tokens.Get("bar"); assert.NoError(err) && assert.NotNil(token) {
assert.True(key2.PublicKey.Equal(token.key))
}
}

View file

@ -28,6 +28,7 @@ import (
"time"
"github.com/dlintw/goconf"
"github.com/stretchr/testify/require"
"go.etcd.io/etcd/server/v3/embed"
)
@ -43,9 +44,7 @@ func newProxyConfigEtcd(t *testing.T, proxy McuProxy) (*embed.Etcd, ProxyConfig)
cfg := goconf.NewConfigFile()
cfg.AddOption("mcu", "keyprefix", "proxies/")
p, err := NewProxyConfigEtcd(cfg, client, proxy)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
t.Cleanup(func() {
p.Stop()
})
@ -55,9 +54,7 @@ func newProxyConfigEtcd(t *testing.T, proxy McuProxy) (*embed.Etcd, ProxyConfig)
func SetEtcdProxy(t *testing.T, etcd *embed.Etcd, path string, proxy *TestProxyInformationEtcd) {
t.Helper()
data, err := json.Marshal(proxy)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
SetEtcdValue(etcd, path, data)
}
@ -74,9 +71,7 @@ func TestProxyConfigEtcd(t *testing.T) {
Address: "https://foo/",
})
proxy.Expect("add", "https://foo/")
if err := config.Start(); err != nil {
t.Fatal(err)
}
require.NoError(t, config.Start())
proxy.WaitForEvents(ctx)
proxy.Expect("add", "https://bar/")

View file

@ -28,6 +28,7 @@ import (
"time"
"github.com/dlintw/goconf"
"github.com/stretchr/testify/require"
)
func newProxyConfigStatic(t *testing.T, proxy McuProxy, dns bool, urls ...string) (ProxyConfig, *DnsMonitor) {
@ -38,9 +39,7 @@ func newProxyConfigStatic(t *testing.T, proxy McuProxy, dns bool, urls ...string
}
dnsMonitor := newDnsMonitorForTest(t, time.Hour) // will be updated manually
p, err := NewProxyConfigStatic(cfg, proxy, dnsMonitor)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
t.Cleanup(func() {
p.Stop()
})
@ -53,9 +52,7 @@ func updateProxyConfigStatic(t *testing.T, config ProxyConfig, dns bool, urls ..
if dns {
cfg.AddOption("mcu", "dnsdiscovery", "true")
}
if err := config.Reload(cfg); err != nil {
t.Fatal(err)
}
require.NoError(t, config.Reload(cfg))
}
func TestProxyConfigStaticSimple(t *testing.T) {
@ -63,9 +60,7 @@ func TestProxyConfigStaticSimple(t *testing.T) {
proxy := newMcuProxyForConfig(t)
config, _ := newProxyConfigStatic(t, proxy, false, "https://foo/")
proxy.Expect("add", "https://foo/")
if err := config.Start(); err != nil {
t.Fatal(err)
}
require.NoError(t, config.Start())
proxy.Expect("keep", "https://foo/")
proxy.Expect("add", "https://bar/")
@ -82,9 +77,7 @@ func TestProxyConfigStaticDNS(t *testing.T) {
lookup := newMockDnsLookupForTest(t)
proxy := newMcuProxyForConfig(t)
config, dnsMonitor := newProxyConfigStatic(t, proxy, true, "https://foo/")
if err := config.Start(); err != nil {
t.Fatal(err)
}
require.NoError(t, config.Start())
time.Sleep(time.Millisecond)

View file

@ -29,6 +29,8 @@ import (
"strings"
"sync"
"testing"
"github.com/stretchr/testify/assert"
)
var (
@ -61,9 +63,7 @@ func newMcuProxyForConfig(t *testing.T) *mcuProxyForConfig {
t: t,
}
t.Cleanup(func() {
if len(proxy.expected) > 0 {
t.Errorf("expected events %+v were not triggered", proxy.expected)
}
assert.Empty(t, proxy.expected)
})
return proxy
}
@ -99,7 +99,7 @@ func (p *mcuProxyForConfig) WaitForEvents(ctx context.Context) {
defer p.mu.Lock()
select {
case <-ctx.Done():
p.t.Error(ctx.Err())
assert.NoError(p.t, ctx.Err())
case <-waiter:
}
}
@ -125,7 +125,7 @@ func (p *mcuProxyForConfig) checkEvent(event *proxyConfigEvent) {
defer p.mu.Unlock()
if len(p.expected) == 0 {
p.t.Errorf("no event expected, got %+v from %s:%d", event, caller.File, caller.Line)
assert.Fail(p.t, "no event expected, got %+v from %s:%d", event, caller.File, caller.Line)
return
}
@ -145,7 +145,7 @@ func (p *mcuProxyForConfig) checkEvent(event *proxyConfigEvent) {
expected := p.expected[0]
p.expected = p.expected[1:]
if !reflect.DeepEqual(expected, *event) {
p.t.Errorf("expected %+v, got %+v from %s:%d", expected, event, caller.File, caller.Line)
assert.Fail(p.t, "expected %+v, got %+v from %s:%d", expected, event, caller.File, caller.Line)
}
}

View file

@ -28,9 +28,12 @@ import (
"testing"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func NewRoomPingForTest(t *testing.T) (*url.URL, *RoomPing) {
require := require.New(t)
r := mux.NewRouter()
registerBackendHandler(t, r)
@ -40,30 +43,23 @@ func NewRoomPingForTest(t *testing.T) (*url.URL, *RoomPing) {
})
config, err := getTestConfig(server)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
backend, err := NewBackendClient(config, 1, "0.0", nil)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
p, err := NewRoomPing(backend, backend.capabilities)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
u, err := url.Parse(server.URL)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
return u, p
}
func TestSingleRoomPing(t *testing.T) {
CatchLogForTest(t)
assert := assert.New(t)
u, ping := NewRoomPingForTest(t)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
@ -78,13 +74,9 @@ func TestSingleRoomPing(t *testing.T) {
SessionId: "123",
},
}
if err := ping.SendPings(ctx, room1.Id(), u, entries1); err != nil {
t.Error(err)
}
if requests := getPingRequests(t); len(requests) != 1 {
t.Errorf("expected one ping request, got %+v", requests)
} else if len(requests[0].Ping.Entries) != 1 {
t.Errorf("expected one entry, got %+v", requests[0].Ping.Entries)
assert.NoError(ping.SendPings(ctx, room1.Id(), u, entries1))
if requests := getPingRequests(t); assert.Len(requests, 1) {
assert.Len(requests[0].Ping.Entries, 1)
}
clearPingRequests(t)
@ -97,24 +89,19 @@ func TestSingleRoomPing(t *testing.T) {
SessionId: "456",
},
}
if err := ping.SendPings(ctx, room2.Id(), u, entries2); err != nil {
t.Error(err)
}
if requests := getPingRequests(t); len(requests) != 1 {
t.Errorf("expected one ping request, got %+v", requests)
} else if len(requests[0].Ping.Entries) != 1 {
t.Errorf("expected one entry, got %+v", requests[0].Ping.Entries)
assert.NoError(ping.SendPings(ctx, room2.Id(), u, entries2))
if requests := getPingRequests(t); assert.Len(requests, 1) {
assert.Len(requests[0].Ping.Entries, 1)
}
clearPingRequests(t)
ping.publishActiveSessions()
if requests := getPingRequests(t); len(requests) != 0 {
t.Errorf("expected no ping requests, got %+v", requests)
}
assert.Empty(getPingRequests(t))
}
func TestMultiRoomPing(t *testing.T) {
CatchLogForTest(t)
assert := assert.New(t)
u, ping := NewRoomPingForTest(t)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
@ -129,12 +116,8 @@ func TestMultiRoomPing(t *testing.T) {
SessionId: "123",
},
}
if err := ping.SendPings(ctx, room1.Id(), u, entries1); err != nil {
t.Error(err)
}
if requests := getPingRequests(t); len(requests) != 0 {
t.Errorf("expected no ping requests, got %+v", requests)
}
assert.NoError(ping.SendPings(ctx, room1.Id(), u, entries1))
assert.Empty(getPingRequests(t))
room2 := &Room{
id: "sample-room-2",
@ -145,23 +128,18 @@ func TestMultiRoomPing(t *testing.T) {
SessionId: "456",
},
}
if err := ping.SendPings(ctx, room2.Id(), u, entries2); err != nil {
t.Error(err)
}
if requests := getPingRequests(t); len(requests) != 0 {
t.Errorf("expected no ping requests, got %+v", requests)
}
assert.NoError(ping.SendPings(ctx, room2.Id(), u, entries2))
assert.Empty(getPingRequests(t))
ping.publishActiveSessions()
if requests := getPingRequests(t); len(requests) != 1 {
t.Errorf("expected one ping request, got %+v", requests)
} else if len(requests[0].Ping.Entries) != 2 {
t.Errorf("expected two entries, got %+v", requests[0].Ping.Entries)
if requests := getPingRequests(t); assert.Len(requests, 1) {
assert.Len(requests[0].Ping.Entries, 2)
}
}
func TestMultiRoomPing_Separate(t *testing.T) {
CatchLogForTest(t)
assert := assert.New(t)
u, ping := NewRoomPingForTest(t)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
@ -176,35 +154,26 @@ func TestMultiRoomPing_Separate(t *testing.T) {
SessionId: "123",
},
}
if err := ping.SendPings(ctx, room1.Id(), u, entries1); err != nil {
t.Error(err)
}
if requests := getPingRequests(t); len(requests) != 0 {
t.Errorf("expected no ping requests, got %+v", requests)
}
assert.NoError(ping.SendPings(ctx, room1.Id(), u, entries1))
assert.Empty(getPingRequests(t))
entries2 := []BackendPingEntry{
{
UserId: "bar",
SessionId: "456",
},
}
if err := ping.SendPings(ctx, room1.Id(), u, entries2); err != nil {
t.Error(err)
}
if requests := getPingRequests(t); len(requests) != 0 {
t.Errorf("expected no ping requests, got %+v", requests)
}
assert.NoError(ping.SendPings(ctx, room1.Id(), u, entries2))
assert.Empty(getPingRequests(t))
ping.publishActiveSessions()
if requests := getPingRequests(t); len(requests) != 1 {
t.Errorf("expected one ping request, got %+v", requests)
} else if len(requests[0].Ping.Entries) != 2 {
t.Errorf("expected two entries, got %+v", requests[0].Ping.Entries)
if requests := getPingRequests(t); assert.Len(requests, 1) {
assert.Len(requests[0].Ping.Entries, 2)
}
}
func TestMultiRoomPing_DeleteRoom(t *testing.T) {
CatchLogForTest(t)
assert := assert.New(t)
u, ping := NewRoomPingForTest(t)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
@ -219,12 +188,8 @@ func TestMultiRoomPing_DeleteRoom(t *testing.T) {
SessionId: "123",
},
}
if err := ping.SendPings(ctx, room1.Id(), u, entries1); err != nil {
t.Error(err)
}
if requests := getPingRequests(t); len(requests) != 0 {
t.Errorf("expected no ping requests, got %+v", requests)
}
assert.NoError(ping.SendPings(ctx, room1.Id(), u, entries1))
assert.Empty(getPingRequests(t))
room2 := &Room{
id: "sample-room-2",
@ -235,19 +200,13 @@ func TestMultiRoomPing_DeleteRoom(t *testing.T) {
SessionId: "456",
},
}
if err := ping.SendPings(ctx, room2.Id(), u, entries2); err != nil {
t.Error(err)
}
if requests := getPingRequests(t); len(requests) != 0 {
t.Errorf("expected no ping requests, got %+v", requests)
}
assert.NoError(ping.SendPings(ctx, room2.Id(), u, entries2))
assert.Empty(getPingRequests(t))
ping.DeleteRoom(room2.Id())
ping.publishActiveSessions()
if requests := getPingRequests(t); len(requests) != 1 {
t.Errorf("expected one ping request, got %+v", requests)
} else if len(requests[0].Ping.Entries) != 1 {
t.Errorf("expected two entries, got %+v", requests[0].Ping.Entries)
if requests := getPingRequests(t); assert.Len(requests, 1) {
assert.Len(requests[0].Ping.Entries, 1)
}
}

View file

@ -27,11 +27,14 @@ import (
"encoding/json"
"fmt"
"io"
"net/http"
"strconv"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRoom_InCall(t *testing.T) {
@ -63,59 +66,47 @@ func TestRoom_InCall(t *testing.T) {
}
for _, test := range tests {
inCall, ok := IsInCall(test.Value)
if ok != test.Valid {
t.Errorf("%+v should be valid %v, got %v", test.Value, test.Valid, ok)
}
if inCall != test.InCall {
t.Errorf("%+v should convert to %v, got %v", test.Value, test.InCall, inCall)
if test.Valid {
assert.True(t, ok, "%+v should be valid", test.Value)
} else {
assert.False(t, ok, "%+v should not be valid", test.Value)
}
assert.EqualValues(t, test.InCall, inCall, "conversion failed for %+v", test.Value)
}
}
func TestRoom_Update(t *testing.T) {
t.Parallel()
CatchLogForTest(t)
require := require.New(t)
assert := assert.New(t)
hub, _, router, server := CreateHubForTest(t)
config, err := getTestConfig(server)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
b, err := NewBackendServer(config, hub, "no-version")
if err != nil {
t.Fatal(err)
}
if err := b.Start(router); err != nil {
t.Fatal(err)
}
require.NoError(err)
require.NoError(b.Start(router))
client := NewTestClient(t, server, hub)
defer client.CloseWithBye()
if err := client.SendHello(testDefaultUserId); err != nil {
t.Fatal(err)
}
require.NoError(client.SendHello(testDefaultUserId))
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
hello, err := client.RunUntilHello(ctx)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
// Join room by id.
roomId := "test-room"
if room, err := client.JoinRoom(ctx, roomId); err != nil {
t.Fatal(err)
} else if room.Room.RoomId != roomId {
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
}
roomMsg, err := client.JoinRoom(ctx, roomId)
require.NoError(err)
require.Equal(roomId, roomMsg.Room.RoomId)
// We will receive a "joined" event.
if err := client.RunUntilJoined(ctx, hello.Hello); err != nil {
t.Error(err)
}
assert.NoError(client.RunUntilJoined(ctx, hello.Hello))
// Simulate backend request from Nextcloud to update the room.
roomProperties := json.RawMessage("{\"foo\":\"bar\"}")
@ -130,54 +121,32 @@ func TestRoom_Update(t *testing.T) {
}
data, err := json.Marshal(msg)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
res, err := performBackendRequest(server.URL+"/api/v1/room/"+roomId, data)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
t.Error(err)
}
if res.StatusCode != 200 {
t.Errorf("Expected successful request, got %s: %s", res.Status, string(body))
}
assert.NoError(err)
assert.Equal(http.StatusOK, res.StatusCode, "Expected successful request, got %s", string(body))
// The client receives a roomlist update and a changed room event. The
// ordering is not defined because messages are sent by asynchronous event
// handlers.
message1, err := client.RunUntilMessage(ctx)
if err != nil {
t.Error(err)
}
assert.NoError(err)
message2, err := client.RunUntilMessage(ctx)
if err != nil {
t.Error(err)
}
assert.NoError(err)
if msg, err := checkMessageRoomlistUpdate(message1); err != nil {
if err := checkMessageRoomId(message1, roomId); err != nil {
t.Error(err)
}
if msg, err := checkMessageRoomlistUpdate(message2); err != nil {
t.Error(err)
} else if msg.RoomId != roomId {
t.Errorf("Expected room id %s, got %+v", roomId, msg)
} else if len(msg.Properties) == 0 || !bytes.Equal(msg.Properties, roomProperties) {
t.Errorf("Expected room properties %s, got %+v", string(roomProperties), msg)
assert.NoError(checkMessageRoomId(message1, roomId))
if msg, err := checkMessageRoomlistUpdate(message2); assert.NoError(err) {
assert.Equal(roomId, msg.RoomId)
assert.Equal(string(roomProperties), string(msg.Properties))
}
} else {
if msg.RoomId != roomId {
t.Errorf("Expected room id %s, got %+v", roomId, msg)
} else if len(msg.Properties) == 0 || !bytes.Equal(msg.Properties, roomProperties) {
t.Errorf("Expected room properties %s, got %+v", string(roomProperties), msg)
}
if err := checkMessageRoomId(message2, roomId); err != nil {
t.Error(err)
}
assert.Equal(roomId, msg.RoomId)
assert.Equal(string(roomProperties), string(msg.Properties))
assert.NoError(checkMessageRoomId(message2, roomId))
}
// Allow up to 100 milliseconds for asynchronous event processing.
@ -206,55 +175,41 @@ loop:
time.Sleep(time.Millisecond)
}
if err != nil {
t.Error(err)
}
assert.NoError(err)
}
func TestRoom_Delete(t *testing.T) {
t.Parallel()
CatchLogForTest(t)
require := require.New(t)
assert := assert.New(t)
hub, _, router, server := CreateHubForTest(t)
config, err := getTestConfig(server)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
b, err := NewBackendServer(config, hub, "no-version")
if err != nil {
t.Fatal(err)
}
if err := b.Start(router); err != nil {
t.Fatal(err)
}
require.NoError(err)
require.NoError(b.Start(router))
client := NewTestClient(t, server, hub)
defer client.CloseWithBye()
if err := client.SendHello(testDefaultUserId); err != nil {
t.Fatal(err)
}
require.NoError(client.SendHello(testDefaultUserId))
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
hello, err := client.RunUntilHello(ctx)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
// Join room by id.
roomId := "test-room"
if room, err := client.JoinRoom(ctx, roomId); err != nil {
t.Fatal(err)
} else if room.Room.RoomId != roomId {
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
}
roomMsg, err := client.JoinRoom(ctx, roomId)
require.NoError(err)
require.Equal(roomId, roomMsg.Room.RoomId)
// We will receive a "joined" event.
if err := client.RunUntilJoined(ctx, hello.Hello); err != nil {
t.Error(err)
}
assert.NoError(client.RunUntilJoined(ctx, hello.Hello))
// Simulate backend request from Nextcloud to update the room.
msg := &BackendServerRoomRequest{
@ -267,47 +222,31 @@ func TestRoom_Delete(t *testing.T) {
}
data, err := json.Marshal(msg)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
res, err := performBackendRequest(server.URL+"/api/v1/room/"+roomId, data)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
t.Error(err)
}
if res.StatusCode != 200 {
t.Errorf("Expected successful request, got %s: %s", res.Status, string(body))
}
assert.NoError(err)
assert.Equal(http.StatusOK, res.StatusCode, "Expected successful request, got %s", string(body))
// The client is no longer invited to the room and leaves it. The ordering
// of messages is not defined as they get published through events and handled
// by asynchronous channels.
message1, err := client.RunUntilMessage(ctx)
if err != nil {
t.Error(err)
}
assert.NoError(err)
if err := checkMessageType(message1, "event"); err != nil {
// Ordering should be "leave room", "disinvited".
if err := checkMessageRoomId(message1, ""); err != nil {
t.Error(err)
}
message2, err := client.RunUntilMessage(ctx)
if err != nil {
t.Error(err)
}
if _, err := checkMessageRoomlistDisinvite(message2); err != nil {
t.Error(err)
assert.NoError(checkMessageRoomId(message1, ""))
if message2, err := client.RunUntilMessage(ctx); assert.NoError(err) {
_, err := checkMessageRoomlistDisinvite(message2)
assert.NoError(err)
}
} else {
// Ordering should be "disinvited", "leave room".
if _, err := checkMessageRoomlistDisinvite(message1); err != nil {
t.Error(err)
}
_, err := checkMessageRoomlistDisinvite(message1)
assert.NoError(err)
message2, err := client.RunUntilMessage(ctx)
if err != nil {
// The connection should get closed after the "disinvited".
@ -315,10 +254,10 @@ func TestRoom_Delete(t *testing.T) {
websocket.CloseNormalClosure,
websocket.CloseGoingAway,
websocket.CloseNoStatusReceived) {
t.Error(err)
assert.NoError(err)
}
} else if err := checkMessageRoomId(message2, ""); err != nil {
t.Error(err)
} else {
assert.NoError(checkMessageRoomId(message2, ""))
}
}
@ -350,151 +289,106 @@ loop:
time.Sleep(time.Millisecond)
}
if err != nil {
t.Error(err)
}
assert.NoError(err)
}
func TestRoom_RoomSessionData(t *testing.T) {
t.Parallel()
CatchLogForTest(t)
require := require.New(t)
assert := assert.New(t)
hub, _, router, server := CreateHubForTest(t)
config, err := getTestConfig(server)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
b, err := NewBackendServer(config, hub, "no-version")
if err != nil {
t.Fatal(err)
}
if err := b.Start(router); err != nil {
t.Fatal(err)
}
require.NoError(err)
require.NoError(b.Start(router))
client := NewTestClient(t, server, hub)
defer client.CloseWithBye()
if err := client.SendHello(authAnonymousUserId); err != nil {
t.Fatal(err)
}
require.NoError(client.SendHello(authAnonymousUserId))
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
hello, err := client.RunUntilHello(ctx)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
// Join room by id.
roomId := "test-room-with-sessiondata"
if room, err := client.JoinRoom(ctx, roomId); err != nil {
t.Fatal(err)
} else if room.Room.RoomId != roomId {
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
}
roomMsg, err := client.JoinRoom(ctx, roomId)
require.NoError(err)
require.Equal(roomId, roomMsg.Room.RoomId)
// We will receive a "joined" event with the userid from the room session data.
expected := "userid-from-sessiondata"
if message, err := client.RunUntilMessage(ctx); err != nil {
t.Error(err)
} else if err := client.checkMessageJoinedSession(message, hello.Hello.SessionId, expected); err != nil {
t.Error(err)
} else if message.Event.Join[0].RoomSessionId != roomId+"-"+hello.Hello.SessionId {
t.Errorf("Expected join room session id %s, got %+v", roomId+"-"+hello.Hello.SessionId, message.Event.Join[0])
if message, err := client.RunUntilMessage(ctx); assert.NoError(err) {
if assert.NoError(client.checkMessageJoinedSession(message, hello.Hello.SessionId, expected)) {
assert.Equal(roomId+"-"+hello.Hello.SessionId, message.Event.Join[0].RoomSessionId)
}
}
session := hub.GetSessionByPublicId(hello.Hello.SessionId)
if session == nil {
t.Fatalf("Could not find session %s", hello.Hello.SessionId)
}
if userid := session.UserId(); userid != expected {
t.Errorf("Expected userid %s, got %s", expected, userid)
}
require.NotNil(session, "Could not find session %s", hello.Hello.SessionId)
assert.Equal(expected, session.UserId())
room := hub.getRoom(roomId)
if room == nil {
t.Fatalf("Room not found")
}
assert.NotNil(room, "Room not found")
entries, wg := room.publishActiveSessions()
if entries != 1 {
t.Errorf("expected 1 entries, got %d", entries)
}
assert.Equal(1, entries)
wg.Wait()
}
func TestRoom_InCallAll(t *testing.T) {
t.Parallel()
CatchLogForTest(t)
require := require.New(t)
assert := assert.New(t)
hub, _, router, server := CreateHubForTest(t)
config, err := getTestConfig(server)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
b, err := NewBackendServer(config, hub, "no-version")
if err != nil {
t.Fatal(err)
}
if err := b.Start(router); err != nil {
t.Fatal(err)
}
require.NoError(err)
require.NoError(b.Start(router))
client1 := NewTestClient(t, server, hub)
defer client1.CloseWithBye()
if err := client1.SendHello(testDefaultUserId + "1"); err != nil {
t.Fatal(err)
}
require.NoError(client1.SendHello(testDefaultUserId + "1"))
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
hello1, err := client1.RunUntilHello(ctx)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
client2 := NewTestClient(t, server, hub)
defer client2.CloseWithBye()
if err := client2.SendHello(testDefaultUserId + "2"); err != nil {
t.Fatal(err)
}
require.NoError(client2.SendHello(testDefaultUserId + "2"))
hello2, err := client2.RunUntilHello(ctx)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
// Join room by id.
roomId := "test-room"
if room, err := client1.JoinRoom(ctx, roomId); err != nil {
t.Fatal(err)
} else if room.Room.RoomId != roomId {
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
}
roomMsg, err := client1.JoinRoom(ctx, roomId)
require.NoError(err)
require.Equal(roomId, roomMsg.Room.RoomId)
if err := client1.RunUntilJoined(ctx, hello1.Hello); err != nil {
t.Error(err)
}
assert.NoError(client1.RunUntilJoined(ctx, hello1.Hello))
if room, err := client2.JoinRoom(ctx, roomId); err != nil {
t.Fatal(err)
} else if room.Room.RoomId != roomId {
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
}
roomMsg, err = client2.JoinRoom(ctx, roomId)
require.NoError(err)
require.Equal(roomId, roomMsg.Room.RoomId)
if err := client2.RunUntilJoined(ctx, hello1.Hello, hello2.Hello); err != nil {
t.Error(err)
}
assert.NoError(client2.RunUntilJoined(ctx, hello1.Hello, hello2.Hello))
if err := client1.RunUntilJoined(ctx, hello2.Hello); err != nil {
t.Error(err)
}
assert.NoError(client1.RunUntilJoined(ctx, hello2.Hello))
// Simulate backend request from Nextcloud to update the "inCall" flag of all participants.
msg1 := &BackendServerRoomRequest{
@ -506,32 +400,20 @@ func TestRoom_InCallAll(t *testing.T) {
}
data1, err := json.Marshal(msg1)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
res1, err := performBackendRequest(server.URL+"/api/v1/room/"+roomId, data1)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
defer res1.Body.Close()
body1, err := io.ReadAll(res1.Body)
if err != nil {
t.Error(err)
}
if res1.StatusCode != 200 {
t.Errorf("Expected successful request, got %s: %s", res1.Status, string(body1))
assert.NoError(err)
assert.Equal(http.StatusOK, res1.StatusCode, "Expected successful request, got %s", string(body1))
if msg, err := client1.RunUntilMessage(ctx); assert.NoError(err) {
assert.NoError(checkMessageInCallAll(msg, roomId, FlagInCall))
}
if msg, err := client1.RunUntilMessage(ctx); err != nil {
t.Fatal(err)
} else if err := checkMessageInCallAll(msg, roomId, FlagInCall); err != nil {
t.Fatal(err)
}
if msg, err := client2.RunUntilMessage(ctx); err != nil {
t.Fatal(err)
} else if err := checkMessageInCallAll(msg, roomId, FlagInCall); err != nil {
t.Fatal(err)
if msg, err := client2.RunUntilMessage(ctx); assert.NoError(err) {
assert.NoError(checkMessageInCallAll(msg, roomId, FlagInCall))
}
// Simulate backend request from Nextcloud to update the "inCall" flag of all participants.
@ -544,31 +426,19 @@ func TestRoom_InCallAll(t *testing.T) {
}
data2, err := json.Marshal(msg2)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
res2, err := performBackendRequest(server.URL+"/api/v1/room/"+roomId, data2)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
defer res2.Body.Close()
body2, err := io.ReadAll(res2.Body)
if err != nil {
t.Error(err)
}
if res2.StatusCode != 200 {
t.Errorf("Expected successful request, got %s: %s", res2.Status, string(body2))
assert.NoError(err)
assert.Equal(http.StatusOK, res2.StatusCode, "Expected successful request, got %s", string(body2))
if msg, err := client1.RunUntilMessage(ctx); assert.NoError(err) {
assert.NoError(checkMessageInCallAll(msg, roomId, 0))
}
if msg, err := client1.RunUntilMessage(ctx); err != nil {
t.Fatal(err)
} else if err := checkMessageInCallAll(msg, roomId, 0); err != nil {
t.Fatal(err)
}
if msg, err := client2.RunUntilMessage(ctx); err != nil {
t.Fatal(err)
} else if err := checkMessageInCallAll(msg, roomId, 0); err != nil {
t.Fatal(err)
if msg, err := client2.RunUntilMessage(ctx); assert.NoError(err) {
assert.NoError(checkMessageInCallAll(msg, roomId, 0))
}
}

View file

@ -23,13 +23,13 @@ package signaling
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestBuiltinRoomSessions(t *testing.T) {
sessions, err := NewBuiltinRoomSessions(nil)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
testRoomSessions(t, sessions)
}

View file

@ -24,9 +24,11 @@ package signaling
import (
"context"
"encoding/json"
"errors"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type DummySession struct {
@ -107,70 +109,53 @@ func checkSession(t *testing.T, sessions RoomSessions, sessionId string, roomSes
session := &DummySession{
publicId: sessionId,
}
if err := sessions.SetRoomSession(session, roomSessionId); err != nil {
t.Fatalf("Expected no error, got %s", err)
}
if sid, err := sessions.GetSessionId(roomSessionId); err != nil {
t.Errorf("Expected session id %s, got error %s", sessionId, err)
} else if sid != sessionId {
t.Errorf("Expected session id %s, got %s", sessionId, sid)
require.NoError(t, sessions.SetRoomSession(session, roomSessionId))
if sid, err := sessions.GetSessionId(roomSessionId); assert.NoError(t, err) {
assert.Equal(t, sessionId, sid)
}
return session
}
func testRoomSessions(t *testing.T, sessions RoomSessions) {
if sid, err := sessions.GetSessionId("unknown"); err != nil && err != ErrNoSuchRoomSession {
t.Errorf("Expected error about invalid room session, got %s", err)
} else if err == nil {
t.Errorf("Expected error about invalid room session, got session id %s", sid)
assert := assert.New(t)
if sid, err := sessions.GetSessionId("unknown"); err == nil {
assert.Fail("Expected error about invalid room session, got session id %s", sid)
} else {
assert.ErrorIs(err, ErrNoSuchRoomSession)
}
s1 := checkSession(t, sessions, "session1", "room1")
s2 := checkSession(t, sessions, "session2", "room2")
if sid, err := sessions.GetSessionId("room1"); err != nil {
t.Errorf("Expected session id %s, got error %s", s1.PublicId(), err)
} else if sid != s1.PublicId() {
t.Errorf("Expected session id %s, got %s", s1.PublicId(), sid)
if sid, err := sessions.GetSessionId("room1"); assert.NoError(err) {
assert.Equal(s1.PublicId(), sid)
}
sessions.DeleteRoomSession(s1)
if sid, err := sessions.GetSessionId("room1"); err != nil && err != ErrNoSuchRoomSession {
t.Errorf("Expected error about invalid room session, got %s", err)
} else if err == nil {
t.Errorf("Expected error about invalid room session, got session id %s", sid)
if sid, err := sessions.GetSessionId("room1"); err == nil {
assert.Fail("Expected error about invalid room session, got session id %s", sid)
} else {
assert.ErrorIs(err, ErrNoSuchRoomSession)
}
if sid, err := sessions.GetSessionId("room2"); err != nil {
t.Errorf("Expected session id %s, got error %s", s2.PublicId(), err)
} else if sid != s2.PublicId() {
t.Errorf("Expected session id %s, got %s", s2.PublicId(), sid)
if sid, err := sessions.GetSessionId("room2"); assert.NoError(err) {
assert.Equal(s2.PublicId(), sid)
}
if err := sessions.SetRoomSession(s1, "room-session"); err != nil {
t.Error(err)
}
if err := sessions.SetRoomSession(s2, "room-session"); err != nil {
t.Error(err)
}
assert.NoError(sessions.SetRoomSession(s1, "room-session"))
assert.NoError(sessions.SetRoomSession(s2, "room-session"))
sessions.DeleteRoomSession(s1)
if sid, err := sessions.GetSessionId("room-session"); err != nil {
t.Errorf("Expected session id %s, got error %s", s2.PublicId(), err)
} else if sid != s2.PublicId() {
t.Errorf("Expected session id %s, got %s", s2.PublicId(), sid)
if sid, err := sessions.GetSessionId("room-session"); assert.NoError(err) {
assert.Equal(s2.PublicId(), sid)
}
if err := sessions.SetRoomSession(s2, "room-session2"); err != nil {
t.Error(err)
}
assert.NoError(sessions.SetRoomSession(s2, "room-session2"))
if sid, err := sessions.GetSessionId("room-session"); err == nil {
t.Errorf("expected error %s, got sid %s", ErrNoSuchRoomSession, sid)
} else if !errors.Is(err, ErrNoSuchRoomSession) {
t.Errorf("expected %s, got %s", ErrNoSuchRoomSession, err)
assert.Fail("Expected error about invalid room session, got session id %s", sid)
} else {
assert.ErrorIs(err, ErrNoSuchRoomSession)
}
if sid, err := sessions.GetSessionId("room-session2"); err != nil {
t.Errorf("Expected session id %s, got error %s", s2.PublicId(), err)
} else if sid != s2.PublicId() {
t.Errorf("Expected session id %s, got %s", s2.PublicId(), sid)
if sid, err := sessions.GetSessionId("room-session2"); assert.NoError(err) {
assert.Equal(s2.PublicId(), sid)
}
}

View file

@ -23,18 +23,16 @@ package signaling
import (
"testing"
"github.com/stretchr/testify/assert"
)
func assertSessionHasPermission(t *testing.T, session Session, permission Permission) {
t.Helper()
if !session.HasPermission(permission) {
t.Errorf("Session %s doesn't have permission %s", session.PublicId(), permission)
}
assert.True(t, session.HasPermission(permission), "Session %s doesn't have permission %s", session.PublicId(), permission)
}
func assertSessionHasNotPermission(t *testing.T, session Session, permission Permission) {
t.Helper()
if session.HasPermission(permission) {
t.Errorf("Session %s has permission %s but shouldn't", session.PublicId(), permission)
}
assert.False(t, session.HasPermission(permission), "Session %s has permission %s but shouldn't", session.PublicId(), permission)
}

View file

@ -26,6 +26,8 @@ import (
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestSingleNotifierNoWaiter(t *testing.T) {
@ -48,9 +50,7 @@ func TestSingleNotifierSimple(t *testing.T) {
defer wg.Done()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := waiter.Wait(ctx); err != nil {
t.Error(err)
}
assert.NoError(t, waiter.Wait(ctx))
}()
notifier.Notify()
@ -74,9 +74,7 @@ func TestSingleNotifierWaitClosed(t *testing.T) {
waiter := notifier.NewWaiter()
notifier.Release(waiter)
if err := waiter.Wait(context.Background()); err != nil {
t.Error(err)
}
assert.NoError(t, waiter.Wait(context.Background()))
}
func TestSingleNotifierWaitClosedMulti(t *testing.T) {
@ -87,12 +85,8 @@ func TestSingleNotifierWaitClosedMulti(t *testing.T) {
notifier.Release(waiter1)
notifier.Release(waiter2)
if err := waiter1.Wait(context.Background()); err != nil {
t.Error(err)
}
if err := waiter2.Wait(context.Background()); err != nil {
t.Error(err)
}
assert.NoError(t, waiter1.Wait(context.Background()))
assert.NoError(t, waiter2.Wait(context.Background()))
}
func TestSingleNotifierResetWillNotify(t *testing.T) {
@ -108,9 +102,7 @@ func TestSingleNotifierResetWillNotify(t *testing.T) {
defer wg.Done()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := waiter.Wait(ctx); err != nil {
t.Error(err)
}
assert.NoError(t, waiter.Wait(ctx))
}()
notifier.Reset()
@ -137,9 +129,7 @@ func TestSingleNotifierDuplicate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := waiter.Wait(ctx); err != nil {
t.Error(err)
}
assert.NoError(t, waiter.Wait(ctx))
}()
}

View file

@ -29,6 +29,7 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/testutil"
"github.com/stretchr/testify/assert"
)
func checkStatsValue(t *testing.T, collector prometheus.Collector, value float64) {
@ -37,10 +38,11 @@ func checkStatsValue(t *testing.T, collector prometheus.Collector, value float64
desc := <-ch
v := testutil.ToFloat64(collector)
if v != value {
assert := assert.New(t)
pc := make([]uintptr, 10)
n := runtime.Callers(2, pc)
if n == 0 {
t.Errorf("Expected value %f for %s, got %f", value, desc, v)
assert.Fail("Expected value %f for %s, got %f", value, desc, v)
return
}
@ -57,20 +59,20 @@ func checkStatsValue(t *testing.T, collector prometheus.Collector, value float64
break
}
}
t.Errorf("Expected value %f for %s, got %f at\n%s", value, desc, v, stack)
assert.Fail("Expected value %f for %s, got %f at\n%s", value, desc, v, stack)
}
}
func collectAndLint(t *testing.T, collectors ...prometheus.Collector) {
assert := assert.New(t)
for _, collector := range collectors {
problems, err := testutil.CollectAndLint(collector)
if err != nil {
t.Errorf("Error linting %+v: %s", collector, err)
if !assert.NoError(err) {
continue
}
for _, problem := range problems {
t.Errorf("Problem with %s: %s", problem.Metric, problem.Text)
assert.Fail("Problem with %s: %s", problem.Metric, problem.Text)
}
}
}

View file

@ -40,6 +40,8 @@ import (
"github.com/golang-jwt/jwt/v4"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var (
@ -225,9 +227,7 @@ type TestClient struct {
func NewTestClientContext(ctx context.Context, t *testing.T, server *httptest.Server, hub *Hub) *TestClient {
// Reference "hub" to prevent compiler error.
conn, _, err := websocket.DefaultDialer.DialContext(ctx, getWebsocketUrl(server.URL), nil)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
messageChan := make(chan []byte)
readErrorChan := make(chan error, 1)
@ -238,8 +238,7 @@ func NewTestClientContext(ctx context.Context, t *testing.T, server *httptest.Se
if err != nil {
readErrorChan <- err
return
} else if messageType != websocket.TextMessage {
t.Errorf("Expect text message, got %d", messageType)
} else if !assert.Equal(t, websocket.TextMessage, messageType) {
return
}
@ -266,13 +265,8 @@ func NewTestClient(t *testing.T, server *httptest.Server, hub *Hub) *TestClient
client := NewTestClientContext(ctx, t, server, hub)
msg, err := client.RunUntilMessage(ctx)
if err != nil {
t.Fatal(err)
}
if msg.Type != "welcome" {
t.Errorf("Expected welcome message, got %+v", msg)
}
require.NoError(t, err)
assert.Equal(t, "welcome", msg.Type)
return client
}
@ -380,9 +374,7 @@ func (c *TestClient) WriteJSON(data interface{}) error {
}
func (c *TestClient) EnsuerWriteJSON(data interface{}) {
if err := c.WriteJSON(data); err != nil {
c.t.Fatalf("Could not write JSON %+v: %s", data, err)
}
require.NoError(c.t, c.WriteJSON(data), "Could not write JSON %+v", data)
}
func (c *TestClient) SendHello(userid string) error {
@ -443,9 +435,7 @@ func (c *TestClient) CreateHelloV2Token(userid string, issuedAt time.Time, expir
func (c *TestClient) SendHelloV2WithTimes(userid string, issuedAt time.Time, expiresAt time.Time) error {
tokenString, err := c.CreateHelloV2Token(userid, issuedAt, expiresAt)
if err != nil {
c.t.Fatal(err)
}
require.NoError(c.t, err)
params := HelloV2AuthParams{
Token: tokenString,
@ -493,9 +483,7 @@ func (c *TestClient) SendHelloInternalWithFeatures(features []string) error {
func (c *TestClient) SendHelloParams(url string, version string, clientType string, features []string, params interface{}) error {
data, err := json.Marshal(params)
if err != nil {
c.t.Fatal(err)
}
require.NoError(c.t, err)
hello := &ClientMessage{
Id: "1234",
@ -524,9 +512,7 @@ func (c *TestClient) SendBye() error {
func (c *TestClient) SendMessage(recipient MessageClientMessageRecipient, data interface{}) error {
payload, err := json.Marshal(data)
if err != nil {
c.t.Fatal(err)
}
require.NoError(c.t, err)
message := &ClientMessage{
Id: "abcd",
@ -541,9 +527,7 @@ func (c *TestClient) SendMessage(recipient MessageClientMessageRecipient, data i
func (c *TestClient) SendControl(recipient MessageClientMessageRecipient, data interface{}) error {
payload, err := json.Marshal(data)
if err != nil {
c.t.Fatal(err)
}
require.NoError(c.t, err)
message := &ClientMessage{
Id: "abcd",
@ -608,9 +592,7 @@ func (c *TestClient) SendInternalDialout(msg *DialoutInternalClientMessage) erro
func (c *TestClient) SetTransientData(key string, value interface{}, ttl time.Duration) error {
payload, err := json.Marshal(value)
if err != nil {
c.t.Fatal(err)
}
require.NoError(c.t, err)
message := &ClientMessage{
Id: "efgh",

View file

@ -30,6 +30,8 @@ import (
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
)
var listenSignalOnce sync.Once
@ -76,7 +78,7 @@ func ensureNoGoroutinesLeak(t *testing.T, f func(t *testing.T)) {
if after != before {
io.Copy(os.Stderr, &prev) // nolint
dumpGoroutines("After:", os.Stderr)
t.Fatalf("Number of Go routines has changed from %d to %d", before, after)
require.Equal(t, before, after, "Number of Go routines has changed")
}
}

View file

@ -25,14 +25,15 @@ import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newMemoryThrottlerForTest(t *testing.T) *memoryThrottler {
t.Helper()
result, err := NewMemoryThrottler()
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
t.Cleanup(func() {
result.Close()
@ -54,12 +55,11 @@ func (t *throttlerTiming) getNow() time.Time {
func (t *throttlerTiming) doDelay(ctx context.Context, duration time.Duration) {
t.t.Helper()
if duration != t.expectedSleep {
t.t.Errorf("expected sleep %s, got %s", t.expectedSleep, duration)
}
assert.Equal(t.t, t.expectedSleep, duration)
}
func TestThrottler(t *testing.T) {
assert := assert.New(t)
timing := &throttlerTiming{
t: t,
now: time.Now(),
@ -71,38 +71,31 @@ func TestThrottler(t *testing.T) {
ctx := context.Background()
throttle1, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
if err != nil {
t.Error(err)
}
assert.NoError(err)
timing.expectedSleep = 100 * time.Millisecond
throttle1(ctx)
timing.now = timing.now.Add(time.Millisecond)
throttle2, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
if err != nil {
t.Error(err)
}
assert.NoError(err)
timing.expectedSleep = 200 * time.Millisecond
throttle2(ctx)
timing.now = timing.now.Add(time.Millisecond)
throttle3, err := th.CheckBruteforce(ctx, "192.168.0.2", "action1")
if err != nil {
t.Error(err)
}
assert.NoError(err)
timing.expectedSleep = 100 * time.Millisecond
throttle3(ctx)
timing.now = timing.now.Add(time.Millisecond)
throttle4, err := th.CheckBruteforce(ctx, "192.168.0.1", "action2")
if err != nil {
t.Error(err)
}
assert.NoError(err)
timing.expectedSleep = 100 * time.Millisecond
throttle4(ctx)
}
func TestThrottlerIPv6(t *testing.T) {
assert := assert.New(t)
timing := &throttlerTiming{
t: t,
now: time.Now(),
@ -115,40 +108,33 @@ func TestThrottlerIPv6(t *testing.T) {
// Make sure full /64 subnets are throttled for IPv6.
throttle1, err := th.CheckBruteforce(ctx, "2001:db8:abcd:0012::1", "action1")
if err != nil {
t.Error(err)
}
assert.NoError(err)
timing.expectedSleep = 100 * time.Millisecond
throttle1(ctx)
timing.now = timing.now.Add(time.Millisecond)
throttle2, err := th.CheckBruteforce(ctx, "2001:db8:abcd:0012::2", "action1")
if err != nil {
t.Error(err)
}
assert.NoError(err)
timing.expectedSleep = 200 * time.Millisecond
throttle2(ctx)
// A diffent /64 subnet is not throttled yet.
timing.now = timing.now.Add(time.Millisecond)
throttle3, err := th.CheckBruteforce(ctx, "2001:db8:abcd:0013::1", "action1")
if err != nil {
t.Error(err)
}
assert.NoError(err)
timing.expectedSleep = 100 * time.Millisecond
throttle3(ctx)
// A different action is not throttled.
timing.now = timing.now.Add(time.Millisecond)
throttle4, err := th.CheckBruteforce(ctx, "2001:db8:abcd:0012::1", "action2")
if err != nil {
t.Error(err)
}
assert.NoError(err)
timing.expectedSleep = 100 * time.Millisecond
throttle4(ctx)
}
func TestThrottler_Bruteforce(t *testing.T) {
assert := assert.New(t)
timing := &throttlerTiming{
t: t,
now: time.Now(),
@ -162,9 +148,7 @@ func TestThrottler_Bruteforce(t *testing.T) {
for i := 0; i < maxBruteforceAttempts; i++ {
timing.now = timing.now.Add(time.Millisecond)
throttle, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
if err != nil {
t.Error(err)
}
assert.NoError(err)
if i == 0 {
timing.expectedSleep = 100 * time.Millisecond
} else {
@ -177,14 +161,12 @@ func TestThrottler_Bruteforce(t *testing.T) {
}
timing.now = timing.now.Add(time.Millisecond)
if _, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1"); err == nil {
t.Error("expected bruteforce error")
} else if err != ErrBruteforceDetected {
t.Errorf("expected error %s, got %s", ErrBruteforceDetected, err)
}
_, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
assert.ErrorIs(err, ErrBruteforceDetected)
}
func TestThrottler_Cleanup(t *testing.T) {
assert := assert.New(t)
timing := &throttlerTiming{
t: t,
now: time.Now(),
@ -196,59 +178,46 @@ func TestThrottler_Cleanup(t *testing.T) {
ctx := context.Background()
throttle1, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
if err != nil {
t.Error(err)
}
assert.NoError(err)
timing.expectedSleep = 100 * time.Millisecond
throttle1(ctx)
throttle2, err := th.CheckBruteforce(ctx, "192.168.0.2", "action1")
if err != nil {
t.Error(err)
}
assert.NoError(err)
timing.expectedSleep = 100 * time.Millisecond
throttle2(ctx)
timing.now = timing.now.Add(time.Hour)
throttle3, err := th.CheckBruteforce(ctx, "192.168.0.1", "action2")
if err != nil {
t.Error(err)
}
assert.NoError(err)
timing.expectedSleep = 100 * time.Millisecond
throttle3(ctx)
throttle4, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
if err != nil {
t.Error(err)
}
assert.NoError(err)
timing.expectedSleep = 200 * time.Millisecond
throttle4(ctx)
timing.now = timing.now.Add(-time.Hour).Add(maxBruteforceAge).Add(time.Second)
th.cleanup(timing.now)
if entries := th.getEntries("192.168.0.1", "action1"); len(entries) != 1 {
t.Errorf("should have removed one entry, got %+v", entries)
}
if entries := th.getEntries("192.168.0.1", "action2"); len(entries) != 1 {
t.Errorf("should have kept entry, got %+v", entries)
}
assert.Len(th.getEntries("192.168.0.1", "action1"), 1)
assert.Len(th.getEntries("192.168.0.1", "action2"), 1)
th.mu.RLock()
if _, found := th.clients["192.168.0.2"]; found {
t.Error("should have removed client \"192.168.0.2\"")
assert.Fail("should have removed client \"192.168.0.2\"")
}
th.mu.RUnlock()
throttle5, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
if err != nil {
t.Error(err)
}
assert.NoError(err)
timing.expectedSleep = 200 * time.Millisecond
throttle5(ctx)
}
func TestThrottler_ExpirePartial(t *testing.T) {
assert := assert.New(t)
timing := &throttlerTiming{
t: t,
now: time.Now(),
@ -260,32 +229,27 @@ func TestThrottler_ExpirePartial(t *testing.T) {
ctx := context.Background()
throttle1, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
if err != nil {
t.Error(err)
}
assert.NoError(err)
timing.expectedSleep = 100 * time.Millisecond
throttle1(ctx)
timing.now = timing.now.Add(time.Minute)
throttle2, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
if err != nil {
t.Error(err)
}
assert.NoError(err)
timing.expectedSleep = 200 * time.Millisecond
throttle2(ctx)
timing.now = timing.now.Add(maxBruteforceAge).Add(-time.Minute + time.Second)
throttle3, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
if err != nil {
t.Error(err)
}
assert.NoError(err)
timing.expectedSleep = 200 * time.Millisecond
throttle3(ctx)
}
func TestThrottler_ExpireAll(t *testing.T) {
assert := assert.New(t)
timing := &throttlerTiming{
t: t,
now: time.Now(),
@ -297,32 +261,27 @@ func TestThrottler_ExpireAll(t *testing.T) {
ctx := context.Background()
throttle1, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
if err != nil {
t.Error(err)
}
assert.NoError(err)
timing.expectedSleep = 100 * time.Millisecond
throttle1(ctx)
timing.now = timing.now.Add(time.Millisecond)
throttle2, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
if err != nil {
t.Error(err)
}
assert.NoError(err)
timing.expectedSleep = 200 * time.Millisecond
throttle2(ctx)
timing.now = timing.now.Add(maxBruteforceAge).Add(time.Second)
throttle3, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
if err != nil {
t.Error(err)
}
assert.NoError(err)
timing.expectedSleep = 100 * time.Millisecond
throttle3(ctx)
}
func TestThrottler_Negative(t *testing.T) {
assert := assert.New(t)
timing := &throttlerTiming{
t: t,
now: time.Now(),
@ -336,8 +295,8 @@ func TestThrottler_Negative(t *testing.T) {
for i := 0; i < maxBruteforceAttempts*10; i++ {
timing.now = timing.now.Add(time.Millisecond)
throttle, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
if err != nil && err != ErrBruteforceDetected {
t.Error(err)
if err != nil {
assert.ErrorIs(err, ErrBruteforceDetected)
}
if i == 0 {
timing.expectedSleep = 100 * time.Millisecond

View file

@ -25,6 +25,9 @@ import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func (t *TransientData) SetTTLChannel(ch chan<- struct{}) {
@ -35,106 +38,57 @@ func (t *TransientData) SetTTLChannel(ch chan<- struct{}) {
}
func Test_TransientData(t *testing.T) {
assert := assert.New(t)
data := NewTransientData()
if data.Set("foo", nil) {
t.Errorf("should not have set value")
}
if !data.Set("foo", "bar") {
t.Errorf("should have set value")
}
if data.Set("foo", "bar") {
t.Errorf("should not have set value")
}
if !data.Set("foo", "baz") {
t.Errorf("should have set value")
}
if data.CompareAndSet("foo", "bar", "lala") {
t.Errorf("should not have set value")
}
if !data.CompareAndSet("foo", "baz", "lala") {
t.Errorf("should have set value")
}
if data.CompareAndSet("test", nil, nil) {
t.Errorf("should not have set value")
}
if !data.CompareAndSet("test", nil, "123") {
t.Errorf("should have set value")
}
if data.CompareAndSet("test", nil, "456") {
t.Errorf("should not have set value")
}
if data.CompareAndRemove("test", "1234") {
t.Errorf("should not have removed value")
}
if !data.CompareAndRemove("test", "123") {
t.Errorf("should have removed value")
}
if data.Remove("lala") {
t.Errorf("should not have removed value")
}
if !data.Remove("foo") {
t.Errorf("should have removed value")
}
assert.False(data.Set("foo", nil))
assert.True(data.Set("foo", "bar"))
assert.False(data.Set("foo", "bar"))
assert.True(data.Set("foo", "baz"))
assert.False(data.CompareAndSet("foo", "bar", "lala"))
assert.True(data.CompareAndSet("foo", "baz", "lala"))
assert.False(data.CompareAndSet("test", nil, nil))
assert.True(data.CompareAndSet("test", nil, "123"))
assert.False(data.CompareAndSet("test", nil, "456"))
assert.False(data.CompareAndRemove("test", "1234"))
assert.True(data.CompareAndRemove("test", "123"))
assert.False(data.Remove("lala"))
assert.True(data.Remove("foo"))
ttlCh := make(chan struct{})
data.SetTTLChannel(ttlCh)
if !data.SetTTL("test", "1234", time.Millisecond) {
t.Errorf("should have set value")
}
if value := data.GetData()["test"]; value != "1234" {
t.Errorf("expected 1234, got %v", value)
}
assert.True(data.SetTTL("test", "1234", time.Millisecond))
assert.Equal("1234", data.GetData()["test"])
// Data is removed after the TTL
<-ttlCh
if value := data.GetData()["test"]; value != nil {
t.Errorf("expected no value, got %v", value)
}
assert.Nil(data.GetData()["test"])
if !data.SetTTL("test", "1234", time.Millisecond) {
t.Errorf("should have set value")
}
if value := data.GetData()["test"]; value != "1234" {
t.Errorf("expected 1234, got %v", value)
}
if !data.SetTTL("test", "2345", 3*time.Millisecond) {
t.Errorf("should have set value")
}
if value := data.GetData()["test"]; value != "2345" {
t.Errorf("expected 2345, got %v", value)
}
assert.True(data.SetTTL("test", "1234", time.Millisecond))
assert.Equal("1234", data.GetData()["test"])
assert.True(data.SetTTL("test", "2345", 3*time.Millisecond))
assert.Equal("2345", data.GetData()["test"])
// Data is removed after the TTL only if the value still matches
time.Sleep(2 * time.Millisecond)
if value := data.GetData()["test"]; value != "2345" {
t.Errorf("expected 2345, got %v", value)
}
assert.Equal("2345", data.GetData()["test"])
// Data is removed after the (second) TTL
<-ttlCh
if value := data.GetData()["test"]; value != nil {
t.Errorf("expected no value, got %v", value)
}
assert.Nil(data.GetData()["test"])
// Setting existing key will update the TTL
if !data.SetTTL("test", "1234", time.Millisecond) {
t.Errorf("should have set value")
}
if data.SetTTL("test", "1234", 3*time.Millisecond) {
t.Errorf("should not have set value")
}
assert.True(data.SetTTL("test", "1234", time.Millisecond))
assert.False(data.SetTTL("test", "1234", 3*time.Millisecond))
// Data still exists after the first TTL
time.Sleep(2 * time.Millisecond)
if value := data.GetData()["test"]; value != "1234" {
t.Errorf("expected 1234, got %v", value)
}
assert.Equal("1234", data.GetData()["test"])
// Data is removed after the (updated) TTL
<-ttlCh
if value := data.GetData()["test"]; value != nil {
t.Errorf("expected no value, got %v", value)
}
assert.Nil(data.GetData()["test"])
}
func Test_TransientMessages(t *testing.T) {
t.Parallel()
CatchLogForTest(t)
require := require.New(t)
assert := assert.New(t)
hub, _, _, server := CreateHubForTest(t)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
@ -142,226 +96,139 @@ func Test_TransientMessages(t *testing.T) {
client1 := NewTestClient(t, server, hub)
defer client1.CloseWithBye()
if err := client1.SendHello(testDefaultUserId + "1"); err != nil {
t.Fatal(err)
}
require.NoError(client1.SendHello(testDefaultUserId + "1"))
hello1, err := client1.RunUntilHello(ctx)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
if err := client1.SetTransientData("foo", "bar", 0); err != nil {
t.Fatal(err)
}
if msg, err := client1.RunUntilMessage(ctx); err != nil {
t.Fatal(err)
} else {
if err := checkMessageError(msg, "not_in_room"); err != nil {
t.Fatal(err)
}
require.NoError(client1.SetTransientData("foo", "bar", 0))
if msg, err := client1.RunUntilMessage(ctx); assert.NoError(err) {
require.NoError(checkMessageError(msg, "not_in_room"))
}
client2 := NewTestClient(t, server, hub)
defer client2.CloseWithBye()
if err := client2.SendHello(testDefaultUserId + "2"); err != nil {
t.Fatal(err)
}
require.NoError(client2.SendHello(testDefaultUserId + "2"))
hello2, err := client2.RunUntilHello(ctx)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
// Join room by id.
roomId := "test-room"
if room, err := client1.JoinRoom(ctx, roomId); err != nil {
t.Fatal(err)
} else if room.Room.RoomId != roomId {
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
}
roomMsg, err := client1.JoinRoom(ctx, roomId)
require.NoError(err)
require.Equal(roomId, roomMsg.Room.RoomId)
// Give message processing some time.
time.Sleep(10 * time.Millisecond)
if room, err := client2.JoinRoom(ctx, roomId); err != nil {
t.Fatal(err)
} else if room.Room.RoomId != roomId {
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
}
roomMsg, err = client2.JoinRoom(ctx, roomId)
require.NoError(err)
require.Equal(roomId, roomMsg.Room.RoomId)
WaitForUsersJoined(ctx, t, client1, hello1, client2, hello2)
session1 := hub.GetSessionByPublicId(hello1.Hello.SessionId).(*ClientSession)
if session1 == nil {
t.Fatalf("Session %s does not exist", hello1.Hello.SessionId)
}
require.NotNil(session1, "Session %s does not exist", hello1.Hello.SessionId)
session2 := hub.GetSessionByPublicId(hello2.Hello.SessionId).(*ClientSession)
if session2 == nil {
t.Fatalf("Session %s does not exist", hello2.Hello.SessionId)
}
require.NotNil(session2, "Session %s does not exist", hello2.Hello.SessionId)
// Client 1 may modify transient data.
session1.SetPermissions([]Permission{PERMISSION_TRANSIENT_DATA})
// Client 2 may not modify transient data.
session2.SetPermissions([]Permission{})
if err := client2.SetTransientData("foo", "bar", 0); err != nil {
t.Fatal(err)
}
if msg, err := client2.RunUntilMessage(ctx); err != nil {
t.Fatal(err)
} else {
if err := checkMessageError(msg, "not_allowed"); err != nil {
t.Fatal(err)
}
require.NoError(client2.SetTransientData("foo", "bar", 0))
if msg, err := client2.RunUntilMessage(ctx); assert.NoError(err) {
require.NoError(checkMessageError(msg, "not_allowed"))
}
if err := client1.SetTransientData("foo", "bar", 0); err != nil {
t.Fatal(err)
require.NoError(client1.SetTransientData("foo", "bar", 0))
if msg, err := client1.RunUntilMessage(ctx); assert.NoError(err) {
require.NoError(checkMessageTransientSet(msg, "foo", "bar", nil))
}
if msg, err := client2.RunUntilMessage(ctx); assert.NoError(err) {
require.NoError(checkMessageTransientSet(msg, "foo", "bar", nil))
}
if msg, err := client1.RunUntilMessage(ctx); err != nil {
t.Fatal(err)
} else {
if err := checkMessageTransientSet(msg, "foo", "bar", nil); err != nil {
t.Fatal(err)
}
}
if msg, err := client2.RunUntilMessage(ctx); err != nil {
t.Fatal(err)
} else {
if err := checkMessageTransientSet(msg, "foo", "bar", nil); err != nil {
t.Fatal(err)
}
}
if err := client2.RemoveTransientData("foo"); err != nil {
t.Fatal(err)
}
if msg, err := client2.RunUntilMessage(ctx); err != nil {
t.Fatal(err)
} else {
if err := checkMessageError(msg, "not_allowed"); err != nil {
t.Fatal(err)
}
require.NoError(client2.RemoveTransientData("foo"))
if msg, err := client2.RunUntilMessage(ctx); assert.NoError(err) {
require.NoError(checkMessageError(msg, "not_allowed"))
}
// Setting the same value is ignored by the server.
if err := client1.SetTransientData("foo", "bar", 0); err != nil {
t.Fatal(err)
}
require.NoError(client1.SetTransientData("foo", "bar", 0))
ctx2, cancel2 := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel2()
if msg, err := client1.RunUntilMessage(ctx2); err != nil {
if err != context.DeadlineExceeded {
t.Fatal(err)
}
if msg, err := client1.RunUntilMessage(ctx2); err == nil {
assert.Fail("Expected no payload, got %+v", msg)
} else {
t.Errorf("Expected no payload, got %+v", msg)
require.ErrorIs(err, context.DeadlineExceeded)
}
data := map[string]interface{}{
"hello": "world",
}
if err := client1.SetTransientData("foo", data, 0); err != nil {
t.Fatal(err)
require.NoError(client1.SetTransientData("foo", data, 0))
if msg, err := client1.RunUntilMessage(ctx); assert.NoError(err) {
require.NoError(checkMessageTransientSet(msg, "foo", data, "bar"))
}
if msg, err := client2.RunUntilMessage(ctx); assert.NoError(err) {
require.NoError(checkMessageTransientSet(msg, "foo", data, "bar"))
}
if msg, err := client1.RunUntilMessage(ctx); err != nil {
t.Fatal(err)
} else {
if err := checkMessageTransientSet(msg, "foo", data, "bar"); err != nil {
t.Fatal(err)
}
}
if msg, err := client2.RunUntilMessage(ctx); err != nil {
t.Fatal(err)
} else {
if err := checkMessageTransientSet(msg, "foo", data, "bar"); err != nil {
t.Fatal(err)
}
}
require.NoError(client1.RemoveTransientData("foo"))
if err := client1.RemoveTransientData("foo"); err != nil {
t.Fatal(err)
if msg, err := client1.RunUntilMessage(ctx); assert.NoError(err) {
require.NoError(checkMessageTransientRemove(msg, "foo", data))
}
if msg, err := client1.RunUntilMessage(ctx); err != nil {
t.Fatal(err)
} else {
if err := checkMessageTransientRemove(msg, "foo", data); err != nil {
t.Fatal(err)
}
}
if msg, err := client2.RunUntilMessage(ctx); err != nil {
t.Fatal(err)
} else {
if err := checkMessageTransientRemove(msg, "foo", data); err != nil {
t.Fatal(err)
}
if msg, err := client2.RunUntilMessage(ctx); assert.NoError(err) {
require.NoError(checkMessageTransientRemove(msg, "foo", data))
}
// Removing a non-existing key is ignored by the server.
if err := client1.RemoveTransientData("foo"); err != nil {
t.Fatal(err)
}
require.NoError(client1.RemoveTransientData("foo"))
ctx3, cancel3 := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel3()
if msg, err := client1.RunUntilMessage(ctx3); err != nil {
if err != context.DeadlineExceeded {
t.Fatal(err)
}
if msg, err := client1.RunUntilMessage(ctx3); err == nil {
assert.Fail("Expected no payload, got %+v", msg)
} else {
t.Errorf("Expected no payload, got %+v", msg)
require.ErrorIs(err, context.DeadlineExceeded)
}
if err := client1.SetTransientData("abc", data, 10*time.Millisecond); err != nil {
t.Fatal(err)
}
require.NoError(client1.SetTransientData("abc", data, 10*time.Millisecond))
client3 := NewTestClient(t, server, hub)
defer client3.CloseWithBye()
if err := client3.SendHello(testDefaultUserId + "3"); err != nil {
t.Fatal(err)
}
require.NoError(client3.SendHello(testDefaultUserId + "3"))
hello3, err := client3.RunUntilHello(ctx)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
if room, err := client3.JoinRoom(ctx, roomId); err != nil {
t.Fatal(err)
} else if room.Room.RoomId != roomId {
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
}
roomMsg, err = client3.JoinRoom(ctx, roomId)
require.NoError(err)
require.Equal(roomId, roomMsg.Room.RoomId)
_, ignored, err := client3.RunUntilJoinedAndReturn(ctx, hello1.Hello, hello2.Hello, hello3.Hello)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
var msg *ServerMessage
if len(ignored) == 0 {
if msg, err = client3.RunUntilMessage(ctx); err != nil {
t.Fatal(err)
}
msg, err = client3.RunUntilMessage(ctx)
require.NoError(err)
} else if len(ignored) == 1 {
msg = ignored[0]
} else {
t.Fatalf("Received too many messages: %+v", ignored)
require.Fail("Received too many messages: %+v", ignored)
}
if err := checkMessageTransientInitial(msg, map[string]interface{}{
require.NoError(checkMessageTransientInitial(msg, map[string]interface{}{
"abc": data,
}); err != nil {
t.Fatal(err)
}
}))
time.Sleep(10 * time.Millisecond)
if msg, err = client3.RunUntilMessage(ctx); err != nil {
t.Fatal(err)
} else if err := checkMessageTransientRemove(msg, "abc", data); err != nil {
t.Fatal(err)
if msg, err = client3.RunUntilMessage(ctx); assert.NoError(err) {
require.NoError(checkMessageTransientRemove(msg, "abc", data))
}
}

View file

@ -27,11 +27,16 @@ import (
"errors"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestVirtualSession(t *testing.T) {
t.Parallel()
CatchLogForTest(t)
require := require.New(t)
assert := assert.New(t)
hub, _, _, server := CreateHubForTest(t)
roomId := "the-room-id"
@ -41,54 +46,34 @@ func TestVirtualSession(t *testing.T) {
compat: true,
}
room, err := hub.createRoom(roomId, emptyProperties, backend)
if err != nil {
t.Fatalf("Could not create room: %s", err)
}
require.NoError(err)
defer room.Close()
clientInternal := NewTestClient(t, server, hub)
defer clientInternal.CloseWithBye()
if err := clientInternal.SendHelloInternal(); err != nil {
t.Fatal(err)
}
require.NoError(clientInternal.SendHelloInternal())
client := NewTestClient(t, server, hub)
defer client.CloseWithBye()
if err := client.SendHello(testDefaultUserId); err != nil {
t.Fatal(err)
}
require.NoError(client.SendHello(testDefaultUserId))
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
if hello, err := clientInternal.RunUntilHello(ctx); err != nil {
t.Error(err)
} else {
if hello.Hello.UserId != "" {
t.Errorf("Expected empty user id, got %+v", hello.Hello)
}
if hello.Hello.SessionId == "" {
t.Errorf("Expected session id, got %+v", hello.Hello)
}
if hello.Hello.ResumeId == "" {
t.Errorf("Expected resume id, got %+v", hello.Hello)
}
if hello, err := clientInternal.RunUntilHello(ctx); assert.NoError(err) {
assert.Empty(hello.Hello.UserId)
assert.NotEmpty(hello.Hello.SessionId)
assert.NotEmpty(hello.Hello.ResumeId)
}
hello, err := client.RunUntilHello(ctx)
if err != nil {
t.Error(err)
}
assert.NoError(err)
if room, err := client.JoinRoom(ctx, roomId); err != nil {
t.Fatal(err)
} else if room.Room.RoomId != roomId {
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
}
roomMsg, err := client.JoinRoom(ctx, roomId)
require.NoError(err)
require.Equal(roomId, roomMsg.Room.RoomId)
// Ignore "join" events.
if err := client.DrainMessages(ctx); err != nil {
t.Error(err)
}
assert.NoError(client.DrainMessages(ctx))
internalSessionId := "session1"
userId := "user1"
@ -106,64 +91,39 @@ func TestVirtualSession(t *testing.T) {
},
},
}
if err := clientInternal.WriteJSON(msgAdd); err != nil {
t.Fatal(err)
}
require.NoError(clientInternal.WriteJSON(msgAdd))
msg1, err := client.RunUntilMessage(ctx)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
// The public session id will be generated by the server, so don't check for it.
if err := client.checkMessageJoinedSession(msg1, "", userId); err != nil {
t.Fatal(err)
}
require.NoError(client.checkMessageJoinedSession(msg1, "", userId))
sessionId := msg1.Event.Join[0].SessionId
session := hub.GetSessionByPublicId(sessionId)
if session == nil {
t.Fatalf("Could not get virtual session %s", sessionId)
}
if session.ClientType() != HelloClientTypeVirtual {
t.Errorf("Expected client type %s, got %s", HelloClientTypeVirtual, session.ClientType())
}
if sid := session.(*VirtualSession).SessionId(); sid != internalSessionId {
t.Errorf("Expected internal session id %s, got %s", internalSessionId, sid)
if assert.NotNil(session, "Could not get virtual session %s", sessionId) {
assert.Equal(HelloClientTypeVirtual, session.ClientType())
sid := session.(*VirtualSession).SessionId()
assert.Equal(internalSessionId, sid)
}
// Also a participants update event will be triggered for the virtual user.
msg2, err := client.RunUntilMessage(ctx)
if err != nil {
t.Fatal(err)
}
updateMsg, err := checkMessageParticipantsInCall(msg2)
if err != nil {
t.Error(err)
} else if updateMsg.RoomId != roomId {
t.Errorf("Expected room %s, got %s", roomId, updateMsg.RoomId)
} else if len(updateMsg.Users) != 1 {
t.Errorf("Expected one user, got %+v", updateMsg.Users)
} else if sid, ok := updateMsg.Users[0]["sessionId"].(string); !ok || sid != sessionId {
t.Errorf("Expected session id %s, got %+v", sessionId, updateMsg.Users[0])
} else if virtual, ok := updateMsg.Users[0]["virtual"].(bool); !ok || !virtual {
t.Errorf("Expected virtual user, got %+v", updateMsg.Users[0])
} else if inCall, ok := updateMsg.Users[0]["inCall"].(float64); !ok || inCall != (FlagInCall|FlagWithPhone) {
t.Errorf("Expected user in call with phone, got %+v", updateMsg.Users[0])
require.NoError(err)
if updateMsg, err := checkMessageParticipantsInCall(msg2); assert.NoError(err) {
assert.Equal(roomId, updateMsg.RoomId)
if assert.Len(updateMsg.Users, 1) {
assert.Equal(sessionId, updateMsg.Users[0]["sessionId"])
assert.Equal(true, updateMsg.Users[0]["virtual"])
assert.EqualValues((FlagInCall | FlagWithPhone), updateMsg.Users[0]["inCall"])
}
}
msg3, err := client.RunUntilMessage(ctx)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
flagsMsg, err := checkMessageParticipantFlags(msg3)
if err != nil {
t.Error(err)
} else if flagsMsg.RoomId != roomId {
t.Errorf("Expected room %s, got %s", roomId, flagsMsg.RoomId)
} else if flagsMsg.SessionId != sessionId {
t.Errorf("Expected session id %s, got %s", sessionId, flagsMsg.SessionId)
} else if flagsMsg.Flags != FLAG_MUTED_SPEAKING {
t.Errorf("Expected flags %d, got %+v", FLAG_MUTED_SPEAKING, flagsMsg.Flags)
if flagsMsg, err := checkMessageParticipantFlags(msg3); assert.NoError(err) {
assert.Equal(roomId, flagsMsg.RoomId)
assert.Equal(sessionId, flagsMsg.SessionId)
assert.EqualValues(FLAG_MUTED_SPEAKING, flagsMsg.Flags)
}
newFlags := uint32(FLAG_TALKING)
@ -180,49 +140,35 @@ func TestVirtualSession(t *testing.T) {
},
},
}
if err := clientInternal.WriteJSON(msgFlags); err != nil {
t.Fatal(err)
}
require.NoError(clientInternal.WriteJSON(msgFlags))
msg4, err := client.RunUntilMessage(ctx)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
flagsMsg, err = checkMessageParticipantFlags(msg4)
if err != nil {
t.Error(err)
} else if flagsMsg.RoomId != roomId {
t.Errorf("Expected room %s, got %s", roomId, flagsMsg.RoomId)
} else if flagsMsg.SessionId != sessionId {
t.Errorf("Expected session id %s, got %s", sessionId, flagsMsg.SessionId)
} else if flagsMsg.Flags != newFlags {
t.Errorf("Expected flags %d, got %+v", newFlags, flagsMsg.Flags)
if flagsMsg, err := checkMessageParticipantFlags(msg4); assert.NoError(err) {
assert.Equal(roomId, flagsMsg.RoomId)
assert.Equal(sessionId, flagsMsg.SessionId)
assert.EqualValues(newFlags, flagsMsg.Flags)
}
// A new client will receive the initial flags of the virtual session.
client2 := NewTestClient(t, server, hub)
defer client2.CloseWithBye()
if err := client2.SendHello(testDefaultUserId + "2"); err != nil {
t.Fatal(err)
}
require.NoError(client2.SendHello(testDefaultUserId + "2"))
if _, err := client2.RunUntilHello(ctx); err != nil {
t.Error(err)
}
_, err = client2.RunUntilHello(ctx)
require.NoError(err)
if room, err := client2.JoinRoom(ctx, roomId); err != nil {
t.Fatal(err)
} else if room.Room.RoomId != roomId {
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
}
roomMsg, err = client2.JoinRoom(ctx, roomId)
require.NoError(err)
require.Equal(roomId, roomMsg.Room.RoomId)
gotFlags := false
var receivedMessages []*ServerMessage
for !gotFlags {
messages, err := client2.GetPendingMessages(ctx)
if err != nil {
t.Error(err)
assert.NoError(err)
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
break
}
@ -234,26 +180,18 @@ func TestVirtualSession(t *testing.T) {
continue
}
if msg.Event.Flags.RoomId != roomId {
t.Errorf("Expected flags in room %s, got %s", roomId, msg.Event.Flags.RoomId)
} else if msg.Event.Flags.SessionId != sessionId {
t.Errorf("Expected flags for session %s, got %s", sessionId, msg.Event.Flags.SessionId)
} else if msg.Event.Flags.Flags != newFlags {
t.Errorf("Expected flags %d, got %d", newFlags, msg.Event.Flags.Flags)
} else {
if assert.Equal(roomId, msg.Event.Flags.RoomId) &&
assert.Equal(sessionId, msg.Event.Flags.SessionId) &&
assert.EqualValues(newFlags, msg.Event.Flags.Flags) {
gotFlags = true
break
}
}
}
if !gotFlags {
t.Errorf("Didn't receive initial flags in %+v", receivedMessages)
}
assert.True(gotFlags, "Didn't receive initial flags in %+v", receivedMessages)
// Ignore "join" messages from second client
if err := client.DrainMessages(ctx); err != nil {
t.Error(err)
}
assert.NoError(client.DrainMessages(ctx))
// When sending to a virtual session, the message is sent to the actual
// client and contains a "Recipient" block with the internal session id.
@ -263,32 +201,21 @@ func TestVirtualSession(t *testing.T) {
}
data := "from-client-to-virtual"
if err := client.SendMessage(recipient, data); err != nil {
t.Fatal(err)
}
require.NoError(client.SendMessage(recipient, data))
msg2, err = clientInternal.RunUntilMessage(ctx)
if err != nil {
t.Fatal(err)
} else if err := checkMessageType(msg2, "message"); err != nil {
t.Fatal(err)
} else if err := checkMessageSender(hub, msg2.Message.Sender, "session", hello.Hello); err != nil {
t.Error(err)
}
require.NoError(err)
require.NoError(checkMessageType(msg2, "message"))
require.NoError(checkMessageSender(hub, msg2.Message.Sender, "session", hello.Hello))
if msg2.Message.Recipient == nil {
t.Errorf("Expected recipient, got none")
} else if msg2.Message.Recipient.Type != "session" {
t.Errorf("Expected recipient type session, got %s", msg2.Message.Recipient.Type)
} else if msg2.Message.Recipient.SessionId != internalSessionId {
t.Errorf("Expected recipient %s, got %s", internalSessionId, msg2.Message.Recipient.SessionId)
if assert.NotNil(msg2.Message.Recipient) {
assert.Equal("session", msg2.Message.Recipient.Type)
assert.Equal(internalSessionId, msg2.Message.Recipient.SessionId)
}
var payload string
if err := json.Unmarshal(msg2.Message.Data, &payload); err != nil {
t.Error(err)
} else if payload != data {
t.Errorf("Expected payload %s, got %s", data, payload)
if err := json.Unmarshal(msg2.Message.Data, &payload); assert.NoError(err) {
assert.Equal(data, payload)
}
msgRemove := &ClientMessage{
@ -303,16 +230,10 @@ func TestVirtualSession(t *testing.T) {
},
},
}
if err := clientInternal.WriteJSON(msgRemove); err != nil {
t.Fatal(err)
}
require.NoError(clientInternal.WriteJSON(msgRemove))
msg5, err := client.RunUntilMessage(ctx)
if err != nil {
t.Fatal(err)
}
if err := client.checkMessageRoomLeaveSession(msg5, sessionId); err != nil {
t.Error(err)
if msg5, err := client.RunUntilMessage(ctx); assert.NoError(err) {
assert.NoError(client.checkMessageRoomLeaveSession(msg5, sessionId))
}
}
@ -342,6 +263,8 @@ func checkHasEntryWithInCall(message *RoomEventServerMessage, sessionId string,
func TestVirtualSessionCustomInCall(t *testing.T) {
t.Parallel()
CatchLogForTest(t)
require := require.New(t)
assert := assert.New(t)
hub, _, _, server := CreateHubForTest(t)
roomId := "the-room-id"
@ -351,9 +274,7 @@ func TestVirtualSessionCustomInCall(t *testing.T) {
compat: true,
}
room, err := hub.createRoom(roomId, emptyProperties, backend)
if err != nil {
t.Fatalf("Could not create room: %s", err)
}
require.NoError(err)
defer room.Close()
clientInternal := NewTestClient(t, server, hub)
@ -361,67 +282,40 @@ func TestVirtualSessionCustomInCall(t *testing.T) {
features := []string{
ClientFeatureInternalInCall,
}
if err := clientInternal.SendHelloInternalWithFeatures(features); err != nil {
t.Fatal(err)
}
require.NoError(clientInternal.SendHelloInternalWithFeatures(features))
client := NewTestClient(t, server, hub)
defer client.CloseWithBye()
if err := client.SendHello(testDefaultUserId); err != nil {
t.Fatal(err)
}
require.NoError(client.SendHello(testDefaultUserId))
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
helloInternal, err := clientInternal.RunUntilHello(ctx)
if err != nil {
t.Error(err)
} else {
if helloInternal.Hello.UserId != "" {
t.Errorf("Expected empty user id, got %+v", helloInternal.Hello)
}
if helloInternal.Hello.SessionId == "" {
t.Errorf("Expected session id, got %+v", helloInternal.Hello)
}
if helloInternal.Hello.ResumeId == "" {
t.Errorf("Expected resume id, got %+v", helloInternal.Hello)
}
}
if room, err := clientInternal.JoinRoomWithRoomSession(ctx, roomId, ""); err != nil {
t.Fatal(err)
} else if room.Room.RoomId != roomId {
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
if assert.NoError(err) {
assert.Empty(helloInternal.Hello.UserId)
assert.NotEmpty(helloInternal.Hello.SessionId)
assert.NotEmpty(helloInternal.Hello.ResumeId)
}
roomMsg, err := clientInternal.JoinRoomWithRoomSession(ctx, roomId, "")
require.NoError(err)
require.Equal(roomId, roomMsg.Room.RoomId)
hello, err := client.RunUntilHello(ctx)
if err != nil {
t.Error(err)
}
if room, err := client.JoinRoom(ctx, roomId); err != nil {
t.Fatal(err)
} else if room.Room.RoomId != roomId {
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
}
assert.NoError(err)
roomMsg, err = client.JoinRoom(ctx, roomId)
require.NoError(err)
require.Equal(roomId, roomMsg.Room.RoomId)
if _, additional, err := clientInternal.RunUntilJoinedAndReturn(ctx, helloInternal.Hello, hello.Hello); err != nil {
t.Error(err)
} else if len(additional) != 1 {
t.Errorf("expected one additional message, got %+v", additional)
} else if additional[0].Type != "event" {
t.Errorf("expected event message, got %+v", additional[0])
} else if additional[0].Event.Target != "participants" {
t.Errorf("expected event participants message, got %+v", additional[0])
} else if additional[0].Event.Type != "update" {
t.Errorf("expected event participants update message, got %+v", additional[0])
} else if additional[0].Event.Update.Users[0]["sessionId"].(string) != helloInternal.Hello.SessionId {
t.Errorf("expected event update message for internal session, got %+v", additional[0])
} else if additional[0].Event.Update.Users[0]["inCall"].(float64) != 0 {
t.Errorf("expected event update message with session not in call, got %+v", additional[0])
}
if err := client.RunUntilJoined(ctx, helloInternal.Hello, hello.Hello); err != nil {
t.Error(err)
if _, additional, err := clientInternal.RunUntilJoinedAndReturn(ctx, helloInternal.Hello, hello.Hello); assert.NoError(err) {
if assert.Len(additional, 1) && assert.Equal("event", additional[0].Type) {
assert.Equal("participants", additional[0].Event.Target)
assert.Equal("update", additional[0].Event.Type)
assert.Equal(helloInternal.Hello.SessionId, additional[0].Event.Update.Users[0]["sessionId"])
assert.EqualValues(0, additional[0].Event.Update.Users[0]["inCall"])
}
}
assert.NoError(client.RunUntilJoined(ctx, helloInternal.Hello, hello.Hello))
internalSessionId := "session1"
userId := "user1"
@ -439,65 +333,38 @@ func TestVirtualSessionCustomInCall(t *testing.T) {
},
},
}
if err := clientInternal.WriteJSON(msgAdd); err != nil {
t.Fatal(err)
}
require.NoError(clientInternal.WriteJSON(msgAdd))
msg1, err := client.RunUntilMessage(ctx)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
// The public session id will be generated by the server, so don't check for it.
if err := client.checkMessageJoinedSession(msg1, "", userId); err != nil {
t.Fatal(err)
}
require.NoError(client.checkMessageJoinedSession(msg1, "", userId))
sessionId := msg1.Event.Join[0].SessionId
session := hub.GetSessionByPublicId(sessionId)
if session == nil {
t.Fatalf("Could not get virtual session %s", sessionId)
}
if session.ClientType() != HelloClientTypeVirtual {
t.Errorf("Expected client type %s, got %s", HelloClientTypeVirtual, session.ClientType())
}
if sid := session.(*VirtualSession).SessionId(); sid != internalSessionId {
t.Errorf("Expected internal session id %s, got %s", internalSessionId, sid)
if assert.NotNil(session) {
assert.Equal(HelloClientTypeVirtual, session.ClientType())
sid := session.(*VirtualSession).SessionId()
assert.Equal(internalSessionId, sid)
}
// Also a participants update event will be triggered for the virtual user.
msg2, err := client.RunUntilMessage(ctx)
if err != nil {
t.Fatal(err)
}
updateMsg, err := checkMessageParticipantsInCall(msg2)
if err != nil {
t.Error(err)
} else if updateMsg.RoomId != roomId {
t.Errorf("Expected room %s, got %s", roomId, updateMsg.RoomId)
} else if len(updateMsg.Users) != 2 {
t.Errorf("Expected two users, got %+v", updateMsg.Users)
}
require.NoError(err)
if updateMsg, err := checkMessageParticipantsInCall(msg2); assert.NoError(err) {
assert.Equal(roomId, updateMsg.RoomId)
assert.Len(updateMsg.Users, 2)
if err := checkHasEntryWithInCall(updateMsg, sessionId, "virtual", 0); err != nil {
t.Error(err)
}
if err := checkHasEntryWithInCall(updateMsg, helloInternal.Hello.SessionId, "internal", 0); err != nil {
t.Error(err)
assert.NoError(checkHasEntryWithInCall(updateMsg, sessionId, "virtual", 0))
assert.NoError(checkHasEntryWithInCall(updateMsg, helloInternal.Hello.SessionId, "internal", 0))
}
msg3, err := client.RunUntilMessage(ctx)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
flagsMsg, err := checkMessageParticipantFlags(msg3)
if err != nil {
t.Error(err)
} else if flagsMsg.RoomId != roomId {
t.Errorf("Expected room %s, got %s", roomId, flagsMsg.RoomId)
} else if flagsMsg.SessionId != sessionId {
t.Errorf("Expected session id %s, got %s", sessionId, flagsMsg.SessionId)
} else if flagsMsg.Flags != FLAG_MUTED_SPEAKING {
t.Errorf("Expected flags %d, got %+v", FLAG_MUTED_SPEAKING, flagsMsg.Flags)
if flagsMsg, err := checkMessageParticipantFlags(msg3); assert.NoError(err) {
assert.Equal(roomId, flagsMsg.RoomId)
assert.Equal(sessionId, flagsMsg.SessionId)
assert.EqualValues(FLAG_MUTED_SPEAKING, flagsMsg.Flags)
}
// The internal session can change its "inCall" flags
@ -510,27 +377,15 @@ func TestVirtualSessionCustomInCall(t *testing.T) {
},
},
}
if err := clientInternal.WriteJSON(msgInCall); err != nil {
t.Fatal(err)
}
require.NoError(clientInternal.WriteJSON(msgInCall))
msg4, err := client.RunUntilMessage(ctx)
if err != nil {
t.Fatal(err)
}
updateMsg2, err := checkMessageParticipantsInCall(msg4)
if err != nil {
t.Error(err)
} else if updateMsg2.RoomId != roomId {
t.Errorf("Expected room %s, got %s", roomId, updateMsg2.RoomId)
} else if len(updateMsg2.Users) != 2 {
t.Errorf("Expected two users, got %+v", updateMsg2.Users)
}
if err := checkHasEntryWithInCall(updateMsg2, sessionId, "virtual", 0); err != nil {
t.Error(err)
}
if err := checkHasEntryWithInCall(updateMsg2, helloInternal.Hello.SessionId, "internal", FlagInCall|FlagWithAudio); err != nil {
t.Error(err)
require.NoError(err)
if updateMsg, err := checkMessageParticipantsInCall(msg4); assert.NoError(err) {
assert.Equal(roomId, updateMsg.RoomId)
assert.Len(updateMsg.Users, 2)
assert.NoError(checkHasEntryWithInCall(updateMsg, sessionId, "virtual", 0))
assert.NoError(checkHasEntryWithInCall(updateMsg, helloInternal.Hello.SessionId, "internal", FlagInCall|FlagWithAudio))
}
// The internal session can change the "inCall" flags of a virtual session
@ -548,33 +403,23 @@ func TestVirtualSessionCustomInCall(t *testing.T) {
},
},
}
if err := clientInternal.WriteJSON(msgInCall2); err != nil {
t.Fatal(err)
}
require.NoError(clientInternal.WriteJSON(msgInCall2))
msg5, err := client.RunUntilMessage(ctx)
if err != nil {
t.Fatal(err)
}
updateMsg3, err := checkMessageParticipantsInCall(msg5)
if err != nil {
t.Error(err)
} else if updateMsg3.RoomId != roomId {
t.Errorf("Expected room %s, got %s", roomId, updateMsg3.RoomId)
} else if len(updateMsg3.Users) != 2 {
t.Errorf("Expected two users, got %+v", updateMsg3.Users)
}
if err := checkHasEntryWithInCall(updateMsg3, sessionId, "virtual", newInCall); err != nil {
t.Error(err)
}
if err := checkHasEntryWithInCall(updateMsg3, helloInternal.Hello.SessionId, "internal", FlagInCall|FlagWithAudio); err != nil {
t.Error(err)
require.NoError(err)
if updateMsg, err := checkMessageParticipantsInCall(msg5); assert.NoError(err) {
assert.Equal(roomId, updateMsg.RoomId)
assert.Len(updateMsg.Users, 2)
assert.NoError(checkHasEntryWithInCall(updateMsg, sessionId, "virtual", newInCall))
assert.NoError(checkHasEntryWithInCall(updateMsg, helloInternal.Hello.SessionId, "internal", FlagInCall|FlagWithAudio))
}
}
func TestVirtualSessionCleanup(t *testing.T) {
t.Parallel()
CatchLogForTest(t)
require := require.New(t)
assert := assert.New(t)
hub, _, _, server := CreateHubForTest(t)
roomId := "the-room-id"
@ -584,53 +429,34 @@ func TestVirtualSessionCleanup(t *testing.T) {
compat: true,
}
room, err := hub.createRoom(roomId, emptyProperties, backend)
if err != nil {
t.Fatalf("Could not create room: %s", err)
}
require.NoError(err)
defer room.Close()
clientInternal := NewTestClient(t, server, hub)
defer clientInternal.CloseWithBye()
if err := clientInternal.SendHelloInternal(); err != nil {
t.Fatal(err)
}
require.NoError(clientInternal.SendHelloInternal())
client := NewTestClient(t, server, hub)
defer client.CloseWithBye()
if err := client.SendHello(testDefaultUserId); err != nil {
t.Fatal(err)
}
require.NoError(client.SendHello(testDefaultUserId))
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
if hello, err := clientInternal.RunUntilHello(ctx); err != nil {
t.Error(err)
} else {
if hello.Hello.UserId != "" {
t.Errorf("Expected empty user id, got %+v", hello.Hello)
}
if hello.Hello.SessionId == "" {
t.Errorf("Expected session id, got %+v", hello.Hello)
}
if hello.Hello.ResumeId == "" {
t.Errorf("Expected resume id, got %+v", hello.Hello)
}
}
if _, err := client.RunUntilHello(ctx); err != nil {
t.Error(err)
if hello, err := clientInternal.RunUntilHello(ctx); assert.NoError(err) {
assert.Empty(hello.Hello.UserId)
assert.NotEmpty(hello.Hello.SessionId)
assert.NotEmpty(hello.Hello.ResumeId)
}
_, err = client.RunUntilHello(ctx)
assert.NoError(err)
if room, err := client.JoinRoom(ctx, roomId); err != nil {
t.Fatal(err)
} else if room.Room.RoomId != roomId {
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
}
roomMsg, err := client.JoinRoom(ctx, roomId)
require.NoError(err)
require.Equal(roomId, roomMsg.Room.RoomId)
// Ignore "join" events.
if err := client.DrainMessages(ctx); err != nil {
t.Error(err)
}
assert.NoError(client.DrainMessages(ctx))
internalSessionId := "session1"
userId := "user1"
@ -648,72 +474,45 @@ func TestVirtualSessionCleanup(t *testing.T) {
},
},
}
if err := clientInternal.WriteJSON(msgAdd); err != nil {
t.Fatal(err)
}
require.NoError(clientInternal.WriteJSON(msgAdd))
msg1, err := client.RunUntilMessage(ctx)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
// The public session id will be generated by the server, so don't check for it.
if err := client.checkMessageJoinedSession(msg1, "", userId); err != nil {
t.Fatal(err)
}
require.NoError(client.checkMessageJoinedSession(msg1, "", userId))
sessionId := msg1.Event.Join[0].SessionId
session := hub.GetSessionByPublicId(sessionId)
if session == nil {
t.Fatalf("Could not get virtual session %s", sessionId)
}
if session.ClientType() != HelloClientTypeVirtual {
t.Errorf("Expected client type %s, got %s", HelloClientTypeVirtual, session.ClientType())
}
if sid := session.(*VirtualSession).SessionId(); sid != internalSessionId {
t.Errorf("Expected internal session id %s, got %s", internalSessionId, sid)
if assert.NotNil(session) {
assert.Equal(HelloClientTypeVirtual, session.ClientType())
sid := session.(*VirtualSession).SessionId()
assert.Equal(internalSessionId, sid)
}
// Also a participants update event will be triggered for the virtual user.
msg2, err := client.RunUntilMessage(ctx)
if err != nil {
t.Fatal(err)
}
updateMsg, err := checkMessageParticipantsInCall(msg2)
if err != nil {
t.Error(err)
} else if updateMsg.RoomId != roomId {
t.Errorf("Expected room %s, got %s", roomId, updateMsg.RoomId)
} else if len(updateMsg.Users) != 1 {
t.Errorf("Expected one user, got %+v", updateMsg.Users)
} else if sid, ok := updateMsg.Users[0]["sessionId"].(string); !ok || sid != sessionId {
t.Errorf("Expected session id %s, got %+v", sessionId, updateMsg.Users[0])
} else if virtual, ok := updateMsg.Users[0]["virtual"].(bool); !ok || !virtual {
t.Errorf("Expected virtual user, got %+v", updateMsg.Users[0])
} else if inCall, ok := updateMsg.Users[0]["inCall"].(float64); !ok || inCall != (FlagInCall|FlagWithPhone) {
t.Errorf("Expected user in call with phone, got %+v", updateMsg.Users[0])
require.NoError(err)
if updateMsg, err := checkMessageParticipantsInCall(msg2); assert.NoError(err) {
assert.Equal(roomId, updateMsg.RoomId)
if assert.Len(updateMsg.Users, 1) {
assert.Equal(sessionId, updateMsg.Users[0]["sessionId"])
assert.Equal(true, updateMsg.Users[0]["virtual"])
assert.EqualValues((FlagInCall | FlagWithPhone), updateMsg.Users[0]["inCall"])
}
}
msg3, err := client.RunUntilMessage(ctx)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
flagsMsg, err := checkMessageParticipantFlags(msg3)
if err != nil {
t.Error(err)
} else if flagsMsg.RoomId != roomId {
t.Errorf("Expected room %s, got %s", roomId, flagsMsg.RoomId)
} else if flagsMsg.SessionId != sessionId {
t.Errorf("Expected session id %s, got %s", sessionId, flagsMsg.SessionId)
} else if flagsMsg.Flags != FLAG_MUTED_SPEAKING {
t.Errorf("Expected flags %d, got %+v", FLAG_MUTED_SPEAKING, flagsMsg.Flags)
if flagsMsg, err := checkMessageParticipantFlags(msg3); assert.NoError(err) {
assert.Equal(roomId, flagsMsg.RoomId)
assert.Equal(sessionId, flagsMsg.SessionId)
assert.EqualValues(FLAG_MUTED_SPEAKING, flagsMsg.Flags)
}
// The virtual sessions are closed when the parent session is deleted.
clientInternal.CloseWithBye()
if msg2, err := client.RunUntilMessage(ctx); err != nil {
t.Fatal(err)
} else if err := client.checkMessageRoomLeaveSession(msg2, sessionId); err != nil {
t.Error(err)
}
msg2, err = client.RunUntilMessage(ctx)
require.NoError(err)
assert.NoError(client.checkMessageRoomLeaveSession(msg2, sessionId))
}