mirror of
https://github.com/strukturag/nextcloud-spreed-signaling
synced 2026-03-14 14:35:44 +01:00
Switch to "github.com/stretchr/testify" for tests.
This commit is contained in:
parent
9fdd61758a
commit
03cad99b8d
50 changed files with 3082 additions and 6234 deletions
|
|
@ -24,19 +24,17 @@ package signaling
|
|||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAllowedIps(t *testing.T) {
|
||||
require := require.New(t)
|
||||
a, err := ParseAllowedIps("127.0.0.1, 192.168.0.1, 192.168.1.1/24")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if a.Empty() {
|
||||
t.Fatal("should not be empty")
|
||||
}
|
||||
if expected := `[127.0.0.1/32, 192.168.0.1/32, 192.168.1.0/24]`; a.String() != expected {
|
||||
t.Errorf("expected %s, got %s", expected, a.String())
|
||||
}
|
||||
require.NoError(err)
|
||||
require.False(a.Empty())
|
||||
require.Equal(`[127.0.0.1/32, 192.168.0.1/32, 192.168.1.0/24]`, a.String())
|
||||
|
||||
allowed := []string{
|
||||
"127.0.0.1",
|
||||
|
|
@ -51,22 +49,18 @@ func TestAllowedIps(t *testing.T) {
|
|||
|
||||
for _, addr := range allowed {
|
||||
t.Run(addr, func(t *testing.T) {
|
||||
ip := net.ParseIP(addr)
|
||||
if ip == nil {
|
||||
t.Errorf("error parsing %s", addr)
|
||||
} else if !a.Allowed(ip) {
|
||||
t.Errorf("should allow %s", addr)
|
||||
assert := assert.New(t)
|
||||
if ip := net.ParseIP(addr); assert.NotNil(ip, "error parsing %s", addr) {
|
||||
assert.True(a.Allowed(ip), "should allow %s", addr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
for _, addr := range notAllowed {
|
||||
t.Run(addr, func(t *testing.T) {
|
||||
ip := net.ParseIP(addr)
|
||||
if ip == nil {
|
||||
t.Errorf("error parsing %s", addr)
|
||||
} else if a.Allowed(ip) {
|
||||
t.Errorf("should not allow %s", addr)
|
||||
assert := assert.New(t)
|
||||
if ip := net.ParseIP(addr); assert.NotNil(ip, "error parsing %s", addr) {
|
||||
assert.False(a.Allowed(ip), "should not allow %s", addr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,42 +24,36 @@ package signaling
|
|||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestBackendChecksum(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert := assert.New(t)
|
||||
rnd := newRandomString(32)
|
||||
body := []byte{1, 2, 3, 4, 5}
|
||||
secret := []byte("shared-secret")
|
||||
|
||||
check1 := CalculateBackendChecksum(rnd, body, secret)
|
||||
check2 := CalculateBackendChecksum(rnd, body, secret)
|
||||
if check1 != check2 {
|
||||
t.Errorf("Expected equal checksums, got %s and %s", check1, check2)
|
||||
}
|
||||
assert.Equal(check1, check2, "Expected equal checksums")
|
||||
|
||||
if !ValidateBackendChecksumValue(check1, rnd, body, secret) {
|
||||
t.Errorf("Checksum %s could not be validated", check1)
|
||||
}
|
||||
if ValidateBackendChecksumValue(check1[1:], rnd, body, secret) {
|
||||
t.Errorf("Checksum %s should not be valid", check1[1:])
|
||||
}
|
||||
if ValidateBackendChecksumValue(check1[:len(check1)-1], rnd, body, secret) {
|
||||
t.Errorf("Checksum %s should not be valid", check1[:len(check1)-1])
|
||||
}
|
||||
assert.True(ValidateBackendChecksumValue(check1, rnd, body, secret), "Checksum should be valid")
|
||||
assert.False(ValidateBackendChecksumValue(check1[1:], rnd, body, secret), "Checksum should not be valid")
|
||||
assert.False(ValidateBackendChecksumValue(check1[:len(check1)-1], rnd, body, secret), "Checksum should not be valid")
|
||||
|
||||
request := &http.Request{
|
||||
Header: make(http.Header),
|
||||
}
|
||||
request.Header.Set("Spreed-Signaling-Random", rnd)
|
||||
request.Header.Set("Spreed-Signaling-Checksum", check1)
|
||||
if !ValidateBackendChecksum(request, body, secret) {
|
||||
t.Errorf("Checksum %s could not be validated from request", check1)
|
||||
}
|
||||
assert.True(ValidateBackendChecksum(request, body, secret), "Checksum could not be validated from request")
|
||||
}
|
||||
|
||||
func TestValidNumbers(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert := assert.New(t)
|
||||
valid := []string{
|
||||
"+12",
|
||||
"+12345",
|
||||
|
|
@ -72,13 +66,9 @@ func TestValidNumbers(t *testing.T) {
|
|||
"+123-45",
|
||||
}
|
||||
for _, number := range valid {
|
||||
if !isValidNumber(number) {
|
||||
t.Errorf("number %s should be valid", number)
|
||||
}
|
||||
assert.True(isValidNumber(number), "number %s should be valid", number)
|
||||
}
|
||||
for _, number := range invalid {
|
||||
if isValidNumber(number) {
|
||||
t.Errorf("number %s should not be valid", number)
|
||||
}
|
||||
assert.False(isValidNumber(number), "number %s should not be valid", number)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,9 +24,10 @@ package signaling
|
|||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type testCheckValid interface {
|
||||
|
|
@ -53,40 +54,34 @@ func wrapMessage(messageType string, msg testCheckValid) *ClientMessage {
|
|||
}
|
||||
|
||||
func testMessages(t *testing.T, messageType string, valid_messages []testCheckValid, invalid_messages []testCheckValid) {
|
||||
t.Helper()
|
||||
assert := assert.New(t)
|
||||
for _, msg := range valid_messages {
|
||||
if err := msg.CheckValid(); err != nil {
|
||||
t.Errorf("Message %+v should be valid, got %s", msg, err)
|
||||
}
|
||||
assert.NoError(msg.CheckValid(), "Message %+v should be valid", msg)
|
||||
|
||||
// If the inner message is valid, it should also be valid in a wrapped
|
||||
// ClientMessage.
|
||||
if wrapped := wrapMessage(messageType, msg); wrapped == nil {
|
||||
t.Errorf("Unknown message type: %s", messageType)
|
||||
} else if err := wrapped.CheckValid(); err != nil {
|
||||
t.Errorf("Message %+v should be valid, got %s", wrapped, err)
|
||||
if wrapped := wrapMessage(messageType, msg); assert.NotNil(wrapped, "Unknown message type: %s", messageType) {
|
||||
assert.NoError(wrapped.CheckValid(), "Message %+v should be valid", wrapped)
|
||||
}
|
||||
}
|
||||
for _, msg := range invalid_messages {
|
||||
if err := msg.CheckValid(); err == nil {
|
||||
t.Errorf("Message %+v should not be valid", msg)
|
||||
}
|
||||
assert.Error(msg.CheckValid(), "Message %+v should not be valid", msg)
|
||||
|
||||
// If the inner message is invalid, it should also be invalid in a
|
||||
// wrapped ClientMessage.
|
||||
if wrapped := wrapMessage(messageType, msg); wrapped == nil {
|
||||
t.Errorf("Unknown message type: %s", messageType)
|
||||
} else if err := wrapped.CheckValid(); err == nil {
|
||||
t.Errorf("Message %+v should not be valid", wrapped)
|
||||
if wrapped := wrapMessage(messageType, msg); assert.NotNil(wrapped, "Unknown message type: %s", messageType) {
|
||||
assert.Error(wrapped.CheckValid(), "Message %+v should not be valid", wrapped)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert := assert.New(t)
|
||||
// The message needs a type.
|
||||
msg := ClientMessage{}
|
||||
if err := msg.CheckValid(); err == nil {
|
||||
t.Errorf("Message %+v should not be valid", msg)
|
||||
}
|
||||
assert.Error(msg.CheckValid())
|
||||
}
|
||||
|
||||
func TestHelloClientMessage(t *testing.T) {
|
||||
|
|
@ -229,9 +224,8 @@ func TestHelloClientMessage(t *testing.T) {
|
|||
msg := ClientMessage{
|
||||
Type: "hello",
|
||||
}
|
||||
if err := msg.CheckValid(); err == nil {
|
||||
t.Errorf("Message %+v should not be valid", msg)
|
||||
}
|
||||
assert := assert.New(t)
|
||||
assert.Error(msg.CheckValid())
|
||||
}
|
||||
|
||||
func TestMessageClientMessage(t *testing.T) {
|
||||
|
|
@ -311,9 +305,8 @@ func TestMessageClientMessage(t *testing.T) {
|
|||
msg := ClientMessage{
|
||||
Type: "message",
|
||||
}
|
||||
if err := msg.CheckValid(); err == nil {
|
||||
t.Errorf("Message %+v should not be valid", msg)
|
||||
}
|
||||
assert := assert.New(t)
|
||||
assert.Error(msg.CheckValid())
|
||||
}
|
||||
|
||||
func TestByeClientMessage(t *testing.T) {
|
||||
|
|
@ -330,9 +323,8 @@ func TestByeClientMessage(t *testing.T) {
|
|||
msg := ClientMessage{
|
||||
Type: "bye",
|
||||
}
|
||||
if err := msg.CheckValid(); err != nil {
|
||||
t.Errorf("Message %+v should be valid, got %s", msg, err)
|
||||
}
|
||||
assert := assert.New(t)
|
||||
assert.NoError(msg.CheckValid())
|
||||
}
|
||||
|
||||
func TestRoomClientMessage(t *testing.T) {
|
||||
|
|
@ -349,42 +341,31 @@ func TestRoomClientMessage(t *testing.T) {
|
|||
msg := ClientMessage{
|
||||
Type: "room",
|
||||
}
|
||||
if err := msg.CheckValid(); err == nil {
|
||||
t.Errorf("Message %+v should not be valid", msg)
|
||||
}
|
||||
assert := assert.New(t)
|
||||
assert.Error(msg.CheckValid())
|
||||
}
|
||||
|
||||
func TestErrorMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert := assert.New(t)
|
||||
id := "request-id"
|
||||
msg := ClientMessage{
|
||||
Id: id,
|
||||
}
|
||||
err1 := msg.NewErrorServerMessage(&Error{})
|
||||
if err1.Id != id {
|
||||
t.Errorf("Expected id %s, got %+v", id, err1)
|
||||
}
|
||||
if err1.Type != "error" || err1.Error == nil {
|
||||
t.Errorf("Expected type \"error\", got %+v", err1)
|
||||
}
|
||||
assert.Equal(id, err1.Id, "%+v", err1)
|
||||
assert.Equal("error", err1.Type, "%+v", err1)
|
||||
assert.NotNil(err1.Error, "%+v", err1)
|
||||
|
||||
err2 := msg.NewWrappedErrorServerMessage(fmt.Errorf("test-error"))
|
||||
if err2.Id != id {
|
||||
t.Errorf("Expected id %s, got %+v", id, err2)
|
||||
}
|
||||
if err2.Type != "error" || err2.Error == nil {
|
||||
t.Errorf("Expected type \"error\", got %+v", err2)
|
||||
}
|
||||
if err2.Error.Code != "internal_error" {
|
||||
t.Errorf("Expected code \"internal_error\", got %+v", err2)
|
||||
}
|
||||
if err2.Error.Message != "test-error" {
|
||||
t.Errorf("Expected message \"test-error\", got %+v", err2)
|
||||
assert.Equal(id, err2.Id, "%+v", err2)
|
||||
assert.Equal("error", err2.Type, "%+v", err2)
|
||||
if assert.NotNil(err2.Error, "%+v", err2) {
|
||||
assert.Equal("internal_error", err2.Error.Code, "%+v", err2)
|
||||
assert.Equal("test-error", err2.Error.Message, "%+v", err2)
|
||||
}
|
||||
// Test "error" interface
|
||||
if err2.Error.Error() != "test-error" {
|
||||
t.Errorf("Expected error string \"test-error\", got %+v", err2)
|
||||
}
|
||||
assert.Equal("test-error", err2.Error.Error(), "%+v", err2)
|
||||
}
|
||||
|
||||
func TestIsChatRefresh(t *testing.T) {
|
||||
|
|
@ -397,9 +378,7 @@ func TestIsChatRefresh(t *testing.T) {
|
|||
Data: data_true,
|
||||
},
|
||||
}
|
||||
if !msg.IsChatRefresh() {
|
||||
t.Error("message should be detected as chat refresh")
|
||||
}
|
||||
assert.True(t, msg.IsChatRefresh())
|
||||
|
||||
data_false := []byte("{\"type\":\"chat\",\"chat\":{\"refresh\":false}}")
|
||||
msg = ServerMessage{
|
||||
|
|
@ -408,9 +387,7 @@ func TestIsChatRefresh(t *testing.T) {
|
|||
Data: data_false,
|
||||
},
|
||||
}
|
||||
if msg.IsChatRefresh() {
|
||||
t.Error("message should not be detected as chat refresh")
|
||||
}
|
||||
assert.False(t, msg.IsChatRefresh())
|
||||
}
|
||||
|
||||
func assertEqualStrings(t *testing.T, expected, result []string) {
|
||||
|
|
@ -427,27 +404,22 @@ func assertEqualStrings(t *testing.T, expected, result []string) {
|
|||
sort.Strings(result)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, result) {
|
||||
t.Errorf("Expected %+v, got %+v", expected, result)
|
||||
}
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
|
||||
func Test_Welcome_AddRemoveFeature(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert := assert.New(t)
|
||||
var msg WelcomeServerMessage
|
||||
assertEqualStrings(t, []string{}, msg.Features)
|
||||
|
||||
msg.AddFeature("one", "two", "one")
|
||||
assertEqualStrings(t, []string{"one", "two"}, msg.Features)
|
||||
if !sort.StringsAreSorted(msg.Features) {
|
||||
t.Errorf("features should be sorted, got %+v", msg.Features)
|
||||
}
|
||||
assert.True(sort.StringsAreSorted(msg.Features), "features should be sorted, got %+v", msg.Features)
|
||||
|
||||
msg.AddFeature("three")
|
||||
assertEqualStrings(t, []string{"one", "two", "three"}, msg.Features)
|
||||
if !sort.StringsAreSorted(msg.Features) {
|
||||
t.Errorf("features should be sorted, got %+v", msg.Features)
|
||||
}
|
||||
assert.True(sort.StringsAreSorted(msg.Features), "features should be sorted, got %+v", msg.Features)
|
||||
|
||||
msg.RemoveFeature("three", "one")
|
||||
assertEqualStrings(t, []string{"two"}, msg.Features)
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ import (
|
|||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
@ -51,7 +53,7 @@ func getRealAsyncEventsForTest(t *testing.T) AsyncEvents {
|
|||
url := startLocalNatsServer(t)
|
||||
events, err := NewAsyncEvents(url)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
return events
|
||||
}
|
||||
|
|
@ -59,7 +61,7 @@ func getRealAsyncEventsForTest(t *testing.T) AsyncEvents {
|
|||
func getLoopbackAsyncEventsForTest(t *testing.T) AsyncEvents {
|
||||
events, err := NewAsyncEvents(NatsLoopbackUrl)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
|
|
|
|||
|
|
@ -24,17 +24,17 @@ package signaling
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/dlintw/goconf"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func returnOCS(t *testing.T, w http.ResponseWriter, body []byte) {
|
||||
|
|
@ -57,37 +57,29 @@ func returnOCS(t *testing.T, w http.ResponseWriter, body []byte) {
|
|||
}
|
||||
|
||||
data, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if _, err := w.Write(data); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
_, err = w.Write(data)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestPostOnRedirect(t *testing.T) {
|
||||
t.Parallel()
|
||||
CatchLogForTest(t)
|
||||
require := require.New(t)
|
||||
r := mux.NewRouter()
|
||||
r.HandleFunc("/ocs/v2.php/one", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "/ocs/v2.php/two", http.StatusTemporaryRedirect)
|
||||
})
|
||||
r.HandleFunc("/ocs/v2.php/two", func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
var request map[string]string
|
||||
if err := json.Unmarshal(body, &request); err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
err = json.Unmarshal(body, &request)
|
||||
require.NoError(err)
|
||||
|
||||
returnOCS(t, w, body)
|
||||
})
|
||||
|
|
@ -96,9 +88,7 @@ func TestPostOnRedirect(t *testing.T) {
|
|||
defer server.Close()
|
||||
|
||||
u, err := url.Parse(server.URL + "/ocs/v2.php/one")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
config := goconf.NewConfigFile()
|
||||
config.AddOption("backend", "allowed", u.Host)
|
||||
|
|
@ -107,9 +97,7 @@ func TestPostOnRedirect(t *testing.T) {
|
|||
config.AddOption("backend", "allowhttp", "true")
|
||||
}
|
||||
client, err := NewBackendClient(config, 1, "0.0", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
ctx := context.Background()
|
||||
request := map[string]string{
|
||||
|
|
@ -117,18 +105,17 @@ func TestPostOnRedirect(t *testing.T) {
|
|||
}
|
||||
var response map[string]string
|
||||
err = client.PerformJSONRequest(ctx, u, request, &response)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
if response == nil || !reflect.DeepEqual(request, response) {
|
||||
t.Errorf("Expected %+v, got %+v", request, response)
|
||||
if assert.NotNil(t, response) {
|
||||
assert.Equal(t, request, response)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostOnRedirectDifferentHost(t *testing.T) {
|
||||
t.Parallel()
|
||||
CatchLogForTest(t)
|
||||
require := require.New(t)
|
||||
r := mux.NewRouter()
|
||||
r.HandleFunc("/ocs/v2.php/one", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "http://domain.invalid/ocs/v2.php/two", http.StatusTemporaryRedirect)
|
||||
|
|
@ -137,9 +124,7 @@ func TestPostOnRedirectDifferentHost(t *testing.T) {
|
|||
defer server.Close()
|
||||
|
||||
u, err := url.Parse(server.URL + "/ocs/v2.php/one")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
config := goconf.NewConfigFile()
|
||||
config.AddOption("backend", "allowed", u.Host)
|
||||
|
|
@ -148,9 +133,7 @@ func TestPostOnRedirectDifferentHost(t *testing.T) {
|
|||
config.AddOption("backend", "allowhttp", "true")
|
||||
}
|
||||
client, err := NewBackendClient(config, 1, "0.0", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
ctx := context.Background()
|
||||
request := map[string]string{
|
||||
|
|
@ -160,41 +143,33 @@ func TestPostOnRedirectDifferentHost(t *testing.T) {
|
|||
err = client.PerformJSONRequest(ctx, u, request, &response)
|
||||
if err != nil {
|
||||
// The redirect to a different host should have failed.
|
||||
if !errors.Is(err, ErrNotRedirecting) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.ErrorIs(err, ErrNotRedirecting)
|
||||
} else {
|
||||
t.Fatal("The redirect should have failed")
|
||||
require.Fail("The redirect should have failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostOnRedirectStatusFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
CatchLogForTest(t)
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
r := mux.NewRouter()
|
||||
r.HandleFunc("/ocs/v2.php/one", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "/ocs/v2.php/two", http.StatusFound)
|
||||
})
|
||||
r.HandleFunc("/ocs/v2.php/two", func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(body) > 0 {
|
||||
t.Errorf("Should not have received any body, got %s", string(body))
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
assert.Empty(string(body), "Should not have received any body, got %s", string(body))
|
||||
returnOCS(t, w, []byte("{}"))
|
||||
})
|
||||
server := httptest.NewServer(r)
|
||||
defer server.Close()
|
||||
|
||||
u, err := url.Parse(server.URL + "/ocs/v2.php/one")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
config := goconf.NewConfigFile()
|
||||
config.AddOption("backend", "allowed", u.Host)
|
||||
|
|
@ -203,9 +178,7 @@ func TestPostOnRedirectStatusFound(t *testing.T) {
|
|||
config.AddOption("backend", "allowhttp", "true")
|
||||
}
|
||||
client, err := NewBackendClient(config, 1, "0.0", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
ctx := context.Background()
|
||||
request := map[string]string{
|
||||
|
|
@ -213,18 +186,16 @@ func TestPostOnRedirectStatusFound(t *testing.T) {
|
|||
}
|
||||
var response map[string]string
|
||||
err = client.PerformJSONRequest(ctx, u, request, &response)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if len(response) > 0 {
|
||||
t.Errorf("Expected empty response, got %+v", response)
|
||||
if assert.NoError(err) {
|
||||
assert.Empty(response, "Expected empty response, got %+v", response)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleThrottled(t *testing.T) {
|
||||
t.Parallel()
|
||||
CatchLogForTest(t)
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
r := mux.NewRouter()
|
||||
r.HandleFunc("/ocs/v2.php/one", func(w http.ResponseWriter, r *http.Request) {
|
||||
returnOCS(t, w, []byte("[]"))
|
||||
|
|
@ -233,9 +204,7 @@ func TestHandleThrottled(t *testing.T) {
|
|||
defer server.Close()
|
||||
|
||||
u, err := url.Parse(server.URL + "/ocs/v2.php/one")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
config := goconf.NewConfigFile()
|
||||
config.AddOption("backend", "allowed", u.Host)
|
||||
|
|
@ -244,9 +213,7 @@ func TestHandleThrottled(t *testing.T) {
|
|||
config.AddOption("backend", "allowhttp", "true")
|
||||
}
|
||||
client, err := NewBackendClient(config, 1, "0.0", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
ctx := context.Background()
|
||||
request := map[string]string{
|
||||
|
|
@ -254,9 +221,7 @@ func TestHandleThrottled(t *testing.T) {
|
|||
}
|
||||
var response map[string]string
|
||||
err = client.PerformJSONRequest(ctx, u, request, &response)
|
||||
if err == nil {
|
||||
t.Error("should have triggered an error")
|
||||
} else if !errors.Is(err, ErrThrottledResponse) {
|
||||
t.Error(err)
|
||||
if assert.Error(err) {
|
||||
assert.ErrorIs(err, ErrThrottledResponse)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@
|
|||
package signaling
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/url"
|
||||
"reflect"
|
||||
|
|
@ -31,32 +30,30 @@ import (
|
|||
|
||||
"github.com/dlintw/goconf"
|
||||
"github.com/prometheus/client_golang/prometheus/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func testUrls(t *testing.T, config *BackendConfiguration, valid_urls []string, invalid_urls []string) {
|
||||
for _, u := range valid_urls {
|
||||
u := u
|
||||
t.Run(u, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
parsed, err := url.ParseRequestURI(u)
|
||||
if err != nil {
|
||||
t.Errorf("The url %s should be valid, got %s", u, err)
|
||||
if !assert.NoError(err, "The url %s should be valid", u) {
|
||||
return
|
||||
}
|
||||
if !config.IsUrlAllowed(parsed) {
|
||||
t.Errorf("The url %s should be allowed", u)
|
||||
}
|
||||
if secret := config.GetSecret(parsed); !bytes.Equal(secret, testBackendSecret) {
|
||||
t.Errorf("Expected secret %s for url %s, got %s", string(testBackendSecret), u, string(secret))
|
||||
}
|
||||
assert.True(config.IsUrlAllowed(parsed), "The url %s should be allowed", u)
|
||||
secret := config.GetSecret(parsed)
|
||||
assert.Equal(string(testBackendSecret), string(secret), "Expected secret %s for url %s, got %s", string(testBackendSecret), u, string(secret))
|
||||
})
|
||||
}
|
||||
for _, u := range invalid_urls {
|
||||
u := u
|
||||
t.Run(u, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
parsed, _ := url.ParseRequestURI(u)
|
||||
if config.IsUrlAllowed(parsed) {
|
||||
t.Errorf("The url %s should not be allowed", u)
|
||||
}
|
||||
assert.False(config.IsUrlAllowed(parsed), "The url %s should not be allowed", u)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -65,28 +62,24 @@ func testBackends(t *testing.T, config *BackendConfiguration, valid_urls [][]str
|
|||
for _, entry := range valid_urls {
|
||||
entry := entry
|
||||
t.Run(entry[0], func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
u := entry[0]
|
||||
parsed, err := url.ParseRequestURI(u)
|
||||
if err != nil {
|
||||
t.Errorf("The url %s should be valid, got %s", u, err)
|
||||
if !assert.NoError(err, "The url %s should be valid", u) {
|
||||
return
|
||||
}
|
||||
if !config.IsUrlAllowed(parsed) {
|
||||
t.Errorf("The url %s should be allowed", u)
|
||||
}
|
||||
assert.True(config.IsUrlAllowed(parsed), "The url %s should be allowed", u)
|
||||
s := entry[1]
|
||||
if secret := config.GetSecret(parsed); !bytes.Equal(secret, []byte(s)) {
|
||||
t.Errorf("Expected secret %s for url %s, got %s", string(s), u, string(secret))
|
||||
}
|
||||
secret := config.GetSecret(parsed)
|
||||
assert.Equal(s, string(secret), "Expected secret %s for url %s, got %s", s, u, string(secret))
|
||||
})
|
||||
}
|
||||
for _, u := range invalid_urls {
|
||||
u := u
|
||||
t.Run(u, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
parsed, _ := url.ParseRequestURI(u)
|
||||
if config.IsUrlAllowed(parsed) {
|
||||
t.Errorf("The url %s should not be allowed", u)
|
||||
}
|
||||
assert.False(config.IsUrlAllowed(parsed), "The url %s should not be allowed", u)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -108,9 +101,7 @@ func TestIsUrlAllowed_Compat(t *testing.T) {
|
|||
config.AddOption("backend", "allowhttp", "true")
|
||||
config.AddOption("backend", "secret", string(testBackendSecret))
|
||||
cfg, err := NewBackendConfiguration(config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
testUrls(t, cfg, valid_urls, invalid_urls)
|
||||
}
|
||||
|
||||
|
|
@ -130,9 +121,7 @@ func TestIsUrlAllowed_CompatForceHttps(t *testing.T) {
|
|||
config.AddOption("backend", "allowed", "domain.invalid")
|
||||
config.AddOption("backend", "secret", string(testBackendSecret))
|
||||
cfg, err := NewBackendConfiguration(config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
testUrls(t, cfg, valid_urls, invalid_urls)
|
||||
}
|
||||
|
||||
|
|
@ -176,9 +165,7 @@ func TestIsUrlAllowed(t *testing.T) {
|
|||
config.AddOption("lala", "url", "https://otherdomain.invalid/")
|
||||
config.AddOption("lala", "secret", string(testBackendSecret)+"-lala")
|
||||
cfg, err := NewBackendConfiguration(config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
testBackends(t, cfg, valid_urls, invalid_urls)
|
||||
}
|
||||
|
||||
|
|
@ -194,9 +181,7 @@ func TestIsUrlAllowed_EmptyAllowlist(t *testing.T) {
|
|||
config.AddOption("backend", "allowed", "")
|
||||
config.AddOption("backend", "secret", string(testBackendSecret))
|
||||
cfg, err := NewBackendConfiguration(config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
testUrls(t, cfg, valid_urls, invalid_urls)
|
||||
}
|
||||
|
||||
|
|
@ -215,9 +200,7 @@ func TestIsUrlAllowed_AllowAll(t *testing.T) {
|
|||
config.AddOption("backend", "allowed", "")
|
||||
config.AddOption("backend", "secret", string(testBackendSecret))
|
||||
cfg, err := NewBackendConfiguration(config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
testUrls(t, cfg, valid_urls, invalid_urls)
|
||||
}
|
||||
|
||||
|
|
@ -238,16 +221,16 @@ func TestParseBackendIds(t *testing.T) {
|
|||
{"backend1,backend2, backend1", []string{"backend1", "backend2"}},
|
||||
}
|
||||
|
||||
assert := assert.New(t)
|
||||
for _, test := range testcases {
|
||||
ids := getConfiguredBackendIDs(test.s)
|
||||
if !reflect.DeepEqual(ids, test.ids) {
|
||||
t.Errorf("List of ids differs, expected %+v, got %+v", test.ids, ids)
|
||||
}
|
||||
assert.Equal(test.ids, ids, "List of ids differs for \"%s\"", test.s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendReloadNoChange(t *testing.T) {
|
||||
CatchLogForTest(t)
|
||||
require := require.New(t)
|
||||
current := testutil.ToFloat64(statsBackendsCurrent)
|
||||
original_config := goconf.NewConfigFile()
|
||||
original_config.AddOption("backend", "backends", "backend1, backend2")
|
||||
|
|
@ -257,9 +240,7 @@ func TestBackendReloadNoChange(t *testing.T) {
|
|||
original_config.AddOption("backend2", "url", "http://domain2.invalid")
|
||||
original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
|
||||
o_cfg, err := NewBackendConfiguration(original_config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
checkStatsValue(t, statsBackendsCurrent, current+2)
|
||||
|
||||
new_config := goconf.NewConfigFile()
|
||||
|
|
@ -270,20 +251,19 @@ func TestBackendReloadNoChange(t *testing.T) {
|
|||
new_config.AddOption("backend2", "url", "http://domain2.invalid")
|
||||
new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
|
||||
n_cfg, err := NewBackendConfiguration(new_config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
checkStatsValue(t, statsBackendsCurrent, current+4)
|
||||
o_cfg.Reload(original_config)
|
||||
checkStatsValue(t, statsBackendsCurrent, current+4)
|
||||
if !reflect.DeepEqual(n_cfg, o_cfg) {
|
||||
t.Error("BackendConfiguration should be equal after Reload")
|
||||
assert.Fail(t, "BackendConfiguration should be equal after Reload")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendReloadChangeExistingURL(t *testing.T) {
|
||||
CatchLogForTest(t)
|
||||
require := require.New(t)
|
||||
current := testutil.ToFloat64(statsBackendsCurrent)
|
||||
original_config := goconf.NewConfigFile()
|
||||
original_config.AddOption("backend", "backends", "backend1, backend2")
|
||||
|
|
@ -293,9 +273,7 @@ func TestBackendReloadChangeExistingURL(t *testing.T) {
|
|||
original_config.AddOption("backend2", "url", "http://domain2.invalid")
|
||||
original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
|
||||
o_cfg, err := NewBackendConfiguration(original_config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
checkStatsValue(t, statsBackendsCurrent, current+2)
|
||||
new_config := goconf.NewConfigFile()
|
||||
|
|
@ -307,9 +285,7 @@ func TestBackendReloadChangeExistingURL(t *testing.T) {
|
|||
new_config.AddOption("backend2", "url", "http://domain2.invalid")
|
||||
new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
|
||||
n_cfg, err := NewBackendConfiguration(new_config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
checkStatsValue(t, statsBackendsCurrent, current+4)
|
||||
original_config.RemoveOption("backend1", "url")
|
||||
|
|
@ -319,12 +295,13 @@ func TestBackendReloadChangeExistingURL(t *testing.T) {
|
|||
o_cfg.Reload(original_config)
|
||||
checkStatsValue(t, statsBackendsCurrent, current+4)
|
||||
if !reflect.DeepEqual(n_cfg, o_cfg) {
|
||||
t.Error("BackendConfiguration should be equal after Reload")
|
||||
assert.Fail(t, "BackendConfiguration should be equal after Reload")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendReloadChangeSecret(t *testing.T) {
|
||||
CatchLogForTest(t)
|
||||
require := require.New(t)
|
||||
current := testutil.ToFloat64(statsBackendsCurrent)
|
||||
original_config := goconf.NewConfigFile()
|
||||
original_config.AddOption("backend", "backends", "backend1, backend2")
|
||||
|
|
@ -334,9 +311,7 @@ func TestBackendReloadChangeSecret(t *testing.T) {
|
|||
original_config.AddOption("backend2", "url", "http://domain2.invalid")
|
||||
original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
|
||||
o_cfg, err := NewBackendConfiguration(original_config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
checkStatsValue(t, statsBackendsCurrent, current+2)
|
||||
new_config := goconf.NewConfigFile()
|
||||
|
|
@ -347,9 +322,7 @@ func TestBackendReloadChangeSecret(t *testing.T) {
|
|||
new_config.AddOption("backend2", "url", "http://domain2.invalid")
|
||||
new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
|
||||
n_cfg, err := NewBackendConfiguration(new_config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
checkStatsValue(t, statsBackendsCurrent, current+4)
|
||||
original_config.RemoveOption("backend1", "secret")
|
||||
|
|
@ -358,12 +331,13 @@ func TestBackendReloadChangeSecret(t *testing.T) {
|
|||
o_cfg.Reload(original_config)
|
||||
checkStatsValue(t, statsBackendsCurrent, current+4)
|
||||
if !reflect.DeepEqual(n_cfg, o_cfg) {
|
||||
t.Error("BackendConfiguration should be equal after Reload")
|
||||
assert.Fail(t, "BackendConfiguration should be equal after Reload")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendReloadAddBackend(t *testing.T) {
|
||||
CatchLogForTest(t)
|
||||
require := require.New(t)
|
||||
current := testutil.ToFloat64(statsBackendsCurrent)
|
||||
original_config := goconf.NewConfigFile()
|
||||
original_config.AddOption("backend", "backends", "backend1")
|
||||
|
|
@ -371,9 +345,7 @@ func TestBackendReloadAddBackend(t *testing.T) {
|
|||
original_config.AddOption("backend1", "url", "http://domain1.invalid")
|
||||
original_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
|
||||
o_cfg, err := NewBackendConfiguration(original_config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
checkStatsValue(t, statsBackendsCurrent, current+1)
|
||||
new_config := goconf.NewConfigFile()
|
||||
|
|
@ -385,9 +357,7 @@ func TestBackendReloadAddBackend(t *testing.T) {
|
|||
new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
|
||||
new_config.AddOption("backend2", "sessionlimit", "10")
|
||||
n_cfg, err := NewBackendConfiguration(new_config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
checkStatsValue(t, statsBackendsCurrent, current+3)
|
||||
original_config.RemoveOption("backend", "backends")
|
||||
|
|
@ -399,12 +369,13 @@ func TestBackendReloadAddBackend(t *testing.T) {
|
|||
o_cfg.Reload(original_config)
|
||||
checkStatsValue(t, statsBackendsCurrent, current+4)
|
||||
if !reflect.DeepEqual(n_cfg, o_cfg) {
|
||||
t.Error("BackendConfiguration should be equal after Reload")
|
||||
assert.Fail(t, "BackendConfiguration should be equal after Reload")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendReloadRemoveHost(t *testing.T) {
|
||||
CatchLogForTest(t)
|
||||
require := require.New(t)
|
||||
current := testutil.ToFloat64(statsBackendsCurrent)
|
||||
original_config := goconf.NewConfigFile()
|
||||
original_config.AddOption("backend", "backends", "backend1, backend2")
|
||||
|
|
@ -414,9 +385,7 @@ func TestBackendReloadRemoveHost(t *testing.T) {
|
|||
original_config.AddOption("backend2", "url", "http://domain2.invalid")
|
||||
original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
|
||||
o_cfg, err := NewBackendConfiguration(original_config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
checkStatsValue(t, statsBackendsCurrent, current+2)
|
||||
new_config := goconf.NewConfigFile()
|
||||
|
|
@ -425,9 +394,7 @@ func TestBackendReloadRemoveHost(t *testing.T) {
|
|||
new_config.AddOption("backend1", "url", "http://domain1.invalid")
|
||||
new_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
|
||||
n_cfg, err := NewBackendConfiguration(new_config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
checkStatsValue(t, statsBackendsCurrent, current+3)
|
||||
original_config.RemoveOption("backend", "backends")
|
||||
|
|
@ -437,12 +404,13 @@ func TestBackendReloadRemoveHost(t *testing.T) {
|
|||
o_cfg.Reload(original_config)
|
||||
checkStatsValue(t, statsBackendsCurrent, current+2)
|
||||
if !reflect.DeepEqual(n_cfg, o_cfg) {
|
||||
t.Error("BackendConfiguration should be equal after Reload")
|
||||
assert.Fail(t, "BackendConfiguration should be equal after Reload")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendReloadRemoveBackendFromSharedHost(t *testing.T) {
|
||||
CatchLogForTest(t)
|
||||
require := require.New(t)
|
||||
current := testutil.ToFloat64(statsBackendsCurrent)
|
||||
original_config := goconf.NewConfigFile()
|
||||
original_config.AddOption("backend", "backends", "backend1, backend2")
|
||||
|
|
@ -452,9 +420,7 @@ func TestBackendReloadRemoveBackendFromSharedHost(t *testing.T) {
|
|||
original_config.AddOption("backend2", "url", "http://domain1.invalid/bar/")
|
||||
original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
|
||||
o_cfg, err := NewBackendConfiguration(original_config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
checkStatsValue(t, statsBackendsCurrent, current+2)
|
||||
new_config := goconf.NewConfigFile()
|
||||
|
|
@ -463,9 +429,7 @@ func TestBackendReloadRemoveBackendFromSharedHost(t *testing.T) {
|
|||
new_config.AddOption("backend1", "url", "http://domain1.invalid/foo/")
|
||||
new_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
|
||||
n_cfg, err := NewBackendConfiguration(new_config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
checkStatsValue(t, statsBackendsCurrent, current+3)
|
||||
original_config.RemoveOption("backend", "backends")
|
||||
|
|
@ -475,7 +439,7 @@ func TestBackendReloadRemoveBackendFromSharedHost(t *testing.T) {
|
|||
o_cfg.Reload(original_config)
|
||||
checkStatsValue(t, statsBackendsCurrent, current+2)
|
||||
if !reflect.DeepEqual(n_cfg, o_cfg) {
|
||||
t.Error("BackendConfiguration should be equal after Reload")
|
||||
assert.Fail(t, "BackendConfiguration should be equal after Reload")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -500,6 +464,8 @@ func mustParse(s string) *url.URL {
|
|||
func TestBackendConfiguration_Etcd(t *testing.T) {
|
||||
t.Parallel()
|
||||
CatchLogForTest(t)
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
etcd, client := NewEtcdClientForTest(t)
|
||||
|
||||
url1 := "https://domain1.invalid/foo"
|
||||
|
|
@ -513,9 +479,7 @@ func TestBackendConfiguration_Etcd(t *testing.T) {
|
|||
config.AddOption("backend", "backendprefix", "/backends")
|
||||
|
||||
cfg, err := NewBackendConfiguration(config, client)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
defer cfg.Close()
|
||||
|
||||
storage := cfg.storage.(*backendStorageEtcd)
|
||||
|
|
@ -524,31 +488,25 @@ func TestBackendConfiguration_Etcd(t *testing.T) {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := storage.WaitForInitialized(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(storage.WaitForInitialized(ctx))
|
||||
|
||||
if backends := sortBackends(cfg.GetBackends()); len(backends) != 1 {
|
||||
t.Errorf("Expected one backend, got %+v", backends)
|
||||
} else if backends[0].url != url1 {
|
||||
t.Errorf("Expected backend url %s, got %s", url1, backends[0].url)
|
||||
} else if string(backends[0].secret) != initialSecret1 {
|
||||
t.Errorf("Expected backend secret %s, got %s", initialSecret1, string(backends[0].secret))
|
||||
} else if backend := cfg.GetBackend(mustParse(url1)); backend != backends[0] {
|
||||
t.Errorf("Expected backend %+v, got %+v", backends[0], backend)
|
||||
if backends := sortBackends(cfg.GetBackends()); assert.Len(backends, 1) &&
|
||||
assert.Equal(url1, backends[0].url) &&
|
||||
assert.Equal(initialSecret1, string(backends[0].secret)) {
|
||||
if backend := cfg.GetBackend(mustParse(url1)); backend != backends[0] {
|
||||
assert.Fail("Expected backend %+v, got %+v", backends[0], backend)
|
||||
}
|
||||
}
|
||||
|
||||
drainWakeupChannel(ch)
|
||||
SetEtcdValue(etcd, "/backends/1_one", []byte("{\"url\":\""+url1+"\",\"secret\":\""+secret1+"\"}"))
|
||||
<-ch
|
||||
if backends := sortBackends(cfg.GetBackends()); len(backends) != 1 {
|
||||
t.Errorf("Expected one backend, got %+v", backends)
|
||||
} else if backends[0].url != url1 {
|
||||
t.Errorf("Expected backend url %s, got %s", url1, backends[0].url)
|
||||
} else if string(backends[0].secret) != secret1 {
|
||||
t.Errorf("Expected backend secret %s, got %s", secret1, string(backends[0].secret))
|
||||
} else if backend := cfg.GetBackend(mustParse(url1)); backend != backends[0] {
|
||||
t.Errorf("Expected backend %+v, got %+v", backends[0], backend)
|
||||
if backends := sortBackends(cfg.GetBackends()); assert.Len(backends, 1) &&
|
||||
assert.Equal(url1, backends[0].url) &&
|
||||
assert.Equal(secret1, string(backends[0].secret)) {
|
||||
if backend := cfg.GetBackend(mustParse(url1)); backend != backends[0] {
|
||||
assert.Fail("Expected backend %+v, got %+v", backends[0], backend)
|
||||
}
|
||||
}
|
||||
|
||||
url2 := "https://domain1.invalid/bar"
|
||||
|
|
@ -557,20 +515,16 @@ func TestBackendConfiguration_Etcd(t *testing.T) {
|
|||
drainWakeupChannel(ch)
|
||||
SetEtcdValue(etcd, "/backends/2_two", []byte("{\"url\":\""+url2+"\",\"secret\":\""+secret2+"\"}"))
|
||||
<-ch
|
||||
if backends := sortBackends(cfg.GetBackends()); len(backends) != 2 {
|
||||
t.Errorf("Expected two backends, got %+v", backends)
|
||||
} else if backends[0].url != url1 {
|
||||
t.Errorf("Expected backend url %s, got %s", url1, backends[0].url)
|
||||
} else if string(backends[0].secret) != secret1 {
|
||||
t.Errorf("Expected backend secret %s, got %s", secret1, string(backends[0].secret))
|
||||
} else if backends[1].url != url2 {
|
||||
t.Errorf("Expected backend url %s, got %s", url2, backends[1].url)
|
||||
} else if string(backends[1].secret) != secret2 {
|
||||
t.Errorf("Expected backend secret %s, got %s", secret2, string(backends[1].secret))
|
||||
} else if backend := cfg.GetBackend(mustParse(url1)); backend != backends[0] {
|
||||
t.Errorf("Expected backend %+v, got %+v", backends[0], backend)
|
||||
} else if backend := cfg.GetBackend(mustParse(url2)); backend != backends[1] {
|
||||
t.Errorf("Expected backend %+v, got %+v", backends[1], backend)
|
||||
if backends := sortBackends(cfg.GetBackends()); assert.Len(backends, 2) &&
|
||||
assert.Equal(url1, backends[0].url) &&
|
||||
assert.Equal(secret1, string(backends[0].secret)) &&
|
||||
assert.Equal(url2, backends[1].url) &&
|
||||
assert.Equal(secret2, string(backends[1].secret)) {
|
||||
if backend := cfg.GetBackend(mustParse(url1)); backend != backends[0] {
|
||||
assert.Fail("Expected backend %+v, got %+v", backends[0], backend)
|
||||
} else if backend := cfg.GetBackend(mustParse(url2)); backend != backends[1] {
|
||||
assert.Fail("Expected backend %+v, got %+v", backends[1], backend)
|
||||
}
|
||||
}
|
||||
|
||||
url3 := "https://domain2.invalid/foo"
|
||||
|
|
@ -579,70 +533,54 @@ func TestBackendConfiguration_Etcd(t *testing.T) {
|
|||
drainWakeupChannel(ch)
|
||||
SetEtcdValue(etcd, "/backends/3_three", []byte("{\"url\":\""+url3+"\",\"secret\":\""+secret3+"\"}"))
|
||||
<-ch
|
||||
if backends := sortBackends(cfg.GetBackends()); len(backends) != 3 {
|
||||
t.Errorf("Expected three backends, got %+v", backends)
|
||||
} else if backends[0].url != url1 {
|
||||
t.Errorf("Expected backend url %s, got %s", url1, backends[0].url)
|
||||
} else if string(backends[0].secret) != secret1 {
|
||||
t.Errorf("Expected backend secret %s, got %s", secret1, string(backends[0].secret))
|
||||
} else if backends[1].url != url2 {
|
||||
t.Errorf("Expected backend url %s, got %s", url2, backends[1].url)
|
||||
} else if string(backends[1].secret) != secret2 {
|
||||
t.Errorf("Expected backend secret %s, got %s", secret2, string(backends[1].secret))
|
||||
} else if backends[2].url != url3 {
|
||||
t.Errorf("Expected backend url %s, got %s", url3, backends[2].url)
|
||||
} else if string(backends[2].secret) != secret3 {
|
||||
t.Errorf("Expected backend secret %s, got %s", secret3, string(backends[2].secret))
|
||||
} else if backend := cfg.GetBackend(mustParse(url1)); backend != backends[0] {
|
||||
t.Errorf("Expected backend %+v, got %+v", backends[0], backend)
|
||||
} else if backend := cfg.GetBackend(mustParse(url2)); backend != backends[1] {
|
||||
t.Errorf("Expected backend %+v, got %+v", backends[1], backend)
|
||||
} else if backend := cfg.GetBackend(mustParse(url3)); backend != backends[2] {
|
||||
t.Errorf("Expected backend %+v, got %+v", backends[2], backend)
|
||||
if backends := sortBackends(cfg.GetBackends()); assert.Len(backends, 3) &&
|
||||
assert.Equal(url1, backends[0].url) &&
|
||||
assert.Equal(secret1, string(backends[0].secret)) &&
|
||||
assert.Equal(url2, backends[1].url) &&
|
||||
assert.Equal(secret2, string(backends[1].secret)) &&
|
||||
assert.Equal(url3, backends[2].url) &&
|
||||
assert.Equal(secret3, string(backends[2].secret)) {
|
||||
if backend := cfg.GetBackend(mustParse(url1)); backend != backends[0] {
|
||||
assert.Fail("Expected backend %+v, got %+v", backends[0], backend)
|
||||
} else if backend := cfg.GetBackend(mustParse(url2)); backend != backends[1] {
|
||||
assert.Fail("Expected backend %+v, got %+v", backends[1], backend)
|
||||
} else if backend := cfg.GetBackend(mustParse(url3)); backend != backends[2] {
|
||||
assert.Fail("Expected backend %+v, got %+v", backends[2], backend)
|
||||
}
|
||||
}
|
||||
|
||||
drainWakeupChannel(ch)
|
||||
DeleteEtcdValue(etcd, "/backends/1_one")
|
||||
<-ch
|
||||
if backends := sortBackends(cfg.GetBackends()); len(backends) != 2 {
|
||||
t.Errorf("Expected two backends, got %+v", backends)
|
||||
} else if backends[0].url != url2 {
|
||||
t.Errorf("Expected backend url %s, got %s", url2, backends[0].url)
|
||||
} else if string(backends[0].secret) != secret2 {
|
||||
t.Errorf("Expected backend secret %s, got %s", secret2, string(backends[0].secret))
|
||||
} else if backends[1].url != url3 {
|
||||
t.Errorf("Expected backend url %s, got %s", url3, backends[1].url)
|
||||
} else if string(backends[1].secret) != secret3 {
|
||||
t.Errorf("Expected backend secret %s, got %s", secret3, string(backends[1].secret))
|
||||
if backends := sortBackends(cfg.GetBackends()); assert.Len(backends, 2) {
|
||||
assert.Equal(url2, backends[0].url)
|
||||
assert.Equal(secret2, string(backends[0].secret))
|
||||
assert.Equal(url3, backends[1].url)
|
||||
assert.Equal(secret3, string(backends[1].secret))
|
||||
}
|
||||
|
||||
drainWakeupChannel(ch)
|
||||
DeleteEtcdValue(etcd, "/backends/2_two")
|
||||
<-ch
|
||||
if backends := sortBackends(cfg.GetBackends()); len(backends) != 1 {
|
||||
t.Errorf("Expected one backend, got %+v", backends)
|
||||
} else if backends[0].url != url3 {
|
||||
t.Errorf("Expected backend url %s, got %s", url3, backends[0].url)
|
||||
} else if string(backends[0].secret) != secret3 {
|
||||
t.Errorf("Expected backend secret %s, got %s", secret3, string(backends[0].secret))
|
||||
if backends := sortBackends(cfg.GetBackends()); assert.Len(backends, 1) {
|
||||
assert.Equal(url3, backends[0].url)
|
||||
assert.Equal(secret3, string(backends[0].secret))
|
||||
}
|
||||
|
||||
if _, found := storage.backends["domain1.invalid"]; found {
|
||||
t.Errorf("Should have removed host information for %s", "domain1.invalid")
|
||||
assert.Fail("Should have removed host information for %s", "domain1.invalid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendCommonSecret(t *testing.T) {
|
||||
t.Parallel()
|
||||
CatchLogForTest(t)
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
u1, err := url.Parse("http://domain1.invalid")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
u2, err := url.Parse("http://domain2.invalid")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
original_config := goconf.NewConfigFile()
|
||||
original_config.AddOption("backend", "backends", "backend1, backend2")
|
||||
original_config.AddOption("backend", "secret", string(testBackendSecret))
|
||||
|
|
@ -650,19 +588,13 @@ func TestBackendCommonSecret(t *testing.T) {
|
|||
original_config.AddOption("backend2", "url", u2.String())
|
||||
original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
|
||||
cfg, err := NewBackendConfiguration(original_config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
if b1 := cfg.GetBackend(u1); b1 == nil {
|
||||
t.Error("didn't get backend")
|
||||
} else if !bytes.Equal(b1.Secret(), testBackendSecret) {
|
||||
t.Errorf("expected secret %s, got %s", string(testBackendSecret), string(b1.Secret()))
|
||||
if b1 := cfg.GetBackend(u1); assert.NotNil(b1) {
|
||||
assert.Equal(string(testBackendSecret), string(b1.Secret()))
|
||||
}
|
||||
if b2 := cfg.GetBackend(u2); b2 == nil {
|
||||
t.Error("didn't get backend")
|
||||
} else if !bytes.Equal(b2.Secret(), []byte(string(testBackendSecret)+"-backend2")) {
|
||||
t.Errorf("expected secret %s, got %s", string(testBackendSecret)+"-backend2", string(b2.Secret()))
|
||||
if b2 := cfg.GetBackend(u2); assert.NotNil(b2) {
|
||||
assert.Equal(string(testBackendSecret)+"-backend2", string(b2.Secret()))
|
||||
}
|
||||
|
||||
updated_config := goconf.NewConfigFile()
|
||||
|
|
@ -673,14 +605,10 @@ func TestBackendCommonSecret(t *testing.T) {
|
|||
updated_config.AddOption("backend2", "url", u2.String())
|
||||
cfg.Reload(updated_config)
|
||||
|
||||
if b1 := cfg.GetBackend(u1); b1 == nil {
|
||||
t.Error("didn't get backend")
|
||||
} else if !bytes.Equal(b1.Secret(), []byte(string(testBackendSecret)+"-backend1")) {
|
||||
t.Errorf("expected secret %s, got %s", string(testBackendSecret)+"-backend1", string(b1.Secret()))
|
||||
if b1 := cfg.GetBackend(u1); assert.NotNil(b1) {
|
||||
assert.Equal(string(testBackendSecret)+"-backend1", string(b1.Secret()))
|
||||
}
|
||||
if b2 := cfg.GetBackend(u2); b2 == nil {
|
||||
t.Error("didn't get backend")
|
||||
} else if !bytes.Equal(b2.Secret(), testBackendSecret) {
|
||||
t.Errorf("expected secret %s, got %s", string(testBackendSecret), string(b2.Secret()))
|
||||
if b2 := cfg.GetBackend(u2); assert.NotNil(b2) {
|
||||
assert.Equal(string(testBackendSecret), string(b2.Secret()))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -25,6 +25,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/dlintw/goconf"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.etcd.io/etcd/server/v3/embed"
|
||||
)
|
||||
|
||||
|
|
@ -67,9 +68,7 @@ func Test_BackendStorageEtcdNoLeak(t *testing.T) {
|
|||
config.AddOption("backend", "backendprefix", "/backends")
|
||||
|
||||
cfg, err := NewBackendConfiguration(config, client)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
<-tl.closed
|
||||
cfg.Close()
|
||||
|
|
|
|||
|
|
@ -25,17 +25,20 @@ import (
|
|||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBackoff_Exponential(t *testing.T) {
|
||||
t.Parallel()
|
||||
backoff, err := NewExponentialBackoff(100*time.Millisecond, 500*time.Millisecond)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert := assert.New(t)
|
||||
minWait := 100 * time.Millisecond
|
||||
backoff, err := NewExponentialBackoff(minWait, 500*time.Millisecond)
|
||||
require.NoError(t, err)
|
||||
|
||||
waitTimes := []time.Duration{
|
||||
100 * time.Millisecond,
|
||||
minWait,
|
||||
200 * time.Millisecond,
|
||||
400 * time.Millisecond,
|
||||
500 * time.Millisecond,
|
||||
|
|
@ -43,23 +46,17 @@ func TestBackoff_Exponential(t *testing.T) {
|
|||
}
|
||||
|
||||
for _, wait := range waitTimes {
|
||||
if backoff.NextWait() != wait {
|
||||
t.Errorf("Wait time should be %s, got %s", wait, backoff.NextWait())
|
||||
}
|
||||
assert.Equal(wait, backoff.NextWait())
|
||||
|
||||
a := time.Now()
|
||||
backoff.Wait(context.Background())
|
||||
b := time.Now()
|
||||
if b.Sub(a) < wait {
|
||||
t.Errorf("Should have waited %s, got %s", wait, b.Sub(a))
|
||||
}
|
||||
assert.GreaterOrEqual(b.Sub(a), wait)
|
||||
}
|
||||
|
||||
backoff.Reset()
|
||||
a := time.Now()
|
||||
backoff.Wait(context.Background())
|
||||
b := time.Now()
|
||||
if b.Sub(a) < 100*time.Millisecond {
|
||||
t.Errorf("Should have waited %s, got %s", 100*time.Millisecond, b.Sub(a))
|
||||
}
|
||||
assert.GreaterOrEqual(b.Sub(a), minWait)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -37,17 +37,16 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func NewCapabilitiesForTestWithCallback(t *testing.T, callback func(*CapabilitiesResponse, http.ResponseWriter) error) (*url.URL, *Capabilities) {
|
||||
require := require.New(t)
|
||||
pool, err := NewHttpClientPool(1, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
capabilities, err := NewCapabilities("0.0", pool)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
r := mux.NewRouter()
|
||||
server := httptest.NewServer(r)
|
||||
|
|
@ -56,9 +55,7 @@ func NewCapabilitiesForTestWithCallback(t *testing.T, callback func(*Capabilitie
|
|||
})
|
||||
|
||||
u, err := url.Parse(server.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
handleCapabilitiesFunc := func(w http.ResponseWriter, r *http.Request) {
|
||||
features := []string{
|
||||
|
|
@ -91,9 +88,7 @@ func NewCapabilitiesForTestWithCallback(t *testing.T, callback func(*Capabilitie
|
|||
}
|
||||
|
||||
data, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
t.Errorf("Could not marshal %+v: %s", response, err)
|
||||
}
|
||||
assert.NoError(t, err, "Could not marshal %+v", response)
|
||||
|
||||
var ocs OcsResponse
|
||||
ocs.Ocs = &OcsBody{
|
||||
|
|
@ -104,9 +99,9 @@ func NewCapabilitiesForTestWithCallback(t *testing.T, callback func(*Capabilitie
|
|||
},
|
||||
Data: data,
|
||||
}
|
||||
if data, err = json.Marshal(ocs); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, err = json.Marshal(ocs)
|
||||
require.NoError(err)
|
||||
|
||||
var cc []string
|
||||
if !strings.Contains(t.Name(), "NoCache") {
|
||||
if strings.Contains(t.Name(), "ShortCache") {
|
||||
|
|
@ -177,70 +172,50 @@ func SetCapabilitiesGetNow(t *testing.T, capabilities *Capabilities, f func() ti
|
|||
func TestCapabilities(t *testing.T) {
|
||||
t.Parallel()
|
||||
CatchLogForTest(t)
|
||||
assert := assert.New(t)
|
||||
url, capabilities := NewCapabilitiesForTest(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
|
||||
if !capabilities.HasCapabilityFeature(ctx, url, "foo") {
|
||||
t.Error("should have capability \"foo\"")
|
||||
}
|
||||
if capabilities.HasCapabilityFeature(ctx, url, "lala") {
|
||||
t.Error("should not have capability \"lala\"")
|
||||
}
|
||||
assert.True(capabilities.HasCapabilityFeature(ctx, url, "foo"))
|
||||
assert.False(capabilities.HasCapabilityFeature(ctx, url, "lala"))
|
||||
|
||||
expectedString := "bar"
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
|
||||
t.Error("could not find value for \"foo\"")
|
||||
} else if value != expectedString {
|
||||
t.Errorf("expected value %s, got %s", expectedString, value)
|
||||
} else if !cached {
|
||||
t.Errorf("expected cached response")
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
|
||||
assert.Equal(expectedString, value)
|
||||
assert.True(cached)
|
||||
}
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "baz"); found {
|
||||
t.Errorf("should not have found value for \"baz\", got %s", value)
|
||||
} else if !cached {
|
||||
t.Errorf("expected cached response")
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "baz"); assert.False(found, "should not have found value for \"baz\", got %s", value) {
|
||||
assert.True(cached)
|
||||
}
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "invalid"); found {
|
||||
t.Errorf("should not have found value for \"invalid\", got %s", value)
|
||||
} else if !cached {
|
||||
t.Errorf("expected cached response")
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "invalid"); assert.False(found, "should not have found value for \"invalid\", got %s", value) {
|
||||
assert.True(cached)
|
||||
}
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "invalid", "foo"); found {
|
||||
t.Errorf("should not have found value for \"baz\", got %s", value)
|
||||
} else if !cached {
|
||||
t.Errorf("expected cached response")
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "invalid", "foo"); assert.False(found, "should not have found value for \"baz\", got %s", value) {
|
||||
assert.True(cached)
|
||||
}
|
||||
|
||||
expectedInt := 42
|
||||
if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "baz"); !found {
|
||||
t.Error("could not find value for \"baz\"")
|
||||
} else if value != expectedInt {
|
||||
t.Errorf("expected value %d, got %d", expectedInt, value)
|
||||
} else if !cached {
|
||||
t.Errorf("expected cached response")
|
||||
if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "baz"); assert.True(found) {
|
||||
assert.Equal(expectedInt, value)
|
||||
assert.True(cached)
|
||||
}
|
||||
if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "foo"); found {
|
||||
t.Errorf("should not have found value for \"foo\", got %d", value)
|
||||
} else if !cached {
|
||||
t.Errorf("expected cached response")
|
||||
if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "foo"); assert.False(found, "should not have found value for \"foo\", got %d", value) {
|
||||
assert.True(cached)
|
||||
}
|
||||
if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "invalid"); found {
|
||||
t.Errorf("should not have found value for \"invalid\", got %d", value)
|
||||
} else if !cached {
|
||||
t.Errorf("expected cached response")
|
||||
if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "invalid"); assert.False(found, "should not have found value for \"invalid\", got %d", value) {
|
||||
assert.True(cached)
|
||||
}
|
||||
if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "invalid", "baz"); found {
|
||||
t.Errorf("should not have found value for \"baz\", got %d", value)
|
||||
} else if !cached {
|
||||
t.Errorf("expected cached response")
|
||||
if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "invalid", "baz"); assert.False(found, "should not have found value for \"baz\", got %d", value) {
|
||||
assert.True(cached)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidateCapabilities(t *testing.T) {
|
||||
t.Parallel()
|
||||
CatchLogForTest(t)
|
||||
assert := assert.New(t)
|
||||
var called atomic.Uint32
|
||||
url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error {
|
||||
called.Add(1)
|
||||
|
|
@ -251,47 +226,35 @@ func TestInvalidateCapabilities(t *testing.T) {
|
|||
defer cancel()
|
||||
|
||||
expectedString := "bar"
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
|
||||
t.Error("could not find value for \"foo\"")
|
||||
} else if value != expectedString {
|
||||
t.Errorf("expected value %s, got %s", expectedString, value)
|
||||
} else if cached {
|
||||
t.Errorf("expected direct response")
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
|
||||
assert.Equal(expectedString, value)
|
||||
assert.False(cached)
|
||||
}
|
||||
|
||||
if value := called.Load(); value != 1 {
|
||||
t.Errorf("expected called %d, got %d", 1, value)
|
||||
}
|
||||
value := called.Load()
|
||||
assert.EqualValues(1, value)
|
||||
|
||||
// Invalidating will cause the capabilities to be reloaded.
|
||||
capabilities.InvalidateCapabilities(url)
|
||||
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
|
||||
t.Error("could not find value for \"foo\"")
|
||||
} else if value != expectedString {
|
||||
t.Errorf("expected value %s, got %s", expectedString, value)
|
||||
} else if cached {
|
||||
t.Errorf("expected direct response")
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
|
||||
assert.Equal(expectedString, value)
|
||||
assert.False(cached)
|
||||
}
|
||||
|
||||
if value := called.Load(); value != 2 {
|
||||
t.Errorf("expected called %d, got %d", 2, value)
|
||||
}
|
||||
value = called.Load()
|
||||
assert.EqualValues(2, value)
|
||||
|
||||
// Invalidating is throttled to about once per minute.
|
||||
capabilities.InvalidateCapabilities(url)
|
||||
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
|
||||
t.Error("could not find value for \"foo\"")
|
||||
} else if value != expectedString {
|
||||
t.Errorf("expected value %s, got %s", expectedString, value)
|
||||
} else if !cached {
|
||||
t.Errorf("expected cached response")
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
|
||||
assert.Equal(expectedString, value)
|
||||
assert.True(cached)
|
||||
}
|
||||
|
||||
if value := called.Load(); value != 2 {
|
||||
t.Errorf("expected called %d, got %d", 2, value)
|
||||
}
|
||||
value = called.Load()
|
||||
assert.EqualValues(2, value)
|
||||
|
||||
// At a later time, invalidating can be done again.
|
||||
SetCapabilitiesGetNow(t, capabilities, func() time.Time {
|
||||
|
|
@ -300,22 +263,19 @@ func TestInvalidateCapabilities(t *testing.T) {
|
|||
|
||||
capabilities.InvalidateCapabilities(url)
|
||||
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
|
||||
t.Error("could not find value for \"foo\"")
|
||||
} else if value != expectedString {
|
||||
t.Errorf("expected value %s, got %s", expectedString, value)
|
||||
} else if cached {
|
||||
t.Errorf("expected direct response")
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
|
||||
assert.Equal(expectedString, value)
|
||||
assert.False(cached)
|
||||
}
|
||||
|
||||
if value := called.Load(); value != 3 {
|
||||
t.Errorf("expected called %d, got %d", 3, value)
|
||||
}
|
||||
value = called.Load()
|
||||
assert.EqualValues(3, value)
|
||||
}
|
||||
|
||||
func TestCapabilitiesNoCache(t *testing.T) {
|
||||
t.Parallel()
|
||||
CatchLogForTest(t)
|
||||
assert := assert.New(t)
|
||||
var called atomic.Uint32
|
||||
url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error {
|
||||
called.Add(1)
|
||||
|
|
@ -326,51 +286,40 @@ func TestCapabilitiesNoCache(t *testing.T) {
|
|||
defer cancel()
|
||||
|
||||
expectedString := "bar"
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
|
||||
t.Error("could not find value for \"foo\"")
|
||||
} else if value != expectedString {
|
||||
t.Errorf("expected value %s, got %s", expectedString, value)
|
||||
} else if cached {
|
||||
t.Errorf("expected direct response")
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
|
||||
assert.Equal(expectedString, value)
|
||||
assert.False(cached)
|
||||
}
|
||||
|
||||
if value := called.Load(); value != 1 {
|
||||
t.Errorf("expected called %d, got %d", 1, value)
|
||||
}
|
||||
value := called.Load()
|
||||
assert.EqualValues(1, value)
|
||||
|
||||
// Capabilities are cached for some time if no "Cache-Control" header is set.
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
|
||||
t.Error("could not find value for \"foo\"")
|
||||
} else if value != expectedString {
|
||||
t.Errorf("expected value %s, got %s", expectedString, value)
|
||||
} else if !cached {
|
||||
t.Errorf("expected cached response")
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
|
||||
assert.Equal(expectedString, value)
|
||||
assert.True(cached)
|
||||
}
|
||||
|
||||
if value := called.Load(); value != 1 {
|
||||
t.Errorf("expected called %d, got %d", 1, value)
|
||||
}
|
||||
value = called.Load()
|
||||
assert.EqualValues(1, value)
|
||||
|
||||
SetCapabilitiesGetNow(t, capabilities, func() time.Time {
|
||||
return time.Now().Add(minCapabilitiesCacheDuration)
|
||||
})
|
||||
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
|
||||
t.Error("could not find value for \"foo\"")
|
||||
} else if value != expectedString {
|
||||
t.Errorf("expected value %s, got %s", expectedString, value)
|
||||
} else if cached {
|
||||
t.Errorf("expected direct response")
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
|
||||
assert.Equal(expectedString, value)
|
||||
assert.False(cached)
|
||||
}
|
||||
|
||||
if value := called.Load(); value != 2 {
|
||||
t.Errorf("expected called %d, got %d", 2, value)
|
||||
}
|
||||
value = called.Load()
|
||||
assert.EqualValues(2, value)
|
||||
}
|
||||
|
||||
func TestCapabilitiesShortCache(t *testing.T) {
|
||||
t.Parallel()
|
||||
CatchLogForTest(t)
|
||||
assert := assert.New(t)
|
||||
var called atomic.Uint32
|
||||
url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error {
|
||||
called.Add(1)
|
||||
|
|
@ -381,76 +330,58 @@ func TestCapabilitiesShortCache(t *testing.T) {
|
|||
defer cancel()
|
||||
|
||||
expectedString := "bar"
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
|
||||
t.Error("could not find value for \"foo\"")
|
||||
} else if value != expectedString {
|
||||
t.Errorf("expected value %s, got %s", expectedString, value)
|
||||
} else if cached {
|
||||
t.Errorf("expected direct response")
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
|
||||
assert.Equal(expectedString, value)
|
||||
assert.False(cached)
|
||||
}
|
||||
|
||||
if value := called.Load(); value != 1 {
|
||||
t.Errorf("expected called %d, got %d", 1, value)
|
||||
}
|
||||
value := called.Load()
|
||||
assert.EqualValues(1, value)
|
||||
|
||||
// Capabilities are cached for some time if no "Cache-Control" header is set.
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
|
||||
t.Error("could not find value for \"foo\"")
|
||||
} else if value != expectedString {
|
||||
t.Errorf("expected value %s, got %s", expectedString, value)
|
||||
} else if !cached {
|
||||
t.Errorf("expected cached response")
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
|
||||
assert.Equal(expectedString, value)
|
||||
assert.True(cached)
|
||||
}
|
||||
|
||||
if value := called.Load(); value != 1 {
|
||||
t.Errorf("expected called %d, got %d", 1, value)
|
||||
}
|
||||
value = called.Load()
|
||||
assert.EqualValues(1, value)
|
||||
|
||||
// The capabilities are cached for a minumum duration.
|
||||
SetCapabilitiesGetNow(t, capabilities, func() time.Time {
|
||||
return time.Now().Add(minCapabilitiesCacheDuration / 2)
|
||||
})
|
||||
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
|
||||
t.Error("could not find value for \"foo\"")
|
||||
} else if value != expectedString {
|
||||
t.Errorf("expected value %s, got %s", expectedString, value)
|
||||
} else if !cached {
|
||||
t.Errorf("expected cached response")
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
|
||||
assert.Equal(expectedString, value)
|
||||
assert.True(cached)
|
||||
}
|
||||
|
||||
SetCapabilitiesGetNow(t, capabilities, func() time.Time {
|
||||
return time.Now().Add(minCapabilitiesCacheDuration)
|
||||
})
|
||||
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
|
||||
t.Error("could not find value for \"foo\"")
|
||||
} else if value != expectedString {
|
||||
t.Errorf("expected value %s, got %s", expectedString, value)
|
||||
} else if cached {
|
||||
t.Errorf("expected direct response")
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
|
||||
assert.Equal(expectedString, value)
|
||||
assert.False(cached)
|
||||
}
|
||||
|
||||
if value := called.Load(); value != 2 {
|
||||
t.Errorf("expected called %d, got %d", 2, value)
|
||||
}
|
||||
value = called.Load()
|
||||
assert.EqualValues(2, value)
|
||||
}
|
||||
|
||||
func TestCapabilitiesNoCacheETag(t *testing.T) {
|
||||
t.Parallel()
|
||||
CatchLogForTest(t)
|
||||
assert := assert.New(t)
|
||||
var called atomic.Uint32
|
||||
url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error {
|
||||
ct := w.Header().Get("Content-Type")
|
||||
switch called.Add(1) {
|
||||
case 1:
|
||||
if ct == "" {
|
||||
t.Error("expected content-type on first request")
|
||||
}
|
||||
assert.NotEmpty(ct, "expected content-type on first request")
|
||||
case 2:
|
||||
if ct != "" {
|
||||
t.Errorf("expected no content-type on second request, got %s", ct)
|
||||
}
|
||||
assert.Empty(ct, "expected no content-type on second request")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
|
@ -459,38 +390,31 @@ func TestCapabilitiesNoCacheETag(t *testing.T) {
|
|||
defer cancel()
|
||||
|
||||
expectedString := "bar"
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
|
||||
t.Error("could not find value for \"foo\"")
|
||||
} else if value != expectedString {
|
||||
t.Errorf("expected value %s, got %s", expectedString, value)
|
||||
} else if cached {
|
||||
t.Errorf("expected direct response")
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
|
||||
assert.Equal(expectedString, value)
|
||||
assert.False(cached)
|
||||
}
|
||||
|
||||
if value := called.Load(); value != 1 {
|
||||
t.Errorf("expected called %d, got %d", 1, value)
|
||||
}
|
||||
value := called.Load()
|
||||
assert.EqualValues(1, value)
|
||||
|
||||
SetCapabilitiesGetNow(t, capabilities, func() time.Time {
|
||||
return time.Now().Add(minCapabilitiesCacheDuration)
|
||||
})
|
||||
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
|
||||
t.Error("could not find value for \"foo\"")
|
||||
} else if value != expectedString {
|
||||
t.Errorf("expected value %s, got %s", expectedString, value)
|
||||
} else if cached {
|
||||
t.Errorf("expected direct response")
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
|
||||
assert.Equal(expectedString, value)
|
||||
assert.False(cached)
|
||||
}
|
||||
|
||||
if value := called.Load(); value != 2 {
|
||||
t.Errorf("expected called %d, got %d", 2, value)
|
||||
}
|
||||
value = called.Load()
|
||||
assert.EqualValues(2, value)
|
||||
}
|
||||
|
||||
func TestCapabilitiesCacheNoMustRevalidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
CatchLogForTest(t)
|
||||
assert := assert.New(t)
|
||||
var called atomic.Uint32
|
||||
url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error {
|
||||
if called.Add(1) == 2 {
|
||||
|
|
@ -504,17 +428,13 @@ func TestCapabilitiesCacheNoMustRevalidate(t *testing.T) {
|
|||
defer cancel()
|
||||
|
||||
expectedString := "bar"
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
|
||||
t.Error("could not find value for \"foo\"")
|
||||
} else if value != expectedString {
|
||||
t.Errorf("expected value %s, got %s", expectedString, value)
|
||||
} else if cached {
|
||||
t.Errorf("expected direct response")
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
|
||||
assert.Equal(expectedString, value)
|
||||
assert.False(cached)
|
||||
}
|
||||
|
||||
if value := called.Load(); value != 1 {
|
||||
t.Errorf("expected called %d, got %d", 1, value)
|
||||
}
|
||||
value := called.Load()
|
||||
assert.EqualValues(1, value)
|
||||
|
||||
SetCapabilitiesGetNow(t, capabilities, func() time.Time {
|
||||
return time.Now().Add(time.Minute)
|
||||
|
|
@ -522,22 +442,19 @@ func TestCapabilitiesCacheNoMustRevalidate(t *testing.T) {
|
|||
|
||||
// Expired capabilities can still be used even in case of update errors if
|
||||
// "must-revalidate" is not set.
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
|
||||
t.Error("could not find value for \"foo\"")
|
||||
} else if value != expectedString {
|
||||
t.Errorf("expected value %s, got %s", expectedString, value)
|
||||
} else if cached {
|
||||
t.Errorf("expected direct response")
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
|
||||
assert.Equal(expectedString, value)
|
||||
assert.False(cached)
|
||||
}
|
||||
|
||||
if value := called.Load(); value != 2 {
|
||||
t.Errorf("expected called %d, got %d", 2, value)
|
||||
}
|
||||
value = called.Load()
|
||||
assert.EqualValues(2, value)
|
||||
}
|
||||
|
||||
func TestCapabilitiesNoCacheNoMustRevalidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
CatchLogForTest(t)
|
||||
assert := assert.New(t)
|
||||
var called atomic.Uint32
|
||||
url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error {
|
||||
if called.Add(1) == 2 {
|
||||
|
|
@ -551,17 +468,13 @@ func TestCapabilitiesNoCacheNoMustRevalidate(t *testing.T) {
|
|||
defer cancel()
|
||||
|
||||
expectedString := "bar"
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
|
||||
t.Error("could not find value for \"foo\"")
|
||||
} else if value != expectedString {
|
||||
t.Errorf("expected value %s, got %s", expectedString, value)
|
||||
} else if cached {
|
||||
t.Errorf("expected direct response")
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
|
||||
assert.Equal(expectedString, value)
|
||||
assert.False(cached)
|
||||
}
|
||||
|
||||
if value := called.Load(); value != 1 {
|
||||
t.Errorf("expected called %d, got %d", 1, value)
|
||||
}
|
||||
value := called.Load()
|
||||
assert.EqualValues(1, value)
|
||||
|
||||
SetCapabilitiesGetNow(t, capabilities, func() time.Time {
|
||||
return time.Now().Add(minCapabilitiesCacheDuration)
|
||||
|
|
@ -569,22 +482,19 @@ func TestCapabilitiesNoCacheNoMustRevalidate(t *testing.T) {
|
|||
|
||||
// Expired capabilities can still be used even in case of update errors if
|
||||
// "must-revalidate" is not set.
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
|
||||
t.Error("could not find value for \"foo\"")
|
||||
} else if value != expectedString {
|
||||
t.Errorf("expected value %s, got %s", expectedString, value)
|
||||
} else if cached {
|
||||
t.Errorf("expected direct response")
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
|
||||
assert.Equal(expectedString, value)
|
||||
assert.False(cached)
|
||||
}
|
||||
|
||||
if value := called.Load(); value != 2 {
|
||||
t.Errorf("expected called %d, got %d", 2, value)
|
||||
}
|
||||
value = called.Load()
|
||||
assert.EqualValues(2, value)
|
||||
}
|
||||
|
||||
func TestCapabilitiesNoCacheMustRevalidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
CatchLogForTest(t)
|
||||
assert := assert.New(t)
|
||||
var called atomic.Uint32
|
||||
url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error {
|
||||
if called.Add(1) == 2 {
|
||||
|
|
@ -598,17 +508,13 @@ func TestCapabilitiesNoCacheMustRevalidate(t *testing.T) {
|
|||
defer cancel()
|
||||
|
||||
expectedString := "bar"
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
|
||||
t.Error("could not find value for \"foo\"")
|
||||
} else if value != expectedString {
|
||||
t.Errorf("expected value %s, got %s", expectedString, value)
|
||||
} else if cached {
|
||||
t.Errorf("expected direct response")
|
||||
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
|
||||
assert.Equal(expectedString, value)
|
||||
assert.False(cached)
|
||||
}
|
||||
|
||||
if value := called.Load(); value != 1 {
|
||||
t.Errorf("expected called %d, got %d", 1, value)
|
||||
}
|
||||
value := called.Load()
|
||||
assert.EqualValues(1, value)
|
||||
|
||||
SetCapabilitiesGetNow(t, capabilities, func() time.Time {
|
||||
return time.Now().Add(minCapabilitiesCacheDuration)
|
||||
|
|
@ -616,11 +522,9 @@ func TestCapabilitiesNoCacheMustRevalidate(t *testing.T) {
|
|||
|
||||
// Capabilities will be cleared if "must-revalidate" is set and an error
|
||||
// occurs while fetching the updated data.
|
||||
if value, _, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); found {
|
||||
t.Errorf("should not have found value for \"foo\", got %s", value)
|
||||
}
|
||||
capaValue, _, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo")
|
||||
assert.False(found, "should not have found value for \"foo\", got %s", capaValue)
|
||||
|
||||
if value := called.Load(); value != 2 {
|
||||
t.Errorf("expected called %d, got %d", 2, value)
|
||||
}
|
||||
value = called.Load()
|
||||
assert.EqualValues(2, value)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,6 +23,8 @@ package signaling
|
|||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestChannelWaiters(t *testing.T) {
|
||||
|
|
@ -42,9 +44,9 @@ func TestChannelWaiters(t *testing.T) {
|
|||
|
||||
select {
|
||||
case <-ch1:
|
||||
t.Error("should have not received another event")
|
||||
assert.Fail(t, "should have not received another event")
|
||||
case <-ch2:
|
||||
t.Error("should have not received another event")
|
||||
assert.Fail(t, "should have not received another event")
|
||||
default:
|
||||
}
|
||||
|
||||
|
|
@ -60,7 +62,7 @@ func TestChannelWaiters(t *testing.T) {
|
|||
<-ch2
|
||||
select {
|
||||
case <-ch3:
|
||||
t.Error("should have not received another event")
|
||||
assert.Fail(t, "should have not received another event")
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,6 +26,9 @@ import (
|
|||
"net/url"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
@ -119,9 +122,7 @@ func Test_permissionsEqual(t *testing.T) {
|
|||
t.Run(strconv.Itoa(idx), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
equal := permissionsEqual(test.a, test.b)
|
||||
if equal != test.equal {
|
||||
t.Errorf("Expected %+v to be %s to %+v but was %s", test.a, equalStrings[test.equal], test.b, equalStrings[equal])
|
||||
}
|
||||
assert.Equal(t, test.equal, equal, "Expected %+v to be %s to %+v but was %s", test.a, equalStrings[test.equal], test.b, equalStrings[equal])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -129,17 +130,16 @@ func Test_permissionsEqual(t *testing.T) {
|
|||
func TestBandwidth_Client(t *testing.T) {
|
||||
t.Parallel()
|
||||
CatchLogForTest(t)
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
hub, _, _, server := CreateHubForTest(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
|
||||
mcu, err := NewTestMCU()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
} else if err := mcu.Start(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
require.NoError(mcu.Start(ctx))
|
||||
defer mcu.Stop()
|
||||
|
||||
hub.SetMcu(mcu)
|
||||
|
|
@ -147,31 +147,23 @@ func TestBandwidth_Client(t *testing.T) {
|
|||
client := NewTestClient(t, server, hub)
|
||||
defer client.CloseWithBye()
|
||||
|
||||
if err := client.SendHello(testDefaultUserId); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(client.SendHello(testDefaultUserId))
|
||||
|
||||
hello, err := client.RunUntilHello(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
// Join room by id.
|
||||
roomId := "test-room"
|
||||
if room, err := client.JoinRoom(ctx, roomId); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if room.Room.RoomId != roomId {
|
||||
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
|
||||
}
|
||||
roomMsg, err := client.JoinRoom(ctx, roomId)
|
||||
require.NoError(err)
|
||||
require.Equal(roomId, roomMsg.Room.RoomId)
|
||||
|
||||
// We will receive a "joined" event.
|
||||
if err := client.RunUntilJoined(ctx, hello.Hello); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(client.RunUntilJoined(ctx, hello.Hello))
|
||||
|
||||
// Client may not send an offer with audio and video.
|
||||
bitrate := 10000
|
||||
if err := client.SendMessage(MessageClientMessageRecipient{
|
||||
require.NoError(client.SendMessage(MessageClientMessageRecipient{
|
||||
Type: "session",
|
||||
SessionId: hello.Hello.SessionId,
|
||||
}, MessageClientMessageData{
|
||||
|
|
@ -182,22 +174,13 @@ func TestBandwidth_Client(t *testing.T) {
|
|||
Payload: map[string]interface{}{
|
||||
"sdp": MockSdpOfferAudioAndVideo,
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}))
|
||||
|
||||
if err := client.RunUntilAnswer(ctx, MockSdpAnswerAudioAndVideo); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(client.RunUntilAnswer(ctx, MockSdpAnswerAudioAndVideo))
|
||||
|
||||
pub := mcu.GetPublisher(hello.Hello.SessionId)
|
||||
if pub == nil {
|
||||
t.Fatal("Could not find publisher")
|
||||
}
|
||||
|
||||
if pub.bitrate != bitrate {
|
||||
t.Errorf("Expected bitrate %d, got %d", bitrate, pub.bitrate)
|
||||
}
|
||||
require.NotNil(pub)
|
||||
assert.Equal(bitrate, pub.bitrate)
|
||||
}
|
||||
|
||||
func TestBandwidth_Backend(t *testing.T) {
|
||||
|
|
@ -206,13 +189,9 @@ func TestBandwidth_Backend(t *testing.T) {
|
|||
hub, _, _, server := CreateHubWithMultipleBackendsForTest(t)
|
||||
|
||||
u, err := url.Parse(server.URL + "/one")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
backend := hub.backend.GetBackend(u)
|
||||
if backend == nil {
|
||||
t.Fatal("Could not get backend")
|
||||
}
|
||||
require.NotNil(t, backend, "Could not get backend")
|
||||
|
||||
backend.maxScreenBitrate = 1000
|
||||
backend.maxStreamBitrate = 2000
|
||||
|
|
@ -221,11 +200,8 @@ func TestBandwidth_Backend(t *testing.T) {
|
|||
defer cancel()
|
||||
|
||||
mcu, err := NewTestMCU()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
} else if err := mcu.Start(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mcu.Start(ctx))
|
||||
defer mcu.Stop()
|
||||
|
||||
hub.SetMcu(mcu)
|
||||
|
|
@ -237,37 +213,31 @@ func TestBandwidth_Backend(t *testing.T) {
|
|||
|
||||
for _, streamType := range streamTypes {
|
||||
t.Run(string(streamType), func(t *testing.T) {
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
client := NewTestClient(t, server, hub)
|
||||
defer client.CloseWithBye()
|
||||
|
||||
params := TestBackendClientAuthParams{
|
||||
UserId: testDefaultUserId,
|
||||
}
|
||||
if err := client.SendHelloParams(server.URL+"/one", HelloVersionV1, "client", nil, params); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(client.SendHelloParams(server.URL+"/one", HelloVersionV1, "client", nil, params))
|
||||
|
||||
hello, err := client.RunUntilHello(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
// Join room by id.
|
||||
roomId := "test-room"
|
||||
if room, err := client.JoinRoom(ctx, roomId); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if room.Room.RoomId != roomId {
|
||||
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
|
||||
}
|
||||
roomMsg, err := client.JoinRoom(ctx, roomId)
|
||||
require.NoError(err)
|
||||
require.Equal(roomId, roomMsg.Room.RoomId)
|
||||
|
||||
// We will receive a "joined" event.
|
||||
if err := client.RunUntilJoined(ctx, hello.Hello); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
require.NoError(client.RunUntilJoined(ctx, hello.Hello))
|
||||
|
||||
// Client may not send an offer with audio and video.
|
||||
bitrate := 10000
|
||||
if err := client.SendMessage(MessageClientMessageRecipient{
|
||||
require.NoError(client.SendMessage(MessageClientMessageRecipient{
|
||||
Type: "session",
|
||||
SessionId: hello.Hello.SessionId,
|
||||
}, MessageClientMessageData{
|
||||
|
|
@ -278,18 +248,12 @@ func TestBandwidth_Backend(t *testing.T) {
|
|||
Payload: map[string]interface{}{
|
||||
"sdp": MockSdpOfferAudioAndVideo,
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}))
|
||||
|
||||
if err := client.RunUntilAnswer(ctx, MockSdpAnswerAudioAndVideo); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(client.RunUntilAnswer(ctx, MockSdpAnswerAudioAndVideo))
|
||||
|
||||
pub := mcu.GetPublisher(hello.Hello.SessionId)
|
||||
if pub == nil {
|
||||
t.Fatal("Could not find publisher")
|
||||
}
|
||||
require.NotNil(pub, "Could not find publisher")
|
||||
|
||||
var expectBitrate int
|
||||
if streamType == StreamTypeVideo {
|
||||
|
|
@ -297,9 +261,7 @@ func TestBandwidth_Backend(t *testing.T) {
|
|||
} else {
|
||||
expectBitrate = backend.maxScreenBitrate
|
||||
}
|
||||
if pub.bitrate != expectBitrate {
|
||||
t.Errorf("Expected bitrate %d, got %d", expectBitrate, pub.bitrate)
|
||||
}
|
||||
assert.Equal(expectBitrate, pub.bitrate)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ package signaling
|
|||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCloserMulti(t *testing.T) {
|
||||
|
|
@ -39,24 +41,16 @@ func TestCloserMulti(t *testing.T) {
|
|||
}()
|
||||
}
|
||||
|
||||
if closer.IsClosed() {
|
||||
t.Error("should not be closed")
|
||||
}
|
||||
assert.False(t, closer.IsClosed())
|
||||
closer.Close()
|
||||
if !closer.IsClosed() {
|
||||
t.Error("should be closed")
|
||||
}
|
||||
assert.True(t, closer.IsClosed())
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestCloserCloseBeforeWait(t *testing.T) {
|
||||
closer := NewCloser()
|
||||
closer.Close()
|
||||
if !closer.IsClosed() {
|
||||
t.Error("should be closed")
|
||||
}
|
||||
assert.True(t, closer.IsClosed())
|
||||
<-closer.C
|
||||
if !closer.IsClosed() {
|
||||
t.Error("should be closed")
|
||||
}
|
||||
assert.True(t, closer.IsClosed())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,75 +25,52 @@ import (
|
|||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConcurrentStringStringMap(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
var m ConcurrentStringStringMap
|
||||
if m.Len() != 0 {
|
||||
t.Errorf("Expected %d entries, got %d", 0, m.Len())
|
||||
}
|
||||
if v, found := m.Get("foo"); found {
|
||||
t.Errorf("Expected missing entry, got %s", v)
|
||||
}
|
||||
assert.Equal(0, m.Len())
|
||||
v, found := m.Get("foo")
|
||||
assert.False(found, "Expected missing entry, got %s", v)
|
||||
|
||||
m.Set("foo", "bar")
|
||||
if m.Len() != 1 {
|
||||
t.Errorf("Expected %d entries, got %d", 1, m.Len())
|
||||
}
|
||||
if v, found := m.Get("foo"); !found {
|
||||
t.Errorf("Expected entry")
|
||||
} else if v != "bar" {
|
||||
t.Errorf("Expected bar, got %s", v)
|
||||
assert.Equal(1, m.Len())
|
||||
if v, found := m.Get("foo"); assert.True(found) {
|
||||
assert.Equal("bar", v)
|
||||
}
|
||||
|
||||
m.Set("foo", "baz")
|
||||
if m.Len() != 1 {
|
||||
t.Errorf("Expected %d entries, got %d", 1, m.Len())
|
||||
}
|
||||
if v, found := m.Get("foo"); !found {
|
||||
t.Errorf("Expected entry")
|
||||
} else if v != "baz" {
|
||||
t.Errorf("Expected baz, got %s", v)
|
||||
assert.Equal(1, m.Len())
|
||||
if v, found := m.Get("foo"); assert.True(found) {
|
||||
assert.Equal("baz", v)
|
||||
}
|
||||
|
||||
m.Set("lala", "lolo")
|
||||
if m.Len() != 2 {
|
||||
t.Errorf("Expected %d entries, got %d", 2, m.Len())
|
||||
}
|
||||
if v, found := m.Get("lala"); !found {
|
||||
t.Errorf("Expected entry")
|
||||
} else if v != "lolo" {
|
||||
t.Errorf("Expected lolo, got %s", v)
|
||||
assert.Equal(2, m.Len())
|
||||
if v, found := m.Get("lala"); assert.True(found) {
|
||||
assert.Equal("lolo", v)
|
||||
}
|
||||
|
||||
// Deleting missing entries doesn't do anything.
|
||||
m.Del("xyz")
|
||||
if m.Len() != 2 {
|
||||
t.Errorf("Expected %d entries, got %d", 2, m.Len())
|
||||
assert.Equal(2, m.Len())
|
||||
if v, found := m.Get("foo"); assert.True(found) {
|
||||
assert.Equal("baz", v)
|
||||
}
|
||||
if v, found := m.Get("foo"); !found {
|
||||
t.Errorf("Expected entry")
|
||||
} else if v != "baz" {
|
||||
t.Errorf("Expected baz, got %s", v)
|
||||
}
|
||||
if v, found := m.Get("lala"); !found {
|
||||
t.Errorf("Expected entry")
|
||||
} else if v != "lolo" {
|
||||
t.Errorf("Expected lolo, got %s", v)
|
||||
if v, found := m.Get("lala"); assert.True(found) {
|
||||
assert.Equal("lolo", v)
|
||||
}
|
||||
|
||||
m.Del("lala")
|
||||
if m.Len() != 1 {
|
||||
t.Errorf("Expected %d entries, got %d", 2, m.Len())
|
||||
}
|
||||
if v, found := m.Get("foo"); !found {
|
||||
t.Errorf("Expected entry")
|
||||
} else if v != "baz" {
|
||||
t.Errorf("Expected baz, got %s", v)
|
||||
}
|
||||
if v, found := m.Get("lala"); found {
|
||||
t.Errorf("Expected missing entry, got %s", v)
|
||||
assert.Equal(1, m.Len())
|
||||
if v, found := m.Get("foo"); assert.True(found) {
|
||||
assert.Equal("baz", v)
|
||||
}
|
||||
v, found = m.Get("lala")
|
||||
assert.False(found, "Expected missing entry, got %s", v)
|
||||
m.Clear()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
|
@ -108,18 +85,13 @@ func TestConcurrentStringStringMap(t *testing.T) {
|
|||
for y := 0; y < count; y = y + 1 {
|
||||
value := newRandomString(32)
|
||||
m.Set(key, value)
|
||||
if v, found := m.Get(key); !found {
|
||||
t.Errorf("Expected entry for key %s", key)
|
||||
return
|
||||
} else if v != value {
|
||||
t.Errorf("Expected value %s for key %s, got %s", value, key, v)
|
||||
if v, found := m.Get(key); !assert.True(found, "Expected entry for key %s", key) ||
|
||||
!assert.Equal(value, v, "Unexpected value for key %s", key) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}(x)
|
||||
}
|
||||
wg.Wait()
|
||||
if m.Len() != concurrency {
|
||||
t.Errorf("Expected %d entries, got %d", concurrency, m.Len())
|
||||
}
|
||||
assert.Equal(concurrency, m.Len())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,10 +22,11 @@
|
|||
package signaling
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/dlintw/goconf"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStringOptions(t *testing.T) {
|
||||
|
|
@ -46,13 +47,8 @@ func TestStringOptions(t *testing.T) {
|
|||
config.AddOption("default", "three", "3")
|
||||
|
||||
options, err := GetStringOptions(config, "foo", false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, options) {
|
||||
t.Errorf("expected %+v, got %+v", expected, options)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expected, options)
|
||||
}
|
||||
|
||||
func TestStringOptionWithEnv(t *testing.T) {
|
||||
|
|
@ -82,10 +78,8 @@ func TestStringOptionWithEnv(t *testing.T) {
|
|||
}
|
||||
for k, v := range expected {
|
||||
value, err := GetStringOptionWithEnv(config, "test", k)
|
||||
if err != nil {
|
||||
t.Errorf("expected value for %s, got %s", k, err)
|
||||
} else if value != v {
|
||||
t.Errorf("expected value %s for %s, got %s", v, k, value)
|
||||
if assert.NoError(t, err, "expected value for %s", k) {
|
||||
assert.Equal(t, v, value, "unexpected value for %s", k)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ package signaling
|
|||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDeferredExecutor_MultiClose(t *testing.T) {
|
||||
|
|
@ -53,9 +55,7 @@ func TestDeferredExecutor_QueueSize(t *testing.T) {
|
|||
b := time.Now()
|
||||
delta := b.Sub(a)
|
||||
// Allow one millisecond less delay to account for time variance on CI runners.
|
||||
if delta+time.Millisecond < delay {
|
||||
t.Errorf("Expected a delay of %s, got %s", delay, delta)
|
||||
}
|
||||
assert.GreaterOrEqual(t, delta+time.Millisecond, delay)
|
||||
}
|
||||
|
||||
func TestDeferredExecutor_Order(t *testing.T) {
|
||||
|
|
@ -81,9 +81,7 @@ func TestDeferredExecutor_Order(t *testing.T) {
|
|||
<-done
|
||||
|
||||
for x := 0; x < 10; x++ {
|
||||
if entries[x] != x {
|
||||
t.Errorf("Expected %d at position %d, got %d", x, x, entries[x])
|
||||
}
|
||||
assert.Equal(t, entries[x], x, "Unexpected at position %d", x)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -108,7 +106,7 @@ func TestDeferredExecutor_DeferAfterClose(t *testing.T) {
|
|||
e.Close()
|
||||
|
||||
e.Execute(func() {
|
||||
t.Error("method should not have been called")
|
||||
assert.Fail(t, "method should not have been called")
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -30,6 +30,9 @@ import (
|
|||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockDnsLookup struct {
|
||||
|
|
@ -82,20 +85,16 @@ func (m *mockDnsLookup) lookup(host string) ([]net.IP, error) {
|
|||
|
||||
func newDnsMonitorForTest(t *testing.T, interval time.Duration) *DnsMonitor {
|
||||
t.Helper()
|
||||
require := require.New(t)
|
||||
|
||||
monitor, err := NewDnsMonitor(interval)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
monitor.Stop()
|
||||
})
|
||||
|
||||
if err := monitor.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
require.NoError(monitor.Start())
|
||||
return monitor
|
||||
}
|
||||
|
||||
|
|
@ -148,20 +147,18 @@ func (r *dnsMonitorReceiver) OnLookup(entry *DnsMonitorEntry, all, add, keep, re
|
|||
expected := r.expected
|
||||
r.expected = nil
|
||||
if expected == expectNone {
|
||||
r.t.Errorf("expected no event, got %v", received)
|
||||
assert.Fail(r.t, "expected no event, got %v", received)
|
||||
return
|
||||
}
|
||||
|
||||
if expected == nil {
|
||||
if r.received != nil && !r.received.Equal(received) {
|
||||
r.t.Errorf("already received %v, got %v", r.received, received)
|
||||
assert.Fail(r.t, "already received %v, got %v", r.received, received)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if !expected.Equal(received) {
|
||||
r.t.Errorf("expected %v, got %v", expected, received)
|
||||
}
|
||||
assert.True(r.t, expected.Equal(received), "expected %v, got %v", expected, received)
|
||||
r.received = nil
|
||||
r.expected = nil
|
||||
}
|
||||
|
|
@ -178,7 +175,7 @@ func (r *dnsMonitorReceiver) WaitForExpected(ctx context.Context) {
|
|||
select {
|
||||
case <-ticker.C:
|
||||
case <-ctx.Done():
|
||||
r.t.Error(ctx.Err())
|
||||
assert.NoError(r.t, ctx.Err())
|
||||
abort = true
|
||||
}
|
||||
r.Lock()
|
||||
|
|
@ -191,7 +188,7 @@ func (r *dnsMonitorReceiver) Expect(all, add, keep, remove []net.IP) {
|
|||
defer r.Unlock()
|
||||
|
||||
if r.expected != nil && r.expected != expectNone {
|
||||
r.t.Errorf("didn't get previously expected %v", r.expected)
|
||||
assert.Fail(r.t, "didn't get previously expected %v", r.expected)
|
||||
}
|
||||
|
||||
expected := &dnsMonitorReceiverRecord{
|
||||
|
|
@ -214,7 +211,7 @@ func (r *dnsMonitorReceiver) ExpectNone() {
|
|||
defer r.Unlock()
|
||||
|
||||
if r.expected != nil && r.expected != expectNone {
|
||||
r.t.Errorf("didn't get previously expected %v", r.expected)
|
||||
assert.Fail(r.t, "didn't get previously expected %v", r.expected)
|
||||
}
|
||||
|
||||
r.expected = expectNone
|
||||
|
|
@ -241,9 +238,7 @@ func TestDnsMonitor(t *testing.T) {
|
|||
rec1.Expect(ips1, ips1, nil, nil)
|
||||
|
||||
entry1, err := monitor.Add("https://foo:12345", rec1.OnLookup)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
defer monitor.Remove(entry1)
|
||||
|
||||
rec1.WaitForExpected(ctx)
|
||||
|
|
@ -306,9 +301,7 @@ func TestDnsMonitorIP(t *testing.T) {
|
|||
rec1.Expect(ips, ips, nil, nil)
|
||||
|
||||
entry, err := monitor.Add(ip+":12345", rec1.OnLookup)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
defer monitor.Remove(entry)
|
||||
|
||||
rec1.WaitForExpected(ctx)
|
||||
|
|
@ -328,9 +321,7 @@ func TestDnsMonitorNoLookupIfEmpty(t *testing.T) {
|
|||
}
|
||||
|
||||
time.Sleep(10 * interval)
|
||||
if checked.Load() {
|
||||
t.Error("should not have checked hostnames")
|
||||
}
|
||||
assert.False(t, checked.Load(), "should not have checked hostnames")
|
||||
}
|
||||
|
||||
type deadlockMonitorReceiver struct {
|
||||
|
|
@ -355,8 +346,7 @@ func newDeadlockMonitorReceiver(t *testing.T, monitor *DnsMonitor) *deadlockMoni
|
|||
}
|
||||
|
||||
func (r *deadlockMonitorReceiver) OnLookup(entry *DnsMonitorEntry, all []net.IP, add []net.IP, keep []net.IP, remove []net.IP) {
|
||||
if r.closed.Load() {
|
||||
r.t.Error("received lookup after closed")
|
||||
if !assert.False(r.t, r.closed.Load(), "received lookup after closed") {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -385,8 +375,7 @@ func (r *deadlockMonitorReceiver) Start() {
|
|||
defer r.mu.Unlock()
|
||||
|
||||
entry, err := r.monitor.Add("foo", r.OnLookup)
|
||||
if err != nil {
|
||||
r.t.Errorf("error adding listener: %s", err)
|
||||
if !assert.NoError(r.t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -422,7 +411,5 @@ func TestDnsMonitorDeadlock(t *testing.T) {
|
|||
time.Sleep(10 * interval)
|
||||
monitor.mu.Lock()
|
||||
defer monitor.mu.Unlock()
|
||||
if len(monitor.hostnames) > 0 {
|
||||
t.Errorf("should have cleared hostnames, got %+v", monitor.hostnames)
|
||||
}
|
||||
assert.Empty(t, monitor.hostnames)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -34,6 +34,8 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/dlintw/goconf"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.etcd.io/etcd/api/v3/mvccpb"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
"go.etcd.io/etcd/server/v3/embed"
|
||||
|
|
@ -66,15 +68,14 @@ func isErrorAddressAlreadyInUse(err error) bool {
|
|||
}
|
||||
|
||||
func NewEtcdForTest(t *testing.T) *embed.Etcd {
|
||||
require := require.New(t)
|
||||
cfg := embed.NewConfig()
|
||||
cfg.Dir = t.TempDir()
|
||||
os.Chmod(cfg.Dir, 0700) // nolint
|
||||
cfg.LogLevel = "warn"
|
||||
|
||||
u, err := url.Parse(etcdListenUrl)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
// Find a free port to bind the server to.
|
||||
var etcd *embed.Etcd
|
||||
|
|
@ -94,14 +95,12 @@ func NewEtcdForTest(t *testing.T) *embed.Etcd {
|
|||
etcd, err = embed.StartEtcd(cfg)
|
||||
if isErrorAddressAlreadyInUse(err) {
|
||||
continue
|
||||
} else if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
require.NoError(err)
|
||||
break
|
||||
}
|
||||
if etcd == nil {
|
||||
t.Fatal("could not find free port")
|
||||
}
|
||||
require.NotNil(etcd, "could not find free port")
|
||||
|
||||
t.Cleanup(func() {
|
||||
etcd.Close()
|
||||
|
|
@ -121,13 +120,9 @@ func NewEtcdClientForTest(t *testing.T) (*embed.Etcd, *EtcdClient) {
|
|||
config.AddOption("etcd", "loglevel", "error")
|
||||
|
||||
client, err := NewEtcdClient(config, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
if err := client.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, client.Close())
|
||||
})
|
||||
return etcd, client
|
||||
}
|
||||
|
|
@ -149,54 +144,44 @@ func DeleteEtcdValue(etcd *embed.Etcd, key string) {
|
|||
func Test_EtcdClient_Get(t *testing.T) {
|
||||
t.Parallel()
|
||||
CatchLogForTest(t)
|
||||
assert := assert.New(t)
|
||||
etcd, client := NewEtcdClientForTest(t)
|
||||
|
||||
if response, err := client.Get(context.Background(), "foo"); err != nil {
|
||||
t.Error(err)
|
||||
} else if response.Count != 0 {
|
||||
t.Errorf("expected 0 response, got %d", response.Count)
|
||||
if response, err := client.Get(context.Background(), "foo"); assert.NoError(err) {
|
||||
assert.EqualValues(0, response.Count)
|
||||
}
|
||||
|
||||
SetEtcdValue(etcd, "foo", []byte("bar"))
|
||||
|
||||
if response, err := client.Get(context.Background(), "foo"); err != nil {
|
||||
t.Error(err)
|
||||
} else if response.Count != 1 {
|
||||
t.Errorf("expected 1 responses, got %d", response.Count)
|
||||
} else if string(response.Kvs[0].Key) != "foo" {
|
||||
t.Errorf("expected key \"foo\", got \"%s\"", string(response.Kvs[0].Key))
|
||||
} else if string(response.Kvs[0].Value) != "bar" {
|
||||
t.Errorf("expected value \"bar\", got \"%s\"", string(response.Kvs[0].Value))
|
||||
if response, err := client.Get(context.Background(), "foo"); assert.NoError(err) {
|
||||
if assert.EqualValues(1, response.Count) {
|
||||
assert.Equal("foo", string(response.Kvs[0].Key))
|
||||
assert.Equal("bar", string(response.Kvs[0].Value))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_EtcdClient_GetPrefix(t *testing.T) {
|
||||
t.Parallel()
|
||||
CatchLogForTest(t)
|
||||
assert := assert.New(t)
|
||||
etcd, client := NewEtcdClientForTest(t)
|
||||
|
||||
if response, err := client.Get(context.Background(), "foo"); err != nil {
|
||||
t.Error(err)
|
||||
} else if response.Count != 0 {
|
||||
t.Errorf("expected 0 response, got %d", response.Count)
|
||||
if response, err := client.Get(context.Background(), "foo"); assert.NoError(err) {
|
||||
assert.EqualValues(0, response.Count)
|
||||
}
|
||||
|
||||
SetEtcdValue(etcd, "foo", []byte("1"))
|
||||
SetEtcdValue(etcd, "foo/lala", []byte("2"))
|
||||
SetEtcdValue(etcd, "lala/foo", []byte("3"))
|
||||
|
||||
if response, err := client.Get(context.Background(), "foo", clientv3.WithPrefix()); err != nil {
|
||||
t.Error(err)
|
||||
} else if response.Count != 2 {
|
||||
t.Errorf("expected 2 responses, got %d", response.Count)
|
||||
} else if string(response.Kvs[0].Key) != "foo" {
|
||||
t.Errorf("expected key \"foo\", got \"%s\"", string(response.Kvs[0].Key))
|
||||
} else if string(response.Kvs[0].Value) != "1" {
|
||||
t.Errorf("expected value \"1\", got \"%s\"", string(response.Kvs[0].Value))
|
||||
} else if string(response.Kvs[1].Key) != "foo/lala" {
|
||||
t.Errorf("expected key \"foo/lala\", got \"%s\"", string(response.Kvs[1].Key))
|
||||
} else if string(response.Kvs[1].Value) != "2" {
|
||||
t.Errorf("expected value \"2\", got \"%s\"", string(response.Kvs[1].Value))
|
||||
if response, err := client.Get(context.Background(), "foo", clientv3.WithPrefix()); assert.NoError(err) {
|
||||
if assert.EqualValues(2, response.Count) {
|
||||
assert.Equal("foo", string(response.Kvs[0].Key))
|
||||
assert.Equal("1", string(response.Kvs[0].Value))
|
||||
assert.Equal("foo/lala", string(response.Kvs[1].Key))
|
||||
assert.Equal("2", string(response.Kvs[1].Value))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -237,8 +222,8 @@ func (l *EtcdClientTestListener) Close() {
|
|||
|
||||
func (l *EtcdClientTestListener) EtcdClientCreated(client *EtcdClient) {
|
||||
go func() {
|
||||
if err := client.WaitForConnection(l.ctx); err != nil {
|
||||
l.t.Errorf("error waiting for connection: %s", err)
|
||||
assert := assert.New(l.t)
|
||||
if err := client.WaitForConnection(l.ctx); !assert.NoError(err) {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -246,23 +231,17 @@ func (l *EtcdClientTestListener) EtcdClientCreated(client *EtcdClient) {
|
|||
defer cancel()
|
||||
|
||||
response, err := client.Get(ctx, "foo", clientv3.WithPrefix())
|
||||
if err != nil {
|
||||
l.t.Error(err)
|
||||
} else if response.Count != 1 {
|
||||
l.t.Errorf("expected 1 responses, got %d", response.Count)
|
||||
} else if string(response.Kvs[0].Key) != "foo/a" {
|
||||
l.t.Errorf("expected key \"foo/a\", got \"%s\"", string(response.Kvs[0].Key))
|
||||
} else if string(response.Kvs[0].Value) != "1" {
|
||||
l.t.Errorf("expected value \"1\", got \"%s\"", string(response.Kvs[0].Value))
|
||||
if assert.NoError(err) && assert.EqualValues(1, response.Count) {
|
||||
assert.Equal("foo/a", string(response.Kvs[0].Key))
|
||||
assert.Equal("1", string(response.Kvs[0].Value))
|
||||
}
|
||||
|
||||
close(l.initial)
|
||||
nextRevision := response.Header.Revision + 1
|
||||
for l.ctx.Err() == nil {
|
||||
var err error
|
||||
if nextRevision, err = client.Watch(clientv3.WithRequireLeader(l.ctx), "foo", nextRevision, l, clientv3.WithPrefix()); err != nil {
|
||||
l.t.Error(err)
|
||||
}
|
||||
nextRevision, err = client.Watch(clientv3.WithRequireLeader(l.ctx), "foo", nextRevision, l, clientv3.WithPrefix())
|
||||
assert.NoError(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
|
@ -296,6 +275,7 @@ func (l *EtcdClientTestListener) EtcdKeyDeleted(client *EtcdClient, key string,
|
|||
func Test_EtcdClient_Watch(t *testing.T) {
|
||||
t.Parallel()
|
||||
CatchLogForTest(t)
|
||||
assert := assert.New(t)
|
||||
etcd, client := NewEtcdClientForTest(t)
|
||||
|
||||
SetEtcdValue(etcd, "foo/a", []byte("1"))
|
||||
|
|
@ -310,31 +290,19 @@ func Test_EtcdClient_Watch(t *testing.T) {
|
|||
|
||||
SetEtcdValue(etcd, "foo/b", []byte("2"))
|
||||
event := <-listener.events
|
||||
if event.t != clientv3.EventTypePut {
|
||||
t.Errorf("expected type %d, got %d", clientv3.EventTypePut, event.t)
|
||||
} else if event.key != "foo/b" {
|
||||
t.Errorf("expected key %s, got %s", "foo/b", event.key)
|
||||
} else if event.value != "2" {
|
||||
t.Errorf("expected value %s, got %s", "2", event.value)
|
||||
}
|
||||
assert.Equal(clientv3.EventTypePut, event.t)
|
||||
assert.Equal("foo/b", event.key)
|
||||
assert.Equal("2", event.value)
|
||||
|
||||
SetEtcdValue(etcd, "foo/a", []byte("3"))
|
||||
event = <-listener.events
|
||||
if event.t != clientv3.EventTypePut {
|
||||
t.Errorf("expected type %d, got %d", clientv3.EventTypePut, event.t)
|
||||
} else if event.key != "foo/a" {
|
||||
t.Errorf("expected key %s, got %s", "foo/a", event.key)
|
||||
} else if event.value != "3" {
|
||||
t.Errorf("expected value %s, got %s", "3", event.value)
|
||||
}
|
||||
assert.Equal(clientv3.EventTypePut, event.t)
|
||||
assert.Equal("foo/a", event.key)
|
||||
assert.Equal("3", event.value)
|
||||
|
||||
DeleteEtcdValue(etcd, "foo/a")
|
||||
event = <-listener.events
|
||||
if event.t != clientv3.EventTypeDelete {
|
||||
t.Errorf("expected type %d, got %d", clientv3.EventTypeDelete, event.t)
|
||||
} else if event.key != "foo/a" {
|
||||
t.Errorf("expected key %s, got %s", "foo/a", event.key)
|
||||
} else if event.prevValue != "3" {
|
||||
t.Errorf("expected previous value %s, got %s", "3", event.prevValue)
|
||||
}
|
||||
assert.Equal(clientv3.EventTypeDelete, event.t)
|
||||
assert.Equal("foo/a", event.key)
|
||||
assert.Equal("3", event.prevValue)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -330,10 +330,10 @@ func Test_Federation(t *testing.T) {
|
|||
ctx2, cancel2 := context.WithTimeout(ctx, 200*time.Millisecond)
|
||||
defer cancel2()
|
||||
|
||||
if message, err := client2.RunUntilMessage(ctx2); err != nil && err != ErrNoMessageReceived && err != context.DeadlineExceeded {
|
||||
t.Error(err)
|
||||
} else {
|
||||
assert.Nil(message)
|
||||
if message, err := client2.RunUntilMessage(ctx2); err == nil {
|
||||
assert.Fail("expected no message, got %+v", message)
|
||||
} else if err != ErrNoMessageReceived && err != context.DeadlineExceeded {
|
||||
assert.NoError(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -345,10 +345,10 @@ func Test_Federation(t *testing.T) {
|
|||
ctx2, cancel2 := context.WithTimeout(ctx, 200*time.Millisecond)
|
||||
defer cancel2()
|
||||
|
||||
if message, err := client2.RunUntilMessage(ctx2); err != nil && err != ErrNoMessageReceived && err != context.DeadlineExceeded {
|
||||
t.Error(err)
|
||||
} else {
|
||||
assert.Nil(message)
|
||||
if message, err := client2.RunUntilMessage(ctx2); err == nil {
|
||||
assert.Fail("expected no message, got %+v", message)
|
||||
} else if err != ErrNoMessageReceived && err != context.DeadlineExceeded {
|
||||
assert.NoError(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,10 +23,12 @@ package signaling
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"path"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
@ -34,38 +36,31 @@ var (
|
|||
)
|
||||
|
||||
func TestFileWatcher_NotExist(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
tmpdir := t.TempDir()
|
||||
w, err := NewFileWatcher(path.Join(tmpdir, "test.txt"), func(filename string) {})
|
||||
if err == nil {
|
||||
t.Error("should not be able to watch non-existing files")
|
||||
if err := w.Close(); err != nil {
|
||||
t.Error(err)
|
||||
if w, err := NewFileWatcher(path.Join(tmpdir, "test.txt"), func(filename string) {}); !assert.ErrorIs(err, os.ErrNotExist) {
|
||||
if w != nil {
|
||||
assert.NoError(w.Close())
|
||||
}
|
||||
} else if !errors.Is(err, os.ErrNotExist) {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileWatcher_File(t *testing.T) {
|
||||
ensureNoGoroutinesLeak(t, func(t *testing.T) {
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
tmpdir := t.TempDir()
|
||||
filename := path.Join(tmpdir, "test.txt")
|
||||
if err := os.WriteFile(filename, []byte("Hello world!"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(os.WriteFile(filename, []byte("Hello world!"), 0644))
|
||||
|
||||
modified := make(chan struct{})
|
||||
w, err := NewFileWatcher(filename, func(filename string) {
|
||||
modified <- struct{}{}
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
defer w.Close()
|
||||
|
||||
if err := os.WriteFile(filename, []byte("Updated"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(os.WriteFile(filename, []byte("Updated"), 0644))
|
||||
<-modified
|
||||
|
||||
ctxTimeout, cancel := context.WithTimeout(context.Background(), testWatcherNoEventTimeout)
|
||||
|
|
@ -73,13 +68,11 @@ func TestFileWatcher_File(t *testing.T) {
|
|||
|
||||
select {
|
||||
case <-modified:
|
||||
t.Error("should not have received another event")
|
||||
assert.Fail("should not have received another event")
|
||||
case <-ctxTimeout.Done():
|
||||
}
|
||||
|
||||
if err := os.WriteFile(filename, []byte("Updated"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(os.WriteFile(filename, []byte("Updated"), 0644))
|
||||
<-modified
|
||||
|
||||
ctxTimeout, cancel = context.WithTimeout(context.Background(), testWatcherNoEventTimeout)
|
||||
|
|
@ -87,45 +80,39 @@ func TestFileWatcher_File(t *testing.T) {
|
|||
|
||||
select {
|
||||
case <-modified:
|
||||
t.Error("should not have received another event")
|
||||
assert.Fail("should not have received another event")
|
||||
case <-ctxTimeout.Done():
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFileWatcher_Rename(t *testing.T) {
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
tmpdir := t.TempDir()
|
||||
filename := path.Join(tmpdir, "test.txt")
|
||||
if err := os.WriteFile(filename, []byte("Hello world!"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(os.WriteFile(filename, []byte("Hello world!"), 0644))
|
||||
|
||||
modified := make(chan struct{})
|
||||
w, err := NewFileWatcher(filename, func(filename string) {
|
||||
modified <- struct{}{}
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
defer w.Close()
|
||||
|
||||
filename2 := path.Join(tmpdir, "test.txt.tmp")
|
||||
if err := os.WriteFile(filename2, []byte("Updated"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(os.WriteFile(filename2, []byte("Updated"), 0644))
|
||||
|
||||
ctxTimeout, cancel := context.WithTimeout(context.Background(), testWatcherNoEventTimeout)
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case <-modified:
|
||||
t.Error("should not have received another event")
|
||||
assert.Fail("should not have received another event")
|
||||
case <-ctxTimeout.Done():
|
||||
}
|
||||
|
||||
if err := os.Rename(filename2, filename); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(os.Rename(filename2, filename))
|
||||
<-modified
|
||||
|
||||
ctxTimeout, cancel = context.WithTimeout(context.Background(), testWatcherNoEventTimeout)
|
||||
|
|
@ -133,35 +120,29 @@ func TestFileWatcher_Rename(t *testing.T) {
|
|||
|
||||
select {
|
||||
case <-modified:
|
||||
t.Error("should not have received another event")
|
||||
assert.Fail("should not have received another event")
|
||||
case <-ctxTimeout.Done():
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileWatcher_Symlink(t *testing.T) {
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
tmpdir := t.TempDir()
|
||||
sourceFilename := path.Join(tmpdir, "test1.txt")
|
||||
if err := os.WriteFile(sourceFilename, []byte("Hello world!"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(os.WriteFile(sourceFilename, []byte("Hello world!"), 0644))
|
||||
|
||||
filename := path.Join(tmpdir, "symlink.txt")
|
||||
if err := os.Symlink(sourceFilename, filename); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(os.Symlink(sourceFilename, filename))
|
||||
|
||||
modified := make(chan struct{})
|
||||
w, err := NewFileWatcher(filename, func(filename string) {
|
||||
modified <- struct{}{}
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
defer w.Close()
|
||||
|
||||
if err := os.WriteFile(sourceFilename, []byte("Updated"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(os.WriteFile(sourceFilename, []byte("Updated"), 0644))
|
||||
<-modified
|
||||
|
||||
ctxTimeout, cancel := context.WithTimeout(context.Background(), testWatcherNoEventTimeout)
|
||||
|
|
@ -169,44 +150,34 @@ func TestFileWatcher_Symlink(t *testing.T) {
|
|||
|
||||
select {
|
||||
case <-modified:
|
||||
t.Error("should not have received another event")
|
||||
assert.Fail("should not have received another event")
|
||||
case <-ctxTimeout.Done():
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileWatcher_ChangeSymlinkTarget(t *testing.T) {
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
tmpdir := t.TempDir()
|
||||
sourceFilename1 := path.Join(tmpdir, "test1.txt")
|
||||
if err := os.WriteFile(sourceFilename1, []byte("Hello world!"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(os.WriteFile(sourceFilename1, []byte("Hello world!"), 0644))
|
||||
|
||||
sourceFilename2 := path.Join(tmpdir, "test2.txt")
|
||||
if err := os.WriteFile(sourceFilename2, []byte("Updated"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(os.WriteFile(sourceFilename2, []byte("Updated"), 0644))
|
||||
|
||||
filename := path.Join(tmpdir, "symlink.txt")
|
||||
if err := os.Symlink(sourceFilename1, filename); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(os.Symlink(sourceFilename1, filename))
|
||||
|
||||
modified := make(chan struct{})
|
||||
w, err := NewFileWatcher(filename, func(filename string) {
|
||||
modified <- struct{}{}
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
defer w.Close()
|
||||
|
||||
// Replace symlink by creating new one and rename it to the original target.
|
||||
if err := os.Symlink(sourceFilename2, filename+".tmp"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.Rename(filename+".tmp", filename); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(os.Symlink(sourceFilename2, filename+".tmp"))
|
||||
require.NoError(os.Rename(filename+".tmp", filename))
|
||||
<-modified
|
||||
|
||||
ctxTimeout, cancel := context.WithTimeout(context.Background(), testWatcherNoEventTimeout)
|
||||
|
|
@ -214,89 +185,73 @@ func TestFileWatcher_ChangeSymlinkTarget(t *testing.T) {
|
|||
|
||||
select {
|
||||
case <-modified:
|
||||
t.Error("should not have received another event")
|
||||
assert.Fail("should not have received another event")
|
||||
case <-ctxTimeout.Done():
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileWatcher_OtherSymlink(t *testing.T) {
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
tmpdir := t.TempDir()
|
||||
sourceFilename1 := path.Join(tmpdir, "test1.txt")
|
||||
if err := os.WriteFile(sourceFilename1, []byte("Hello world!"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(os.WriteFile(sourceFilename1, []byte("Hello world!"), 0644))
|
||||
|
||||
sourceFilename2 := path.Join(tmpdir, "test2.txt")
|
||||
if err := os.WriteFile(sourceFilename2, []byte("Updated"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(os.WriteFile(sourceFilename2, []byte("Updated"), 0644))
|
||||
|
||||
filename := path.Join(tmpdir, "symlink.txt")
|
||||
if err := os.Symlink(sourceFilename1, filename); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(os.Symlink(sourceFilename1, filename))
|
||||
|
||||
modified := make(chan struct{})
|
||||
w, err := NewFileWatcher(filename, func(filename string) {
|
||||
modified <- struct{}{}
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
defer w.Close()
|
||||
|
||||
if err := os.Symlink(sourceFilename2, filename+".tmp"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(os.Symlink(sourceFilename2, filename+".tmp"))
|
||||
|
||||
ctxTimeout, cancel := context.WithTimeout(context.Background(), testWatcherNoEventTimeout)
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case <-modified:
|
||||
t.Error("should not have received event for other symlink")
|
||||
assert.Fail("should not have received event for other symlink")
|
||||
case <-ctxTimeout.Done():
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileWatcher_RenameSymlinkTarget(t *testing.T) {
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
tmpdir := t.TempDir()
|
||||
sourceFilename1 := path.Join(tmpdir, "test1.txt")
|
||||
if err := os.WriteFile(sourceFilename1, []byte("Hello world!"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(os.WriteFile(sourceFilename1, []byte("Hello world!"), 0644))
|
||||
|
||||
filename := path.Join(tmpdir, "test.txt")
|
||||
if err := os.Symlink(sourceFilename1, filename); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(os.Symlink(sourceFilename1, filename))
|
||||
|
||||
modified := make(chan struct{})
|
||||
w, err := NewFileWatcher(filename, func(filename string) {
|
||||
modified <- struct{}{}
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
defer w.Close()
|
||||
|
||||
sourceFilename2 := path.Join(tmpdir, "test1.txt.tmp")
|
||||
if err := os.WriteFile(sourceFilename2, []byte("Updated"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(os.WriteFile(sourceFilename2, []byte("Updated"), 0644))
|
||||
|
||||
ctxTimeout, cancel := context.WithTimeout(context.Background(), testWatcherNoEventTimeout)
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case <-modified:
|
||||
t.Error("should not have received another event")
|
||||
assert.Fail("should not have received another event")
|
||||
case <-ctxTimeout.Done():
|
||||
}
|
||||
|
||||
if err := os.Rename(sourceFilename2, sourceFilename1); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(os.Rename(sourceFilename2, sourceFilename1))
|
||||
<-modified
|
||||
|
||||
ctxTimeout, cancel = context.WithTimeout(context.Background(), testWatcherNoEventTimeout)
|
||||
|
|
@ -304,7 +259,7 @@ func TestFileWatcher_RenameSymlinkTarget(t *testing.T) {
|
|||
|
||||
select {
|
||||
case <-modified:
|
||||
t.Error("should not have received another event")
|
||||
assert.Fail("should not have received another event")
|
||||
case <-ctxTimeout.Done():
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,55 +25,28 @@ import (
|
|||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFlags(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
var f Flags
|
||||
if f.Get() != 0 {
|
||||
t.Fatalf("Expected flags 0, got %d", f.Get())
|
||||
}
|
||||
if !f.Add(1) {
|
||||
t.Error("expected true")
|
||||
}
|
||||
if f.Get() != 1 {
|
||||
t.Fatalf("Expected flags 1, got %d", f.Get())
|
||||
}
|
||||
if f.Add(1) {
|
||||
t.Error("expected false")
|
||||
}
|
||||
if f.Get() != 1 {
|
||||
t.Fatalf("Expected flags 1, got %d", f.Get())
|
||||
}
|
||||
if !f.Add(2) {
|
||||
t.Error("expected true")
|
||||
}
|
||||
if f.Get() != 3 {
|
||||
t.Fatalf("Expected flags 3, got %d", f.Get())
|
||||
}
|
||||
if !f.Remove(1) {
|
||||
t.Error("expected true")
|
||||
}
|
||||
if f.Get() != 2 {
|
||||
t.Fatalf("Expected flags 2, got %d", f.Get())
|
||||
}
|
||||
if f.Remove(1) {
|
||||
t.Error("expected false")
|
||||
}
|
||||
if f.Get() != 2 {
|
||||
t.Fatalf("Expected flags 2, got %d", f.Get())
|
||||
}
|
||||
if !f.Add(3) {
|
||||
t.Error("expected true")
|
||||
}
|
||||
if f.Get() != 3 {
|
||||
t.Fatalf("Expected flags 3, got %d", f.Get())
|
||||
}
|
||||
if !f.Remove(1) {
|
||||
t.Error("expected true")
|
||||
}
|
||||
if f.Get() != 2 {
|
||||
t.Fatalf("Expected flags 2, got %d", f.Get())
|
||||
}
|
||||
assert.EqualValues(0, f.Get())
|
||||
assert.True(f.Add(1))
|
||||
assert.EqualValues(1, f.Get())
|
||||
assert.False(f.Add(1))
|
||||
assert.EqualValues(1, f.Get())
|
||||
assert.True(f.Add(2))
|
||||
assert.EqualValues(3, f.Get())
|
||||
assert.True(f.Remove(1))
|
||||
assert.EqualValues(2, f.Get())
|
||||
assert.False(f.Remove(1))
|
||||
assert.EqualValues(2, f.Get())
|
||||
assert.True(f.Add(3))
|
||||
assert.EqualValues(3, f.Get())
|
||||
assert.True(f.Remove(1))
|
||||
assert.EqualValues(2, f.Get())
|
||||
}
|
||||
|
||||
func runConcurrentFlags(t *testing.T, count int, f func()) {
|
||||
|
|
@ -106,9 +79,7 @@ func TestFlagsConcurrentAdd(t *testing.T) {
|
|||
added.Add(1)
|
||||
}
|
||||
})
|
||||
if added.Load() != 1 {
|
||||
t.Errorf("expected only one successfull attempt, got %d", added.Load())
|
||||
}
|
||||
assert.EqualValues(t, 1, added.Load(), "expected only one successfull attempt")
|
||||
}
|
||||
|
||||
func TestFlagsConcurrentRemove(t *testing.T) {
|
||||
|
|
@ -122,9 +93,7 @@ func TestFlagsConcurrentRemove(t *testing.T) {
|
|||
removed.Add(1)
|
||||
}
|
||||
})
|
||||
if removed.Load() != 1 {
|
||||
t.Errorf("expected only one successfull attempt, got %d", removed.Load())
|
||||
}
|
||||
assert.EqualValues(t, 1, removed.Load(), "expected only one successfull attempt")
|
||||
}
|
||||
|
||||
func TestFlagsConcurrentSet(t *testing.T) {
|
||||
|
|
@ -137,7 +106,5 @@ func TestFlagsConcurrentSet(t *testing.T) {
|
|||
set.Add(1)
|
||||
}
|
||||
})
|
||||
if set.Load() != 1 {
|
||||
t.Errorf("expected only one successfull attempt, got %d", set.Load())
|
||||
}
|
||||
assert.EqualValues(t, 1, set.Load(), "expected only one successfull attempt")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,6 +32,9 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func testGeoLookupReader(t *testing.T, reader *GeoLookup) {
|
||||
|
|
@ -47,14 +50,11 @@ func testGeoLookupReader(t *testing.T, reader *GeoLookup) {
|
|||
expected := expected
|
||||
t.Run(ip, func(t *testing.T) {
|
||||
country, err := reader.LookupCountry(net.ParseIP(ip))
|
||||
if err != nil {
|
||||
t.Errorf("Could not lookup %s: %s", ip, err)
|
||||
if !assert.NoError(t, err, "Could not lookup %s", ip) {
|
||||
return
|
||||
}
|
||||
|
||||
if country != expected {
|
||||
t.Errorf("Expected %s for %s, got %s", expected, ip, country)
|
||||
}
|
||||
assert.Equal(t, expected, country, "Unexpected country for %s", ip)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -79,36 +79,28 @@ func GetGeoIpUrlForTest(t *testing.T) string {
|
|||
|
||||
func TestGeoLookup(t *testing.T) {
|
||||
CatchLogForTest(t)
|
||||
require := require.New(t)
|
||||
reader, err := NewGeoLookupFromUrl(GetGeoIpUrlForTest(t))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
defer reader.Close()
|
||||
|
||||
if err := reader.Update(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(reader.Update())
|
||||
|
||||
testGeoLookupReader(t, reader)
|
||||
}
|
||||
|
||||
func TestGeoLookupCaching(t *testing.T) {
|
||||
CatchLogForTest(t)
|
||||
require := require.New(t)
|
||||
reader, err := NewGeoLookupFromUrl(GetGeoIpUrlForTest(t))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
defer reader.Close()
|
||||
|
||||
if err := reader.Update(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(reader.Update())
|
||||
|
||||
// Updating the second time will most likely return a "304 Not Modified".
|
||||
// Make sure this doesn't trigger an error.
|
||||
if err := reader.Update(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(reader.Update())
|
||||
}
|
||||
|
||||
func TestGeoLookupContinent(t *testing.T) {
|
||||
|
|
@ -125,13 +117,11 @@ func TestGeoLookupContinent(t *testing.T) {
|
|||
expected := expected
|
||||
t.Run(country, func(t *testing.T) {
|
||||
continents := LookupContinents(country)
|
||||
if len(continents) != len(expected) {
|
||||
t.Errorf("Continents didn't match for %s: got %s, expected %s", country, continents, expected)
|
||||
if !assert.Equal(t, len(expected), len(continents), "Continents didn't match for %s: got %s, expected %s", country, continents, expected) {
|
||||
return
|
||||
}
|
||||
for idx, c := range expected {
|
||||
if continents[idx] != c {
|
||||
t.Errorf("Continents didn't match for %s: got %s, expected %s", country, continents, expected)
|
||||
if !assert.Equal(t, c, continents[idx], "Continents didn't match for %s: got %s, expected %s", country, continents, expected) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
|
@ -142,36 +132,29 @@ func TestGeoLookupContinent(t *testing.T) {
|
|||
func TestGeoLookupCloseEmpty(t *testing.T) {
|
||||
CatchLogForTest(t)
|
||||
reader, err := NewGeoLookupFromUrl("ignore-url")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
reader.Close()
|
||||
}
|
||||
|
||||
func TestGeoLookupFromFile(t *testing.T) {
|
||||
CatchLogForTest(t)
|
||||
require := require.New(t)
|
||||
geoIpUrl := GetGeoIpUrlForTest(t)
|
||||
|
||||
resp, err := http.Get(geoIpUrl)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
body := resp.Body
|
||||
url := geoIpUrl
|
||||
if strings.HasSuffix(geoIpUrl, ".gz") {
|
||||
body, err = gzip.NewReader(body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
url = strings.TrimSuffix(url, ".gz")
|
||||
}
|
||||
|
||||
tmpfile, err := os.CreateTemp("", "geoipdb")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
t.Cleanup(func() {
|
||||
os.Remove(tmpfile.Name())
|
||||
})
|
||||
|
|
@ -183,9 +166,8 @@ func TestGeoLookupFromFile(t *testing.T) {
|
|||
header, err := tarfile.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
if !strings.HasSuffix(header.Name, ".mmdb") {
|
||||
continue
|
||||
|
|
@ -193,33 +175,25 @@ func TestGeoLookupFromFile(t *testing.T) {
|
|||
|
||||
if _, err := io.Copy(tmpfile, tarfile); err != nil {
|
||||
tmpfile.Close()
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := tmpfile.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
require.NoError(err)
|
||||
}
|
||||
require.NoError(tmpfile.Close())
|
||||
foundDatabase = true
|
||||
break
|
||||
}
|
||||
} else {
|
||||
if _, err := io.Copy(tmpfile, body); err != nil {
|
||||
tmpfile.Close()
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := tmpfile.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
require.NoError(err)
|
||||
}
|
||||
require.NoError(tmpfile.Close())
|
||||
foundDatabase = true
|
||||
}
|
||||
|
||||
if !foundDatabase {
|
||||
t.Fatalf("Did not find GeoIP database in download from %s", geoIpUrl)
|
||||
}
|
||||
require.True(foundDatabase, "Did not find GeoIP database in download from %s", geoIpUrl)
|
||||
|
||||
reader, err := NewGeoLookupFromFile(tmpfile.Name())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
defer reader.Close()
|
||||
|
||||
testGeoLookupReader(t, reader)
|
||||
|
|
@ -228,9 +202,7 @@ func TestGeoLookupFromFile(t *testing.T) {
|
|||
func TestIsValidContinent(t *testing.T) {
|
||||
for country, continents := range ContinentMap {
|
||||
for _, continent := range continents {
|
||||
if !IsValidContinent(continent) {
|
||||
t.Errorf("Continent %s of country %s is not valid", continent, country)
|
||||
}
|
||||
assert.True(t, IsValidContinent(continent), "Continent %s of country %s is not valid", continent, country)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -33,6 +33,8 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/dlintw/goconf"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.etcd.io/etcd/server/v3/embed"
|
||||
)
|
||||
|
||||
|
|
@ -52,9 +54,7 @@ func (c *GrpcClients) getWakeupChannelForTesting() <-chan struct{} {
|
|||
func NewGrpcClientsForTestWithConfig(t *testing.T, config *goconf.ConfigFile, etcdClient *EtcdClient) (*GrpcClients, *DnsMonitor) {
|
||||
dnsMonitor := newDnsMonitorForTest(t, time.Hour) // will be updated manually
|
||||
client, err := NewGrpcClients(config, etcdClient, dnsMonitor)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
client.Close()
|
||||
})
|
||||
|
|
@ -78,13 +78,9 @@ func NewGrpcClientsWithEtcdForTest(t *testing.T, etcd *embed.Etcd) (*GrpcClients
|
|||
config.AddOption("grpc", "targetprefix", "/grpctargets")
|
||||
|
||||
etcdClient, err := NewEtcdClient(config, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
if err := etcdClient.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, etcdClient.Close())
|
||||
})
|
||||
|
||||
return NewGrpcClientsForTestWithConfig(t, config, etcdClient)
|
||||
|
|
@ -107,7 +103,7 @@ func waitForEvent(ctx context.Context, t *testing.T, ch <-chan struct{}) {
|
|||
case <-ch:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
t.Error("timeout waiting for event")
|
||||
assert.Fail(t, "timeout waiting for event")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -125,19 +121,17 @@ func Test_GrpcClients_EtcdInitial(t *testing.T) {
|
|||
client, _ := NewGrpcClientsWithEtcdForTest(t, etcd)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
if err := client.WaitForInitialized(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, client.WaitForInitialized(ctx))
|
||||
|
||||
if clients := client.GetClients(); len(clients) != 2 {
|
||||
t.Errorf("Expected two clients, got %+v", clients)
|
||||
}
|
||||
clients := client.GetClients()
|
||||
assert.Len(t, clients, 2, "Expected two clients, got %+v", clients)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_GrpcClients_EtcdUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
CatchLogForTest(t)
|
||||
assert := assert.New(t)
|
||||
etcd := NewEtcdForTest(t)
|
||||
client, _ := NewGrpcClientsWithEtcdForTest(t, etcd)
|
||||
ch := client.getWakeupChannelForTesting()
|
||||
|
|
@ -145,55 +139,45 @@ func Test_GrpcClients_EtcdUpdate(t *testing.T) {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
|
||||
if clients := client.GetClients(); len(clients) != 0 {
|
||||
t.Errorf("Expected no clients, got %+v", clients)
|
||||
}
|
||||
assert.Empty(client.GetClients())
|
||||
|
||||
drainWakeupChannel(ch)
|
||||
_, addr1 := NewGrpcServerForTest(t)
|
||||
SetEtcdValue(etcd, "/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}"))
|
||||
waitForEvent(ctx, t, ch)
|
||||
if clients := client.GetClients(); len(clients) != 1 {
|
||||
t.Errorf("Expected one client, got %+v", clients)
|
||||
} else if clients[0].Target() != addr1 {
|
||||
t.Errorf("Expected target %s, got %s", addr1, clients[0].Target())
|
||||
if clients := client.GetClients(); assert.Len(clients, 1) {
|
||||
assert.Equal(addr1, clients[0].Target())
|
||||
}
|
||||
|
||||
drainWakeupChannel(ch)
|
||||
_, addr2 := NewGrpcServerForTest(t)
|
||||
SetEtcdValue(etcd, "/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}"))
|
||||
waitForEvent(ctx, t, ch)
|
||||
if clients := client.GetClients(); len(clients) != 2 {
|
||||
t.Errorf("Expected two clients, got %+v", clients)
|
||||
} else if clients[0].Target() != addr1 {
|
||||
t.Errorf("Expected target %s, got %s", addr1, clients[0].Target())
|
||||
} else if clients[1].Target() != addr2 {
|
||||
t.Errorf("Expected target %s, got %s", addr2, clients[1].Target())
|
||||
if clients := client.GetClients(); assert.Len(clients, 2) {
|
||||
assert.Equal(addr1, clients[0].Target())
|
||||
assert.Equal(addr2, clients[1].Target())
|
||||
}
|
||||
|
||||
drainWakeupChannel(ch)
|
||||
DeleteEtcdValue(etcd, "/grpctargets/one")
|
||||
waitForEvent(ctx, t, ch)
|
||||
if clients := client.GetClients(); len(clients) != 1 {
|
||||
t.Errorf("Expected one client, got %+v", clients)
|
||||
} else if clients[0].Target() != addr2 {
|
||||
t.Errorf("Expected target %s, got %s", addr2, clients[0].Target())
|
||||
if clients := client.GetClients(); assert.Len(clients, 1) {
|
||||
assert.Equal(addr2, clients[0].Target())
|
||||
}
|
||||
|
||||
drainWakeupChannel(ch)
|
||||
_, addr3 := NewGrpcServerForTest(t)
|
||||
SetEtcdValue(etcd, "/grpctargets/two", []byte("{\"address\":\""+addr3+"\"}"))
|
||||
waitForEvent(ctx, t, ch)
|
||||
if clients := client.GetClients(); len(clients) != 1 {
|
||||
t.Errorf("Expected one client, got %+v", clients)
|
||||
} else if clients[0].Target() != addr3 {
|
||||
t.Errorf("Expected target %s, got %s", addr3, clients[0].Target())
|
||||
if clients := client.GetClients(); assert.Len(clients, 1) {
|
||||
assert.Equal(addr3, clients[0].Target())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_GrpcClients_EtcdIgnoreSelf(t *testing.T) {
|
||||
t.Parallel()
|
||||
CatchLogForTest(t)
|
||||
assert := assert.New(t)
|
||||
etcd := NewEtcdForTest(t)
|
||||
client, _ := NewGrpcClientsWithEtcdForTest(t, etcd)
|
||||
ch := client.getWakeupChannelForTesting()
|
||||
|
|
@ -201,18 +185,14 @@ func Test_GrpcClients_EtcdIgnoreSelf(t *testing.T) {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
|
||||
if clients := client.GetClients(); len(clients) != 0 {
|
||||
t.Errorf("Expected no clients, got %+v", clients)
|
||||
}
|
||||
assert.Empty(client.GetClients())
|
||||
|
||||
drainWakeupChannel(ch)
|
||||
_, addr1 := NewGrpcServerForTest(t)
|
||||
SetEtcdValue(etcd, "/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}"))
|
||||
waitForEvent(ctx, t, ch)
|
||||
if clients := client.GetClients(); len(clients) != 1 {
|
||||
t.Errorf("Expected one client, got %+v", clients)
|
||||
} else if clients[0].Target() != addr1 {
|
||||
t.Errorf("Expected target %s, got %s", addr1, clients[0].Target())
|
||||
if clients := client.GetClients(); assert.Len(clients, 1) {
|
||||
assert.Equal(addr1, clients[0].Target())
|
||||
}
|
||||
|
||||
drainWakeupChannel(ch)
|
||||
|
|
@ -221,25 +201,22 @@ func Test_GrpcClients_EtcdIgnoreSelf(t *testing.T) {
|
|||
SetEtcdValue(etcd, "/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}"))
|
||||
waitForEvent(ctx, t, ch)
|
||||
client.selfCheckWaitGroup.Wait()
|
||||
if clients := client.GetClients(); len(clients) != 1 {
|
||||
t.Errorf("Expected one client, got %+v", clients)
|
||||
} else if clients[0].Target() != addr1 {
|
||||
t.Errorf("Expected target %s, got %s", addr1, clients[0].Target())
|
||||
if clients := client.GetClients(); assert.Len(clients, 1) {
|
||||
assert.Equal(addr1, clients[0].Target())
|
||||
}
|
||||
|
||||
drainWakeupChannel(ch)
|
||||
DeleteEtcdValue(etcd, "/grpctargets/two")
|
||||
waitForEvent(ctx, t, ch)
|
||||
if clients := client.GetClients(); len(clients) != 1 {
|
||||
t.Errorf("Expected one client, got %+v", clients)
|
||||
} else if clients[0].Target() != addr1 {
|
||||
t.Errorf("Expected target %s, got %s", addr1, clients[0].Target())
|
||||
if clients := client.GetClients(); assert.Len(clients, 1) {
|
||||
assert.Equal(addr1, clients[0].Target())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_GrpcClients_DnsDiscovery(t *testing.T) {
|
||||
CatchLogForTest(t)
|
||||
ensureNoGoroutinesLeak(t, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
lookup := newMockDnsLookupForTest(t)
|
||||
target := "testgrpc:12345"
|
||||
ip1 := net.ParseIP("192.168.0.1")
|
||||
|
|
@ -254,12 +231,9 @@ func Test_GrpcClients_DnsDiscovery(t *testing.T) {
|
|||
defer cancel()
|
||||
|
||||
dnsMonitor.checkHostnames()
|
||||
if clients := client.GetClients(); len(clients) != 1 {
|
||||
t.Errorf("Expected one client, got %+v", clients)
|
||||
} else if clients[0].Target() != targetWithIp1 {
|
||||
t.Errorf("Expected target %s, got %s", targetWithIp1, clients[0].Target())
|
||||
} else if !clients[0].ip.Equal(ip1) {
|
||||
t.Errorf("Expected IP %s, got %s", ip1, clients[0].ip)
|
||||
if clients := client.GetClients(); assert.Len(clients, 1) {
|
||||
assert.Equal(targetWithIp1, clients[0].Target())
|
||||
assert.True(clients[0].ip.Equal(ip1), "Expected IP %s, got %s", ip1, clients[0].ip)
|
||||
}
|
||||
|
||||
lookup.Set("testgrpc", []net.IP{ip1, ip2})
|
||||
|
|
@ -267,16 +241,11 @@ func Test_GrpcClients_DnsDiscovery(t *testing.T) {
|
|||
dnsMonitor.checkHostnames()
|
||||
waitForEvent(ctx, t, ch)
|
||||
|
||||
if clients := client.GetClients(); len(clients) != 2 {
|
||||
t.Errorf("Expected two client, got %+v", clients)
|
||||
} else if clients[0].Target() != targetWithIp1 {
|
||||
t.Errorf("Expected target %s, got %s", targetWithIp1, clients[0].Target())
|
||||
} else if !clients[0].ip.Equal(ip1) {
|
||||
t.Errorf("Expected IP %s, got %s", ip1, clients[0].ip)
|
||||
} else if clients[1].Target() != targetWithIp2 {
|
||||
t.Errorf("Expected target %s, got %s", targetWithIp2, clients[1].Target())
|
||||
} else if !clients[1].ip.Equal(ip2) {
|
||||
t.Errorf("Expected IP %s, got %s", ip2, clients[1].ip)
|
||||
if clients := client.GetClients(); assert.Len(clients, 2) {
|
||||
assert.Equal(targetWithIp1, clients[0].Target())
|
||||
assert.True(clients[0].ip.Equal(ip1), "Expected IP %s, got %s", ip1, clients[0].ip)
|
||||
assert.Equal(targetWithIp2, clients[1].Target())
|
||||
assert.True(clients[1].ip.Equal(ip2), "Expected IP %s, got %s", ip2, clients[1].ip)
|
||||
}
|
||||
|
||||
lookup.Set("testgrpc", []net.IP{ip2})
|
||||
|
|
@ -284,12 +253,9 @@ func Test_GrpcClients_DnsDiscovery(t *testing.T) {
|
|||
dnsMonitor.checkHostnames()
|
||||
waitForEvent(ctx, t, ch)
|
||||
|
||||
if clients := client.GetClients(); len(clients) != 1 {
|
||||
t.Errorf("Expected one client, got %+v", clients)
|
||||
} else if clients[0].Target() != targetWithIp2 {
|
||||
t.Errorf("Expected target %s, got %s", targetWithIp2, clients[0].Target())
|
||||
} else if !clients[0].ip.Equal(ip2) {
|
||||
t.Errorf("Expected IP %s, got %s", ip2, clients[0].ip)
|
||||
if clients := client.GetClients(); assert.Len(clients, 1) {
|
||||
assert.Equal(targetWithIp2, clients[0].Target())
|
||||
assert.True(clients[0].ip.Equal(ip2), "Expected IP %s, got %s", ip2, clients[0].ip)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
@ -297,6 +263,7 @@ func Test_GrpcClients_DnsDiscovery(t *testing.T) {
|
|||
func Test_GrpcClients_DnsDiscoveryInitialFailed(t *testing.T) {
|
||||
t.Parallel()
|
||||
CatchLogForTest(t)
|
||||
assert := assert.New(t)
|
||||
lookup := newMockDnsLookupForTest(t)
|
||||
target := "testgrpc:12345"
|
||||
ip1 := net.ParseIP("192.168.0.1")
|
||||
|
|
@ -309,39 +276,29 @@ func Test_GrpcClients_DnsDiscoveryInitialFailed(t *testing.T) {
|
|||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
if err := client.WaitForInitialized(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, client.WaitForInitialized(ctx))
|
||||
|
||||
if clients := client.GetClients(); len(clients) != 0 {
|
||||
t.Errorf("Expected no client, got %+v", clients)
|
||||
}
|
||||
assert.Empty(client.GetClients())
|
||||
|
||||
lookup.Set("testgrpc", []net.IP{ip1})
|
||||
drainWakeupChannel(ch)
|
||||
dnsMonitor.checkHostnames()
|
||||
waitForEvent(testCtx, t, ch)
|
||||
|
||||
if clients := client.GetClients(); len(clients) != 1 {
|
||||
t.Errorf("Expected one client, got %+v", clients)
|
||||
} else if clients[0].Target() != targetWithIp1 {
|
||||
t.Errorf("Expected target %s, got %s", targetWithIp1, clients[0].Target())
|
||||
} else if !clients[0].ip.Equal(ip1) {
|
||||
t.Errorf("Expected IP %s, got %s", ip1, clients[0].ip)
|
||||
if clients := client.GetClients(); assert.Len(clients, 1) {
|
||||
assert.Equal(targetWithIp1, clients[0].Target())
|
||||
assert.True(clients[0].ip.Equal(ip1), "Expected IP %s, got %s", ip1, clients[0].ip)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_GrpcClients_Encryption(t *testing.T) {
|
||||
CatchLogForTest(t)
|
||||
ensureNoGoroutinesLeak(t, func(t *testing.T) {
|
||||
require := require.New(t)
|
||||
serverKey, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
clientKey, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
serverCert := GenerateSelfSignedCertificateForTesting(t, 1024, "Server cert", serverKey)
|
||||
clientCert := GenerateSelfSignedCertificateForTesting(t, 1024, "Testing client", clientKey)
|
||||
|
|
@ -376,14 +333,11 @@ func Test_GrpcClients_Encryption(t *testing.T) {
|
|||
ctx, cancel1 := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel1()
|
||||
|
||||
if err := clients.WaitForInitialized(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(clients.WaitForInitialized(ctx))
|
||||
|
||||
for _, client := range clients.GetClients() {
|
||||
if _, err := client.GetServerId(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err := client.GetServerId(ctx)
|
||||
require.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -35,6 +35,8 @@ import (
|
|||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func (c *reloadableCredentials) WaitForCertificateReload(ctx context.Context) error {
|
||||
|
|
@ -72,9 +74,7 @@ func GenerateSelfSignedCertificateForTesting(t *testing.T, bits int, organizatio
|
|||
}
|
||||
|
||||
data, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
data = pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
|
|
@ -108,23 +108,15 @@ func WritePublicKey(key *rsa.PublicKey, filename string) error {
|
|||
|
||||
func replaceFile(t *testing.T, filename string, data []byte, perm fs.FileMode) {
|
||||
t.Helper()
|
||||
require := require.New(t)
|
||||
oldStat, err := os.Stat(filename)
|
||||
if err != nil {
|
||||
t.Fatalf("can't stat old file %s: %s", filename, err)
|
||||
return
|
||||
}
|
||||
require.NoError(err, "can't stat old file %s", filename)
|
||||
|
||||
for {
|
||||
if err := os.WriteFile(filename, data, perm); err != nil {
|
||||
t.Fatalf("can't write file %s: %s", filename, err)
|
||||
return
|
||||
}
|
||||
require.NoError(os.WriteFile(filename, data, perm), "can't write file %s", filename)
|
||||
|
||||
newStat, err := os.Stat(filename)
|
||||
if err != nil {
|
||||
t.Fatalf("can't stat new file %s: %s", filename, err)
|
||||
return
|
||||
}
|
||||
require.NoError(err, "can't stat new file %s", filename)
|
||||
|
||||
// We need different modification times.
|
||||
if !newStat.ModTime().Equal(oldStat.ModTime()) {
|
||||
|
|
|
|||
|
|
@ -37,6 +37,8 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/dlintw/goconf"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
)
|
||||
|
|
@ -67,23 +69,19 @@ func NewGrpcServerForTestWithConfig(t *testing.T, config *goconf.ConfigFile) (se
|
|||
server, err = NewGrpcServer(config)
|
||||
if isErrorAddressAlreadyInUse(err) {
|
||||
continue
|
||||
} else if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
break
|
||||
}
|
||||
|
||||
if server == nil {
|
||||
t.Fatal("could not find free port")
|
||||
}
|
||||
require.NotNil(t, server, "could not find free port")
|
||||
|
||||
// Don't match with own server id by default.
|
||||
server.serverId = "dont-match"
|
||||
|
||||
go func() {
|
||||
if err := server.Run(); err != nil {
|
||||
t.Errorf("could not start GRPC server: %s", err)
|
||||
}
|
||||
assert.NoError(t, server.Run(), "could not start GRPC server")
|
||||
}()
|
||||
|
||||
t.Cleanup(func() {
|
||||
|
|
@ -99,10 +97,10 @@ func NewGrpcServerForTest(t *testing.T) (server *GrpcServer, addr string) {
|
|||
|
||||
func Test_GrpcServer_ReloadCerts(t *testing.T) {
|
||||
CatchLogForTest(t)
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
org1 := "Testing certificate"
|
||||
cert1 := GenerateSelfSignedCertificateForTesting(t, 1024, org1, key)
|
||||
|
|
@ -124,24 +122,20 @@ func Test_GrpcServer_ReloadCerts(t *testing.T) {
|
|||
|
||||
cp1 := x509.NewCertPool()
|
||||
if !cp1.AppendCertsFromPEM(cert1) {
|
||||
t.Fatalf("could not add certificate")
|
||||
require.Fail("could not add certificate")
|
||||
}
|
||||
|
||||
cfg1 := &tls.Config{
|
||||
RootCAs: cp1,
|
||||
}
|
||||
conn1, err := tls.Dial("tcp", addr, cfg1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
defer conn1.Close() // nolint
|
||||
state1 := conn1.ConnectionState()
|
||||
if certs := state1.PeerCertificates; len(certs) == 0 {
|
||||
t.Errorf("expected certificates, got %+v", state1)
|
||||
} else if len(certs[0].Subject.Organization) == 0 {
|
||||
t.Errorf("expected organization, got %s", certs[0].Subject)
|
||||
} else if certs[0].Subject.Organization[0] != org1 {
|
||||
t.Errorf("expected organization %s, got %s", org1, certs[0].Subject)
|
||||
if certs := state1.PeerCertificates; assert.NotEmpty(certs) {
|
||||
if assert.NotEmpty(certs[0].Subject.Organization) {
|
||||
assert.Equal(org1, certs[0].Subject.Organization[0])
|
||||
}
|
||||
}
|
||||
|
||||
org2 := "Updated certificate"
|
||||
|
|
@ -151,43 +145,34 @@ func Test_GrpcServer_ReloadCerts(t *testing.T) {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := server.WaitForCertificateReload(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(server.WaitForCertificateReload(ctx))
|
||||
|
||||
cp2 := x509.NewCertPool()
|
||||
if !cp2.AppendCertsFromPEM(cert2) {
|
||||
t.Fatalf("could not add certificate")
|
||||
require.Fail("could not add certificate")
|
||||
}
|
||||
|
||||
cfg2 := &tls.Config{
|
||||
RootCAs: cp2,
|
||||
}
|
||||
conn2, err := tls.Dial("tcp", addr, cfg2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
defer conn2.Close() // nolint
|
||||
state2 := conn2.ConnectionState()
|
||||
if certs := state2.PeerCertificates; len(certs) == 0 {
|
||||
t.Errorf("expected certificates, got %+v", state2)
|
||||
} else if len(certs[0].Subject.Organization) == 0 {
|
||||
t.Errorf("expected organization, got %s", certs[0].Subject)
|
||||
} else if certs[0].Subject.Organization[0] != org2 {
|
||||
t.Errorf("expected organization %s, got %s", org2, certs[0].Subject)
|
||||
if certs := state2.PeerCertificates; assert.NotEmpty(certs) {
|
||||
if assert.NotEmpty(certs[0].Subject.Organization) {
|
||||
assert.Equal(org2, certs[0].Subject.Organization[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_GrpcServer_ReloadCA(t *testing.T) {
|
||||
CatchLogForTest(t)
|
||||
require := require.New(t)
|
||||
serverKey, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
clientKey, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
serverCert := GenerateSelfSignedCertificateForTesting(t, 1024, "Server cert", serverKey)
|
||||
org1 := "Testing client"
|
||||
|
|
@ -213,65 +198,53 @@ func Test_GrpcServer_ReloadCA(t *testing.T) {
|
|||
|
||||
pool := x509.NewCertPool()
|
||||
if !pool.AppendCertsFromPEM(serverCert) {
|
||||
t.Fatalf("could not add certificate")
|
||||
require.Fail("could not add certificate")
|
||||
}
|
||||
|
||||
pair1, err := tls.X509KeyPair(clientCert1, pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(clientKey),
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
cfg1 := &tls.Config{
|
||||
RootCAs: pool,
|
||||
Certificates: []tls.Certificate{pair1},
|
||||
}
|
||||
client1, err := NewGrpcClient(addr, nil, grpc.WithTransportCredentials(credentials.NewTLS(cfg1)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
defer client1.Close() // nolint
|
||||
|
||||
ctx1, cancel1 := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel1()
|
||||
|
||||
if _, err := client1.GetServerId(ctx1); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = client1.GetServerId(ctx1)
|
||||
require.NoError(err)
|
||||
|
||||
org2 := "Updated client"
|
||||
clientCert2 := GenerateSelfSignedCertificateForTesting(t, 1024, org2, clientKey)
|
||||
replaceFile(t, caFile, clientCert2, 0755)
|
||||
|
||||
if err := server.WaitForCertPoolReload(ctx1); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(server.WaitForCertPoolReload(ctx1))
|
||||
|
||||
pair2, err := tls.X509KeyPair(clientCert2, pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(clientKey),
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
cfg2 := &tls.Config{
|
||||
RootCAs: pool,
|
||||
Certificates: []tls.Certificate{pair2},
|
||||
}
|
||||
client2, err := NewGrpcClient(addr, nil, grpc.WithTransportCredentials(credentials.NewTLS(cfg2)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
defer client2.Close() // nolint
|
||||
|
||||
ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel2()
|
||||
|
||||
// This will fail if the CA certificate has not been reloaded by the server.
|
||||
if _, err := client2.GetServerId(ctx2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = client2.GetServerId(ctx2)
|
||||
require.NoError(err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,52 +26,42 @@ import (
|
|||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHttpClientPool(t *testing.T) {
|
||||
t.Parallel()
|
||||
if _, err := NewHttpClientPool(0, false); err == nil {
|
||||
t.Error("should not be possible to create empty pool")
|
||||
}
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
_, err := NewHttpClientPool(0, false)
|
||||
assert.Error(err)
|
||||
|
||||
pool, err := NewHttpClientPool(1, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
u, err := url.Parse("http://localhost/foo/bar")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
ctx := context.Background()
|
||||
if _, _, err := pool.Get(ctx, u); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, _, err = pool.Get(ctx, u)
|
||||
require.NoError(err)
|
||||
|
||||
ctx2, cancel := context.WithTimeout(ctx, 10*time.Millisecond)
|
||||
defer cancel()
|
||||
if _, _, err := pool.Get(ctx2, u); err == nil {
|
||||
t.Error("fetching from empty pool should have timed out")
|
||||
} else if err != context.DeadlineExceeded {
|
||||
t.Errorf("fetching from empty pool should have timed out, got %s", err)
|
||||
}
|
||||
_, _, err = pool.Get(ctx2, u)
|
||||
assert.ErrorIs(err, context.DeadlineExceeded)
|
||||
|
||||
// Pools are separated by hostname, so can get client for different host.
|
||||
u2, err := url.Parse("http://local.host/foo/bar")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
if _, _, err := pool.Get(ctx, u2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, _, err = pool.Get(ctx, u2)
|
||||
require.NoError(err)
|
||||
|
||||
ctx3, cancel2 := context.WithTimeout(ctx, 10*time.Millisecond)
|
||||
defer cancel2()
|
||||
if _, _, err := pool.Get(ctx3, u2); err == nil {
|
||||
t.Error("fetching from empty pool should have timed out")
|
||||
} else if err != context.DeadlineExceeded {
|
||||
t.Errorf("fetching from empty pool should have timed out, got %s", err)
|
||||
}
|
||||
_, _, err = pool.Get(ctx3, u2)
|
||||
assert.ErrorIs(err, context.DeadlineExceeded)
|
||||
}
|
||||
|
|
|
|||
3841
hub_test.go
3841
hub_test.go
File diff suppressed because it is too large
Load diff
88
lru_test.go
88
lru_test.go
|
|
@ -24,46 +24,36 @@ package signaling
|
|||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestLruUnbound(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
lru := NewLruCache(0)
|
||||
count := 10
|
||||
for i := 0; i < count; i++ {
|
||||
key := fmt.Sprintf("%d", i)
|
||||
lru.Set(key, i)
|
||||
}
|
||||
if lru.Len() != count {
|
||||
t.Errorf("Expected %d entries, got %d", count, lru.Len())
|
||||
}
|
||||
assert.Equal(count, lru.Len())
|
||||
for i := 0; i < count; i++ {
|
||||
key := fmt.Sprintf("%d", i)
|
||||
value := lru.Get(key)
|
||||
if value == nil {
|
||||
t.Errorf("No value found for %s", key)
|
||||
continue
|
||||
} else if value.(int) != i {
|
||||
t.Errorf("Expected value to be %d, got %d", value.(int), i)
|
||||
if value := lru.Get(key); assert.NotNil(value, "No value found for %s", key) {
|
||||
assert.EqualValues(i, value)
|
||||
}
|
||||
}
|
||||
// The first key ("0") is now the oldest.
|
||||
lru.RemoveOldest()
|
||||
if lru.Len() != count-1 {
|
||||
t.Errorf("Expected %d entries after RemoveOldest, got %d", count-1, lru.Len())
|
||||
}
|
||||
assert.Equal(count-1, lru.Len())
|
||||
for i := 0; i < count; i++ {
|
||||
key := fmt.Sprintf("%d", i)
|
||||
value := lru.Get(key)
|
||||
if i == 0 {
|
||||
if value != nil {
|
||||
t.Errorf("The value for key %s should have been removed", key)
|
||||
}
|
||||
assert.Nil(value, "The value for key %s should have been removed", key)
|
||||
continue
|
||||
} else if value == nil {
|
||||
t.Errorf("No value found for %s", key)
|
||||
continue
|
||||
} else if value.(int) != i {
|
||||
t.Errorf("Expected value to be %d, got %d", value.(int), i)
|
||||
} else if assert.NotNil(value, "No value found for %s", key) {
|
||||
assert.EqualValues(i, value)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -74,66 +64,47 @@ func TestLruUnbound(t *testing.T) {
|
|||
key := fmt.Sprintf("%d", i)
|
||||
lru.Set(key, i)
|
||||
}
|
||||
if lru.Len() != count-1 {
|
||||
t.Errorf("Expected %d entries, got %d", count-1, lru.Len())
|
||||
}
|
||||
assert.Equal(count-1, lru.Len())
|
||||
// NOTE: The same ordering as the Set calls above.
|
||||
for i := count - 1; i >= 1; i-- {
|
||||
key := fmt.Sprintf("%d", i)
|
||||
value := lru.Get(key)
|
||||
if value == nil {
|
||||
t.Errorf("No value found for %s", key)
|
||||
continue
|
||||
} else if value.(int) != i {
|
||||
t.Errorf("Expected value to be %d, got %d", value.(int), i)
|
||||
if value := lru.Get(key); assert.NotNil(value, "No value found for %s", key) {
|
||||
assert.EqualValues(i, value)
|
||||
}
|
||||
}
|
||||
|
||||
// The last key ("9") is now the oldest.
|
||||
lru.RemoveOldest()
|
||||
if lru.Len() != count-2 {
|
||||
t.Errorf("Expected %d entries after RemoveOldest, got %d", count-2, lru.Len())
|
||||
}
|
||||
assert.Equal(count-2, lru.Len())
|
||||
for i := 0; i < count; i++ {
|
||||
key := fmt.Sprintf("%d", i)
|
||||
value := lru.Get(key)
|
||||
if i == 0 || i == count-1 {
|
||||
if value != nil {
|
||||
t.Errorf("The value for key %s should have been removed", key)
|
||||
}
|
||||
assert.Nil(value, "The value for key %s should have been removed", key)
|
||||
continue
|
||||
} else if value == nil {
|
||||
t.Errorf("No value found for %s", key)
|
||||
continue
|
||||
} else if value.(int) != i {
|
||||
t.Errorf("Expected value to be %d, got %d", value.(int), i)
|
||||
} else if assert.NotNil(value, "No value found for %s", key) {
|
||||
assert.EqualValues(i, value)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove an arbitrary key from the cache
|
||||
key := fmt.Sprintf("%d", count/2)
|
||||
lru.Remove(key)
|
||||
if lru.Len() != count-3 {
|
||||
t.Errorf("Expected %d entries after RemoveOldest, got %d", count-3, lru.Len())
|
||||
}
|
||||
assert.Equal(count-3, lru.Len())
|
||||
for i := 0; i < count; i++ {
|
||||
key := fmt.Sprintf("%d", i)
|
||||
value := lru.Get(key)
|
||||
if i == 0 || i == count-1 || i == count/2 {
|
||||
if value != nil {
|
||||
t.Errorf("The value for key %s should have been removed", key)
|
||||
}
|
||||
assert.Nil(value, "The value for key %s should have been removed", key)
|
||||
continue
|
||||
} else if value == nil {
|
||||
t.Errorf("No value found for %s", key)
|
||||
continue
|
||||
} else if value.(int) != i {
|
||||
t.Errorf("Expected value to be %d, got %d", value.(int), i)
|
||||
} else if assert.NotNil(value, "No value found for %s", key) {
|
||||
assert.EqualValues(i, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLruBound(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
size := 2
|
||||
lru := NewLruCache(size)
|
||||
count := 10
|
||||
|
|
@ -141,23 +112,16 @@ func TestLruBound(t *testing.T) {
|
|||
key := fmt.Sprintf("%d", i)
|
||||
lru.Set(key, i)
|
||||
}
|
||||
if lru.Len() != size {
|
||||
t.Errorf("Expected %d entries, got %d", size, lru.Len())
|
||||
}
|
||||
assert.Equal(size, lru.Len())
|
||||
// Only the last "size" entries have been stored.
|
||||
for i := 0; i < count; i++ {
|
||||
key := fmt.Sprintf("%d", i)
|
||||
value := lru.Get(key)
|
||||
if i < count-size {
|
||||
if value != nil {
|
||||
t.Errorf("The value for key %s should have been removed", key)
|
||||
}
|
||||
assert.Nil(value, "The value for key %s should have been removed", key)
|
||||
continue
|
||||
} else if value == nil {
|
||||
t.Errorf("No value found for %s", key)
|
||||
continue
|
||||
} else if value.(int) != i {
|
||||
t.Errorf("Expected value to be %d, got %d", value.(int), i)
|
||||
} else if assert.NotNil(value, "No value found for %s", key) {
|
||||
assert.EqualValues(i, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,9 +23,12 @@ package signaling
|
|||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetFmtpValueH264(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
testcases := []struct {
|
||||
fmtp string
|
||||
profile string
|
||||
|
|
@ -51,16 +54,17 @@ func TestGetFmtpValueH264(t *testing.T) {
|
|||
for _, tc := range testcases {
|
||||
value, found := getFmtpValue(tc.fmtp, "profile-level-id")
|
||||
if !found && tc.profile != "" {
|
||||
t.Errorf("did not find profile \"%s\" in \"%s\"", tc.profile, tc.fmtp)
|
||||
assert.Fail("did not find profile \"%s\" in \"%s\"", tc.profile, tc.fmtp)
|
||||
} else if found && tc.profile == "" {
|
||||
t.Errorf("did not expect profile in \"%s\" but got \"%s\"", tc.fmtp, value)
|
||||
assert.Fail("did not expect profile in \"%s\" but got \"%s\"", tc.fmtp, value)
|
||||
} else if found && tc.profile != value {
|
||||
t.Errorf("expected profile \"%s\" in \"%s\" but got \"%s\"", tc.profile, tc.fmtp, value)
|
||||
assert.Fail("expected profile \"%s\" in \"%s\" but got \"%s\"", tc.profile, tc.fmtp, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFmtpValueVP9(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
testcases := []struct {
|
||||
fmtp string
|
||||
profile string
|
||||
|
|
@ -82,11 +86,11 @@ func TestGetFmtpValueVP9(t *testing.T) {
|
|||
for _, tc := range testcases {
|
||||
value, found := getFmtpValue(tc.fmtp, "profile-id")
|
||||
if !found && tc.profile != "" {
|
||||
t.Errorf("did not find profile \"%s\" in \"%s\"", tc.profile, tc.fmtp)
|
||||
assert.Fail("did not find profile \"%s\" in \"%s\"", tc.profile, tc.fmtp)
|
||||
} else if found && tc.profile == "" {
|
||||
t.Errorf("did not expect profile in \"%s\" but got \"%s\"", tc.fmtp, value)
|
||||
assert.Fail("did not expect profile in \"%s\" but got \"%s\"", tc.fmtp, value)
|
||||
} else if found && tc.profile != value {
|
||||
t.Errorf("expected profile \"%s\" in \"%s\" but got \"%s\"", tc.profile, tc.fmtp, value)
|
||||
assert.Fail("expected profile \"%s\" in \"%s\" but got \"%s\"", tc.profile, tc.fmtp, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -41,6 +41,8 @@ import (
|
|||
|
||||
"github.com/dlintw/goconf"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.etcd.io/etcd/server/v3/embed"
|
||||
)
|
||||
|
||||
|
|
@ -105,7 +107,7 @@ func Test_sortConnectionsForCountry(t *testing.T) {
|
|||
sorted := sortConnectionsForCountry(test[0], country, nil)
|
||||
for idx, conn := range sorted {
|
||||
if test[1][idx] != conn {
|
||||
t.Errorf("Index %d for %s: expected %s, got %s", idx, country, test[1][idx].Country(), conn.Country())
|
||||
assert.Fail(t, "Index %d for %s: expected %s, got %s", idx, country, test[1][idx].Country(), conn.Country())
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
@ -179,7 +181,7 @@ func Test_sortConnectionsForCountryWithOverride(t *testing.T) {
|
|||
sorted := sortConnectionsForCountry(test[0], country, continentMap)
|
||||
for idx, conn := range sorted {
|
||||
if test[1][idx] != conn {
|
||||
t.Errorf("Index %d for %s: expected %s, got %s", idx, country, test[1][idx].Country(), conn.Country())
|
||||
assert.Fail(t, "Index %d for %s: expected %s, got %s", idx, country, test[1][idx].Country(), conn.Country())
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
@ -342,7 +344,7 @@ func (c *testProxyServerClient) handleSendMessageError(fmt string, msg *ProxySer
|
|||
c.t.Helper()
|
||||
|
||||
if !errors.Is(err, websocket.ErrCloseSent) || msg.Type != "event" || msg.Event.Type != "update-load" {
|
||||
c.t.Errorf(fmt, msg, err)
|
||||
assert.Fail(c.t, fmt, msg, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -385,6 +387,7 @@ func (c *testProxyServerClient) run() {
|
|||
c.ws = nil
|
||||
}()
|
||||
c.processMessage = c.processHello
|
||||
assert := assert.New(c.t)
|
||||
for {
|
||||
c.mu.Lock()
|
||||
ws := c.ws
|
||||
|
|
@ -396,36 +399,31 @@ func (c *testProxyServerClient) run() {
|
|||
msgType, reader, err := ws.NextReader()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) {
|
||||
c.t.Error(err)
|
||||
assert.NoError(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
c.t.Error(err)
|
||||
if !assert.NoError(err) {
|
||||
continue
|
||||
}
|
||||
|
||||
if msgType != websocket.TextMessage {
|
||||
c.t.Errorf("unexpected message type %q (%s)", msgType, string(body))
|
||||
if !assert.Equal(websocket.TextMessage, msgType, "unexpected message type for %s", string(body)) {
|
||||
continue
|
||||
}
|
||||
|
||||
var msg ProxyClientMessage
|
||||
if err := json.Unmarshal(body, &msg); err != nil {
|
||||
c.t.Errorf("could not decode message %s: %s", string(body), err)
|
||||
if err := json.Unmarshal(body, &msg); !assert.NoError(err, "could not decode message %s", string(body)) {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := msg.CheckValid(); err != nil {
|
||||
c.t.Errorf("invalid message %s: %s", string(body), err)
|
||||
if err := msg.CheckValid(); !assert.NoError(err, "invalid message %s", string(body)) {
|
||||
continue
|
||||
}
|
||||
|
||||
response, err := c.processMessage(&msg)
|
||||
if err != nil {
|
||||
c.t.Error(err)
|
||||
if !assert.NoError(err) {
|
||||
continue
|
||||
}
|
||||
|
||||
|
|
@ -605,8 +603,7 @@ func (h *TestProxyServerHandler) removeClient(client *testProxyServerClient) {
|
|||
|
||||
func (h *TestProxyServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
ws, err := h.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
h.t.Error(err)
|
||||
if !assert.NoError(h.t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -658,15 +655,14 @@ type proxyTestOptions struct {
|
|||
|
||||
func newMcuProxyForTestWithOptions(t *testing.T, options proxyTestOptions) *mcuProxy {
|
||||
t.Helper()
|
||||
require := require.New(t)
|
||||
if options.etcd == nil {
|
||||
options.etcd = NewEtcdForTest(t)
|
||||
}
|
||||
grpcClients, dnsMonitor := NewGrpcClientsWithEtcdForTest(t, options.etcd)
|
||||
|
||||
tokenKey, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
dir := t.TempDir()
|
||||
privkeyFile := path.Join(dir, "privkey.pem")
|
||||
pubkeyFile := path.Join(dir, "pubkey.pem")
|
||||
|
|
@ -696,19 +692,13 @@ func newMcuProxyForTestWithOptions(t *testing.T, options proxyTestOptions) *mcuP
|
|||
etcdConfig.AddOption("etcd", "loglevel", "error")
|
||||
|
||||
etcdClient, err := NewEtcdClient(etcdConfig, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
t.Cleanup(func() {
|
||||
if err := etcdClient.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, etcdClient.Close())
|
||||
})
|
||||
|
||||
mcu, err := NewMcuProxy(cfg, etcdClient, grpcClients, dnsMonitor)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
t.Cleanup(func() {
|
||||
mcu.Stop()
|
||||
})
|
||||
|
|
@ -716,20 +706,14 @@ func newMcuProxyForTestWithOptions(t *testing.T, options proxyTestOptions) *mcuP
|
|||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := mcu.Start(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(mcu.Start(ctx))
|
||||
|
||||
proxy := mcu.(*mcuProxy)
|
||||
|
||||
if err := proxy.WaitForConnections(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(proxy.WaitForConnections(ctx))
|
||||
|
||||
for len(waitingMap) > 0 {
|
||||
if err := ctx.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(ctx.Err())
|
||||
|
||||
for u := range waitingMap {
|
||||
proxy.connectionsMu.RLock()
|
||||
|
|
@ -782,9 +766,7 @@ func Test_ProxyPublisherSubscriber(t *testing.T) {
|
|||
}
|
||||
|
||||
pub, err := mcu.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
defer pub.Close(context.Background())
|
||||
|
||||
|
|
@ -795,9 +777,7 @@ func Test_ProxyPublisherSubscriber(t *testing.T) {
|
|||
country: "DE",
|
||||
}
|
||||
sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
defer sub.Close(context.Background())
|
||||
}
|
||||
|
|
@ -829,8 +809,7 @@ func Test_ProxyWaitForPublisher(t *testing.T) {
|
|||
go func() {
|
||||
defer close(done)
|
||||
sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -841,14 +820,12 @@ func Test_ProxyWaitForPublisher(t *testing.T) {
|
|||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
pub, err := mcu.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
t.Error(ctx.Err())
|
||||
assert.NoError(t, ctx.Err())
|
||||
}
|
||||
defer pub.Close(context.Background())
|
||||
}
|
||||
|
|
@ -875,9 +852,7 @@ func Test_ProxyPublisherBandwidth(t *testing.T) {
|
|||
country: "DE",
|
||||
}
|
||||
pub1, err := mcu.NewPublisher(ctx, pub1Listener, pub1Id, pub1Sid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pub1Initiator)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
defer pub1.Close(context.Background())
|
||||
|
||||
|
|
@ -914,15 +889,11 @@ func Test_ProxyPublisherBandwidth(t *testing.T) {
|
|||
country: "DE",
|
||||
}
|
||||
pub2, err := mcu.NewPublisher(ctx, pub2Listener, pub2Id, pub2id, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pub2Initiator)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
defer pub2.Close(context.Background())
|
||||
|
||||
if pub1.(*mcuProxyPublisher).conn.rawUrl == pub2.(*mcuProxyPublisher).conn.rawUrl {
|
||||
t.Errorf("servers should be different, got %s", pub1.(*mcuProxyPublisher).conn.rawUrl)
|
||||
}
|
||||
assert.NotEqual(t, pub1.(*mcuProxyPublisher).conn.rawUrl, pub2.(*mcuProxyPublisher).conn.rawUrl)
|
||||
}
|
||||
|
||||
func Test_ProxyPublisherBandwidthOverload(t *testing.T) {
|
||||
|
|
@ -947,9 +918,7 @@ func Test_ProxyPublisherBandwidthOverload(t *testing.T) {
|
|||
country: "DE",
|
||||
}
|
||||
pub1, err := mcu.NewPublisher(ctx, pub1Listener, pub1Id, pub1Sid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pub1Initiator)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
defer pub1.Close(context.Background())
|
||||
|
||||
|
|
@ -989,15 +958,11 @@ func Test_ProxyPublisherBandwidthOverload(t *testing.T) {
|
|||
country: "DE",
|
||||
}
|
||||
pub2, err := mcu.NewPublisher(ctx, pub2Listener, pub2Id, pub2id, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pub2Initiator)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
defer pub2.Close(context.Background())
|
||||
|
||||
if pub1.(*mcuProxyPublisher).conn.rawUrl != pub2.(*mcuProxyPublisher).conn.rawUrl {
|
||||
t.Errorf("servers should be the same, got %s / %s", pub1.(*mcuProxyPublisher).conn.rawUrl, pub2.(*mcuProxyPublisher).conn.rawUrl)
|
||||
}
|
||||
assert.Equal(t, pub1.(*mcuProxyPublisher).conn.rawUrl, pub2.(*mcuProxyPublisher).conn.rawUrl)
|
||||
}
|
||||
|
||||
func Test_ProxyPublisherLoad(t *testing.T) {
|
||||
|
|
@ -1022,9 +987,7 @@ func Test_ProxyPublisherLoad(t *testing.T) {
|
|||
country: "DE",
|
||||
}
|
||||
pub1, err := mcu.NewPublisher(ctx, pub1Listener, pub1Id, pub1Sid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pub1Initiator)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
defer pub1.Close(context.Background())
|
||||
|
||||
|
|
@ -1041,15 +1004,11 @@ func Test_ProxyPublisherLoad(t *testing.T) {
|
|||
country: "DE",
|
||||
}
|
||||
pub2, err := mcu.NewPublisher(ctx, pub2Listener, pub2Id, pub2id, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pub2Initiator)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
defer pub2.Close(context.Background())
|
||||
|
||||
if pub1.(*mcuProxyPublisher).conn.rawUrl == pub2.(*mcuProxyPublisher).conn.rawUrl {
|
||||
t.Errorf("servers should be different, got %s", pub1.(*mcuProxyPublisher).conn.rawUrl)
|
||||
}
|
||||
assert.NotEqual(t, pub1.(*mcuProxyPublisher).conn.rawUrl, pub2.(*mcuProxyPublisher).conn.rawUrl)
|
||||
}
|
||||
|
||||
func Test_ProxyPublisherCountry(t *testing.T) {
|
||||
|
|
@ -1074,15 +1033,11 @@ func Test_ProxyPublisherCountry(t *testing.T) {
|
|||
country: "DE",
|
||||
}
|
||||
pubDE, err := mcu.NewPublisher(ctx, pubDEListener, pubDEId, pubDESid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubDEInitiator)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
defer pubDE.Close(context.Background())
|
||||
|
||||
if pubDE.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL {
|
||||
t.Errorf("expected server %s, go %s", serverDE.URL, pubDE.(*mcuProxyPublisher).conn.rawUrl)
|
||||
}
|
||||
assert.Equal(t, serverDE.URL, pubDE.(*mcuProxyPublisher).conn.rawUrl)
|
||||
|
||||
pubUSId := "the-publisher-us"
|
||||
pubUSSid := "1234567890"
|
||||
|
|
@ -1093,15 +1048,11 @@ func Test_ProxyPublisherCountry(t *testing.T) {
|
|||
country: "US",
|
||||
}
|
||||
pubUS, err := mcu.NewPublisher(ctx, pubUSListener, pubUSId, pubUSSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubUSInitiator)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
defer pubUS.Close(context.Background())
|
||||
|
||||
if pubUS.(*mcuProxyPublisher).conn.rawUrl != serverUS.URL {
|
||||
t.Errorf("expected server %s, go %s", serverUS.URL, pubUS.(*mcuProxyPublisher).conn.rawUrl)
|
||||
}
|
||||
assert.Equal(t, serverUS.URL, pubUS.(*mcuProxyPublisher).conn.rawUrl)
|
||||
}
|
||||
|
||||
func Test_ProxyPublisherContinent(t *testing.T) {
|
||||
|
|
@ -1126,15 +1077,11 @@ func Test_ProxyPublisherContinent(t *testing.T) {
|
|||
country: "DE",
|
||||
}
|
||||
pubDE, err := mcu.NewPublisher(ctx, pubDEListener, pubDEId, pubDESid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubDEInitiator)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
defer pubDE.Close(context.Background())
|
||||
|
||||
if pubDE.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL {
|
||||
t.Errorf("expected server %s, go %s", serverDE.URL, pubDE.(*mcuProxyPublisher).conn.rawUrl)
|
||||
}
|
||||
assert.Equal(t, serverDE.URL, pubDE.(*mcuProxyPublisher).conn.rawUrl)
|
||||
|
||||
pubFRId := "the-publisher-fr"
|
||||
pubFRSid := "1234567890"
|
||||
|
|
@ -1145,15 +1092,11 @@ func Test_ProxyPublisherContinent(t *testing.T) {
|
|||
country: "FR",
|
||||
}
|
||||
pubFR, err := mcu.NewPublisher(ctx, pubFRListener, pubFRId, pubFRSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubFRInitiator)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
defer pubFR.Close(context.Background())
|
||||
|
||||
if pubFR.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL {
|
||||
t.Errorf("expected server %s, go %s", serverDE.URL, pubFR.(*mcuProxyPublisher).conn.rawUrl)
|
||||
}
|
||||
assert.Equal(t, serverDE.URL, pubFR.(*mcuProxyPublisher).conn.rawUrl)
|
||||
}
|
||||
|
||||
func Test_ProxySubscriberCountry(t *testing.T) {
|
||||
|
|
@ -1178,15 +1121,11 @@ func Test_ProxySubscriberCountry(t *testing.T) {
|
|||
country: "DE",
|
||||
}
|
||||
pub, err := mcu.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
defer pub.Close(context.Background())
|
||||
|
||||
if pub.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL {
|
||||
t.Errorf("expected server %s, go %s", serverDE.URL, pub.(*mcuProxyPublisher).conn.rawUrl)
|
||||
}
|
||||
assert.Equal(t, serverDE.URL, pub.(*mcuProxyPublisher).conn.rawUrl)
|
||||
|
||||
subListener := &MockMcuListener{
|
||||
publicId: "subscriber-public",
|
||||
|
|
@ -1195,15 +1134,11 @@ func Test_ProxySubscriberCountry(t *testing.T) {
|
|||
country: "US",
|
||||
}
|
||||
sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
defer sub.Close(context.Background())
|
||||
|
||||
if sub.(*mcuProxySubscriber).conn.rawUrl != serverUS.URL {
|
||||
t.Errorf("expected server %s, go %s", serverUS.URL, sub.(*mcuProxySubscriber).conn.rawUrl)
|
||||
}
|
||||
assert.Equal(t, serverUS.URL, sub.(*mcuProxySubscriber).conn.rawUrl)
|
||||
}
|
||||
|
||||
func Test_ProxySubscriberContinent(t *testing.T) {
|
||||
|
|
@ -1228,15 +1163,11 @@ func Test_ProxySubscriberContinent(t *testing.T) {
|
|||
country: "DE",
|
||||
}
|
||||
pub, err := mcu.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
defer pub.Close(context.Background())
|
||||
|
||||
if pub.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL {
|
||||
t.Errorf("expected server %s, go %s", serverDE.URL, pub.(*mcuProxyPublisher).conn.rawUrl)
|
||||
}
|
||||
assert.Equal(t, serverDE.URL, pub.(*mcuProxyPublisher).conn.rawUrl)
|
||||
|
||||
subListener := &MockMcuListener{
|
||||
publicId: "subscriber-public",
|
||||
|
|
@ -1245,15 +1176,11 @@ func Test_ProxySubscriberContinent(t *testing.T) {
|
|||
country: "FR",
|
||||
}
|
||||
sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
defer sub.Close(context.Background())
|
||||
|
||||
if sub.(*mcuProxySubscriber).conn.rawUrl != serverDE.URL {
|
||||
t.Errorf("expected server %s, go %s", serverDE.URL, sub.(*mcuProxySubscriber).conn.rawUrl)
|
||||
}
|
||||
assert.Equal(t, serverDE.URL, sub.(*mcuProxySubscriber).conn.rawUrl)
|
||||
}
|
||||
|
||||
func Test_ProxySubscriberBandwidth(t *testing.T) {
|
||||
|
|
@ -1278,15 +1205,11 @@ func Test_ProxySubscriberBandwidth(t *testing.T) {
|
|||
country: "DE",
|
||||
}
|
||||
pub, err := mcu.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
defer pub.Close(context.Background())
|
||||
|
||||
if pub.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL {
|
||||
t.Errorf("expected server %s, go %s", serverDE.URL, pub.(*mcuProxyPublisher).conn.rawUrl)
|
||||
}
|
||||
assert.Equal(t, serverDE.URL, pub.(*mcuProxyPublisher).conn.rawUrl)
|
||||
|
||||
serverDE.UpdateBandwidth(0, 100)
|
||||
|
||||
|
|
@ -1315,15 +1238,11 @@ func Test_ProxySubscriberBandwidth(t *testing.T) {
|
|||
country: "US",
|
||||
}
|
||||
sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
defer sub.Close(context.Background())
|
||||
|
||||
if sub.(*mcuProxySubscriber).conn.rawUrl != serverUS.URL {
|
||||
t.Errorf("expected server %s, go %s", serverUS.URL, sub.(*mcuProxySubscriber).conn.rawUrl)
|
||||
}
|
||||
assert.Equal(t, serverUS.URL, sub.(*mcuProxySubscriber).conn.rawUrl)
|
||||
}
|
||||
|
||||
func Test_ProxySubscriberBandwidthOverload(t *testing.T) {
|
||||
|
|
@ -1348,15 +1267,11 @@ func Test_ProxySubscriberBandwidthOverload(t *testing.T) {
|
|||
country: "DE",
|
||||
}
|
||||
pub, err := mcu.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
defer pub.Close(context.Background())
|
||||
|
||||
if pub.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL {
|
||||
t.Errorf("expected server %s, go %s", serverDE.URL, pub.(*mcuProxyPublisher).conn.rawUrl)
|
||||
}
|
||||
assert.Equal(t, serverDE.URL, pub.(*mcuProxyPublisher).conn.rawUrl)
|
||||
|
||||
serverDE.UpdateBandwidth(0, 100)
|
||||
serverUS.UpdateBandwidth(0, 102)
|
||||
|
|
@ -1386,15 +1301,11 @@ func Test_ProxySubscriberBandwidthOverload(t *testing.T) {
|
|||
country: "US",
|
||||
}
|
||||
sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
defer sub.Close(context.Background())
|
||||
|
||||
if sub.(*mcuProxySubscriber).conn.rawUrl != serverDE.URL {
|
||||
t.Errorf("expected server %s, go %s", serverDE.URL, sub.(*mcuProxySubscriber).conn.rawUrl)
|
||||
}
|
||||
assert.Equal(t, serverDE.URL, sub.(*mcuProxySubscriber).conn.rawUrl)
|
||||
}
|
||||
|
||||
type mockGrpcServerHub struct {
|
||||
|
|
@ -1490,9 +1401,7 @@ func Test_ProxyRemotePublisher(t *testing.T) {
|
|||
defer hub1.removeSession(session1)
|
||||
|
||||
pub, err := mcu1.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
defer pub.Close(context.Background())
|
||||
|
||||
|
|
@ -1508,9 +1417,7 @@ func Test_ProxyRemotePublisher(t *testing.T) {
|
|||
country: "DE",
|
||||
}
|
||||
sub, err := mcu2.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
defer sub.Close(context.Background())
|
||||
}
|
||||
|
|
@ -1580,8 +1487,7 @@ func Test_ProxyRemotePublisherWait(t *testing.T) {
|
|||
go func() {
|
||||
defer close(done)
|
||||
sub, err := mcu2.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -1592,9 +1498,7 @@ func Test_ProxyRemotePublisherWait(t *testing.T) {
|
|||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
pub, err := mcu1.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
defer pub.Close(context.Background())
|
||||
|
||||
|
|
@ -1606,7 +1510,7 @@ func Test_ProxyRemotePublisherWait(t *testing.T) {
|
|||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
t.Error(ctx.Err())
|
||||
assert.NoError(t, ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1663,9 +1567,7 @@ func Test_ProxyRemotePublisherTemporary(t *testing.T) {
|
|||
defer hub1.removeSession(session1)
|
||||
|
||||
pub, err := mcu1.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
defer pub.Close(context.Background())
|
||||
|
||||
|
|
@ -1677,9 +1579,7 @@ func Test_ProxyRemotePublisherTemporary(t *testing.T) {
|
|||
mcu2.connectionsMu.RLock()
|
||||
count := len(mcu2.connections)
|
||||
mcu2.connectionsMu.RUnlock()
|
||||
if expected := 1; count != expected {
|
||||
t.Errorf("expected %d connections, got %+v", expected, count)
|
||||
}
|
||||
assert.Equal(t, 1, count)
|
||||
|
||||
subListener := &MockMcuListener{
|
||||
publicId: "subscriber-public",
|
||||
|
|
@ -1688,23 +1588,17 @@ func Test_ProxyRemotePublisherTemporary(t *testing.T) {
|
|||
country: "DE",
|
||||
}
|
||||
sub, err := mcu2.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
defer sub.Close(context.Background())
|
||||
|
||||
if sub.(*mcuProxySubscriber).conn.rawUrl != server1.URL {
|
||||
t.Errorf("expected server %s, go %s", server1.URL, sub.(*mcuProxySubscriber).conn.rawUrl)
|
||||
}
|
||||
assert.Equal(t, server1.URL, sub.(*mcuProxySubscriber).conn.rawUrl)
|
||||
|
||||
// The temporary connection has been added
|
||||
mcu2.connectionsMu.RLock()
|
||||
count = len(mcu2.connections)
|
||||
mcu2.connectionsMu.RUnlock()
|
||||
if expected := 2; count != expected {
|
||||
t.Errorf("expected %d connections, got %+v", expected, count)
|
||||
}
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
sub.Close(context.Background())
|
||||
|
||||
|
|
@ -1713,7 +1607,7 @@ loop:
|
|||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Error(ctx.Err())
|
||||
assert.NoError(t, ctx.Err())
|
||||
default:
|
||||
mcu2.connectionsMu.RLock()
|
||||
count = len(mcu2.connections)
|
||||
|
|
|
|||
|
|
@ -25,6 +25,9 @@ import (
|
|||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func (c *LoopbackNatsClient) waitForSubscriptionsEmpty(ctx context.Context, t *testing.T) {
|
||||
|
|
@ -39,7 +42,7 @@ func (c *LoopbackNatsClient) waitForSubscriptionsEmpty(ctx context.Context, t *t
|
|||
select {
|
||||
case <-ctx.Done():
|
||||
c.mu.Lock()
|
||||
t.Errorf("Error waiting for subscriptions %+v to terminate: %s", c.subscriptions, ctx.Err())
|
||||
assert.NoError(t, ctx.Err(), "Error waiting for subscriptions %+v to terminate", c.subscriptions)
|
||||
c.mu.Unlock()
|
||||
return
|
||||
default:
|
||||
|
|
@ -50,9 +53,7 @@ func (c *LoopbackNatsClient) waitForSubscriptionsEmpty(ctx context.Context, t *t
|
|||
|
||||
func CreateLoopbackNatsClientForTest(t *testing.T) NatsClient {
|
||||
result, err := NewLoopbackNatsClient()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
result.Close()
|
||||
})
|
||||
|
|
|
|||
|
|
@ -27,6 +27,8 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
natsserver "github.com/nats-io/nats-server/v2/test"
|
||||
)
|
||||
|
|
@ -46,9 +48,7 @@ func startLocalNatsServer(t *testing.T) string {
|
|||
func CreateLocalNatsClientForTest(t *testing.T) NatsClient {
|
||||
url := startLocalNatsServer(t)
|
||||
result, err := NewNatsClient(url)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
result.Close()
|
||||
})
|
||||
|
|
@ -56,11 +56,11 @@ func CreateLocalNatsClientForTest(t *testing.T) NatsClient {
|
|||
}
|
||||
|
||||
func testNatsClient_Subscribe(t *testing.T, client NatsClient) {
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
dest := make(chan *nats.Msg)
|
||||
sub, err := client.Subscribe("foo", dest)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
ch := make(chan struct{})
|
||||
|
||||
var received atomic.Int32
|
||||
|
|
@ -75,9 +75,7 @@ func testNatsClient_Subscribe(t *testing.T, client NatsClient) {
|
|||
case <-dest:
|
||||
total := received.Add(1)
|
||||
if total == max {
|
||||
err := sub.Unsubscribe()
|
||||
if err != nil {
|
||||
t.Errorf("Unsubscribe failed with err: %s", err)
|
||||
if err := sub.Unsubscribe(); !assert.NoError(err) {
|
||||
return
|
||||
}
|
||||
close(ch)
|
||||
|
|
@ -89,18 +87,14 @@ func testNatsClient_Subscribe(t *testing.T, client NatsClient) {
|
|||
}()
|
||||
<-ready
|
||||
for i := int32(0); i < max; i++ {
|
||||
if err := client.Publish("foo", []byte("hello")); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(client.Publish("foo", []byte("hello")))
|
||||
|
||||
// Allow NATS goroutines to process messages.
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
<-ch
|
||||
|
||||
if r := received.Load(); r != max {
|
||||
t.Fatalf("Received wrong # of messages: %d vs %d", r, max)
|
||||
}
|
||||
require.EqualValues(max, received.Load(), "Received wrong # of messages")
|
||||
}
|
||||
|
||||
func TestNatsClient_Subscribe(t *testing.T) {
|
||||
|
|
@ -115,9 +109,7 @@ func TestNatsClient_Subscribe(t *testing.T) {
|
|||
func testNatsClient_PublishAfterClose(t *testing.T, client NatsClient) {
|
||||
client.Close()
|
||||
|
||||
if err := client.Publish("foo", "bar"); err != nats.ErrConnectionClosed {
|
||||
t.Errorf("Expected %v, got %v", nats.ErrConnectionClosed, err)
|
||||
}
|
||||
assert.ErrorIs(t, client.Publish("foo", "bar"), nats.ErrConnectionClosed)
|
||||
}
|
||||
|
||||
func TestNatsClient_PublishAfterClose(t *testing.T) {
|
||||
|
|
@ -133,9 +125,8 @@ func testNatsClient_SubscribeAfterClose(t *testing.T, client NatsClient) {
|
|||
client.Close()
|
||||
|
||||
ch := make(chan *nats.Msg)
|
||||
if _, err := client.Subscribe("foo", ch); err != nats.ErrConnectionClosed {
|
||||
t.Errorf("Expected %v, got %v", nats.ErrConnectionClosed, err)
|
||||
}
|
||||
_, err := client.Subscribe("foo", ch)
|
||||
assert.ErrorIs(t, err, nats.ErrConnectionClosed)
|
||||
}
|
||||
|
||||
func TestNatsClient_SubscribeAfterClose(t *testing.T) {
|
||||
|
|
@ -148,6 +139,7 @@ func TestNatsClient_SubscribeAfterClose(t *testing.T) {
|
|||
}
|
||||
|
||||
func testNatsClient_BadSubjects(t *testing.T, client NatsClient) {
|
||||
assert := assert.New(t)
|
||||
subjects := []string{
|
||||
"foo bar",
|
||||
"foo.",
|
||||
|
|
@ -155,9 +147,8 @@ func testNatsClient_BadSubjects(t *testing.T, client NatsClient) {
|
|||
|
||||
ch := make(chan *nats.Msg)
|
||||
for _, s := range subjects {
|
||||
if _, err := client.Subscribe(s, ch); err != nats.ErrBadSubject {
|
||||
t.Errorf("Expected %v for subject %s, got %v", nats.ErrBadSubject, s, err)
|
||||
}
|
||||
_, err := client.Subscribe(s, ch)
|
||||
assert.ErrorIs(err, nats.ErrBadSubject, "Expected error for subject %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ import (
|
|||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNotifierNoWaiter(t *testing.T) {
|
||||
|
|
@ -48,9 +50,7 @@ func TestNotifierSimple(t *testing.T) {
|
|||
defer wg.Done()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
if err := waiter.Wait(ctx); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, waiter.Wait(ctx))
|
||||
}()
|
||||
|
||||
notifier.Notify("foo")
|
||||
|
|
@ -74,9 +74,7 @@ func TestNotifierWaitClosed(t *testing.T) {
|
|||
waiter := notifier.NewWaiter("foo")
|
||||
notifier.Release(waiter)
|
||||
|
||||
if err := waiter.Wait(context.Background()); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, waiter.Wait(context.Background()))
|
||||
}
|
||||
|
||||
func TestNotifierWaitClosedMulti(t *testing.T) {
|
||||
|
|
@ -87,12 +85,8 @@ func TestNotifierWaitClosedMulti(t *testing.T) {
|
|||
notifier.Release(waiter1)
|
||||
notifier.Release(waiter2)
|
||||
|
||||
if err := waiter1.Wait(context.Background()); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := waiter2.Wait(context.Background()); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, waiter1.Wait(context.Background()))
|
||||
assert.NoError(t, waiter2.Wait(context.Background()))
|
||||
}
|
||||
|
||||
func TestNotifierResetWillNotify(t *testing.T) {
|
||||
|
|
@ -108,9 +102,7 @@ func TestNotifierResetWillNotify(t *testing.T) {
|
|||
defer wg.Done()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
if err := waiter.Wait(ctx); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, waiter.Wait(ctx))
|
||||
}()
|
||||
|
||||
notifier.Reset()
|
||||
|
|
@ -137,9 +129,7 @@ func TestNotifierDuplicate(t *testing.T) {
|
|||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
if err := waiter.Wait(ctx); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, waiter.Wait(ctx))
|
||||
}()
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -38,6 +38,8 @@ import (
|
|||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
signaling "github.com/strukturag/nextcloud-spreed-signaling"
|
||||
)
|
||||
|
||||
|
|
@ -57,6 +59,7 @@ func getWebsocketUrl(url string) string {
|
|||
}
|
||||
|
||||
func newProxyServerForTest(t *testing.T) (*ProxyServer, *rsa.PrivateKey, *httptest.Server) {
|
||||
require := require.New(t)
|
||||
tempdir := t.TempDir()
|
||||
var proxy *ProxyServer
|
||||
t.Cleanup(func() {
|
||||
|
|
@ -67,43 +70,32 @@ func newProxyServerForTest(t *testing.T) (*ProxyServer, *rsa.PrivateKey, *httpte
|
|||
|
||||
r := mux.NewRouter()
|
||||
key, err := rsa.GenerateKey(rand.Reader, KeypairSizeForTest)
|
||||
if err != nil {
|
||||
t.Fatalf("could not generate key: %s", err)
|
||||
}
|
||||
require.NoError(err)
|
||||
priv := &pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(key),
|
||||
}
|
||||
privkey, err := os.CreateTemp(tempdir, "privkey*.pem")
|
||||
if err != nil {
|
||||
t.Fatalf("could not create temporary file for private key: %s", err)
|
||||
}
|
||||
if err := pem.Encode(privkey, priv); err != nil {
|
||||
t.Fatalf("could not encode private key: %s", err)
|
||||
}
|
||||
require.NoError(err)
|
||||
require.NoError(pem.Encode(privkey, priv))
|
||||
require.NoError(privkey.Close())
|
||||
|
||||
pubData, err := x509.MarshalPKIXPublicKey(&key.PublicKey)
|
||||
if err != nil {
|
||||
t.Fatalf("could not marshal public key: %s", err)
|
||||
}
|
||||
require.NoError(err)
|
||||
pub := &pem.Block{
|
||||
Type: "RSA PUBLIC KEY",
|
||||
Bytes: pubData,
|
||||
}
|
||||
pubkey, err := os.CreateTemp(tempdir, "pubkey*.pem")
|
||||
if err != nil {
|
||||
t.Fatalf("could not create temporary file for public key: %s", err)
|
||||
}
|
||||
if err := pem.Encode(pubkey, pub); err != nil {
|
||||
t.Fatalf("could not encode public key: %s", err)
|
||||
}
|
||||
require.NoError(err)
|
||||
require.NoError(pem.Encode(pubkey, pub))
|
||||
require.NoError(pubkey.Close())
|
||||
|
||||
config := goconf.NewConfigFile()
|
||||
config.AddOption("tokens", TokenIdForTest, pubkey.Name())
|
||||
|
||||
if proxy, err = NewProxyServer(r, "0.0", config); err != nil {
|
||||
t.Fatalf("could not create proxy server: %s", err)
|
||||
}
|
||||
proxy, err = NewProxyServer(r, "0.0", config)
|
||||
require.NoError(err)
|
||||
|
||||
server := httptest.NewServer(r)
|
||||
t.Cleanup(func() {
|
||||
|
|
@ -125,19 +117,14 @@ func TestTokenValid(t *testing.T) {
|
|||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
tokenString, err := token.SignedString(key)
|
||||
if err != nil {
|
||||
t.Fatalf("could not create token: %s", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
hello := &signaling.HelloProxyClientMessage{
|
||||
Version: "1.0",
|
||||
Token: tokenString,
|
||||
}
|
||||
session, err := proxy.NewSession(hello)
|
||||
if session != nil {
|
||||
if session, err := proxy.NewSession(hello); assert.NoError(t, err) {
|
||||
defer session.Close()
|
||||
} else if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -153,20 +140,16 @@ func TestTokenNotSigned(t *testing.T) {
|
|||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
|
||||
tokenString, err := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
|
||||
if err != nil {
|
||||
t.Fatalf("could not create token: %s", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
hello := &signaling.HelloProxyClientMessage{
|
||||
Version: "1.0",
|
||||
Token: tokenString,
|
||||
}
|
||||
session, err := proxy.NewSession(hello)
|
||||
if session != nil {
|
||||
defer session.Close()
|
||||
t.Errorf("should not have created session")
|
||||
} else if err != TokenAuthFailed {
|
||||
t.Errorf("could have failed with TokenAuthFailed, got %s", err)
|
||||
if session, err := proxy.NewSession(hello); !assert.ErrorIs(t, err, TokenAuthFailed) {
|
||||
if session != nil {
|
||||
defer session.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -182,20 +165,16 @@ func TestTokenUnknown(t *testing.T) {
|
|||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
tokenString, err := token.SignedString(key)
|
||||
if err != nil {
|
||||
t.Fatalf("could not create token: %s", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
hello := &signaling.HelloProxyClientMessage{
|
||||
Version: "1.0",
|
||||
Token: tokenString,
|
||||
}
|
||||
session, err := proxy.NewSession(hello)
|
||||
if session != nil {
|
||||
defer session.Close()
|
||||
t.Errorf("should not have created session")
|
||||
} else if err != TokenAuthFailed {
|
||||
t.Errorf("could have failed with TokenAuthFailed, got %s", err)
|
||||
if session, err := proxy.NewSession(hello); !assert.ErrorIs(t, err, TokenAuthFailed) {
|
||||
if session != nil {
|
||||
defer session.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -211,20 +190,16 @@ func TestTokenInFuture(t *testing.T) {
|
|||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
tokenString, err := token.SignedString(key)
|
||||
if err != nil {
|
||||
t.Fatalf("could not create token: %s", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
hello := &signaling.HelloProxyClientMessage{
|
||||
Version: "1.0",
|
||||
Token: tokenString,
|
||||
}
|
||||
session, err := proxy.NewSession(hello)
|
||||
if session != nil {
|
||||
defer session.Close()
|
||||
t.Errorf("should not have created session")
|
||||
} else if err != TokenNotValidYet {
|
||||
t.Errorf("could have failed with TokenNotValidYet, got %s", err)
|
||||
if session, err := proxy.NewSession(hello); !assert.ErrorIs(t, err, TokenNotValidYet) {
|
||||
if session != nil {
|
||||
defer session.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -240,24 +215,21 @@ func TestTokenExpired(t *testing.T) {
|
|||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
tokenString, err := token.SignedString(key)
|
||||
if err != nil {
|
||||
t.Fatalf("could not create token: %s", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
hello := &signaling.HelloProxyClientMessage{
|
||||
Version: "1.0",
|
||||
Token: tokenString,
|
||||
}
|
||||
session, err := proxy.NewSession(hello)
|
||||
if session != nil {
|
||||
defer session.Close()
|
||||
t.Errorf("should not have created session")
|
||||
} else if err != TokenExpired {
|
||||
t.Errorf("could have failed with TokenExpired, got %s", err)
|
||||
if session, err := proxy.NewSession(hello); !assert.ErrorIs(t, err, TokenExpired) {
|
||||
if session != nil {
|
||||
defer session.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicIPs(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
public := []string{
|
||||
"8.8.8.8",
|
||||
"172.15.1.2",
|
||||
|
|
@ -275,35 +247,30 @@ func TestPublicIPs(t *testing.T) {
|
|||
}
|
||||
for _, s := range public {
|
||||
ip := net.ParseIP(s)
|
||||
if len(ip) == 0 {
|
||||
t.Errorf("invalid IP: %s", s)
|
||||
} else if !IsPublicIP(ip) {
|
||||
t.Errorf("should be public IP: %s", s)
|
||||
if assert.NotEmpty(ip, "invalid IP: %s", s) {
|
||||
assert.True(IsPublicIP(ip), "should be public IP: %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
for _, s := range private {
|
||||
ip := net.ParseIP(s)
|
||||
if len(ip) == 0 {
|
||||
t.Errorf("invalid IP: %s", s)
|
||||
} else if IsPublicIP(ip) {
|
||||
t.Errorf("should be private IP: %s", s)
|
||||
if assert.NotEmpty(ip, "invalid IP: %s", s) {
|
||||
assert.False(IsPublicIP(ip), "should be private IP: %s", s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebsocketFeatures(t *testing.T) {
|
||||
signaling.CatchLogForTest(t)
|
||||
assert := assert.New(t)
|
||||
_, _, server := newProxyServerForTest(t)
|
||||
|
||||
conn, response, err := websocket.DefaultDialer.DialContext(context.Background(), getWebsocketUrl(server.URL), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
defer conn.Close() // nolint
|
||||
|
||||
if server := response.Header.Get("Server"); !strings.HasPrefix(server, "nextcloud-spreed-signaling-proxy/") {
|
||||
t.Errorf("expected valid server header, got \"%s\"", server)
|
||||
assert.Fail("expected valid server header, got \"%s\"", server)
|
||||
}
|
||||
features := response.Header.Get("X-Spreed-Signaling-Features")
|
||||
featuresList := make(map[string]bool)
|
||||
|
|
@ -311,19 +278,15 @@ func TestWebsocketFeatures(t *testing.T) {
|
|||
f = strings.TrimSpace(f)
|
||||
if f != "" {
|
||||
if _, found := featuresList[f]; found {
|
||||
t.Errorf("duplicate feature id \"%s\" in \"%s\"", f, features)
|
||||
assert.Fail("duplicate feature id \"%s\" in \"%s\"", f, features)
|
||||
}
|
||||
featuresList[f] = true
|
||||
}
|
||||
}
|
||||
if len(featuresList) == 0 {
|
||||
t.Errorf("expected valid features header, got \"%s\"", features)
|
||||
}
|
||||
assert.NotEmpty(featuresList, "expected valid features header, got \"%s\"", features)
|
||||
if _, found := featuresList["remote-streams"]; !found {
|
||||
t.Errorf("expected feature \"remote-streams\", got \"%s\"", features)
|
||||
assert.Fail("expected feature \"remote-streams\", got \"%s\"", features)
|
||||
}
|
||||
|
||||
if err := conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Time{}); err != nil {
|
||||
t.Errorf("could not write close message: %s", err)
|
||||
}
|
||||
assert.NoError(conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Time{}))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -37,6 +37,8 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/dlintw/goconf"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.etcd.io/etcd/server/v3/embed"
|
||||
"go.etcd.io/etcd/server/v3/lease"
|
||||
|
||||
|
|
@ -73,9 +75,7 @@ func newEtcdForTesting(t *testing.T) *embed.Etcd {
|
|||
cfg.LogLevel = "warn"
|
||||
|
||||
u, err := url.Parse(etcdListenUrl)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
// Find a free port to bind the server to.
|
||||
var etcd *embed.Etcd
|
||||
|
|
@ -91,14 +91,12 @@ func newEtcdForTesting(t *testing.T) *embed.Etcd {
|
|||
etcd, err = embed.StartEtcd(cfg)
|
||||
if isErrorAddressAlreadyInUse(err) {
|
||||
continue
|
||||
} else if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
break
|
||||
}
|
||||
if etcd == nil {
|
||||
t.Fatal("could not find free port")
|
||||
}
|
||||
require.NotNil(t, etcd, "could not find free port")
|
||||
|
||||
t.Cleanup(func() {
|
||||
etcd.Close()
|
||||
|
|
@ -118,9 +116,7 @@ func newTokensEtcdForTesting(t *testing.T) (*tokensEtcd, *embed.Etcd) {
|
|||
cfg.AddOption("tokens", "keyformat", "/%s, /testing/%s/key")
|
||||
|
||||
tokens, err := NewProxyTokensEtcd(cfg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
tokens.Close()
|
||||
})
|
||||
|
|
@ -134,11 +130,9 @@ func storeKey(t *testing.T, etcd *embed.Etcd, key string, pubkey crypto.PublicKe
|
|||
switch pubkey := pubkey.(type) {
|
||||
case rsa.PublicKey:
|
||||
data, err = x509.MarshalPKIXPublicKey(&pubkey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
default:
|
||||
t.Fatalf("unknown key type %T in %+v", pubkey, pubkey)
|
||||
require.Fail(t, "unknown key type %T in %+v", pubkey, pubkey)
|
||||
}
|
||||
|
||||
data = pem.EncodeToMemory(&pem.Block{
|
||||
|
|
@ -154,9 +148,7 @@ func storeKey(t *testing.T, etcd *embed.Etcd, key string, pubkey crypto.PublicKe
|
|||
|
||||
func generateAndSaveKey(t *testing.T, etcd *embed.Etcd, name string) *rsa.PrivateKey {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
storeKey(t, etcd, name, key.PublicKey)
|
||||
return key
|
||||
|
|
@ -164,24 +156,17 @@ func generateAndSaveKey(t *testing.T, etcd *embed.Etcd, name string) *rsa.Privat
|
|||
|
||||
func TestProxyTokensEtcd(t *testing.T) {
|
||||
signaling.CatchLogForTest(t)
|
||||
assert := assert.New(t)
|
||||
tokens, etcd := newTokensEtcdForTesting(t)
|
||||
|
||||
key1 := generateAndSaveKey(t, etcd, "/foo")
|
||||
key2 := generateAndSaveKey(t, etcd, "/testing/bar/key")
|
||||
|
||||
if token, err := tokens.Get("foo"); err != nil {
|
||||
t.Error(err)
|
||||
} else if token == nil {
|
||||
t.Error("could not get token")
|
||||
} else if !key1.PublicKey.Equal(token.key) {
|
||||
t.Error("token keys mismatch")
|
||||
if token, err := tokens.Get("foo"); assert.NoError(err) && assert.NotNil(token) {
|
||||
assert.True(key1.PublicKey.Equal(token.key))
|
||||
}
|
||||
|
||||
if token, err := tokens.Get("bar"); err != nil {
|
||||
t.Error(err)
|
||||
} else if token == nil {
|
||||
t.Error("could not get token")
|
||||
} else if !key2.PublicKey.Equal(token.key) {
|
||||
t.Error("token keys mismatch")
|
||||
if token, err := tokens.Get("bar"); assert.NoError(err) && assert.NotNil(token) {
|
||||
assert.True(key2.PublicKey.Equal(token.key))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/dlintw/goconf"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.etcd.io/etcd/server/v3/embed"
|
||||
)
|
||||
|
||||
|
|
@ -43,9 +44,7 @@ func newProxyConfigEtcd(t *testing.T, proxy McuProxy) (*embed.Etcd, ProxyConfig)
|
|||
cfg := goconf.NewConfigFile()
|
||||
cfg.AddOption("mcu", "keyprefix", "proxies/")
|
||||
p, err := NewProxyConfigEtcd(cfg, client, proxy)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
p.Stop()
|
||||
})
|
||||
|
|
@ -55,9 +54,7 @@ func newProxyConfigEtcd(t *testing.T, proxy McuProxy) (*embed.Etcd, ProxyConfig)
|
|||
func SetEtcdProxy(t *testing.T, etcd *embed.Etcd, path string, proxy *TestProxyInformationEtcd) {
|
||||
t.Helper()
|
||||
data, err := json.Marshal(proxy)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
SetEtcdValue(etcd, path, data)
|
||||
}
|
||||
|
||||
|
|
@ -74,9 +71,7 @@ func TestProxyConfigEtcd(t *testing.T) {
|
|||
Address: "https://foo/",
|
||||
})
|
||||
proxy.Expect("add", "https://foo/")
|
||||
if err := config.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, config.Start())
|
||||
proxy.WaitForEvents(ctx)
|
||||
|
||||
proxy.Expect("add", "https://bar/")
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/dlintw/goconf"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newProxyConfigStatic(t *testing.T, proxy McuProxy, dns bool, urls ...string) (ProxyConfig, *DnsMonitor) {
|
||||
|
|
@ -38,9 +39,7 @@ func newProxyConfigStatic(t *testing.T, proxy McuProxy, dns bool, urls ...string
|
|||
}
|
||||
dnsMonitor := newDnsMonitorForTest(t, time.Hour) // will be updated manually
|
||||
p, err := NewProxyConfigStatic(cfg, proxy, dnsMonitor)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
p.Stop()
|
||||
})
|
||||
|
|
@ -53,9 +52,7 @@ func updateProxyConfigStatic(t *testing.T, config ProxyConfig, dns bool, urls ..
|
|||
if dns {
|
||||
cfg.AddOption("mcu", "dnsdiscovery", "true")
|
||||
}
|
||||
if err := config.Reload(cfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, config.Reload(cfg))
|
||||
}
|
||||
|
||||
func TestProxyConfigStaticSimple(t *testing.T) {
|
||||
|
|
@ -63,9 +60,7 @@ func TestProxyConfigStaticSimple(t *testing.T) {
|
|||
proxy := newMcuProxyForConfig(t)
|
||||
config, _ := newProxyConfigStatic(t, proxy, false, "https://foo/")
|
||||
proxy.Expect("add", "https://foo/")
|
||||
if err := config.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, config.Start())
|
||||
|
||||
proxy.Expect("keep", "https://foo/")
|
||||
proxy.Expect("add", "https://bar/")
|
||||
|
|
@ -82,9 +77,7 @@ func TestProxyConfigStaticDNS(t *testing.T) {
|
|||
lookup := newMockDnsLookupForTest(t)
|
||||
proxy := newMcuProxyForConfig(t)
|
||||
config, dnsMonitor := newProxyConfigStatic(t, proxy, true, "https://foo/")
|
||||
if err := config.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, config.Start())
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,8 @@ import (
|
|||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
@ -61,9 +63,7 @@ func newMcuProxyForConfig(t *testing.T) *mcuProxyForConfig {
|
|||
t: t,
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if len(proxy.expected) > 0 {
|
||||
t.Errorf("expected events %+v were not triggered", proxy.expected)
|
||||
}
|
||||
assert.Empty(t, proxy.expected)
|
||||
})
|
||||
return proxy
|
||||
}
|
||||
|
|
@ -99,7 +99,7 @@ func (p *mcuProxyForConfig) WaitForEvents(ctx context.Context) {
|
|||
defer p.mu.Lock()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
p.t.Error(ctx.Err())
|
||||
assert.NoError(p.t, ctx.Err())
|
||||
case <-waiter:
|
||||
}
|
||||
}
|
||||
|
|
@ -125,7 +125,7 @@ func (p *mcuProxyForConfig) checkEvent(event *proxyConfigEvent) {
|
|||
defer p.mu.Unlock()
|
||||
|
||||
if len(p.expected) == 0 {
|
||||
p.t.Errorf("no event expected, got %+v from %s:%d", event, caller.File, caller.Line)
|
||||
assert.Fail(p.t, "no event expected, got %+v from %s:%d", event, caller.File, caller.Line)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -145,7 +145,7 @@ func (p *mcuProxyForConfig) checkEvent(event *proxyConfigEvent) {
|
|||
expected := p.expected[0]
|
||||
p.expected = p.expected[1:]
|
||||
if !reflect.DeepEqual(expected, *event) {
|
||||
p.t.Errorf("expected %+v, got %+v from %s:%d", expected, event, caller.File, caller.Line)
|
||||
assert.Fail(p.t, "expected %+v, got %+v from %s:%d", expected, event, caller.File, caller.Line)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -28,9 +28,12 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func NewRoomPingForTest(t *testing.T) (*url.URL, *RoomPing) {
|
||||
require := require.New(t)
|
||||
r := mux.NewRouter()
|
||||
registerBackendHandler(t, r)
|
||||
|
||||
|
|
@ -40,30 +43,23 @@ func NewRoomPingForTest(t *testing.T) (*url.URL, *RoomPing) {
|
|||
})
|
||||
|
||||
config, err := getTestConfig(server)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
backend, err := NewBackendClient(config, 1, "0.0", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
p, err := NewRoomPing(backend, backend.capabilities)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
u, err := url.Parse(server.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
return u, p
|
||||
}
|
||||
|
||||
func TestSingleRoomPing(t *testing.T) {
|
||||
CatchLogForTest(t)
|
||||
assert := assert.New(t)
|
||||
u, ping := NewRoomPingForTest(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
|
|
@ -78,13 +74,9 @@ func TestSingleRoomPing(t *testing.T) {
|
|||
SessionId: "123",
|
||||
},
|
||||
}
|
||||
if err := ping.SendPings(ctx, room1.Id(), u, entries1); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if requests := getPingRequests(t); len(requests) != 1 {
|
||||
t.Errorf("expected one ping request, got %+v", requests)
|
||||
} else if len(requests[0].Ping.Entries) != 1 {
|
||||
t.Errorf("expected one entry, got %+v", requests[0].Ping.Entries)
|
||||
assert.NoError(ping.SendPings(ctx, room1.Id(), u, entries1))
|
||||
if requests := getPingRequests(t); assert.Len(requests, 1) {
|
||||
assert.Len(requests[0].Ping.Entries, 1)
|
||||
}
|
||||
clearPingRequests(t)
|
||||
|
||||
|
|
@ -97,24 +89,19 @@ func TestSingleRoomPing(t *testing.T) {
|
|||
SessionId: "456",
|
||||
},
|
||||
}
|
||||
if err := ping.SendPings(ctx, room2.Id(), u, entries2); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if requests := getPingRequests(t); len(requests) != 1 {
|
||||
t.Errorf("expected one ping request, got %+v", requests)
|
||||
} else if len(requests[0].Ping.Entries) != 1 {
|
||||
t.Errorf("expected one entry, got %+v", requests[0].Ping.Entries)
|
||||
assert.NoError(ping.SendPings(ctx, room2.Id(), u, entries2))
|
||||
if requests := getPingRequests(t); assert.Len(requests, 1) {
|
||||
assert.Len(requests[0].Ping.Entries, 1)
|
||||
}
|
||||
clearPingRequests(t)
|
||||
|
||||
ping.publishActiveSessions()
|
||||
if requests := getPingRequests(t); len(requests) != 0 {
|
||||
t.Errorf("expected no ping requests, got %+v", requests)
|
||||
}
|
||||
assert.Empty(getPingRequests(t))
|
||||
}
|
||||
|
||||
func TestMultiRoomPing(t *testing.T) {
|
||||
CatchLogForTest(t)
|
||||
assert := assert.New(t)
|
||||
u, ping := NewRoomPingForTest(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
|
|
@ -129,12 +116,8 @@ func TestMultiRoomPing(t *testing.T) {
|
|||
SessionId: "123",
|
||||
},
|
||||
}
|
||||
if err := ping.SendPings(ctx, room1.Id(), u, entries1); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if requests := getPingRequests(t); len(requests) != 0 {
|
||||
t.Errorf("expected no ping requests, got %+v", requests)
|
||||
}
|
||||
assert.NoError(ping.SendPings(ctx, room1.Id(), u, entries1))
|
||||
assert.Empty(getPingRequests(t))
|
||||
|
||||
room2 := &Room{
|
||||
id: "sample-room-2",
|
||||
|
|
@ -145,23 +128,18 @@ func TestMultiRoomPing(t *testing.T) {
|
|||
SessionId: "456",
|
||||
},
|
||||
}
|
||||
if err := ping.SendPings(ctx, room2.Id(), u, entries2); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if requests := getPingRequests(t); len(requests) != 0 {
|
||||
t.Errorf("expected no ping requests, got %+v", requests)
|
||||
}
|
||||
assert.NoError(ping.SendPings(ctx, room2.Id(), u, entries2))
|
||||
assert.Empty(getPingRequests(t))
|
||||
|
||||
ping.publishActiveSessions()
|
||||
if requests := getPingRequests(t); len(requests) != 1 {
|
||||
t.Errorf("expected one ping request, got %+v", requests)
|
||||
} else if len(requests[0].Ping.Entries) != 2 {
|
||||
t.Errorf("expected two entries, got %+v", requests[0].Ping.Entries)
|
||||
if requests := getPingRequests(t); assert.Len(requests, 1) {
|
||||
assert.Len(requests[0].Ping.Entries, 2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiRoomPing_Separate(t *testing.T) {
|
||||
CatchLogForTest(t)
|
||||
assert := assert.New(t)
|
||||
u, ping := NewRoomPingForTest(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
|
|
@ -176,35 +154,26 @@ func TestMultiRoomPing_Separate(t *testing.T) {
|
|||
SessionId: "123",
|
||||
},
|
||||
}
|
||||
if err := ping.SendPings(ctx, room1.Id(), u, entries1); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if requests := getPingRequests(t); len(requests) != 0 {
|
||||
t.Errorf("expected no ping requests, got %+v", requests)
|
||||
}
|
||||
assert.NoError(ping.SendPings(ctx, room1.Id(), u, entries1))
|
||||
assert.Empty(getPingRequests(t))
|
||||
entries2 := []BackendPingEntry{
|
||||
{
|
||||
UserId: "bar",
|
||||
SessionId: "456",
|
||||
},
|
||||
}
|
||||
if err := ping.SendPings(ctx, room1.Id(), u, entries2); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if requests := getPingRequests(t); len(requests) != 0 {
|
||||
t.Errorf("expected no ping requests, got %+v", requests)
|
||||
}
|
||||
assert.NoError(ping.SendPings(ctx, room1.Id(), u, entries2))
|
||||
assert.Empty(getPingRequests(t))
|
||||
|
||||
ping.publishActiveSessions()
|
||||
if requests := getPingRequests(t); len(requests) != 1 {
|
||||
t.Errorf("expected one ping request, got %+v", requests)
|
||||
} else if len(requests[0].Ping.Entries) != 2 {
|
||||
t.Errorf("expected two entries, got %+v", requests[0].Ping.Entries)
|
||||
if requests := getPingRequests(t); assert.Len(requests, 1) {
|
||||
assert.Len(requests[0].Ping.Entries, 2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiRoomPing_DeleteRoom(t *testing.T) {
|
||||
CatchLogForTest(t)
|
||||
assert := assert.New(t)
|
||||
u, ping := NewRoomPingForTest(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
|
|
@ -219,12 +188,8 @@ func TestMultiRoomPing_DeleteRoom(t *testing.T) {
|
|||
SessionId: "123",
|
||||
},
|
||||
}
|
||||
if err := ping.SendPings(ctx, room1.Id(), u, entries1); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if requests := getPingRequests(t); len(requests) != 0 {
|
||||
t.Errorf("expected no ping requests, got %+v", requests)
|
||||
}
|
||||
assert.NoError(ping.SendPings(ctx, room1.Id(), u, entries1))
|
||||
assert.Empty(getPingRequests(t))
|
||||
|
||||
room2 := &Room{
|
||||
id: "sample-room-2",
|
||||
|
|
@ -235,19 +200,13 @@ func TestMultiRoomPing_DeleteRoom(t *testing.T) {
|
|||
SessionId: "456",
|
||||
},
|
||||
}
|
||||
if err := ping.SendPings(ctx, room2.Id(), u, entries2); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if requests := getPingRequests(t); len(requests) != 0 {
|
||||
t.Errorf("expected no ping requests, got %+v", requests)
|
||||
}
|
||||
assert.NoError(ping.SendPings(ctx, room2.Id(), u, entries2))
|
||||
assert.Empty(getPingRequests(t))
|
||||
|
||||
ping.DeleteRoom(room2.Id())
|
||||
|
||||
ping.publishActiveSessions()
|
||||
if requests := getPingRequests(t); len(requests) != 1 {
|
||||
t.Errorf("expected one ping request, got %+v", requests)
|
||||
} else if len(requests[0].Ping.Entries) != 1 {
|
||||
t.Errorf("expected two entries, got %+v", requests[0].Ping.Entries)
|
||||
if requests := getPingRequests(t); assert.Len(requests, 1) {
|
||||
assert.Len(requests[0].Ping.Entries, 1)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
356
room_test.go
356
room_test.go
|
|
@ -27,11 +27,14 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRoom_InCall(t *testing.T) {
|
||||
|
|
@ -63,59 +66,47 @@ func TestRoom_InCall(t *testing.T) {
|
|||
}
|
||||
for _, test := range tests {
|
||||
inCall, ok := IsInCall(test.Value)
|
||||
if ok != test.Valid {
|
||||
t.Errorf("%+v should be valid %v, got %v", test.Value, test.Valid, ok)
|
||||
}
|
||||
if inCall != test.InCall {
|
||||
t.Errorf("%+v should convert to %v, got %v", test.Value, test.InCall, inCall)
|
||||
if test.Valid {
|
||||
assert.True(t, ok, "%+v should be valid", test.Value)
|
||||
} else {
|
||||
assert.False(t, ok, "%+v should not be valid", test.Value)
|
||||
}
|
||||
assert.EqualValues(t, test.InCall, inCall, "conversion failed for %+v", test.Value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoom_Update(t *testing.T) {
|
||||
t.Parallel()
|
||||
CatchLogForTest(t)
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
hub, _, router, server := CreateHubForTest(t)
|
||||
|
||||
config, err := getTestConfig(server)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
b, err := NewBackendServer(config, hub, "no-version")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := b.Start(router); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
require.NoError(b.Start(router))
|
||||
|
||||
client := NewTestClient(t, server, hub)
|
||||
defer client.CloseWithBye()
|
||||
|
||||
if err := client.SendHello(testDefaultUserId); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(client.SendHello(testDefaultUserId))
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
|
||||
hello, err := client.RunUntilHello(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
// Join room by id.
|
||||
roomId := "test-room"
|
||||
if room, err := client.JoinRoom(ctx, roomId); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if room.Room.RoomId != roomId {
|
||||
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
|
||||
}
|
||||
roomMsg, err := client.JoinRoom(ctx, roomId)
|
||||
require.NoError(err)
|
||||
require.Equal(roomId, roomMsg.Room.RoomId)
|
||||
|
||||
// We will receive a "joined" event.
|
||||
if err := client.RunUntilJoined(ctx, hello.Hello); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(client.RunUntilJoined(ctx, hello.Hello))
|
||||
|
||||
// Simulate backend request from Nextcloud to update the room.
|
||||
roomProperties := json.RawMessage("{\"foo\":\"bar\"}")
|
||||
|
|
@ -130,54 +121,32 @@ func TestRoom_Update(t *testing.T) {
|
|||
}
|
||||
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
res, err := performBackendRequest(server.URL+"/api/v1/room/"+roomId, data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
defer res.Body.Close()
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if res.StatusCode != 200 {
|
||||
t.Errorf("Expected successful request, got %s: %s", res.Status, string(body))
|
||||
}
|
||||
assert.NoError(err)
|
||||
assert.Equal(http.StatusOK, res.StatusCode, "Expected successful request, got %s", string(body))
|
||||
|
||||
// The client receives a roomlist update and a changed room event. The
|
||||
// ordering is not defined because messages are sent by asynchronous event
|
||||
// handlers.
|
||||
message1, err := client.RunUntilMessage(ctx)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(err)
|
||||
message2, err := client.RunUntilMessage(ctx)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(err)
|
||||
|
||||
if msg, err := checkMessageRoomlistUpdate(message1); err != nil {
|
||||
if err := checkMessageRoomId(message1, roomId); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if msg, err := checkMessageRoomlistUpdate(message2); err != nil {
|
||||
t.Error(err)
|
||||
} else if msg.RoomId != roomId {
|
||||
t.Errorf("Expected room id %s, got %+v", roomId, msg)
|
||||
} else if len(msg.Properties) == 0 || !bytes.Equal(msg.Properties, roomProperties) {
|
||||
t.Errorf("Expected room properties %s, got %+v", string(roomProperties), msg)
|
||||
assert.NoError(checkMessageRoomId(message1, roomId))
|
||||
if msg, err := checkMessageRoomlistUpdate(message2); assert.NoError(err) {
|
||||
assert.Equal(roomId, msg.RoomId)
|
||||
assert.Equal(string(roomProperties), string(msg.Properties))
|
||||
}
|
||||
} else {
|
||||
if msg.RoomId != roomId {
|
||||
t.Errorf("Expected room id %s, got %+v", roomId, msg)
|
||||
} else if len(msg.Properties) == 0 || !bytes.Equal(msg.Properties, roomProperties) {
|
||||
t.Errorf("Expected room properties %s, got %+v", string(roomProperties), msg)
|
||||
}
|
||||
if err := checkMessageRoomId(message2, roomId); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.Equal(roomId, msg.RoomId)
|
||||
assert.Equal(string(roomProperties), string(msg.Properties))
|
||||
assert.NoError(checkMessageRoomId(message2, roomId))
|
||||
}
|
||||
|
||||
// Allow up to 100 milliseconds for asynchronous event processing.
|
||||
|
|
@ -206,55 +175,41 @@ loop:
|
|||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(err)
|
||||
}
|
||||
|
||||
func TestRoom_Delete(t *testing.T) {
|
||||
t.Parallel()
|
||||
CatchLogForTest(t)
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
hub, _, router, server := CreateHubForTest(t)
|
||||
|
||||
config, err := getTestConfig(server)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
b, err := NewBackendServer(config, hub, "no-version")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := b.Start(router); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
require.NoError(b.Start(router))
|
||||
|
||||
client := NewTestClient(t, server, hub)
|
||||
defer client.CloseWithBye()
|
||||
|
||||
if err := client.SendHello(testDefaultUserId); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(client.SendHello(testDefaultUserId))
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
|
||||
hello, err := client.RunUntilHello(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
// Join room by id.
|
||||
roomId := "test-room"
|
||||
if room, err := client.JoinRoom(ctx, roomId); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if room.Room.RoomId != roomId {
|
||||
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
|
||||
}
|
||||
roomMsg, err := client.JoinRoom(ctx, roomId)
|
||||
require.NoError(err)
|
||||
require.Equal(roomId, roomMsg.Room.RoomId)
|
||||
|
||||
// We will receive a "joined" event.
|
||||
if err := client.RunUntilJoined(ctx, hello.Hello); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(client.RunUntilJoined(ctx, hello.Hello))
|
||||
|
||||
// Simulate backend request from Nextcloud to update the room.
|
||||
msg := &BackendServerRoomRequest{
|
||||
|
|
@ -267,47 +222,31 @@ func TestRoom_Delete(t *testing.T) {
|
|||
}
|
||||
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
res, err := performBackendRequest(server.URL+"/api/v1/room/"+roomId, data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
defer res.Body.Close()
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if res.StatusCode != 200 {
|
||||
t.Errorf("Expected successful request, got %s: %s", res.Status, string(body))
|
||||
}
|
||||
assert.NoError(err)
|
||||
assert.Equal(http.StatusOK, res.StatusCode, "Expected successful request, got %s", string(body))
|
||||
|
||||
// The client is no longer invited to the room and leaves it. The ordering
|
||||
// of messages is not defined as they get published through events and handled
|
||||
// by asynchronous channels.
|
||||
message1, err := client.RunUntilMessage(ctx)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(err)
|
||||
|
||||
if err := checkMessageType(message1, "event"); err != nil {
|
||||
// Ordering should be "leave room", "disinvited".
|
||||
if err := checkMessageRoomId(message1, ""); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
message2, err := client.RunUntilMessage(ctx)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if _, err := checkMessageRoomlistDisinvite(message2); err != nil {
|
||||
t.Error(err)
|
||||
assert.NoError(checkMessageRoomId(message1, ""))
|
||||
if message2, err := client.RunUntilMessage(ctx); assert.NoError(err) {
|
||||
_, err := checkMessageRoomlistDisinvite(message2)
|
||||
assert.NoError(err)
|
||||
}
|
||||
} else {
|
||||
// Ordering should be "disinvited", "leave room".
|
||||
if _, err := checkMessageRoomlistDisinvite(message1); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
_, err := checkMessageRoomlistDisinvite(message1)
|
||||
assert.NoError(err)
|
||||
message2, err := client.RunUntilMessage(ctx)
|
||||
if err != nil {
|
||||
// The connection should get closed after the "disinvited".
|
||||
|
|
@ -315,10 +254,10 @@ func TestRoom_Delete(t *testing.T) {
|
|||
websocket.CloseNormalClosure,
|
||||
websocket.CloseGoingAway,
|
||||
websocket.CloseNoStatusReceived) {
|
||||
t.Error(err)
|
||||
assert.NoError(err)
|
||||
}
|
||||
} else if err := checkMessageRoomId(message2, ""); err != nil {
|
||||
t.Error(err)
|
||||
} else {
|
||||
assert.NoError(checkMessageRoomId(message2, ""))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -350,151 +289,106 @@ loop:
|
|||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(err)
|
||||
}
|
||||
|
||||
func TestRoom_RoomSessionData(t *testing.T) {
|
||||
t.Parallel()
|
||||
CatchLogForTest(t)
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
hub, _, router, server := CreateHubForTest(t)
|
||||
|
||||
config, err := getTestConfig(server)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
b, err := NewBackendServer(config, hub, "no-version")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := b.Start(router); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
require.NoError(b.Start(router))
|
||||
|
||||
client := NewTestClient(t, server, hub)
|
||||
defer client.CloseWithBye()
|
||||
|
||||
if err := client.SendHello(authAnonymousUserId); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(client.SendHello(authAnonymousUserId))
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
|
||||
hello, err := client.RunUntilHello(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
// Join room by id.
|
||||
roomId := "test-room-with-sessiondata"
|
||||
if room, err := client.JoinRoom(ctx, roomId); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if room.Room.RoomId != roomId {
|
||||
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
|
||||
}
|
||||
roomMsg, err := client.JoinRoom(ctx, roomId)
|
||||
require.NoError(err)
|
||||
require.Equal(roomId, roomMsg.Room.RoomId)
|
||||
|
||||
// We will receive a "joined" event with the userid from the room session data.
|
||||
expected := "userid-from-sessiondata"
|
||||
if message, err := client.RunUntilMessage(ctx); err != nil {
|
||||
t.Error(err)
|
||||
} else if err := client.checkMessageJoinedSession(message, hello.Hello.SessionId, expected); err != nil {
|
||||
t.Error(err)
|
||||
} else if message.Event.Join[0].RoomSessionId != roomId+"-"+hello.Hello.SessionId {
|
||||
t.Errorf("Expected join room session id %s, got %+v", roomId+"-"+hello.Hello.SessionId, message.Event.Join[0])
|
||||
if message, err := client.RunUntilMessage(ctx); assert.NoError(err) {
|
||||
if assert.NoError(client.checkMessageJoinedSession(message, hello.Hello.SessionId, expected)) {
|
||||
assert.Equal(roomId+"-"+hello.Hello.SessionId, message.Event.Join[0].RoomSessionId)
|
||||
}
|
||||
}
|
||||
|
||||
session := hub.GetSessionByPublicId(hello.Hello.SessionId)
|
||||
if session == nil {
|
||||
t.Fatalf("Could not find session %s", hello.Hello.SessionId)
|
||||
}
|
||||
|
||||
if userid := session.UserId(); userid != expected {
|
||||
t.Errorf("Expected userid %s, got %s", expected, userid)
|
||||
}
|
||||
require.NotNil(session, "Could not find session %s", hello.Hello.SessionId)
|
||||
assert.Equal(expected, session.UserId())
|
||||
|
||||
room := hub.getRoom(roomId)
|
||||
if room == nil {
|
||||
t.Fatalf("Room not found")
|
||||
}
|
||||
assert.NotNil(room, "Room not found")
|
||||
|
||||
entries, wg := room.publishActiveSessions()
|
||||
if entries != 1 {
|
||||
t.Errorf("expected 1 entries, got %d", entries)
|
||||
}
|
||||
assert.Equal(1, entries)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestRoom_InCallAll(t *testing.T) {
|
||||
t.Parallel()
|
||||
CatchLogForTest(t)
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
hub, _, router, server := CreateHubForTest(t)
|
||||
|
||||
config, err := getTestConfig(server)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
b, err := NewBackendServer(config, hub, "no-version")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := b.Start(router); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
require.NoError(b.Start(router))
|
||||
|
||||
client1 := NewTestClient(t, server, hub)
|
||||
defer client1.CloseWithBye()
|
||||
|
||||
if err := client1.SendHello(testDefaultUserId + "1"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(client1.SendHello(testDefaultUserId + "1"))
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
|
||||
hello1, err := client1.RunUntilHello(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
client2 := NewTestClient(t, server, hub)
|
||||
defer client2.CloseWithBye()
|
||||
|
||||
if err := client2.SendHello(testDefaultUserId + "2"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(client2.SendHello(testDefaultUserId + "2"))
|
||||
|
||||
hello2, err := client2.RunUntilHello(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
// Join room by id.
|
||||
roomId := "test-room"
|
||||
if room, err := client1.JoinRoom(ctx, roomId); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if room.Room.RoomId != roomId {
|
||||
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
|
||||
}
|
||||
roomMsg, err := client1.JoinRoom(ctx, roomId)
|
||||
require.NoError(err)
|
||||
require.Equal(roomId, roomMsg.Room.RoomId)
|
||||
|
||||
if err := client1.RunUntilJoined(ctx, hello1.Hello); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(client1.RunUntilJoined(ctx, hello1.Hello))
|
||||
|
||||
if room, err := client2.JoinRoom(ctx, roomId); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if room.Room.RoomId != roomId {
|
||||
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
|
||||
}
|
||||
roomMsg, err = client2.JoinRoom(ctx, roomId)
|
||||
require.NoError(err)
|
||||
require.Equal(roomId, roomMsg.Room.RoomId)
|
||||
|
||||
if err := client2.RunUntilJoined(ctx, hello1.Hello, hello2.Hello); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(client2.RunUntilJoined(ctx, hello1.Hello, hello2.Hello))
|
||||
|
||||
if err := client1.RunUntilJoined(ctx, hello2.Hello); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(client1.RunUntilJoined(ctx, hello2.Hello))
|
||||
|
||||
// Simulate backend request from Nextcloud to update the "inCall" flag of all participants.
|
||||
msg1 := &BackendServerRoomRequest{
|
||||
|
|
@ -506,32 +400,20 @@ func TestRoom_InCallAll(t *testing.T) {
|
|||
}
|
||||
|
||||
data1, err := json.Marshal(msg1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
res1, err := performBackendRequest(server.URL+"/api/v1/room/"+roomId, data1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
defer res1.Body.Close()
|
||||
body1, err := io.ReadAll(res1.Body)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if res1.StatusCode != 200 {
|
||||
t.Errorf("Expected successful request, got %s: %s", res1.Status, string(body1))
|
||||
assert.NoError(err)
|
||||
assert.Equal(http.StatusOK, res1.StatusCode, "Expected successful request, got %s", string(body1))
|
||||
|
||||
if msg, err := client1.RunUntilMessage(ctx); assert.NoError(err) {
|
||||
assert.NoError(checkMessageInCallAll(msg, roomId, FlagInCall))
|
||||
}
|
||||
|
||||
if msg, err := client1.RunUntilMessage(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if err := checkMessageInCallAll(msg, roomId, FlagInCall); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if msg, err := client2.RunUntilMessage(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if err := checkMessageInCallAll(msg, roomId, FlagInCall); err != nil {
|
||||
t.Fatal(err)
|
||||
if msg, err := client2.RunUntilMessage(ctx); assert.NoError(err) {
|
||||
assert.NoError(checkMessageInCallAll(msg, roomId, FlagInCall))
|
||||
}
|
||||
|
||||
// Simulate backend request from Nextcloud to update the "inCall" flag of all participants.
|
||||
|
|
@ -544,31 +426,19 @@ func TestRoom_InCallAll(t *testing.T) {
|
|||
}
|
||||
|
||||
data2, err := json.Marshal(msg2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
res2, err := performBackendRequest(server.URL+"/api/v1/room/"+roomId, data2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
defer res2.Body.Close()
|
||||
body2, err := io.ReadAll(res2.Body)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if res2.StatusCode != 200 {
|
||||
t.Errorf("Expected successful request, got %s: %s", res2.Status, string(body2))
|
||||
assert.NoError(err)
|
||||
assert.Equal(http.StatusOK, res2.StatusCode, "Expected successful request, got %s", string(body2))
|
||||
|
||||
if msg, err := client1.RunUntilMessage(ctx); assert.NoError(err) {
|
||||
assert.NoError(checkMessageInCallAll(msg, roomId, 0))
|
||||
}
|
||||
|
||||
if msg, err := client1.RunUntilMessage(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if err := checkMessageInCallAll(msg, roomId, 0); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if msg, err := client2.RunUntilMessage(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if err := checkMessageInCallAll(msg, roomId, 0); err != nil {
|
||||
t.Fatal(err)
|
||||
if msg, err := client2.RunUntilMessage(ctx); assert.NoError(err) {
|
||||
assert.NoError(checkMessageInCallAll(msg, roomId, 0))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,13 +23,13 @@ package signaling
|
|||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBuiltinRoomSessions(t *testing.T) {
|
||||
sessions, err := NewBuiltinRoomSessions(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
testRoomSessions(t, sessions)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,9 +24,11 @@ package signaling
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type DummySession struct {
|
||||
|
|
@ -107,70 +109,53 @@ func checkSession(t *testing.T, sessions RoomSessions, sessionId string, roomSes
|
|||
session := &DummySession{
|
||||
publicId: sessionId,
|
||||
}
|
||||
if err := sessions.SetRoomSession(session, roomSessionId); err != nil {
|
||||
t.Fatalf("Expected no error, got %s", err)
|
||||
}
|
||||
if sid, err := sessions.GetSessionId(roomSessionId); err != nil {
|
||||
t.Errorf("Expected session id %s, got error %s", sessionId, err)
|
||||
} else if sid != sessionId {
|
||||
t.Errorf("Expected session id %s, got %s", sessionId, sid)
|
||||
require.NoError(t, sessions.SetRoomSession(session, roomSessionId))
|
||||
if sid, err := sessions.GetSessionId(roomSessionId); assert.NoError(t, err) {
|
||||
assert.Equal(t, sessionId, sid)
|
||||
}
|
||||
return session
|
||||
}
|
||||
|
||||
func testRoomSessions(t *testing.T, sessions RoomSessions) {
|
||||
if sid, err := sessions.GetSessionId("unknown"); err != nil && err != ErrNoSuchRoomSession {
|
||||
t.Errorf("Expected error about invalid room session, got %s", err)
|
||||
} else if err == nil {
|
||||
t.Errorf("Expected error about invalid room session, got session id %s", sid)
|
||||
assert := assert.New(t)
|
||||
if sid, err := sessions.GetSessionId("unknown"); err == nil {
|
||||
assert.Fail("Expected error about invalid room session, got session id %s", sid)
|
||||
} else {
|
||||
assert.ErrorIs(err, ErrNoSuchRoomSession)
|
||||
}
|
||||
|
||||
s1 := checkSession(t, sessions, "session1", "room1")
|
||||
s2 := checkSession(t, sessions, "session2", "room2")
|
||||
|
||||
if sid, err := sessions.GetSessionId("room1"); err != nil {
|
||||
t.Errorf("Expected session id %s, got error %s", s1.PublicId(), err)
|
||||
} else if sid != s1.PublicId() {
|
||||
t.Errorf("Expected session id %s, got %s", s1.PublicId(), sid)
|
||||
if sid, err := sessions.GetSessionId("room1"); assert.NoError(err) {
|
||||
assert.Equal(s1.PublicId(), sid)
|
||||
}
|
||||
|
||||
sessions.DeleteRoomSession(s1)
|
||||
if sid, err := sessions.GetSessionId("room1"); err != nil && err != ErrNoSuchRoomSession {
|
||||
t.Errorf("Expected error about invalid room session, got %s", err)
|
||||
} else if err == nil {
|
||||
t.Errorf("Expected error about invalid room session, got session id %s", sid)
|
||||
if sid, err := sessions.GetSessionId("room1"); err == nil {
|
||||
assert.Fail("Expected error about invalid room session, got session id %s", sid)
|
||||
} else {
|
||||
assert.ErrorIs(err, ErrNoSuchRoomSession)
|
||||
}
|
||||
if sid, err := sessions.GetSessionId("room2"); err != nil {
|
||||
t.Errorf("Expected session id %s, got error %s", s2.PublicId(), err)
|
||||
} else if sid != s2.PublicId() {
|
||||
t.Errorf("Expected session id %s, got %s", s2.PublicId(), sid)
|
||||
if sid, err := sessions.GetSessionId("room2"); assert.NoError(err) {
|
||||
assert.Equal(s2.PublicId(), sid)
|
||||
}
|
||||
|
||||
if err := sessions.SetRoomSession(s1, "room-session"); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := sessions.SetRoomSession(s2, "room-session"); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(sessions.SetRoomSession(s1, "room-session"))
|
||||
assert.NoError(sessions.SetRoomSession(s2, "room-session"))
|
||||
sessions.DeleteRoomSession(s1)
|
||||
if sid, err := sessions.GetSessionId("room-session"); err != nil {
|
||||
t.Errorf("Expected session id %s, got error %s", s2.PublicId(), err)
|
||||
} else if sid != s2.PublicId() {
|
||||
t.Errorf("Expected session id %s, got %s", s2.PublicId(), sid)
|
||||
if sid, err := sessions.GetSessionId("room-session"); assert.NoError(err) {
|
||||
assert.Equal(s2.PublicId(), sid)
|
||||
}
|
||||
|
||||
if err := sessions.SetRoomSession(s2, "room-session2"); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(sessions.SetRoomSession(s2, "room-session2"))
|
||||
if sid, err := sessions.GetSessionId("room-session"); err == nil {
|
||||
t.Errorf("expected error %s, got sid %s", ErrNoSuchRoomSession, sid)
|
||||
} else if !errors.Is(err, ErrNoSuchRoomSession) {
|
||||
t.Errorf("expected %s, got %s", ErrNoSuchRoomSession, err)
|
||||
assert.Fail("Expected error about invalid room session, got session id %s", sid)
|
||||
} else {
|
||||
assert.ErrorIs(err, ErrNoSuchRoomSession)
|
||||
}
|
||||
|
||||
if sid, err := sessions.GetSessionId("room-session2"); err != nil {
|
||||
t.Errorf("Expected session id %s, got error %s", s2.PublicId(), err)
|
||||
} else if sid != s2.PublicId() {
|
||||
t.Errorf("Expected session id %s, got %s", s2.PublicId(), sid)
|
||||
if sid, err := sessions.GetSessionId("room-session2"); assert.NoError(err) {
|
||||
assert.Equal(s2.PublicId(), sid)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,18 +23,16 @@ package signaling
|
|||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func assertSessionHasPermission(t *testing.T, session Session, permission Permission) {
|
||||
t.Helper()
|
||||
if !session.HasPermission(permission) {
|
||||
t.Errorf("Session %s doesn't have permission %s", session.PublicId(), permission)
|
||||
}
|
||||
assert.True(t, session.HasPermission(permission), "Session %s doesn't have permission %s", session.PublicId(), permission)
|
||||
}
|
||||
|
||||
func assertSessionHasNotPermission(t *testing.T, session Session, permission Permission) {
|
||||
t.Helper()
|
||||
if session.HasPermission(permission) {
|
||||
t.Errorf("Session %s has permission %s but shouldn't", session.PublicId(), permission)
|
||||
}
|
||||
assert.False(t, session.HasPermission(permission), "Session %s has permission %s but shouldn't", session.PublicId(), permission)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ import (
|
|||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSingleNotifierNoWaiter(t *testing.T) {
|
||||
|
|
@ -48,9 +50,7 @@ func TestSingleNotifierSimple(t *testing.T) {
|
|||
defer wg.Done()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
if err := waiter.Wait(ctx); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, waiter.Wait(ctx))
|
||||
}()
|
||||
|
||||
notifier.Notify()
|
||||
|
|
@ -74,9 +74,7 @@ func TestSingleNotifierWaitClosed(t *testing.T) {
|
|||
waiter := notifier.NewWaiter()
|
||||
notifier.Release(waiter)
|
||||
|
||||
if err := waiter.Wait(context.Background()); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, waiter.Wait(context.Background()))
|
||||
}
|
||||
|
||||
func TestSingleNotifierWaitClosedMulti(t *testing.T) {
|
||||
|
|
@ -87,12 +85,8 @@ func TestSingleNotifierWaitClosedMulti(t *testing.T) {
|
|||
notifier.Release(waiter1)
|
||||
notifier.Release(waiter2)
|
||||
|
||||
if err := waiter1.Wait(context.Background()); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := waiter2.Wait(context.Background()); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, waiter1.Wait(context.Background()))
|
||||
assert.NoError(t, waiter2.Wait(context.Background()))
|
||||
}
|
||||
|
||||
func TestSingleNotifierResetWillNotify(t *testing.T) {
|
||||
|
|
@ -108,9 +102,7 @@ func TestSingleNotifierResetWillNotify(t *testing.T) {
|
|||
defer wg.Done()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
if err := waiter.Wait(ctx); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, waiter.Wait(ctx))
|
||||
}()
|
||||
|
||||
notifier.Reset()
|
||||
|
|
@ -137,9 +129,7 @@ func TestSingleNotifierDuplicate(t *testing.T) {
|
|||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
if err := waiter.Wait(ctx); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, waiter.Wait(ctx))
|
||||
}()
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ import (
|
|||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func checkStatsValue(t *testing.T, collector prometheus.Collector, value float64) {
|
||||
|
|
@ -37,10 +38,11 @@ func checkStatsValue(t *testing.T, collector prometheus.Collector, value float64
|
|||
desc := <-ch
|
||||
v := testutil.ToFloat64(collector)
|
||||
if v != value {
|
||||
assert := assert.New(t)
|
||||
pc := make([]uintptr, 10)
|
||||
n := runtime.Callers(2, pc)
|
||||
if n == 0 {
|
||||
t.Errorf("Expected value %f for %s, got %f", value, desc, v)
|
||||
assert.Fail("Expected value %f for %s, got %f", value, desc, v)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -57,20 +59,20 @@ func checkStatsValue(t *testing.T, collector prometheus.Collector, value float64
|
|||
break
|
||||
}
|
||||
}
|
||||
t.Errorf("Expected value %f for %s, got %f at\n%s", value, desc, v, stack)
|
||||
assert.Fail("Expected value %f for %s, got %f at\n%s", value, desc, v, stack)
|
||||
}
|
||||
}
|
||||
|
||||
func collectAndLint(t *testing.T, collectors ...prometheus.Collector) {
|
||||
assert := assert.New(t)
|
||||
for _, collector := range collectors {
|
||||
problems, err := testutil.CollectAndLint(collector)
|
||||
if err != nil {
|
||||
t.Errorf("Error linting %+v: %s", collector, err)
|
||||
if !assert.NoError(err) {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, problem := range problems {
|
||||
t.Errorf("Problem with %s: %s", problem.Metric, problem.Text)
|
||||
assert.Fail("Problem with %s: %s", problem.Metric, problem.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -40,6 +40,8 @@ import (
|
|||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
@ -225,9 +227,7 @@ type TestClient struct {
|
|||
func NewTestClientContext(ctx context.Context, t *testing.T, server *httptest.Server, hub *Hub) *TestClient {
|
||||
// Reference "hub" to prevent compiler error.
|
||||
conn, _, err := websocket.DefaultDialer.DialContext(ctx, getWebsocketUrl(server.URL), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
messageChan := make(chan []byte)
|
||||
readErrorChan := make(chan error, 1)
|
||||
|
|
@ -238,8 +238,7 @@ func NewTestClientContext(ctx context.Context, t *testing.T, server *httptest.Se
|
|||
if err != nil {
|
||||
readErrorChan <- err
|
||||
return
|
||||
} else if messageType != websocket.TextMessage {
|
||||
t.Errorf("Expect text message, got %d", messageType)
|
||||
} else if !assert.Equal(t, websocket.TextMessage, messageType) {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -266,13 +265,8 @@ func NewTestClient(t *testing.T, server *httptest.Server, hub *Hub) *TestClient
|
|||
|
||||
client := NewTestClientContext(ctx, t, server, hub)
|
||||
msg, err := client.RunUntilMessage(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if msg.Type != "welcome" {
|
||||
t.Errorf("Expected welcome message, got %+v", msg)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "welcome", msg.Type)
|
||||
return client
|
||||
}
|
||||
|
||||
|
|
@ -380,9 +374,7 @@ func (c *TestClient) WriteJSON(data interface{}) error {
|
|||
}
|
||||
|
||||
func (c *TestClient) EnsuerWriteJSON(data interface{}) {
|
||||
if err := c.WriteJSON(data); err != nil {
|
||||
c.t.Fatalf("Could not write JSON %+v: %s", data, err)
|
||||
}
|
||||
require.NoError(c.t, c.WriteJSON(data), "Could not write JSON %+v", data)
|
||||
}
|
||||
|
||||
func (c *TestClient) SendHello(userid string) error {
|
||||
|
|
@ -443,9 +435,7 @@ func (c *TestClient) CreateHelloV2Token(userid string, issuedAt time.Time, expir
|
|||
|
||||
func (c *TestClient) SendHelloV2WithTimes(userid string, issuedAt time.Time, expiresAt time.Time) error {
|
||||
tokenString, err := c.CreateHelloV2Token(userid, issuedAt, expiresAt)
|
||||
if err != nil {
|
||||
c.t.Fatal(err)
|
||||
}
|
||||
require.NoError(c.t, err)
|
||||
|
||||
params := HelloV2AuthParams{
|
||||
Token: tokenString,
|
||||
|
|
@ -493,9 +483,7 @@ func (c *TestClient) SendHelloInternalWithFeatures(features []string) error {
|
|||
|
||||
func (c *TestClient) SendHelloParams(url string, version string, clientType string, features []string, params interface{}) error {
|
||||
data, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
c.t.Fatal(err)
|
||||
}
|
||||
require.NoError(c.t, err)
|
||||
|
||||
hello := &ClientMessage{
|
||||
Id: "1234",
|
||||
|
|
@ -524,9 +512,7 @@ func (c *TestClient) SendBye() error {
|
|||
|
||||
func (c *TestClient) SendMessage(recipient MessageClientMessageRecipient, data interface{}) error {
|
||||
payload, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
c.t.Fatal(err)
|
||||
}
|
||||
require.NoError(c.t, err)
|
||||
|
||||
message := &ClientMessage{
|
||||
Id: "abcd",
|
||||
|
|
@ -541,9 +527,7 @@ func (c *TestClient) SendMessage(recipient MessageClientMessageRecipient, data i
|
|||
|
||||
func (c *TestClient) SendControl(recipient MessageClientMessageRecipient, data interface{}) error {
|
||||
payload, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
c.t.Fatal(err)
|
||||
}
|
||||
require.NoError(c.t, err)
|
||||
|
||||
message := &ClientMessage{
|
||||
Id: "abcd",
|
||||
|
|
@ -608,9 +592,7 @@ func (c *TestClient) SendInternalDialout(msg *DialoutInternalClientMessage) erro
|
|||
|
||||
func (c *TestClient) SetTransientData(key string, value interface{}, ttl time.Duration) error {
|
||||
payload, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
c.t.Fatal(err)
|
||||
}
|
||||
require.NoError(c.t, err)
|
||||
|
||||
message := &ClientMessage{
|
||||
Id: "efgh",
|
||||
|
|
|
|||
|
|
@ -30,6 +30,8 @@ import (
|
|||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var listenSignalOnce sync.Once
|
||||
|
|
@ -76,7 +78,7 @@ func ensureNoGoroutinesLeak(t *testing.T, f func(t *testing.T)) {
|
|||
if after != before {
|
||||
io.Copy(os.Stderr, &prev) // nolint
|
||||
dumpGoroutines("After:", os.Stderr)
|
||||
t.Fatalf("Number of Go routines has changed from %d to %d", before, after)
|
||||
require.Equal(t, before, after, "Number of Go routines has changed")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
119
throttle_test.go
119
throttle_test.go
|
|
@ -25,14 +25,15 @@ import (
|
|||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newMemoryThrottlerForTest(t *testing.T) *memoryThrottler {
|
||||
t.Helper()
|
||||
result, err := NewMemoryThrottler()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
result.Close()
|
||||
|
|
@ -54,12 +55,11 @@ func (t *throttlerTiming) getNow() time.Time {
|
|||
|
||||
func (t *throttlerTiming) doDelay(ctx context.Context, duration time.Duration) {
|
||||
t.t.Helper()
|
||||
if duration != t.expectedSleep {
|
||||
t.t.Errorf("expected sleep %s, got %s", t.expectedSleep, duration)
|
||||
}
|
||||
assert.Equal(t.t, t.expectedSleep, duration)
|
||||
}
|
||||
|
||||
func TestThrottler(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
timing := &throttlerTiming{
|
||||
t: t,
|
||||
now: time.Now(),
|
||||
|
|
@ -71,38 +71,31 @@ func TestThrottler(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
|
||||
throttle1, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(err)
|
||||
timing.expectedSleep = 100 * time.Millisecond
|
||||
throttle1(ctx)
|
||||
|
||||
timing.now = timing.now.Add(time.Millisecond)
|
||||
throttle2, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(err)
|
||||
timing.expectedSleep = 200 * time.Millisecond
|
||||
throttle2(ctx)
|
||||
|
||||
timing.now = timing.now.Add(time.Millisecond)
|
||||
throttle3, err := th.CheckBruteforce(ctx, "192.168.0.2", "action1")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(err)
|
||||
timing.expectedSleep = 100 * time.Millisecond
|
||||
throttle3(ctx)
|
||||
|
||||
timing.now = timing.now.Add(time.Millisecond)
|
||||
throttle4, err := th.CheckBruteforce(ctx, "192.168.0.1", "action2")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(err)
|
||||
timing.expectedSleep = 100 * time.Millisecond
|
||||
throttle4(ctx)
|
||||
}
|
||||
|
||||
func TestThrottlerIPv6(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
timing := &throttlerTiming{
|
||||
t: t,
|
||||
now: time.Now(),
|
||||
|
|
@ -115,40 +108,33 @@ func TestThrottlerIPv6(t *testing.T) {
|
|||
|
||||
// Make sure full /64 subnets are throttled for IPv6.
|
||||
throttle1, err := th.CheckBruteforce(ctx, "2001:db8:abcd:0012::1", "action1")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(err)
|
||||
timing.expectedSleep = 100 * time.Millisecond
|
||||
throttle1(ctx)
|
||||
|
||||
timing.now = timing.now.Add(time.Millisecond)
|
||||
throttle2, err := th.CheckBruteforce(ctx, "2001:db8:abcd:0012::2", "action1")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(err)
|
||||
timing.expectedSleep = 200 * time.Millisecond
|
||||
throttle2(ctx)
|
||||
|
||||
// A diffent /64 subnet is not throttled yet.
|
||||
timing.now = timing.now.Add(time.Millisecond)
|
||||
throttle3, err := th.CheckBruteforce(ctx, "2001:db8:abcd:0013::1", "action1")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(err)
|
||||
timing.expectedSleep = 100 * time.Millisecond
|
||||
throttle3(ctx)
|
||||
|
||||
// A different action is not throttled.
|
||||
timing.now = timing.now.Add(time.Millisecond)
|
||||
throttle4, err := th.CheckBruteforce(ctx, "2001:db8:abcd:0012::1", "action2")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(err)
|
||||
timing.expectedSleep = 100 * time.Millisecond
|
||||
throttle4(ctx)
|
||||
}
|
||||
|
||||
func TestThrottler_Bruteforce(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
timing := &throttlerTiming{
|
||||
t: t,
|
||||
now: time.Now(),
|
||||
|
|
@ -162,9 +148,7 @@ func TestThrottler_Bruteforce(t *testing.T) {
|
|||
for i := 0; i < maxBruteforceAttempts; i++ {
|
||||
timing.now = timing.now.Add(time.Millisecond)
|
||||
throttle, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(err)
|
||||
if i == 0 {
|
||||
timing.expectedSleep = 100 * time.Millisecond
|
||||
} else {
|
||||
|
|
@ -177,14 +161,12 @@ func TestThrottler_Bruteforce(t *testing.T) {
|
|||
}
|
||||
|
||||
timing.now = timing.now.Add(time.Millisecond)
|
||||
if _, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1"); err == nil {
|
||||
t.Error("expected bruteforce error")
|
||||
} else if err != ErrBruteforceDetected {
|
||||
t.Errorf("expected error %s, got %s", ErrBruteforceDetected, err)
|
||||
}
|
||||
_, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
|
||||
assert.ErrorIs(err, ErrBruteforceDetected)
|
||||
}
|
||||
|
||||
func TestThrottler_Cleanup(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
timing := &throttlerTiming{
|
||||
t: t,
|
||||
now: time.Now(),
|
||||
|
|
@ -196,59 +178,46 @@ func TestThrottler_Cleanup(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
|
||||
throttle1, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(err)
|
||||
timing.expectedSleep = 100 * time.Millisecond
|
||||
throttle1(ctx)
|
||||
|
||||
throttle2, err := th.CheckBruteforce(ctx, "192.168.0.2", "action1")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(err)
|
||||
timing.expectedSleep = 100 * time.Millisecond
|
||||
throttle2(ctx)
|
||||
|
||||
timing.now = timing.now.Add(time.Hour)
|
||||
throttle3, err := th.CheckBruteforce(ctx, "192.168.0.1", "action2")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(err)
|
||||
timing.expectedSleep = 100 * time.Millisecond
|
||||
throttle3(ctx)
|
||||
|
||||
throttle4, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(err)
|
||||
timing.expectedSleep = 200 * time.Millisecond
|
||||
throttle4(ctx)
|
||||
|
||||
timing.now = timing.now.Add(-time.Hour).Add(maxBruteforceAge).Add(time.Second)
|
||||
th.cleanup(timing.now)
|
||||
|
||||
if entries := th.getEntries("192.168.0.1", "action1"); len(entries) != 1 {
|
||||
t.Errorf("should have removed one entry, got %+v", entries)
|
||||
}
|
||||
if entries := th.getEntries("192.168.0.1", "action2"); len(entries) != 1 {
|
||||
t.Errorf("should have kept entry, got %+v", entries)
|
||||
}
|
||||
assert.Len(th.getEntries("192.168.0.1", "action1"), 1)
|
||||
assert.Len(th.getEntries("192.168.0.1", "action2"), 1)
|
||||
|
||||
th.mu.RLock()
|
||||
if _, found := th.clients["192.168.0.2"]; found {
|
||||
t.Error("should have removed client \"192.168.0.2\"")
|
||||
assert.Fail("should have removed client \"192.168.0.2\"")
|
||||
}
|
||||
th.mu.RUnlock()
|
||||
|
||||
throttle5, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(err)
|
||||
timing.expectedSleep = 200 * time.Millisecond
|
||||
throttle5(ctx)
|
||||
}
|
||||
|
||||
func TestThrottler_ExpirePartial(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
timing := &throttlerTiming{
|
||||
t: t,
|
||||
now: time.Now(),
|
||||
|
|
@ -260,32 +229,27 @@ func TestThrottler_ExpirePartial(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
|
||||
throttle1, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(err)
|
||||
timing.expectedSleep = 100 * time.Millisecond
|
||||
throttle1(ctx)
|
||||
|
||||
timing.now = timing.now.Add(time.Minute)
|
||||
|
||||
throttle2, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(err)
|
||||
timing.expectedSleep = 200 * time.Millisecond
|
||||
throttle2(ctx)
|
||||
|
||||
timing.now = timing.now.Add(maxBruteforceAge).Add(-time.Minute + time.Second)
|
||||
|
||||
throttle3, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(err)
|
||||
timing.expectedSleep = 200 * time.Millisecond
|
||||
throttle3(ctx)
|
||||
}
|
||||
|
||||
func TestThrottler_ExpireAll(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
timing := &throttlerTiming{
|
||||
t: t,
|
||||
now: time.Now(),
|
||||
|
|
@ -297,32 +261,27 @@ func TestThrottler_ExpireAll(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
|
||||
throttle1, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(err)
|
||||
timing.expectedSleep = 100 * time.Millisecond
|
||||
throttle1(ctx)
|
||||
|
||||
timing.now = timing.now.Add(time.Millisecond)
|
||||
|
||||
throttle2, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(err)
|
||||
timing.expectedSleep = 200 * time.Millisecond
|
||||
throttle2(ctx)
|
||||
|
||||
timing.now = timing.now.Add(maxBruteforceAge).Add(time.Second)
|
||||
|
||||
throttle3, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(err)
|
||||
timing.expectedSleep = 100 * time.Millisecond
|
||||
throttle3(ctx)
|
||||
}
|
||||
|
||||
func TestThrottler_Negative(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
timing := &throttlerTiming{
|
||||
t: t,
|
||||
now: time.Now(),
|
||||
|
|
@ -336,8 +295,8 @@ func TestThrottler_Negative(t *testing.T) {
|
|||
for i := 0; i < maxBruteforceAttempts*10; i++ {
|
||||
timing.now = timing.now.Add(time.Millisecond)
|
||||
throttle, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
|
||||
if err != nil && err != ErrBruteforceDetected {
|
||||
t.Error(err)
|
||||
if err != nil {
|
||||
assert.ErrorIs(err, ErrBruteforceDetected)
|
||||
}
|
||||
if i == 0 {
|
||||
timing.expectedSleep = 100 * time.Millisecond
|
||||
|
|
|
|||
|
|
@ -25,6 +25,9 @@ import (
|
|||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func (t *TransientData) SetTTLChannel(ch chan<- struct{}) {
|
||||
|
|
@ -35,106 +38,57 @@ func (t *TransientData) SetTTLChannel(ch chan<- struct{}) {
|
|||
}
|
||||
|
||||
func Test_TransientData(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
data := NewTransientData()
|
||||
if data.Set("foo", nil) {
|
||||
t.Errorf("should not have set value")
|
||||
}
|
||||
if !data.Set("foo", "bar") {
|
||||
t.Errorf("should have set value")
|
||||
}
|
||||
if data.Set("foo", "bar") {
|
||||
t.Errorf("should not have set value")
|
||||
}
|
||||
if !data.Set("foo", "baz") {
|
||||
t.Errorf("should have set value")
|
||||
}
|
||||
if data.CompareAndSet("foo", "bar", "lala") {
|
||||
t.Errorf("should not have set value")
|
||||
}
|
||||
if !data.CompareAndSet("foo", "baz", "lala") {
|
||||
t.Errorf("should have set value")
|
||||
}
|
||||
if data.CompareAndSet("test", nil, nil) {
|
||||
t.Errorf("should not have set value")
|
||||
}
|
||||
if !data.CompareAndSet("test", nil, "123") {
|
||||
t.Errorf("should have set value")
|
||||
}
|
||||
if data.CompareAndSet("test", nil, "456") {
|
||||
t.Errorf("should not have set value")
|
||||
}
|
||||
if data.CompareAndRemove("test", "1234") {
|
||||
t.Errorf("should not have removed value")
|
||||
}
|
||||
if !data.CompareAndRemove("test", "123") {
|
||||
t.Errorf("should have removed value")
|
||||
}
|
||||
if data.Remove("lala") {
|
||||
t.Errorf("should not have removed value")
|
||||
}
|
||||
if !data.Remove("foo") {
|
||||
t.Errorf("should have removed value")
|
||||
}
|
||||
assert.False(data.Set("foo", nil))
|
||||
assert.True(data.Set("foo", "bar"))
|
||||
assert.False(data.Set("foo", "bar"))
|
||||
assert.True(data.Set("foo", "baz"))
|
||||
assert.False(data.CompareAndSet("foo", "bar", "lala"))
|
||||
assert.True(data.CompareAndSet("foo", "baz", "lala"))
|
||||
assert.False(data.CompareAndSet("test", nil, nil))
|
||||
assert.True(data.CompareAndSet("test", nil, "123"))
|
||||
assert.False(data.CompareAndSet("test", nil, "456"))
|
||||
assert.False(data.CompareAndRemove("test", "1234"))
|
||||
assert.True(data.CompareAndRemove("test", "123"))
|
||||
assert.False(data.Remove("lala"))
|
||||
assert.True(data.Remove("foo"))
|
||||
|
||||
ttlCh := make(chan struct{})
|
||||
data.SetTTLChannel(ttlCh)
|
||||
if !data.SetTTL("test", "1234", time.Millisecond) {
|
||||
t.Errorf("should have set value")
|
||||
}
|
||||
if value := data.GetData()["test"]; value != "1234" {
|
||||
t.Errorf("expected 1234, got %v", value)
|
||||
}
|
||||
assert.True(data.SetTTL("test", "1234", time.Millisecond))
|
||||
assert.Equal("1234", data.GetData()["test"])
|
||||
// Data is removed after the TTL
|
||||
<-ttlCh
|
||||
if value := data.GetData()["test"]; value != nil {
|
||||
t.Errorf("expected no value, got %v", value)
|
||||
}
|
||||
assert.Nil(data.GetData()["test"])
|
||||
|
||||
if !data.SetTTL("test", "1234", time.Millisecond) {
|
||||
t.Errorf("should have set value")
|
||||
}
|
||||
if value := data.GetData()["test"]; value != "1234" {
|
||||
t.Errorf("expected 1234, got %v", value)
|
||||
}
|
||||
if !data.SetTTL("test", "2345", 3*time.Millisecond) {
|
||||
t.Errorf("should have set value")
|
||||
}
|
||||
if value := data.GetData()["test"]; value != "2345" {
|
||||
t.Errorf("expected 2345, got %v", value)
|
||||
}
|
||||
assert.True(data.SetTTL("test", "1234", time.Millisecond))
|
||||
assert.Equal("1234", data.GetData()["test"])
|
||||
assert.True(data.SetTTL("test", "2345", 3*time.Millisecond))
|
||||
assert.Equal("2345", data.GetData()["test"])
|
||||
// Data is removed after the TTL only if the value still matches
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
if value := data.GetData()["test"]; value != "2345" {
|
||||
t.Errorf("expected 2345, got %v", value)
|
||||
}
|
||||
assert.Equal("2345", data.GetData()["test"])
|
||||
// Data is removed after the (second) TTL
|
||||
<-ttlCh
|
||||
if value := data.GetData()["test"]; value != nil {
|
||||
t.Errorf("expected no value, got %v", value)
|
||||
}
|
||||
assert.Nil(data.GetData()["test"])
|
||||
|
||||
// Setting existing key will update the TTL
|
||||
if !data.SetTTL("test", "1234", time.Millisecond) {
|
||||
t.Errorf("should have set value")
|
||||
}
|
||||
if data.SetTTL("test", "1234", 3*time.Millisecond) {
|
||||
t.Errorf("should not have set value")
|
||||
}
|
||||
assert.True(data.SetTTL("test", "1234", time.Millisecond))
|
||||
assert.False(data.SetTTL("test", "1234", 3*time.Millisecond))
|
||||
// Data still exists after the first TTL
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
if value := data.GetData()["test"]; value != "1234" {
|
||||
t.Errorf("expected 1234, got %v", value)
|
||||
}
|
||||
assert.Equal("1234", data.GetData()["test"])
|
||||
// Data is removed after the (updated) TTL
|
||||
<-ttlCh
|
||||
if value := data.GetData()["test"]; value != nil {
|
||||
t.Errorf("expected no value, got %v", value)
|
||||
}
|
||||
assert.Nil(data.GetData()["test"])
|
||||
}
|
||||
|
||||
func Test_TransientMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
CatchLogForTest(t)
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
hub, _, _, server := CreateHubForTest(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
|
|
@ -142,226 +96,139 @@ func Test_TransientMessages(t *testing.T) {
|
|||
|
||||
client1 := NewTestClient(t, server, hub)
|
||||
defer client1.CloseWithBye()
|
||||
if err := client1.SendHello(testDefaultUserId + "1"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(client1.SendHello(testDefaultUserId + "1"))
|
||||
hello1, err := client1.RunUntilHello(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
if err := client1.SetTransientData("foo", "bar", 0); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if msg, err := client1.RunUntilMessage(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
} else {
|
||||
if err := checkMessageError(msg, "not_in_room"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(client1.SetTransientData("foo", "bar", 0))
|
||||
if msg, err := client1.RunUntilMessage(ctx); assert.NoError(err) {
|
||||
require.NoError(checkMessageError(msg, "not_in_room"))
|
||||
}
|
||||
|
||||
client2 := NewTestClient(t, server, hub)
|
||||
defer client2.CloseWithBye()
|
||||
if err := client2.SendHello(testDefaultUserId + "2"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(client2.SendHello(testDefaultUserId + "2"))
|
||||
hello2, err := client2.RunUntilHello(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
// Join room by id.
|
||||
roomId := "test-room"
|
||||
if room, err := client1.JoinRoom(ctx, roomId); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if room.Room.RoomId != roomId {
|
||||
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
|
||||
}
|
||||
roomMsg, err := client1.JoinRoom(ctx, roomId)
|
||||
require.NoError(err)
|
||||
require.Equal(roomId, roomMsg.Room.RoomId)
|
||||
|
||||
// Give message processing some time.
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
if room, err := client2.JoinRoom(ctx, roomId); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if room.Room.RoomId != roomId {
|
||||
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
|
||||
}
|
||||
roomMsg, err = client2.JoinRoom(ctx, roomId)
|
||||
require.NoError(err)
|
||||
require.Equal(roomId, roomMsg.Room.RoomId)
|
||||
|
||||
WaitForUsersJoined(ctx, t, client1, hello1, client2, hello2)
|
||||
|
||||
session1 := hub.GetSessionByPublicId(hello1.Hello.SessionId).(*ClientSession)
|
||||
if session1 == nil {
|
||||
t.Fatalf("Session %s does not exist", hello1.Hello.SessionId)
|
||||
}
|
||||
require.NotNil(session1, "Session %s does not exist", hello1.Hello.SessionId)
|
||||
session2 := hub.GetSessionByPublicId(hello2.Hello.SessionId).(*ClientSession)
|
||||
if session2 == nil {
|
||||
t.Fatalf("Session %s does not exist", hello2.Hello.SessionId)
|
||||
}
|
||||
require.NotNil(session2, "Session %s does not exist", hello2.Hello.SessionId)
|
||||
|
||||
// Client 1 may modify transient data.
|
||||
session1.SetPermissions([]Permission{PERMISSION_TRANSIENT_DATA})
|
||||
// Client 2 may not modify transient data.
|
||||
session2.SetPermissions([]Permission{})
|
||||
|
||||
if err := client2.SetTransientData("foo", "bar", 0); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if msg, err := client2.RunUntilMessage(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
} else {
|
||||
if err := checkMessageError(msg, "not_allowed"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(client2.SetTransientData("foo", "bar", 0))
|
||||
if msg, err := client2.RunUntilMessage(ctx); assert.NoError(err) {
|
||||
require.NoError(checkMessageError(msg, "not_allowed"))
|
||||
}
|
||||
|
||||
if err := client1.SetTransientData("foo", "bar", 0); err != nil {
|
||||
t.Fatal(err)
|
||||
require.NoError(client1.SetTransientData("foo", "bar", 0))
|
||||
|
||||
if msg, err := client1.RunUntilMessage(ctx); assert.NoError(err) {
|
||||
require.NoError(checkMessageTransientSet(msg, "foo", "bar", nil))
|
||||
}
|
||||
if msg, err := client2.RunUntilMessage(ctx); assert.NoError(err) {
|
||||
require.NoError(checkMessageTransientSet(msg, "foo", "bar", nil))
|
||||
}
|
||||
|
||||
if msg, err := client1.RunUntilMessage(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
} else {
|
||||
if err := checkMessageTransientSet(msg, "foo", "bar", nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
if msg, err := client2.RunUntilMessage(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
} else {
|
||||
if err := checkMessageTransientSet(msg, "foo", "bar", nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := client2.RemoveTransientData("foo"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if msg, err := client2.RunUntilMessage(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
} else {
|
||||
if err := checkMessageError(msg, "not_allowed"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(client2.RemoveTransientData("foo"))
|
||||
if msg, err := client2.RunUntilMessage(ctx); assert.NoError(err) {
|
||||
require.NoError(checkMessageError(msg, "not_allowed"))
|
||||
}
|
||||
|
||||
// Setting the same value is ignored by the server.
|
||||
if err := client1.SetTransientData("foo", "bar", 0); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(client1.SetTransientData("foo", "bar", 0))
|
||||
ctx2, cancel2 := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel2()
|
||||
|
||||
if msg, err := client1.RunUntilMessage(ctx2); err != nil {
|
||||
if err != context.DeadlineExceeded {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if msg, err := client1.RunUntilMessage(ctx2); err == nil {
|
||||
assert.Fail("Expected no payload, got %+v", msg)
|
||||
} else {
|
||||
t.Errorf("Expected no payload, got %+v", msg)
|
||||
require.ErrorIs(err, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"hello": "world",
|
||||
}
|
||||
if err := client1.SetTransientData("foo", data, 0); err != nil {
|
||||
t.Fatal(err)
|
||||
require.NoError(client1.SetTransientData("foo", data, 0))
|
||||
|
||||
if msg, err := client1.RunUntilMessage(ctx); assert.NoError(err) {
|
||||
require.NoError(checkMessageTransientSet(msg, "foo", data, "bar"))
|
||||
}
|
||||
if msg, err := client2.RunUntilMessage(ctx); assert.NoError(err) {
|
||||
require.NoError(checkMessageTransientSet(msg, "foo", data, "bar"))
|
||||
}
|
||||
|
||||
if msg, err := client1.RunUntilMessage(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
} else {
|
||||
if err := checkMessageTransientSet(msg, "foo", data, "bar"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
if msg, err := client2.RunUntilMessage(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
} else {
|
||||
if err := checkMessageTransientSet(msg, "foo", data, "bar"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
require.NoError(client1.RemoveTransientData("foo"))
|
||||
|
||||
if err := client1.RemoveTransientData("foo"); err != nil {
|
||||
t.Fatal(err)
|
||||
if msg, err := client1.RunUntilMessage(ctx); assert.NoError(err) {
|
||||
require.NoError(checkMessageTransientRemove(msg, "foo", data))
|
||||
}
|
||||
|
||||
if msg, err := client1.RunUntilMessage(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
} else {
|
||||
if err := checkMessageTransientRemove(msg, "foo", data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
if msg, err := client2.RunUntilMessage(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
} else {
|
||||
if err := checkMessageTransientRemove(msg, "foo", data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if msg, err := client2.RunUntilMessage(ctx); assert.NoError(err) {
|
||||
require.NoError(checkMessageTransientRemove(msg, "foo", data))
|
||||
}
|
||||
|
||||
// Removing a non-existing key is ignored by the server.
|
||||
if err := client1.RemoveTransientData("foo"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(client1.RemoveTransientData("foo"))
|
||||
ctx3, cancel3 := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel3()
|
||||
|
||||
if msg, err := client1.RunUntilMessage(ctx3); err != nil {
|
||||
if err != context.DeadlineExceeded {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if msg, err := client1.RunUntilMessage(ctx3); err == nil {
|
||||
assert.Fail("Expected no payload, got %+v", msg)
|
||||
} else {
|
||||
t.Errorf("Expected no payload, got %+v", msg)
|
||||
require.ErrorIs(err, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
if err := client1.SetTransientData("abc", data, 10*time.Millisecond); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(client1.SetTransientData("abc", data, 10*time.Millisecond))
|
||||
|
||||
client3 := NewTestClient(t, server, hub)
|
||||
defer client3.CloseWithBye()
|
||||
if err := client3.SendHello(testDefaultUserId + "3"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(client3.SendHello(testDefaultUserId + "3"))
|
||||
hello3, err := client3.RunUntilHello(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
if room, err := client3.JoinRoom(ctx, roomId); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if room.Room.RoomId != roomId {
|
||||
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
|
||||
}
|
||||
roomMsg, err = client3.JoinRoom(ctx, roomId)
|
||||
require.NoError(err)
|
||||
require.Equal(roomId, roomMsg.Room.RoomId)
|
||||
|
||||
_, ignored, err := client3.RunUntilJoinedAndReturn(ctx, hello1.Hello, hello2.Hello, hello3.Hello)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
var msg *ServerMessage
|
||||
if len(ignored) == 0 {
|
||||
if msg, err = client3.RunUntilMessage(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
msg, err = client3.RunUntilMessage(ctx)
|
||||
require.NoError(err)
|
||||
} else if len(ignored) == 1 {
|
||||
msg = ignored[0]
|
||||
} else {
|
||||
t.Fatalf("Received too many messages: %+v", ignored)
|
||||
require.Fail("Received too many messages: %+v", ignored)
|
||||
}
|
||||
|
||||
if err := checkMessageTransientInitial(msg, map[string]interface{}{
|
||||
require.NoError(checkMessageTransientInitial(msg, map[string]interface{}{
|
||||
"abc": data,
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}))
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
if msg, err = client3.RunUntilMessage(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if err := checkMessageTransientRemove(msg, "abc", data); err != nil {
|
||||
t.Fatal(err)
|
||||
if msg, err = client3.RunUntilMessage(ctx); assert.NoError(err) {
|
||||
require.NoError(checkMessageTransientRemove(msg, "abc", data))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,11 +27,16 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestVirtualSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
CatchLogForTest(t)
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
hub, _, _, server := CreateHubForTest(t)
|
||||
|
||||
roomId := "the-room-id"
|
||||
|
|
@ -41,54 +46,34 @@ func TestVirtualSession(t *testing.T) {
|
|||
compat: true,
|
||||
}
|
||||
room, err := hub.createRoom(roomId, emptyProperties, backend)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not create room: %s", err)
|
||||
}
|
||||
require.NoError(err)
|
||||
defer room.Close()
|
||||
|
||||
clientInternal := NewTestClient(t, server, hub)
|
||||
defer clientInternal.CloseWithBye()
|
||||
if err := clientInternal.SendHelloInternal(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(clientInternal.SendHelloInternal())
|
||||
|
||||
client := NewTestClient(t, server, hub)
|
||||
defer client.CloseWithBye()
|
||||
if err := client.SendHello(testDefaultUserId); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(client.SendHello(testDefaultUserId))
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
|
||||
if hello, err := clientInternal.RunUntilHello(ctx); err != nil {
|
||||
t.Error(err)
|
||||
} else {
|
||||
if hello.Hello.UserId != "" {
|
||||
t.Errorf("Expected empty user id, got %+v", hello.Hello)
|
||||
}
|
||||
if hello.Hello.SessionId == "" {
|
||||
t.Errorf("Expected session id, got %+v", hello.Hello)
|
||||
}
|
||||
if hello.Hello.ResumeId == "" {
|
||||
t.Errorf("Expected resume id, got %+v", hello.Hello)
|
||||
}
|
||||
if hello, err := clientInternal.RunUntilHello(ctx); assert.NoError(err) {
|
||||
assert.Empty(hello.Hello.UserId)
|
||||
assert.NotEmpty(hello.Hello.SessionId)
|
||||
assert.NotEmpty(hello.Hello.ResumeId)
|
||||
}
|
||||
hello, err := client.RunUntilHello(ctx)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(err)
|
||||
|
||||
if room, err := client.JoinRoom(ctx, roomId); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if room.Room.RoomId != roomId {
|
||||
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
|
||||
}
|
||||
roomMsg, err := client.JoinRoom(ctx, roomId)
|
||||
require.NoError(err)
|
||||
require.Equal(roomId, roomMsg.Room.RoomId)
|
||||
|
||||
// Ignore "join" events.
|
||||
if err := client.DrainMessages(ctx); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(client.DrainMessages(ctx))
|
||||
|
||||
internalSessionId := "session1"
|
||||
userId := "user1"
|
||||
|
|
@ -106,64 +91,39 @@ func TestVirtualSession(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
if err := clientInternal.WriteJSON(msgAdd); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(clientInternal.WriteJSON(msgAdd))
|
||||
|
||||
msg1, err := client.RunUntilMessage(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
// The public session id will be generated by the server, so don't check for it.
|
||||
if err := client.checkMessageJoinedSession(msg1, "", userId); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(client.checkMessageJoinedSession(msg1, "", userId))
|
||||
sessionId := msg1.Event.Join[0].SessionId
|
||||
session := hub.GetSessionByPublicId(sessionId)
|
||||
if session == nil {
|
||||
t.Fatalf("Could not get virtual session %s", sessionId)
|
||||
}
|
||||
if session.ClientType() != HelloClientTypeVirtual {
|
||||
t.Errorf("Expected client type %s, got %s", HelloClientTypeVirtual, session.ClientType())
|
||||
}
|
||||
if sid := session.(*VirtualSession).SessionId(); sid != internalSessionId {
|
||||
t.Errorf("Expected internal session id %s, got %s", internalSessionId, sid)
|
||||
if assert.NotNil(session, "Could not get virtual session %s", sessionId) {
|
||||
assert.Equal(HelloClientTypeVirtual, session.ClientType())
|
||||
sid := session.(*VirtualSession).SessionId()
|
||||
assert.Equal(internalSessionId, sid)
|
||||
}
|
||||
|
||||
// Also a participants update event will be triggered for the virtual user.
|
||||
msg2, err := client.RunUntilMessage(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
updateMsg, err := checkMessageParticipantsInCall(msg2)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
} else if updateMsg.RoomId != roomId {
|
||||
t.Errorf("Expected room %s, got %s", roomId, updateMsg.RoomId)
|
||||
} else if len(updateMsg.Users) != 1 {
|
||||
t.Errorf("Expected one user, got %+v", updateMsg.Users)
|
||||
} else if sid, ok := updateMsg.Users[0]["sessionId"].(string); !ok || sid != sessionId {
|
||||
t.Errorf("Expected session id %s, got %+v", sessionId, updateMsg.Users[0])
|
||||
} else if virtual, ok := updateMsg.Users[0]["virtual"].(bool); !ok || !virtual {
|
||||
t.Errorf("Expected virtual user, got %+v", updateMsg.Users[0])
|
||||
} else if inCall, ok := updateMsg.Users[0]["inCall"].(float64); !ok || inCall != (FlagInCall|FlagWithPhone) {
|
||||
t.Errorf("Expected user in call with phone, got %+v", updateMsg.Users[0])
|
||||
require.NoError(err)
|
||||
if updateMsg, err := checkMessageParticipantsInCall(msg2); assert.NoError(err) {
|
||||
assert.Equal(roomId, updateMsg.RoomId)
|
||||
if assert.Len(updateMsg.Users, 1) {
|
||||
assert.Equal(sessionId, updateMsg.Users[0]["sessionId"])
|
||||
assert.Equal(true, updateMsg.Users[0]["virtual"])
|
||||
assert.EqualValues((FlagInCall | FlagWithPhone), updateMsg.Users[0]["inCall"])
|
||||
}
|
||||
}
|
||||
|
||||
msg3, err := client.RunUntilMessage(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
flagsMsg, err := checkMessageParticipantFlags(msg3)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
} else if flagsMsg.RoomId != roomId {
|
||||
t.Errorf("Expected room %s, got %s", roomId, flagsMsg.RoomId)
|
||||
} else if flagsMsg.SessionId != sessionId {
|
||||
t.Errorf("Expected session id %s, got %s", sessionId, flagsMsg.SessionId)
|
||||
} else if flagsMsg.Flags != FLAG_MUTED_SPEAKING {
|
||||
t.Errorf("Expected flags %d, got %+v", FLAG_MUTED_SPEAKING, flagsMsg.Flags)
|
||||
if flagsMsg, err := checkMessageParticipantFlags(msg3); assert.NoError(err) {
|
||||
assert.Equal(roomId, flagsMsg.RoomId)
|
||||
assert.Equal(sessionId, flagsMsg.SessionId)
|
||||
assert.EqualValues(FLAG_MUTED_SPEAKING, flagsMsg.Flags)
|
||||
}
|
||||
|
||||
newFlags := uint32(FLAG_TALKING)
|
||||
|
|
@ -180,49 +140,35 @@ func TestVirtualSession(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
if err := clientInternal.WriteJSON(msgFlags); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(clientInternal.WriteJSON(msgFlags))
|
||||
|
||||
msg4, err := client.RunUntilMessage(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
flagsMsg, err = checkMessageParticipantFlags(msg4)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
} else if flagsMsg.RoomId != roomId {
|
||||
t.Errorf("Expected room %s, got %s", roomId, flagsMsg.RoomId)
|
||||
} else if flagsMsg.SessionId != sessionId {
|
||||
t.Errorf("Expected session id %s, got %s", sessionId, flagsMsg.SessionId)
|
||||
} else if flagsMsg.Flags != newFlags {
|
||||
t.Errorf("Expected flags %d, got %+v", newFlags, flagsMsg.Flags)
|
||||
if flagsMsg, err := checkMessageParticipantFlags(msg4); assert.NoError(err) {
|
||||
assert.Equal(roomId, flagsMsg.RoomId)
|
||||
assert.Equal(sessionId, flagsMsg.SessionId)
|
||||
assert.EqualValues(newFlags, flagsMsg.Flags)
|
||||
}
|
||||
|
||||
// A new client will receive the initial flags of the virtual session.
|
||||
client2 := NewTestClient(t, server, hub)
|
||||
defer client2.CloseWithBye()
|
||||
if err := client2.SendHello(testDefaultUserId + "2"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(client2.SendHello(testDefaultUserId + "2"))
|
||||
|
||||
if _, err := client2.RunUntilHello(ctx); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
_, err = client2.RunUntilHello(ctx)
|
||||
require.NoError(err)
|
||||
|
||||
if room, err := client2.JoinRoom(ctx, roomId); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if room.Room.RoomId != roomId {
|
||||
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
|
||||
}
|
||||
roomMsg, err = client2.JoinRoom(ctx, roomId)
|
||||
require.NoError(err)
|
||||
require.Equal(roomId, roomMsg.Room.RoomId)
|
||||
|
||||
gotFlags := false
|
||||
var receivedMessages []*ServerMessage
|
||||
for !gotFlags {
|
||||
messages, err := client2.GetPendingMessages(ctx)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
assert.NoError(err)
|
||||
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
|
||||
break
|
||||
}
|
||||
|
|
@ -234,26 +180,18 @@ func TestVirtualSession(t *testing.T) {
|
|||
continue
|
||||
}
|
||||
|
||||
if msg.Event.Flags.RoomId != roomId {
|
||||
t.Errorf("Expected flags in room %s, got %s", roomId, msg.Event.Flags.RoomId)
|
||||
} else if msg.Event.Flags.SessionId != sessionId {
|
||||
t.Errorf("Expected flags for session %s, got %s", sessionId, msg.Event.Flags.SessionId)
|
||||
} else if msg.Event.Flags.Flags != newFlags {
|
||||
t.Errorf("Expected flags %d, got %d", newFlags, msg.Event.Flags.Flags)
|
||||
} else {
|
||||
if assert.Equal(roomId, msg.Event.Flags.RoomId) &&
|
||||
assert.Equal(sessionId, msg.Event.Flags.SessionId) &&
|
||||
assert.EqualValues(newFlags, msg.Event.Flags.Flags) {
|
||||
gotFlags = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !gotFlags {
|
||||
t.Errorf("Didn't receive initial flags in %+v", receivedMessages)
|
||||
}
|
||||
assert.True(gotFlags, "Didn't receive initial flags in %+v", receivedMessages)
|
||||
|
||||
// Ignore "join" messages from second client
|
||||
if err := client.DrainMessages(ctx); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(client.DrainMessages(ctx))
|
||||
|
||||
// When sending to a virtual session, the message is sent to the actual
|
||||
// client and contains a "Recipient" block with the internal session id.
|
||||
|
|
@ -263,32 +201,21 @@ func TestVirtualSession(t *testing.T) {
|
|||
}
|
||||
|
||||
data := "from-client-to-virtual"
|
||||
if err := client.SendMessage(recipient, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(client.SendMessage(recipient, data))
|
||||
|
||||
msg2, err = clientInternal.RunUntilMessage(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
} else if err := checkMessageType(msg2, "message"); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if err := checkMessageSender(hub, msg2.Message.Sender, "session", hello.Hello); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
require.NoError(checkMessageType(msg2, "message"))
|
||||
require.NoError(checkMessageSender(hub, msg2.Message.Sender, "session", hello.Hello))
|
||||
|
||||
if msg2.Message.Recipient == nil {
|
||||
t.Errorf("Expected recipient, got none")
|
||||
} else if msg2.Message.Recipient.Type != "session" {
|
||||
t.Errorf("Expected recipient type session, got %s", msg2.Message.Recipient.Type)
|
||||
} else if msg2.Message.Recipient.SessionId != internalSessionId {
|
||||
t.Errorf("Expected recipient %s, got %s", internalSessionId, msg2.Message.Recipient.SessionId)
|
||||
if assert.NotNil(msg2.Message.Recipient) {
|
||||
assert.Equal("session", msg2.Message.Recipient.Type)
|
||||
assert.Equal(internalSessionId, msg2.Message.Recipient.SessionId)
|
||||
}
|
||||
|
||||
var payload string
|
||||
if err := json.Unmarshal(msg2.Message.Data, &payload); err != nil {
|
||||
t.Error(err)
|
||||
} else if payload != data {
|
||||
t.Errorf("Expected payload %s, got %s", data, payload)
|
||||
if err := json.Unmarshal(msg2.Message.Data, &payload); assert.NoError(err) {
|
||||
assert.Equal(data, payload)
|
||||
}
|
||||
|
||||
msgRemove := &ClientMessage{
|
||||
|
|
@ -303,16 +230,10 @@ func TestVirtualSession(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
if err := clientInternal.WriteJSON(msgRemove); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(clientInternal.WriteJSON(msgRemove))
|
||||
|
||||
msg5, err := client.RunUntilMessage(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := client.checkMessageRoomLeaveSession(msg5, sessionId); err != nil {
|
||||
t.Error(err)
|
||||
if msg5, err := client.RunUntilMessage(ctx); assert.NoError(err) {
|
||||
assert.NoError(client.checkMessageRoomLeaveSession(msg5, sessionId))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -342,6 +263,8 @@ func checkHasEntryWithInCall(message *RoomEventServerMessage, sessionId string,
|
|||
func TestVirtualSessionCustomInCall(t *testing.T) {
|
||||
t.Parallel()
|
||||
CatchLogForTest(t)
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
hub, _, _, server := CreateHubForTest(t)
|
||||
|
||||
roomId := "the-room-id"
|
||||
|
|
@ -351,9 +274,7 @@ func TestVirtualSessionCustomInCall(t *testing.T) {
|
|||
compat: true,
|
||||
}
|
||||
room, err := hub.createRoom(roomId, emptyProperties, backend)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not create room: %s", err)
|
||||
}
|
||||
require.NoError(err)
|
||||
defer room.Close()
|
||||
|
||||
clientInternal := NewTestClient(t, server, hub)
|
||||
|
|
@ -361,67 +282,40 @@ func TestVirtualSessionCustomInCall(t *testing.T) {
|
|||
features := []string{
|
||||
ClientFeatureInternalInCall,
|
||||
}
|
||||
if err := clientInternal.SendHelloInternalWithFeatures(features); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(clientInternal.SendHelloInternalWithFeatures(features))
|
||||
|
||||
client := NewTestClient(t, server, hub)
|
||||
defer client.CloseWithBye()
|
||||
if err := client.SendHello(testDefaultUserId); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(client.SendHello(testDefaultUserId))
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
|
||||
helloInternal, err := clientInternal.RunUntilHello(ctx)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
} else {
|
||||
if helloInternal.Hello.UserId != "" {
|
||||
t.Errorf("Expected empty user id, got %+v", helloInternal.Hello)
|
||||
}
|
||||
if helloInternal.Hello.SessionId == "" {
|
||||
t.Errorf("Expected session id, got %+v", helloInternal.Hello)
|
||||
}
|
||||
if helloInternal.Hello.ResumeId == "" {
|
||||
t.Errorf("Expected resume id, got %+v", helloInternal.Hello)
|
||||
}
|
||||
}
|
||||
if room, err := clientInternal.JoinRoomWithRoomSession(ctx, roomId, ""); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if room.Room.RoomId != roomId {
|
||||
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
|
||||
if assert.NoError(err) {
|
||||
assert.Empty(helloInternal.Hello.UserId)
|
||||
assert.NotEmpty(helloInternal.Hello.SessionId)
|
||||
assert.NotEmpty(helloInternal.Hello.ResumeId)
|
||||
}
|
||||
roomMsg, err := clientInternal.JoinRoomWithRoomSession(ctx, roomId, "")
|
||||
require.NoError(err)
|
||||
require.Equal(roomId, roomMsg.Room.RoomId)
|
||||
|
||||
hello, err := client.RunUntilHello(ctx)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if room, err := client.JoinRoom(ctx, roomId); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if room.Room.RoomId != roomId {
|
||||
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
|
||||
}
|
||||
assert.NoError(err)
|
||||
roomMsg, err = client.JoinRoom(ctx, roomId)
|
||||
require.NoError(err)
|
||||
require.Equal(roomId, roomMsg.Room.RoomId)
|
||||
|
||||
if _, additional, err := clientInternal.RunUntilJoinedAndReturn(ctx, helloInternal.Hello, hello.Hello); err != nil {
|
||||
t.Error(err)
|
||||
} else if len(additional) != 1 {
|
||||
t.Errorf("expected one additional message, got %+v", additional)
|
||||
} else if additional[0].Type != "event" {
|
||||
t.Errorf("expected event message, got %+v", additional[0])
|
||||
} else if additional[0].Event.Target != "participants" {
|
||||
t.Errorf("expected event participants message, got %+v", additional[0])
|
||||
} else if additional[0].Event.Type != "update" {
|
||||
t.Errorf("expected event participants update message, got %+v", additional[0])
|
||||
} else if additional[0].Event.Update.Users[0]["sessionId"].(string) != helloInternal.Hello.SessionId {
|
||||
t.Errorf("expected event update message for internal session, got %+v", additional[0])
|
||||
} else if additional[0].Event.Update.Users[0]["inCall"].(float64) != 0 {
|
||||
t.Errorf("expected event update message with session not in call, got %+v", additional[0])
|
||||
}
|
||||
if err := client.RunUntilJoined(ctx, helloInternal.Hello, hello.Hello); err != nil {
|
||||
t.Error(err)
|
||||
if _, additional, err := clientInternal.RunUntilJoinedAndReturn(ctx, helloInternal.Hello, hello.Hello); assert.NoError(err) {
|
||||
if assert.Len(additional, 1) && assert.Equal("event", additional[0].Type) {
|
||||
assert.Equal("participants", additional[0].Event.Target)
|
||||
assert.Equal("update", additional[0].Event.Type)
|
||||
assert.Equal(helloInternal.Hello.SessionId, additional[0].Event.Update.Users[0]["sessionId"])
|
||||
assert.EqualValues(0, additional[0].Event.Update.Users[0]["inCall"])
|
||||
}
|
||||
}
|
||||
assert.NoError(client.RunUntilJoined(ctx, helloInternal.Hello, hello.Hello))
|
||||
|
||||
internalSessionId := "session1"
|
||||
userId := "user1"
|
||||
|
|
@ -439,65 +333,38 @@ func TestVirtualSessionCustomInCall(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
if err := clientInternal.WriteJSON(msgAdd); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(clientInternal.WriteJSON(msgAdd))
|
||||
|
||||
msg1, err := client.RunUntilMessage(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
// The public session id will be generated by the server, so don't check for it.
|
||||
if err := client.checkMessageJoinedSession(msg1, "", userId); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(client.checkMessageJoinedSession(msg1, "", userId))
|
||||
sessionId := msg1.Event.Join[0].SessionId
|
||||
session := hub.GetSessionByPublicId(sessionId)
|
||||
if session == nil {
|
||||
t.Fatalf("Could not get virtual session %s", sessionId)
|
||||
}
|
||||
if session.ClientType() != HelloClientTypeVirtual {
|
||||
t.Errorf("Expected client type %s, got %s", HelloClientTypeVirtual, session.ClientType())
|
||||
}
|
||||
if sid := session.(*VirtualSession).SessionId(); sid != internalSessionId {
|
||||
t.Errorf("Expected internal session id %s, got %s", internalSessionId, sid)
|
||||
if assert.NotNil(session) {
|
||||
assert.Equal(HelloClientTypeVirtual, session.ClientType())
|
||||
sid := session.(*VirtualSession).SessionId()
|
||||
assert.Equal(internalSessionId, sid)
|
||||
}
|
||||
|
||||
// Also a participants update event will be triggered for the virtual user.
|
||||
msg2, err := client.RunUntilMessage(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
updateMsg, err := checkMessageParticipantsInCall(msg2)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
} else if updateMsg.RoomId != roomId {
|
||||
t.Errorf("Expected room %s, got %s", roomId, updateMsg.RoomId)
|
||||
} else if len(updateMsg.Users) != 2 {
|
||||
t.Errorf("Expected two users, got %+v", updateMsg.Users)
|
||||
}
|
||||
require.NoError(err)
|
||||
if updateMsg, err := checkMessageParticipantsInCall(msg2); assert.NoError(err) {
|
||||
assert.Equal(roomId, updateMsg.RoomId)
|
||||
assert.Len(updateMsg.Users, 2)
|
||||
|
||||
if err := checkHasEntryWithInCall(updateMsg, sessionId, "virtual", 0); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := checkHasEntryWithInCall(updateMsg, helloInternal.Hello.SessionId, "internal", 0); err != nil {
|
||||
t.Error(err)
|
||||
assert.NoError(checkHasEntryWithInCall(updateMsg, sessionId, "virtual", 0))
|
||||
assert.NoError(checkHasEntryWithInCall(updateMsg, helloInternal.Hello.SessionId, "internal", 0))
|
||||
}
|
||||
|
||||
msg3, err := client.RunUntilMessage(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
flagsMsg, err := checkMessageParticipantFlags(msg3)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
} else if flagsMsg.RoomId != roomId {
|
||||
t.Errorf("Expected room %s, got %s", roomId, flagsMsg.RoomId)
|
||||
} else if flagsMsg.SessionId != sessionId {
|
||||
t.Errorf("Expected session id %s, got %s", sessionId, flagsMsg.SessionId)
|
||||
} else if flagsMsg.Flags != FLAG_MUTED_SPEAKING {
|
||||
t.Errorf("Expected flags %d, got %+v", FLAG_MUTED_SPEAKING, flagsMsg.Flags)
|
||||
if flagsMsg, err := checkMessageParticipantFlags(msg3); assert.NoError(err) {
|
||||
assert.Equal(roomId, flagsMsg.RoomId)
|
||||
assert.Equal(sessionId, flagsMsg.SessionId)
|
||||
assert.EqualValues(FLAG_MUTED_SPEAKING, flagsMsg.Flags)
|
||||
}
|
||||
|
||||
// The internal session can change its "inCall" flags
|
||||
|
|
@ -510,27 +377,15 @@ func TestVirtualSessionCustomInCall(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
if err := clientInternal.WriteJSON(msgInCall); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(clientInternal.WriteJSON(msgInCall))
|
||||
|
||||
msg4, err := client.RunUntilMessage(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
updateMsg2, err := checkMessageParticipantsInCall(msg4)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
} else if updateMsg2.RoomId != roomId {
|
||||
t.Errorf("Expected room %s, got %s", roomId, updateMsg2.RoomId)
|
||||
} else if len(updateMsg2.Users) != 2 {
|
||||
t.Errorf("Expected two users, got %+v", updateMsg2.Users)
|
||||
}
|
||||
if err := checkHasEntryWithInCall(updateMsg2, sessionId, "virtual", 0); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := checkHasEntryWithInCall(updateMsg2, helloInternal.Hello.SessionId, "internal", FlagInCall|FlagWithAudio); err != nil {
|
||||
t.Error(err)
|
||||
require.NoError(err)
|
||||
if updateMsg, err := checkMessageParticipantsInCall(msg4); assert.NoError(err) {
|
||||
assert.Equal(roomId, updateMsg.RoomId)
|
||||
assert.Len(updateMsg.Users, 2)
|
||||
assert.NoError(checkHasEntryWithInCall(updateMsg, sessionId, "virtual", 0))
|
||||
assert.NoError(checkHasEntryWithInCall(updateMsg, helloInternal.Hello.SessionId, "internal", FlagInCall|FlagWithAudio))
|
||||
}
|
||||
|
||||
// The internal session can change the "inCall" flags of a virtual session
|
||||
|
|
@ -548,33 +403,23 @@ func TestVirtualSessionCustomInCall(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
if err := clientInternal.WriteJSON(msgInCall2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(clientInternal.WriteJSON(msgInCall2))
|
||||
|
||||
msg5, err := client.RunUntilMessage(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
updateMsg3, err := checkMessageParticipantsInCall(msg5)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
} else if updateMsg3.RoomId != roomId {
|
||||
t.Errorf("Expected room %s, got %s", roomId, updateMsg3.RoomId)
|
||||
} else if len(updateMsg3.Users) != 2 {
|
||||
t.Errorf("Expected two users, got %+v", updateMsg3.Users)
|
||||
}
|
||||
if err := checkHasEntryWithInCall(updateMsg3, sessionId, "virtual", newInCall); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := checkHasEntryWithInCall(updateMsg3, helloInternal.Hello.SessionId, "internal", FlagInCall|FlagWithAudio); err != nil {
|
||||
t.Error(err)
|
||||
require.NoError(err)
|
||||
if updateMsg, err := checkMessageParticipantsInCall(msg5); assert.NoError(err) {
|
||||
assert.Equal(roomId, updateMsg.RoomId)
|
||||
assert.Len(updateMsg.Users, 2)
|
||||
assert.NoError(checkHasEntryWithInCall(updateMsg, sessionId, "virtual", newInCall))
|
||||
assert.NoError(checkHasEntryWithInCall(updateMsg, helloInternal.Hello.SessionId, "internal", FlagInCall|FlagWithAudio))
|
||||
}
|
||||
}
|
||||
|
||||
func TestVirtualSessionCleanup(t *testing.T) {
|
||||
t.Parallel()
|
||||
CatchLogForTest(t)
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
hub, _, _, server := CreateHubForTest(t)
|
||||
|
||||
roomId := "the-room-id"
|
||||
|
|
@ -584,53 +429,34 @@ func TestVirtualSessionCleanup(t *testing.T) {
|
|||
compat: true,
|
||||
}
|
||||
room, err := hub.createRoom(roomId, emptyProperties, backend)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not create room: %s", err)
|
||||
}
|
||||
require.NoError(err)
|
||||
defer room.Close()
|
||||
|
||||
clientInternal := NewTestClient(t, server, hub)
|
||||
defer clientInternal.CloseWithBye()
|
||||
if err := clientInternal.SendHelloInternal(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(clientInternal.SendHelloInternal())
|
||||
|
||||
client := NewTestClient(t, server, hub)
|
||||
defer client.CloseWithBye()
|
||||
if err := client.SendHello(testDefaultUserId); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(client.SendHello(testDefaultUserId))
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
|
||||
if hello, err := clientInternal.RunUntilHello(ctx); err != nil {
|
||||
t.Error(err)
|
||||
} else {
|
||||
if hello.Hello.UserId != "" {
|
||||
t.Errorf("Expected empty user id, got %+v", hello.Hello)
|
||||
}
|
||||
if hello.Hello.SessionId == "" {
|
||||
t.Errorf("Expected session id, got %+v", hello.Hello)
|
||||
}
|
||||
if hello.Hello.ResumeId == "" {
|
||||
t.Errorf("Expected resume id, got %+v", hello.Hello)
|
||||
}
|
||||
}
|
||||
if _, err := client.RunUntilHello(ctx); err != nil {
|
||||
t.Error(err)
|
||||
if hello, err := clientInternal.RunUntilHello(ctx); assert.NoError(err) {
|
||||
assert.Empty(hello.Hello.UserId)
|
||||
assert.NotEmpty(hello.Hello.SessionId)
|
||||
assert.NotEmpty(hello.Hello.ResumeId)
|
||||
}
|
||||
_, err = client.RunUntilHello(ctx)
|
||||
assert.NoError(err)
|
||||
|
||||
if room, err := client.JoinRoom(ctx, roomId); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if room.Room.RoomId != roomId {
|
||||
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
|
||||
}
|
||||
roomMsg, err := client.JoinRoom(ctx, roomId)
|
||||
require.NoError(err)
|
||||
require.Equal(roomId, roomMsg.Room.RoomId)
|
||||
|
||||
// Ignore "join" events.
|
||||
if err := client.DrainMessages(ctx); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(client.DrainMessages(ctx))
|
||||
|
||||
internalSessionId := "session1"
|
||||
userId := "user1"
|
||||
|
|
@ -648,72 +474,45 @@ func TestVirtualSessionCleanup(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
if err := clientInternal.WriteJSON(msgAdd); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(clientInternal.WriteJSON(msgAdd))
|
||||
|
||||
msg1, err := client.RunUntilMessage(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
// The public session id will be generated by the server, so don't check for it.
|
||||
if err := client.checkMessageJoinedSession(msg1, "", userId); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(client.checkMessageJoinedSession(msg1, "", userId))
|
||||
sessionId := msg1.Event.Join[0].SessionId
|
||||
session := hub.GetSessionByPublicId(sessionId)
|
||||
if session == nil {
|
||||
t.Fatalf("Could not get virtual session %s", sessionId)
|
||||
}
|
||||
if session.ClientType() != HelloClientTypeVirtual {
|
||||
t.Errorf("Expected client type %s, got %s", HelloClientTypeVirtual, session.ClientType())
|
||||
}
|
||||
if sid := session.(*VirtualSession).SessionId(); sid != internalSessionId {
|
||||
t.Errorf("Expected internal session id %s, got %s", internalSessionId, sid)
|
||||
if assert.NotNil(session) {
|
||||
assert.Equal(HelloClientTypeVirtual, session.ClientType())
|
||||
sid := session.(*VirtualSession).SessionId()
|
||||
assert.Equal(internalSessionId, sid)
|
||||
}
|
||||
|
||||
// Also a participants update event will be triggered for the virtual user.
|
||||
msg2, err := client.RunUntilMessage(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
updateMsg, err := checkMessageParticipantsInCall(msg2)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
} else if updateMsg.RoomId != roomId {
|
||||
t.Errorf("Expected room %s, got %s", roomId, updateMsg.RoomId)
|
||||
} else if len(updateMsg.Users) != 1 {
|
||||
t.Errorf("Expected one user, got %+v", updateMsg.Users)
|
||||
} else if sid, ok := updateMsg.Users[0]["sessionId"].(string); !ok || sid != sessionId {
|
||||
t.Errorf("Expected session id %s, got %+v", sessionId, updateMsg.Users[0])
|
||||
} else if virtual, ok := updateMsg.Users[0]["virtual"].(bool); !ok || !virtual {
|
||||
t.Errorf("Expected virtual user, got %+v", updateMsg.Users[0])
|
||||
} else if inCall, ok := updateMsg.Users[0]["inCall"].(float64); !ok || inCall != (FlagInCall|FlagWithPhone) {
|
||||
t.Errorf("Expected user in call with phone, got %+v", updateMsg.Users[0])
|
||||
require.NoError(err)
|
||||
if updateMsg, err := checkMessageParticipantsInCall(msg2); assert.NoError(err) {
|
||||
assert.Equal(roomId, updateMsg.RoomId)
|
||||
if assert.Len(updateMsg.Users, 1) {
|
||||
assert.Equal(sessionId, updateMsg.Users[0]["sessionId"])
|
||||
assert.Equal(true, updateMsg.Users[0]["virtual"])
|
||||
assert.EqualValues((FlagInCall | FlagWithPhone), updateMsg.Users[0]["inCall"])
|
||||
}
|
||||
}
|
||||
|
||||
msg3, err := client.RunUntilMessage(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
flagsMsg, err := checkMessageParticipantFlags(msg3)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
} else if flagsMsg.RoomId != roomId {
|
||||
t.Errorf("Expected room %s, got %s", roomId, flagsMsg.RoomId)
|
||||
} else if flagsMsg.SessionId != sessionId {
|
||||
t.Errorf("Expected session id %s, got %s", sessionId, flagsMsg.SessionId)
|
||||
} else if flagsMsg.Flags != FLAG_MUTED_SPEAKING {
|
||||
t.Errorf("Expected flags %d, got %+v", FLAG_MUTED_SPEAKING, flagsMsg.Flags)
|
||||
if flagsMsg, err := checkMessageParticipantFlags(msg3); assert.NoError(err) {
|
||||
assert.Equal(roomId, flagsMsg.RoomId)
|
||||
assert.Equal(sessionId, flagsMsg.SessionId)
|
||||
assert.EqualValues(FLAG_MUTED_SPEAKING, flagsMsg.Flags)
|
||||
}
|
||||
|
||||
// The virtual sessions are closed when the parent session is deleted.
|
||||
clientInternal.CloseWithBye()
|
||||
|
||||
if msg2, err := client.RunUntilMessage(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if err := client.checkMessageRoomLeaveSession(msg2, sessionId); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
msg2, err = client.RunUntilMessage(ctx)
|
||||
require.NoError(err)
|
||||
assert.NoError(client.checkMessageRoomLeaveSession(msg2, sessionId))
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue