/** * Standalone signaling server for the Nextcloud Spreed app. * Copyright (C) 2017 struktur AG * * @author Joachim Bauch * * @license GNU AGPL version 3 or any later version * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see . */ package signaling import ( "bytes" "context" "crypto/hmac" "crypto/sha256" "encoding/hex" "encoding/json" "fmt" "net" "net/http/httptest" "reflect" "strconv" "strings" "sync" "testing" "time" "github.com/golang-jwt/jwt/v4" "github.com/gorilla/websocket" ) var ( testBackendSecret = []byte("secret") testInternalSecret = []byte("internal-secret") ErrNoMessageReceived = fmt.Errorf("no message was received by the server") ) type TestBackendClientAuthParams struct { UserId string `json:"userid"` } func getWebsocketUrl(url string) string { if strings.HasPrefix(url, "http://") { return "ws://" + url[7:] + "/spreed" } else if strings.HasPrefix(url, "https://") { return "wss://" + url[8:] + "/spreed" } else { panic("Unsupported URL: " + url) } } func getPubliceSessionIdData(h *Hub, publicId string) *SessionIdData { decodedPublic := h.decodeSessionId(publicId, publicSessionName) if decodedPublic == nil { panic("invalid public session id") } return decodedPublic } func checkUnexpectedClose(err error) error { if err != nil && websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) { return fmt.Errorf("Connection was closed with unexpected error: %s", err) } return nil } func checkMessageType(message *ServerMessage, expectedType string) error { if message == nil { return ErrNoMessageReceived } if message.Type != expectedType { return fmt.Errorf("Expected \"%s\" message, got %+v", expectedType, message) } switch message.Type { case "hello": if message.Hello == nil { return fmt.Errorf("Expected \"%s\" message, got %+v", expectedType, message) } case "message": if message.Message == nil { return fmt.Errorf("Expected \"%s\" message, got %+v", expectedType, message) } else if message.Message.Data == nil || len(*message.Message.Data) == 0 { return fmt.Errorf("Received message without data") } case "room": if message.Room == nil { return fmt.Errorf("Expected \"%s\" message, got %+v", expectedType, message) } case "event": if message.Event == nil { return fmt.Errorf("Expected \"%s\" message, got %+v", expectedType, message) } case "transient": if message.TransientData == nil { return fmt.Errorf("Expected \"%s\" message, got %+v", expectedType, message) } } return nil } func checkMessageSender(hub *Hub, sender *MessageServerMessageSender, senderType string, hello *HelloServerMessage) error { if sender.Type != senderType { return fmt.Errorf("Expected sender type %s, got %s", senderType, sender.SessionId) } else if sender.SessionId != hello.SessionId { return fmt.Errorf("Expected session id %+v, got %+v", getPubliceSessionIdData(hub, hello.SessionId), getPubliceSessionIdData(hub, sender.SessionId)) } else if sender.UserId != hello.UserId { return fmt.Errorf("Expected user id %s, got %s", hello.UserId, sender.UserId) } return nil } func checkReceiveClientMessageWithSenderAndRecipient(ctx context.Context, client *TestClient, senderType string, hello *HelloServerMessage, payload interface{}, sender **MessageServerMessageSender, recipient **MessageClientMessageRecipient) error { message, err := client.RunUntilMessage(ctx) if err := checkUnexpectedClose(err); err != nil { return err } else if err := checkMessageType(message, "message"); err != nil { return err } else if err := checkMessageSender(client.hub, message.Message.Sender, senderType, hello); err != nil { return err } else { if err := json.Unmarshal(*message.Message.Data, payload); err != nil { return err } } if sender != nil { *sender = message.Message.Sender } if recipient != nil { *recipient = message.Message.Recipient } return nil } func checkReceiveClientMessageWithSender(ctx context.Context, client *TestClient, senderType string, hello *HelloServerMessage, payload interface{}, sender **MessageServerMessageSender) error { return checkReceiveClientMessageWithSenderAndRecipient(ctx, client, senderType, hello, payload, sender, nil) } func checkReceiveClientMessage(ctx context.Context, client *TestClient, senderType string, hello *HelloServerMessage, payload interface{}) error { return checkReceiveClientMessageWithSenderAndRecipient(ctx, client, senderType, hello, payload, nil, nil) } func checkReceiveClientControlWithSenderAndRecipient(ctx context.Context, client *TestClient, senderType string, hello *HelloServerMessage, payload interface{}, sender **MessageServerMessageSender, recipient **MessageClientMessageRecipient) error { message, err := client.RunUntilMessage(ctx) if err := checkUnexpectedClose(err); err != nil { return err } else if err := checkMessageType(message, "control"); err != nil { return err } else if err := checkMessageSender(client.hub, message.Control.Sender, senderType, hello); err != nil { return err } else { if err := json.Unmarshal(*message.Control.Data, payload); err != nil { return err } } if sender != nil { *sender = message.Control.Sender } if recipient != nil { *recipient = message.Control.Recipient } return nil } func checkReceiveClientControlWithSender(ctx context.Context, client *TestClient, senderType string, hello *HelloServerMessage, payload interface{}, sender **MessageServerMessageSender) error { // nolint return checkReceiveClientControlWithSenderAndRecipient(ctx, client, senderType, hello, payload, sender, nil) } func checkReceiveClientControl(ctx context.Context, client *TestClient, senderType string, hello *HelloServerMessage, payload interface{}) error { return checkReceiveClientControlWithSenderAndRecipient(ctx, client, senderType, hello, payload, nil, nil) } func checkReceiveClientEvent(ctx context.Context, client *TestClient, eventType string, msg **EventServerMessage) error { message, err := client.RunUntilMessage(ctx) if err := checkUnexpectedClose(err); err != nil { return err } else if err := checkMessageType(message, "event"); err != nil { return err } else if message.Event.Type != eventType { return fmt.Errorf("Expected \"%s\" event type, got \"%s\"", eventType, message.Event.Type) } else { if msg != nil { *msg = message.Event } } return nil } type TestClient struct { t *testing.T hub *Hub server *httptest.Server mu sync.Mutex conn *websocket.Conn localAddr net.Addr messageChan chan []byte readErrorChan chan error publicId string } func NewTestClientContext(ctx context.Context, t *testing.T, server *httptest.Server, hub *Hub) *TestClient { // Reference "hub" to prevent compiler error. conn, _, err := websocket.DefaultDialer.DialContext(ctx, getWebsocketUrl(server.URL), nil) if err != nil { t.Fatal(err) } messageChan := make(chan []byte) readErrorChan := make(chan error, 1) go func() { for { messageType, data, err := conn.ReadMessage() if err != nil { readErrorChan <- err return } else if messageType != websocket.TextMessage { t.Errorf("Expect text message, got %d", messageType) return } messageChan <- data } }() return &TestClient{ t: t, hub: hub, server: server, conn: conn, localAddr: conn.LocalAddr(), messageChan: messageChan, readErrorChan: readErrorChan, } } func NewTestClient(t *testing.T, server *httptest.Server, hub *Hub) *TestClient { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() client := NewTestClientContext(ctx, t, server, hub) msg, err := client.RunUntilMessage(ctx) if err != nil { t.Fatal(err) } if msg.Type != "welcome" { t.Errorf("Expected welcome message, got %+v", msg) } return client } func (c *TestClient) CloseWithBye() { c.SendBye() // nolint c.Close() } func (c *TestClient) Close() { c.mu.Lock() defer c.mu.Unlock() if err := c.conn.WriteMessage(websocket.CloseMessage, []byte{}); err == websocket.ErrCloseSent { // Already closed return } // Wait a bit for close message to be processed. time.Sleep(100 * time.Millisecond) c.conn.Close() // Drain any entries in the channels to terminate the read goroutine. loop: for { select { case <-c.readErrorChan: case <-c.messageChan: default: break loop } } } func (c *TestClient) WaitForClientRemoved(ctx context.Context) error { c.hub.mu.Lock() defer c.hub.mu.Unlock() for { found := false for _, client := range c.hub.clients { if cc, ok := client.(*Client); ok { cc.mu.Lock() conn := cc.conn cc.mu.Unlock() if conn != nil && conn.RemoteAddr().String() == c.localAddr.String() { found = true break } } } if !found { break } c.hub.mu.Unlock() select { case <-ctx.Done(): c.hub.mu.Lock() return ctx.Err() default: time.Sleep(time.Millisecond) } c.hub.mu.Lock() } return nil } func (c *TestClient) WaitForSessionRemoved(ctx context.Context, sessionId string) error { data := c.hub.decodeSessionId(sessionId, publicSessionName) if data == nil { return fmt.Errorf("Invalid session id passed") } c.hub.mu.Lock() defer c.hub.mu.Unlock() for { _, found := c.hub.sessions[data.Sid] if !found { break } c.hub.mu.Unlock() select { case <-ctx.Done(): c.hub.mu.Lock() return ctx.Err() default: time.Sleep(time.Millisecond) } c.hub.mu.Lock() } return nil } func (c *TestClient) WriteJSON(data interface{}) error { if !strings.Contains(c.t.Name(), "HelloUnsupportedVersion") { if msg, ok := data.(*ClientMessage); ok { if err := msg.CheckValid(); err != nil { return err } } } c.mu.Lock() defer c.mu.Unlock() return c.conn.WriteJSON(data) } func (c *TestClient) EnsuerWriteJSON(data interface{}) { if err := c.WriteJSON(data); err != nil { c.t.Fatalf("Could not write JSON %+v: %s", data, err) } } func (c *TestClient) SendHello(userid string) error { return c.SendHelloV1(userid) } func (c *TestClient) SendHelloV1(userid string) error { params := TestBackendClientAuthParams{ UserId: userid, } return c.SendHelloParams(c.server.URL, HelloVersionV1, "", nil, params) } func (c *TestClient) SendHelloV2(userid string) error { now := time.Now() return c.SendHelloV2WithTimes(userid, now, now.Add(time.Minute)) } func (c *TestClient) SendHelloV2WithTimes(userid string, issuedAt time.Time, expiresAt time.Time) error { userdata := map[string]string{ "displayname": "Displayname " + userid, } data, err := json.Marshal(userdata) if err != nil { c.t.Fatal(err) } claims := &HelloV2TokenClaims{ RegisteredClaims: jwt.RegisteredClaims{ Issuer: c.server.URL, Subject: userid, }, UserData: (*json.RawMessage)(&data), } if !issuedAt.IsZero() { claims.IssuedAt = jwt.NewNumericDate(issuedAt) } if !expiresAt.IsZero() { claims.ExpiresAt = jwt.NewNumericDate(expiresAt) } var token *jwt.Token if strings.Contains(c.t.Name(), "ECDSA") { token = jwt.NewWithClaims(jwt.SigningMethodES256, claims) } else if strings.Contains(c.t.Name(), "Ed25519") { token = jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims) } else { token = jwt.NewWithClaims(jwt.SigningMethodRS256, claims) } private := getPrivateAuthToken(c.t) tokenString, err := token.SignedString(private) if err != nil { c.t.Fatal(err) } params := HelloV2AuthParams{ Token: tokenString, } return c.SendHelloParams(c.server.URL, HelloVersionV2, "", nil, params) } func (c *TestClient) SendHelloResume(resumeId string) error { hello := &ClientMessage{ Id: "1234", Type: "hello", Hello: &HelloClientMessage{ Version: HelloVersionV1, ResumeId: resumeId, }, } return c.WriteJSON(hello) } func (c *TestClient) SendHelloClient(userid string) error { params := TestBackendClientAuthParams{ UserId: userid, } return c.SendHelloParams(c.server.URL, HelloVersionV1, "client", nil, params) } func (c *TestClient) SendHelloInternal() error { return c.SendHelloInternalWithFeatures(nil) } func (c *TestClient) SendHelloInternalWithFeatures(features []string) error { random := newRandomString(48) mac := hmac.New(sha256.New, testInternalSecret) mac.Write([]byte(random)) // nolint token := hex.EncodeToString(mac.Sum(nil)) backend := c.server.URL params := ClientTypeInternalAuthParams{ Random: random, Token: token, Backend: backend, } return c.SendHelloParams("", HelloVersionV1, "internal", features, params) } func (c *TestClient) SendHelloParams(url string, version string, clientType string, features []string, params interface{}) error { data, err := json.Marshal(params) if err != nil { c.t.Fatal(err) } hello := &ClientMessage{ Id: "1234", Type: "hello", Hello: &HelloClientMessage{ Version: version, Features: features, Auth: &HelloClientMessageAuth{ Type: clientType, Url: url, Params: (*json.RawMessage)(&data), }, }, } return c.WriteJSON(hello) } func (c *TestClient) SendBye() error { hello := &ClientMessage{ Id: "9876", Type: "bye", Bye: &ByeClientMessage{}, } return c.WriteJSON(hello) } func (c *TestClient) SendMessage(recipient MessageClientMessageRecipient, data interface{}) error { payload, err := json.Marshal(data) if err != nil { c.t.Fatal(err) } message := &ClientMessage{ Id: "abcd", Type: "message", Message: &MessageClientMessage{ Recipient: recipient, Data: (*json.RawMessage)(&payload), }, } return c.WriteJSON(message) } func (c *TestClient) SendControl(recipient MessageClientMessageRecipient, data interface{}) error { payload, err := json.Marshal(data) if err != nil { c.t.Fatal(err) } message := &ClientMessage{ Id: "abcd", Type: "control", Control: &ControlClientMessage{ MessageClientMessage: MessageClientMessage{ Recipient: recipient, Data: (*json.RawMessage)(&payload), }, }, } return c.WriteJSON(message) } func (c *TestClient) SendInternalAddSession(msg *AddSessionInternalClientMessage) error { message := &ClientMessage{ Id: "abcd", Type: "internal", Internal: &InternalClientMessage{ Type: "addsession", AddSession: msg, }, } return c.WriteJSON(message) } func (c *TestClient) SendInternalUpdateSession(msg *UpdateSessionInternalClientMessage) error { message := &ClientMessage{ Id: "abcd", Type: "internal", Internal: &InternalClientMessage{ Type: "updatesession", UpdateSession: msg, }, } return c.WriteJSON(message) } func (c *TestClient) SendInternalRemoveSession(msg *RemoveSessionInternalClientMessage) error { message := &ClientMessage{ Id: "abcd", Type: "internal", Internal: &InternalClientMessage{ Type: "removesession", RemoveSession: msg, }, } return c.WriteJSON(message) } func (c *TestClient) SendInternalDialout(msg *DialoutInternalClientMessage) error { message := &ClientMessage{ Id: "abcd", Type: "internal", Internal: &InternalClientMessage{ Type: "dialout", Dialout: msg, }, } return c.WriteJSON(message) } func (c *TestClient) SetTransientData(key string, value interface{}, ttl time.Duration) error { payload, err := json.Marshal(value) if err != nil { c.t.Fatal(err) } message := &ClientMessage{ Id: "efgh", Type: "transient", TransientData: &TransientDataClientMessage{ Type: "set", Key: key, Value: (*json.RawMessage)(&payload), TTL: ttl, }, } return c.WriteJSON(message) } func (c *TestClient) RemoveTransientData(key string) error { message := &ClientMessage{ Id: "ijkl", Type: "transient", TransientData: &TransientDataClientMessage{ Type: "remove", Key: key, }, } return c.WriteJSON(message) } func (c *TestClient) DrainMessages(ctx context.Context) error { select { case err := <-c.readErrorChan: return err case <-c.messageChan: n := len(c.messageChan) for i := 0; i < n; i++ { <-c.messageChan } case <-ctx.Done(): return ctx.Err() } return nil } func (c *TestClient) GetPendingMessages(ctx context.Context) ([]*ServerMessage, error) { var result []*ServerMessage select { case err := <-c.readErrorChan: return nil, err case msg := <-c.messageChan: var m ServerMessage if err := json.Unmarshal(msg, &m); err != nil { return nil, err } result = append(result, &m) n := len(c.messageChan) for i := 0; i < n; i++ { var m ServerMessage msg = <-c.messageChan if err := json.Unmarshal(msg, &m); err != nil { return nil, err } result = append(result, &m) } case <-ctx.Done(): return nil, ctx.Err() } return result, nil } func (c *TestClient) RunUntilMessage(ctx context.Context) (message *ServerMessage, err error) { select { case err = <-c.readErrorChan: case msg := <-c.messageChan: var m ServerMessage if err = json.Unmarshal(msg, &m); err == nil { message = &m } case <-ctx.Done(): err = ctx.Err() } return } func (c *TestClient) RunUntilHello(ctx context.Context) (message *ServerMessage, err error) { if message, err = c.RunUntilMessage(ctx); err != nil { return nil, err } if err := checkUnexpectedClose(err); err != nil { return nil, err } if err := checkMessageType(message, "hello"); err != nil { return nil, err } c.publicId = message.Hello.SessionId return message, nil } func (c *TestClient) JoinRoom(ctx context.Context, roomId string) (message *ServerMessage, err error) { return c.JoinRoomWithRoomSession(ctx, roomId, roomId+"-"+c.publicId) } func (c *TestClient) JoinRoomWithRoomSession(ctx context.Context, roomId string, roomSessionId string) (message *ServerMessage, err error) { msg := &ClientMessage{ Id: "ABCD", Type: "room", Room: &RoomClientMessage{ RoomId: roomId, SessionId: roomSessionId, }, } if err := c.WriteJSON(msg); err != nil { return nil, err } if message, err = c.RunUntilMessage(ctx); err != nil { return nil, err } if err := checkUnexpectedClose(err); err != nil { return nil, err } if err := checkMessageType(message, "room"); err != nil { return nil, err } return message, nil } func checkMessageRoomId(message *ServerMessage, roomId string) error { if err := checkMessageType(message, "room"); err != nil { return err } if message.Room.RoomId != roomId { return fmt.Errorf("Expected room id %s, got %+v", roomId, message.Room) } return nil } func (c *TestClient) RunUntilRoom(ctx context.Context, roomId string) error { message, err := c.RunUntilMessage(ctx) if err != nil { return err } if err := checkUnexpectedClose(err); err != nil { return err } return checkMessageRoomId(message, roomId) } func (c *TestClient) checkMessageJoined(message *ServerMessage, hello *HelloServerMessage) error { return c.checkMessageJoinedSession(message, hello.SessionId, hello.UserId) } func (c *TestClient) checkSingleMessageJoined(message *ServerMessage) error { if err := checkMessageType(message, "event"); err != nil { return err } else if message.Event.Target != "room" { return fmt.Errorf("Expected event target room, got %+v", message.Event) } else if message.Event.Type != "join" { return fmt.Errorf("Expected event type join, got %+v", message.Event) } else if len(message.Event.Join) != 1 { return fmt.Errorf("Expected one join event entry, got %+v", message.Event) } return nil } func (c *TestClient) checkMessageJoinedSession(message *ServerMessage, sessionId string, userId string) error { if err := c.checkSingleMessageJoined(message); err != nil { return err } evt := message.Event.Join[0] if sessionId != "" && evt.SessionId != sessionId { return fmt.Errorf("Expected join session id %+v, got %+v", getPubliceSessionIdData(c.hub, sessionId), getPubliceSessionIdData(c.hub, evt.SessionId)) } if evt.UserId != userId { return fmt.Errorf("Expected join user id %s, got %+v", userId, evt) } return nil } func (c *TestClient) RunUntilJoinedAndReturn(ctx context.Context, hello ...*HelloServerMessage) ([]*EventServerMessageSessionEntry, []*ServerMessage, error) { received := make([]*EventServerMessageSessionEntry, len(hello)) var ignored []*ServerMessage hellos := make(map[*HelloServerMessage]int, len(hello)) for idx, h := range hello { hellos[h] = idx } for len(hellos) > 0 { message, err := c.RunUntilMessage(ctx) if err != nil { return nil, nil, fmt.Errorf("got error while waiting for %+v: %w", hellos, err) } if err := checkMessageType(message, "event"); err != nil { ignored = append(ignored, message) continue } else if message.Event.Target != "room" || message.Event.Type != "join" { ignored = append(ignored, message) continue } for len(message.Event.Join) > 0 { found := false loop: for h, idx := range hellos { for idx2, evt := range message.Event.Join { if evt.SessionId == h.SessionId && evt.UserId == h.UserId { received[idx] = evt delete(hellos, h) message.Event.Join = append(message.Event.Join[:idx2], message.Event.Join[idx2+1:]...) found = true break loop } } } if !found { return nil, nil, fmt.Errorf("expected one of the passed hello sessions, got %+v", message.Event.Join[0]) } } } return received, ignored, nil } func (c *TestClient) RunUntilJoined(ctx context.Context, hello ...*HelloServerMessage) error { _, unexpected, err := c.RunUntilJoinedAndReturn(ctx, hello...) if err != nil { return err } if len(unexpected) > 0 { return fmt.Errorf("Received unexpected messages: %+v", unexpected) } return nil } func (c *TestClient) checkMessageRoomLeave(message *ServerMessage, hello *HelloServerMessage) error { return c.checkMessageRoomLeaveSession(message, hello.SessionId) } func (c *TestClient) checkMessageRoomLeaveSession(message *ServerMessage, sessionId string) error { if err := checkMessageType(message, "event"); err != nil { return err } else if message.Event.Target != "room" { return fmt.Errorf("Expected event target room, got %+v", message.Event) } else if message.Event.Type != "leave" { return fmt.Errorf("Expected event type leave, got %+v", message.Event) } else if len(message.Event.Leave) != 1 { return fmt.Errorf("Expected one leave event entry, got %+v", message.Event) } else if message.Event.Leave[0] != sessionId { return fmt.Errorf("Expected leave session id %+v, got %+v", getPubliceSessionIdData(c.hub, sessionId), getPubliceSessionIdData(c.hub, message.Event.Leave[0])) } return nil } func (c *TestClient) RunUntilLeft(ctx context.Context, hello *HelloServerMessage) error { message, err := c.RunUntilMessage(ctx) if err != nil { return err } return c.checkMessageRoomLeave(message, hello) } func checkMessageRoomlistUpdate(message *ServerMessage) (*RoomEventServerMessage, error) { if err := checkMessageType(message, "event"); err != nil { return nil, err } else if message.Event.Target != "roomlist" { return nil, fmt.Errorf("Expected event target room, got %+v", message.Event) } else if message.Event.Type != "update" || message.Event.Update == nil { return nil, fmt.Errorf("Expected event type update, got %+v", message.Event) } else { return message.Event.Update, nil } } func (c *TestClient) RunUntilRoomlistUpdate(ctx context.Context) (*RoomEventServerMessage, error) { message, err := c.RunUntilMessage(ctx) if err != nil { return nil, err } return checkMessageRoomlistUpdate(message) } func checkMessageRoomlistDisinvite(message *ServerMessage) (*RoomDisinviteEventServerMessage, error) { if err := checkMessageType(message, "event"); err != nil { return nil, err } else if message.Event.Target != "roomlist" { return nil, fmt.Errorf("Expected event target room, got %+v", message.Event) } else if message.Event.Type != "disinvite" || message.Event.Disinvite == nil { return nil, fmt.Errorf("Expected event type disinvite, got %+v", message.Event) } return message.Event.Disinvite, nil } func (c *TestClient) RunUntilRoomlistDisinvite(ctx context.Context) (*RoomDisinviteEventServerMessage, error) { message, err := c.RunUntilMessage(ctx) if err != nil { return nil, err } return checkMessageRoomlistDisinvite(message) } func checkMessageParticipantsInCall(message *ServerMessage) (*RoomEventServerMessage, error) { if err := checkMessageType(message, "event"); err != nil { return nil, err } else if message.Event.Target != "participants" { return nil, fmt.Errorf("Expected event target participants, got %+v", message.Event) } else if message.Event.Type != "update" || message.Event.Update == nil { return nil, fmt.Errorf("Expected event type update, got %+v", message.Event) } return message.Event.Update, nil } func checkMessageParticipantFlags(message *ServerMessage) (*RoomFlagsServerMessage, error) { if err := checkMessageType(message, "event"); err != nil { return nil, err } else if message.Event.Target != "participants" { return nil, fmt.Errorf("Expected event target room, got %+v", message.Event) } else if message.Event.Type != "flags" || message.Event.Flags == nil { return nil, fmt.Errorf("Expected event type flags, got %+v", message.Event) } return message.Event.Flags, nil } func checkMessageRoomMessage(message *ServerMessage) (*RoomEventMessage, error) { if err := checkMessageType(message, "event"); err != nil { return nil, err } else if message.Event.Target != "room" { return nil, fmt.Errorf("Expected event target room, got %+v", message.Event) } else if message.Event.Type != "message" || message.Event.Message == nil { return nil, fmt.Errorf("Expected event type message, got %+v", message.Event) } return message.Event.Message, nil } func (c *TestClient) RunUntilRoomMessage(ctx context.Context) (*RoomEventMessage, error) { message, err := c.RunUntilMessage(ctx) if err != nil { return nil, err } return checkMessageRoomMessage(message) } func checkMessageError(message *ServerMessage, msgid string) error { if err := checkMessageType(message, "error"); err != nil { return err } else if message.Error.Code != msgid { return fmt.Errorf("Expected error \"%s\", got \"%s\" (%+v)", msgid, message.Error.Code, message.Error) } return nil } func (c *TestClient) RunUntilOffer(ctx context.Context, offer string) error { message, err := c.RunUntilMessage(ctx) if err != nil { return err } if err := checkUnexpectedClose(err); err != nil { return err } else if err := checkMessageType(message, "message"); err != nil { return err } var data map[string]interface{} if err := json.Unmarshal(*message.Message.Data, &data); err != nil { return err } if data["type"].(string) != "offer" { return fmt.Errorf("expected data type offer, got %+v", data) } payload := data["payload"].(map[string]interface{}) if payload["type"].(string) != "offer" { return fmt.Errorf("expected payload type offer, got %+v", payload) } if payload["sdp"].(string) != offer { return fmt.Errorf("expected payload answer %s, got %+v", offer, payload) } return nil } func (c *TestClient) RunUntilAnswer(ctx context.Context, answer string) error { message, err := c.RunUntilMessage(ctx) if err != nil { return err } if err := checkUnexpectedClose(err); err != nil { return err } else if err := checkMessageType(message, "message"); err != nil { return err } var data map[string]interface{} if err := json.Unmarshal(*message.Message.Data, &data); err != nil { return err } if data["type"].(string) != "answer" { return fmt.Errorf("expected data type answer, got %+v", data) } payload := data["payload"].(map[string]interface{}) if payload["type"].(string) != "answer" { return fmt.Errorf("expected payload type answer, got %+v", payload) } if payload["sdp"].(string) != answer { return fmt.Errorf("expected payload answer %s, got %+v", answer, payload) } return nil } func checkMessageTransientSet(message *ServerMessage, key string, value interface{}, oldValue interface{}) error { if err := checkMessageType(message, "transient"); err != nil { return err } else if message.TransientData.Type != "set" { return fmt.Errorf("Expected transient set, got %+v", message.TransientData) } else if message.TransientData.Key != key { return fmt.Errorf("Expected transient set key %s, got %+v", key, message.TransientData) } else if !reflect.DeepEqual(message.TransientData.Value, value) { return fmt.Errorf("Expected transient set value %+v, got %+v", value, message.TransientData.Value) } else if !reflect.DeepEqual(message.TransientData.OldValue, oldValue) { return fmt.Errorf("Expected transient set old value %+v, got %+v", oldValue, message.TransientData.OldValue) } return nil } func checkMessageTransientRemove(message *ServerMessage, key string, oldValue interface{}) error { if err := checkMessageType(message, "transient"); err != nil { return err } else if message.TransientData.Type != "remove" { return fmt.Errorf("Expected transient remove, got %+v", message.TransientData) } else if message.TransientData.Key != key { return fmt.Errorf("Expected transient remove key %s, got %+v", key, message.TransientData) } else if !reflect.DeepEqual(message.TransientData.OldValue, oldValue) { return fmt.Errorf("Expected transient remove old value %+v, got %+v", oldValue, message.TransientData.OldValue) } return nil } func checkMessageTransientInitial(message *ServerMessage, data map[string]interface{}) error { if err := checkMessageType(message, "transient"); err != nil { return err } else if message.TransientData.Type != "initial" { return fmt.Errorf("Expected transient initial, got %+v", message.TransientData) } else if !reflect.DeepEqual(message.TransientData.Data, data) { return fmt.Errorf("Expected transient initial data %+v, got %+v", data, message.TransientData.Data) } return nil } func checkMessageInCallAll(message *ServerMessage, roomId string, inCall int) error { if err := checkMessageType(message, "event"); err != nil { return err } else if message.Event.Type != "update" { return fmt.Errorf("Expected update event, got %+v", message.Event) } else if message.Event.Target != "participants" { return fmt.Errorf("Expected participants update event, got %+v", message.Event) } else if message.Event.Update.RoomId != roomId { return fmt.Errorf("Expected participants update event for room %s, got %+v", roomId, message.Event.Update) } else if !message.Event.Update.All { return fmt.Errorf("Expected participants update event for all, got %+v", message.Event.Update) } else if !bytes.Equal(*message.Event.Update.InCall, []byte(strconv.FormatInt(int64(inCall), 10))) { return fmt.Errorf("Expected incall flags %d, got %+v", inCall, message.Event.Update) } return nil } func checkMessageSwitchTo(message *ServerMessage, roomId string, details json.RawMessage) (*EventServerMessageSwitchTo, error) { if err := checkMessageType(message, "event"); err != nil { return nil, err } else if message.Event.Type != "switchto" { return nil, fmt.Errorf("Expected switchto event, got %+v", message.Event) } else if message.Event.Target != "room" { return nil, fmt.Errorf("Expected room switchto event, got %+v", message.Event) } else if message.Event.SwitchTo.RoomId != roomId { return nil, fmt.Errorf("Expected room switchto event for room %s, got %+v", roomId, message.Event) } if details != nil { if message.Event.SwitchTo.Details == nil || !bytes.Equal(details, message.Event.SwitchTo.Details) { return nil, fmt.Errorf("Expected details %s, got %+v", string(details), message.Event) } } else if message.Event.SwitchTo.Details != nil { return nil, fmt.Errorf("Expected no details, got %+v", message.Event) } return message.Event.SwitchTo, nil } func (c *TestClient) RunUntilSwitchTo(ctx context.Context, roomId string, details json.RawMessage) (*EventServerMessageSwitchTo, error) { message, err := c.RunUntilMessage(ctx) if err != nil { return nil, err } return checkMessageSwitchTo(message, roomId, details) }