mirror of
https://github.com/strukturag/nextcloud-spreed-signaling
synced 2024-05-23 16:02:12 +02:00
Merge pull request #281 from strukturag/refactor-async-events
Clustering support
This commit is contained in:
commit
d3f8876d25
1
.github/workflows/lint.yml
vendored
1
.github/workflows/lint.yml
vendored
|
@ -45,6 +45,7 @@ jobs:
|
|||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt -y update && sudo apt -y install protobuf-compiler
|
||||
make common
|
||||
|
||||
- name: lint
|
||||
|
|
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
|
@ -50,6 +50,10 @@ jobs:
|
|||
path: ${{ steps.go-cache-paths.outputs.go-mod }}
|
||||
key: ${{ runner.os }}-${{ steps.go-cache-paths.outputs.go-version }}-mod-${{ hashFiles('**/go.mod', '**/go.sum') }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt -y update && sudo apt -y install protobuf-compiler
|
||||
|
||||
- name: Build applications
|
||||
run: |
|
||||
echo "Building with $(nproc) threads"
|
||||
|
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -3,6 +3,7 @@ vendor/
|
|||
|
||||
*_easyjson.go
|
||||
*.pem
|
||||
*.pb.go
|
||||
*.prof
|
||||
*.socket
|
||||
*.tar.gz
|
||||
|
|
|
@ -3,6 +3,7 @@ FROM golang:1.18 AS builder
|
|||
WORKDIR /workdir
|
||||
|
||||
COPY . .
|
||||
RUN apt -y update && apt -y install protobuf-compiler
|
||||
RUN make build
|
||||
|
||||
FROM alpine:3.15
|
||||
|
|
34
Makefile
34
Makefile
|
@ -56,6 +56,14 @@ $(GOPATHBIN)/easyjson:
|
|||
$(GO) get -u -d github.com/mailru/easyjson/...
|
||||
$(GO) install github.com/mailru/easyjson/...
|
||||
|
||||
$(GOPATHBIN)/protoc-gen-go:
|
||||
$(GO) get -u -d google.golang.org/protobuf/cmd/protoc-gen-go
|
||||
$(GO) install google.golang.org/protobuf/cmd/protoc-gen-go
|
||||
|
||||
$(GOPATHBIN)/protoc-gen-go-grpc:
|
||||
$(GO) get -u -d google.golang.org/grpc/cmd/protoc-gen-go-grpc
|
||||
$(GO) install google.golang.org/grpc/cmd/protoc-gen-go-grpc
|
||||
|
||||
continentmap.go:
|
||||
$(CURDIR)/scripts/get_continent_map.py $@
|
||||
|
||||
|
@ -70,7 +78,7 @@ check-continentmap:
|
|||
get:
|
||||
$(GO) get $(PACKAGE)
|
||||
|
||||
fmt: hook
|
||||
fmt: hook | common_proto
|
||||
$(GOFMT) -s -w *.go client proxy server
|
||||
|
||||
vet: common
|
||||
|
@ -83,22 +91,37 @@ cover: vet common
|
|||
rm -f cover.out && \
|
||||
$(GO) test -v -timeout $(TIMEOUT) -coverprofile cover.out $(ALL_PACKAGES) && \
|
||||
sed -i "/_easyjson/d" cover.out && \
|
||||
sed -i "/\.pb\.go/d" cover.out && \
|
||||
$(GO) tool cover -func=cover.out
|
||||
|
||||
coverhtml: vet common
|
||||
rm -f cover.out && \
|
||||
$(GO) test -v -timeout $(TIMEOUT) -coverprofile cover.out $(ALL_PACKAGES) && \
|
||||
sed -i "/_easyjson/d" cover.out && \
|
||||
sed -i "/\.pb\.go/d" cover.out && \
|
||||
$(GO) tool cover -html=cover.out -o coverage.html
|
||||
|
||||
%_easyjson.go: %.go $(GOPATHBIN)/easyjson
|
||||
%_easyjson.go: %.go $(GOPATHBIN)/easyjson | common_proto
|
||||
PATH="$(GODIR)":$(PATH) "$(GOPATHBIN)/easyjson" -all $*.go
|
||||
|
||||
common: \
|
||||
%.pb.go: %.proto $(GOPATHBIN)/protoc-gen-go $(GOPATHBIN)/protoc-gen-go-grpc
|
||||
PATH="$(GODIR)":"$(GOPATHBIN)":$(PATH) protoc --go_out=. --go_opt=paths=source_relative \
|
||||
--go-grpc_out=. --go-grpc_opt=paths=source_relative \
|
||||
$*.proto
|
||||
|
||||
common: common_easyjson common_proto
|
||||
|
||||
common_easyjson: \
|
||||
api_async_easyjson.go \
|
||||
api_backend_easyjson.go \
|
||||
api_grpc_easyjson.go \
|
||||
api_proxy_easyjson.go \
|
||||
api_signaling_easyjson.go \
|
||||
natsclient_easyjson.go
|
||||
api_signaling_easyjson.go
|
||||
|
||||
common_proto: \
|
||||
grpc_internal.pb.go \
|
||||
grpc_mcu.pb.go \
|
||||
grpc_sessions.pb.go
|
||||
|
||||
$(BINDIR):
|
||||
mkdir -p $(BINDIR)
|
||||
|
@ -115,6 +138,7 @@ proxy: common $(BINDIR)
|
|||
clean:
|
||||
rm -f *_easyjson.go
|
||||
rm -f easyjson-bootstrap*.go
|
||||
rm -f *.pb.go
|
||||
|
||||
build: server proxy
|
||||
|
||||
|
|
15
README.md
15
README.md
|
@ -19,6 +19,7 @@ The following tools are required for building the signaling server.
|
|||
- git
|
||||
- go >= 1.17
|
||||
- make
|
||||
- protobuf-compiler >= 3
|
||||
|
||||
All other dependencies are fetched automatically while building.
|
||||
|
||||
|
@ -156,6 +157,20 @@ proxy process gracefully after all clients have been disconnected. No new
|
|||
publishers will be accepted in this case.
|
||||
|
||||
|
||||
### Clustering
|
||||
|
||||
The signaling server supports a clustering mode where multiple running servers
|
||||
can be interconnected to form a single "virtual" server. This can be used to
|
||||
increase the capacity of the signaling server or provide a failover setup.
|
||||
|
||||
For that a central NATS server / cluster must be used by all instances. Each
|
||||
instance must running a GRPC server (enable `listening` in section `grpc` and
|
||||
optionally setup certificate, private key and CA). The list of other GRPC
|
||||
targets must be configured as `targets` in section `grpc` or can be retrieved
|
||||
from an etcd cluster. See `server.conf.in` in section `grpc` for configuration
|
||||
details.
|
||||
|
||||
|
||||
## Setup of frontend webserver
|
||||
|
||||
Usually the standalone signaling server is running behind a webserver that does
|
||||
|
|
54
api_async.go
Normal file
54
api_async.go
Normal file
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Standalone signaling server for the Nextcloud Spreed app.
|
||||
* Copyright (C) 2022 struktur AG
|
||||
*
|
||||
* @author Joachim Bauch <bauch@struktur.de>
|
||||
*
|
||||
* @license GNU AGPL version 3 or any later version
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package signaling
|
||||
|
||||
import "time"
|
||||
|
||||
type AsyncMessage struct {
|
||||
SendTime time.Time `json:"sendtime"`
|
||||
|
||||
Type string `json:"type"`
|
||||
|
||||
Message *ServerMessage `json:"message,omitempty"`
|
||||
|
||||
Room *BackendServerRoomRequest `json:"room,omitempty"`
|
||||
|
||||
Permissions []Permission `json:"permissions,omitempty"`
|
||||
|
||||
AsyncRoom *AsyncRoomMessage `json:"asyncroom,omitempty"`
|
||||
|
||||
SendOffer *SendOfferMessage `json:"sendoffer,omitempty"`
|
||||
|
||||
Id string `json:"id"`
|
||||
}
|
||||
|
||||
type AsyncRoomMessage struct {
|
||||
Type string `json:"type"`
|
||||
|
||||
SessionId string `json:"sessionid,omitempty"`
|
||||
}
|
||||
|
||||
type SendOfferMessage struct {
|
||||
MessageId string `json:"messageid,omitempty"`
|
||||
SessionId string `json:"sessionid"`
|
||||
Data *MessageClientMessageData `json:"data"`
|
||||
}
|
|
@ -28,7 +28,10 @@ import (
|
|||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -321,3 +324,39 @@ type TurnCredentials struct {
|
|||
TTL int64 `json:"ttl"`
|
||||
URIs []string `json:"uris"`
|
||||
}
|
||||
|
||||
// Information on a backend in the etcd cluster.
|
||||
|
||||
type BackendInformationEtcd struct {
|
||||
parsedUrl *url.URL
|
||||
|
||||
Url string `json:"url"`
|
||||
Secret string `json:"secret"`
|
||||
|
||||
MaxStreamBitrate int `json:"maxstreambitrate,omitempty"`
|
||||
MaxScreenBitrate int `json:"maxscreenbitrate,omitempty"`
|
||||
|
||||
SessionLimit uint64 `json:"sessionlimit,omitempty"`
|
||||
}
|
||||
|
||||
func (p *BackendInformationEtcd) CheckValid() error {
|
||||
if p.Url == "" {
|
||||
return fmt.Errorf("url missing")
|
||||
}
|
||||
if p.Secret == "" {
|
||||
return fmt.Errorf("secret missing")
|
||||
}
|
||||
|
||||
parsedUrl, err := url.Parse(p.Url)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid url: %w", err)
|
||||
}
|
||||
|
||||
if strings.Contains(parsedUrl.Host, ":") && hasStandardPort(parsedUrl) {
|
||||
parsedUrl.Host = parsedUrl.Hostname()
|
||||
p.Url = parsedUrl.String()
|
||||
}
|
||||
|
||||
p.parsedUrl = parsedUrl
|
||||
return nil
|
||||
}
|
||||
|
|
41
api_grpc.go
Normal file
41
api_grpc.go
Normal file
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Standalone signaling server for the Nextcloud Spreed app.
|
||||
* Copyright (C) 2022 struktur AG
|
||||
*
|
||||
* @author Joachim Bauch <bauch@struktur.de>
|
||||
*
|
||||
* @license GNU AGPL version 3 or any later version
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package signaling
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Information on a GRPC target in the etcd cluster.
|
||||
|
||||
type GrpcTargetInformationEtcd struct {
|
||||
Address string `json:"address"`
|
||||
}
|
||||
|
||||
func (p *GrpcTargetInformationEtcd) CheckValid() error {
|
||||
if l := len(p.Address); l == 0 {
|
||||
return fmt.Errorf("address missing")
|
||||
} else if p.Address[l-1] == '/' {
|
||||
p.Address = p.Address[:l-1]
|
||||
}
|
||||
return nil
|
||||
}
|
210
async_events.go
Normal file
210
async_events.go
Normal file
|
@ -0,0 +1,210 @@
|
|||
/**
|
||||
* Standalone signaling server for the Nextcloud Spreed app.
|
||||
* Copyright (C) 2022 struktur AG
|
||||
*
|
||||
* @author Joachim Bauch <bauch@struktur.de>
|
||||
*
|
||||
* @license GNU AGPL version 3 or any later version
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package signaling
|
||||
|
||||
import "sync"
|
||||
|
||||
type AsyncBackendRoomEventListener interface {
|
||||
ProcessBackendRoomRequest(message *AsyncMessage)
|
||||
}
|
||||
|
||||
type AsyncRoomEventListener interface {
|
||||
ProcessAsyncRoomMessage(message *AsyncMessage)
|
||||
}
|
||||
|
||||
type AsyncUserEventListener interface {
|
||||
ProcessAsyncUserMessage(message *AsyncMessage)
|
||||
}
|
||||
|
||||
type AsyncSessionEventListener interface {
|
||||
ProcessAsyncSessionMessage(message *AsyncMessage)
|
||||
}
|
||||
|
||||
type AsyncEvents interface {
|
||||
Close()
|
||||
|
||||
RegisterBackendRoomListener(roomId string, backend *Backend, listener AsyncBackendRoomEventListener) error
|
||||
UnregisterBackendRoomListener(roomId string, backend *Backend, listener AsyncBackendRoomEventListener)
|
||||
|
||||
RegisterRoomListener(roomId string, backend *Backend, listener AsyncRoomEventListener) error
|
||||
UnregisterRoomListener(roomId string, backend *Backend, listener AsyncRoomEventListener)
|
||||
|
||||
RegisterUserListener(userId string, backend *Backend, listener AsyncUserEventListener) error
|
||||
UnregisterUserListener(userId string, backend *Backend, listener AsyncUserEventListener)
|
||||
|
||||
RegisterSessionListener(sessionId string, backend *Backend, listener AsyncSessionEventListener) error
|
||||
UnregisterSessionListener(sessionId string, backend *Backend, listener AsyncSessionEventListener)
|
||||
|
||||
PublishBackendRoomMessage(roomId string, backend *Backend, message *AsyncMessage) error
|
||||
PublishRoomMessage(roomId string, backend *Backend, message *AsyncMessage) error
|
||||
PublishUserMessage(userId string, backend *Backend, message *AsyncMessage) error
|
||||
PublishSessionMessage(sessionId string, backend *Backend, message *AsyncMessage) error
|
||||
}
|
||||
|
||||
func NewAsyncEvents(url string) (AsyncEvents, error) {
|
||||
client, err := NewNatsClient(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewAsyncEventsNats(client)
|
||||
}
|
||||
|
||||
type asyncBackendRoomSubscriber struct {
|
||||
mu sync.Mutex
|
||||
|
||||
listeners map[AsyncBackendRoomEventListener]bool
|
||||
}
|
||||
|
||||
func (s *asyncBackendRoomSubscriber) processBackendRoomRequest(message *AsyncMessage) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for listener := range s.listeners {
|
||||
s.mu.Unlock()
|
||||
listener.ProcessBackendRoomRequest(message)
|
||||
s.mu.Lock()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *asyncBackendRoomSubscriber) addListener(listener AsyncBackendRoomEventListener) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.listeners == nil {
|
||||
s.listeners = make(map[AsyncBackendRoomEventListener]bool)
|
||||
}
|
||||
s.listeners[listener] = true
|
||||
}
|
||||
|
||||
func (s *asyncBackendRoomSubscriber) removeListener(listener AsyncBackendRoomEventListener) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
delete(s.listeners, listener)
|
||||
return len(s.listeners) > 0
|
||||
}
|
||||
|
||||
type asyncRoomSubscriber struct {
|
||||
mu sync.Mutex
|
||||
|
||||
listeners map[AsyncRoomEventListener]bool
|
||||
}
|
||||
|
||||
func (s *asyncRoomSubscriber) processAsyncRoomMessage(message *AsyncMessage) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for listener := range s.listeners {
|
||||
s.mu.Unlock()
|
||||
listener.ProcessAsyncRoomMessage(message)
|
||||
s.mu.Lock()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *asyncRoomSubscriber) addListener(listener AsyncRoomEventListener) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.listeners == nil {
|
||||
s.listeners = make(map[AsyncRoomEventListener]bool)
|
||||
}
|
||||
s.listeners[listener] = true
|
||||
}
|
||||
|
||||
func (s *asyncRoomSubscriber) removeListener(listener AsyncRoomEventListener) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
delete(s.listeners, listener)
|
||||
return len(s.listeners) > 0
|
||||
}
|
||||
|
||||
type asyncUserSubscriber struct {
|
||||
mu sync.Mutex
|
||||
|
||||
listeners map[AsyncUserEventListener]bool
|
||||
}
|
||||
|
||||
func (s *asyncUserSubscriber) processAsyncUserMessage(message *AsyncMessage) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for listener := range s.listeners {
|
||||
s.mu.Unlock()
|
||||
listener.ProcessAsyncUserMessage(message)
|
||||
s.mu.Lock()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *asyncUserSubscriber) addListener(listener AsyncUserEventListener) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.listeners == nil {
|
||||
s.listeners = make(map[AsyncUserEventListener]bool)
|
||||
}
|
||||
s.listeners[listener] = true
|
||||
}
|
||||
|
||||
func (s *asyncUserSubscriber) removeListener(listener AsyncUserEventListener) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
delete(s.listeners, listener)
|
||||
return len(s.listeners) > 0
|
||||
}
|
||||
|
||||
type asyncSessionSubscriber struct {
|
||||
mu sync.Mutex
|
||||
|
||||
listeners map[AsyncSessionEventListener]bool
|
||||
}
|
||||
|
||||
func (s *asyncSessionSubscriber) processAsyncSessionMessage(message *AsyncMessage) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for listener := range s.listeners {
|
||||
s.mu.Unlock()
|
||||
listener.ProcessAsyncSessionMessage(message)
|
||||
s.mu.Lock()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *asyncSessionSubscriber) addListener(listener AsyncSessionEventListener) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.listeners == nil {
|
||||
s.listeners = make(map[AsyncSessionEventListener]bool)
|
||||
}
|
||||
s.listeners[listener] = true
|
||||
}
|
||||
|
||||
func (s *asyncSessionSubscriber) removeListener(listener AsyncSessionEventListener) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
delete(s.listeners, listener)
|
||||
return len(s.listeners) > 0
|
||||
}
|
450
async_events_nats.go
Normal file
450
async_events_nats.go
Normal file
|
@ -0,0 +1,450 @@
|
|||
/**
|
||||
* Standalone signaling server for the Nextcloud Spreed app.
|
||||
* Copyright (C) 2022 struktur AG
|
||||
*
|
||||
* @author Joachim Bauch <bauch@struktur.de>
|
||||
*
|
||||
* @license GNU AGPL version 3 or any later version
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package signaling
|
||||
|
||||
import (
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/nats-io/nats.go"
|
||||
)
|
||||
|
||||
func GetSubjectForBackendRoomId(roomId string, backend *Backend) string {
|
||||
if backend == nil || backend.IsCompat() {
|
||||
return GetEncodedSubject("backend.room", roomId)
|
||||
}
|
||||
|
||||
return GetEncodedSubject("backend.room", roomId+"|"+backend.Id())
|
||||
}
|
||||
|
||||
func GetSubjectForRoomId(roomId string, backend *Backend) string {
|
||||
if backend == nil || backend.IsCompat() {
|
||||
return GetEncodedSubject("room", roomId)
|
||||
}
|
||||
|
||||
return GetEncodedSubject("room", roomId+"|"+backend.Id())
|
||||
}
|
||||
|
||||
func GetSubjectForUserId(userId string, backend *Backend) string {
|
||||
if backend == nil || backend.IsCompat() {
|
||||
return GetEncodedSubject("user", userId)
|
||||
}
|
||||
|
||||
return GetEncodedSubject("user", userId+"|"+backend.Id())
|
||||
}
|
||||
|
||||
func GetSubjectForSessionId(sessionId string, backend *Backend) string {
|
||||
return "session." + sessionId
|
||||
}
|
||||
|
||||
type asyncSubscriberNats struct {
|
||||
key string
|
||||
client NatsClient
|
||||
|
||||
receiver chan *nats.Msg
|
||||
closeChan chan bool
|
||||
subscription NatsSubscription
|
||||
|
||||
processMessage func(*nats.Msg)
|
||||
}
|
||||
|
||||
func newAsyncSubscriberNats(key string, client NatsClient) (*asyncSubscriberNats, error) {
|
||||
receiver := make(chan *nats.Msg, 64)
|
||||
sub, err := client.Subscribe(key, receiver)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &asyncSubscriberNats{
|
||||
key: key,
|
||||
client: client,
|
||||
|
||||
receiver: receiver,
|
||||
closeChan: make(chan bool),
|
||||
subscription: sub,
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *asyncSubscriberNats) run() {
|
||||
defer func() {
|
||||
if err := s.subscription.Unsubscribe(); err != nil {
|
||||
log.Printf("Error unsubscribing %s: %s", s.key, err)
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg := <-s.receiver:
|
||||
s.processMessage(msg)
|
||||
for count := len(s.receiver); count > 0; count-- {
|
||||
s.processMessage(<-s.receiver)
|
||||
}
|
||||
case <-s.closeChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *asyncSubscriberNats) close() {
|
||||
close(s.closeChan)
|
||||
}
|
||||
|
||||
type asyncBackendRoomSubscriberNats struct {
|
||||
*asyncSubscriberNats
|
||||
asyncBackendRoomSubscriber
|
||||
}
|
||||
|
||||
func newAsyncBackendRoomSubscriberNats(key string, client NatsClient) (*asyncBackendRoomSubscriberNats, error) {
|
||||
sub, err := newAsyncSubscriberNats(key, client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &asyncBackendRoomSubscriberNats{
|
||||
asyncSubscriberNats: sub,
|
||||
}
|
||||
result.processMessage = result.doProcessMessage
|
||||
go result.run()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *asyncBackendRoomSubscriberNats) doProcessMessage(msg *nats.Msg) {
|
||||
var message AsyncMessage
|
||||
if err := s.client.Decode(msg, &message); err != nil {
|
||||
log.Printf("Could not decode NATS message %+v, %s", msg, err)
|
||||
return
|
||||
}
|
||||
|
||||
s.processBackendRoomRequest(&message)
|
||||
}
|
||||
|
||||
type asyncRoomSubscriberNats struct {
|
||||
asyncRoomSubscriber
|
||||
*asyncSubscriberNats
|
||||
}
|
||||
|
||||
func newAsyncRoomSubscriberNats(key string, client NatsClient) (*asyncRoomSubscriberNats, error) {
|
||||
sub, err := newAsyncSubscriberNats(key, client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &asyncRoomSubscriberNats{
|
||||
asyncSubscriberNats: sub,
|
||||
}
|
||||
result.processMessage = result.doProcessMessage
|
||||
go result.run()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *asyncRoomSubscriberNats) doProcessMessage(msg *nats.Msg) {
|
||||
var message AsyncMessage
|
||||
if err := s.client.Decode(msg, &message); err != nil {
|
||||
log.Printf("Could not decode nats message %+v, %s", msg, err)
|
||||
return
|
||||
}
|
||||
|
||||
s.processAsyncRoomMessage(&message)
|
||||
}
|
||||
|
||||
type asyncUserSubscriberNats struct {
|
||||
*asyncSubscriberNats
|
||||
asyncUserSubscriber
|
||||
}
|
||||
|
||||
func newAsyncUserSubscriberNats(key string, client NatsClient) (*asyncUserSubscriberNats, error) {
|
||||
sub, err := newAsyncSubscriberNats(key, client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &asyncUserSubscriberNats{
|
||||
asyncSubscriberNats: sub,
|
||||
}
|
||||
result.processMessage = result.doProcessMessage
|
||||
go result.run()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *asyncUserSubscriberNats) doProcessMessage(msg *nats.Msg) {
|
||||
var message AsyncMessage
|
||||
if err := s.client.Decode(msg, &message); err != nil {
|
||||
log.Printf("Could not decode nats message %+v, %s", msg, err)
|
||||
return
|
||||
}
|
||||
|
||||
s.processAsyncUserMessage(&message)
|
||||
}
|
||||
|
||||
type asyncSessionSubscriberNats struct {
|
||||
*asyncSubscriberNats
|
||||
asyncSessionSubscriber
|
||||
}
|
||||
|
||||
func newAsyncSessionSubscriberNats(key string, client NatsClient) (*asyncSessionSubscriberNats, error) {
|
||||
sub, err := newAsyncSubscriberNats(key, client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &asyncSessionSubscriberNats{
|
||||
asyncSubscriberNats: sub,
|
||||
}
|
||||
result.processMessage = result.doProcessMessage
|
||||
go result.run()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *asyncSessionSubscriberNats) doProcessMessage(msg *nats.Msg) {
|
||||
var message AsyncMessage
|
||||
if err := s.client.Decode(msg, &message); err != nil {
|
||||
log.Printf("Could not decode nats message %+v, %s", msg, err)
|
||||
return
|
||||
}
|
||||
|
||||
s.processAsyncSessionMessage(&message)
|
||||
}
|
||||
|
||||
type asyncEventsNats struct {
|
||||
mu sync.Mutex
|
||||
client NatsClient
|
||||
|
||||
backendRoomSubscriptions map[string]*asyncBackendRoomSubscriberNats
|
||||
roomSubscriptions map[string]*asyncRoomSubscriberNats
|
||||
userSubscriptions map[string]*asyncUserSubscriberNats
|
||||
sessionSubscriptions map[string]*asyncSessionSubscriberNats
|
||||
}
|
||||
|
||||
func NewAsyncEventsNats(client NatsClient) (AsyncEvents, error) {
|
||||
events := &asyncEventsNats{
|
||||
client: client,
|
||||
|
||||
backendRoomSubscriptions: make(map[string]*asyncBackendRoomSubscriberNats),
|
||||
roomSubscriptions: make(map[string]*asyncRoomSubscriberNats),
|
||||
userSubscriptions: make(map[string]*asyncUserSubscriberNats),
|
||||
sessionSubscriptions: make(map[string]*asyncSessionSubscriberNats),
|
||||
}
|
||||
return events, nil
|
||||
}
|
||||
|
||||
func (e *asyncEventsNats) Close() {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func(subscriptions map[string]*asyncBackendRoomSubscriberNats) {
|
||||
defer wg.Done()
|
||||
for _, sub := range subscriptions {
|
||||
sub.close()
|
||||
}
|
||||
}(e.backendRoomSubscriptions)
|
||||
wg.Add(1)
|
||||
go func(subscriptions map[string]*asyncRoomSubscriberNats) {
|
||||
defer wg.Done()
|
||||
for _, sub := range subscriptions {
|
||||
sub.close()
|
||||
}
|
||||
}(e.roomSubscriptions)
|
||||
wg.Add(1)
|
||||
go func(subscriptions map[string]*asyncUserSubscriberNats) {
|
||||
defer wg.Done()
|
||||
for _, sub := range subscriptions {
|
||||
sub.close()
|
||||
}
|
||||
}(e.userSubscriptions)
|
||||
wg.Add(1)
|
||||
go func(subscriptions map[string]*asyncSessionSubscriberNats) {
|
||||
defer wg.Done()
|
||||
for _, sub := range subscriptions {
|
||||
sub.close()
|
||||
}
|
||||
}(e.sessionSubscriptions)
|
||||
e.backendRoomSubscriptions = make(map[string]*asyncBackendRoomSubscriberNats)
|
||||
e.roomSubscriptions = make(map[string]*asyncRoomSubscriberNats)
|
||||
e.userSubscriptions = make(map[string]*asyncUserSubscriberNats)
|
||||
e.sessionSubscriptions = make(map[string]*asyncSessionSubscriberNats)
|
||||
wg.Wait()
|
||||
e.client.Close()
|
||||
}
|
||||
|
||||
func (e *asyncEventsNats) RegisterBackendRoomListener(roomId string, backend *Backend, listener AsyncBackendRoomEventListener) error {
|
||||
key := GetSubjectForBackendRoomId(roomId, backend)
|
||||
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
sub, found := e.backendRoomSubscriptions[key]
|
||||
if !found {
|
||||
var err error
|
||||
if sub, err = newAsyncBackendRoomSubscriberNats(key, e.client); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e.backendRoomSubscriptions[key] = sub
|
||||
}
|
||||
sub.addListener(listener)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *asyncEventsNats) UnregisterBackendRoomListener(roomId string, backend *Backend, listener AsyncBackendRoomEventListener) {
|
||||
key := GetSubjectForBackendRoomId(roomId, backend)
|
||||
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
sub, found := e.backendRoomSubscriptions[key]
|
||||
if !found {
|
||||
return
|
||||
}
|
||||
|
||||
if !sub.removeListener(listener) {
|
||||
delete(e.backendRoomSubscriptions, key)
|
||||
sub.close()
|
||||
}
|
||||
}
|
||||
|
||||
func (e *asyncEventsNats) RegisterRoomListener(roomId string, backend *Backend, listener AsyncRoomEventListener) error {
|
||||
key := GetSubjectForRoomId(roomId, backend)
|
||||
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
sub, found := e.roomSubscriptions[key]
|
||||
if !found {
|
||||
var err error
|
||||
if sub, err = newAsyncRoomSubscriberNats(key, e.client); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e.roomSubscriptions[key] = sub
|
||||
}
|
||||
sub.addListener(listener)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *asyncEventsNats) UnregisterRoomListener(roomId string, backend *Backend, listener AsyncRoomEventListener) {
|
||||
key := GetSubjectForRoomId(roomId, backend)
|
||||
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
sub, found := e.roomSubscriptions[key]
|
||||
if !found {
|
||||
return
|
||||
}
|
||||
|
||||
if !sub.removeListener(listener) {
|
||||
delete(e.roomSubscriptions, key)
|
||||
sub.close()
|
||||
}
|
||||
}
|
||||
|
||||
func (e *asyncEventsNats) RegisterUserListener(roomId string, backend *Backend, listener AsyncUserEventListener) error {
|
||||
key := GetSubjectForUserId(roomId, backend)
|
||||
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
sub, found := e.userSubscriptions[key]
|
||||
if !found {
|
||||
var err error
|
||||
if sub, err = newAsyncUserSubscriberNats(key, e.client); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e.userSubscriptions[key] = sub
|
||||
}
|
||||
sub.addListener(listener)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *asyncEventsNats) UnregisterUserListener(roomId string, backend *Backend, listener AsyncUserEventListener) {
|
||||
key := GetSubjectForUserId(roomId, backend)
|
||||
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
sub, found := e.userSubscriptions[key]
|
||||
if !found {
|
||||
return
|
||||
}
|
||||
|
||||
if !sub.removeListener(listener) {
|
||||
delete(e.userSubscriptions, key)
|
||||
sub.close()
|
||||
}
|
||||
}
|
||||
|
||||
func (e *asyncEventsNats) RegisterSessionListener(sessionId string, backend *Backend, listener AsyncSessionEventListener) error {
|
||||
key := GetSubjectForSessionId(sessionId, backend)
|
||||
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
sub, found := e.sessionSubscriptions[key]
|
||||
if !found {
|
||||
var err error
|
||||
if sub, err = newAsyncSessionSubscriberNats(key, e.client); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e.sessionSubscriptions[key] = sub
|
||||
}
|
||||
sub.addListener(listener)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *asyncEventsNats) UnregisterSessionListener(sessionId string, backend *Backend, listener AsyncSessionEventListener) {
|
||||
key := GetSubjectForSessionId(sessionId, backend)
|
||||
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
sub, found := e.sessionSubscriptions[key]
|
||||
if !found {
|
||||
return
|
||||
}
|
||||
|
||||
if !sub.removeListener(listener) {
|
||||
delete(e.sessionSubscriptions, key)
|
||||
sub.close()
|
||||
}
|
||||
}
|
||||
|
||||
func (e *asyncEventsNats) publish(subject string, message *AsyncMessage) error {
|
||||
message.SendTime = time.Now()
|
||||
return e.client.Publish(subject, message)
|
||||
}
|
||||
|
||||
func (e *asyncEventsNats) PublishBackendRoomMessage(roomId string, backend *Backend, message *AsyncMessage) error {
|
||||
subject := GetSubjectForBackendRoomId(roomId, backend)
|
||||
return e.publish(subject, message)
|
||||
}
|
||||
|
||||
func (e *asyncEventsNats) PublishRoomMessage(roomId string, backend *Backend, message *AsyncMessage) error {
|
||||
subject := GetSubjectForRoomId(roomId, backend)
|
||||
return e.publish(subject, message)
|
||||
}
|
||||
|
||||
func (e *asyncEventsNats) PublishUserMessage(userId string, backend *Backend, message *AsyncMessage) error {
|
||||
subject := GetSubjectForUserId(userId, backend)
|
||||
return e.publish(subject, message)
|
||||
}
|
||||
|
||||
func (e *asyncEventsNats) PublishSessionMessage(sessionId string, backend *Backend, message *AsyncMessage) error {
|
||||
subject := GetSubjectForSessionId(sessionId, backend)
|
||||
return e.publish(subject, message)
|
||||
}
|
73
async_events_test.go
Normal file
73
async_events_test.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
/**
|
||||
* Standalone signaling server for the Nextcloud Spreed app.
|
||||
* Copyright (C) 2022 struktur AG
|
||||
*
|
||||
* @author Joachim Bauch <bauch@struktur.de>
|
||||
*
|
||||
* @license GNU AGPL version 3 or any later version
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package signaling
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var (
|
||||
eventBackendsForTest = []string{
|
||||
"loopback",
|
||||
"nats",
|
||||
}
|
||||
)
|
||||
|
||||
func getAsyncEventsForTest(t *testing.T) AsyncEvents {
|
||||
var events AsyncEvents
|
||||
if strings.HasSuffix(t.Name(), "/nats") {
|
||||
events = getRealAsyncEventsForTest(t)
|
||||
} else {
|
||||
events = getLoopbackAsyncEventsForTest(t)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
events.Close()
|
||||
})
|
||||
return events
|
||||
}
|
||||
|
||||
func getRealAsyncEventsForTest(t *testing.T) AsyncEvents {
|
||||
url := startLocalNatsServer(t)
|
||||
events, err := NewAsyncEvents(url)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return events
|
||||
}
|
||||
|
||||
func getLoopbackAsyncEventsForTest(t *testing.T) AsyncEvents {
|
||||
events, err := NewAsyncEvents(NatsLoopbackUrl)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
|
||||
nats := (events.(*asyncEventsNats)).client
|
||||
(nats).(*LoopbackNatsClient).waitForSubscriptionsEmpty(ctx, t)
|
||||
})
|
||||
return events
|
||||
}
|
|
@ -50,8 +50,8 @@ type BackendClient struct {
|
|||
capabilities *Capabilities
|
||||
}
|
||||
|
||||
func NewBackendClient(config *goconf.ConfigFile, maxConcurrentRequestsPerHost int, version string) (*BackendClient, error) {
|
||||
backends, err := NewBackendConfiguration(config)
|
||||
func NewBackendClient(config *goconf.ConfigFile, maxConcurrentRequestsPerHost int, version string, etcdClient *EtcdClient) (*BackendClient, error) {
|
||||
backends, err := NewBackendConfiguration(config, etcdClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -80,6 +80,10 @@ func NewBackendClient(config *goconf.ConfigFile, maxConcurrentRequestsPerHost in
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (b *BackendClient) Close() {
|
||||
b.backends.Close()
|
||||
}
|
||||
|
||||
func (b *BackendClient) Reload(config *goconf.ConfigFile) {
|
||||
b.backends.Reload(config)
|
||||
}
|
||||
|
|
|
@ -95,7 +95,7 @@ func TestPostOnRedirect(t *testing.T) {
|
|||
if u.Scheme == "http" {
|
||||
config.AddOption("backend", "allowhttp", "true")
|
||||
}
|
||||
client, err := NewBackendClient(config, 1, "0.0")
|
||||
client, err := NewBackendClient(config, 1, "0.0", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -134,7 +134,7 @@ func TestPostOnRedirectDifferentHost(t *testing.T) {
|
|||
if u.Scheme == "http" {
|
||||
config.AddOption("backend", "allowhttp", "true")
|
||||
}
|
||||
client, err := NewBackendClient(config, 1, "0.0")
|
||||
client, err := NewBackendClient(config, 1, "0.0", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -187,7 +187,7 @@ func TestPostOnRedirectStatusFound(t *testing.T) {
|
|||
if u.Scheme == "http" {
|
||||
config.AddOption("backend", "allowhttp", "true")
|
||||
}
|
||||
client, err := NewBackendClient(config, 1, "0.0")
|
||||
client, err := NewBackendClient(config, 1, "0.0", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -22,15 +22,21 @@
|
|||
package signaling
|
||||
|
||||
import (
|
||||
"log"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/dlintw/goconf"
|
||||
)
|
||||
|
||||
const (
|
||||
BackendTypeStatic = "static"
|
||||
BackendTypeEtcd = "etcd"
|
||||
|
||||
DefaultBackendType = BackendTypeStatic
|
||||
)
|
||||
|
||||
var (
|
||||
SessionLimitExceeded = NewError("session_limit_exceeded", "Too many sessions connected for this backend.")
|
||||
)
|
||||
|
@ -105,271 +111,43 @@ func (b *Backend) RemoveSession(session Session) {
|
|||
delete(b.sessions, session.PublicId())
|
||||
}
|
||||
|
||||
type BackendConfiguration struct {
|
||||
type BackendStorage interface {
|
||||
Close()
|
||||
Reload(config *goconf.ConfigFile)
|
||||
|
||||
GetCompatBackend() *Backend
|
||||
GetBackend(u *url.URL) *Backend
|
||||
GetBackends() []*Backend
|
||||
}
|
||||
|
||||
type backendStorageCommon struct {
|
||||
mu sync.RWMutex
|
||||
backends map[string][]*Backend
|
||||
|
||||
// Deprecated
|
||||
allowAll bool
|
||||
commonSecret []byte
|
||||
compatBackend *Backend
|
||||
}
|
||||
|
||||
func NewBackendConfiguration(config *goconf.ConfigFile) (*BackendConfiguration, error) {
|
||||
allowAll, _ := config.GetBool("backend", "allowall")
|
||||
allowHttp, _ := config.GetBool("backend", "allowhttp")
|
||||
commonSecret, _ := config.GetString("backend", "secret")
|
||||
sessionLimit, err := config.GetInt("backend", "sessionlimit")
|
||||
if err != nil || sessionLimit < 0 {
|
||||
sessionLimit = 0
|
||||
func (s *backendStorageCommon) GetBackends() []*Backend {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
var result []*Backend
|
||||
for _, entries := range s.backends {
|
||||
result = append(result, entries...)
|
||||
}
|
||||
backends := make(map[string][]*Backend)
|
||||
var compatBackend *Backend
|
||||
numBackends := 0
|
||||
if allowAll {
|
||||
log.Println("WARNING: All backend hostnames are allowed, only use for development!")
|
||||
compatBackend = &Backend{
|
||||
id: "compat",
|
||||
secret: []byte(commonSecret),
|
||||
compat: true,
|
||||
|
||||
allowHttp: allowHttp,
|
||||
|
||||
sessionLimit: uint64(sessionLimit),
|
||||
}
|
||||
if sessionLimit > 0 {
|
||||
log.Printf("Allow a maximum of %d sessions", sessionLimit)
|
||||
}
|
||||
numBackends++
|
||||
} else if backendIds, _ := config.GetString("backend", "backends"); backendIds != "" {
|
||||
for host, configuredBackends := range getConfiguredHosts(backendIds, config) {
|
||||
backends[host] = append(backends[host], configuredBackends...)
|
||||
for _, be := range configuredBackends {
|
||||
log.Printf("Backend %s added for %s", be.id, be.url)
|
||||
}
|
||||
numBackends += len(configuredBackends)
|
||||
}
|
||||
} else if allowedUrls, _ := config.GetString("backend", "allowed"); allowedUrls != "" {
|
||||
// Old-style configuration, only hosts are configured and are using a common secret.
|
||||
allowMap := make(map[string]bool)
|
||||
for _, u := range strings.Split(allowedUrls, ",") {
|
||||
u = strings.TrimSpace(u)
|
||||
if idx := strings.IndexByte(u, '/'); idx != -1 {
|
||||
log.Printf("WARNING: Removing path from allowed hostname \"%s\", check your configuration!", u)
|
||||
u = u[:idx]
|
||||
}
|
||||
if u != "" {
|
||||
allowMap[strings.ToLower(u)] = true
|
||||
}
|
||||
}
|
||||
|
||||
if len(allowMap) == 0 {
|
||||
log.Println("WARNING: No backend hostnames are allowed, check your configuration!")
|
||||
} else {
|
||||
compatBackend = &Backend{
|
||||
id: "compat",
|
||||
secret: []byte(commonSecret),
|
||||
compat: true,
|
||||
|
||||
allowHttp: allowHttp,
|
||||
|
||||
sessionLimit: uint64(sessionLimit),
|
||||
}
|
||||
hosts := make([]string, 0, len(allowMap))
|
||||
for host := range allowMap {
|
||||
hosts = append(hosts, host)
|
||||
backends[host] = []*Backend{compatBackend}
|
||||
}
|
||||
if len(hosts) > 1 {
|
||||
log.Println("WARNING: Using deprecated backend configuration. Please migrate the \"allowed\" setting to the new \"backends\" configuration.")
|
||||
}
|
||||
log.Printf("Allowed backend hostnames: %s", hosts)
|
||||
if sessionLimit > 0 {
|
||||
log.Printf("Allow a maximum of %d sessions", sessionLimit)
|
||||
}
|
||||
numBackends++
|
||||
}
|
||||
}
|
||||
|
||||
RegisterBackendConfigurationStats()
|
||||
statsBackendsCurrent.Add(float64(numBackends))
|
||||
|
||||
return &BackendConfiguration{
|
||||
backends: backends,
|
||||
|
||||
allowAll: allowAll,
|
||||
commonSecret: []byte(commonSecret),
|
||||
compatBackend: compatBackend,
|
||||
}, nil
|
||||
return result
|
||||
}
|
||||
|
||||
func (b *BackendConfiguration) RemoveBackendsForHost(host string) {
|
||||
if oldBackends := b.backends[host]; len(oldBackends) > 0 {
|
||||
for _, backend := range oldBackends {
|
||||
log.Printf("Backend %s removed for %s", backend.id, backend.url)
|
||||
}
|
||||
statsBackendsCurrent.Sub(float64(len(oldBackends)))
|
||||
}
|
||||
delete(b.backends, host)
|
||||
}
|
||||
func (s *backendStorageCommon) getBackendLocked(u *url.URL) *Backend {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
func (b *BackendConfiguration) UpsertHost(host string, backends []*Backend) {
|
||||
for existingIndex, existingBackend := range b.backends[host] {
|
||||
found := false
|
||||
index := 0
|
||||
for _, newBackend := range backends {
|
||||
if reflect.DeepEqual(existingBackend, newBackend) { // otherwise we could manually compare the struct members here
|
||||
found = true
|
||||
backends = append(backends[:index], backends[index+1:]...)
|
||||
break
|
||||
} else if newBackend.id == existingBackend.id {
|
||||
found = true
|
||||
b.backends[host][existingIndex] = newBackend
|
||||
backends = append(backends[:index], backends[index+1:]...)
|
||||
log.Printf("Backend %s updated for %s", newBackend.id, newBackend.url)
|
||||
break
|
||||
}
|
||||
index++
|
||||
}
|
||||
if !found {
|
||||
removed := b.backends[host][existingIndex]
|
||||
log.Printf("Backend %s removed for %s", removed.id, removed.url)
|
||||
b.backends[host] = append(b.backends[host][:existingIndex], b.backends[host][existingIndex+1:]...)
|
||||
statsBackendsCurrent.Dec()
|
||||
}
|
||||
}
|
||||
|
||||
b.backends[host] = append(b.backends[host], backends...)
|
||||
for _, added := range backends {
|
||||
log.Printf("Backend %s added for %s", added.id, added.url)
|
||||
}
|
||||
statsBackendsCurrent.Add(float64(len(backends)))
|
||||
}
|
||||
|
||||
func getConfiguredBackendIDs(backendIds string) (ids []string) {
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, id := range strings.Split(backendIds, ",") {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if seen[id] {
|
||||
continue
|
||||
}
|
||||
ids = append(ids, id)
|
||||
seen[id] = true
|
||||
}
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
func getConfiguredHosts(backendIds string, config *goconf.ConfigFile) (hosts map[string][]*Backend) {
|
||||
hosts = make(map[string][]*Backend)
|
||||
for _, id := range getConfiguredBackendIDs(backendIds) {
|
||||
u, _ := config.GetString(id, "url")
|
||||
if u == "" {
|
||||
log.Printf("Backend %s is missing or incomplete, skipping", id)
|
||||
continue
|
||||
}
|
||||
|
||||
if u[len(u)-1] != '/' {
|
||||
u += "/"
|
||||
}
|
||||
parsed, err := url.Parse(u)
|
||||
if err != nil {
|
||||
log.Printf("Backend %s has an invalid url %s configured (%s), skipping", id, u, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.Contains(parsed.Host, ":") && hasStandardPort(parsed) {
|
||||
parsed.Host = parsed.Hostname()
|
||||
u = parsed.String()
|
||||
}
|
||||
|
||||
secret, _ := config.GetString(id, "secret")
|
||||
if u == "" || secret == "" {
|
||||
log.Printf("Backend %s is missing or incomplete, skipping", id)
|
||||
continue
|
||||
}
|
||||
|
||||
sessionLimit, err := config.GetInt(id, "sessionlimit")
|
||||
if err != nil || sessionLimit < 0 {
|
||||
sessionLimit = 0
|
||||
}
|
||||
if sessionLimit > 0 {
|
||||
log.Printf("Backend %s allows a maximum of %d sessions", id, sessionLimit)
|
||||
}
|
||||
|
||||
maxStreamBitrate, err := config.GetInt(id, "maxstreambitrate")
|
||||
if err != nil || maxStreamBitrate < 0 {
|
||||
maxStreamBitrate = 0
|
||||
}
|
||||
maxScreenBitrate, err := config.GetInt(id, "maxscreenbitrate")
|
||||
if err != nil || maxScreenBitrate < 0 {
|
||||
maxScreenBitrate = 0
|
||||
}
|
||||
|
||||
hosts[parsed.Host] = append(hosts[parsed.Host], &Backend{
|
||||
id: id,
|
||||
url: u,
|
||||
secret: []byte(secret),
|
||||
|
||||
allowHttp: parsed.Scheme == "http",
|
||||
|
||||
maxStreamBitrate: maxStreamBitrate,
|
||||
maxScreenBitrate: maxScreenBitrate,
|
||||
|
||||
sessionLimit: uint64(sessionLimit),
|
||||
})
|
||||
}
|
||||
|
||||
return hosts
|
||||
}
|
||||
|
||||
func (b *BackendConfiguration) Reload(config *goconf.ConfigFile) {
|
||||
if b.compatBackend != nil {
|
||||
log.Println("Old-style configuration active, reload is not supported")
|
||||
return
|
||||
}
|
||||
|
||||
if backendIds, _ := config.GetString("backend", "backends"); backendIds != "" {
|
||||
configuredHosts := getConfiguredHosts(backendIds, config)
|
||||
|
||||
// remove backends that are no longer configured
|
||||
for hostname := range b.backends {
|
||||
if _, ok := configuredHosts[hostname]; !ok {
|
||||
b.RemoveBackendsForHost(hostname)
|
||||
}
|
||||
}
|
||||
|
||||
// rewrite backends adding newly configured ones and rewriting existing ones
|
||||
for hostname, configuredBackends := range configuredHosts {
|
||||
b.UpsertHost(hostname, configuredBackends)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BackendConfiguration) GetCompatBackend() *Backend {
|
||||
return b.compatBackend
|
||||
}
|
||||
|
||||
func (b *BackendConfiguration) GetBackend(u *url.URL) *Backend {
|
||||
if strings.Contains(u.Host, ":") && hasStandardPort(u) {
|
||||
u.Host = u.Hostname()
|
||||
}
|
||||
|
||||
entries, found := b.backends[u.Host]
|
||||
entries, found := s.backends[u.Host]
|
||||
if !found {
|
||||
if b.allowAll {
|
||||
return b.compatBackend
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
s := u.String()
|
||||
if s[len(s)-1] != '/' {
|
||||
s += "/"
|
||||
url := u.String()
|
||||
if url[len(url)-1] != '/' {
|
||||
url += "/"
|
||||
}
|
||||
for _, entry := range entries {
|
||||
if !entry.IsUrlAllowed(u) {
|
||||
|
@ -379,7 +157,7 @@ func (b *BackendConfiguration) GetBackend(u *url.URL) *Backend {
|
|||
if entry.url == "" {
|
||||
// Old-style configuration, only hosts are configured.
|
||||
return entry
|
||||
} else if strings.HasPrefix(s, entry.url) {
|
||||
} else if strings.HasPrefix(url, entry.url) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
|
@ -387,12 +165,59 @@ func (b *BackendConfiguration) GetBackend(u *url.URL) *Backend {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (b *BackendConfiguration) GetBackends() []*Backend {
|
||||
var result []*Backend
|
||||
for _, entries := range b.backends {
|
||||
result = append(result, entries...)
|
||||
type BackendConfiguration struct {
|
||||
storage BackendStorage
|
||||
}
|
||||
|
||||
func NewBackendConfiguration(config *goconf.ConfigFile, etcdClient *EtcdClient) (*BackendConfiguration, error) {
|
||||
backendType, _ := config.GetString("backend", "backendtype")
|
||||
if backendType == "" {
|
||||
backendType = DefaultBackendType
|
||||
}
|
||||
return result
|
||||
|
||||
RegisterBackendConfigurationStats()
|
||||
|
||||
var storage BackendStorage
|
||||
var err error
|
||||
switch backendType {
|
||||
case BackendTypeStatic:
|
||||
storage, err = NewBackendStorageStatic(config)
|
||||
case BackendTypeEtcd:
|
||||
storage, err = NewBackendStorageEtcd(config, etcdClient)
|
||||
default:
|
||||
err = fmt.Errorf("unknown backend type: %s", backendType)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &BackendConfiguration{
|
||||
storage: storage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *BackendConfiguration) Close() {
|
||||
b.storage.Close()
|
||||
}
|
||||
|
||||
func (b *BackendConfiguration) Reload(config *goconf.ConfigFile) {
|
||||
b.storage.Reload(config)
|
||||
}
|
||||
|
||||
func (b *BackendConfiguration) GetCompatBackend() *Backend {
|
||||
return b.storage.GetCompatBackend()
|
||||
}
|
||||
|
||||
func (b *BackendConfiguration) GetBackend(u *url.URL) *Backend {
|
||||
if strings.Contains(u.Host, ":") && hasStandardPort(u) {
|
||||
u.Host = u.Hostname()
|
||||
}
|
||||
|
||||
return b.storage.GetBackend(u)
|
||||
}
|
||||
|
||||
func (b *BackendConfiguration) GetBackends() []*Backend {
|
||||
return b.storage.GetBackends()
|
||||
}
|
||||
|
||||
func (b *BackendConfiguration) IsUrlAllowed(u *url.URL) bool {
|
||||
|
@ -416,5 +241,5 @@ func (b *BackendConfiguration) GetSecret(u *url.URL) []byte {
|
|||
return nil
|
||||
}
|
||||
|
||||
return entry.secret
|
||||
return entry.Secret()
|
||||
}
|
||||
|
|
|
@ -23,8 +23,10 @@ package signaling
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/dlintw/goconf"
|
||||
|
@ -104,7 +106,7 @@ func TestIsUrlAllowed_Compat(t *testing.T) {
|
|||
config.AddOption("backend", "allowed", "domain.invalid")
|
||||
config.AddOption("backend", "allowhttp", "true")
|
||||
config.AddOption("backend", "secret", string(testBackendSecret))
|
||||
cfg, err := NewBackendConfiguration(config)
|
||||
cfg, err := NewBackendConfiguration(config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -125,7 +127,7 @@ func TestIsUrlAllowed_CompatForceHttps(t *testing.T) {
|
|||
config := goconf.NewConfigFile()
|
||||
config.AddOption("backend", "allowed", "domain.invalid")
|
||||
config.AddOption("backend", "secret", string(testBackendSecret))
|
||||
cfg, err := NewBackendConfiguration(config)
|
||||
cfg, err := NewBackendConfiguration(config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -170,7 +172,7 @@ func TestIsUrlAllowed(t *testing.T) {
|
|||
config.AddOption("baz", "secret", string(testBackendSecret)+"-baz")
|
||||
config.AddOption("lala", "url", "https://otherdomain.invalid/")
|
||||
config.AddOption("lala", "secret", string(testBackendSecret)+"-lala")
|
||||
cfg, err := NewBackendConfiguration(config)
|
||||
cfg, err := NewBackendConfiguration(config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -187,7 +189,7 @@ func TestIsUrlAllowed_EmptyAllowlist(t *testing.T) {
|
|||
config := goconf.NewConfigFile()
|
||||
config.AddOption("backend", "allowed", "")
|
||||
config.AddOption("backend", "secret", string(testBackendSecret))
|
||||
cfg, err := NewBackendConfiguration(config)
|
||||
cfg, err := NewBackendConfiguration(config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -207,7 +209,7 @@ func TestIsUrlAllowed_AllowAll(t *testing.T) {
|
|||
config.AddOption("backend", "allowall", "true")
|
||||
config.AddOption("backend", "allowed", "")
|
||||
config.AddOption("backend", "secret", string(testBackendSecret))
|
||||
cfg, err := NewBackendConfiguration(config)
|
||||
cfg, err := NewBackendConfiguration(config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -247,7 +249,7 @@ func TestBackendReloadNoChange(t *testing.T) {
|
|||
original_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
|
||||
original_config.AddOption("backend2", "url", "http://domain2.invalid")
|
||||
original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
|
||||
o_cfg, err := NewBackendConfiguration(original_config)
|
||||
o_cfg, err := NewBackendConfiguration(original_config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -260,7 +262,7 @@ func TestBackendReloadNoChange(t *testing.T) {
|
|||
new_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
|
||||
new_config.AddOption("backend2", "url", "http://domain2.invalid")
|
||||
new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
|
||||
n_cfg, err := NewBackendConfiguration(new_config)
|
||||
n_cfg, err := NewBackendConfiguration(new_config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -282,7 +284,7 @@ func TestBackendReloadChangeExistingURL(t *testing.T) {
|
|||
original_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
|
||||
original_config.AddOption("backend2", "url", "http://domain2.invalid")
|
||||
original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
|
||||
o_cfg, err := NewBackendConfiguration(original_config)
|
||||
o_cfg, err := NewBackendConfiguration(original_config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -296,7 +298,7 @@ func TestBackendReloadChangeExistingURL(t *testing.T) {
|
|||
new_config.AddOption("backend1", "sessionlimit", "10")
|
||||
new_config.AddOption("backend2", "url", "http://domain2.invalid")
|
||||
new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
|
||||
n_cfg, err := NewBackendConfiguration(new_config)
|
||||
n_cfg, err := NewBackendConfiguration(new_config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -322,7 +324,7 @@ func TestBackendReloadChangeSecret(t *testing.T) {
|
|||
original_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
|
||||
original_config.AddOption("backend2", "url", "http://domain2.invalid")
|
||||
original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
|
||||
o_cfg, err := NewBackendConfiguration(original_config)
|
||||
o_cfg, err := NewBackendConfiguration(original_config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -335,7 +337,7 @@ func TestBackendReloadChangeSecret(t *testing.T) {
|
|||
new_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend3")
|
||||
new_config.AddOption("backend2", "url", "http://domain2.invalid")
|
||||
new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
|
||||
n_cfg, err := NewBackendConfiguration(new_config)
|
||||
n_cfg, err := NewBackendConfiguration(new_config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -358,7 +360,7 @@ func TestBackendReloadAddBackend(t *testing.T) {
|
|||
original_config.AddOption("backend", "allowall", "false")
|
||||
original_config.AddOption("backend1", "url", "http://domain1.invalid")
|
||||
original_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
|
||||
o_cfg, err := NewBackendConfiguration(original_config)
|
||||
o_cfg, err := NewBackendConfiguration(original_config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -372,7 +374,7 @@ func TestBackendReloadAddBackend(t *testing.T) {
|
|||
new_config.AddOption("backend2", "url", "http://domain2.invalid")
|
||||
new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
|
||||
new_config.AddOption("backend2", "sessionlimit", "10")
|
||||
n_cfg, err := NewBackendConfiguration(new_config)
|
||||
n_cfg, err := NewBackendConfiguration(new_config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -400,7 +402,7 @@ func TestBackendReloadRemoveHost(t *testing.T) {
|
|||
original_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
|
||||
original_config.AddOption("backend2", "url", "http://domain2.invalid")
|
||||
original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
|
||||
o_cfg, err := NewBackendConfiguration(original_config)
|
||||
o_cfg, err := NewBackendConfiguration(original_config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -411,7 +413,7 @@ func TestBackendReloadRemoveHost(t *testing.T) {
|
|||
new_config.AddOption("backend", "allowall", "false")
|
||||
new_config.AddOption("backend1", "url", "http://domain1.invalid")
|
||||
new_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
|
||||
n_cfg, err := NewBackendConfiguration(new_config)
|
||||
n_cfg, err := NewBackendConfiguration(new_config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -437,7 +439,7 @@ func TestBackendReloadRemoveBackendFromSharedHost(t *testing.T) {
|
|||
original_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
|
||||
original_config.AddOption("backend2", "url", "http://domain1.invalid/bar/")
|
||||
original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
|
||||
o_cfg, err := NewBackendConfiguration(original_config)
|
||||
o_cfg, err := NewBackendConfiguration(original_config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -448,7 +450,7 @@ func TestBackendReloadRemoveBackendFromSharedHost(t *testing.T) {
|
|||
new_config.AddOption("backend", "allowall", "false")
|
||||
new_config.AddOption("backend1", "url", "http://domain1.invalid/foo/")
|
||||
new_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
|
||||
n_cfg, err := NewBackendConfiguration(new_config)
|
||||
n_cfg, err := NewBackendConfiguration(new_config, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -464,3 +466,155 @@ func TestBackendReloadRemoveBackendFromSharedHost(t *testing.T) {
|
|||
t.Error("BackendConfiguration should be equal after Reload")
|
||||
}
|
||||
}
|
||||
|
||||
func sortBackends(backends []*Backend) []*Backend {
|
||||
result := make([]*Backend, len(backends))
|
||||
copy(result, backends)
|
||||
|
||||
sort.Slice(result, func(i, j int) bool {
|
||||
return result[i].Id() < result[j].Id()
|
||||
})
|
||||
return result
|
||||
}
|
||||
|
||||
func mustParse(s string) *url.URL {
|
||||
p, err := url.Parse(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func TestBackendConfiguration_Etcd(t *testing.T) {
|
||||
etcd, client := NewEtcdClientForTest(t)
|
||||
|
||||
url1 := "https://domain1.invalid/foo"
|
||||
initialSecret1 := string(testBackendSecret) + "-backend1-initial"
|
||||
secret1 := string(testBackendSecret) + "-backend1"
|
||||
|
||||
SetEtcdValue(etcd, "/backends/1_one", []byte("{\"url\":\""+url1+"\",\"secret\":\""+initialSecret1+"\"}"))
|
||||
|
||||
config := goconf.NewConfigFile()
|
||||
config.AddOption("backend", "backendtype", "etcd")
|
||||
config.AddOption("backend", "backendprefix", "/backends")
|
||||
|
||||
cfg, err := NewBackendConfiguration(config, client)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer cfg.Close()
|
||||
|
||||
storage := cfg.storage.(*backendStorageEtcd)
|
||||
ch := make(chan bool, 1)
|
||||
storage.SetWakeupForTesting(ch)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := storage.WaitForInitialized(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if backends := sortBackends(cfg.GetBackends()); len(backends) != 1 {
|
||||
t.Errorf("Expected one backend, got %+v", backends)
|
||||
} else if backends[0].url != url1 {
|
||||
t.Errorf("Expected backend url %s, got %s", url1, backends[0].url)
|
||||
} else if string(backends[0].secret) != initialSecret1 {
|
||||
t.Errorf("Expected backend secret %s, got %s", initialSecret1, string(backends[0].secret))
|
||||
} else if backend := cfg.GetBackend(mustParse(url1)); backend != backends[0] {
|
||||
t.Errorf("Expected backend %+v, got %+v", backends[0], backend)
|
||||
}
|
||||
|
||||
drainWakeupChannel(ch)
|
||||
SetEtcdValue(etcd, "/backends/1_one", []byte("{\"url\":\""+url1+"\",\"secret\":\""+secret1+"\"}"))
|
||||
<-ch
|
||||
if backends := sortBackends(cfg.GetBackends()); len(backends) != 1 {
|
||||
t.Errorf("Expected one backend, got %+v", backends)
|
||||
} else if backends[0].url != url1 {
|
||||
t.Errorf("Expected backend url %s, got %s", url1, backends[0].url)
|
||||
} else if string(backends[0].secret) != secret1 {
|
||||
t.Errorf("Expected backend secret %s, got %s", secret1, string(backends[0].secret))
|
||||
} else if backend := cfg.GetBackend(mustParse(url1)); backend != backends[0] {
|
||||
t.Errorf("Expected backend %+v, got %+v", backends[0], backend)
|
||||
}
|
||||
|
||||
url2 := "https://domain1.invalid/bar"
|
||||
secret2 := string(testBackendSecret) + "-backend2"
|
||||
|
||||
drainWakeupChannel(ch)
|
||||
SetEtcdValue(etcd, "/backends/2_two", []byte("{\"url\":\""+url2+"\",\"secret\":\""+secret2+"\"}"))
|
||||
<-ch
|
||||
if backends := sortBackends(cfg.GetBackends()); len(backends) != 2 {
|
||||
t.Errorf("Expected two backends, got %+v", backends)
|
||||
} else if backends[0].url != url1 {
|
||||
t.Errorf("Expected backend url %s, got %s", url1, backends[0].url)
|
||||
} else if string(backends[0].secret) != secret1 {
|
||||
t.Errorf("Expected backend secret %s, got %s", secret1, string(backends[0].secret))
|
||||
} else if backends[1].url != url2 {
|
||||
t.Errorf("Expected backend url %s, got %s", url2, backends[1].url)
|
||||
} else if string(backends[1].secret) != secret2 {
|
||||
t.Errorf("Expected backend secret %s, got %s", secret2, string(backends[1].secret))
|
||||
} else if backend := cfg.GetBackend(mustParse(url1)); backend != backends[0] {
|
||||
t.Errorf("Expected backend %+v, got %+v", backends[0], backend)
|
||||
} else if backend := cfg.GetBackend(mustParse(url2)); backend != backends[1] {
|
||||
t.Errorf("Expected backend %+v, got %+v", backends[1], backend)
|
||||
}
|
||||
|
||||
url3 := "https://domain2.invalid/foo"
|
||||
secret3 := string(testBackendSecret) + "-backend3"
|
||||
|
||||
drainWakeupChannel(ch)
|
||||
SetEtcdValue(etcd, "/backends/3_three", []byte("{\"url\":\""+url3+"\",\"secret\":\""+secret3+"\"}"))
|
||||
<-ch
|
||||
if backends := sortBackends(cfg.GetBackends()); len(backends) != 3 {
|
||||
t.Errorf("Expected three backends, got %+v", backends)
|
||||
} else if backends[0].url != url1 {
|
||||
t.Errorf("Expected backend url %s, got %s", url1, backends[0].url)
|
||||
} else if string(backends[0].secret) != secret1 {
|
||||
t.Errorf("Expected backend secret %s, got %s", secret1, string(backends[0].secret))
|
||||
} else if backends[1].url != url2 {
|
||||
t.Errorf("Expected backend url %s, got %s", url2, backends[1].url)
|
||||
} else if string(backends[1].secret) != secret2 {
|
||||
t.Errorf("Expected backend secret %s, got %s", secret2, string(backends[1].secret))
|
||||
} else if backends[2].url != url3 {
|
||||
t.Errorf("Expected backend url %s, got %s", url3, backends[2].url)
|
||||
} else if string(backends[2].secret) != secret3 {
|
||||
t.Errorf("Expected backend secret %s, got %s", secret3, string(backends[2].secret))
|
||||
} else if backend := cfg.GetBackend(mustParse(url1)); backend != backends[0] {
|
||||
t.Errorf("Expected backend %+v, got %+v", backends[0], backend)
|
||||
} else if backend := cfg.GetBackend(mustParse(url2)); backend != backends[1] {
|
||||
t.Errorf("Expected backend %+v, got %+v", backends[1], backend)
|
||||
} else if backend := cfg.GetBackend(mustParse(url3)); backend != backends[2] {
|
||||
t.Errorf("Expected backend %+v, got %+v", backends[2], backend)
|
||||
}
|
||||
|
||||
drainWakeupChannel(ch)
|
||||
DeleteEtcdValue(etcd, "/backends/1_one")
|
||||
<-ch
|
||||
if backends := sortBackends(cfg.GetBackends()); len(backends) != 2 {
|
||||
t.Errorf("Expected two backends, got %+v", backends)
|
||||
} else if backends[0].url != url2 {
|
||||
t.Errorf("Expected backend url %s, got %s", url2, backends[0].url)
|
||||
} else if string(backends[0].secret) != secret2 {
|
||||
t.Errorf("Expected backend secret %s, got %s", secret2, string(backends[0].secret))
|
||||
} else if backends[1].url != url3 {
|
||||
t.Errorf("Expected backend url %s, got %s", url3, backends[1].url)
|
||||
} else if string(backends[1].secret) != secret3 {
|
||||
t.Errorf("Expected backend secret %s, got %s", secret3, string(backends[1].secret))
|
||||
}
|
||||
|
||||
drainWakeupChannel(ch)
|
||||
DeleteEtcdValue(etcd, "/backends/2_two")
|
||||
<-ch
|
||||
if backends := sortBackends(cfg.GetBackends()); len(backends) != 1 {
|
||||
t.Errorf("Expected one backend, got %+v", backends)
|
||||
} else if backends[0].url != url3 {
|
||||
t.Errorf("Expected backend url %s, got %s", url3, backends[0].url)
|
||||
} else if string(backends[0].secret) != secret3 {
|
||||
t.Errorf("Expected backend secret %s, got %s", secret3, string(backends[0].secret))
|
||||
}
|
||||
|
||||
if _, found := storage.backends["domain1.invalid"]; found {
|
||||
t.Errorf("Should have removed host information for %s", "domain1.invalid")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
package signaling
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
|
@ -53,7 +54,7 @@ const (
|
|||
|
||||
type BackendServer struct {
|
||||
hub *Hub
|
||||
nats NatsClient
|
||||
events AsyncEvents
|
||||
roomSessions RoomSessions
|
||||
|
||||
version string
|
||||
|
@ -123,7 +124,7 @@ func NewBackendServer(config *goconf.ConfigFile, hub *Hub, version string) (*Bac
|
|||
|
||||
return &BackendServer{
|
||||
hub: hub,
|
||||
nats: hub.nats,
|
||||
events: hub.events,
|
||||
roomSessions: hub.roomSessions,
|
||||
version: version,
|
||||
|
||||
|
@ -279,45 +280,53 @@ func (b *BackendServer) parseRequestBody(f func(http.ResponseWriter, *http.Reque
|
|||
}
|
||||
|
||||
func (b *BackendServer) sendRoomInvite(roomid string, backend *Backend, userids []string, properties *json.RawMessage) {
|
||||
msg := &ServerMessage{
|
||||
Type: "event",
|
||||
Event: &EventServerMessage{
|
||||
Target: "roomlist",
|
||||
Type: "invite",
|
||||
Invite: &RoomEventServerMessage{
|
||||
RoomId: roomid,
|
||||
Properties: properties,
|
||||
msg := &AsyncMessage{
|
||||
Type: "message",
|
||||
Message: &ServerMessage{
|
||||
Type: "event",
|
||||
Event: &EventServerMessage{
|
||||
Target: "roomlist",
|
||||
Type: "invite",
|
||||
Invite: &RoomEventServerMessage{
|
||||
RoomId: roomid,
|
||||
Properties: properties,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, userid := range userids {
|
||||
if err := b.nats.PublishMessage(GetSubjectForUserId(userid, backend), msg); err != nil {
|
||||
if err := b.events.PublishUserMessage(userid, backend, msg); err != nil {
|
||||
log.Printf("Could not publish room invite for user %s in backend %s: %s", userid, backend.Id(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BackendServer) sendRoomDisinvite(roomid string, backend *Backend, reason string, userids []string, sessionids []string) {
|
||||
msg := &ServerMessage{
|
||||
Type: "event",
|
||||
Event: &EventServerMessage{
|
||||
Target: "roomlist",
|
||||
Type: "disinvite",
|
||||
Disinvite: &RoomDisinviteEventServerMessage{
|
||||
RoomEventServerMessage: RoomEventServerMessage{
|
||||
RoomId: roomid,
|
||||
msg := &AsyncMessage{
|
||||
Type: "message",
|
||||
Message: &ServerMessage{
|
||||
Type: "event",
|
||||
Event: &EventServerMessage{
|
||||
Target: "roomlist",
|
||||
Type: "disinvite",
|
||||
Disinvite: &RoomDisinviteEventServerMessage{
|
||||
RoomEventServerMessage: RoomEventServerMessage{
|
||||
RoomId: roomid,
|
||||
},
|
||||
Reason: reason,
|
||||
},
|
||||
Reason: reason,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, userid := range userids {
|
||||
if err := b.nats.PublishMessage(GetSubjectForUserId(userid, backend), msg); err != nil {
|
||||
if err := b.events.PublishUserMessage(userid, backend, msg); err != nil {
|
||||
log.Printf("Could not publish room disinvite for user %s in backend %s: %s", userid, backend.Id(), err)
|
||||
}
|
||||
}
|
||||
|
||||
timeout := time.Second
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
var wg sync.WaitGroup
|
||||
for _, sessionid := range sessionids {
|
||||
if sessionid == sessionIdNotInMeeting {
|
||||
|
@ -328,10 +337,10 @@ func (b *BackendServer) sendRoomDisinvite(roomid string, backend *Backend, reaso
|
|||
wg.Add(1)
|
||||
go func(sessionid string) {
|
||||
defer wg.Done()
|
||||
if sid, err := b.lookupByRoomSessionId(sessionid, nil, timeout); err != nil {
|
||||
if sid, err := b.lookupByRoomSessionId(ctx, sessionid, nil); err != nil {
|
||||
log.Printf("Could not lookup by room session %s: %s", sessionid, err)
|
||||
} else if sid != "" {
|
||||
if err := b.nats.PublishMessage("session."+sid, msg); err != nil {
|
||||
if err := b.events.PublishSessionMessage(sid, backend, msg); err != nil {
|
||||
log.Printf("Could not publish room disinvite for session %s: %s", sid, err)
|
||||
}
|
||||
}
|
||||
|
@ -341,14 +350,17 @@ func (b *BackendServer) sendRoomDisinvite(roomid string, backend *Backend, reaso
|
|||
}
|
||||
|
||||
func (b *BackendServer) sendRoomUpdate(roomid string, backend *Backend, notified_userids []string, all_userids []string, properties *json.RawMessage) {
|
||||
msg := &ServerMessage{
|
||||
Type: "event",
|
||||
Event: &EventServerMessage{
|
||||
Target: "roomlist",
|
||||
Type: "update",
|
||||
Update: &RoomEventServerMessage{
|
||||
RoomId: roomid,
|
||||
Properties: properties,
|
||||
msg := &AsyncMessage{
|
||||
Type: "message",
|
||||
Message: &ServerMessage{
|
||||
Type: "event",
|
||||
Event: &EventServerMessage{
|
||||
Target: "roomlist",
|
||||
Type: "update",
|
||||
Update: &RoomEventServerMessage{
|
||||
RoomId: roomid,
|
||||
Properties: properties,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -362,13 +374,13 @@ func (b *BackendServer) sendRoomUpdate(roomid string, backend *Backend, notified
|
|||
continue
|
||||
}
|
||||
|
||||
if err := b.nats.PublishMessage(GetSubjectForUserId(userid, backend), msg); err != nil {
|
||||
if err := b.events.PublishUserMessage(userid, backend, msg); err != nil {
|
||||
log.Printf("Could not publish room update for user %s in backend %s: %s", userid, backend.Id(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BackendServer) lookupByRoomSessionId(roomSessionId string, cache *ConcurrentStringStringMap, timeout time.Duration) (string, error) {
|
||||
func (b *BackendServer) lookupByRoomSessionId(ctx context.Context, roomSessionId string, cache *ConcurrentStringStringMap) (string, error) {
|
||||
if roomSessionId == sessionIdNotInMeeting {
|
||||
log.Printf("Trying to lookup empty room session id: %s", roomSessionId)
|
||||
return "", nil
|
||||
|
@ -380,7 +392,7 @@ func (b *BackendServer) lookupByRoomSessionId(roomSessionId string, cache *Concu
|
|||
}
|
||||
}
|
||||
|
||||
sid, err := b.roomSessions.GetSessionId(roomSessionId)
|
||||
sid, err := b.roomSessions.LookupSessionId(ctx, roomSessionId)
|
||||
if err == ErrNoSuchRoomSession {
|
||||
return "", nil
|
||||
} else if err != nil {
|
||||
|
@ -393,7 +405,7 @@ func (b *BackendServer) lookupByRoomSessionId(roomSessionId string, cache *Concu
|
|||
return sid, nil
|
||||
}
|
||||
|
||||
func (b *BackendServer) fixupUserSessions(cache *ConcurrentStringStringMap, users []map[string]interface{}, timeout time.Duration) []map[string]interface{} {
|
||||
func (b *BackendServer) fixupUserSessions(ctx context.Context, cache *ConcurrentStringStringMap, users []map[string]interface{}) []map[string]interface{} {
|
||||
if len(users) == 0 {
|
||||
return users
|
||||
}
|
||||
|
@ -421,7 +433,7 @@ func (b *BackendServer) fixupUserSessions(cache *ConcurrentStringStringMap, user
|
|||
wg.Add(1)
|
||||
go func(roomSessionId string, u map[string]interface{}) {
|
||||
defer wg.Done()
|
||||
if sessionId, err := b.lookupByRoomSessionId(roomSessionId, cache, timeout); err != nil {
|
||||
if sessionId, err := b.lookupByRoomSessionId(ctx, roomSessionId, cache); err != nil {
|
||||
log.Printf("Could not lookup by room session %s: %s", roomSessionId, err)
|
||||
delete(u, "sessionId")
|
||||
} else if sessionId != "" {
|
||||
|
@ -447,27 +459,35 @@ func (b *BackendServer) sendRoomIncall(roomid string, backend *Backend, request
|
|||
if !request.InCall.All {
|
||||
timeout := time.Second
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
var cache ConcurrentStringStringMap
|
||||
// Convert (Nextcloud) session ids to signaling session ids.
|
||||
request.InCall.Users = b.fixupUserSessions(&cache, request.InCall.Users, timeout)
|
||||
request.InCall.Users = b.fixupUserSessions(ctx, &cache, request.InCall.Users)
|
||||
// Entries in "Changed" are most likely already fetched through the "Users" list.
|
||||
request.InCall.Changed = b.fixupUserSessions(&cache, request.InCall.Changed, timeout)
|
||||
request.InCall.Changed = b.fixupUserSessions(ctx, &cache, request.InCall.Changed)
|
||||
|
||||
if len(request.InCall.Users) == 0 && len(request.InCall.Changed) == 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return b.nats.PublishBackendServerRoomRequest(GetSubjectForBackendRoomId(roomid, backend), request)
|
||||
message := &AsyncMessage{
|
||||
Type: "room",
|
||||
Room: request,
|
||||
}
|
||||
return b.events.PublishBackendRoomMessage(roomid, backend, message)
|
||||
}
|
||||
|
||||
func (b *BackendServer) sendRoomParticipantsUpdate(roomid string, backend *Backend, request *BackendServerRoomRequest) error {
|
||||
timeout := time.Second
|
||||
|
||||
// Convert (Nextcloud) session ids to signaling session ids.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
var cache ConcurrentStringStringMap
|
||||
request.Participants.Users = b.fixupUserSessions(&cache, request.Participants.Users, timeout)
|
||||
request.Participants.Changed = b.fixupUserSessions(&cache, request.Participants.Changed, timeout)
|
||||
request.Participants.Users = b.fixupUserSessions(ctx, &cache, request.Participants.Users)
|
||||
request.Participants.Changed = b.fixupUserSessions(ctx, &cache, request.Participants.Changed)
|
||||
|
||||
if len(request.Participants.Users) == 0 && len(request.Participants.Changed) == 0 {
|
||||
return nil
|
||||
|
@ -500,22 +520,30 @@ loop:
|
|||
|
||||
go func(sessionId string, permissions []Permission) {
|
||||
defer wg.Done()
|
||||
message := &NatsMessage{
|
||||
message := &AsyncMessage{
|
||||
Type: "permissions",
|
||||
Permissions: permissions,
|
||||
}
|
||||
if err := b.nats.Publish("session."+sessionId, message); err != nil {
|
||||
if err := b.events.PublishSessionMessage(sessionId, backend, message); err != nil {
|
||||
log.Printf("Could not send permissions update (%+v) to session %s: %s", permissions, sessionId, err)
|
||||
}
|
||||
}(sessionId, permissions)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
return b.nats.PublishBackendServerRoomRequest(GetSubjectForBackendRoomId(roomid, backend), request)
|
||||
message := &AsyncMessage{
|
||||
Type: "room",
|
||||
Room: request,
|
||||
}
|
||||
return b.events.PublishBackendRoomMessage(roomid, backend, message)
|
||||
}
|
||||
|
||||
func (b *BackendServer) sendRoomMessage(roomid string, backend *Backend, request *BackendServerRoomRequest) error {
|
||||
return b.nats.PublishBackendServerRoomRequest(GetSubjectForBackendRoomId(roomid, backend), request)
|
||||
message := &AsyncMessage{
|
||||
Type: "room",
|
||||
Room: request,
|
||||
}
|
||||
return b.events.PublishBackendRoomMessage(roomid, backend, message)
|
||||
}
|
||||
|
||||
func (b *BackendServer) roomHandler(w http.ResponseWriter, r *http.Request, body []byte) {
|
||||
|
@ -580,10 +608,18 @@ func (b *BackendServer) roomHandler(w http.ResponseWriter, r *http.Request, body
|
|||
b.sendRoomDisinvite(roomid, backend, DisinviteReasonDisinvited, request.Disinvite.UserIds, request.Disinvite.SessionIds)
|
||||
b.sendRoomUpdate(roomid, backend, request.Disinvite.UserIds, request.Disinvite.AllUserIds, request.Disinvite.Properties)
|
||||
case "update":
|
||||
err = b.nats.PublishBackendServerRoomRequest(GetSubjectForBackendRoomId(roomid, backend), &request)
|
||||
message := &AsyncMessage{
|
||||
Type: "room",
|
||||
Room: &request,
|
||||
}
|
||||
err = b.events.PublishBackendRoomMessage(roomid, backend, message)
|
||||
b.sendRoomUpdate(roomid, backend, nil, request.Update.UserIds, request.Update.Properties)
|
||||
case "delete":
|
||||
err = b.nats.PublishBackendServerRoomRequest(GetSubjectForBackendRoomId(roomid, backend), &request)
|
||||
message := &AsyncMessage{
|
||||
Type: "room",
|
||||
Room: &request,
|
||||
}
|
||||
err = b.events.PublishBackendRoomMessage(roomid, backend, message)
|
||||
b.sendRoomDisinvite(roomid, backend, DisinviteReasonDeleted, request.Delete.UserIds, nil)
|
||||
case "incall":
|
||||
err = b.sendRoomIncall(roomid, backend, &request)
|
||||
|
|
|
@ -42,7 +42,6 @@ import (
|
|||
"github.com/dlintw/goconf"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/nats-io/nats.go"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -52,11 +51,11 @@ var (
|
|||
turnServers = strings.Split(turnServersString, ",")
|
||||
)
|
||||
|
||||
func CreateBackendServerForTest(t *testing.T) (*goconf.ConfigFile, *BackendServer, NatsClient, *Hub, *mux.Router, *httptest.Server) {
|
||||
func CreateBackendServerForTest(t *testing.T) (*goconf.ConfigFile, *BackendServer, AsyncEvents, *Hub, *mux.Router, *httptest.Server) {
|
||||
return CreateBackendServerForTestFromConfig(t, nil)
|
||||
}
|
||||
|
||||
func CreateBackendServerForTestWithTurn(t *testing.T) (*goconf.ConfigFile, *BackendServer, NatsClient, *Hub, *mux.Router, *httptest.Server) {
|
||||
func CreateBackendServerForTestWithTurn(t *testing.T) (*goconf.ConfigFile, *BackendServer, AsyncEvents, *Hub, *mux.Router, *httptest.Server) {
|
||||
config := goconf.NewConfigFile()
|
||||
config.AddOption("turn", "apikey", turnApiKey)
|
||||
config.AddOption("turn", "secret", turnSecret)
|
||||
|
@ -64,11 +63,14 @@ func CreateBackendServerForTestWithTurn(t *testing.T) (*goconf.ConfigFile, *Back
|
|||
return CreateBackendServerForTestFromConfig(t, config)
|
||||
}
|
||||
|
||||
func CreateBackendServerForTestFromConfig(t *testing.T, config *goconf.ConfigFile) (*goconf.ConfigFile, *BackendServer, NatsClient, *Hub, *mux.Router, *httptest.Server) {
|
||||
func CreateBackendServerForTestFromConfig(t *testing.T, config *goconf.ConfigFile) (*goconf.ConfigFile, *BackendServer, AsyncEvents, *Hub, *mux.Router, *httptest.Server) {
|
||||
r := mux.NewRouter()
|
||||
registerBackendHandler(t, r)
|
||||
|
||||
server := httptest.NewServer(r)
|
||||
t.Cleanup(func() {
|
||||
server.Close()
|
||||
})
|
||||
if config == nil {
|
||||
config = goconf.NewConfigFile()
|
||||
}
|
||||
|
@ -85,11 +87,8 @@ func CreateBackendServerForTestFromConfig(t *testing.T, config *goconf.ConfigFil
|
|||
config.AddOption("sessions", "blockkey", "09876543210987654321098765432109")
|
||||
config.AddOption("clients", "internalsecret", string(testInternalSecret))
|
||||
config.AddOption("geoip", "url", "none")
|
||||
nats, err := NewLoopbackNatsClient()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
hub, err := NewHub(config, nats, r, "no-version")
|
||||
events := getAsyncEventsForTest(t)
|
||||
hub, err := NewHub(config, events, nil, nil, nil, r, "no-version")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -108,12 +107,122 @@ func CreateBackendServerForTestFromConfig(t *testing.T, config *goconf.ConfigFil
|
|||
defer cancel()
|
||||
|
||||
WaitForHub(ctx, t, hub)
|
||||
(nats).(*LoopbackNatsClient).waitForSubscriptionsEmpty(ctx, t)
|
||||
nats.Close()
|
||||
server.Close()
|
||||
})
|
||||
|
||||
return config, b, nats, hub, r, server
|
||||
return config, b, events, hub, r, server
|
||||
}
|
||||
|
||||
func CreateBackendServerWithClusteringForTest(t *testing.T) (*BackendServer, *BackendServer, *Hub, *Hub, *httptest.Server, *httptest.Server) {
|
||||
return CreateBackendServerWithClusteringForTestFromConfig(t, nil, nil)
|
||||
}
|
||||
|
||||
func CreateBackendServerWithClusteringForTestFromConfig(t *testing.T, config1 *goconf.ConfigFile, config2 *goconf.ConfigFile) (*BackendServer, *BackendServer, *Hub, *Hub, *httptest.Server, *httptest.Server) {
|
||||
r1 := mux.NewRouter()
|
||||
registerBackendHandler(t, r1)
|
||||
|
||||
server1 := httptest.NewServer(r1)
|
||||
t.Cleanup(func() {
|
||||
server1.Close()
|
||||
})
|
||||
|
||||
r2 := mux.NewRouter()
|
||||
registerBackendHandler(t, r2)
|
||||
|
||||
server2 := httptest.NewServer(r2)
|
||||
t.Cleanup(func() {
|
||||
server2.Close()
|
||||
})
|
||||
|
||||
nats := startLocalNatsServer(t)
|
||||
grpcServer1, addr1 := NewGrpcServerForTest(t)
|
||||
grpcServer2, addr2 := NewGrpcServerForTest(t)
|
||||
|
||||
if config1 == nil {
|
||||
config1 = goconf.NewConfigFile()
|
||||
}
|
||||
u1, err := url.Parse(server1.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config1.AddOption("backend", "allowed", u1.Host)
|
||||
if u1.Scheme == "http" {
|
||||
config1.AddOption("backend", "allowhttp", "true")
|
||||
}
|
||||
config1.AddOption("backend", "secret", string(testBackendSecret))
|
||||
config1.AddOption("sessions", "hashkey", "12345678901234567890123456789012")
|
||||
config1.AddOption("sessions", "blockkey", "09876543210987654321098765432109")
|
||||
config1.AddOption("clients", "internalsecret", string(testInternalSecret))
|
||||
config1.AddOption("geoip", "url", "none")
|
||||
|
||||
events1, err := NewAsyncEvents(nats)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
events1.Close()
|
||||
})
|
||||
client1 := NewGrpcClientsForTest(t, addr2)
|
||||
hub1, err := NewHub(config1, events1, grpcServer1, client1, nil, r1, "no-version")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if config2 == nil {
|
||||
config2 = goconf.NewConfigFile()
|
||||
}
|
||||
u2, err := url.Parse(server2.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config2.AddOption("backend", "allowed", u2.Host)
|
||||
if u2.Scheme == "http" {
|
||||
config2.AddOption("backend", "allowhttp", "true")
|
||||
}
|
||||
config2.AddOption("backend", "secret", string(testBackendSecret))
|
||||
config2.AddOption("sessions", "hashkey", "12345678901234567890123456789012")
|
||||
config2.AddOption("sessions", "blockkey", "09876543210987654321098765432109")
|
||||
config2.AddOption("clients", "internalsecret", string(testInternalSecret))
|
||||
config2.AddOption("geoip", "url", "none")
|
||||
events2, err := NewAsyncEvents(nats)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
events2.Close()
|
||||
})
|
||||
client2 := NewGrpcClientsForTest(t, addr1)
|
||||
hub2, err := NewHub(config2, events2, grpcServer2, client2, nil, r2, "no-version")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
b1, err := NewBackendServer(config1, hub1, "no-version")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := b1.Start(r1); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
b2, err := NewBackendServer(config2, hub2, "no-version")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := b2.Start(r2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
go hub1.Run()
|
||||
go hub2.Run()
|
||||
|
||||
t.Cleanup(func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
|
||||
WaitForHub(ctx, t, hub1)
|
||||
WaitForHub(ctx, t, hub2)
|
||||
})
|
||||
|
||||
return b1, b2, hub1, hub2, server1, server2
|
||||
}
|
||||
|
||||
func performBackendRequest(url string, body []byte) (*http.Response, error) {
|
||||
|
@ -131,23 +240,16 @@ func performBackendRequest(url string, body []byte) (*http.Response, error) {
|
|||
return client.Do(request)
|
||||
}
|
||||
|
||||
func expectRoomlistEvent(n NatsClient, ch chan *nats.Msg, subject string, msgType string) (*EventServerMessage, error) {
|
||||
func expectRoomlistEvent(ch chan *AsyncMessage, msgType string) (*EventServerMessage, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
select {
|
||||
case message := <-ch:
|
||||
if message.Subject != subject {
|
||||
return nil, fmt.Errorf("Expected subject %s, got %s", subject, message.Subject)
|
||||
}
|
||||
var natsMsg NatsMessage
|
||||
if err := n.Decode(message, &natsMsg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if natsMsg.Type != "message" || natsMsg.Message == nil {
|
||||
return nil, fmt.Errorf("Expected message type message, got %+v", natsMsg)
|
||||
if message.Type != "message" || message.Message == nil {
|
||||
return nil, fmt.Errorf("Expected message type message, got %+v", message)
|
||||
}
|
||||
|
||||
msg := natsMsg.Message
|
||||
msg := message.Message
|
||||
if msg.Type != "event" || msg.Event == nil {
|
||||
return nil, fmt.Errorf("Expected message type event, got %+v", msg)
|
||||
}
|
||||
|
@ -309,7 +411,23 @@ func TestBackendServer_UnsupportedRequest(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestBackendServer_RoomInvite(t *testing.T) {
|
||||
_, _, n, hub, _, server := CreateBackendServerForTest(t)
|
||||
for _, backend := range eventBackendsForTest {
|
||||
t.Run(backend, func(t *testing.T) {
|
||||
RunTestBackendServer_RoomInvite(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type channelEventListener struct {
|
||||
ch chan *AsyncMessage
|
||||
}
|
||||
|
||||
func (l *channelEventListener) ProcessAsyncUserMessage(message *AsyncMessage) {
|
||||
l.ch <- message
|
||||
}
|
||||
|
||||
func RunTestBackendServer_RoomInvite(t *testing.T) {
|
||||
_, _, events, hub, _, server := CreateBackendServerForTest(t)
|
||||
|
||||
u, err := url.Parse(server.URL)
|
||||
if err != nil {
|
||||
|
@ -320,17 +438,14 @@ func TestBackendServer_RoomInvite(t *testing.T) {
|
|||
roomProperties := json.RawMessage("{\"foo\":\"bar\"}")
|
||||
backend := hub.backend.GetBackend(u)
|
||||
|
||||
natsChan := make(chan *nats.Msg, 1)
|
||||
subject := GetSubjectForUserId(userid, backend)
|
||||
sub, err := n.Subscribe(subject, natsChan)
|
||||
if err != nil {
|
||||
eventsChan := make(chan *AsyncMessage, 1)
|
||||
listener := &channelEventListener{
|
||||
ch: eventsChan,
|
||||
}
|
||||
if err := events.RegisterUserListener(userid, backend, listener); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := sub.Unsubscribe(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
defer events.UnregisterUserListener(userid, backend, listener)
|
||||
|
||||
msg := &BackendServerRoomRequest{
|
||||
Type: "invite",
|
||||
|
@ -363,7 +478,7 @@ func TestBackendServer_RoomInvite(t *testing.T) {
|
|||
t.Errorf("Expected successful request, got %s: %s", res.Status, string(body))
|
||||
}
|
||||
|
||||
event, err := expectRoomlistEvent(n, natsChan, subject, "invite")
|
||||
event, err := expectRoomlistEvent(eventsChan, "invite")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
} else if event.Invite == nil {
|
||||
|
@ -376,7 +491,15 @@ func TestBackendServer_RoomInvite(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestBackendServer_RoomDisinvite(t *testing.T) {
|
||||
_, _, n, hub, _, server := CreateBackendServerForTest(t)
|
||||
for _, backend := range eventBackendsForTest {
|
||||
t.Run(backend, func(t *testing.T) {
|
||||
RunTestBackendServer_RoomDisinvite(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func RunTestBackendServer_RoomDisinvite(t *testing.T) {
|
||||
_, _, events, hub, _, server := CreateBackendServerForTest(t)
|
||||
|
||||
u, err := url.Parse(server.URL)
|
||||
if err != nil {
|
||||
|
@ -414,17 +537,14 @@ func TestBackendServer_RoomDisinvite(t *testing.T) {
|
|||
|
||||
roomProperties := json.RawMessage("{\"foo\":\"bar\"}")
|
||||
|
||||
natsChan := make(chan *nats.Msg, 1)
|
||||
subject := GetSubjectForUserId(testDefaultUserId, backend)
|
||||
sub, err := n.Subscribe(subject, natsChan)
|
||||
if err != nil {
|
||||
eventsChan := make(chan *AsyncMessage, 1)
|
||||
listener := &channelEventListener{
|
||||
ch: eventsChan,
|
||||
}
|
||||
if err := events.RegisterUserListener(testDefaultUserId, backend, listener); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := sub.Unsubscribe(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
defer events.UnregisterUserListener(testDefaultUserId, backend, listener)
|
||||
|
||||
msg := &BackendServerRoomRequest{
|
||||
Type: "disinvite",
|
||||
|
@ -457,7 +577,7 @@ func TestBackendServer_RoomDisinvite(t *testing.T) {
|
|||
t.Errorf("Expected successful request, got %s: %s", res.Status, string(body))
|
||||
}
|
||||
|
||||
event, err := expectRoomlistEvent(n, natsChan, subject, "disinvite")
|
||||
event, err := expectRoomlistEvent(eventsChan, "disinvite")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
} else if event.Disinvite == nil {
|
||||
|
@ -606,11 +726,18 @@ func TestBackendServer_RoomDisinviteDifferentRooms(t *testing.T) {
|
|||
} else if message.RoomId != roomId2 {
|
||||
t.Errorf("Expected message for room %s, got %s", roomId2, message.RoomId)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestBackendServer_RoomUpdate(t *testing.T) {
|
||||
_, _, n, hub, _, server := CreateBackendServerForTest(t)
|
||||
for _, backend := range eventBackendsForTest {
|
||||
t.Run(backend, func(t *testing.T) {
|
||||
RunTestBackendServer_RoomUpdate(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func RunTestBackendServer_RoomUpdate(t *testing.T) {
|
||||
_, _, events, hub, _, server := CreateBackendServerForTest(t)
|
||||
|
||||
u, err := url.Parse(server.URL)
|
||||
if err != nil {
|
||||
|
@ -632,17 +759,14 @@ func TestBackendServer_RoomUpdate(t *testing.T) {
|
|||
userid := "test-userid"
|
||||
roomProperties := json.RawMessage("{\"foo\":\"bar\"}")
|
||||
|
||||
natsChan := make(chan *nats.Msg, 1)
|
||||
subject := GetSubjectForUserId(userid, backend)
|
||||
sub, err := n.Subscribe(subject, natsChan)
|
||||
if err != nil {
|
||||
eventsChan := make(chan *AsyncMessage, 1)
|
||||
listener := &channelEventListener{
|
||||
ch: eventsChan,
|
||||
}
|
||||
if err := events.RegisterUserListener(userid, backend, listener); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := sub.Unsubscribe(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
defer events.UnregisterUserListener(userid, backend, listener)
|
||||
|
||||
msg := &BackendServerRoomRequest{
|
||||
Type: "update",
|
||||
|
@ -671,7 +795,7 @@ func TestBackendServer_RoomUpdate(t *testing.T) {
|
|||
t.Errorf("Expected successful request, got %s: %s", res.Status, string(body))
|
||||
}
|
||||
|
||||
event, err := expectRoomlistEvent(n, natsChan, subject, "update")
|
||||
event, err := expectRoomlistEvent(eventsChan, "update")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
} else if event.Update == nil {
|
||||
|
@ -682,7 +806,7 @@ func TestBackendServer_RoomUpdate(t *testing.T) {
|
|||
t.Errorf("Room properties don't match: expected %s, got %s", string(roomProperties), string(*event.Update.Properties))
|
||||
}
|
||||
|
||||
// TODO: Use event to wait for NATS messages.
|
||||
// TODO: Use event to wait for asynchronous messages.
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
room = hub.getRoom(roomId)
|
||||
|
@ -695,7 +819,15 @@ func TestBackendServer_RoomUpdate(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestBackendServer_RoomDelete(t *testing.T) {
|
||||
_, _, n, hub, _, server := CreateBackendServerForTest(t)
|
||||
for _, backend := range eventBackendsForTest {
|
||||
t.Run(backend, func(t *testing.T) {
|
||||
RunTestBackendServer_RoomDelete(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func RunTestBackendServer_RoomDelete(t *testing.T) {
|
||||
_, _, events, hub, _, server := CreateBackendServerForTest(t)
|
||||
|
||||
u, err := url.Parse(server.URL)
|
||||
if err != nil {
|
||||
|
@ -713,18 +845,14 @@ func TestBackendServer_RoomDelete(t *testing.T) {
|
|||
}
|
||||
|
||||
userid := "test-userid"
|
||||
|
||||
natsChan := make(chan *nats.Msg, 1)
|
||||
subject := GetSubjectForUserId(userid, backend)
|
||||
sub, err := n.Subscribe(subject, natsChan)
|
||||
if err != nil {
|
||||
eventsChan := make(chan *AsyncMessage, 1)
|
||||
listener := &channelEventListener{
|
||||
ch: eventsChan,
|
||||
}
|
||||
if err := events.RegisterUserListener(userid, backend, listener); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := sub.Unsubscribe(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
defer events.UnregisterUserListener(userid, backend, listener)
|
||||
|
||||
msg := &BackendServerRoomRequest{
|
||||
Type: "delete",
|
||||
|
@ -753,7 +881,7 @@ func TestBackendServer_RoomDelete(t *testing.T) {
|
|||
}
|
||||
|
||||
// A deleted room is signalled as a "disinvite" event.
|
||||
event, err := expectRoomlistEvent(n, natsChan, subject, "disinvite")
|
||||
event, err := expectRoomlistEvent(eventsChan, "disinvite")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
} else if event.Disinvite == nil {
|
||||
|
@ -766,7 +894,7 @@ func TestBackendServer_RoomDelete(t *testing.T) {
|
|||
t.Errorf("Reason should be deleted, got %s", event.Disinvite.Reason)
|
||||
}
|
||||
|
||||
// TODO: Use event to wait for NATS messages.
|
||||
// TODO: Use event to wait for asynchronous messages.
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
room := hub.getRoom(roomId)
|
||||
|
@ -776,117 +904,134 @@ func TestBackendServer_RoomDelete(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestBackendServer_ParticipantsUpdatePermissions(t *testing.T) {
|
||||
_, _, _, hub, _, server := CreateBackendServerForTest(t)
|
||||
for _, subtest := range clusteredTests {
|
||||
t.Run(subtest, func(t *testing.T) {
|
||||
var hub1 *Hub
|
||||
var hub2 *Hub
|
||||
var server1 *httptest.Server
|
||||
var server2 *httptest.Server
|
||||
|
||||
client1 := NewTestClient(t, server, hub)
|
||||
defer client1.CloseWithBye()
|
||||
if err := client1.SendHello(testDefaultUserId + "1"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
client2 := NewTestClient(t, server, hub)
|
||||
defer client2.CloseWithBye()
|
||||
if err := client2.SendHello(testDefaultUserId + "2"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if isLocalTest(t) {
|
||||
_, _, _, hub1, _, server1 = CreateBackendServerForTest(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
hub2 = hub1
|
||||
server2 = server1
|
||||
} else {
|
||||
_, _, hub1, hub2, server1, server2 = CreateBackendServerWithClusteringForTest(t)
|
||||
}
|
||||
|
||||
hello1, err := client1.RunUntilHello(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
hello2, err := client2.RunUntilHello(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
client1 := NewTestClient(t, server1, hub1)
|
||||
defer client1.CloseWithBye()
|
||||
if err := client1.SendHello(testDefaultUserId + "1"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
client2 := NewTestClient(t, server2, hub2)
|
||||
defer client2.CloseWithBye()
|
||||
if err := client2.SendHello(testDefaultUserId + "2"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
session1 := hub.GetSessionByPublicId(hello1.Hello.SessionId)
|
||||
if session1 == nil {
|
||||
t.Fatalf("Session %s does not exist", hello1.Hello.SessionId)
|
||||
}
|
||||
session2 := hub.GetSessionByPublicId(hello2.Hello.SessionId)
|
||||
if session2 == nil {
|
||||
t.Fatalf("Session %s does not exist", hello2.Hello.SessionId)
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Sessions have all permissions initially (fallback for old-style sessions).
|
||||
assertSessionHasPermission(t, session1, PERMISSION_MAY_PUBLISH_MEDIA)
|
||||
assertSessionHasPermission(t, session1, PERMISSION_MAY_PUBLISH_SCREEN)
|
||||
assertSessionHasPermission(t, session2, PERMISSION_MAY_PUBLISH_MEDIA)
|
||||
assertSessionHasPermission(t, session2, PERMISSION_MAY_PUBLISH_SCREEN)
|
||||
hello1, err := client1.RunUntilHello(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
hello2, err := client2.RunUntilHello(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Join room by id.
|
||||
roomId := "test-room"
|
||||
if room, err := client1.JoinRoom(ctx, roomId); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if room.Room.RoomId != roomId {
|
||||
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
|
||||
}
|
||||
if room, err := client2.JoinRoom(ctx, roomId); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if room.Room.RoomId != roomId {
|
||||
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
|
||||
}
|
||||
session1 := hub1.GetSessionByPublicId(hello1.Hello.SessionId)
|
||||
if session1 == nil {
|
||||
t.Fatalf("Session %s does not exist", hello1.Hello.SessionId)
|
||||
}
|
||||
session2 := hub2.GetSessionByPublicId(hello2.Hello.SessionId)
|
||||
if session2 == nil {
|
||||
t.Fatalf("Session %s does not exist", hello2.Hello.SessionId)
|
||||
}
|
||||
|
||||
// Ignore "join" events.
|
||||
if err := client1.DrainMessages(ctx); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := client2.DrainMessages(ctx); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
// Sessions have all permissions initially (fallback for old-style sessions).
|
||||
assertSessionHasPermission(t, session1, PERMISSION_MAY_PUBLISH_MEDIA)
|
||||
assertSessionHasPermission(t, session1, PERMISSION_MAY_PUBLISH_SCREEN)
|
||||
assertSessionHasPermission(t, session2, PERMISSION_MAY_PUBLISH_MEDIA)
|
||||
assertSessionHasPermission(t, session2, PERMISSION_MAY_PUBLISH_SCREEN)
|
||||
|
||||
msg := &BackendServerRoomRequest{
|
||||
Type: "participants",
|
||||
Participants: &BackendRoomParticipantsRequest{
|
||||
Changed: []map[string]interface{}{
|
||||
{
|
||||
"sessionId": roomId + "-" + hello1.Hello.SessionId,
|
||||
"permissions": []Permission{PERMISSION_MAY_PUBLISH_MEDIA},
|
||||
// Join room by id.
|
||||
roomId := "test-room"
|
||||
if room, err := client1.JoinRoom(ctx, roomId); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if room.Room.RoomId != roomId {
|
||||
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
|
||||
}
|
||||
if room, err := client2.JoinRoom(ctx, roomId); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if room.Room.RoomId != roomId {
|
||||
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
|
||||
}
|
||||
|
||||
// Ignore "join" events.
|
||||
if err := client1.DrainMessages(ctx); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := client2.DrainMessages(ctx); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
msg := &BackendServerRoomRequest{
|
||||
Type: "participants",
|
||||
Participants: &BackendRoomParticipantsRequest{
|
||||
Changed: []map[string]interface{}{
|
||||
{
|
||||
"sessionId": roomId + "-" + hello1.Hello.SessionId,
|
||||
"permissions": []Permission{PERMISSION_MAY_PUBLISH_MEDIA},
|
||||
},
|
||||
{
|
||||
"sessionId": roomId + "-" + hello2.Hello.SessionId,
|
||||
"permissions": []Permission{PERMISSION_MAY_PUBLISH_SCREEN},
|
||||
},
|
||||
},
|
||||
Users: []map[string]interface{}{
|
||||
{
|
||||
"sessionId": roomId + "-" + hello1.Hello.SessionId,
|
||||
"permissions": []Permission{PERMISSION_MAY_PUBLISH_MEDIA},
|
||||
},
|
||||
{
|
||||
"sessionId": roomId + "-" + hello2.Hello.SessionId,
|
||||
"permissions": []Permission{PERMISSION_MAY_PUBLISH_SCREEN},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"sessionId": roomId + "-" + hello2.Hello.SessionId,
|
||||
"permissions": []Permission{PERMISSION_MAY_PUBLISH_SCREEN},
|
||||
},
|
||||
},
|
||||
Users: []map[string]interface{}{
|
||||
{
|
||||
"sessionId": roomId + "-" + hello1.Hello.SessionId,
|
||||
"permissions": []Permission{PERMISSION_MAY_PUBLISH_MEDIA},
|
||||
},
|
||||
{
|
||||
"sessionId": roomId + "-" + hello2.Hello.SessionId,
|
||||
"permissions": []Permission{PERMISSION_MAY_PUBLISH_SCREEN},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
res, err := performBackendRequest(server.URL+"/api/v1/room/"+roomId, data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if res.StatusCode != 200 {
|
||||
t.Errorf("Expected successful request, got %s: %s", res.Status, string(body))
|
||||
}
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// The request could be sent to any of the backend servers.
|
||||
res, err := performBackendRequest(server1.URL+"/api/v1/room/"+roomId, data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if res.StatusCode != 200 {
|
||||
t.Errorf("Expected successful request, got %s: %s", res.Status, string(body))
|
||||
}
|
||||
|
||||
// TODO: Use event to wait for NATS messages.
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
// TODO: Use event to wait for asynchronous messages.
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
assertSessionHasPermission(t, session1, PERMISSION_MAY_PUBLISH_MEDIA)
|
||||
assertSessionHasNotPermission(t, session1, PERMISSION_MAY_PUBLISH_SCREEN)
|
||||
assertSessionHasNotPermission(t, session2, PERMISSION_MAY_PUBLISH_MEDIA)
|
||||
assertSessionHasPermission(t, session2, PERMISSION_MAY_PUBLISH_SCREEN)
|
||||
assertSessionHasPermission(t, session1, PERMISSION_MAY_PUBLISH_MEDIA)
|
||||
assertSessionHasNotPermission(t, session1, PERMISSION_MAY_PUBLISH_SCREEN)
|
||||
assertSessionHasNotPermission(t, session2, PERMISSION_MAY_PUBLISH_MEDIA)
|
||||
assertSessionHasPermission(t, session2, PERMISSION_MAY_PUBLISH_SCREEN)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendServer_ParticipantsUpdateEmptyPermissions(t *testing.T) {
|
||||
|
@ -967,7 +1112,7 @@ func TestBackendServer_ParticipantsUpdateEmptyPermissions(t *testing.T) {
|
|||
t.Errorf("Expected successful request, got %s: %s", res.Status, string(body))
|
||||
}
|
||||
|
||||
// TODO: Use event to wait for NATS messages.
|
||||
// TODO: Use event to wait for asynchronous messages.
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
assertSessionHasNotPermission(t, session, PERMISSION_MAY_PUBLISH_MEDIA)
|
||||
|
@ -1187,6 +1332,256 @@ func TestBackendServer_ParticipantsUpdateTimeout(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestBackendServer_InCallAll(t *testing.T) {
|
||||
for _, subtest := range clusteredTests {
|
||||
t.Run(subtest, func(t *testing.T) {
|
||||
var hub1 *Hub
|
||||
var hub2 *Hub
|
||||
var server1 *httptest.Server
|
||||
var server2 *httptest.Server
|
||||
|
||||
if isLocalTest(t) {
|
||||
_, _, _, hub1, _, server1 = CreateBackendServerForTest(t)
|
||||
|
||||
hub2 = hub1
|
||||
server2 = server1
|
||||
} else {
|
||||
_, _, hub1, hub2, server1, server2 = CreateBackendServerWithClusteringForTest(t)
|
||||
}
|
||||
|
||||
client1 := NewTestClient(t, server1, hub1)
|
||||
defer client1.CloseWithBye()
|
||||
if err := client1.SendHello(testDefaultUserId + "1"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
client2 := NewTestClient(t, server2, hub2)
|
||||
defer client2.CloseWithBye()
|
||||
if err := client2.SendHello(testDefaultUserId + "2"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
|
||||
hello1, err := client1.RunUntilHello(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
hello2, err := client2.RunUntilHello(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
session1 := hub1.GetSessionByPublicId(hello1.Hello.SessionId)
|
||||
if session1 == nil {
|
||||
t.Fatalf("Could not find session %s", hello1.Hello.SessionId)
|
||||
}
|
||||
session2 := hub2.GetSessionByPublicId(hello2.Hello.SessionId)
|
||||
if session2 == nil {
|
||||
t.Fatalf("Could not find session %s", hello2.Hello.SessionId)
|
||||
}
|
||||
|
||||
// Join room by id.
|
||||
roomId := "test-room"
|
||||
if room, err := client1.JoinRoom(ctx, roomId); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if room.Room.RoomId != roomId {
|
||||
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
|
||||
}
|
||||
|
||||
// Give message processing some time.
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
if room, err := client2.JoinRoom(ctx, roomId); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if room.Room.RoomId != roomId {
|
||||
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
|
||||
}
|
||||
|
||||
WaitForUsersJoined(ctx, t, client1, hello1, client2, hello2)
|
||||
|
||||
room1 := hub1.getRoom(roomId)
|
||||
if room1 == nil {
|
||||
t.Fatalf("Could not find room %s in hub1", roomId)
|
||||
}
|
||||
room2 := hub2.getRoom(roomId)
|
||||
if room2 == nil {
|
||||
t.Fatalf("Could not find room %s in hub2", roomId)
|
||||
}
|
||||
|
||||
if room1.IsSessionInCall(session1) {
|
||||
t.Errorf("Session %s should not be in room %s", session1.PublicId(), room1.Id())
|
||||
}
|
||||
if room2.IsSessionInCall(session2) {
|
||||
t.Errorf("Session %s should not be in room %s", session2.PublicId(), room2.Id())
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
msg := &BackendServerRoomRequest{
|
||||
Type: "incall",
|
||||
InCall: &BackendRoomInCallRequest{
|
||||
InCall: json.RawMessage("7"),
|
||||
All: true,
|
||||
},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
res, err := performBackendRequest(server1.URL+"/api/v1/room/"+roomId, data)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
defer res.Body.Close()
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if res.StatusCode != 200 {
|
||||
t.Errorf("Expected successful request, got %s: %s", res.Status, string(body))
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
if t.Failed() {
|
||||
return
|
||||
}
|
||||
|
||||
if msg1_a, err := client1.RunUntilMessage(ctx); err != nil {
|
||||
t.Error(err)
|
||||
} else if in_call_1, err := checkMessageParticipantsInCall(msg1_a); err != nil {
|
||||
t.Error(err)
|
||||
} else if !in_call_1.All {
|
||||
t.Errorf("All flag not set in message %+v", in_call_1)
|
||||
} else if !bytes.Equal(*in_call_1.InCall, []byte("7")) {
|
||||
t.Errorf("Expected inCall flag 7, got %s", string(*in_call_1.InCall))
|
||||
}
|
||||
|
||||
if msg2_a, err := client2.RunUntilMessage(ctx); err != nil {
|
||||
t.Error(err)
|
||||
} else if in_call_1, err := checkMessageParticipantsInCall(msg2_a); err != nil {
|
||||
t.Error(err)
|
||||
} else if !in_call_1.All {
|
||||
t.Errorf("All flag not set in message %+v", in_call_1)
|
||||
} else if !bytes.Equal(*in_call_1.InCall, []byte("7")) {
|
||||
t.Errorf("Expected inCall flag 7, got %s", string(*in_call_1.InCall))
|
||||
}
|
||||
|
||||
if !room1.IsSessionInCall(session1) {
|
||||
t.Errorf("Session %s should be in room %s", session1.PublicId(), room1.Id())
|
||||
}
|
||||
if !room2.IsSessionInCall(session2) {
|
||||
t.Errorf("Session %s should be in room %s", session2.PublicId(), room2.Id())
|
||||
}
|
||||
|
||||
ctx2, cancel2 := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel2()
|
||||
|
||||
if message, err := client1.RunUntilMessage(ctx2); err != nil && err != ErrNoMessageReceived && err != context.DeadlineExceeded {
|
||||
t.Error(err)
|
||||
} else if message != nil {
|
||||
t.Errorf("Expected no message, got %+v", message)
|
||||
}
|
||||
|
||||
ctx3, cancel3 := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel3()
|
||||
|
||||
if message, err := client2.RunUntilMessage(ctx3); err != nil && err != ErrNoMessageReceived && err != context.DeadlineExceeded {
|
||||
t.Error(err)
|
||||
} else if message != nil {
|
||||
t.Errorf("Expected no message, got %+v", message)
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
msg := &BackendServerRoomRequest{
|
||||
Type: "incall",
|
||||
InCall: &BackendRoomInCallRequest{
|
||||
InCall: json.RawMessage("0"),
|
||||
All: true,
|
||||
},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
res, err := performBackendRequest(server1.URL+"/api/v1/room/"+roomId, data)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
defer res.Body.Close()
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if res.StatusCode != 200 {
|
||||
t.Errorf("Expected successful request, got %s: %s", res.Status, string(body))
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
if t.Failed() {
|
||||
return
|
||||
}
|
||||
|
||||
if msg1_a, err := client1.RunUntilMessage(ctx); err != nil {
|
||||
t.Error(err)
|
||||
} else if in_call_1, err := checkMessageParticipantsInCall(msg1_a); err != nil {
|
||||
t.Error(err)
|
||||
} else if !in_call_1.All {
|
||||
t.Errorf("All flag not set in message %+v", in_call_1)
|
||||
} else if !bytes.Equal(*in_call_1.InCall, []byte("0")) {
|
||||
t.Errorf("Expected inCall flag 0, got %s", string(*in_call_1.InCall))
|
||||
}
|
||||
|
||||
if msg2_a, err := client2.RunUntilMessage(ctx); err != nil {
|
||||
t.Error(err)
|
||||
} else if in_call_1, err := checkMessageParticipantsInCall(msg2_a); err != nil {
|
||||
t.Error(err)
|
||||
} else if !in_call_1.All {
|
||||
t.Errorf("All flag not set in message %+v", in_call_1)
|
||||
} else if !bytes.Equal(*in_call_1.InCall, []byte("0")) {
|
||||
t.Errorf("Expected inCall flag 0, got %s", string(*in_call_1.InCall))
|
||||
}
|
||||
|
||||
if room1.IsSessionInCall(session1) {
|
||||
t.Errorf("Session %s should not be in room %s", session1.PublicId(), room1.Id())
|
||||
}
|
||||
if room2.IsSessionInCall(session2) {
|
||||
t.Errorf("Session %s should not be in room %s", session2.PublicId(), room2.Id())
|
||||
}
|
||||
|
||||
ctx4, cancel4 := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel4()
|
||||
|
||||
if message, err := client1.RunUntilMessage(ctx4); err != nil && err != ErrNoMessageReceived && err != context.DeadlineExceeded {
|
||||
t.Error(err)
|
||||
} else if message != nil {
|
||||
t.Errorf("Expected no message, got %+v", message)
|
||||
}
|
||||
|
||||
ctx5, cancel5 := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel5()
|
||||
|
||||
if message, err := client2.RunUntilMessage(ctx5); err != nil && err != ErrNoMessageReceived && err != context.DeadlineExceeded {
|
||||
t.Error(err)
|
||||
} else if message != nil {
|
||||
t.Errorf("Expected no message, got %+v", message)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendServer_RoomMessage(t *testing.T) {
|
||||
_, _, _, hub, _, server := CreateBackendServerForTest(t)
|
||||
|
||||
|
|
256
backend_storage_etcd.go
Normal file
256
backend_storage_etcd.go
Normal file
|
@ -0,0 +1,256 @@
|
|||
/**
|
||||
* Standalone signaling server for the Nextcloud Spreed app.
|
||||
* Copyright (C) 2022 struktur AG
|
||||
*
|
||||
* @author Joachim Bauch <bauch@struktur.de>
|
||||
*
|
||||
* @license GNU AGPL version 3 or any later version
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package signaling
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/dlintw/goconf"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
)
|
||||
|
||||
type backendStorageEtcd struct {
|
||||
backendStorageCommon
|
||||
|
||||
etcdClient *EtcdClient
|
||||
keyPrefix string
|
||||
keyInfos map[string]*BackendInformationEtcd
|
||||
|
||||
initializedCtx context.Context
|
||||
initializedFunc context.CancelFunc
|
||||
wakeupChanForTesting chan bool
|
||||
}
|
||||
|
||||
func NewBackendStorageEtcd(config *goconf.ConfigFile, etcdClient *EtcdClient) (BackendStorage, error) {
|
||||
if etcdClient == nil || !etcdClient.IsConfigured() {
|
||||
return nil, fmt.Errorf("no etcd endpoints configured")
|
||||
}
|
||||
|
||||
keyPrefix, _ := config.GetString("backend", "backendprefix")
|
||||
if keyPrefix == "" {
|
||||
return nil, fmt.Errorf("no backend prefix configured")
|
||||
}
|
||||
|
||||
initializedCtx, initializedFunc := context.WithCancel(context.Background())
|
||||
result := &backendStorageEtcd{
|
||||
backendStorageCommon: backendStorageCommon{
|
||||
backends: make(map[string][]*Backend),
|
||||
},
|
||||
etcdClient: etcdClient,
|
||||
keyPrefix: keyPrefix,
|
||||
keyInfos: make(map[string]*BackendInformationEtcd),
|
||||
|
||||
initializedCtx: initializedCtx,
|
||||
initializedFunc: initializedFunc,
|
||||
}
|
||||
|
||||
etcdClient.AddListener(result)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *backendStorageEtcd) WaitForInitialized(ctx context.Context) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-s.initializedCtx.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *backendStorageEtcd) SetWakeupForTesting(ch chan bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.wakeupChanForTesting = ch
|
||||
}
|
||||
|
||||
func (s *backendStorageEtcd) wakeupForTesting() {
|
||||
if s.wakeupChanForTesting == nil {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case s.wakeupChanForTesting <- true:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (s *backendStorageEtcd) EtcdClientCreated(client *EtcdClient) {
|
||||
go func() {
|
||||
if err := client.Watch(context.Background(), s.keyPrefix, s, clientv3.WithPrefix()); err != nil {
|
||||
log.Printf("Error processing watch for %s: %s", s.keyPrefix, err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
client.WaitForConnection()
|
||||
|
||||
waitDelay := initialWaitDelay
|
||||
for {
|
||||
response, err := s.getBackends(client, s.keyPrefix)
|
||||
if err != nil {
|
||||
if err == context.DeadlineExceeded {
|
||||
log.Printf("Timeout getting initial list of backends, retry in %s", waitDelay)
|
||||
} else {
|
||||
log.Printf("Could not get initial list of backends, retry in %s: %s", waitDelay, err)
|
||||
}
|
||||
|
||||
time.Sleep(waitDelay)
|
||||
waitDelay = waitDelay * 2
|
||||
if waitDelay > maxWaitDelay {
|
||||
waitDelay = maxWaitDelay
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
for _, ev := range response.Kvs {
|
||||
s.EtcdKeyUpdated(client, string(ev.Key), ev.Value)
|
||||
}
|
||||
s.initializedFunc()
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *backendStorageEtcd) getBackends(client *EtcdClient, keyPrefix string) (*clientv3.GetResponse, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
return client.Get(ctx, keyPrefix, clientv3.WithPrefix())
|
||||
}
|
||||
|
||||
func (s *backendStorageEtcd) EtcdKeyUpdated(client *EtcdClient, key string, data []byte) {
|
||||
var info BackendInformationEtcd
|
||||
if err := json.Unmarshal(data, &info); err != nil {
|
||||
log.Printf("Could not decode backend information %s: %s", string(data), err)
|
||||
return
|
||||
}
|
||||
if err := info.CheckValid(); err != nil {
|
||||
log.Printf("Received invalid backend information %s: %s", string(data), err)
|
||||
return
|
||||
}
|
||||
|
||||
backend := &Backend{
|
||||
id: key,
|
||||
url: info.Url,
|
||||
secret: []byte(info.Secret),
|
||||
|
||||
allowHttp: info.parsedUrl.Scheme == "http",
|
||||
|
||||
maxStreamBitrate: info.MaxStreamBitrate,
|
||||
maxScreenBitrate: info.MaxScreenBitrate,
|
||||
sessionLimit: info.SessionLimit,
|
||||
}
|
||||
|
||||
host := info.parsedUrl.Host
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.keyInfos[key] = &info
|
||||
entries, found := s.backends[host]
|
||||
if !found {
|
||||
// Simple case, first backend for this host
|
||||
log.Printf("Added backend %s (from %s)", info.Url, key)
|
||||
s.backends[host] = []*Backend{backend}
|
||||
statsBackendsCurrent.Inc()
|
||||
s.wakeupForTesting()
|
||||
return
|
||||
}
|
||||
|
||||
// Was the backend changed?
|
||||
replaced := false
|
||||
for idx, entry := range entries {
|
||||
if entry.id == key {
|
||||
log.Printf("Updated backend %s (from %s)", info.Url, key)
|
||||
entries[idx] = backend
|
||||
replaced = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !replaced {
|
||||
// New backend, add to list.
|
||||
log.Printf("Added backend %s (from %s)", info.Url, key)
|
||||
s.backends[host] = append(entries, backend)
|
||||
statsBackendsCurrent.Inc()
|
||||
}
|
||||
s.wakeupForTesting()
|
||||
}
|
||||
|
||||
func (s *backendStorageEtcd) EtcdKeyDeleted(client *EtcdClient, key string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
info, found := s.keyInfos[key]
|
||||
if !found {
|
||||
return
|
||||
}
|
||||
|
||||
delete(s.keyInfos, key)
|
||||
host := info.parsedUrl.Host
|
||||
entries, found := s.backends[host]
|
||||
if !found {
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Removing backend %s (from %s)", info.Url, key)
|
||||
newEntries := make([]*Backend, 0, len(entries)-1)
|
||||
for _, entry := range entries {
|
||||
if entry.id == key {
|
||||
statsBackendsCurrent.Dec()
|
||||
continue
|
||||
}
|
||||
|
||||
newEntries = append(newEntries, entry)
|
||||
}
|
||||
if len(newEntries) > 0 {
|
||||
s.backends[host] = newEntries
|
||||
} else {
|
||||
delete(s.backends, host)
|
||||
}
|
||||
s.wakeupForTesting()
|
||||
}
|
||||
|
||||
func (s *backendStorageEtcd) Close() {
|
||||
s.etcdClient.RemoveListener(s)
|
||||
}
|
||||
|
||||
func (s *backendStorageEtcd) Reload(config *goconf.ConfigFile) {
|
||||
// Backend updates are processed through etcd.
|
||||
}
|
||||
|
||||
func (s *backendStorageEtcd) GetCompatBackend() *Backend {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *backendStorageEtcd) GetBackend(u *url.URL) *Backend {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
return s.getBackendLocked(u)
|
||||
}
|
303
backend_storage_static.go
Normal file
303
backend_storage_static.go
Normal file
|
@ -0,0 +1,303 @@
|
|||
/**
|
||||
* Standalone signaling server for the Nextcloud Spreed app.
|
||||
* Copyright (C) 2022 struktur AG
|
||||
*
|
||||
* @author Joachim Bauch <bauch@struktur.de>
|
||||
*
|
||||
* @license GNU AGPL version 3 or any later version
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package signaling
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/dlintw/goconf"
|
||||
)
|
||||
|
||||
type backendStorageStatic struct {
|
||||
backendStorageCommon
|
||||
|
||||
// Deprecated
|
||||
allowAll bool
|
||||
commonSecret []byte
|
||||
compatBackend *Backend
|
||||
}
|
||||
|
||||
func NewBackendStorageStatic(config *goconf.ConfigFile) (BackendStorage, error) {
|
||||
allowAll, _ := config.GetBool("backend", "allowall")
|
||||
allowHttp, _ := config.GetBool("backend", "allowhttp")
|
||||
commonSecret, _ := config.GetString("backend", "secret")
|
||||
sessionLimit, err := config.GetInt("backend", "sessionlimit")
|
||||
if err != nil || sessionLimit < 0 {
|
||||
sessionLimit = 0
|
||||
}
|
||||
backends := make(map[string][]*Backend)
|
||||
var compatBackend *Backend
|
||||
numBackends := 0
|
||||
if allowAll {
|
||||
log.Println("WARNING: All backend hostnames are allowed, only use for development!")
|
||||
compatBackend = &Backend{
|
||||
id: "compat",
|
||||
secret: []byte(commonSecret),
|
||||
compat: true,
|
||||
|
||||
allowHttp: allowHttp,
|
||||
|
||||
sessionLimit: uint64(sessionLimit),
|
||||
}
|
||||
if sessionLimit > 0 {
|
||||
log.Printf("Allow a maximum of %d sessions", sessionLimit)
|
||||
}
|
||||
numBackends++
|
||||
} else if backendIds, _ := config.GetString("backend", "backends"); backendIds != "" {
|
||||
for host, configuredBackends := range getConfiguredHosts(backendIds, config) {
|
||||
backends[host] = append(backends[host], configuredBackends...)
|
||||
for _, be := range configuredBackends {
|
||||
log.Printf("Backend %s added for %s", be.id, be.url)
|
||||
}
|
||||
numBackends += len(configuredBackends)
|
||||
}
|
||||
} else if allowedUrls, _ := config.GetString("backend", "allowed"); allowedUrls != "" {
|
||||
// Old-style configuration, only hosts are configured and are using a common secret.
|
||||
allowMap := make(map[string]bool)
|
||||
for _, u := range strings.Split(allowedUrls, ",") {
|
||||
u = strings.TrimSpace(u)
|
||||
if idx := strings.IndexByte(u, '/'); idx != -1 {
|
||||
log.Printf("WARNING: Removing path from allowed hostname \"%s\", check your configuration!", u)
|
||||
u = u[:idx]
|
||||
}
|
||||
if u != "" {
|
||||
allowMap[strings.ToLower(u)] = true
|
||||
}
|
||||
}
|
||||
|
||||
if len(allowMap) == 0 {
|
||||
log.Println("WARNING: No backend hostnames are allowed, check your configuration!")
|
||||
} else {
|
||||
compatBackend = &Backend{
|
||||
id: "compat",
|
||||
secret: []byte(commonSecret),
|
||||
compat: true,
|
||||
|
||||
allowHttp: allowHttp,
|
||||
|
||||
sessionLimit: uint64(sessionLimit),
|
||||
}
|
||||
hosts := make([]string, 0, len(allowMap))
|
||||
for host := range allowMap {
|
||||
hosts = append(hosts, host)
|
||||
backends[host] = []*Backend{compatBackend}
|
||||
}
|
||||
if len(hosts) > 1 {
|
||||
log.Println("WARNING: Using deprecated backend configuration. Please migrate the \"allowed\" setting to the new \"backends\" configuration.")
|
||||
}
|
||||
log.Printf("Allowed backend hostnames: %s", hosts)
|
||||
if sessionLimit > 0 {
|
||||
log.Printf("Allow a maximum of %d sessions", sessionLimit)
|
||||
}
|
||||
numBackends++
|
||||
}
|
||||
}
|
||||
|
||||
statsBackendsCurrent.Add(float64(numBackends))
|
||||
return &backendStorageStatic{
|
||||
backendStorageCommon: backendStorageCommon{
|
||||
backends: backends,
|
||||
},
|
||||
|
||||
allowAll: allowAll,
|
||||
commonSecret: []byte(commonSecret),
|
||||
compatBackend: compatBackend,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *backendStorageStatic) Close() {
|
||||
}
|
||||
|
||||
func (s *backendStorageStatic) RemoveBackendsForHost(host string) {
|
||||
if oldBackends := s.backends[host]; len(oldBackends) > 0 {
|
||||
for _, backend := range oldBackends {
|
||||
log.Printf("Backend %s removed for %s", backend.id, backend.url)
|
||||
}
|
||||
statsBackendsCurrent.Sub(float64(len(oldBackends)))
|
||||
}
|
||||
delete(s.backends, host)
|
||||
}
|
||||
|
||||
func (s *backendStorageStatic) UpsertHost(host string, backends []*Backend) {
|
||||
for existingIndex, existingBackend := range s.backends[host] {
|
||||
found := false
|
||||
index := 0
|
||||
for _, newBackend := range backends {
|
||||
if reflect.DeepEqual(existingBackend, newBackend) { // otherwise we could manually compare the struct members here
|
||||
found = true
|
||||
backends = append(backends[:index], backends[index+1:]...)
|
||||
break
|
||||
} else if newBackend.id == existingBackend.id {
|
||||
found = true
|
||||
s.backends[host][existingIndex] = newBackend
|
||||
backends = append(backends[:index], backends[index+1:]...)
|
||||
log.Printf("Backend %s updated for %s", newBackend.id, newBackend.url)
|
||||
break
|
||||
}
|
||||
index++
|
||||
}
|
||||
if !found {
|
||||
removed := s.backends[host][existingIndex]
|
||||
log.Printf("Backend %s removed for %s", removed.id, removed.url)
|
||||
s.backends[host] = append(s.backends[host][:existingIndex], s.backends[host][existingIndex+1:]...)
|
||||
statsBackendsCurrent.Dec()
|
||||
}
|
||||
}
|
||||
|
||||
s.backends[host] = append(s.backends[host], backends...)
|
||||
for _, added := range backends {
|
||||
log.Printf("Backend %s added for %s", added.id, added.url)
|
||||
}
|
||||
statsBackendsCurrent.Add(float64(len(backends)))
|
||||
}
|
||||
|
||||
func getConfiguredBackendIDs(backendIds string) (ids []string) {
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, id := range strings.Split(backendIds, ",") {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if seen[id] {
|
||||
continue
|
||||
}
|
||||
ids = append(ids, id)
|
||||
seen[id] = true
|
||||
}
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
func getConfiguredHosts(backendIds string, config *goconf.ConfigFile) (hosts map[string][]*Backend) {
|
||||
hosts = make(map[string][]*Backend)
|
||||
for _, id := range getConfiguredBackendIDs(backendIds) {
|
||||
u, _ := config.GetString(id, "url")
|
||||
if u == "" {
|
||||
log.Printf("Backend %s is missing or incomplete, skipping", id)
|
||||
continue
|
||||
}
|
||||
|
||||
if u[len(u)-1] != '/' {
|
||||
u += "/"
|
||||
}
|
||||
parsed, err := url.Parse(u)
|
||||
if err != nil {
|
||||
log.Printf("Backend %s has an invalid url %s configured (%s), skipping", id, u, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.Contains(parsed.Host, ":") && hasStandardPort(parsed) {
|
||||
parsed.Host = parsed.Hostname()
|
||||
u = parsed.String()
|
||||
}
|
||||
|
||||
secret, _ := config.GetString(id, "secret")
|
||||
if u == "" || secret == "" {
|
||||
log.Printf("Backend %s is missing or incomplete, skipping", id)
|
||||
continue
|
||||
}
|
||||
|
||||
sessionLimit, err := config.GetInt(id, "sessionlimit")
|
||||
if err != nil || sessionLimit < 0 {
|
||||
sessionLimit = 0
|
||||
}
|
||||
if sessionLimit > 0 {
|
||||
log.Printf("Backend %s allows a maximum of %d sessions", id, sessionLimit)
|
||||
}
|
||||
|
||||
maxStreamBitrate, err := config.GetInt(id, "maxstreambitrate")
|
||||
if err != nil || maxStreamBitrate < 0 {
|
||||
maxStreamBitrate = 0
|
||||
}
|
||||
maxScreenBitrate, err := config.GetInt(id, "maxscreenbitrate")
|
||||
if err != nil || maxScreenBitrate < 0 {
|
||||
maxScreenBitrate = 0
|
||||
}
|
||||
|
||||
hosts[parsed.Host] = append(hosts[parsed.Host], &Backend{
|
||||
id: id,
|
||||
url: u,
|
||||
secret: []byte(secret),
|
||||
|
||||
allowHttp: parsed.Scheme == "http",
|
||||
|
||||
maxStreamBitrate: maxStreamBitrate,
|
||||
maxScreenBitrate: maxScreenBitrate,
|
||||
|
||||
sessionLimit: uint64(sessionLimit),
|
||||
})
|
||||
}
|
||||
|
||||
return hosts
|
||||
}
|
||||
|
||||
func (s *backendStorageStatic) Reload(config *goconf.ConfigFile) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.compatBackend != nil {
|
||||
log.Println("Old-style configuration active, reload is not supported")
|
||||
return
|
||||
}
|
||||
|
||||
if backendIds, _ := config.GetString("backend", "backends"); backendIds != "" {
|
||||
configuredHosts := getConfiguredHosts(backendIds, config)
|
||||
|
||||
// remove backends that are no longer configured
|
||||
for hostname := range s.backends {
|
||||
if _, ok := configuredHosts[hostname]; !ok {
|
||||
s.RemoveBackendsForHost(hostname)
|
||||
}
|
||||
}
|
||||
|
||||
// rewrite backends adding newly configured ones and rewriting existing ones
|
||||
for hostname, configuredBackends := range configuredHosts {
|
||||
s.UpsertHost(hostname, configuredBackends)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *backendStorageStatic) GetCompatBackend() *Backend {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
return s.compatBackend
|
||||
}
|
||||
|
||||
func (s *backendStorageStatic) GetBackend(u *url.URL) *Backend {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if _, found := s.backends[u.Host]; !found {
|
||||
if s.allowAll {
|
||||
return s.compatBackend
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.getBackendLocked(u)
|
||||
}
|
76
backoff.go
Normal file
76
backoff.go
Normal file
|
@ -0,0 +1,76 @@
|
|||
/**
|
||||
* Standalone signaling server for the Nextcloud Spreed app.
|
||||
* Copyright (C) 2022 struktur AG
|
||||
*
|
||||
* @author Joachim Bauch <bauch@struktur.de>
|
||||
*
|
||||
* @license GNU AGPL version 3 or any later version
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package signaling
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Backoff interface {
|
||||
Reset()
|
||||
NextWait() time.Duration
|
||||
Wait(context.Context)
|
||||
}
|
||||
|
||||
type exponentialBackoff struct {
|
||||
initial time.Duration
|
||||
maxWait time.Duration
|
||||
nextWait time.Duration
|
||||
}
|
||||
|
||||
func NewExponentialBackoff(initial time.Duration, maxWait time.Duration) (Backoff, error) {
|
||||
if initial <= 0 {
|
||||
return nil, fmt.Errorf("initial must be larger than 0")
|
||||
}
|
||||
if maxWait < initial {
|
||||
return nil, fmt.Errorf("maxWait must be larger or equal to initial")
|
||||
}
|
||||
|
||||
return &exponentialBackoff{
|
||||
initial: initial,
|
||||
maxWait: maxWait,
|
||||
|
||||
nextWait: initial,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *exponentialBackoff) Reset() {
|
||||
b.nextWait = b.initial
|
||||
}
|
||||
|
||||
func (b *exponentialBackoff) NextWait() time.Duration {
|
||||
return b.nextWait
|
||||
}
|
||||
|
||||
func (b *exponentialBackoff) Wait(ctx context.Context) {
|
||||
waiter, cancel := context.WithTimeout(ctx, b.nextWait)
|
||||
defer cancel()
|
||||
|
||||
b.nextWait = b.nextWait * 2
|
||||
if b.nextWait > b.maxWait {
|
||||
b.nextWait = b.maxWait
|
||||
}
|
||||
|
||||
<-waiter.Done()
|
||||
}
|
64
backoff_test.go
Normal file
64
backoff_test.go
Normal file
|
@ -0,0 +1,64 @@
|
|||
/**
|
||||
* Standalone signaling server for the Nextcloud Spreed app.
|
||||
* Copyright (C) 2022 struktur AG
|
||||
*
|
||||
* @author Joachim Bauch <bauch@struktur.de>
|
||||
*
|
||||
* @license GNU AGPL version 3 or any later version
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package signaling
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestBackoff_Exponential(t *testing.T) {
|
||||
backoff, err := NewExponentialBackoff(100*time.Millisecond, 500*time.Millisecond)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
waitTimes := []time.Duration{
|
||||
100 * time.Millisecond,
|
||||
200 * time.Millisecond,
|
||||
400 * time.Millisecond,
|
||||
500 * time.Millisecond,
|
||||
500 * time.Millisecond,
|
||||
}
|
||||
|
||||
for _, wait := range waitTimes {
|
||||
if backoff.NextWait() != wait {
|
||||
t.Errorf("Wait time should be %s, got %s", wait, backoff.NextWait())
|
||||
}
|
||||
|
||||
a := time.Now()
|
||||
backoff.Wait(context.Background())
|
||||
b := time.Now()
|
||||
if b.Sub(a) < wait {
|
||||
t.Errorf("Should have waited %s, got %s", wait, b.Sub(a))
|
||||
}
|
||||
}
|
||||
|
||||
backoff.Reset()
|
||||
a := time.Now()
|
||||
backoff.Wait(context.Background())
|
||||
b := time.Now()
|
||||
if b.Sub(a) < 100*time.Millisecond {
|
||||
t.Errorf("Should have waited %s, got %s", 100*time.Millisecond, b.Sub(a))
|
||||
}
|
||||
}
|
190
certificate_reloader.go
Normal file
190
certificate_reloader.go
Normal file
|
@ -0,0 +1,190 @@
|
|||
/**
|
||||
* Standalone signaling server for the Nextcloud Spreed app.
|
||||
* Copyright (C) 2022 struktur AG
|
||||
*
|
||||
* @author Joachim Bauch <bauch@struktur.de>
|
||||
*
|
||||
* @license GNU AGPL version 3 or any later version
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package signaling
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
// CertificateCheckInterval defines the interval in which certificate files
|
||||
// are checked for modifications.
|
||||
CertificateCheckInterval = time.Minute
|
||||
)
|
||||
|
||||
type CertificateReloader struct {
|
||||
mu sync.Mutex
|
||||
|
||||
certFile string
|
||||
keyFile string
|
||||
|
||||
certificate *tls.Certificate
|
||||
lastModified time.Time
|
||||
|
||||
nextCheck time.Time
|
||||
}
|
||||
|
||||
func NewCertificateReloader(certFile string, keyFile string) (*CertificateReloader, error) {
|
||||
pair, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not load certificate / key: %w", err)
|
||||
}
|
||||
|
||||
stat, err := os.Stat(certFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not stat %s: %w", certFile, err)
|
||||
}
|
||||
|
||||
return &CertificateReloader{
|
||||
certFile: certFile,
|
||||
keyFile: keyFile,
|
||||
|
||||
certificate: &pair,
|
||||
lastModified: stat.ModTime(),
|
||||
|
||||
nextCheck: time.Now().Add(CertificateCheckInterval),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *CertificateReloader) getCertificate() (*tls.Certificate, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
if now.Before(r.nextCheck) {
|
||||
return r.certificate, nil
|
||||
}
|
||||
|
||||
r.nextCheck = now.Add(CertificateCheckInterval)
|
||||
|
||||
stat, err := os.Stat(r.certFile)
|
||||
if err != nil {
|
||||
log.Printf("could not stat %s: %s", r.certFile, err)
|
||||
return r.certificate, nil
|
||||
}
|
||||
|
||||
if !stat.ModTime().Equal(r.lastModified) {
|
||||
log.Printf("reloading certificate from %s with %s", r.certFile, r.keyFile)
|
||||
pair, err := tls.LoadX509KeyPair(r.certFile, r.keyFile)
|
||||
if err != nil {
|
||||
log.Printf("could not load certificate / key: %s", err)
|
||||
return r.certificate, nil
|
||||
}
|
||||
|
||||
r.certificate = &pair
|
||||
r.lastModified = stat.ModTime()
|
||||
}
|
||||
|
||||
return r.certificate, nil
|
||||
}
|
||||
|
||||
func (r *CertificateReloader) GetCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return r.getCertificate()
|
||||
}
|
||||
|
||||
func (r *CertificateReloader) GetClientCertificate(i *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
return r.getCertificate()
|
||||
}
|
||||
|
||||
type CertPoolReloader struct {
|
||||
mu sync.Mutex
|
||||
|
||||
certFile string
|
||||
|
||||
pool *x509.CertPool
|
||||
lastModified time.Time
|
||||
|
||||
nextCheck time.Time
|
||||
}
|
||||
|
||||
func loadCertPool(filename string) (*x509.CertPool, error) {
|
||||
cert, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pool := x509.NewCertPool()
|
||||
if !pool.AppendCertsFromPEM(cert) {
|
||||
return nil, fmt.Errorf("invalid CA in %s: %w", filename, err)
|
||||
}
|
||||
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
func NewCertPoolReloader(certFile string) (*CertPoolReloader, error) {
|
||||
pool, err := loadCertPool(certFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stat, err := os.Stat(certFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not stat %s: %w", certFile, err)
|
||||
}
|
||||
|
||||
return &CertPoolReloader{
|
||||
certFile: certFile,
|
||||
|
||||
pool: pool,
|
||||
lastModified: stat.ModTime(),
|
||||
|
||||
nextCheck: time.Now().Add(CertificateCheckInterval),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *CertPoolReloader) GetCertPool() *x509.CertPool {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
if now.Before(r.nextCheck) {
|
||||
return r.pool
|
||||
}
|
||||
|
||||
r.nextCheck = now.Add(CertificateCheckInterval)
|
||||
|
||||
stat, err := os.Stat(r.certFile)
|
||||
if err != nil {
|
||||
log.Printf("could not stat %s: %s", r.certFile, err)
|
||||
return r.pool
|
||||
}
|
||||
|
||||
if !stat.ModTime().Equal(r.lastModified) {
|
||||
log.Printf("reloading certificate pool from %s", r.certFile)
|
||||
pool, err := loadCertPool(r.certFile)
|
||||
if err != nil {
|
||||
log.Printf("could not load certificate pool: %s", err)
|
||||
return r.pool
|
||||
}
|
||||
|
||||
r.pool = pool
|
||||
r.lastModified = stat.ModTime()
|
||||
}
|
||||
|
||||
return r.pool
|
||||
}
|
36
certificate_reloader_test.go
Normal file
36
certificate_reloader_test.go
Normal file
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Standalone signaling server for the Nextcloud Spreed app.
|
||||
* Copyright (C) 2022 struktur AG
|
||||
*
|
||||
* @author Joachim Bauch <bauch@struktur.de>
|
||||
*
|
||||
* @license GNU AGPL version 3 or any later version
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package signaling
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func UpdateCertificateCheckIntervalForTest(t *testing.T, interval time.Duration) {
|
||||
old := CertificateCheckInterval
|
||||
t.Cleanup(func() {
|
||||
CertificateCheckInterval = old
|
||||
})
|
||||
|
||||
CertificateCheckInterval = interval
|
||||
}
|
250
clientsession.go
250
clientsession.go
|
@ -33,7 +33,6 @@ import (
|
|||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/pion/sdp"
|
||||
)
|
||||
|
||||
|
@ -50,8 +49,8 @@ var (
|
|||
type ClientSession struct {
|
||||
roomJoinTime int64
|
||||
|
||||
running int32
|
||||
hub *Hub
|
||||
events AsyncEvents
|
||||
privateId string
|
||||
publicId string
|
||||
data *SessionIdData
|
||||
|
@ -68,10 +67,7 @@ type ClientSession struct {
|
|||
backendUrl string
|
||||
parsedBackendUrl *url.URL
|
||||
|
||||
natsReceiver chan *nats.Msg
|
||||
stopRun chan bool
|
||||
runStopped chan bool
|
||||
expires time.Time
|
||||
expires time.Time
|
||||
|
||||
mu sync.Mutex
|
||||
|
||||
|
@ -79,9 +75,8 @@ type ClientSession struct {
|
|||
room unsafe.Pointer
|
||||
roomSessionId string
|
||||
|
||||
userSubscription NatsSubscription
|
||||
sessionSubscription NatsSubscription
|
||||
roomSubscription NatsSubscription
|
||||
publisherWaitersId uint64
|
||||
publisherWaiters map[uint64]chan bool
|
||||
|
||||
publishers map[string]McuPublisher
|
||||
subscribers map[string]McuSubscriber
|
||||
|
@ -96,6 +91,7 @@ type ClientSession struct {
|
|||
func NewClientSession(hub *Hub, privateId string, publicId string, data *SessionIdData, backend *Backend, hello *HelloClientMessage, auth *BackendClientAuthResponse) (*ClientSession, error) {
|
||||
s := &ClientSession{
|
||||
hub: hub,
|
||||
events: hub.events,
|
||||
privateId: privateId,
|
||||
publicId: publicId,
|
||||
data: data,
|
||||
|
@ -106,10 +102,6 @@ func NewClientSession(hub *Hub, privateId string, publicId string, data *Session
|
|||
userData: auth.User,
|
||||
|
||||
backend: backend,
|
||||
|
||||
natsReceiver: make(chan *nats.Msg, 64),
|
||||
stopRun: make(chan bool, 1),
|
||||
runStopped: make(chan bool, 1),
|
||||
}
|
||||
if s.clientType == HelloClientTypeInternal {
|
||||
s.backendUrl = hello.Auth.internalParams.Backend
|
||||
|
@ -137,11 +129,9 @@ func NewClientSession(hub *Hub, privateId string, publicId string, data *Session
|
|||
s.parsedBackendUrl = u
|
||||
}
|
||||
|
||||
if err := s.SubscribeNats(hub.nats); err != nil {
|
||||
if err := s.SubscribeEvents(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
atomic.StoreInt32(&s.running, 1)
|
||||
go s.run()
|
||||
return s, nil
|
||||
}
|
||||
|
||||
|
@ -298,19 +288,6 @@ func (s *ClientSession) UserData() *json.RawMessage {
|
|||
return s.userData
|
||||
}
|
||||
|
||||
func (s *ClientSession) run() {
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case msg := <-s.natsReceiver:
|
||||
s.processClientMessage(msg)
|
||||
case <-s.stopRun:
|
||||
break loop
|
||||
}
|
||||
}
|
||||
s.runStopped <- true
|
||||
}
|
||||
|
||||
func (s *ClientSession) StartExpire() {
|
||||
// The hub mutex must be held when calling this method.
|
||||
s.expires = time.Now().Add(sessionExpireDuration)
|
||||
|
@ -378,18 +355,10 @@ func (s *ClientSession) closeAndWait(wait bool) {
|
|||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.userSubscription != nil {
|
||||
if err := s.userSubscription.Unsubscribe(); err != nil {
|
||||
log.Printf("Error closing user subscription in session %s: %s", s.PublicId(), err)
|
||||
}
|
||||
s.userSubscription = nil
|
||||
}
|
||||
if s.sessionSubscription != nil {
|
||||
if err := s.sessionSubscription.Unsubscribe(); err != nil {
|
||||
log.Printf("Error closing session subscription in session %s: %s", s.PublicId(), err)
|
||||
}
|
||||
s.sessionSubscription = nil
|
||||
if s.userId != "" {
|
||||
s.events.UnregisterUserListener(s.userId, s.backend, s)
|
||||
}
|
||||
s.events.UnregisterSessionListener(s.publicId, s.backend, s)
|
||||
go func(virtualSessions map[*VirtualSession]bool) {
|
||||
for session := range virtualSessions {
|
||||
session.Close()
|
||||
|
@ -399,56 +368,32 @@ func (s *ClientSession) closeAndWait(wait bool) {
|
|||
s.releaseMcuObjects()
|
||||
s.clearClientLocked(nil)
|
||||
s.backend.RemoveSession(s)
|
||||
if atomic.CompareAndSwapInt32(&s.running, 1, 0) {
|
||||
s.stopRun <- true
|
||||
// Only wait if called from outside the Session goroutine.
|
||||
if wait {
|
||||
s.mu.Unlock()
|
||||
// Wait for Session goroutine to stop
|
||||
<-s.runStopped
|
||||
s.mu.Lock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func GetSubjectForUserId(userId string, backend *Backend) string {
|
||||
if backend == nil || backend.IsCompat() {
|
||||
return GetEncodedSubject("user", userId)
|
||||
}
|
||||
|
||||
return GetEncodedSubject("user", userId+"|"+backend.Id())
|
||||
}
|
||||
|
||||
func (s *ClientSession) SubscribeNats(n NatsClient) error {
|
||||
func (s *ClientSession) SubscribeEvents() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var err error
|
||||
if s.userId != "" {
|
||||
if s.userSubscription, err = n.Subscribe(GetSubjectForUserId(s.userId, s.backend), s.natsReceiver); err != nil {
|
||||
if err := s.events.RegisterUserListener(s.userId, s.backend, s); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if s.sessionSubscription, err = n.Subscribe("session."+s.publicId, s.natsReceiver); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return s.events.RegisterSessionListener(s.publicId, s.backend, s)
|
||||
}
|
||||
|
||||
func (s *ClientSession) SubscribeRoomNats(n NatsClient, roomid string, roomSessionId string) error {
|
||||
func (s *ClientSession) SubscribeRoomEvents(roomid string, roomSessionId string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var err error
|
||||
if s.roomSubscription, err = n.Subscribe(GetSubjectForRoomId(roomid, s.Backend()), s.natsReceiver); err != nil {
|
||||
if err := s.events.RegisterRoomListener(roomid, s.backend, s); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if roomSessionId != "" {
|
||||
if err = s.hub.roomSessions.SetRoomSession(s, roomSessionId); err != nil {
|
||||
s.doUnsubscribeRoomNats(true)
|
||||
if err := s.hub.roomSessions.SetRoomSession(s, roomSessionId); err != nil {
|
||||
s.doUnsubscribeRoomEvents(true)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -479,29 +424,26 @@ func (s *ClientSession) LeaveRoom(notify bool) *Room {
|
|||
return nil
|
||||
}
|
||||
|
||||
s.doUnsubscribeRoomNats(notify)
|
||||
s.doUnsubscribeRoomEvents(notify)
|
||||
s.SetRoom(nil)
|
||||
s.releaseMcuObjects()
|
||||
room.RemoveSession(s)
|
||||
return room
|
||||
}
|
||||
|
||||
func (s *ClientSession) UnsubscribeRoomNats() {
|
||||
func (s *ClientSession) UnsubscribeRoomEvents() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.doUnsubscribeRoomNats(true)
|
||||
s.doUnsubscribeRoomEvents(true)
|
||||
}
|
||||
|
||||
func (s *ClientSession) doUnsubscribeRoomNats(notify bool) {
|
||||
if s.roomSubscription != nil {
|
||||
if err := s.roomSubscription.Unsubscribe(); err != nil {
|
||||
log.Printf("Error closing room subscription in session %s: %s", s.PublicId(), err)
|
||||
}
|
||||
s.roomSubscription = nil
|
||||
func (s *ClientSession) doUnsubscribeRoomEvents(notify bool) {
|
||||
room := s.GetRoom()
|
||||
if room != nil {
|
||||
s.events.UnregisterRoomListener(room.Id(), s.Backend(), s)
|
||||
}
|
||||
s.hub.roomSessions.DeleteRoomSession(s)
|
||||
room := s.GetRoom()
|
||||
if notify && room != nil && s.roomSessionId != "" {
|
||||
// Notify
|
||||
go func(sid string) {
|
||||
|
@ -864,6 +806,26 @@ func (s *ClientSession) checkOfferTypeLocked(streamType string, data *MessageCli
|
|||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *ClientSession) wakeupPublisherWaiters() {
|
||||
for _, ch := range s.publisherWaiters {
|
||||
ch <- true
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ClientSession) addPublisherWaiter(ch chan bool) uint64 {
|
||||
if s.publisherWaiters == nil {
|
||||
s.publisherWaiters = make(map[uint64]chan bool)
|
||||
}
|
||||
id := s.publisherWaitersId + 1
|
||||
s.publisherWaitersId = id
|
||||
s.publisherWaiters[id] = ch
|
||||
return id
|
||||
}
|
||||
|
||||
func (s *ClientSession) removePublisherWaiter(id uint64) {
|
||||
delete(s.publisherWaiters, id)
|
||||
}
|
||||
|
||||
func (s *ClientSession) GetOrCreatePublisher(ctx context.Context, mcu Mcu, streamType string, data *MessageClientMessageData) (McuPublisher, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
@ -912,6 +874,7 @@ func (s *ClientSession) GetOrCreatePublisher(ctx context.Context, mcu Mcu, strea
|
|||
s.publishers[streamType] = publisher
|
||||
}
|
||||
log.Printf("Publishing %s as %s for session %s", streamType, publisher.Id(), s.PublicId())
|
||||
s.wakeupPublisherWaiters()
|
||||
} else {
|
||||
publisher.SetMedia(mediaTypes)
|
||||
}
|
||||
|
@ -919,11 +882,44 @@ func (s *ClientSession) GetOrCreatePublisher(ctx context.Context, mcu Mcu, strea
|
|||
return publisher, nil
|
||||
}
|
||||
|
||||
func (s *ClientSession) getPublisherLocked(streamType string) McuPublisher {
|
||||
return s.publishers[streamType]
|
||||
}
|
||||
|
||||
func (s *ClientSession) GetPublisher(streamType string) McuPublisher {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
return s.publishers[streamType]
|
||||
return s.getPublisherLocked(streamType)
|
||||
}
|
||||
|
||||
func (s *ClientSession) GetOrWaitForPublisher(ctx context.Context, streamType string) McuPublisher {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
publisher := s.getPublisherLocked(streamType)
|
||||
if publisher != nil {
|
||||
return publisher
|
||||
}
|
||||
|
||||
ch := make(chan bool, 1)
|
||||
id := s.addPublisherWaiter(ch)
|
||||
defer s.removePublisherWaiter(id)
|
||||
|
||||
for {
|
||||
s.mu.Unlock()
|
||||
select {
|
||||
case <-ch:
|
||||
s.mu.Lock()
|
||||
publisher := s.getPublisherLocked(streamType)
|
||||
if publisher != nil {
|
||||
return publisher
|
||||
}
|
||||
case <-ctx.Done():
|
||||
s.mu.Lock()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ClientSession) GetOrCreateSubscriber(ctx context.Context, mcu Mcu, id string, streamType string) (McuSubscriber, error) {
|
||||
|
@ -967,13 +963,19 @@ func (s *ClientSession) GetSubscriber(id string, streamType string) McuSubscribe
|
|||
return s.subscribers[id+"|"+streamType]
|
||||
}
|
||||
|
||||
func (s *ClientSession) processClientMessage(msg *nats.Msg) {
|
||||
var message NatsMessage
|
||||
if err := s.hub.nats.Decode(msg, &message); err != nil {
|
||||
log.Printf("Could not decode NATS message %+v for session %s: %s", *msg, s.PublicId(), err)
|
||||
return
|
||||
}
|
||||
func (s *ClientSession) ProcessAsyncRoomMessage(message *AsyncMessage) {
|
||||
s.processAsyncMessage(message)
|
||||
}
|
||||
|
||||
func (s *ClientSession) ProcessAsyncUserMessage(message *AsyncMessage) {
|
||||
s.processAsyncMessage(message)
|
||||
}
|
||||
|
||||
func (s *ClientSession) ProcessAsyncSessionMessage(message *AsyncMessage) {
|
||||
s.processAsyncMessage(message)
|
||||
}
|
||||
|
||||
func (s *ClientSession) processAsyncMessage(message *AsyncMessage) {
|
||||
switch message.Type {
|
||||
case "permissions":
|
||||
s.SetPermissions(message.Permissions)
|
||||
|
@ -1015,9 +1017,71 @@ func (s *ClientSession) processClientMessage(msg *nats.Msg) {
|
|||
s.LeaveRoom(false)
|
||||
defer s.closeAndWait(false)
|
||||
}
|
||||
case "sendoffer":
|
||||
// Process asynchronously to not block other messages received.
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), s.hub.mcuTimeout)
|
||||
defer cancel()
|
||||
|
||||
mc, err := s.GetOrCreateSubscriber(ctx, s.hub.mcu, message.SendOffer.SessionId, message.SendOffer.Data.RoomType)
|
||||
if err != nil {
|
||||
log.Printf("Could not create MCU subscriber for session %s to process sendoffer in %s: %s", message.SendOffer.SessionId, s.PublicId(), err)
|
||||
if err := s.events.PublishSessionMessage(message.SendOffer.SessionId, s.backend, &AsyncMessage{
|
||||
Type: "message",
|
||||
Message: &ServerMessage{
|
||||
Id: message.SendOffer.MessageId,
|
||||
Type: "error",
|
||||
Error: NewError("client_not_found", "No MCU client found to send message to."),
|
||||
},
|
||||
}); err != nil {
|
||||
log.Printf("Error sending sendoffer error response to %s: %s", message.SendOffer.SessionId, err)
|
||||
}
|
||||
return
|
||||
} else if mc == nil {
|
||||
log.Printf("No MCU subscriber found for session %s to process sendoffer in %s", message.SendOffer.SessionId, s.PublicId())
|
||||
if err := s.events.PublishSessionMessage(message.SendOffer.SessionId, s.backend, &AsyncMessage{
|
||||
Type: "message",
|
||||
Message: &ServerMessage{
|
||||
Id: message.SendOffer.MessageId,
|
||||
Type: "error",
|
||||
Error: NewError("client_not_found", "No MCU client found to send message to."),
|
||||
},
|
||||
}); err != nil {
|
||||
log.Printf("Error sending sendoffer error response to %s: %s", message.SendOffer.SessionId, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
mc.SendMessage(context.TODO(), nil, message.SendOffer.Data, func(err error, response map[string]interface{}) {
|
||||
if err != nil {
|
||||
log.Printf("Could not send MCU message %+v for session %s to %s: %s", message.SendOffer.Data, message.SendOffer.SessionId, s.PublicId(), err)
|
||||
if err := s.events.PublishSessionMessage(message.SendOffer.SessionId, s.backend, &AsyncMessage{
|
||||
Type: "message",
|
||||
Message: &ServerMessage{
|
||||
Id: message.SendOffer.MessageId,
|
||||
Type: "error",
|
||||
Error: NewError("processing_failed", "Processing of the message failed, please check server logs."),
|
||||
},
|
||||
}); err != nil {
|
||||
log.Printf("Error sending sendoffer error response to %s: %s", message.SendOffer.SessionId, err)
|
||||
}
|
||||
return
|
||||
} else if response == nil {
|
||||
// No response received
|
||||
return
|
||||
}
|
||||
|
||||
s.hub.sendMcuMessageResponse(s, mc, &MessageClientMessage{
|
||||
Recipient: MessageClientMessageRecipient{
|
||||
SessionId: message.SendOffer.SessionId,
|
||||
},
|
||||
}, message.SendOffer.Data, response)
|
||||
})
|
||||
}()
|
||||
return
|
||||
}
|
||||
|
||||
serverMessage := s.processNatsMessage(&message)
|
||||
serverMessage := s.filterAsyncMessage(message)
|
||||
if serverMessage == nil {
|
||||
return
|
||||
}
|
||||
|
@ -1147,11 +1211,11 @@ func (s *ClientSession) filterMessage(message *ServerMessage) *ServerMessage {
|
|||
return message
|
||||
}
|
||||
|
||||
func (s *ClientSession) processNatsMessage(msg *NatsMessage) *ServerMessage {
|
||||
func (s *ClientSession) filterAsyncMessage(msg *AsyncMessage) *ServerMessage {
|
||||
switch msg.Type {
|
||||
case "message":
|
||||
if msg.Message == nil {
|
||||
log.Printf("Received NATS message without payload: %+v", msg)
|
||||
log.Printf("Received asynchronous message without payload: %+v", msg)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -1172,7 +1236,7 @@ func (s *ClientSession) processNatsMessage(msg *NatsMessage) *ServerMessage {
|
|||
}
|
||||
case "event":
|
||||
if msg.Message.Event.Target == "room" {
|
||||
// Can happen mostly during tests where an older room NATS message
|
||||
// Can happen mostly during tests where an older room async message
|
||||
// could be received by a subscriber that joined after it was sent.
|
||||
if joined := s.getRoomJoinTime(); joined.IsZero() || msg.SendTime.Before(joined) {
|
||||
log.Printf("Message %+v was sent before room was joined, ignoring", msg.Message)
|
||||
|
@ -1183,7 +1247,7 @@ func (s *ClientSession) processNatsMessage(msg *NatsMessage) *ServerMessage {
|
|||
|
||||
return msg.Message
|
||||
default:
|
||||
log.Printf("Received NATS message with unsupported type %s: %+v", msg.Type, msg)
|
||||
log.Printf("Received async message with unsupported type %s: %+v", msg.Type, msg)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
|
@ -45,3 +45,6 @@ The following metrics are available:
|
|||
| `signaling_mcu_no_backend_available_total` | Counter | 0.4.0 | Total number of publishing requests where no backend was available | `type` |
|
||||
| `signaling_room_sessions` | Gauge | 0.4.0 | The current number of sessions in a room | `backend`, `room`, `clienttype` |
|
||||
| `signaling_server_messages_total` | Counter | 0.4.0 | The total number of signaling messages | `type` |
|
||||
| `signaling_grpc_clients` | Gauge | 1.0.0 | The current number of GRPC clients | |
|
||||
| `signaling_grpc_client_calls_total` | Counter | 1.0.0 | The total number of GRPC client calls | `method` |
|
||||
| `signaling_grpc_server_calls_total` | Counter | 1.0.0 | The total number of GRPC server calls | `method` |
|
||||
|
|
4
go.mod
4
go.mod
|
@ -19,6 +19,8 @@ require (
|
|||
go.etcd.io/etcd/client/pkg/v3 v3.5.4
|
||||
go.etcd.io/etcd/client/v3 v3.5.4
|
||||
go.etcd.io/etcd/server/v3 v3.5.4
|
||||
google.golang.org/grpc v1.47.0
|
||||
google.golang.org/protobuf v1.28.0
|
||||
)
|
||||
|
||||
require (
|
||||
|
@ -78,8 +80,6 @@ require (
|
|||
golang.org/x/text v0.3.6 // indirect
|
||||
golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 // indirect
|
||||
google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c // indirect
|
||||
google.golang.org/grpc v1.38.0 // indirect
|
||||
google.golang.org/protobuf v1.26.0 // indirect
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.0.0 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
sigs.k8s.io/yaml v1.2.0 // indirect
|
||||
|
|
16
go.sum
16
go.sum
|
@ -69,6 +69,10 @@ github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMn
|
|||
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
|
||||
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
|
||||
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
|
||||
github.com/cncf/udpa/go v0.0.0-20210930031921-04548b0d99d4/go.mod h1:6pvJx4me5XPnfI9Z40ddWsdw2W/uZgQLFXToKeRcDiI=
|
||||
github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
|
||||
github.com/cncf/xds/go v0.0.0-20211001041855-01bcc9b48dfe/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
|
||||
github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
|
||||
github.com/cockroachdb/datadriven v0.0.0-20200714090401-bf6692d28da5 h1:xD/lrqdvwsc+O2bjSSi3YqY73Ke3LAiSCx49aCesA0E=
|
||||
github.com/cockroachdb/datadriven v0.0.0-20200714090401-bf6692d28da5/go.mod h1:h6jFvWxBdQXxjopDMZyH2UVceIRfR84bdzbkoKrsWNo=
|
||||
github.com/cockroachdb/errors v1.2.4 h1:Lap807SXTH5tri2TivECb/4abUkMZC9zRoLarvcKDqs=
|
||||
|
@ -101,6 +105,7 @@ github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.m
|
|||
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
|
||||
github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
|
||||
github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
|
||||
github.com/envoyproxy/go-control-plane v0.10.2-0.20220325020618-49ff273808a1/go.mod h1:KJwIaB5Mv44NWtYuAOFCVOjcI94vtpEz2JU/D2v6IjE=
|
||||
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
||||
github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4=
|
||||
github.com/form3tech-oss/jwt-go v3.2.3+incompatible h1:7ZaBxOI7TMoYBfyA3cQHErNNyAWIKUMIwqxEtgHOs5c=
|
||||
|
@ -170,8 +175,9 @@ github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
|||
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
|
||||
github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
|
||||
|
@ -581,6 +587,7 @@ golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7w
|
|||
golang.org/x/sys v0.0.0-20200923182605-d9f96fdee20d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
|
@ -732,8 +739,9 @@ google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM
|
|||
google.golang.org/grpc v1.33.1/go.mod h1:fr5YgcSWrqhRRxogOsw7RzIpsmvOZ6IcH4kBYTpR3n0=
|
||||
google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU=
|
||||
google.golang.org/grpc v1.37.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM=
|
||||
google.golang.org/grpc v1.38.0 h1:/9BgsAsa5nWe26HqOlvlgJnqBuktYOLCgjCPqsa56W0=
|
||||
google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM=
|
||||
google.golang.org/grpc v1.47.0 h1:9n77onPX5F3qfFCqjy9dhn8PbNQsIKeVU04J9G7umt8=
|
||||
google.golang.org/grpc v1.47.0/go.mod h1:vN9eftEi1UMyUsIF80+uQXhHjbXYbm0uXoFCACuMGWk=
|
||||
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
|
||||
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
|
||||
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
|
||||
|
@ -745,8 +753,10 @@ google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpAD
|
|||
google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4=
|
||||
google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.26.0 h1:bxAC2xTBsZGibn2RTntX0oH50xLsqy1OxA9tTL3p/lk=
|
||||
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
|
||||
google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
|
||||
google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw=
|
||||
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
|
775
grpc_client.go
Normal file
775
grpc_client.go
Normal file
|
@ -0,0 +1,775 @@
|
|||
/**
|
||||
* Standalone signaling server for the Nextcloud Spreed app.
|
||||
* Copyright (C) 2022 struktur AG
|
||||
*
|
||||
* @author Joachim Bauch <bauch@struktur.de>
|
||||
*
|
||||
* @license GNU AGPL version 3 or any later version
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package signaling
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/dlintw/goconf"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
"google.golang.org/grpc"
|
||||
codes "google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/resolver"
|
||||
status "google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
const (
|
||||
GrpcTargetTypeStatic = "static"
|
||||
GrpcTargetTypeEtcd = "etcd"
|
||||
|
||||
DefaultGrpcTargetType = GrpcTargetTypeStatic
|
||||
)
|
||||
|
||||
var (
|
||||
lookupGrpcIp = net.LookupIP // can be overwritten from tests
|
||||
|
||||
customResolverPrefix uint64
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterGrpcClientStats()
|
||||
}
|
||||
|
||||
type grpcClientImpl struct {
|
||||
RpcInternalClient
|
||||
RpcMcuClient
|
||||
RpcSessionsClient
|
||||
}
|
||||
|
||||
func newGrpcClientImpl(conn grpc.ClientConnInterface) *grpcClientImpl {
|
||||
return &grpcClientImpl{
|
||||
RpcInternalClient: NewRpcInternalClient(conn),
|
||||
RpcMcuClient: NewRpcMcuClient(conn),
|
||||
RpcSessionsClient: NewRpcSessionsClient(conn),
|
||||
}
|
||||
}
|
||||
|
||||
type GrpcClient struct {
|
||||
isSelf uint32
|
||||
|
||||
ip net.IP
|
||||
target string
|
||||
conn *grpc.ClientConn
|
||||
impl *grpcClientImpl
|
||||
}
|
||||
|
||||
type customIpResolver struct {
|
||||
resolver.Builder
|
||||
resolver.Resolver
|
||||
|
||||
scheme string
|
||||
addr string
|
||||
hostname string
|
||||
}
|
||||
|
||||
func (r *customIpResolver) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) {
|
||||
state := resolver.State{
|
||||
Addresses: []resolver.Address{
|
||||
{
|
||||
Addr: r.addr,
|
||||
ServerName: r.hostname,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := cc.UpdateState(state); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *customIpResolver) Scheme() string {
|
||||
return r.scheme
|
||||
}
|
||||
|
||||
func (r *customIpResolver) ResolveNow(opts resolver.ResolveNowOptions) {
|
||||
// Noop, we use a static configuration.
|
||||
}
|
||||
|
||||
func (r *customIpResolver) Close() {
|
||||
// Noop
|
||||
}
|
||||
|
||||
func NewGrpcClient(target string, ip net.IP, opts ...grpc.DialOption) (*GrpcClient, error) {
|
||||
var conn *grpc.ClientConn
|
||||
var err error
|
||||
if ip != nil {
|
||||
prefix := atomic.AddUint64(&customResolverPrefix, 1)
|
||||
addr := ip.String()
|
||||
hostname := target
|
||||
if host, port, err := net.SplitHostPort(target); err == nil {
|
||||
addr = net.JoinHostPort(addr, port)
|
||||
hostname = host
|
||||
}
|
||||
resolver := &customIpResolver{
|
||||
scheme: fmt.Sprintf("custom%d", prefix),
|
||||
addr: addr,
|
||||
hostname: hostname,
|
||||
}
|
||||
opts = append(opts, grpc.WithResolvers(resolver))
|
||||
conn, err = grpc.Dial(fmt.Sprintf("%s://%s", resolver.Scheme(), target), opts...)
|
||||
} else {
|
||||
conn, err = grpc.Dial(target, opts...)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &GrpcClient{
|
||||
ip: ip,
|
||||
target: target,
|
||||
conn: conn,
|
||||
impl: newGrpcClientImpl(conn),
|
||||
}
|
||||
|
||||
if ip != nil {
|
||||
result.target += " (" + ip.String() + ")"
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *GrpcClient) Target() string {
|
||||
return c.target
|
||||
}
|
||||
|
||||
func (c *GrpcClient) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func (c *GrpcClient) IsSelf() bool {
|
||||
return atomic.LoadUint32(&c.isSelf) != 0
|
||||
}
|
||||
|
||||
func (c *GrpcClient) SetSelf(self bool) {
|
||||
if self {
|
||||
atomic.StoreUint32(&c.isSelf, 1)
|
||||
} else {
|
||||
atomic.StoreUint32(&c.isSelf, 0)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *GrpcClient) GetServerId(ctx context.Context) (string, error) {
|
||||
statsGrpcClientCalls.WithLabelValues("GetServerId").Inc()
|
||||
response, err := c.impl.GetServerId(ctx, &GetServerIdRequest{}, grpc.WaitForReady(true))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return response.GetServerId(), nil
|
||||
}
|
||||
|
||||
func (c *GrpcClient) LookupSessionId(ctx context.Context, roomSessionId string) (string, error) {
|
||||
statsGrpcClientCalls.WithLabelValues("LookupSessionId").Inc()
|
||||
// TODO: Remove debug logging
|
||||
log.Printf("Lookup room session %s on %s", roomSessionId, c.Target())
|
||||
response, err := c.impl.LookupSessionId(ctx, &LookupSessionIdRequest{
|
||||
RoomSessionId: roomSessionId,
|
||||
}, grpc.WaitForReady(true))
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
return "", ErrNoSuchRoomSession
|
||||
} else if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
sessionId := response.GetSessionId()
|
||||
if sessionId == "" {
|
||||
return "", ErrNoSuchRoomSession
|
||||
}
|
||||
|
||||
return sessionId, nil
|
||||
}
|
||||
|
||||
func (c *GrpcClient) IsSessionInCall(ctx context.Context, sessionId string, room *Room) (bool, error) {
|
||||
statsGrpcClientCalls.WithLabelValues("IsSessionInCall").Inc()
|
||||
// TODO: Remove debug logging
|
||||
log.Printf("Check if session %s is in call %s on %s", sessionId, room.Id(), c.Target())
|
||||
response, err := c.impl.IsSessionInCall(ctx, &IsSessionInCallRequest{
|
||||
SessionId: sessionId,
|
||||
RoomId: room.Id(),
|
||||
BackendUrl: room.Backend().url,
|
||||
}, grpc.WaitForReady(true))
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
return false, nil
|
||||
} else if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return response.GetInCall(), nil
|
||||
}
|
||||
|
||||
func (c *GrpcClient) GetPublisherId(ctx context.Context, sessionId string, streamType string) (string, string, net.IP, error) {
|
||||
statsGrpcClientCalls.WithLabelValues("GetPublisherId").Inc()
|
||||
// TODO: Remove debug logging
|
||||
log.Printf("Get %s publisher id %s on %s", streamType, sessionId, c.Target())
|
||||
response, err := c.impl.GetPublisherId(ctx, &GetPublisherIdRequest{
|
||||
SessionId: sessionId,
|
||||
StreamType: streamType,
|
||||
}, grpc.WaitForReady(true))
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
return "", "", nil, nil
|
||||
} else if err != nil {
|
||||
return "", "", nil, err
|
||||
}
|
||||
|
||||
return response.GetPublisherId(), response.GetProxyUrl(), net.ParseIP(response.GetIp()), nil
|
||||
}
|
||||
|
||||
type GrpcClients struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
clientsMap map[string][]*GrpcClient
|
||||
clients []*GrpcClient
|
||||
|
||||
dnsDiscovery bool
|
||||
stopping chan bool
|
||||
stopped chan bool
|
||||
|
||||
etcdClient *EtcdClient
|
||||
targetPrefix string
|
||||
targetInformation map[string]*GrpcTargetInformationEtcd
|
||||
dialOptions atomic.Value // []grpc.DialOption
|
||||
|
||||
initializedCtx context.Context
|
||||
initializedFunc context.CancelFunc
|
||||
wakeupChanForTesting chan bool
|
||||
selfCheckWaitGroup sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewGrpcClients(config *goconf.ConfigFile, etcdClient *EtcdClient) (*GrpcClients, error) {
|
||||
initializedCtx, initializedFunc := context.WithCancel(context.Background())
|
||||
result := &GrpcClients{
|
||||
etcdClient: etcdClient,
|
||||
initializedCtx: initializedCtx,
|
||||
initializedFunc: initializedFunc,
|
||||
|
||||
stopping: make(chan bool, 1),
|
||||
stopped: make(chan bool, 1),
|
||||
}
|
||||
if err := result.load(config, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *GrpcClients) load(config *goconf.ConfigFile, fromReload bool) error {
|
||||
creds, err := NewReloadableCredentials(config, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
opts := []grpc.DialOption{grpc.WithTransportCredentials(creds)}
|
||||
c.dialOptions.Store(opts)
|
||||
|
||||
targetType, _ := config.GetString("grpc", "targettype")
|
||||
if targetType == "" {
|
||||
targetType = DefaultGrpcTargetType
|
||||
}
|
||||
|
||||
switch targetType {
|
||||
case GrpcTargetTypeStatic:
|
||||
err = c.loadTargetsStatic(config, fromReload, opts...)
|
||||
if err == nil && c.dnsDiscovery {
|
||||
go c.monitorGrpcIPs()
|
||||
}
|
||||
case GrpcTargetTypeEtcd:
|
||||
err = c.loadTargetsEtcd(config, fromReload, opts...)
|
||||
default:
|
||||
err = fmt.Errorf("unknown GRPC target type: %s", targetType)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *GrpcClients) closeClient(client *GrpcClient) {
|
||||
if client.IsSelf() {
|
||||
// Already closed.
|
||||
return
|
||||
}
|
||||
|
||||
if err := client.Close(); err != nil {
|
||||
log.Printf("Error closing client to %s: %s", client.Target(), err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *GrpcClients) isClientAvailable(target string, client *GrpcClient) bool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
entries, found := c.clientsMap[target]
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry == client {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *GrpcClients) getServerIdWithTimeout(ctx context.Context, client *GrpcClient) (string, error) {
|
||||
ctx2, cancel := context.WithTimeout(ctx, time.Second)
|
||||
defer cancel()
|
||||
|
||||
id, err := client.GetServerId(ctx2)
|
||||
return id, err
|
||||
}
|
||||
|
||||
func (c *GrpcClients) checkIsSelf(ctx context.Context, target string, client *GrpcClient) {
|
||||
backoff, _ := NewExponentialBackoff(initialWaitDelay, maxWaitDelay)
|
||||
defer c.selfCheckWaitGroup.Done()
|
||||
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Cancelled
|
||||
return
|
||||
default:
|
||||
if !c.isClientAvailable(target, client) {
|
||||
return
|
||||
}
|
||||
|
||||
id, err := c.getServerIdWithTimeout(ctx, client)
|
||||
if err != nil {
|
||||
if status.Code(err) != codes.Canceled {
|
||||
log.Printf("Error checking GRPC server id of %s, retrying in %s: %s", client.Target(), backoff.NextWait(), err)
|
||||
}
|
||||
backoff.Wait(ctx)
|
||||
continue
|
||||
}
|
||||
|
||||
if id == GrpcServerId {
|
||||
log.Printf("GRPC target %s is this server, removing", client.Target())
|
||||
c.closeClient(client)
|
||||
client.SetSelf(true)
|
||||
} else {
|
||||
log.Printf("Checked GRPC server id of %s", client.Target())
|
||||
}
|
||||
break loop
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *GrpcClients) loadTargetsStatic(config *goconf.ConfigFile, fromReload bool, opts ...grpc.DialOption) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
clientsMap := make(map[string][]*GrpcClient)
|
||||
var clients []*GrpcClient
|
||||
removeTargets := make(map[string]bool, len(c.clientsMap))
|
||||
for target, entries := range c.clientsMap {
|
||||
removeTargets[target] = true
|
||||
clientsMap[target] = entries
|
||||
}
|
||||
|
||||
targets, _ := config.GetString("grpc", "targets")
|
||||
for _, target := range strings.Split(targets, ",") {
|
||||
target = strings.TrimSpace(target)
|
||||
if target == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if entries, found := clientsMap[target]; found {
|
||||
clients = append(clients, entries...)
|
||||
delete(removeTargets, target)
|
||||
continue
|
||||
}
|
||||
|
||||
host := target
|
||||
if h, _, err := net.SplitHostPort(target); err == nil {
|
||||
host = h
|
||||
}
|
||||
|
||||
var ips []net.IP
|
||||
if net.ParseIP(host) == nil {
|
||||
// Use dedicated client for each IP address.
|
||||
var err error
|
||||
ips, err = lookupGrpcIp(host)
|
||||
if err != nil {
|
||||
log.Printf("Could not lookup %s: %s", host, err)
|
||||
// Make sure updating continues even if initial lookup failed.
|
||||
clientsMap[target] = nil
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
// Connect directly to IP address.
|
||||
ips = []net.IP{nil}
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
client, err := NewGrpcClient(target, ip, opts...)
|
||||
if err != nil {
|
||||
for _, clients := range clientsMap {
|
||||
for _, client := range clients {
|
||||
c.closeClient(client)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
c.selfCheckWaitGroup.Add(1)
|
||||
go c.checkIsSelf(context.Background(), target, client)
|
||||
|
||||
log.Printf("Adding %s as GRPC target", client.Target())
|
||||
clientsMap[target] = append(clientsMap[target], client)
|
||||
clients = append(clients, client)
|
||||
}
|
||||
}
|
||||
|
||||
for target := range removeTargets {
|
||||
if clients, found := clientsMap[target]; found {
|
||||
for _, client := range clients {
|
||||
log.Printf("Deleting GRPC target %s", client.Target())
|
||||
c.closeClient(client)
|
||||
}
|
||||
delete(clientsMap, target)
|
||||
}
|
||||
}
|
||||
|
||||
dnsDiscovery, _ := config.GetBool("grpc", "dnsdiscovery")
|
||||
if dnsDiscovery != c.dnsDiscovery {
|
||||
if !dnsDiscovery && fromReload {
|
||||
c.stopping <- true
|
||||
<-c.stopped
|
||||
}
|
||||
c.dnsDiscovery = dnsDiscovery
|
||||
if dnsDiscovery && fromReload {
|
||||
go c.monitorGrpcIPs()
|
||||
}
|
||||
}
|
||||
|
||||
c.clients = clients
|
||||
c.clientsMap = clientsMap
|
||||
c.initializedFunc()
|
||||
statsGrpcClients.Set(float64(len(clients)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *GrpcClients) monitorGrpcIPs() {
|
||||
log.Printf("Start monitoring GRPC client IPs")
|
||||
ticker := time.NewTicker(updateDnsInterval)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
c.updateGrpcIPs()
|
||||
case <-c.stopping:
|
||||
c.stopped <- true
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *GrpcClients) updateGrpcIPs() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
opts := c.dialOptions.Load().([]grpc.DialOption)
|
||||
|
||||
mapModified := false
|
||||
for target, clients := range c.clientsMap {
|
||||
host := target
|
||||
if h, _, err := net.SplitHostPort(target); err == nil {
|
||||
host = h
|
||||
}
|
||||
|
||||
if net.ParseIP(host) != nil {
|
||||
// No need to lookup endpoints that connect to IP addresses.
|
||||
continue
|
||||
}
|
||||
|
||||
ips, err := lookupGrpcIp(host)
|
||||
if err != nil {
|
||||
log.Printf("Could not lookup %s: %s", host, err)
|
||||
continue
|
||||
}
|
||||
|
||||
var newClients []*GrpcClient
|
||||
changed := false
|
||||
for _, client := range clients {
|
||||
found := false
|
||||
for idx, ip := range ips {
|
||||
if ip.Equal(client.ip) {
|
||||
ips = append(ips[:idx], ips[idx+1:]...)
|
||||
found = true
|
||||
newClients = append(newClients, client)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
changed = true
|
||||
log.Printf("Removing connection to %s", client.Target())
|
||||
c.closeClient(client)
|
||||
c.wakeupForTesting()
|
||||
}
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
client, err := NewGrpcClient(target, ip, opts...)
|
||||
if err != nil {
|
||||
log.Printf("Error creating client to %s with IP %s: %s", target, ip.String(), err)
|
||||
continue
|
||||
}
|
||||
|
||||
c.selfCheckWaitGroup.Add(1)
|
||||
go c.checkIsSelf(context.Background(), target, client)
|
||||
|
||||
log.Printf("Adding %s as GRPC target", client.Target())
|
||||
newClients = append(newClients, client)
|
||||
changed = true
|
||||
c.wakeupForTesting()
|
||||
}
|
||||
|
||||
if changed {
|
||||
c.clientsMap[target] = newClients
|
||||
mapModified = true
|
||||
}
|
||||
}
|
||||
|
||||
if mapModified {
|
||||
c.clients = make([]*GrpcClient, 0, len(c.clientsMap))
|
||||
for _, clients := range c.clientsMap {
|
||||
c.clients = append(c.clients, clients...)
|
||||
}
|
||||
statsGrpcClients.Set(float64(len(c.clients)))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *GrpcClients) loadTargetsEtcd(config *goconf.ConfigFile, fromReload bool, opts ...grpc.DialOption) error {
|
||||
if !c.etcdClient.IsConfigured() {
|
||||
return fmt.Errorf("No etcd endpoints configured")
|
||||
}
|
||||
|
||||
targetPrefix, _ := config.GetString("grpc", "targetprefix")
|
||||
if targetPrefix == "" {
|
||||
return fmt.Errorf("No GRPC target prefix configured")
|
||||
}
|
||||
c.targetPrefix = targetPrefix
|
||||
if c.targetInformation == nil {
|
||||
c.targetInformation = make(map[string]*GrpcTargetInformationEtcd)
|
||||
}
|
||||
|
||||
c.etcdClient.AddListener(c)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *GrpcClients) EtcdClientCreated(client *EtcdClient) {
|
||||
go func() {
|
||||
if err := client.Watch(context.Background(), c.targetPrefix, c, clientv3.WithPrefix()); err != nil {
|
||||
log.Printf("Error processing watch for %s: %s", c.targetPrefix, err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
client.WaitForConnection()
|
||||
|
||||
backoff, _ := NewExponentialBackoff(initialWaitDelay, maxWaitDelay)
|
||||
for {
|
||||
response, err := c.getGrpcTargets(client, c.targetPrefix)
|
||||
if err != nil {
|
||||
if err == context.DeadlineExceeded {
|
||||
log.Printf("Timeout getting initial list of GRPC targets, retry in %s", backoff.NextWait())
|
||||
} else {
|
||||
log.Printf("Could not get initial list of GRPC targets, retry in %s: %s", backoff.NextWait(), err)
|
||||
}
|
||||
|
||||
backoff.Wait(context.Background())
|
||||
continue
|
||||
}
|
||||
|
||||
for _, ev := range response.Kvs {
|
||||
c.EtcdKeyUpdated(client, string(ev.Key), ev.Value)
|
||||
}
|
||||
c.initializedFunc()
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *GrpcClients) getGrpcTargets(client *EtcdClient, targetPrefix string) (*clientv3.GetResponse, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
return client.Get(ctx, targetPrefix, clientv3.WithPrefix())
|
||||
}
|
||||
|
||||
func (c *GrpcClients) EtcdKeyUpdated(client *EtcdClient, key string, data []byte) {
|
||||
var info GrpcTargetInformationEtcd
|
||||
if err := json.Unmarshal(data, &info); err != nil {
|
||||
log.Printf("Could not decode GRPC target %s=%s: %s", key, string(data), err)
|
||||
return
|
||||
}
|
||||
if err := info.CheckValid(); err != nil {
|
||||
log.Printf("Received invalid GRPC target %s=%s: %s", key, string(data), err)
|
||||
return
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
prev, found := c.targetInformation[key]
|
||||
if found && prev.Address != info.Address {
|
||||
// Address of endpoint has changed, remove old one.
|
||||
c.removeEtcdClientLocked(key)
|
||||
}
|
||||
|
||||
if _, found := c.clientsMap[info.Address]; found {
|
||||
log.Printf("GRPC target %s already exists, ignoring %s", info.Address, key)
|
||||
return
|
||||
}
|
||||
|
||||
opts := c.dialOptions.Load().([]grpc.DialOption)
|
||||
cl, err := NewGrpcClient(info.Address, nil, opts...)
|
||||
if err != nil {
|
||||
log.Printf("Could not create GRPC client for target %s: %s", info.Address, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.selfCheckWaitGroup.Add(1)
|
||||
go c.checkIsSelf(context.Background(), info.Address, cl)
|
||||
|
||||
log.Printf("Adding %s as GRPC target", cl.Target())
|
||||
|
||||
if c.clientsMap == nil {
|
||||
c.clientsMap = make(map[string][]*GrpcClient)
|
||||
}
|
||||
c.clientsMap[info.Address] = []*GrpcClient{cl}
|
||||
c.clients = append(c.clients, cl)
|
||||
c.targetInformation[key] = &info
|
||||
statsGrpcClients.Inc()
|
||||
c.wakeupForTesting()
|
||||
}
|
||||
|
||||
func (c *GrpcClients) EtcdKeyDeleted(client *EtcdClient, key string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.removeEtcdClientLocked(key)
|
||||
}
|
||||
|
||||
func (c *GrpcClients) removeEtcdClientLocked(key string) {
|
||||
info, found := c.targetInformation[key]
|
||||
if !found {
|
||||
log.Printf("No connection found for %s, ignoring", key)
|
||||
c.wakeupForTesting()
|
||||
return
|
||||
}
|
||||
|
||||
delete(c.targetInformation, key)
|
||||
clients, found := c.clientsMap[info.Address]
|
||||
if !found {
|
||||
return
|
||||
}
|
||||
|
||||
for _, client := range clients {
|
||||
log.Printf("Removing connection to %s (from %s)", client.Target(), key)
|
||||
c.closeClient(client)
|
||||
}
|
||||
delete(c.clientsMap, info.Address)
|
||||
c.clients = make([]*GrpcClient, 0, len(c.clientsMap))
|
||||
for _, clients := range c.clientsMap {
|
||||
c.clients = append(c.clients, clients...)
|
||||
}
|
||||
statsGrpcClients.Dec()
|
||||
c.wakeupForTesting()
|
||||
}
|
||||
|
||||
func (c *GrpcClients) WaitForInitialized(ctx context.Context) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-c.initializedCtx.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *GrpcClients) wakeupForTesting() {
|
||||
if c.wakeupChanForTesting == nil {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case c.wakeupChanForTesting <- true:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (c *GrpcClients) Reload(config *goconf.ConfigFile) {
|
||||
if err := c.load(config, true); err != nil {
|
||||
log.Printf("Could not reload RPC clients: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *GrpcClients) Close() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
for _, clients := range c.clientsMap {
|
||||
for _, client := range clients {
|
||||
if err := client.Close(); err != nil {
|
||||
log.Printf("Error closing client to %s: %s", client.Target(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.clients = nil
|
||||
c.clientsMap = nil
|
||||
if c.dnsDiscovery {
|
||||
c.stopping <- true
|
||||
<-c.stopped
|
||||
c.dnsDiscovery = false
|
||||
}
|
||||
|
||||
if c.etcdClient != nil {
|
||||
c.etcdClient.RemoveListener(c)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *GrpcClients) GetClients() []*GrpcClient {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
if len(c.clients) == 0 {
|
||||
return c.clients
|
||||
}
|
||||
|
||||
result := make([]*GrpcClient, 0, len(c.clients)-1)
|
||||
for _, client := range c.clients {
|
||||
if client.IsSelf() {
|
||||
continue
|
||||
}
|
||||
|
||||
result = append(result, client)
|
||||
}
|
||||
return result
|
||||
}
|
358
grpc_client_test.go
Normal file
358
grpc_client_test.go
Normal file
|
@ -0,0 +1,358 @@
|
|||
/**
|
||||
* Standalone signaling server for the Nextcloud Spreed app.
|
||||
* Copyright (C) 2022 struktur AG
|
||||
*
|
||||
* @author Joachim Bauch <bauch@struktur.de>
|
||||
*
|
||||
* @license GNU AGPL version 3 or any later version
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package signaling
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/dlintw/goconf"
|
||||
"go.etcd.io/etcd/server/v3/embed"
|
||||
)
|
||||
|
||||
func NewGrpcClientsForTestWithConfig(t *testing.T, config *goconf.ConfigFile, etcdClient *EtcdClient) *GrpcClients {
|
||||
client, err := NewGrpcClients(config, etcdClient)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
client.Close()
|
||||
})
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
func NewGrpcClientsForTest(t *testing.T, addr string) *GrpcClients {
|
||||
config := goconf.NewConfigFile()
|
||||
config.AddOption("grpc", "targets", addr)
|
||||
config.AddOption("grpc", "dnsdiscovery", "true")
|
||||
|
||||
return NewGrpcClientsForTestWithConfig(t, config, nil)
|
||||
}
|
||||
|
||||
func NewGrpcClientsWithEtcdForTest(t *testing.T, etcd *embed.Etcd) *GrpcClients {
|
||||
config := goconf.NewConfigFile()
|
||||
config.AddOption("etcd", "endpoints", etcd.Config().LCUrls[0].String())
|
||||
|
||||
config.AddOption("grpc", "targettype", "etcd")
|
||||
config.AddOption("grpc", "targetprefix", "/grpctargets")
|
||||
|
||||
etcdClient, err := NewEtcdClient(config, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err := etcdClient.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
})
|
||||
|
||||
return NewGrpcClientsForTestWithConfig(t, config, etcdClient)
|
||||
}
|
||||
|
||||
func drainWakeupChannel(ch chan bool) {
|
||||
for {
|
||||
select {
|
||||
case <-ch:
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_GrpcClients_EtcdInitial(t *testing.T) {
|
||||
_, addr1 := NewGrpcServerForTest(t)
|
||||
_, addr2 := NewGrpcServerForTest(t)
|
||||
|
||||
etcd := NewEtcdForTest(t)
|
||||
|
||||
SetEtcdValue(etcd, "/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}"))
|
||||
SetEtcdValue(etcd, "/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}"))
|
||||
|
||||
client := NewGrpcClientsWithEtcdForTest(t, etcd)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
if err := client.WaitForInitialized(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if clients := client.GetClients(); len(clients) != 2 {
|
||||
t.Errorf("Expected two clients, got %+v", clients)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_GrpcClients_EtcdUpdate(t *testing.T) {
|
||||
etcd := NewEtcdForTest(t)
|
||||
client := NewGrpcClientsWithEtcdForTest(t, etcd)
|
||||
ch := make(chan bool, 1)
|
||||
client.wakeupChanForTesting = ch
|
||||
|
||||
if clients := client.GetClients(); len(clients) != 0 {
|
||||
t.Errorf("Expected no clients, got %+v", clients)
|
||||
}
|
||||
|
||||
drainWakeupChannel(ch)
|
||||
_, addr1 := NewGrpcServerForTest(t)
|
||||
SetEtcdValue(etcd, "/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}"))
|
||||
<-ch
|
||||
if clients := client.GetClients(); len(clients) != 1 {
|
||||
t.Errorf("Expected one client, got %+v", clients)
|
||||
} else if clients[0].Target() != addr1 {
|
||||
t.Errorf("Expected target %s, got %s", addr1, clients[0].Target())
|
||||
}
|
||||
|
||||
drainWakeupChannel(ch)
|
||||
_, addr2 := NewGrpcServerForTest(t)
|
||||
SetEtcdValue(etcd, "/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}"))
|
||||
<-ch
|
||||
if clients := client.GetClients(); len(clients) != 2 {
|
||||
t.Errorf("Expected two clients, got %+v", clients)
|
||||
} else if clients[0].Target() != addr1 {
|
||||
t.Errorf("Expected target %s, got %s", addr1, clients[0].Target())
|
||||
} else if clients[1].Target() != addr2 {
|
||||
t.Errorf("Expected target %s, got %s", addr2, clients[1].Target())
|
||||
}
|
||||
|
||||
drainWakeupChannel(ch)
|
||||
DeleteEtcdValue(etcd, "/grpctargets/one")
|
||||
<-ch
|
||||
if clients := client.GetClients(); len(clients) != 1 {
|
||||
t.Errorf("Expected one client, got %+v", clients)
|
||||
} else if clients[0].Target() != addr2 {
|
||||
t.Errorf("Expected target %s, got %s", addr2, clients[0].Target())
|
||||
}
|
||||
|
||||
drainWakeupChannel(ch)
|
||||
_, addr3 := NewGrpcServerForTest(t)
|
||||
SetEtcdValue(etcd, "/grpctargets/two", []byte("{\"address\":\""+addr3+"\"}"))
|
||||
<-ch
|
||||
if clients := client.GetClients(); len(clients) != 1 {
|
||||
t.Errorf("Expected one client, got %+v", clients)
|
||||
} else if clients[0].Target() != addr3 {
|
||||
t.Errorf("Expected target %s, got %s", addr3, clients[0].Target())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_GrpcClients_EtcdIgnoreSelf(t *testing.T) {
|
||||
etcd := NewEtcdForTest(t)
|
||||
client := NewGrpcClientsWithEtcdForTest(t, etcd)
|
||||
ch := make(chan bool, 1)
|
||||
client.wakeupChanForTesting = ch
|
||||
|
||||
if clients := client.GetClients(); len(clients) != 0 {
|
||||
t.Errorf("Expected no clients, got %+v", clients)
|
||||
}
|
||||
|
||||
drainWakeupChannel(ch)
|
||||
_, addr1 := NewGrpcServerForTest(t)
|
||||
SetEtcdValue(etcd, "/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}"))
|
||||
<-ch
|
||||
if clients := client.GetClients(); len(clients) != 1 {
|
||||
t.Errorf("Expected one client, got %+v", clients)
|
||||
} else if clients[0].Target() != addr1 {
|
||||
t.Errorf("Expected target %s, got %s", addr1, clients[0].Target())
|
||||
}
|
||||
|
||||
drainWakeupChannel(ch)
|
||||
server2, addr2 := NewGrpcServerForTest(t)
|
||||
server2.serverId = GrpcServerId
|
||||
SetEtcdValue(etcd, "/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}"))
|
||||
<-ch
|
||||
client.selfCheckWaitGroup.Wait()
|
||||
if clients := client.GetClients(); len(clients) != 1 {
|
||||
t.Errorf("Expected one client, got %+v", clients)
|
||||
} else if clients[0].Target() != addr1 {
|
||||
t.Errorf("Expected target %s, got %s", addr1, clients[0].Target())
|
||||
}
|
||||
|
||||
drainWakeupChannel(ch)
|
||||
DeleteEtcdValue(etcd, "/grpctargets/two")
|
||||
<-ch
|
||||
if clients := client.GetClients(); len(clients) != 1 {
|
||||
t.Errorf("Expected one client, got %+v", clients)
|
||||
} else if clients[0].Target() != addr1 {
|
||||
t.Errorf("Expected target %s, got %s", addr1, clients[0].Target())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_GrpcClients_DnsDiscovery(t *testing.T) {
|
||||
var ipsResult []net.IP
|
||||
lookupGrpcIp = func(host string) ([]net.IP, error) {
|
||||
if host == "testgrpc" {
|
||||
return ipsResult, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unknown host")
|
||||
}
|
||||
target := "testgrpc:12345"
|
||||
ip1 := net.ParseIP("192.168.0.1")
|
||||
ip2 := net.ParseIP("192.168.0.2")
|
||||
targetWithIp1 := fmt.Sprintf("%s (%s)", target, ip1)
|
||||
targetWithIp2 := fmt.Sprintf("%s (%s)", target, ip2)
|
||||
ipsResult = []net.IP{ip1}
|
||||
client := NewGrpcClientsForTest(t, target)
|
||||
ch := make(chan bool, 1)
|
||||
client.wakeupChanForTesting = ch
|
||||
|
||||
if clients := client.GetClients(); len(clients) != 1 {
|
||||
t.Errorf("Expected one client, got %+v", clients)
|
||||
} else if clients[0].Target() != targetWithIp1 {
|
||||
t.Errorf("Expected target %s, got %s", targetWithIp1, clients[0].Target())
|
||||
} else if !clients[0].ip.Equal(ip1) {
|
||||
t.Errorf("Expected IP %s, got %s", ip1, clients[0].ip)
|
||||
}
|
||||
|
||||
ipsResult = []net.IP{ip1, ip2}
|
||||
drainWakeupChannel(ch)
|
||||
client.updateGrpcIPs()
|
||||
<-ch
|
||||
|
||||
if clients := client.GetClients(); len(clients) != 2 {
|
||||
t.Errorf("Expected two client, got %+v", clients)
|
||||
} else if clients[0].Target() != targetWithIp1 {
|
||||
t.Errorf("Expected target %s, got %s", targetWithIp1, clients[0].Target())
|
||||
} else if !clients[0].ip.Equal(ip1) {
|
||||
t.Errorf("Expected IP %s, got %s", ip1, clients[0].ip)
|
||||
} else if clients[1].Target() != targetWithIp2 {
|
||||
t.Errorf("Expected target %s, got %s", targetWithIp2, clients[1].Target())
|
||||
} else if !clients[1].ip.Equal(ip2) {
|
||||
t.Errorf("Expected IP %s, got %s", ip2, clients[1].ip)
|
||||
}
|
||||
|
||||
ipsResult = []net.IP{ip2}
|
||||
drainWakeupChannel(ch)
|
||||
client.updateGrpcIPs()
|
||||
<-ch
|
||||
|
||||
if clients := client.GetClients(); len(clients) != 1 {
|
||||
t.Errorf("Expected one client, got %+v", clients)
|
||||
} else if clients[0].Target() != targetWithIp2 {
|
||||
t.Errorf("Expected target %s, got %s", targetWithIp2, clients[0].Target())
|
||||
} else if !clients[0].ip.Equal(ip2) {
|
||||
t.Errorf("Expected IP %s, got %s", ip2, clients[0].ip)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_GrpcClients_DnsDiscoveryInitialFailed(t *testing.T) {
|
||||
var ipsResult []net.IP
|
||||
lookupGrpcIp = func(host string) ([]net.IP, error) {
|
||||
if host == "testgrpc" && len(ipsResult) > 0 {
|
||||
return ipsResult, nil
|
||||
}
|
||||
|
||||
return nil, &net.DNSError{
|
||||
Err: "no such host",
|
||||
Name: host,
|
||||
IsNotFound: true,
|
||||
}
|
||||
}
|
||||
target := "testgrpc:12345"
|
||||
ip1 := net.ParseIP("192.168.0.1")
|
||||
targetWithIp1 := fmt.Sprintf("%s (%s)", target, ip1)
|
||||
client := NewGrpcClientsForTest(t, target)
|
||||
ch := make(chan bool, 1)
|
||||
client.wakeupChanForTesting = ch
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
if err := client.WaitForInitialized(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if clients := client.GetClients(); len(clients) != 0 {
|
||||
t.Errorf("Expected no client, got %+v", clients)
|
||||
}
|
||||
|
||||
ipsResult = []net.IP{ip1}
|
||||
drainWakeupChannel(ch)
|
||||
client.updateGrpcIPs()
|
||||
<-ch
|
||||
|
||||
if clients := client.GetClients(); len(clients) != 1 {
|
||||
t.Errorf("Expected one client, got %+v", clients)
|
||||
} else if clients[0].Target() != targetWithIp1 {
|
||||
t.Errorf("Expected target %s, got %s", targetWithIp1, clients[0].Target())
|
||||
} else if !clients[0].ip.Equal(ip1) {
|
||||
t.Errorf("Expected IP %s, got %s", ip1, clients[0].ip)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_GrpcClients_Encryption(t *testing.T) {
|
||||
serverKey, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
clientKey, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
serverCert := GenerateSelfSignedCertificateForTesting(t, 1024, "Server cert", serverKey)
|
||||
clientCert := GenerateSelfSignedCertificateForTesting(t, 1024, "Testing client", clientKey)
|
||||
|
||||
dir := t.TempDir()
|
||||
serverPrivkeyFile := path.Join(dir, "server-privkey.pem")
|
||||
serverPubkeyFile := path.Join(dir, "server-pubkey.pem")
|
||||
serverCertFile := path.Join(dir, "server-cert.pem")
|
||||
WritePrivateKey(serverKey, serverPrivkeyFile) // nolint
|
||||
WritePublicKey(&serverKey.PublicKey, serverPubkeyFile) // nolint
|
||||
os.WriteFile(serverCertFile, serverCert, 0755) // nolint
|
||||
clientPrivkeyFile := path.Join(dir, "client-privkey.pem")
|
||||
clientPubkeyFile := path.Join(dir, "client-pubkey.pem")
|
||||
clientCertFile := path.Join(dir, "client-cert.pem")
|
||||
WritePrivateKey(clientKey, clientPrivkeyFile) // nolint
|
||||
WritePublicKey(&clientKey.PublicKey, clientPubkeyFile) // nolint
|
||||
os.WriteFile(clientCertFile, clientCert, 0755) // nolint
|
||||
|
||||
serverConfig := goconf.NewConfigFile()
|
||||
serverConfig.AddOption("grpc", "servercertificate", serverCertFile)
|
||||
serverConfig.AddOption("grpc", "serverkey", serverPrivkeyFile)
|
||||
serverConfig.AddOption("grpc", "clientca", clientCertFile)
|
||||
_, addr := NewGrpcServerForTestWithConfig(t, serverConfig)
|
||||
|
||||
clientConfig := goconf.NewConfigFile()
|
||||
clientConfig.AddOption("grpc", "targets", addr)
|
||||
clientConfig.AddOption("grpc", "clientcertificate", clientCertFile)
|
||||
clientConfig.AddOption("grpc", "clientkey", clientPrivkeyFile)
|
||||
clientConfig.AddOption("grpc", "serverca", serverCertFile)
|
||||
clients := NewGrpcClientsForTestWithConfig(t, clientConfig, nil)
|
||||
|
||||
ctx, cancel1 := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel1()
|
||||
|
||||
if err := clients.WaitForInitialized(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, client := range clients.GetClients() {
|
||||
if _, err := client.GetServerId(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
172
grpc_common.go
Normal file
172
grpc_common.go
Normal file
|
@ -0,0 +1,172 @@
|
|||
/**
|
||||
* Standalone signaling server for the Nextcloud Spreed app.
|
||||
* Copyright (C) 2022 struktur AG
|
||||
*
|
||||
* @author Joachim Bauch <bauch@struktur.de>
|
||||
*
|
||||
* @license GNU AGPL version 3 or any later version
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package signaling
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
|
||||
"github.com/dlintw/goconf"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
)
|
||||
|
||||
type reloadableCredentials struct {
|
||||
config *tls.Config
|
||||
|
||||
pool *CertPoolReloader
|
||||
}
|
||||
|
||||
func (c *reloadableCredentials) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
// use local cfg to avoid clobbering ServerName if using multiple endpoints
|
||||
cfg := c.config.Clone()
|
||||
cfg.RootCAs = c.pool.GetCertPool()
|
||||
if cfg.ServerName == "" {
|
||||
serverName, _, err := net.SplitHostPort(authority)
|
||||
if err != nil {
|
||||
// If the authority had no host port or if the authority cannot be parsed, use it as-is.
|
||||
serverName = authority
|
||||
}
|
||||
cfg.ServerName = serverName
|
||||
}
|
||||
conn := tls.Client(rawConn, cfg)
|
||||
errChannel := make(chan error, 1)
|
||||
go func() {
|
||||
errChannel <- conn.Handshake()
|
||||
close(errChannel)
|
||||
}()
|
||||
select {
|
||||
case err := <-errChannel:
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
case <-ctx.Done():
|
||||
conn.Close()
|
||||
return nil, nil, ctx.Err()
|
||||
}
|
||||
tlsInfo := credentials.TLSInfo{
|
||||
State: conn.ConnectionState(),
|
||||
CommonAuthInfo: credentials.CommonAuthInfo{
|
||||
SecurityLevel: credentials.PrivacyAndIntegrity,
|
||||
},
|
||||
}
|
||||
return WrapSyscallConn(rawConn, conn), tlsInfo, nil
|
||||
}
|
||||
|
||||
func (c *reloadableCredentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
cfg := c.config.Clone()
|
||||
cfg.ClientCAs = c.pool.GetCertPool()
|
||||
|
||||
conn := tls.Server(rawConn, cfg)
|
||||
if err := conn.Handshake(); err != nil {
|
||||
conn.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
tlsInfo := credentials.TLSInfo{
|
||||
State: conn.ConnectionState(),
|
||||
CommonAuthInfo: credentials.CommonAuthInfo{
|
||||
SecurityLevel: credentials.PrivacyAndIntegrity,
|
||||
},
|
||||
}
|
||||
return WrapSyscallConn(rawConn, conn), tlsInfo, nil
|
||||
}
|
||||
|
||||
func (c *reloadableCredentials) Info() credentials.ProtocolInfo {
|
||||
return credentials.ProtocolInfo{
|
||||
SecurityProtocol: "tls",
|
||||
SecurityVersion: "1.2",
|
||||
ServerName: c.config.ServerName,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *reloadableCredentials) Clone() credentials.TransportCredentials {
|
||||
return &reloadableCredentials{
|
||||
config: c.config.Clone(),
|
||||
pool: c.pool,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *reloadableCredentials) OverrideServerName(serverName string) error {
|
||||
c.config.ServerName = serverName
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewReloadableCredentials(config *goconf.ConfigFile, server bool) (credentials.TransportCredentials, error) {
|
||||
var prefix string
|
||||
var caPrefix string
|
||||
if server {
|
||||
prefix = "server"
|
||||
caPrefix = "client"
|
||||
} else {
|
||||
prefix = "client"
|
||||
caPrefix = "server"
|
||||
}
|
||||
certificateFile, _ := config.GetString("grpc", prefix+"certificate")
|
||||
keyFile, _ := config.GetString("grpc", prefix+"key")
|
||||
caFile, _ := config.GetString("grpc", caPrefix+"ca")
|
||||
cfg := &tls.Config{
|
||||
NextProtos: []string{"h2"},
|
||||
}
|
||||
if certificateFile != "" && keyFile != "" {
|
||||
loader, err := NewCertificateReloader(certificateFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid GRPC %s certificate / key in %s / %s: %w", prefix, certificateFile, keyFile, err)
|
||||
}
|
||||
|
||||
if server {
|
||||
cfg.GetCertificate = loader.GetCertificate
|
||||
} else {
|
||||
cfg.GetClientCertificate = loader.GetClientCertificate
|
||||
}
|
||||
}
|
||||
|
||||
if caFile != "" {
|
||||
pool, err := NewCertPoolReloader(caFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if server {
|
||||
cfg.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
}
|
||||
creds := &reloadableCredentials{
|
||||
config: cfg,
|
||||
pool: pool,
|
||||
}
|
||||
return creds, nil
|
||||
}
|
||||
|
||||
if cfg.GetCertificate == nil {
|
||||
if server {
|
||||
log.Printf("WARNING: No GRPC server certificate and/or key configured, running unencrypted")
|
||||
} else {
|
||||
log.Printf("WARNING: No GRPC CA configured, expecting unencrypted connections")
|
||||
}
|
||||
return insecure.NewCredentials(), nil
|
||||
}
|
||||
|
||||
return credentials.NewTLS(cfg), nil
|
||||
}
|
88
grpc_common_test.go
Normal file
88
grpc_common_test.go
Normal file
|
@ -0,0 +1,88 @@
|
|||
/**
|
||||
* Standalone signaling server for the Nextcloud Spreed app.
|
||||
* Copyright (C) 2022 struktur AG
|
||||
*
|
||||
* @author Joachim Bauch <bauch@struktur.de>
|
||||
*
|
||||
* @license GNU AGPL version 3 or any later version
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package signaling
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func GenerateSelfSignedCertificateForTesting(t *testing.T, bits int, organization string, key *rsa.PrivateKey) []byte {
|
||||
template := x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{organization},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Hour * 24 * 180),
|
||||
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageClientAuth,
|
||||
x509.ExtKeyUsageServerAuth,
|
||||
},
|
||||
BasicConstraintsValid: true,
|
||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||
}
|
||||
|
||||
data, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data = pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: data,
|
||||
})
|
||||
return data
|
||||
}
|
||||
|
||||
func WritePrivateKey(key *rsa.PrivateKey, filename string) error {
|
||||
data := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(key),
|
||||
})
|
||||
|
||||
return os.WriteFile(filename, data, 0600)
|
||||
}
|
||||
|
||||
func WritePublicKey(key *rsa.PublicKey, filename string) error {
|
||||
data, err := x509.MarshalPKIXPublicKey(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data = pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PUBLIC KEY",
|
||||
Bytes: data,
|
||||
})
|
||||
|
||||
return os.WriteFile(filename, data, 0755)
|
||||
}
|
37
grpc_internal.proto
Normal file
37
grpc_internal.proto
Normal file
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Standalone signaling server for the Nextcloud Spreed app.
|
||||
* Copyright (C) 2022 struktur AG
|
||||
*
|
||||
* @author Joachim Bauch <bauch@struktur.de>
|
||||
*
|
||||
* @license GNU AGPL version 3 or any later version
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
syntax = "proto3";
|
||||
|
||||
option go_package = "github.com/strukturag/nextcloud-spreed-signaling;signaling";
|
||||
|
||||
package signaling;
|
||||
|
||||
service RpcInternal {
|
||||
rpc GetServerId(GetServerIdRequest) returns (GetServerIdReply) {}
|
||||
}
|
||||
|
||||
message GetServerIdRequest {
|
||||
}
|
||||
|
||||
message GetServerIdReply {
|
||||
string serverId = 1;
|
||||
}
|
41
grpc_mcu.proto
Normal file
41
grpc_mcu.proto
Normal file
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Standalone signaling server for the Nextcloud Spreed app.
|
||||
* Copyright (C) 2022 struktur AG
|
||||
*
|
||||
* @author Joachim Bauch <bauch@struktur.de>
|
||||
*
|
||||
* @license GNU AGPL version 3 or any later version
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
syntax = "proto3";
|
||||
|
||||
option go_package = "github.com/strukturag/nextcloud-spreed-signaling;signaling";
|
||||
|
||||
package signaling;
|
||||
|
||||
service RpcMcu {
|
||||
rpc GetPublisherId(GetPublisherIdRequest) returns (GetPublisherIdReply) {}
|
||||
}
|
||||
|
||||
message GetPublisherIdRequest {
|
||||
string sessionId = 1;
|
||||
string streamType = 2;
|
||||
}
|
||||
|
||||
message GetPublisherIdReply {
|
||||
string publisherId = 1;
|
||||
string proxyUrl = 2;
|
||||
string ip = 3;
|
||||
}
|
178
grpc_server.go
Normal file
178
grpc_server.go
Normal file
|
@ -0,0 +1,178 @@
|
|||
/**
|
||||
* Standalone signaling server for the Nextcloud Spreed app.
|
||||
* Copyright (C) 2022 struktur AG
|
||||
*
|
||||
* @author Joachim Bauch <bauch@struktur.de>
|
||||
*
|
||||
* @license GNU AGPL version 3 or any later version
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package signaling
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/dlintw/goconf"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
status "google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
var (
|
||||
GrpcServerId string
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterGrpcServerStats()
|
||||
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
hostname = newRandomString(8)
|
||||
}
|
||||
md := sha256.New()
|
||||
md.Write([]byte(fmt.Sprintf("%s-%s-%d", newRandomString(32), hostname, os.Getpid())))
|
||||
GrpcServerId = hex.EncodeToString(md.Sum(nil))
|
||||
}
|
||||
|
||||
type GrpcServer struct {
|
||||
UnimplementedRpcInternalServer
|
||||
UnimplementedRpcMcuServer
|
||||
UnimplementedRpcSessionsServer
|
||||
|
||||
conn *grpc.Server
|
||||
listener net.Listener
|
||||
serverId string // can be overwritten from tests
|
||||
|
||||
hub *Hub
|
||||
}
|
||||
|
||||
func NewGrpcServer(config *goconf.ConfigFile) (*GrpcServer, error) {
|
||||
var listener net.Listener
|
||||
if addr, _ := config.GetString("grpc", "listen"); addr != "" {
|
||||
var err error
|
||||
listener, err = net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create GRPC listener %s: %w", addr, err)
|
||||
}
|
||||
}
|
||||
|
||||
creds, err := NewReloadableCredentials(config, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn := grpc.NewServer(grpc.Creds(creds))
|
||||
result := &GrpcServer{
|
||||
conn: conn,
|
||||
listener: listener,
|
||||
serverId: GrpcServerId,
|
||||
}
|
||||
RegisterRpcInternalServer(conn, result)
|
||||
RegisterRpcSessionsServer(conn, result)
|
||||
RegisterRpcMcuServer(conn, result)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *GrpcServer) Run() error {
|
||||
if s.listener == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.conn.Serve(s.listener)
|
||||
}
|
||||
|
||||
func (s *GrpcServer) Close() {
|
||||
s.conn.GracefulStop()
|
||||
}
|
||||
|
||||
func (s *GrpcServer) LookupSessionId(ctx context.Context, request *LookupSessionIdRequest) (*LookupSessionIdReply, error) {
|
||||
statsGrpcServerCalls.WithLabelValues("LookupSessionId").Inc()
|
||||
// TODO: Remove debug logging
|
||||
log.Printf("Lookup session id for room session id %s", request.RoomSessionId)
|
||||
sid, err := s.hub.roomSessions.GetSessionId(request.RoomSessionId)
|
||||
if errors.Is(err, ErrNoSuchRoomSession) {
|
||||
return nil, status.Error(codes.NotFound, "no such room session id")
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &LookupSessionIdReply{
|
||||
SessionId: sid,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *GrpcServer) IsSessionInCall(ctx context.Context, request *IsSessionInCallRequest) (*IsSessionInCallReply, error) {
|
||||
statsGrpcServerCalls.WithLabelValues("IsSessionInCall").Inc()
|
||||
// TODO: Remove debug logging
|
||||
log.Printf("Check if session %s is in call %s on %s", request.SessionId, request.RoomId, request.BackendUrl)
|
||||
session := s.hub.GetSessionByPublicId(request.SessionId)
|
||||
if session == nil {
|
||||
return nil, status.Error(codes.NotFound, "no such session id")
|
||||
}
|
||||
|
||||
result := &IsSessionInCallReply{}
|
||||
room := session.GetRoom()
|
||||
if room == nil || room.Id() != request.GetRoomId() || room.Backend().url != request.GetBackendUrl() ||
|
||||
(session.ClientType() != HelloClientTypeInternal && !room.IsSessionInCall(session)) {
|
||||
// Recipient is not in a room, a different room or not in the call.
|
||||
result.InCall = false
|
||||
} else {
|
||||
result.InCall = true
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *GrpcServer) GetPublisherId(ctx context.Context, request *GetPublisherIdRequest) (*GetPublisherIdReply, error) {
|
||||
statsGrpcServerCalls.WithLabelValues("GetPublisherId").Inc()
|
||||
// TODO: Remove debug logging
|
||||
log.Printf("Get %s publisher id for session %s", request.StreamType, request.SessionId)
|
||||
session := s.hub.GetSessionByPublicId(request.SessionId)
|
||||
if session == nil {
|
||||
return nil, status.Error(codes.NotFound, "no such session")
|
||||
}
|
||||
|
||||
clientSession, ok := session.(*ClientSession)
|
||||
if !ok {
|
||||
return nil, status.Error(codes.NotFound, "no such session")
|
||||
}
|
||||
|
||||
publisher := clientSession.GetOrWaitForPublisher(ctx, request.StreamType)
|
||||
if publisher, ok := publisher.(*mcuProxyPublisher); ok {
|
||||
reply := &GetPublisherIdReply{
|
||||
PublisherId: publisher.Id(),
|
||||
ProxyUrl: publisher.conn.rawUrl,
|
||||
}
|
||||
if ip := publisher.conn.ip; ip != nil {
|
||||
reply.Ip = ip.String()
|
||||
}
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
return nil, status.Error(codes.NotFound, "no such publisher")
|
||||
}
|
||||
|
||||
func (s *GrpcServer) GetServerId(ctx context.Context, request *GetServerIdRequest) (*GetServerIdReply, error) {
|
||||
statsGrpcServerCalls.WithLabelValues("GetServerId").Inc()
|
||||
return &GetServerIdReply{
|
||||
ServerId: s.serverId,
|
||||
}, nil
|
||||
}
|
245
grpc_server_test.go
Normal file
245
grpc_server_test.go
Normal file
|
@ -0,0 +1,245 @@
|
|||
/**
|
||||
* Standalone signaling server for the Nextcloud Spreed app.
|
||||
* Copyright (C) 2022 struktur AG
|
||||
*
|
||||
* @author Joachim Bauch <bauch@struktur.de>
|
||||
*
|
||||
* @license GNU AGPL version 3 or any later version
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package signaling
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"net"
|
||||
"os"
|
||||
"path"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/dlintw/goconf"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
)
|
||||
|
||||
func NewGrpcServerForTestWithConfig(t *testing.T, config *goconf.ConfigFile) (server *GrpcServer, addr string) {
|
||||
for port := 50000; port < 50100; port++ {
|
||||
addr = net.JoinHostPort("127.0.0.1", strconv.Itoa(port))
|
||||
config.AddOption("grpc", "listen", addr)
|
||||
var err error
|
||||
server, err = NewGrpcServer(config)
|
||||
if isErrorAddressAlreadyInUse(err) {
|
||||
continue
|
||||
} else if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if server == nil {
|
||||
t.Fatal("could not find free port")
|
||||
}
|
||||
|
||||
// Don't match with own server id by default.
|
||||
server.serverId = "dont-match"
|
||||
|
||||
go func() {
|
||||
if err := server.Run(); err != nil {
|
||||
t.Errorf("could not start GRPC server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
t.Cleanup(func() {
|
||||
server.Close()
|
||||
})
|
||||
return server, addr
|
||||
}
|
||||
|
||||
func NewGrpcServerForTest(t *testing.T) (server *GrpcServer, addr string) {
|
||||
config := goconf.NewConfigFile()
|
||||
return NewGrpcServerForTestWithConfig(t, config)
|
||||
}
|
||||
|
||||
func Test_GrpcServer_ReloadCerts(t *testing.T) {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
org1 := "Testing certificate"
|
||||
cert1 := GenerateSelfSignedCertificateForTesting(t, 1024, org1, key)
|
||||
|
||||
dir := t.TempDir()
|
||||
privkeyFile := path.Join(dir, "privkey.pem")
|
||||
pubkeyFile := path.Join(dir, "pubkey.pem")
|
||||
certFile := path.Join(dir, "cert.pem")
|
||||
WritePrivateKey(key, privkeyFile) // nolint
|
||||
WritePublicKey(&key.PublicKey, pubkeyFile) // nolint
|
||||
os.WriteFile(certFile, cert1, 0755) // nolint
|
||||
|
||||
config := goconf.NewConfigFile()
|
||||
config.AddOption("grpc", "servercertificate", certFile)
|
||||
config.AddOption("grpc", "serverkey", privkeyFile)
|
||||
|
||||
UpdateCertificateCheckIntervalForTest(t, time.Millisecond)
|
||||
_, addr := NewGrpcServerForTestWithConfig(t, config)
|
||||
|
||||
cp1 := x509.NewCertPool()
|
||||
if !cp1.AppendCertsFromPEM(cert1) {
|
||||
t.Fatalf("could not add certificate")
|
||||
}
|
||||
|
||||
cfg1 := &tls.Config{
|
||||
RootCAs: cp1,
|
||||
}
|
||||
conn1, err := tls.Dial("tcp", addr, cfg1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn1.Close() // nolint
|
||||
state1 := conn1.ConnectionState()
|
||||
if certs := state1.PeerCertificates; len(certs) == 0 {
|
||||
t.Errorf("expected certificates, got %+v", state1)
|
||||
} else if len(certs[0].Subject.Organization) == 0 {
|
||||
t.Errorf("expected organization, got %s", certs[0].Subject)
|
||||
} else if certs[0].Subject.Organization[0] != org1 {
|
||||
t.Errorf("expected organization %s, got %s", org1, certs[0].Subject)
|
||||
}
|
||||
|
||||
org2 := "Updated certificate"
|
||||
cert2 := GenerateSelfSignedCertificateForTesting(t, 1024, org2, key)
|
||||
os.WriteFile(certFile, cert2, 0755) // nolint
|
||||
|
||||
cp2 := x509.NewCertPool()
|
||||
if !cp2.AppendCertsFromPEM(cert2) {
|
||||
t.Fatalf("could not add certificate")
|
||||
}
|
||||
|
||||
cfg2 := &tls.Config{
|
||||
RootCAs: cp2,
|
||||
}
|
||||
conn2, err := tls.Dial("tcp", addr, cfg2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn2.Close() // nolint
|
||||
state2 := conn2.ConnectionState()
|
||||
if certs := state2.PeerCertificates; len(certs) == 0 {
|
||||
t.Errorf("expected certificates, got %+v", state2)
|
||||
} else if len(certs[0].Subject.Organization) == 0 {
|
||||
t.Errorf("expected organization, got %s", certs[0].Subject)
|
||||
} else if certs[0].Subject.Organization[0] != org2 {
|
||||
t.Errorf("expected organization %s, got %s", org2, certs[0].Subject)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_GrpcServer_ReloadCA(t *testing.T) {
|
||||
serverKey, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
clientKey, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
serverCert := GenerateSelfSignedCertificateForTesting(t, 1024, "Server cert", serverKey)
|
||||
org1 := "Testing client"
|
||||
clientCert1 := GenerateSelfSignedCertificateForTesting(t, 1024, org1, clientKey)
|
||||
|
||||
dir := t.TempDir()
|
||||
privkeyFile := path.Join(dir, "privkey.pem")
|
||||
pubkeyFile := path.Join(dir, "pubkey.pem")
|
||||
certFile := path.Join(dir, "cert.pem")
|
||||
caFile := path.Join(dir, "ca.pem")
|
||||
WritePrivateKey(serverKey, privkeyFile) // nolint
|
||||
WritePublicKey(&serverKey.PublicKey, pubkeyFile) // nolint
|
||||
os.WriteFile(certFile, serverCert, 0755) // nolint
|
||||
os.WriteFile(caFile, clientCert1, 0755) // nolint
|
||||
|
||||
config := goconf.NewConfigFile()
|
||||
config.AddOption("grpc", "servercertificate", certFile)
|
||||
config.AddOption("grpc", "serverkey", privkeyFile)
|
||||
config.AddOption("grpc", "clientca", caFile)
|
||||
|
||||
UpdateCertificateCheckIntervalForTest(t, time.Millisecond)
|
||||
_, addr := NewGrpcServerForTestWithConfig(t, config)
|
||||
|
||||
pool := x509.NewCertPool()
|
||||
if !pool.AppendCertsFromPEM(serverCert) {
|
||||
t.Fatalf("could not add certificate")
|
||||
}
|
||||
|
||||
pair1, err := tls.X509KeyPair(clientCert1, pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(clientKey),
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg1 := &tls.Config{
|
||||
RootCAs: pool,
|
||||
Certificates: []tls.Certificate{pair1},
|
||||
}
|
||||
client1, err := NewGrpcClient(addr, nil, grpc.WithTransportCredentials(credentials.NewTLS(cfg1)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer client1.Close() // nolint
|
||||
|
||||
ctx1, cancel1 := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel1()
|
||||
|
||||
if _, err := client1.GetServerId(ctx1); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
org2 := "Updated client"
|
||||
clientCert2 := GenerateSelfSignedCertificateForTesting(t, 1024, org2, clientKey)
|
||||
os.WriteFile(caFile, clientCert2, 0755) // nolint
|
||||
|
||||
pair2, err := tls.X509KeyPair(clientCert2, pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(clientKey),
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg2 := &tls.Config{
|
||||
RootCAs: pool,
|
||||
Certificates: []tls.Certificate{pair2},
|
||||
}
|
||||
client2, err := NewGrpcClient(addr, nil, grpc.WithTransportCredentials(credentials.NewTLS(cfg2)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer client2.Close() // nolint
|
||||
|
||||
ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel2()
|
||||
|
||||
// This will fail if the CA certificate has not been reloaded by the server.
|
||||
if _, err := client2.GetServerId(ctx2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
49
grpc_sessions.proto
Normal file
49
grpc_sessions.proto
Normal file
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Standalone signaling server for the Nextcloud Spreed app.
|
||||
* Copyright (C) 2022 struktur AG
|
||||
*
|
||||
* @author Joachim Bauch <bauch@struktur.de>
|
||||
*
|
||||
* @license GNU AGPL version 3 or any later version
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
syntax = "proto3";
|
||||
|
||||
option go_package = "github.com/strukturag/nextcloud-spreed-signaling;signaling";
|
||||
|
||||
package signaling;
|
||||
|
||||
service RpcSessions {
|
||||
rpc LookupSessionId(LookupSessionIdRequest) returns (LookupSessionIdReply) {}
|
||||
rpc IsSessionInCall(IsSessionInCallRequest) returns (IsSessionInCallReply) {}
|
||||
}
|
||||
|
||||
message LookupSessionIdRequest {
|
||||
string roomSessionId = 1;
|
||||
}
|
||||
|
||||
message LookupSessionIdReply {
|
||||
string sessionId = 1;
|
||||
}
|
||||
|
||||
message IsSessionInCallRequest {
|
||||
string sessionId = 1;
|
||||
string roomId = 2;
|
||||
string backendUrl = 3;
|
||||
}
|
||||
|
||||
message IsSessionInCallReply {
|
||||
bool inCall = 1;
|
||||
}
|
65
grpc_stats_prometheus.go
Normal file
65
grpc_stats_prometheus.go
Normal file
|
@ -0,0 +1,65 @@
|
|||
/**
|
||||
* Standalone signaling server for the Nextcloud Spreed app.
|
||||
* Copyright (C) 2022 struktur AG
|
||||
*
|
||||
* @author Joachim Bauch <bauch@struktur.de>
|
||||
*
|
||||
* @license GNU AGPL version 3 or any later version
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package signaling
|
||||
|
||||
import (
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
var (
|
||||
statsGrpcClients = prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Namespace: "signaling",
|
||||
Subsystem: "grpc",
|
||||
Name: "clients",
|
||||
Help: "The current number of GRPC clients",
|
||||
})
|
||||
statsGrpcClientCalls = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: "signaling",
|
||||
Subsystem: "grpc",
|
||||
Name: "client_calls_total",
|
||||
Help: "The total number of GRPC client calls",
|
||||
}, []string{"method"})
|
||||
|
||||
grpcClientStats = []prometheus.Collector{
|
||||
statsGrpcClients,
|
||||
statsGrpcClientCalls,
|
||||
}
|
||||
|
||||
statsGrpcServerCalls = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: "signaling",
|
||||
Subsystem: "grpc",
|
||||
Name: "server_calls_total",
|
||||
Help: "The total number of GRPC server calls",
|
||||
}, []string{"method"})
|
||||
|
||||
grpcServerStats = []prometheus.Collector{
|
||||
statsGrpcServerCalls,
|
||||
}
|
||||
)
|
||||
|
||||
func RegisterGrpcClientStats() {
|
||||
registerAll(grpcClientStats...)
|
||||
}
|
||||
|
||||
func RegisterGrpcServerStats() {
|
||||
registerAll(grpcServerStats...)
|
||||
}
|
386
hub.go
386
hub.go
|
@ -28,6 +28,7 @@ import (
|
|||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"log"
|
||||
|
@ -103,7 +104,7 @@ type Hub struct {
|
|||
// 64-bit members that are accessed atomically must be 64-bit aligned.
|
||||
sid uint64
|
||||
|
||||
nats NatsClient
|
||||
events AsyncEvents
|
||||
upgrader websocket.Upgrader
|
||||
cookie *securecookie.SecureCookie
|
||||
info *WelcomeServerMessage
|
||||
|
@ -149,9 +150,12 @@ type Hub struct {
|
|||
geoip *GeoLookup
|
||||
geoipOverrides map[*net.IPNet]string
|
||||
geoipUpdating int32
|
||||
|
||||
rpcServer *GrpcServer
|
||||
rpcClients *GrpcClients
|
||||
}
|
||||
|
||||
func NewHub(config *goconf.ConfigFile, nats NatsClient, r *mux.Router, version string) (*Hub, error) {
|
||||
func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer, rpcClients *GrpcClients, etcdClient *EtcdClient, r *mux.Router, version string) (*Hub, error) {
|
||||
hashKey, _ := config.GetString("sessions", "hashkey")
|
||||
switch len(hashKey) {
|
||||
case 32:
|
||||
|
@ -182,7 +186,7 @@ func NewHub(config *goconf.ConfigFile, nats NatsClient, r *mux.Router, version s
|
|||
maxConcurrentRequestsPerHost = defaultMaxConcurrentRequestsPerHost
|
||||
}
|
||||
|
||||
backend, err := NewBackendClient(config, maxConcurrentRequestsPerHost, version)
|
||||
backend, err := NewBackendClient(config, maxConcurrentRequestsPerHost, version, etcdClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -211,7 +215,7 @@ func NewHub(config *goconf.ConfigFile, nats NatsClient, r *mux.Router, version s
|
|||
decodeCaches = append(decodeCaches, NewLruCache(decodeCacheSize))
|
||||
}
|
||||
|
||||
roomSessions, err := NewBuiltinRoomSessions()
|
||||
roomSessions, err := NewBuiltinRoomSessions(rpcClients)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -293,7 +297,7 @@ func NewHub(config *goconf.ConfigFile, nats NatsClient, r *mux.Router, version s
|
|||
}
|
||||
|
||||
hub := &Hub{
|
||||
nats: nats,
|
||||
events: events,
|
||||
upgrader: websocket.Upgrader{
|
||||
ReadBufferSize: websocketReadBufferSize,
|
||||
WriteBufferSize: websocketWriteBufferSize,
|
||||
|
@ -333,12 +337,18 @@ func NewHub(config *goconf.ConfigFile, nats NatsClient, r *mux.Router, version s
|
|||
|
||||
geoip: geoip,
|
||||
geoipOverrides: geoipOverrides,
|
||||
|
||||
rpcServer: rpcServer,
|
||||
rpcClients: rpcClients,
|
||||
}
|
||||
hub.setWelcomeMessage(&ServerMessage{
|
||||
Type: "welcome",
|
||||
Welcome: NewWelcomeServerMessage(version, DefaultWelcomeFeatures...),
|
||||
})
|
||||
backend.hub = hub
|
||||
if rpcServer != nil {
|
||||
rpcServer.hub = hub
|
||||
}
|
||||
hub.upgrader.CheckOrigin = hub.checkOrigin
|
||||
r.HandleFunc("/spreed", func(w http.ResponseWriter, r *http.Request) {
|
||||
hub.serveWs(w, r)
|
||||
|
@ -418,6 +428,7 @@ func (h *Hub) Run() {
|
|||
go h.updateGeoDatabase()
|
||||
h.roomPing.Start()
|
||||
defer h.roomPing.Stop()
|
||||
defer h.backend.Close()
|
||||
|
||||
housekeeping := time.NewTicker(housekeepingInterval)
|
||||
geoipUpdater := time.NewTicker(24 * time.Hour)
|
||||
|
@ -461,6 +472,7 @@ func (h *Hub) Reload(config *goconf.ConfigFile) {
|
|||
h.mcu.Reload(config)
|
||||
}
|
||||
h.backend.Reload(config)
|
||||
h.rpcClients.Reload(config)
|
||||
}
|
||||
|
||||
func reverseSessionId(s string) (string, error) {
|
||||
|
@ -553,9 +565,13 @@ func (h *Hub) GetSessionByPublicId(sessionId string) Session {
|
|||
return nil
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
session := h.sessions[data.Sid]
|
||||
h.mu.Unlock()
|
||||
if session != nil && session.PublicId() != sessionId {
|
||||
// Session was created on different server.
|
||||
return nil
|
||||
}
|
||||
return session
|
||||
}
|
||||
|
||||
|
@ -985,7 +1001,7 @@ func (h *Hub) processHelloInternal(client *Client, message *ClientMessage) {
|
|||
h.processRegister(client, message, backend, auth)
|
||||
}
|
||||
|
||||
func (h *Hub) disconnectByRoomSessionId(roomSessionId string) {
|
||||
func (h *Hub) disconnectByRoomSessionId(roomSessionId string, backend *Backend) {
|
||||
sessionId, err := h.roomSessions.GetSessionId(roomSessionId)
|
||||
if err == ErrNoSuchRoomSession {
|
||||
return
|
||||
|
@ -997,13 +1013,16 @@ func (h *Hub) disconnectByRoomSessionId(roomSessionId string) {
|
|||
session := h.GetSessionByPublicId(sessionId)
|
||||
if session == nil {
|
||||
// Session is located on a different server.
|
||||
msg := &ServerMessage{
|
||||
Type: "bye",
|
||||
Bye: &ByeServerMessage{
|
||||
Reason: "room_session_reconnected",
|
||||
msg := &AsyncMessage{
|
||||
Type: "message",
|
||||
Message: &ServerMessage{
|
||||
Type: "bye",
|
||||
Bye: &ByeServerMessage{
|
||||
Reason: "room_session_reconnected",
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := h.nats.PublishMessage("session."+sessionId, msg); err != nil {
|
||||
if err := h.events.PublishSessionMessage(sessionId, backend, msg); err != nil {
|
||||
log.Printf("Could not send reconnect bye to session %s: %s", sessionId, err)
|
||||
}
|
||||
return
|
||||
|
@ -1097,7 +1116,7 @@ func (h *Hub) processRoom(client *Client, message *ClientMessage) {
|
|||
if message.Room.SessionId != "" {
|
||||
// There can only be one connection per Nextcloud Talk session,
|
||||
// disconnect any other connections without sending a "leave" event.
|
||||
h.disconnectByRoomSessionId(message.Room.SessionId)
|
||||
h.disconnectByRoomSessionId(message.Room.SessionId, session.Backend())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1125,7 +1144,7 @@ func (h *Hub) removeRoom(room *Room) {
|
|||
|
||||
func (h *Hub) createRoom(id string, properties *json.RawMessage, backend *Backend) (*Room, error) {
|
||||
// Note the write lock must be held.
|
||||
room, err := NewRoom(id, properties, h, h.nats, backend)
|
||||
room, err := NewRoom(id, properties, h, h.events, backend)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -1149,7 +1168,7 @@ func (h *Hub) processJoinRoom(session *ClientSession, message *ClientMessage, ro
|
|||
|
||||
roomId := room.Room.RoomId
|
||||
internalRoomId := getRoomIdForBackend(roomId, session.Backend())
|
||||
if err := session.SubscribeRoomNats(h.nats, roomId, message.Room.SessionId); err != nil {
|
||||
if err := session.SubscribeRoomEvents(roomId, message.Room.SessionId); err != nil {
|
||||
session.SendMessage(message.NewWrappedErrorServerMessage(err))
|
||||
// The client (implicitly) left the room due to an error.
|
||||
h.sendRoom(session, nil, nil)
|
||||
|
@ -1164,7 +1183,7 @@ func (h *Hub) processJoinRoom(session *ClientSession, message *ClientMessage, ro
|
|||
h.ru.Unlock()
|
||||
session.SendMessage(message.NewWrappedErrorServerMessage(err))
|
||||
// The client (implicitly) left the room due to an error.
|
||||
session.UnsubscribeRoomNats()
|
||||
session.UnsubscribeRoomEvents()
|
||||
h.sendRoom(session, nil, nil)
|
||||
return
|
||||
}
|
||||
|
@ -1182,65 +1201,7 @@ func (h *Hub) processJoinRoom(session *ClientSession, message *ClientMessage, ro
|
|||
session.SetPermissions(*room.Room.Permissions)
|
||||
}
|
||||
h.sendRoom(session, message, r)
|
||||
h.notifyUserJoinedRoom(r, session, room.Room.Session)
|
||||
}
|
||||
|
||||
func (h *Hub) notifyUserJoinedRoom(room *Room, session *ClientSession, sessionData *json.RawMessage) {
|
||||
// Register session with the room
|
||||
if sessions := room.AddSession(session, sessionData); len(sessions) > 0 {
|
||||
events := make([]*EventServerMessageSessionEntry, 0, len(sessions))
|
||||
for _, s := range sessions {
|
||||
entry := &EventServerMessageSessionEntry{
|
||||
SessionId: s.PublicId(),
|
||||
UserId: s.UserId(),
|
||||
User: s.UserData(),
|
||||
}
|
||||
if s, ok := s.(*ClientSession); ok {
|
||||
entry.RoomSessionId = s.RoomSessionId()
|
||||
}
|
||||
events = append(events, entry)
|
||||
}
|
||||
msg := &ServerMessage{
|
||||
Type: "event",
|
||||
Event: &EventServerMessage{
|
||||
Target: "room",
|
||||
Type: "join",
|
||||
Join: events,
|
||||
},
|
||||
}
|
||||
|
||||
// No need to send through NATS, the session is connected locally.
|
||||
session.SendMessage(msg)
|
||||
|
||||
// Notify about initial flags of virtual sessions.
|
||||
for _, s := range sessions {
|
||||
vsess, ok := s.(*VirtualSession)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
flags := vsess.Flags()
|
||||
if flags == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
msg := &ServerMessage{
|
||||
Type: "event",
|
||||
Event: &EventServerMessage{
|
||||
Target: "participants",
|
||||
Type: "flags",
|
||||
Flags: &RoomFlagsServerMessage{
|
||||
RoomId: room.Id(),
|
||||
SessionId: vsess.PublicId(),
|
||||
Flags: vsess.Flags(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// No need to send through NATS, the session is connected locally.
|
||||
session.SendMessage(msg)
|
||||
}
|
||||
}
|
||||
r.AddSession(session, room.Room.Session)
|
||||
}
|
||||
|
||||
func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) {
|
||||
|
@ -1255,68 +1216,72 @@ func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) {
|
|||
var subject string
|
||||
var clientData *MessageClientMessageData
|
||||
var serverRecipient *MessageClientMessageRecipient
|
||||
var recipientSessionId string
|
||||
var room *Room
|
||||
switch msg.Recipient.Type {
|
||||
case RecipientTypeSession:
|
||||
data := h.decodeSessionId(msg.Recipient.SessionId, publicSessionName)
|
||||
if data != nil {
|
||||
if data.BackendId != session.Backend().Id() {
|
||||
if h.mcu != nil {
|
||||
// Maybe this is a message to be processed by the MCU.
|
||||
var data MessageClientMessageData
|
||||
if err := json.Unmarshal(*msg.Data, &data); err == nil {
|
||||
clientData = &data
|
||||
|
||||
switch clientData.Type {
|
||||
case "requestoffer":
|
||||
// Process asynchronously to avoid blocking regular
|
||||
// message processing for this client.
|
||||
go h.processMcuMessage(session, message, msg, clientData)
|
||||
return
|
||||
case "offer":
|
||||
fallthrough
|
||||
case "answer":
|
||||
fallthrough
|
||||
case "endOfCandidates":
|
||||
fallthrough
|
||||
case "selectStream":
|
||||
fallthrough
|
||||
case "candidate":
|
||||
h.processMcuMessage(session, message, msg, clientData)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sess := h.GetSessionByPublicId(msg.Recipient.SessionId)
|
||||
if sess != nil {
|
||||
// Recipient is also connected to this instance.
|
||||
if sess.Backend().Id() != session.Backend().Id() {
|
||||
// Clients are only allowed to send to sessions from the same backend.
|
||||
return
|
||||
}
|
||||
|
||||
if h.mcu != nil {
|
||||
// Maybe this is a message to be processed by the MCU.
|
||||
var data MessageClientMessageData
|
||||
if err := json.Unmarshal(*msg.Data, &data); err == nil {
|
||||
clientData = &data
|
||||
switch data.Type {
|
||||
case "requestoffer":
|
||||
// Process asynchronously to avoid blocking regular
|
||||
// message processing for this client.
|
||||
go h.processMcuMessage(session, session, message, msg, &data)
|
||||
return
|
||||
case "offer":
|
||||
fallthrough
|
||||
case "answer":
|
||||
fallthrough
|
||||
case "endOfCandidates":
|
||||
fallthrough
|
||||
case "selectStream":
|
||||
fallthrough
|
||||
case "candidate":
|
||||
h.processMcuMessage(session, session, message, msg, &data)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if msg.Recipient.SessionId == session.PublicId() {
|
||||
// Don't loop messages to the sender.
|
||||
return
|
||||
}
|
||||
|
||||
subject = "session." + msg.Recipient.SessionId
|
||||
h.mu.RLock()
|
||||
sess, found := h.sessions[data.Sid]
|
||||
if found {
|
||||
if sess, ok := sess.(*ClientSession); ok {
|
||||
recipient = sess
|
||||
}
|
||||
recipientSessionId = msg.Recipient.SessionId
|
||||
if sess, ok := sess.(*ClientSession); ok {
|
||||
recipient = sess
|
||||
}
|
||||
|
||||
// Send to client connection for virtual sessions.
|
||||
if sess.ClientType() == HelloClientTypeVirtual {
|
||||
virtualSession := sess.(*VirtualSession)
|
||||
clientSession := virtualSession.Session()
|
||||
subject = "session." + clientSession.PublicId()
|
||||
recipient = clientSession
|
||||
// The client should see his session id as recipient.
|
||||
serverRecipient = &MessageClientMessageRecipient{
|
||||
Type: "session",
|
||||
SessionId: virtualSession.SessionId(),
|
||||
}
|
||||
// Send to client connection for virtual sessions.
|
||||
if sess.ClientType() == HelloClientTypeVirtual {
|
||||
virtualSession := sess.(*VirtualSession)
|
||||
clientSession := virtualSession.Session()
|
||||
subject = "session." + clientSession.PublicId()
|
||||
recipientSessionId = clientSession.PublicId()
|
||||
recipient = clientSession
|
||||
// The client should see his session id as recipient.
|
||||
serverRecipient = &MessageClientMessageRecipient{
|
||||
Type: "session",
|
||||
SessionId: virtualSession.SessionId(),
|
||||
}
|
||||
}
|
||||
h.mu.RUnlock()
|
||||
} else {
|
||||
subject = "session." + msg.Recipient.SessionId
|
||||
recipientSessionId = msg.Recipient.SessionId
|
||||
}
|
||||
case RecipientTypeUser:
|
||||
if msg.Recipient.UserId != "" {
|
||||
|
@ -1331,7 +1296,7 @@ func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) {
|
|||
}
|
||||
case RecipientTypeRoom:
|
||||
if session != nil {
|
||||
if room := session.GetRoom(); room != nil {
|
||||
if room = session.GetRoom(); room != nil {
|
||||
subject = GetSubjectForRoomId(room.Id(), room.Backend())
|
||||
|
||||
if h.mcu != nil {
|
||||
|
@ -1384,7 +1349,7 @@ func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) {
|
|||
},
|
||||
}
|
||||
if recipient != nil {
|
||||
// The recipient is connected to this instance, no need to go through NATS.
|
||||
// The recipient is connected to this instance, no need to go through asynchronous events.
|
||||
if clientData != nil && clientData.Type == "sendoffer" {
|
||||
if err := session.IsAllowedToSend(clientData); err != nil {
|
||||
log.Printf("Session %s is not allowed to send offer for %s, ignoring (%s)", session.PublicId(), clientData.RoomType, err)
|
||||
|
@ -1392,21 +1357,83 @@ func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) {
|
|||
return
|
||||
}
|
||||
|
||||
msg.Recipient.SessionId = session.PublicId()
|
||||
// It may take some time for the publisher (which is the current
|
||||
// client) to start his stream, so we must not block the active
|
||||
// goroutine.
|
||||
go h.processMcuMessage(session, recipient, message, msg, clientData)
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), h.mcuTimeout)
|
||||
defer cancel()
|
||||
|
||||
mc, err := recipient.GetOrCreateSubscriber(ctx, h.mcu, session.PublicId(), clientData.RoomType)
|
||||
if err != nil {
|
||||
log.Printf("Could not create MCU subscriber for session %s to send %+v to %s: %s", session.PublicId(), clientData, recipient.PublicId(), err)
|
||||
sendMcuClientNotFound(session, message)
|
||||
return
|
||||
} else if mc == nil {
|
||||
log.Printf("No MCU subscriber found for session %s to send %+v to %s", session.PublicId(), clientData, recipient.PublicId())
|
||||
sendMcuClientNotFound(session, message)
|
||||
return
|
||||
}
|
||||
|
||||
mc.SendMessage(context.TODO(), msg, clientData, func(err error, response map[string]interface{}) {
|
||||
if err != nil {
|
||||
log.Printf("Could not send MCU message %+v for session %s to %s: %s", clientData, session.PublicId(), recipient.PublicId(), err)
|
||||
sendMcuProcessingFailed(session, message)
|
||||
return
|
||||
} else if response == nil {
|
||||
// No response received
|
||||
return
|
||||
}
|
||||
|
||||
// The response (i.e. the "offer") must be sent to the recipient but
|
||||
// should be coming from the sender.
|
||||
msg.Recipient.SessionId = session.PublicId()
|
||||
h.sendMcuMessageResponse(recipient, mc, msg, clientData, response)
|
||||
})
|
||||
}()
|
||||
return
|
||||
}
|
||||
|
||||
recipient.SendMessage(response)
|
||||
} else {
|
||||
if clientData != nil && clientData.Type == "sendoffer" {
|
||||
// TODO(jojo): Implement this.
|
||||
log.Printf("Sending offers to remote clients is not supported yet (client %s)", session.PublicId())
|
||||
if err := session.IsAllowedToSend(clientData); err != nil {
|
||||
log.Printf("Session %s is not allowed to send offer for %s, ignoring (%s)", session.PublicId(), clientData.RoomType, err)
|
||||
sendNotAllowed(session, message, "Not allowed to send offer")
|
||||
return
|
||||
}
|
||||
|
||||
async := &AsyncMessage{
|
||||
Type: "sendoffer",
|
||||
SendOffer: &SendOfferMessage{
|
||||
MessageId: message.Id,
|
||||
SessionId: session.PublicId(),
|
||||
Data: clientData,
|
||||
},
|
||||
}
|
||||
if err := h.events.PublishSessionMessage(recipientSessionId, session.Backend(), async); err != nil {
|
||||
log.Printf("Error publishing message to remote session: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err := h.nats.PublishMessage(subject, response); err != nil {
|
||||
|
||||
async := &AsyncMessage{
|
||||
Type: "message",
|
||||
Message: response,
|
||||
}
|
||||
var err error
|
||||
switch msg.Recipient.Type {
|
||||
case RecipientTypeSession:
|
||||
err = h.events.PublishSessionMessage(recipientSessionId, session.Backend(), async)
|
||||
case RecipientTypeUser:
|
||||
err = h.events.PublishUserMessage(msg.Recipient.UserId, session.Backend(), async)
|
||||
case RecipientTypeRoom:
|
||||
err = h.events.PublishRoomMessage(room.Id(), session.Backend(), async)
|
||||
default:
|
||||
err = fmt.Errorf("unsupported recipient type: %s", msg.Recipient.Type)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Error publishing message to remote session: %s", err)
|
||||
}
|
||||
}
|
||||
|
@ -1437,9 +1464,11 @@ func (h *Hub) processControlMsg(client *Client, message *ClientMessage) {
|
|||
return
|
||||
}
|
||||
|
||||
var recipient *Client
|
||||
var recipient *ClientSession
|
||||
var subject string
|
||||
var serverRecipient *MessageClientMessageRecipient
|
||||
var recipientSessionId string
|
||||
var room *Room
|
||||
switch msg.Recipient.Type {
|
||||
case RecipientTypeSession:
|
||||
data := h.decodeSessionId(msg.Recipient.SessionId, publicSessionName)
|
||||
|
@ -1450,16 +1479,21 @@ func (h *Hub) processControlMsg(client *Client, message *ClientMessage) {
|
|||
}
|
||||
|
||||
subject = "session." + msg.Recipient.SessionId
|
||||
recipientSessionId = msg.Recipient.SessionId
|
||||
h.mu.RLock()
|
||||
recipient = h.clients[data.Sid]
|
||||
if recipient == nil {
|
||||
sess, found := h.sessions[data.Sid]
|
||||
if found && sess.PublicId() == msg.Recipient.SessionId {
|
||||
if sess, ok := sess.(*ClientSession); ok {
|
||||
recipient = sess
|
||||
}
|
||||
|
||||
// Send to client connection for virtual sessions.
|
||||
sess := h.sessions[data.Sid]
|
||||
if sess != nil && sess.ClientType() == HelloClientTypeVirtual {
|
||||
if sess.ClientType() == HelloClientTypeVirtual {
|
||||
virtualSession := sess.(*VirtualSession)
|
||||
clientSession := virtualSession.Session()
|
||||
subject = "session." + clientSession.PublicId()
|
||||
recipient = clientSession.GetClient()
|
||||
recipientSessionId = clientSession.PublicId()
|
||||
recipient = clientSession
|
||||
// The client should see his session id as recipient.
|
||||
serverRecipient = &MessageClientMessageRecipient{
|
||||
Type: "session",
|
||||
|
@ -1482,7 +1516,7 @@ func (h *Hub) processControlMsg(client *Client, message *ClientMessage) {
|
|||
}
|
||||
case RecipientTypeRoom:
|
||||
if session != nil {
|
||||
if room := session.GetRoom(); room != nil {
|
||||
if room = session.GetRoom(); room != nil {
|
||||
subject = GetSubjectForRoomId(room.Id(), room.Backend())
|
||||
}
|
||||
}
|
||||
|
@ -1507,7 +1541,22 @@ func (h *Hub) processControlMsg(client *Client, message *ClientMessage) {
|
|||
if recipient != nil {
|
||||
recipient.SendMessage(response)
|
||||
} else {
|
||||
if err := h.nats.PublishMessage(subject, response); err != nil {
|
||||
async := &AsyncMessage{
|
||||
Type: "message",
|
||||
Message: response,
|
||||
}
|
||||
var err error
|
||||
switch msg.Recipient.Type {
|
||||
case RecipientTypeSession:
|
||||
err = h.events.PublishSessionMessage(recipientSessionId, session.Backend(), async)
|
||||
case RecipientTypeUser:
|
||||
err = h.events.PublishUserMessage(msg.Recipient.UserId, session.Backend(), async)
|
||||
case RecipientTypeRoom:
|
||||
err = h.events.PublishRoomMessage(room.Id(), room.Backend(), async)
|
||||
default:
|
||||
err = fmt.Errorf("unsupported recipient type: %s", msg.Recipient.Type)
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("Error publishing message to remote session: %s", err)
|
||||
}
|
||||
}
|
||||
|
@ -1727,7 +1776,41 @@ func sendMcuProcessingFailed(session *ClientSession, message *ClientMessage) {
|
|||
session.SendMessage(response)
|
||||
}
|
||||
|
||||
func (h *Hub) isInSameCall(senderSession *ClientSession, recipientSessionId string) bool {
|
||||
func (h *Hub) isInSameCallRemote(ctx context.Context, senderSession *ClientSession, senderRoom *Room, recipientSessionId string) bool {
|
||||
clients := h.rpcClients.GetClients()
|
||||
if len(clients) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
var result int32
|
||||
var wg sync.WaitGroup
|
||||
rpcCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
for _, client := range clients {
|
||||
wg.Add(1)
|
||||
go func(client *GrpcClient) {
|
||||
defer wg.Done()
|
||||
|
||||
inCall, err := client.IsSessionInCall(rpcCtx, recipientSessionId, senderRoom)
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return
|
||||
} else if err != nil {
|
||||
log.Printf("Error checking session %s in call on %s: %s", recipientSessionId, client.Target(), err)
|
||||
return
|
||||
} else if !inCall {
|
||||
return
|
||||
}
|
||||
|
||||
cancel()
|
||||
atomic.StoreInt32(&result, 1)
|
||||
}(client)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
return atomic.LoadInt32(&result) != 0
|
||||
}
|
||||
|
||||
func (h *Hub) isInSameCall(ctx context.Context, senderSession *ClientSession, recipientSessionId string) bool {
|
||||
if senderSession.ClientType() == HelloClientTypeInternal {
|
||||
// Internal clients may subscribe all streams.
|
||||
return true
|
||||
|
@ -1742,7 +1825,7 @@ func (h *Hub) isInSameCall(senderSession *ClientSession, recipientSessionId stri
|
|||
recipientSession := h.GetSessionByPublicId(recipientSessionId)
|
||||
if recipientSession == nil {
|
||||
// Recipient session does not exist.
|
||||
return false
|
||||
return h.isInSameCallRemote(ctx, senderSession, senderRoom, recipientSessionId)
|
||||
}
|
||||
|
||||
recipientRoom := recipientSession.GetRoom()
|
||||
|
@ -1755,7 +1838,7 @@ func (h *Hub) isInSameCall(senderSession *ClientSession, recipientSessionId stri
|
|||
return true
|
||||
}
|
||||
|
||||
func (h *Hub) processMcuMessage(senderSession *ClientSession, session *ClientSession, client_message *ClientMessage, message *MessageClientMessage, data *MessageClientMessageData) {
|
||||
func (h *Hub) processMcuMessage(session *ClientSession, client_message *ClientMessage, message *MessageClientMessage, data *MessageClientMessageData) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), h.mcuTimeout)
|
||||
defer cancel()
|
||||
|
||||
|
@ -1771,29 +1854,28 @@ func (h *Hub) processMcuMessage(senderSession *ClientSession, session *ClientSes
|
|||
|
||||
// A user is only allowed to subscribe a stream if she is in the same room
|
||||
// as the other user and both have their "inCall" flag set.
|
||||
if !h.allowSubscribeAnyStream && !h.isInSameCall(senderSession, message.Recipient.SessionId) {
|
||||
if !h.allowSubscribeAnyStream && !h.isInSameCall(ctx, session, message.Recipient.SessionId) {
|
||||
log.Printf("Session %s is not in the same call as session %s, not requesting offer", session.PublicId(), message.Recipient.SessionId)
|
||||
sendNotAllowed(senderSession, client_message, "Not allowed to request offer.")
|
||||
sendNotAllowed(session, client_message, "Not allowed to request offer.")
|
||||
return
|
||||
}
|
||||
|
||||
clientType = "subscriber"
|
||||
mc, err = session.GetOrCreateSubscriber(ctx, h.mcu, message.Recipient.SessionId, data.RoomType)
|
||||
case "sendoffer":
|
||||
// Permissions have already been checked in "processMessageMsg".
|
||||
clientType = "subscriber"
|
||||
mc, err = session.GetOrCreateSubscriber(ctx, h.mcu, message.Recipient.SessionId, data.RoomType)
|
||||
// Will be sent directly.
|
||||
return
|
||||
case "offer":
|
||||
clientType = "publisher"
|
||||
mc, err = session.GetOrCreatePublisher(ctx, h.mcu, data.RoomType, data)
|
||||
if err, ok := err.(*PermissionError); ok {
|
||||
log.Printf("Session %s is not allowed to offer %s, ignoring (%s)", session.PublicId(), data.RoomType, err)
|
||||
sendNotAllowed(senderSession, client_message, "Not allowed to publish.")
|
||||
sendNotAllowed(session, client_message, "Not allowed to publish.")
|
||||
return
|
||||
}
|
||||
if err, ok := err.(*SdpError); ok {
|
||||
log.Printf("Session %s sent unsupported offer %s, ignoring (%s)", session.PublicId(), data.RoomType, err)
|
||||
sendNotAllowed(senderSession, client_message, "Not allowed to publish.")
|
||||
sendNotAllowed(session, client_message, "Not allowed to publish.")
|
||||
return
|
||||
}
|
||||
case "selectStream":
|
||||
|
@ -1808,7 +1890,7 @@ func (h *Hub) processMcuMessage(senderSession *ClientSession, session *ClientSes
|
|||
if session.PublicId() == message.Recipient.SessionId {
|
||||
if err := session.IsAllowedToSend(data); err != nil {
|
||||
log.Printf("Session %s is not allowed to send candidate for %s, ignoring (%s)", session.PublicId(), data.RoomType, err)
|
||||
sendNotAllowed(senderSession, client_message, "Not allowed to send candidate.")
|
||||
sendNotAllowed(session, client_message, "Not allowed to send candidate.")
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -1821,18 +1903,18 @@ func (h *Hub) processMcuMessage(senderSession *ClientSession, session *ClientSes
|
|||
}
|
||||
if err != nil {
|
||||
log.Printf("Could not create MCU %s for session %s to send %+v to %s: %s", clientType, session.PublicId(), data, message.Recipient.SessionId, err)
|
||||
sendMcuClientNotFound(senderSession, client_message)
|
||||
sendMcuClientNotFound(session, client_message)
|
||||
return
|
||||
} else if mc == nil {
|
||||
log.Printf("No MCU %s found for session %s to send %+v to %s", clientType, session.PublicId(), data, message.Recipient.SessionId)
|
||||
sendMcuClientNotFound(senderSession, client_message)
|
||||
sendMcuClientNotFound(session, client_message)
|
||||
return
|
||||
}
|
||||
|
||||
mc.SendMessage(context.TODO(), message, data, func(err error, response map[string]interface{}) {
|
||||
if err != nil {
|
||||
log.Printf("Could not send MCU message %+v for session %s to %s: %s", data, session.PublicId(), message.Recipient.SessionId, err)
|
||||
sendMcuProcessingFailed(senderSession, client_message)
|
||||
sendMcuProcessingFailed(session, client_message)
|
||||
return
|
||||
} else if response == nil {
|
||||
// No response received
|
||||
|
|
1048
hub_test.go
1048
hub_test.go
File diff suppressed because it is too large
Load diff
229
mcu_proxy.go
229
mcu_proxy.go
|
@ -26,6 +26,7 @@ import (
|
|||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
|
@ -308,6 +309,9 @@ type mcuProxyConnection struct {
|
|||
shutdownScheduled uint32
|
||||
closeScheduled uint32
|
||||
trackClose uint32
|
||||
temporary uint32
|
||||
|
||||
connectedNotifier SingleNotifier
|
||||
|
||||
helloMsgId string
|
||||
sessionId string
|
||||
|
@ -363,6 +367,7 @@ type mcuProxyConnectionStats struct {
|
|||
Clients int64 `json:"clients"`
|
||||
Load *int64 `json:"load,omitempty"`
|
||||
Shutdown *bool `json:"shutdown,omitempty"`
|
||||
Temporary *bool `json:"temporary,omitempty"`
|
||||
Uptime *time.Time `json:"uptime,omitempty"`
|
||||
}
|
||||
|
||||
|
@ -379,6 +384,8 @@ func (c *mcuProxyConnection) GetStats() *mcuProxyConnectionStats {
|
|||
result.Load = &load
|
||||
shutdown := c.IsShutdownScheduled()
|
||||
result.Shutdown = &shutdown
|
||||
temporary := c.IsTemporary()
|
||||
result.Temporary = &temporary
|
||||
}
|
||||
c.mu.Unlock()
|
||||
c.publishersLock.RLock()
|
||||
|
@ -399,6 +406,18 @@ func (c *mcuProxyConnection) Country() string {
|
|||
return c.country.Load().(string)
|
||||
}
|
||||
|
||||
func (c *mcuProxyConnection) IsTemporary() bool {
|
||||
return atomic.LoadUint32(&c.temporary) != 0
|
||||
}
|
||||
|
||||
func (c *mcuProxyConnection) setTemporary() {
|
||||
atomic.StoreUint32(&c.temporary, 1)
|
||||
}
|
||||
|
||||
func (c *mcuProxyConnection) clearTemporary() {
|
||||
atomic.StoreUint32(&c.temporary, 0)
|
||||
}
|
||||
|
||||
func (c *mcuProxyConnection) IsShutdownScheduled() bool {
|
||||
return atomic.LoadUint32(&c.shutdownScheduled) != 0 || atomic.LoadUint32(&c.closeScheduled) != 0
|
||||
}
|
||||
|
@ -483,6 +502,7 @@ func (c *mcuProxyConnection) writePump() {
|
|||
}()
|
||||
|
||||
c.reconnectTimer = time.NewTimer(0)
|
||||
defer c.reconnectTimer.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-c.reconnectTimer.C:
|
||||
|
@ -540,6 +560,8 @@ func (c *mcuProxyConnection) close() {
|
|||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.connectedNotifier.Reset()
|
||||
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
|
@ -587,6 +609,11 @@ func (c *mcuProxyConnection) scheduleReconnect() {
|
|||
}
|
||||
c.close()
|
||||
|
||||
if c.IsShutdownScheduled() {
|
||||
c.proxy.removeConnection(c)
|
||||
return
|
||||
}
|
||||
|
||||
interval := atomic.LoadInt64(&c.reconnectInterval)
|
||||
c.reconnectTimer.Reset(time.Duration(interval))
|
||||
|
||||
|
@ -634,6 +661,11 @@ func (c *mcuProxyConnection) reconnect() {
|
|||
return
|
||||
}
|
||||
|
||||
if c.IsShutdownScheduled() {
|
||||
c.proxy.removeConnection(c)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Connected to %s", c)
|
||||
atomic.StoreUint32(&c.closed, 0)
|
||||
|
||||
|
@ -657,6 +689,22 @@ func (c *mcuProxyConnection) reconnect() {
|
|||
go c.readPump()
|
||||
}
|
||||
|
||||
func (c *mcuProxyConnection) waitUntilConnected(ctx context.Context) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.conn != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
waiter := c.connectedNotifier.NewWaiter()
|
||||
defer c.connectedNotifier.Release(waiter)
|
||||
|
||||
c.mu.Unlock()
|
||||
defer c.mu.Lock()
|
||||
return waiter.Wait(ctx)
|
||||
}
|
||||
|
||||
func (c *mcuProxyConnection) removePublisher(publisher *mcuProxyPublisher) {
|
||||
c.proxy.removePublisher(publisher)
|
||||
|
||||
|
@ -669,7 +717,7 @@ func (c *mcuProxyConnection) removePublisher(publisher *mcuProxyPublisher) {
|
|||
}
|
||||
delete(c.publisherIds, publisher.id+"|"+publisher.StreamType())
|
||||
|
||||
if len(c.publishers) == 0 && atomic.LoadUint32(&c.closeScheduled) != 0 {
|
||||
if len(c.publishers) == 0 && (atomic.LoadUint32(&c.closeScheduled) != 0 || c.IsTemporary()) {
|
||||
go c.closeIfEmpty()
|
||||
}
|
||||
}
|
||||
|
@ -686,7 +734,7 @@ func (c *mcuProxyConnection) clearPublishers() {
|
|||
c.publishers = make(map[string]*mcuProxyPublisher)
|
||||
c.publisherIds = make(map[string]string)
|
||||
|
||||
if atomic.LoadUint32(&c.closeScheduled) != 0 {
|
||||
if atomic.LoadUint32(&c.closeScheduled) != 0 || c.IsTemporary() {
|
||||
go c.closeIfEmpty()
|
||||
}
|
||||
}
|
||||
|
@ -700,7 +748,7 @@ func (c *mcuProxyConnection) removeSubscriber(subscriber *mcuProxySubscriber) {
|
|||
statsSubscribersCurrent.WithLabelValues(subscriber.StreamType()).Dec()
|
||||
}
|
||||
|
||||
if len(c.subscribers) == 0 && atomic.LoadUint32(&c.closeScheduled) != 0 {
|
||||
if len(c.subscribers) == 0 && (atomic.LoadUint32(&c.closeScheduled) != 0 || c.IsTemporary()) {
|
||||
go c.closeIfEmpty()
|
||||
}
|
||||
}
|
||||
|
@ -716,7 +764,7 @@ func (c *mcuProxyConnection) clearSubscribers() {
|
|||
}(c.subscribers)
|
||||
c.subscribers = make(map[string]*mcuProxySubscriber)
|
||||
|
||||
if atomic.LoadUint32(&c.closeScheduled) != 0 {
|
||||
if atomic.LoadUint32(&c.closeScheduled) != 0 || c.IsTemporary() {
|
||||
go c.closeIfEmpty()
|
||||
}
|
||||
}
|
||||
|
@ -780,6 +828,8 @@ func (c *mcuProxyConnection) processMessage(msg *ProxyServerMessage) {
|
|||
if atomic.CompareAndSwapUint32(&c.trackClose, 0, 1) {
|
||||
statsConnectedProxyBackendsCurrent.WithLabelValues(c.Country()).Inc()
|
||||
}
|
||||
|
||||
c.connectedNotifier.Notify()
|
||||
default:
|
||||
log.Printf("Received unsupported hello response %+v from %s, reconnecting", msg, c)
|
||||
c.scheduleReconnect()
|
||||
|
@ -1006,20 +1056,13 @@ func (c *mcuProxyConnection) newPublisher(ctx context.Context, listener McuListe
|
|||
return publisher, nil
|
||||
}
|
||||
|
||||
func (c *mcuProxyConnection) newSubscriber(ctx context.Context, listener McuListener, publisher string, streamType string) (McuSubscriber, error) {
|
||||
c.publishersLock.Lock()
|
||||
id, found := c.publisherIds[publisher+"|"+streamType]
|
||||
c.publishersLock.Unlock()
|
||||
if !found {
|
||||
return nil, fmt.Errorf("Unknown publisher %s", publisher)
|
||||
}
|
||||
|
||||
func (c *mcuProxyConnection) newSubscriber(ctx context.Context, listener McuListener, publisherId string, publisherSessionId string, streamType string) (McuSubscriber, error) {
|
||||
msg := &ProxyClientMessage{
|
||||
Type: "command",
|
||||
Command: &CommandProxyClientMessage{
|
||||
Type: "create-subscriber",
|
||||
StreamType: streamType,
|
||||
PublisherId: id,
|
||||
PublisherId: publisherId,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -1030,8 +1073,8 @@ func (c *mcuProxyConnection) newSubscriber(ctx context.Context, listener McuList
|
|||
}
|
||||
|
||||
proxyId := response.Command.Id
|
||||
log.Printf("Created %s subscriber %s on %s for %s", streamType, proxyId, c, publisher)
|
||||
subscriber := newMcuProxySubscriber(publisher, response.Command.Sid, streamType, proxyId, c, listener)
|
||||
log.Printf("Created %s subscriber %s on %s for %s", streamType, proxyId, c, publisherSessionId)
|
||||
subscriber := newMcuProxySubscriber(publisherSessionId, response.Command.Sid, streamType, proxyId, c, listener)
|
||||
c.subscribersLock.Lock()
|
||||
c.subscribers[proxyId] = subscriber
|
||||
c.subscribersLock.Unlock()
|
||||
|
@ -1075,9 +1118,11 @@ type mcuProxy struct {
|
|||
publisherWaiters map[uint64]chan bool
|
||||
|
||||
continentsMap atomic.Value
|
||||
|
||||
rpcClients *GrpcClients
|
||||
}
|
||||
|
||||
func NewMcuProxy(config *goconf.ConfigFile, etcdClient *EtcdClient) (Mcu, error) {
|
||||
func NewMcuProxy(config *goconf.ConfigFile, etcdClient *EtcdClient, rpcClients *GrpcClients) (Mcu, error) {
|
||||
urlType, _ := config.GetString("mcu", "urltype")
|
||||
if urlType == "" {
|
||||
urlType = proxyUrlTypeStatic
|
||||
|
@ -1139,6 +1184,8 @@ func NewMcuProxy(config *goconf.ConfigFile, etcdClient *EtcdClient) (Mcu, error)
|
|||
publishers: make(map[string]*mcuProxyConnection),
|
||||
|
||||
publisherWaiters: make(map[uint64]chan bool),
|
||||
|
||||
rpcClients: rpcClients,
|
||||
}
|
||||
|
||||
if err := mcu.loadContinentsMap(config); err != nil {
|
||||
|
@ -1282,6 +1329,11 @@ func (m *mcuProxy) updateProxyIPs() {
|
|||
host = h
|
||||
}
|
||||
|
||||
if net.ParseIP(host) != nil {
|
||||
// No need to lookup endpoints that connect to IP addresses.
|
||||
continue
|
||||
}
|
||||
|
||||
ips, err := net.LookupIP(host)
|
||||
if err != nil {
|
||||
log.Printf("Could not lookup %s: %s", host, err)
|
||||
|
@ -1297,6 +1349,7 @@ func (m *mcuProxy) updateProxyIPs() {
|
|||
ips = append(ips[:idx], ips[idx+1:]...)
|
||||
found = true
|
||||
conn.stopCloseIfEmpty()
|
||||
conn.clearTemporary()
|
||||
newConns = append(newConns, conn)
|
||||
break
|
||||
}
|
||||
|
@ -1363,6 +1416,7 @@ func (m *mcuProxy) configureStatic(config *goconf.ConfigFile, fromReload bool) e
|
|||
delete(remove, u)
|
||||
for _, conn := range existing {
|
||||
conn.stopCloseIfEmpty()
|
||||
conn.clearTemporary()
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
@ -1561,6 +1615,7 @@ func (m *mcuProxy) EtcdKeyUpdated(client *EtcdClient, key string, data []byte) {
|
|||
m.urlToKey[info.Address] = key
|
||||
for _, conn := range conns {
|
||||
conn.stopCloseIfEmpty()
|
||||
conn.clearTemporary()
|
||||
}
|
||||
} else {
|
||||
conn, err := newMcuProxyConnection(m, info.Address, nil)
|
||||
|
@ -1815,7 +1870,7 @@ func (m *mcuProxy) removeWaiter(id uint64) {
|
|||
func (m *mcuProxy) NewPublisher(ctx context.Context, listener McuListener, id string, sid string, streamType string, bitrate int, mediaTypes MediaType, initiator McuInitiator) (McuPublisher, error) {
|
||||
connections := m.getSortedConnections(initiator)
|
||||
for _, conn := range connections {
|
||||
if conn.IsShutdownScheduled() {
|
||||
if conn.IsShutdownScheduled() || conn.IsTemporary() {
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -1850,19 +1905,18 @@ func (m *mcuProxy) NewPublisher(ctx context.Context, listener McuListener, id st
|
|||
return nil, fmt.Errorf("No MCU connection available")
|
||||
}
|
||||
|
||||
func (m *mcuProxy) getPublisherConnection(ctx context.Context, publisher string, streamType string) *mcuProxyConnection {
|
||||
func (m *mcuProxy) getPublisherConnection(publisher string, streamType string) *mcuProxyConnection {
|
||||
m.mu.RLock()
|
||||
conn := m.publishers[publisher+"|"+streamType]
|
||||
m.mu.RUnlock()
|
||||
if conn != nil {
|
||||
return conn
|
||||
}
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
log.Printf("No %s publisher %s found yet, deferring", streamType, publisher)
|
||||
return m.publishers[publisher+"|"+streamType]
|
||||
}
|
||||
|
||||
func (m *mcuProxy) waitForPublisherConnection(ctx context.Context, publisher string, streamType string) *mcuProxyConnection {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
conn = m.publishers[publisher+"|"+streamType]
|
||||
conn := m.publishers[publisher+"|"+streamType]
|
||||
if conn != nil {
|
||||
// Publisher was created while waiting for lock.
|
||||
return conn
|
||||
|
@ -1890,10 +1944,127 @@ func (m *mcuProxy) getPublisherConnection(ctx context.Context, publisher string,
|
|||
}
|
||||
|
||||
func (m *mcuProxy) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType string) (McuSubscriber, error) {
|
||||
conn := m.getPublisherConnection(ctx, publisher, streamType)
|
||||
if conn == nil {
|
||||
return nil, fmt.Errorf("No %s publisher %s found", streamType, publisher)
|
||||
if conn := m.getPublisherConnection(publisher, streamType); conn != nil {
|
||||
// Fast common path: publisher is available locally.
|
||||
conn.publishersLock.Lock()
|
||||
id, found := conn.publisherIds[publisher+"|"+streamType]
|
||||
conn.publishersLock.Unlock()
|
||||
if !found {
|
||||
return nil, fmt.Errorf("Unknown publisher %s", publisher)
|
||||
}
|
||||
|
||||
return conn.newSubscriber(ctx, listener, id, publisher, streamType)
|
||||
}
|
||||
|
||||
return conn.newSubscriber(ctx, listener, publisher, streamType)
|
||||
log.Printf("No %s publisher %s found yet, deferring", streamType, publisher)
|
||||
ch := make(chan McuSubscriber)
|
||||
getctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// Wait for publisher to be created locally.
|
||||
go func() {
|
||||
if conn := m.waitForPublisherConnection(getctx, publisher, streamType); conn != nil {
|
||||
cancel() // Cancel pending RPC calls.
|
||||
|
||||
conn.publishersLock.Lock()
|
||||
id, found := conn.publisherIds[publisher+"|"+streamType]
|
||||
conn.publishersLock.Unlock()
|
||||
if !found {
|
||||
log.Printf("Unknown id for local %s publisher %s", streamType, publisher)
|
||||
return
|
||||
}
|
||||
|
||||
subscriber, err := conn.newSubscriber(ctx, listener, id, publisher, streamType)
|
||||
if subscriber != nil {
|
||||
ch <- subscriber
|
||||
} else if err != nil {
|
||||
log.Printf("Error creating local subscriber for %s publisher %s: %s", streamType, publisher, err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for publisher to be created on one of the other servers in the cluster.
|
||||
if clients := m.rpcClients.GetClients(); len(clients) > 0 {
|
||||
for _, client := range clients {
|
||||
go func(client *GrpcClient) {
|
||||
id, url, ip, err := client.GetPublisherId(getctx, publisher, streamType)
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return
|
||||
} else if err != nil {
|
||||
log.Printf("Error getting %s publisher id %s from %s: %s", streamType, publisher, client.Target(), err)
|
||||
return
|
||||
} else if id == "" {
|
||||
// Publisher not found on other server
|
||||
return
|
||||
}
|
||||
|
||||
cancel() // Cancel pending RPC calls.
|
||||
log.Printf("Found publisher id %s through %s on proxy %s", id, client.Target(), url)
|
||||
|
||||
m.connectionsMu.RLock()
|
||||
connections := m.connections
|
||||
m.connectionsMu.RUnlock()
|
||||
var publisherConn *mcuProxyConnection
|
||||
for _, conn := range connections {
|
||||
if conn.rawUrl != url || !ip.Equal(conn.ip) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Simple case, signaling server has a connection to the same endpoint
|
||||
publisherConn = conn
|
||||
break
|
||||
}
|
||||
|
||||
if publisherConn == nil {
|
||||
publisherConn, err = newMcuProxyConnection(m, url, ip)
|
||||
if err != nil {
|
||||
log.Printf("Could not create temporary connection to %s for %s publisher %s: %s", url, streamType, publisher, err)
|
||||
return
|
||||
}
|
||||
publisherConn.setTemporary()
|
||||
|
||||
if err := publisherConn.start(); err != nil {
|
||||
log.Printf("Could not start new connection to %s: %s", publisherConn, err)
|
||||
publisherConn.closeIfEmpty()
|
||||
return
|
||||
}
|
||||
|
||||
if err := publisherConn.waitUntilConnected(ctx); err != nil {
|
||||
log.Printf("Could not establish new connection to %s: %s", publisherConn, err)
|
||||
publisherConn.closeIfEmpty()
|
||||
return
|
||||
}
|
||||
|
||||
m.connectionsMu.Lock()
|
||||
m.connections = append(m.connections, publisherConn)
|
||||
conns, found := m.connectionsMap[url]
|
||||
if found {
|
||||
conns = append(conns, publisherConn)
|
||||
} else {
|
||||
conns = []*mcuProxyConnection{publisherConn}
|
||||
}
|
||||
m.connectionsMap[url] = conns
|
||||
m.connectionsMu.Unlock()
|
||||
}
|
||||
|
||||
subscriber, err := publisherConn.newSubscriber(ctx, listener, id, publisher, streamType)
|
||||
if err != nil {
|
||||
if publisherConn.IsTemporary() {
|
||||
publisherConn.closeIfEmpty()
|
||||
}
|
||||
log.Printf("Could not create subscriber for %s publisher %s: %s", streamType, publisher, err)
|
||||
return
|
||||
}
|
||||
|
||||
ch <- subscriber
|
||||
}(client)
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case subscriber := <-ch:
|
||||
return subscriber, nil
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("No %s publisher %s found", streamType, publisher)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,22 +35,10 @@ import (
|
|||
const (
|
||||
initialConnectInterval = time.Second
|
||||
maxConnectInterval = 8 * time.Second
|
||||
|
||||
NatsLoopbackUrl = "nats://loopback"
|
||||
)
|
||||
|
||||
type NatsMessage struct {
|
||||
SendTime time.Time `json:"sendtime"`
|
||||
|
||||
Type string `json:"type"`
|
||||
|
||||
Message *ServerMessage `json:"message,omitempty"`
|
||||
|
||||
Room *BackendServerRoomRequest `json:"room,omitempty"`
|
||||
|
||||
Permissions []Permission `json:"permissions,omitempty"`
|
||||
|
||||
Id string `json:"id"`
|
||||
}
|
||||
|
||||
type NatsSubscription interface {
|
||||
Unsubscribe() error
|
||||
}
|
||||
|
@ -59,11 +47,7 @@ type NatsClient interface {
|
|||
Close()
|
||||
|
||||
Subscribe(subject string, ch chan *nats.Msg) (NatsSubscription, error)
|
||||
|
||||
Publish(subject string, message interface{}) error
|
||||
PublishNats(subject string, message *NatsMessage) error
|
||||
PublishMessage(subject string, message *ServerMessage) error
|
||||
PublishBackendServerRoomRequest(subject string, message *BackendServerRoomRequest) error
|
||||
|
||||
Decode(msg *nats.Msg, v interface{}) error
|
||||
}
|
||||
|
@ -82,7 +66,11 @@ type natsClient struct {
|
|||
|
||||
func NewNatsClient(url string) (NatsClient, error) {
|
||||
if url == ":loopback:" {
|
||||
log.Println("No NATS url configured, using internal loopback client")
|
||||
log.Printf("WARNING: events url %s is deprecated, please use %s instead", url, NatsLoopbackUrl)
|
||||
url = NatsLoopbackUrl
|
||||
}
|
||||
if url == NatsLoopbackUrl {
|
||||
log.Println("Using internal NATS loopback client")
|
||||
return NewLoopbackNatsClient()
|
||||
}
|
||||
|
||||
|
@ -148,28 +136,6 @@ func (c *natsClient) Publish(subject string, message interface{}) error {
|
|||
return c.conn.Publish(subject, message)
|
||||
}
|
||||
|
||||
func (c *natsClient) PublishNats(subject string, message *NatsMessage) error {
|
||||
return c.Publish(subject, message)
|
||||
}
|
||||
|
||||
func (c *natsClient) PublishMessage(subject string, message *ServerMessage) error {
|
||||
msg := &NatsMessage{
|
||||
SendTime: time.Now(),
|
||||
Type: "message",
|
||||
Message: message,
|
||||
}
|
||||
return c.PublishNats(subject, msg)
|
||||
}
|
||||
|
||||
func (c *natsClient) PublishBackendServerRoomRequest(subject string, message *BackendServerRoomRequest) error {
|
||||
msg := &NatsMessage{
|
||||
SendTime: time.Now(),
|
||||
Type: "room",
|
||||
Room: message,
|
||||
}
|
||||
return c.PublishNats(subject, msg)
|
||||
}
|
||||
|
||||
func (c *natsClient) Decode(msg *nats.Msg, v interface{}) error {
|
||||
return c.conn.Enc.Decode(msg.Subject, msg.Data, v)
|
||||
}
|
||||
|
|
|
@ -27,7 +27,6 @@ import (
|
|||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/nats-io/nats.go"
|
||||
)
|
||||
|
@ -170,28 +169,6 @@ func (c *LoopbackNatsClient) Publish(subject string, message interface{}) error
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *LoopbackNatsClient) PublishNats(subject string, message *NatsMessage) error {
|
||||
return c.Publish(subject, message)
|
||||
}
|
||||
|
||||
func (c *LoopbackNatsClient) PublishMessage(subject string, message *ServerMessage) error {
|
||||
msg := &NatsMessage{
|
||||
SendTime: time.Now(),
|
||||
Type: "message",
|
||||
Message: message,
|
||||
}
|
||||
return c.PublishNats(subject, msg)
|
||||
}
|
||||
|
||||
func (c *LoopbackNatsClient) PublishBackendServerRoomRequest(subject string, message *BackendServerRoomRequest) error {
|
||||
msg := &NatsMessage{
|
||||
SendTime: time.Now(),
|
||||
Type: "room",
|
||||
Room: message,
|
||||
}
|
||||
return c.PublishNats(subject, msg)
|
||||
}
|
||||
|
||||
func (c *LoopbackNatsClient) Decode(msg *nats.Msg, v interface{}) error {
|
||||
return json.Unmarshal(msg.Data, v)
|
||||
}
|
||||
|
|
28
notifier.go
28
notifier.go
|
@ -29,17 +29,7 @@ import (
|
|||
type Waiter struct {
|
||||
key string
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func (w *Waiter) Wait(ctx context.Context) error {
|
||||
select {
|
||||
case <-w.ctx.Done():
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
SingleWaiter
|
||||
}
|
||||
|
||||
type Notifier struct {
|
||||
|
@ -56,9 +46,11 @@ func (n *Notifier) NewWaiter(key string) *Waiter {
|
|||
waiter, found := n.waiters[key]
|
||||
if found {
|
||||
w := &Waiter{
|
||||
key: key,
|
||||
ctx: waiter.ctx,
|
||||
cancel: waiter.cancel,
|
||||
key: key,
|
||||
SingleWaiter: SingleWaiter{
|
||||
ctx: waiter.ctx,
|
||||
cancel: waiter.cancel,
|
||||
},
|
||||
}
|
||||
n.waiterMap[key][w] = true
|
||||
return w
|
||||
|
@ -66,9 +58,11 @@ func (n *Notifier) NewWaiter(key string) *Waiter {
|
|||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
waiter = &Waiter{
|
||||
key: key,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
key: key,
|
||||
SingleWaiter: SingleWaiter{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
},
|
||||
}
|
||||
if n.waiters == nil {
|
||||
n.waiters = make(map[string]*Waiter)
|
||||
|
|
245
room.go
245
room.go
|
@ -32,7 +32,6 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
|
@ -56,7 +55,7 @@ func init() {
|
|||
type Room struct {
|
||||
id string
|
||||
hub *Hub
|
||||
nats NatsClient
|
||||
events AsyncEvents
|
||||
backend *Backend
|
||||
|
||||
properties *json.RawMessage
|
||||
|
@ -72,34 +71,15 @@ type Room struct {
|
|||
|
||||
statsRoomSessionsCurrent *prometheus.GaugeVec
|
||||
|
||||
natsReceiver chan *nats.Msg
|
||||
backendSubscription NatsSubscription
|
||||
|
||||
// Users currently in the room
|
||||
users []map[string]interface{}
|
||||
|
||||
// Timestamps of last NATS backend requests for the different types.
|
||||
lastNatsRoomRequests map[string]int64
|
||||
// Timestamps of last backend requests for the different types.
|
||||
lastRoomRequests map[string]int64
|
||||
|
||||
transientData *TransientData
|
||||
}
|
||||
|
||||
func GetSubjectForRoomId(roomId string, backend *Backend) string {
|
||||
if backend == nil || backend.IsCompat() {
|
||||
return GetEncodedSubject("room", roomId)
|
||||
}
|
||||
|
||||
return GetEncodedSubject("room", roomId+"|"+backend.Id())
|
||||
}
|
||||
|
||||
func GetSubjectForBackendRoomId(roomId string, backend *Backend) string {
|
||||
if backend == nil || backend.IsCompat() {
|
||||
return GetEncodedSubject("backend.room", roomId)
|
||||
}
|
||||
|
||||
return GetEncodedSubject("backend.room", roomId+"|"+backend.Id())
|
||||
}
|
||||
|
||||
func getRoomIdForBackend(id string, backend *Backend) string {
|
||||
if id == "" {
|
||||
return ""
|
||||
|
@ -108,18 +88,11 @@ func getRoomIdForBackend(id string, backend *Backend) string {
|
|||
return backend.Id() + "|" + id
|
||||
}
|
||||
|
||||
func NewRoom(roomId string, properties *json.RawMessage, hub *Hub, n NatsClient, backend *Backend) (*Room, error) {
|
||||
natsReceiver := make(chan *nats.Msg, 64)
|
||||
backendSubscription, err := n.Subscribe(GetSubjectForBackendRoomId(roomId, backend), natsReceiver)
|
||||
if err != nil {
|
||||
close(natsReceiver)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func NewRoom(roomId string, properties *json.RawMessage, hub *Hub, events AsyncEvents, backend *Backend) (*Room, error) {
|
||||
room := &Room{
|
||||
id: roomId,
|
||||
hub: hub,
|
||||
nats: n,
|
||||
events: events,
|
||||
backend: backend,
|
||||
|
||||
properties: properties,
|
||||
|
@ -138,13 +111,15 @@ func NewRoom(roomId string, properties *json.RawMessage, hub *Hub, n NatsClient,
|
|||
"room": roomId,
|
||||
}),
|
||||
|
||||
natsReceiver: natsReceiver,
|
||||
backendSubscription: backendSubscription,
|
||||
|
||||
lastNatsRoomRequests: make(map[string]int64),
|
||||
lastRoomRequests: make(map[string]int64),
|
||||
|
||||
transientData: NewTransientData(),
|
||||
}
|
||||
|
||||
if err := events.RegisterBackendRoomListener(roomId, backend, room); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go room.run()
|
||||
|
||||
return room, nil
|
||||
|
@ -193,10 +168,6 @@ loop:
|
|||
select {
|
||||
case <-r.closeChan:
|
||||
break loop
|
||||
case msg := <-r.natsReceiver:
|
||||
if msg != nil {
|
||||
r.processNatsMessage(msg)
|
||||
}
|
||||
case <-ticker.C:
|
||||
r.publishActiveSessions()
|
||||
}
|
||||
|
@ -211,16 +182,7 @@ func (r *Room) doClose() {
|
|||
}
|
||||
|
||||
func (r *Room) unsubscribeBackend() {
|
||||
if r.backendSubscription == nil {
|
||||
return
|
||||
}
|
||||
|
||||
go func(subscription NatsSubscription) {
|
||||
if err := subscription.Unsubscribe(); err != nil {
|
||||
log.Printf("Error closing backend subscription for room %s: %s", r.Id(), err)
|
||||
}
|
||||
}(r.backendSubscription)
|
||||
r.backendSubscription = nil
|
||||
r.events.UnregisterBackendRoomListener(r.id, r.backend, r)
|
||||
}
|
||||
|
||||
func (r *Room) Close() []Session {
|
||||
|
@ -240,33 +202,29 @@ func (r *Room) Close() []Session {
|
|||
return result
|
||||
}
|
||||
|
||||
func (r *Room) processNatsMessage(message *nats.Msg) {
|
||||
var msg NatsMessage
|
||||
if err := r.nats.Decode(message, &msg); err != nil {
|
||||
log.Printf("Could not decode nats message %+v, %s", message, err)
|
||||
return
|
||||
}
|
||||
|
||||
switch msg.Type {
|
||||
func (r *Room) ProcessBackendRoomRequest(message *AsyncMessage) {
|
||||
switch message.Type {
|
||||
case "room":
|
||||
r.processBackendRoomRequest(msg.Room)
|
||||
r.processBackendRoomRequestRoom(message.Room)
|
||||
case "asyncroom":
|
||||
r.processBackendRoomRequestAsyncRoom(message.AsyncRoom)
|
||||
default:
|
||||
log.Printf("Unsupported NATS room request with type %s: %+v", msg.Type, msg)
|
||||
log.Printf("Unsupported backend room request with type %s in %s: %+v", message.Type, r.id, message)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Room) processBackendRoomRequest(message *BackendServerRoomRequest) {
|
||||
func (r *Room) processBackendRoomRequestRoom(message *BackendServerRoomRequest) {
|
||||
received := message.ReceivedTime
|
||||
if last, found := r.lastNatsRoomRequests[message.Type]; found && last > received {
|
||||
if last, found := r.lastRoomRequests[message.Type]; found && last > received {
|
||||
if msg, err := json.Marshal(message); err == nil {
|
||||
log.Printf("Ignore old NATS backend room request for %s: %s", r.Id(), string(msg))
|
||||
log.Printf("Ignore old backend room request for %s: %s", r.Id(), string(msg))
|
||||
} else {
|
||||
log.Printf("Ignore old NATS backend room request for %s: %+v", r.Id(), message)
|
||||
log.Printf("Ignore old backend room request for %s: %+v", r.Id(), message)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
r.lastNatsRoomRequests[message.Type] = received
|
||||
r.lastRoomRequests[message.Type] = received
|
||||
message.room = r
|
||||
switch message.Type {
|
||||
case "update":
|
||||
|
@ -281,11 +239,20 @@ func (r *Room) processBackendRoomRequest(message *BackendServerRoomRequest) {
|
|||
case "message":
|
||||
r.publishRoomMessage(message.Message)
|
||||
default:
|
||||
log.Printf("Unsupported NATS backend room request with type %s in %s: %+v", message.Type, r.Id(), message)
|
||||
log.Printf("Unsupported backend room request with type %s in %s: %+v", message.Type, r.Id(), message)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Room) AddSession(session Session, sessionData *json.RawMessage) []Session {
|
||||
func (r *Room) processBackendRoomRequestAsyncRoom(message *AsyncRoomMessage) {
|
||||
switch message.Type {
|
||||
case "sessionjoined":
|
||||
r.notifySessionJoined(message.SessionId)
|
||||
default:
|
||||
log.Printf("Unsupported async room request with type %s in %s: %+v", message.Type, r.Id(), message)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Room) AddSession(session Session, sessionData *json.RawMessage) {
|
||||
var roomSessionData *RoomSessionData
|
||||
if sessionData != nil && len(*sessionData) > 0 {
|
||||
roomSessionData = &RoomSessionData{}
|
||||
|
@ -298,13 +265,6 @@ func (r *Room) AddSession(session Session, sessionData *json.RawMessage) []Sessi
|
|||
sid := session.PublicId()
|
||||
r.mu.Lock()
|
||||
_, found := r.sessions[sid]
|
||||
// Return list of sessions already in the room.
|
||||
result := make([]Session, 0, len(r.sessions))
|
||||
for _, s := range r.sessions {
|
||||
if s != session {
|
||||
result = append(result, s)
|
||||
}
|
||||
}
|
||||
r.sessions[sid] = session
|
||||
if !found {
|
||||
r.statsRoomSessionsCurrent.With(prometheus.Labels{"clienttype": session.ClientType()}).Inc()
|
||||
|
@ -340,7 +300,116 @@ func (r *Room) AddSession(session Session, sessionData *json.RawMessage) []Sessi
|
|||
r.transientData.AddListener(clientSession)
|
||||
}
|
||||
}
|
||||
return result
|
||||
|
||||
// Trigger notifications that the session joined.
|
||||
if err := r.events.PublishBackendRoomMessage(r.id, r.backend, &AsyncMessage{
|
||||
Type: "asyncroom",
|
||||
AsyncRoom: &AsyncRoomMessage{
|
||||
Type: "sessionjoined",
|
||||
SessionId: sid,
|
||||
},
|
||||
}); err != nil {
|
||||
log.Printf("Error publishing joined event for session %s: %s", sid, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Room) getOtherSessions(ignoreSessionId string) (Session, []Session) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
sessions := make([]Session, 0, len(r.sessions))
|
||||
for _, s := range r.sessions {
|
||||
if s.PublicId() == ignoreSessionId {
|
||||
continue
|
||||
}
|
||||
|
||||
sessions = append(sessions, s)
|
||||
}
|
||||
|
||||
return r.sessions[ignoreSessionId], sessions
|
||||
}
|
||||
|
||||
func (r *Room) notifySessionJoined(sessionId string) {
|
||||
session, sessions := r.getOtherSessions(sessionId)
|
||||
if len(sessions) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if session != nil && session.ClientType() != HelloClientTypeClient {
|
||||
session = nil
|
||||
}
|
||||
|
||||
events := make([]*EventServerMessageSessionEntry, 0, len(sessions))
|
||||
for _, s := range sessions {
|
||||
entry := &EventServerMessageSessionEntry{
|
||||
SessionId: s.PublicId(),
|
||||
UserId: s.UserId(),
|
||||
User: s.UserData(),
|
||||
}
|
||||
if s, ok := s.(*ClientSession); ok {
|
||||
entry.RoomSessionId = s.RoomSessionId()
|
||||
}
|
||||
events = append(events, entry)
|
||||
}
|
||||
|
||||
msg := &ServerMessage{
|
||||
Type: "event",
|
||||
Event: &EventServerMessage{
|
||||
Target: "room",
|
||||
Type: "join",
|
||||
Join: events,
|
||||
},
|
||||
}
|
||||
|
||||
if session != nil {
|
||||
// No need to send through asynchronous events, the session is connected locally.
|
||||
session.(*ClientSession).SendMessage(msg)
|
||||
} else {
|
||||
if err := r.events.PublishSessionMessage(sessionId, r.backend, &AsyncMessage{
|
||||
Type: "message",
|
||||
Message: msg,
|
||||
}); err != nil {
|
||||
log.Printf("Error publishing joined events to session %s: %s", sessionId, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Notify about initial flags of virtual sessions.
|
||||
for _, s := range sessions {
|
||||
vsess, ok := s.(*VirtualSession)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
flags := vsess.Flags()
|
||||
if flags == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
msg := &ServerMessage{
|
||||
Type: "event",
|
||||
Event: &EventServerMessage{
|
||||
Target: "participants",
|
||||
Type: "flags",
|
||||
Flags: &RoomFlagsServerMessage{
|
||||
RoomId: r.id,
|
||||
SessionId: vsess.PublicId(),
|
||||
Flags: vsess.Flags(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if session != nil {
|
||||
// No need to send through asynchronous events, the session is connected locally.
|
||||
session.(*ClientSession).SendMessage(msg)
|
||||
} else {
|
||||
if err := r.events.PublishSessionMessage(sessionId, r.backend, &AsyncMessage{
|
||||
Type: "message",
|
||||
Message: msg,
|
||||
}); err != nil {
|
||||
log.Printf("Error publishing initial flags to session %s: %s", sessionId, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Room) HasSession(session Session) bool {
|
||||
|
@ -394,7 +463,10 @@ func (r *Room) RemoveSession(session Session) bool {
|
|||
}
|
||||
|
||||
func (r *Room) publish(message *ServerMessage) error {
|
||||
return r.nats.PublishMessage(GetSubjectForRoomId(r.id, r.backend), message)
|
||||
return r.events.PublishRoomMessage(r.id, r.backend, &AsyncMessage{
|
||||
Type: "message",
|
||||
Message: message,
|
||||
})
|
||||
}
|
||||
|
||||
func (r *Room) UpdateProperties(properties *json.RawMessage) {
|
||||
|
@ -620,11 +692,13 @@ func (r *Room) PublishUsersInCallChangedAll(inCall int) {
|
|||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
var notify []*ClientSession
|
||||
if inCall&FlagInCall != 0 {
|
||||
// All connected sessions join the call.
|
||||
var joined []string
|
||||
for _, session := range r.sessions {
|
||||
if _, ok := session.(*ClientSession); !ok {
|
||||
clientSession, ok := session.(*ClientSession)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -636,6 +710,7 @@ func (r *Room) PublishUsersInCallChangedAll(inCall int) {
|
|||
r.inCallSessions[session] = true
|
||||
joined = append(joined, session.PublicId())
|
||||
}
|
||||
notify = append(notify, clientSession)
|
||||
}
|
||||
|
||||
if len(joined) == 0 {
|
||||
|
@ -657,6 +732,15 @@ func (r *Room) PublishUsersInCallChangedAll(inCall int) {
|
|||
}
|
||||
}()
|
||||
|
||||
for _, session := range r.sessions {
|
||||
clientSession, ok := session.(*ClientSession)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
notify = append(notify, clientSession)
|
||||
}
|
||||
|
||||
for session := range r.inCallSessions {
|
||||
if clientSession, ok := session.(*ClientSession); ok {
|
||||
ch <- clientSession
|
||||
|
@ -683,8 +767,11 @@ func (r *Room) PublishUsersInCallChangedAll(inCall int) {
|
|||
},
|
||||
},
|
||||
}
|
||||
if err := r.publish(message); err != nil {
|
||||
log.Printf("Could not publish incall message in room %s: %s", r.Id(), err)
|
||||
|
||||
for _, session := range notify {
|
||||
if !session.SendMessage(message) {
|
||||
log.Printf("Could not send incall message from room %s to %s", r.Id(), session.PublicId())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -44,7 +44,7 @@ func NewRoomPingForTest(t *testing.T) (*url.URL, *RoomPing) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
backend, err := NewBackendClient(config, 1, "0.0")
|
||||
backend, err := NewBackendClient(config, 1, "0.0", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -145,7 +145,7 @@ func TestRoom_Update(t *testing.T) {
|
|||
}
|
||||
|
||||
// The client receives a roomlist update and a changed room event. The
|
||||
// ordering is not defined because messages are sent by asynchronous NATS
|
||||
// ordering is not defined because messages are sent by asynchronous event
|
||||
// handlers.
|
||||
message1, err := client.RunUntilMessage(ctx)
|
||||
if err != nil {
|
||||
|
@ -178,7 +178,7 @@ func TestRoom_Update(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// Allow up to 100 milliseconds for NATS processing.
|
||||
// Allow up to 100 milliseconds for asynchronous event processing.
|
||||
ctx2, cancel2 := context.WithTimeout(ctx, 100*time.Millisecond)
|
||||
defer cancel2()
|
||||
|
||||
|
@ -280,7 +280,7 @@ func TestRoom_Delete(t *testing.T) {
|
|||
}
|
||||
|
||||
// The client is no longer invited to the room and leaves it. The ordering
|
||||
// of messages is not defined as they get published through NATS and handled
|
||||
// of messages is not defined as they get published through events and handled
|
||||
// by asynchronous channels.
|
||||
message1, err := client.RunUntilMessage(ctx)
|
||||
if err != nil {
|
||||
|
@ -318,7 +318,7 @@ func TestRoom_Delete(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// Allow up to 100 milliseconds for NATS processing.
|
||||
// Allow up to 100 milliseconds for asynchronous event processing.
|
||||
ctx2, cancel2 := context.WithTimeout(ctx, 100*time.Millisecond)
|
||||
defer cancel2()
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
package signaling
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
|
@ -34,4 +35,5 @@ type RoomSessions interface {
|
|||
DeleteRoomSession(session Session)
|
||||
|
||||
GetSessionId(roomSessionId string) (string, error)
|
||||
LookupSessionId(ctx context.Context, roomSessionId string) (string, error)
|
||||
}
|
||||
|
|
|
@ -22,19 +22,27 @@
|
|||
package signaling
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type BuiltinRoomSessions struct {
|
||||
sessionIdToRoomSession map[string]string
|
||||
roomSessionToSessionid map[string]string
|
||||
mu sync.RWMutex
|
||||
|
||||
clients *GrpcClients
|
||||
}
|
||||
|
||||
func NewBuiltinRoomSessions() (RoomSessions, error) {
|
||||
func NewBuiltinRoomSessions(clients *GrpcClients) (RoomSessions, error) {
|
||||
return &BuiltinRoomSessions{
|
||||
sessionIdToRoomSession: make(map[string]string),
|
||||
roomSessionToSessionid: make(map[string]string),
|
||||
|
||||
clients: clients,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -78,3 +86,53 @@ func (r *BuiltinRoomSessions) GetSessionId(roomSessionId string) (string, error)
|
|||
|
||||
return sid, nil
|
||||
}
|
||||
|
||||
func (r *BuiltinRoomSessions) LookupSessionId(ctx context.Context, roomSessionId string) (string, error) {
|
||||
sid, err := r.GetSessionId(roomSessionId)
|
||||
if err == nil {
|
||||
return sid, nil
|
||||
}
|
||||
|
||||
if r.clients == nil {
|
||||
return "", ErrNoSuchRoomSession
|
||||
}
|
||||
|
||||
clients := r.clients.GetClients()
|
||||
if len(clients) == 0 {
|
||||
return "", ErrNoSuchRoomSession
|
||||
}
|
||||
|
||||
lookupctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var result atomic.Value
|
||||
for _, client := range clients {
|
||||
wg.Add(1)
|
||||
go func(client *GrpcClient) {
|
||||
defer wg.Done()
|
||||
|
||||
sid, err := client.LookupSessionId(lookupctx, roomSessionId)
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return
|
||||
} else if err != nil {
|
||||
log.Printf("Received error while checking for room session id %s on %s: %s", roomSessionId, client.Target(), err)
|
||||
return
|
||||
} else if sid == "" {
|
||||
log.Printf("Received empty session id for room session id %s from %s", roomSessionId, client.Target())
|
||||
return
|
||||
}
|
||||
|
||||
cancel() // Cancel pending RPC calls.
|
||||
result.Store(sid)
|
||||
}(client)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
value := result.Load()
|
||||
if value == nil {
|
||||
return "", ErrNoSuchRoomSession
|
||||
}
|
||||
|
||||
return value.(string), nil
|
||||
}
|
||||
|
|
|
@ -26,7 +26,7 @@ import (
|
|||
)
|
||||
|
||||
func TestBuiltinRoomSessions(t *testing.T) {
|
||||
sessions, err := NewBuiltinRoomSessions()
|
||||
sessions, err := NewBuiltinRoomSessions(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -50,6 +50,15 @@ blockkey = -encryption-key-
|
|||
internalsecret = the-shared-secret-for-internal-clients
|
||||
|
||||
[backend]
|
||||
# Type of backend configuration.
|
||||
# Defaults to "static".
|
||||
#
|
||||
# Possible values:
|
||||
# - static: A comma-separated list of backends is given in the "backends" option.
|
||||
# - etcd: Backends are retrieved from an etcd cluster.
|
||||
#backendtype = static
|
||||
|
||||
# For backend type "static":
|
||||
# Comma-separated list of backend ids from which clients are allowed to connect
|
||||
# from. Each backend will have isolated rooms, i.e. clients connecting to room
|
||||
# "abc12345" on backend 1 will be in a different room than clients connected to
|
||||
|
@ -57,6 +66,22 @@ internalsecret = the-shared-secret-for-internal-clients
|
|||
# backends will not be able to communicate with each other.
|
||||
#backends = backend-id, another-backend
|
||||
|
||||
# For backend type "etcd":
|
||||
# Key prefix of backend entries. All keys below will be watched and assumed to
|
||||
# contain a JSON document with the following entries:
|
||||
# - "url": Url of the Nextcloud instance.
|
||||
# - "secret": Shared secret for requests from and to the backend servers.
|
||||
#
|
||||
# Additional optional entries:
|
||||
# - "maxstreambitrate": Maximum bitrate per publishing stream (in bits per second).
|
||||
# - "maxscreenbitrate": Maximum bitrate per screensharing stream (in bits per second).
|
||||
# - "sessionlimit": Number of sessions that are allowed to connect.
|
||||
#
|
||||
# Example:
|
||||
# "/signaling/backend/one" -> {"url": "https://nextcloud.domain1.invalid", ...}
|
||||
# "/signaling/backend/two" -> {"url": "https://domain2.invalid/nextcloud", ...}
|
||||
#backendprefix = /signaling/backend
|
||||
|
||||
# Allow any hostname as backend endpoint. This is extremely insecure and should
|
||||
# only be used while running the benchmark client against the server.
|
||||
allowall = false
|
||||
|
@ -77,6 +102,7 @@ connectionsperhost = 8
|
|||
# certificates.
|
||||
#skipverify = false
|
||||
|
||||
# For backendtype "static":
|
||||
# Backend configurations as defined in the "[backend]" section above. The
|
||||
# section names must match the ids used in "backends" above.
|
||||
#[backend-id]
|
||||
|
@ -109,7 +135,7 @@ connectionsperhost = 8
|
|||
|
||||
[nats]
|
||||
# Url of NATS backend to use. This can also be a list of URLs to connect to
|
||||
# multiple backends. For local development, this can be set to ":loopback:"
|
||||
# multiple backends. For local development, this can be set to "nats://loopback"
|
||||
# to process NATS messages internally instead of sending them through an
|
||||
# external NATS backend.
|
||||
#url = nats://localhost:4222
|
||||
|
@ -234,3 +260,55 @@ connectionsperhost = 8
|
|||
#clientkey = /path/to/etcd-client.key
|
||||
#clientcert = /path/to/etcd-client.crt
|
||||
#cacert = /path/to/etcd-ca.crt
|
||||
|
||||
[grpc]
|
||||
# IP and port to listen on for GRPC requests.
|
||||
# Comment line to disable the listener.
|
||||
#listen = 0.0.0.0:9090
|
||||
|
||||
# Certificate / private key to use for the GRPC server.
|
||||
# Omit to use unencrypted connections.
|
||||
#servercertificate = /path/to/grpc-server.crt
|
||||
#serverkey = /path/to/grpc-server.key
|
||||
|
||||
# CA certificate that is allowed to issue certificates of GRPC servers.
|
||||
# Omit to expect unencrypted connections.
|
||||
#serverca = /path/to/grpc-ca.crt
|
||||
|
||||
# Certificate / private key to use for the GRPC client.
|
||||
# Omit if clients don't need to authenticate on the server.
|
||||
#clientcertificate = /path/to/grpc-client.crt
|
||||
#clientkey = /path/to/grpc-client.key
|
||||
|
||||
# CA certificate that is allowed to issue certificates of GRPC clients.
|
||||
# Omit to allow any clients to connect.
|
||||
#clientca = /path/to/grpc-ca.crt
|
||||
|
||||
# Type of GRPC target configuration.
|
||||
# Defaults to "static".
|
||||
#
|
||||
# Possible values:
|
||||
# - static: A comma-separated list of targets is given in the "targets" option.
|
||||
# - etcd: Target URLs are retrieved from an etcd cluster.
|
||||
#targettype = static
|
||||
|
||||
# For target type "static": Comma-separated list of GRPC targets to connect to
|
||||
# for clustering mode.
|
||||
#targets = 192.168.0.1:9090, 192.168.0.2:9090
|
||||
|
||||
# For target type "static": Enable DNS discovery on hostnames of GRPC target.
|
||||
# If a hostname resolves to multiple IP addresses, a connection is established
|
||||
# to each of them.
|
||||
# Changes to the DNS are monitored regularly and GRPC clients are created or
|
||||
# deleted as necessary.
|
||||
#dnsdiscovery = true
|
||||
|
||||
# For target type "etcd": Key prefix of GRPC target entries. All keys below will
|
||||
# be watched and assumed to contain a JSON document. The entry "address" from
|
||||
# this document will be used as target URL, other contents in the document will
|
||||
# be ignored.
|
||||
#
|
||||
# Example:
|
||||
# "/signaling/cluster/grpc/one" -> {"address": "192.168.0.1:9090"}
|
||||
# "/signaling/cluster/grpc/two" -> {"address": "192.168.0.2:9090"}
|
||||
#targetprefix = /signaling/cluster/grpc
|
||||
|
|
|
@ -148,10 +148,11 @@ func main() {
|
|||
natsUrl = nats.DefaultURL
|
||||
}
|
||||
|
||||
nats, err := signaling.NewNatsClient(natsUrl)
|
||||
events, err := signaling.NewAsyncEvents(natsUrl)
|
||||
if err != nil {
|
||||
log.Fatal("Could not create NATS client: ", err)
|
||||
log.Fatal("Could not create async events client: ", err)
|
||||
}
|
||||
defer events.Close()
|
||||
|
||||
etcdClient, err := signaling.NewEtcdClient(config, "mcu")
|
||||
if err != nil {
|
||||
|
@ -163,8 +164,25 @@ func main() {
|
|||
}
|
||||
}()
|
||||
|
||||
rpcServer, err := signaling.NewGrpcServer(config)
|
||||
if err != nil {
|
||||
log.Fatalf("Could not create RPC server: %s", err)
|
||||
}
|
||||
go func() {
|
||||
if err := rpcServer.Run(); err != nil {
|
||||
log.Fatalf("Could not start RPC server: %s", err)
|
||||
}
|
||||
}()
|
||||
defer rpcServer.Close()
|
||||
|
||||
rpcClients, err := signaling.NewGrpcClients(config, etcdClient)
|
||||
if err != nil {
|
||||
log.Fatalf("Could not create RPC clients: %s", err)
|
||||
}
|
||||
defer rpcClients.Close()
|
||||
|
||||
r := mux.NewRouter()
|
||||
hub, err := signaling.NewHub(config, nats, r, version)
|
||||
hub, err := signaling.NewHub(config, events, rpcServer, rpcClients, etcdClient, r, version)
|
||||
if err != nil {
|
||||
log.Fatal("Could not create hub: ", err)
|
||||
}
|
||||
|
@ -191,7 +209,7 @@ func main() {
|
|||
signaling.UnregisterProxyMcuStats()
|
||||
signaling.RegisterJanusMcuStats()
|
||||
case signaling.McuTypeProxy:
|
||||
mcu, err = signaling.NewMcuProxy(config, etcdClient)
|
||||
mcu, err = signaling.NewMcuProxy(config, etcdClient, rpcClients)
|
||||
signaling.UnregisterJanusMcuStats()
|
||||
signaling.RegisterProxyMcuStats()
|
||||
default:
|
||||
|
|
109
single_notifier.go
Normal file
109
single_notifier.go
Normal file
|
@ -0,0 +1,109 @@
|
|||
/**
|
||||
* Standalone signaling server for the Nextcloud Spreed app.
|
||||
* Copyright (C) 2022 struktur AG
|
||||
*
|
||||
* @author Joachim Bauch <bauch@struktur.de>
|
||||
*
|
||||
* @license GNU AGPL version 3 or any later version
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package signaling
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type SingleWaiter struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func (w *SingleWaiter) Wait(ctx context.Context) error {
|
||||
select {
|
||||
case <-w.ctx.Done():
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
type SingleNotifier struct {
|
||||
sync.Mutex
|
||||
|
||||
waiter *SingleWaiter
|
||||
waiters map[*SingleWaiter]bool
|
||||
}
|
||||
|
||||
func (n *SingleNotifier) NewWaiter() *SingleWaiter {
|
||||
n.Lock()
|
||||
defer n.Unlock()
|
||||
|
||||
if n.waiter == nil {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
n.waiter = &SingleWaiter{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
if n.waiters == nil {
|
||||
n.waiters = make(map[*SingleWaiter]bool)
|
||||
}
|
||||
|
||||
w := &SingleWaiter{
|
||||
ctx: n.waiter.ctx,
|
||||
cancel: n.waiter.cancel,
|
||||
}
|
||||
n.waiters[w] = true
|
||||
return w
|
||||
}
|
||||
|
||||
func (n *SingleNotifier) Reset() {
|
||||
n.Lock()
|
||||
defer n.Unlock()
|
||||
|
||||
if n.waiter != nil {
|
||||
n.waiter.cancel()
|
||||
n.waiter = nil
|
||||
}
|
||||
n.waiters = nil
|
||||
}
|
||||
|
||||
func (n *SingleNotifier) Release(w *SingleWaiter) {
|
||||
n.Lock()
|
||||
defer n.Unlock()
|
||||
|
||||
if _, found := n.waiters[w]; found {
|
||||
delete(n.waiters, w)
|
||||
if len(n.waiters) == 0 {
|
||||
n.waiters = nil
|
||||
if n.waiter != nil {
|
||||
n.waiter.cancel()
|
||||
n.waiter = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (n *SingleNotifier) Notify() {
|
||||
n.Lock()
|
||||
defer n.Unlock()
|
||||
|
||||
if n.waiter != nil {
|
||||
n.waiter.cancel()
|
||||
}
|
||||
n.waiters = nil
|
||||
}
|
150
single_notifier_test.go
Normal file
150
single_notifier_test.go
Normal file
|
@ -0,0 +1,150 @@
|
|||
/**
|
||||
* Standalone signaling server for the Nextcloud Spreed app.
|
||||
* Copyright (C) 2022 struktur AG
|
||||
*
|
||||
* @author Joachim Bauch <bauch@struktur.de>
|
||||
*
|
||||
* @license GNU AGPL version 3 or any later version
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package signaling
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSingleNotifierNoWaiter(t *testing.T) {
|
||||
var notifier SingleNotifier
|
||||
|
||||
// Notifications can be sent even if no waiter exists.
|
||||
notifier.Notify()
|
||||
}
|
||||
|
||||
func TestSingleNotifierSimple(t *testing.T) {
|
||||
var notifier SingleNotifier
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
waiter := notifier.NewWaiter()
|
||||
defer notifier.Release(waiter)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
if err := waiter.Wait(ctx); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
notifier.Notify()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestSingleNotifierMultiNotify(t *testing.T) {
|
||||
var notifier SingleNotifier
|
||||
|
||||
waiter := notifier.NewWaiter()
|
||||
defer notifier.Release(waiter)
|
||||
|
||||
notifier.Notify()
|
||||
// The second notification will be ignored while the first is still pending.
|
||||
notifier.Notify()
|
||||
}
|
||||
|
||||
func TestSingleNotifierWaitClosed(t *testing.T) {
|
||||
var notifier SingleNotifier
|
||||
|
||||
waiter := notifier.NewWaiter()
|
||||
notifier.Release(waiter)
|
||||
|
||||
if err := waiter.Wait(context.Background()); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSingleNotifierWaitClosedMulti(t *testing.T) {
|
||||
var notifier SingleNotifier
|
||||
|
||||
waiter1 := notifier.NewWaiter()
|
||||
waiter2 := notifier.NewWaiter()
|
||||
notifier.Release(waiter1)
|
||||
notifier.Release(waiter2)
|
||||
|
||||
if err := waiter1.Wait(context.Background()); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := waiter2.Wait(context.Background()); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSingleNotifierResetWillNotify(t *testing.T) {
|
||||
var notifier SingleNotifier
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
waiter := notifier.NewWaiter()
|
||||
defer notifier.Release(waiter)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
if err := waiter.Wait(ctx); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
notifier.Reset()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestSingleNotifierDuplicate(t *testing.T) {
|
||||
var notifier SingleNotifier
|
||||
var wgStart sync.WaitGroup
|
||||
var wgEnd sync.WaitGroup
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
wgStart.Add(1)
|
||||
wgEnd.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wgEnd.Done()
|
||||
waiter := notifier.NewWaiter()
|
||||
defer notifier.Release(waiter)
|
||||
|
||||
// Goroutine has created the waiter and is ready.
|
||||
wgStart.Done()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
if err := waiter.Wait(ctx); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wgStart.Wait()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
notifier.Notify()
|
||||
wgEnd.Wait()
|
||||
}
|
58
syscallconn.go
Executable file
58
syscallconn.go
Executable file
|
@ -0,0 +1,58 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2018 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package signaling
|
||||
|
||||
import (
|
||||
"net"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
type sysConn = syscall.Conn
|
||||
|
||||
// syscallConn keeps reference of rawConn to support syscall.Conn for channelz.
|
||||
// SyscallConn() (the method in interface syscall.Conn) is explicitly
|
||||
// implemented on this type,
|
||||
//
|
||||
// Interface syscall.Conn is implemented by most net.Conn implementations (e.g.
|
||||
// TCPConn, UnixConn), but is not part of net.Conn interface. So wrapper conns
|
||||
// that embed net.Conn don't implement syscall.Conn. (Side note: tls.Conn
|
||||
// doesn't embed net.Conn, so even if syscall.Conn is part of net.Conn, it won't
|
||||
// help here).
|
||||
type syscallConn struct {
|
||||
net.Conn
|
||||
// sysConn is a type alias of syscall.Conn. It's necessary because the name
|
||||
// `Conn` collides with `net.Conn`.
|
||||
sysConn
|
||||
}
|
||||
|
||||
// WrapSyscallConn tries to wrap rawConn and newConn into a net.Conn that
|
||||
// implements syscall.Conn. rawConn will be used to support syscall, and newConn
|
||||
// will be used for read/write.
|
||||
//
|
||||
// This function returns newConn if rawConn doesn't implement syscall.Conn.
|
||||
func WrapSyscallConn(rawConn, newConn net.Conn) net.Conn {
|
||||
sysConn, ok := rawConn.(syscall.Conn)
|
||||
if !ok {
|
||||
return newConn
|
||||
}
|
||||
return &syscallConn{
|
||||
Conn: newConn,
|
||||
sysConn: sysConn,
|
||||
}
|
||||
}
|
|
@ -124,14 +124,14 @@ func checkMessageType(message *ServerMessage, expectedType string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func checkMessageSender(hub *Hub, message *MessageServerMessage, senderType string, hello *HelloServerMessage) error {
|
||||
if message.Sender.Type != senderType {
|
||||
return fmt.Errorf("Expected sender type %s, got %s", senderType, message.Sender.SessionId)
|
||||
} else if message.Sender.SessionId != hello.SessionId {
|
||||
func checkMessageSender(hub *Hub, sender *MessageServerMessageSender, senderType string, hello *HelloServerMessage) error {
|
||||
if sender.Type != senderType {
|
||||
return fmt.Errorf("Expected sender type %s, got %s", senderType, sender.SessionId)
|
||||
} else if sender.SessionId != hello.SessionId {
|
||||
return fmt.Errorf("Expected session id %+v, got %+v",
|
||||
getPubliceSessionIdData(hub, hello.SessionId), getPubliceSessionIdData(hub, message.Sender.SessionId))
|
||||
} else if message.Sender.UserId != hello.UserId {
|
||||
return fmt.Errorf("Expected user id %s, got %s", hello.UserId, message.Sender.UserId)
|
||||
getPubliceSessionIdData(hub, hello.SessionId), getPubliceSessionIdData(hub, sender.SessionId))
|
||||
} else if sender.UserId != hello.UserId {
|
||||
return fmt.Errorf("Expected user id %s, got %s", hello.UserId, sender.UserId)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -143,7 +143,7 @@ func checkReceiveClientMessageWithSender(ctx context.Context, client *TestClient
|
|||
return err
|
||||
} else if err := checkMessageType(message, "message"); err != nil {
|
||||
return err
|
||||
} else if err := checkMessageSender(client.hub, message.Message, senderType, hello); err != nil {
|
||||
} else if err := checkMessageSender(client.hub, message.Message.Sender, senderType, hello); err != nil {
|
||||
return err
|
||||
} else {
|
||||
if err := json.Unmarshal(*message.Message.Data, payload); err != nil {
|
||||
|
@ -160,6 +160,29 @@ func checkReceiveClientMessage(ctx context.Context, client *TestClient, senderTy
|
|||
return checkReceiveClientMessageWithSender(ctx, client, senderType, hello, payload, nil)
|
||||
}
|
||||
|
||||
func checkReceiveClientControlWithSender(ctx context.Context, client *TestClient, senderType string, hello *HelloServerMessage, payload interface{}, sender **MessageServerMessageSender) error {
|
||||
message, err := client.RunUntilMessage(ctx)
|
||||
if err := checkUnexpectedClose(err); err != nil {
|
||||
return err
|
||||
} else if err := checkMessageType(message, "control"); err != nil {
|
||||
return err
|
||||
} else if err := checkMessageSender(client.hub, message.Control.Sender, senderType, hello); err != nil {
|
||||
return err
|
||||
} else {
|
||||
if err := json.Unmarshal(*message.Control.Data, payload); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if sender != nil {
|
||||
*sender = message.Message.Sender
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkReceiveClientControl(ctx context.Context, client *TestClient, senderType string, hello *HelloServerMessage, payload interface{}) error {
|
||||
return checkReceiveClientControlWithSender(ctx, client, senderType, hello, payload, nil)
|
||||
}
|
||||
|
||||
func checkReceiveClientEvent(ctx context.Context, client *TestClient, eventType string, msg **EventServerMessage) error {
|
||||
message, err := client.RunUntilMessage(ctx)
|
||||
if err := checkUnexpectedClose(err); err != nil {
|
||||
|
@ -430,6 +453,25 @@ func (c *TestClient) SendMessage(recipient MessageClientMessageRecipient, data i
|
|||
return c.WriteJSON(message)
|
||||
}
|
||||
|
||||
func (c *TestClient) SendControl(recipient MessageClientMessageRecipient, data interface{}) error {
|
||||
payload, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
c.t.Fatal(err)
|
||||
}
|
||||
|
||||
message := &ClientMessage{
|
||||
Id: "abcd",
|
||||
Type: "control",
|
||||
Control: &ControlClientMessage{
|
||||
MessageClientMessage: MessageClientMessage{
|
||||
Recipient: recipient,
|
||||
Data: (*json.RawMessage)(&payload),
|
||||
},
|
||||
},
|
||||
}
|
||||
return c.WriteJSON(message)
|
||||
}
|
||||
|
||||
func (c *TestClient) SetTransientData(key string, value interface{}) error {
|
||||
payload, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
|
|
|
@ -265,7 +265,7 @@ func TestVirtualSession(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
} else if err := checkMessageType(msg2, "message"); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if err := checkMessageSender(hub, msg2.Message, "session", hello.Hello); err != nil {
|
||||
} else if err := checkMessageSender(hub, msg2.Message.Sender, "session", hello.Hello); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue