From 669ec5c1af8116510c46ccd3791841cd20d39ea8 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Tue, 20 Oct 2020 14:29:58 +0200 Subject: [PATCH] Support "internal" messages to create/update/remove virtual sessions. --- src/signaling/api_backend.go | 35 +++ src/signaling/api_signaling.go | 125 ++++++++- src/signaling/backend_client.go | 2 +- src/signaling/clientsession.go | 34 ++- src/signaling/hub.go | 272 ++++++++++++++++--- src/signaling/hub_test.go | 41 ++- src/signaling/room.go | 43 ++- src/signaling/testclient_test.go | 20 +- src/signaling/virtualsession.go | 195 ++++++++++++++ src/signaling/virtualsession_test.go | 378 +++++++++++++++++++++++++++ 10 files changed, 1094 insertions(+), 51 deletions(-) create mode 100644 src/signaling/virtualsession.go create mode 100644 src/signaling/virtualsession_test.go diff --git a/src/signaling/api_backend.go b/src/signaling/api_backend.go index 37d2e5e..a58b38a 100644 --- a/src/signaling/api_backend.go +++ b/src/signaling/api_backend.go @@ -155,6 +155,8 @@ type BackendClientRequest struct { Room *BackendClientRoomRequest `json:"room,omitempty"` Ping *BackendClientPingRequest `json:"ping,omitempty"` + + Session *BackendClientSessionRequest `json:"session,omitempty"` } func NewBackendClientAuthRequest(params *json.RawMessage) *BackendClientRequest { @@ -177,6 +179,8 @@ type BackendClientResponse struct { Room *BackendClientRoomResponse `json:"room,omitempty"` Ping *BackendClientRingResponse `json:"ping,omitempty"` + + Session *BackendClientSessionResponse `json:"session,omitempty"` } type BackendClientAuthResponse struct { @@ -249,6 +253,37 @@ type BackendClientRingResponse struct { RoomId string `json:"roomid"` } +type BackendClientSessionRequest struct { + Version string `json:"version"` + RoomId string `json:"roomid"` + Action string `json:"action"` + SessionId string `json:"sessionid"` + UserId string `json:"userid,omitempty"` + User *json.RawMessage `json:"user,omitempty"` +} + +type BackendClientSessionResponse struct { + Version string `json:"version"` + RoomId string `json:"roomid"` +} + +func NewBackendClientSessionRequest(roomid string, action string, sessionid string, msg *AddSessionInternalClientMessage) *BackendClientRequest { + request := &BackendClientRequest{ + Type: "session", + Session: &BackendClientSessionRequest{ + Version: BackendVersion, + RoomId: roomid, + Action: action, + SessionId: sessionid, + }, + } + if msg != nil { + request.Session.UserId = msg.UserId + request.Session.User = msg.User + } + return request +} + type OcsMeta struct { Status string `json:"status"` StatusCode int `json:"statuscode"` diff --git a/src/signaling/api_signaling.go b/src/signaling/api_signaling.go index 5762459..2a67b89 100644 --- a/src/signaling/api_signaling.go +++ b/src/signaling/api_signaling.go @@ -50,6 +50,8 @@ type ClientMessage struct { Message *MessageClientMessage `json:"message,omitempty"` Control *ControlClientMessage `json:"control,omitempty"` + + Internal *InternalClientMessage `json:"internal,omitempty"` } func (m *ClientMessage) CheckValid() error { @@ -82,6 +84,12 @@ func (m *ClientMessage) CheckValid() error { } else if err := m.Control.CheckValid(); err != nil { return err } + case "internal": + if m.Internal == nil { + return fmt.Errorf("internal missing") + } else if err := m.Internal.CheckValid(); err != nil { + return err + } } return nil } @@ -191,6 +199,8 @@ func (e *Error) Error() string { const ( HelloClientTypeClient = "client" HelloClientTypeInternal = "internal" + + HelloClientTypeVirtual = "virtual" ) type ClientTypeInternalAuthParams struct { @@ -272,7 +282,18 @@ func (m *HelloClientMessage) CheckValid() error { } const ( + // Features for all clients. ServerFeatureMcu = "mcu" + + // Features for internal clients only. + ServerFeatureInternalVirtualSessions = "virtual-sessions" +) + +var ( + DefaultFeatures []string + DefaultFeaturesInternal []string = []string{ + ServerFeatureInternalVirtualSessions, + } ) type HelloServerMessageServer struct { @@ -388,7 +409,8 @@ type MessageServerMessageData struct { } type MessageServerMessage struct { - Sender *MessageServerMessageSender `json:"sender"` + Sender *MessageServerMessageSender `json:"sender"` + Recipient *MessageClientMessageRecipient `json:"recipient,omitempty"` Data *json.RawMessage `json:"data"` } @@ -399,12 +421,111 @@ type ControlClientMessage struct { MessageClientMessage } +func (m *ControlClientMessage) CheckValid() error { + if err := m.MessageClientMessage.CheckValid(); err != nil { + return err + } + return nil +} + type ControlServerMessage struct { - Sender *MessageServerMessageSender `json:"sender"` + Sender *MessageServerMessageSender `json:"sender"` + Recipient *MessageClientMessageRecipient `json:"recipient,omitempty"` Data *json.RawMessage `json:"data"` } +// Type "internal" + +type CommonSessionInternalClientMessage struct { + SessionId string `json:"sessionid"` + + RoomId string `json:"roomid"` +} + +func (m *CommonSessionInternalClientMessage) CheckValid() error { + if m.SessionId == "" { + return fmt.Errorf("sessionid missing") + } + if m.RoomId == "" { + return fmt.Errorf("roomid missing") + } + return nil +} + +type AddSessionInternalClientMessage struct { + CommonSessionInternalClientMessage + + UserId string `json:"userid,omitempty"` + User *json.RawMessage `json:"user,omitempty"` + Flags uint32 `json:"flags,omitempty"` +} + +func (m *AddSessionInternalClientMessage) CheckValid() error { + if err := m.CommonSessionInternalClientMessage.CheckValid(); err != nil { + return err + } + return nil +} + +type UpdateSessionInternalClientMessage struct { + CommonSessionInternalClientMessage + + Flags *uint32 `json:"flags,omitempty"` +} + +func (m *UpdateSessionInternalClientMessage) CheckValid() error { + if err := m.CommonSessionInternalClientMessage.CheckValid(); err != nil { + return err + } + return nil +} + +type RemoveSessionInternalClientMessage struct { + CommonSessionInternalClientMessage +} + +func (m *RemoveSessionInternalClientMessage) CheckValid() error { + if err := m.CommonSessionInternalClientMessage.CheckValid(); err != nil { + return err + } + return nil +} + +type InternalClientMessage struct { + Type string `json:"type"` + + AddSession *AddSessionInternalClientMessage `json:"addsession,omitempty"` + + UpdateSession *UpdateSessionInternalClientMessage `json:"updatesession,omitempty"` + + RemoveSession *RemoveSessionInternalClientMessage `json:"removesession,omitempty"` +} + +func (m *InternalClientMessage) CheckValid() error { + switch m.Type { + case "addsession": + if m.AddSession == nil { + return fmt.Errorf("addsession missing") + } else if err := m.AddSession.CheckValid(); err != nil { + return err + } + case "updatesession": + if m.UpdateSession == nil { + return fmt.Errorf("updatesession missing") + } else if err := m.UpdateSession.CheckValid(); err != nil { + return err + } + case "removesession": + if m.RemoveSession == nil { + return fmt.Errorf("removesession missing") + } else if err := m.RemoveSession.CheckValid(); err != nil { + return err + } + } + return nil +} + // Type "event" type RoomEventServerMessage struct { diff --git a/src/signaling/backend_client.go b/src/signaling/backend_client.go index 44aa120..882bc95 100644 --- a/src/signaling/backend_client.go +++ b/src/signaling/backend_client.go @@ -347,7 +347,7 @@ func (b *BackendClient) PerformJSONRequest(ctx context.Context, u *url.URL, requ return err } - if isOcsRequest(u) { + if isOcsRequest(u) || req.Header.Get("OCS-APIRequest") != "" { // OCS response are wrapped in an OCS container that needs to be parsed // to get the actual contents: // { diff --git a/src/signaling/clientsession.go b/src/signaling/clientsession.go index 5c3cdf4..4900eb3 100644 --- a/src/signaling/clientsession.go +++ b/src/signaling/clientsession.go @@ -82,6 +82,8 @@ type ClientSession struct { pendingClientMessages []*ServerMessage hasPendingChat bool hasPendingParticipantsUpdate bool + + virtualSessions map[*VirtualSession]bool } func NewClientSession(hub *Hub, privateId string, publicId string, data *SessionIdData, backend *Backend, hello *HelloClientMessage, auth *BackendClientAuthResponse) (*ClientSession, error) { @@ -96,14 +98,19 @@ func NewClientSession(hub *Hub, privateId string, publicId string, data *Session userId: auth.UserId, userData: auth.User, - backend: backend, - backendUrl: hello.Auth.Url, - parsedBackendUrl: hello.Auth.parsedUrl, + backend: backend, natsReceiver: make(chan *nats.Msg, 64), stopRun: make(chan bool, 1), runStopped: make(chan bool, 1), } + if s.clientType == HelloClientTypeInternal { + s.backendUrl = hello.Auth.internalParams.Backend + s.parsedBackendUrl = hello.Auth.internalParams.parsedBackend + } else { + s.backendUrl = hello.Auth.Url + s.parsedBackendUrl = hello.Auth.parsedUrl + } if err := s.SubscribeNats(hub.nats); err != nil { return nil, err } @@ -293,6 +300,12 @@ func (s *ClientSession) closeAndWait(wait bool) { s.sessionSubscription.Unsubscribe() s.sessionSubscription = nil } + go func(virtualSessions map[*VirtualSession]bool) { + for session, _ := range virtualSessions { + session.Close() + } + }(s.virtualSessions) + s.virtualSessions = nil s.releaseMcuObjects() s.clearClientLocked(nil) if atomic.CompareAndSwapInt32(&s.running, 1, 0) { @@ -783,3 +796,18 @@ func (s *ClientSession) NotifySessionResumed(client *Client) { } } } + +func (s *ClientSession) AddVirtualSession(session *VirtualSession) { + s.mu.Lock() + if s.virtualSessions == nil { + s.virtualSessions = make(map[*VirtualSession]bool) + } + s.virtualSessions[session] = true + s.mu.Unlock() +} + +func (s *ClientSession) RemoveVirtualSession(session *VirtualSession) { + s.mu.Lock() + delete(s.virtualSessions, session) + s.mu.Unlock() +} diff --git a/src/signaling/hub.go b/src/signaling/hub.go index eb46c41..742c048 100644 --- a/src/signaling/hub.go +++ b/src/signaling/hub.go @@ -96,10 +96,11 @@ const ( ) type Hub struct { - nats NatsClient - upgrader websocket.Upgrader - cookie *securecookie.SecureCookie - info *HelloServerMessageServer + nats NatsClient + upgrader websocket.Upgrader + cookie *securecookie.SecureCookie + info *HelloServerMessageServer + infoInternal *HelloServerMessageServer stopped int32 stopChan chan bool @@ -117,7 +118,8 @@ type Hub struct { sessions map[uint64]Session rooms map[string]*Room - roomSessions RoomSessions + roomSessions RoomSessions + virtualSessions map[string]uint64 decodeCaches []*LruCache @@ -276,7 +278,12 @@ func NewHub(config *goconf.ConfigFile, nats NatsClient, r *mux.Router, version s }, cookie: securecookie.New([]byte(hashKey), blockBytes).MaxAge(0), info: &HelloServerMessageServer{ - Version: version, + Version: version, + Features: DefaultFeatures, + }, + infoInternal: &HelloServerMessageServer{ + Version: version, + Features: DefaultFeaturesInternal, }, stopChan: make(chan bool), @@ -290,7 +297,8 @@ func NewHub(config *goconf.ConfigFile, nats NatsClient, r *mux.Router, version s sessions: make(map[uint64]Session), rooms: make(map[string]*Room), - roomSessions: roomSessions, + roomSessions: roomSessions, + virtualSessions: make(map[string]uint64), decodeCaches: decodeCaches, @@ -315,29 +323,41 @@ func NewHub(config *goconf.ConfigFile, nats NatsClient, r *mux.Router, version s return hub, nil } -func (h *Hub) SetMcu(mcu Mcu) { - h.mcu = mcu +func addFeature(msg *HelloServerMessageServer, feature string) { var newFeatures []string - if mcu == nil { - for _, f := range h.info.Features { - if f != ServerFeatureMcu { - newFeatures = append(newFeatures, f) - } - } - } else { - log.Printf("Using a timeout of %s for MCU requests", h.mcuTimeout) - added := false - for _, f := range h.info.Features { - newFeatures = append(newFeatures, f) - if f == ServerFeatureMcu { - added = true - } - } - if !added { - newFeatures = append(newFeatures, ServerFeatureMcu) + added := false + for _, f := range msg.Features { + newFeatures = append(newFeatures, f) + if f == feature { + added = true } } - h.info.Features = newFeatures + if !added { + newFeatures = append(newFeatures, feature) + } + msg.Features = newFeatures +} + +func removeFeature(msg *HelloServerMessageServer, feature string) { + var newFeatures []string + for _, f := range msg.Features { + if f != feature { + newFeatures = append(newFeatures, f) + } + } + msg.Features = newFeatures +} + +func (h *Hub) SetMcu(mcu Mcu) { + h.mcu = mcu + if mcu == nil { + removeFeature(h.info, ServerFeatureMcu) + removeFeature(h.infoInternal, ServerFeatureMcu) + } else { + log.Printf("Using a timeout of %s for MCU requests", h.mcuTimeout) + addFeature(h.info, ServerFeatureMcu) + addFeature(h.infoInternal, ServerFeatureMcu) + } } func (h *Hub) checkOrigin(r *http.Request) bool { @@ -345,7 +365,11 @@ func (h *Hub) checkOrigin(r *http.Request) bool { return true } -func (h *Hub) GetServerInfo() *HelloServerMessageServer { +func (h *Hub) GetServerInfo(session Session) *HelloServerMessageServer { + if session.ClientType() == HelloClientTypeInternal { + return h.infoInternal + } + return h.info } @@ -623,6 +647,19 @@ func (h *Hub) processNewClient(client *Client) { h.startExpectHello(client) } +func (h *Hub) newSessionIdData(backend *Backend) *SessionIdData { + sid := atomic.AddUint64(&h.sid, 1) + for sid == 0 { + sid = atomic.AddUint64(&h.sid, 1) + } + sessionIdData := &SessionIdData{ + Sid: sid, + Created: time.Now(), + BackendId: backend.Id(), + } + return sessionIdData +} + func (h *Hub) processRegister(client *Client, message *ClientMessage, backend *Backend, auth *BackendClientResponse) { if !client.IsConnected() { // Client disconnected while waiting for "hello" response. @@ -641,11 +678,7 @@ func (h *Hub) processRegister(client *Client, message *ClientMessage, backend *B for sid == 0 { sid = atomic.AddUint64(&h.sid, 1) } - sessionIdData := &SessionIdData{ - Sid: sid, - Created: time.Now(), - BackendId: backend.Id(), - } + sessionIdData := h.newSessionIdData(backend) privateSessionId, err := h.encodeSessionId(sessionIdData, privateSessionName) if err != nil { client.SendMessage(message.NewWrappedErrorServerMessage(err)) @@ -755,6 +788,8 @@ func (h *Hub) processMessage(client *Client, data []byte) { h.processMessageMsg(client, &message) case "control": h.processControlMsg(client, &message) + case "internal": + h.processInternalMsg(client, &message) case "bye": h.processByeMsg(client, &message) case "hello": @@ -773,7 +808,7 @@ func (h *Hub) sendHelloResponse(client *Client, message *ClientMessage, session SessionId: session.PublicId(), ResumeId: session.PrivateId(), UserId: session.UserId(), - Server: h.GetServerInfo(), + Server: h.GetServerInfo(session), }, } return client.SendMessage(response) @@ -1147,6 +1182,7 @@ func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) { var recipient *Client var subject string var clientData *MessageClientMessageData + var serverRecipient *MessageClientMessageRecipient switch msg.Recipient.Type { case RecipientTypeSession: data := h.decodeSessionId(msg.Recipient.SessionId, publicSessionName) @@ -1188,6 +1224,21 @@ func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) { subject = "session." + msg.Recipient.SessionId h.mu.RLock() recipient = h.clients[data.Sid] + if recipient == nil { + // Send to client connection for virtual sessions. + sess := h.sessions[data.Sid] + if sess != nil && sess.ClientType() == HelloClientTypeVirtual { + virtualSession := sess.(*VirtualSession) + clientSession := virtualSession.Session() + subject = "session." + clientSession.PublicId() + recipient = clientSession.GetClient() + // The client should see his session id as recipient. + serverRecipient = &MessageClientMessageRecipient{ + Type: "session", + SessionId: virtualSession.SessionId(), + } + } + } h.mu.RUnlock() } case RecipientTypeUser: @@ -1251,7 +1302,8 @@ func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) { SessionId: session.PublicId(), UserId: session.UserId(), }, - Data: msg.Data, + Recipient: serverRecipient, + Data: msg.Data, }, } if recipient != nil { @@ -1312,6 +1364,7 @@ func (h *Hub) processControlMsg(client *Client, message *ClientMessage) { var recipient *Client var subject string + var serverRecipient *MessageClientMessageRecipient switch msg.Recipient.Type { case RecipientTypeSession: data := h.decodeSessionId(msg.Recipient.SessionId, publicSessionName) @@ -1324,6 +1377,21 @@ func (h *Hub) processControlMsg(client *Client, message *ClientMessage) { subject = "session." + msg.Recipient.SessionId h.mu.RLock() recipient = h.clients[data.Sid] + if recipient == nil { + // Send to client connection for virtual sessions. + sess := h.sessions[data.Sid] + if sess != nil && sess.ClientType() == HelloClientTypeVirtual { + virtualSession := sess.(*VirtualSession) + clientSession := virtualSession.Session() + subject = "session." + clientSession.PublicId() + recipient = clientSession.GetClient() + // The client should see his session id as recipient. + serverRecipient = &MessageClientMessageRecipient{ + Type: "session", + SessionId: virtualSession.SessionId(), + } + } + } h.mu.RUnlock() } case RecipientTypeUser: @@ -1357,7 +1425,8 @@ func (h *Hub) processControlMsg(client *Client, message *ClientMessage) { SessionId: session.PublicId(), UserId: session.UserId(), }, - Data: msg.Data, + Recipient: serverRecipient, + Data: msg.Data, }, } if recipient != nil { @@ -1367,6 +1436,137 @@ func (h *Hub) processControlMsg(client *Client, message *ClientMessage) { } } +func (h *Hub) processInternalMsg(client *Client, message *ClientMessage) { + msg := message.Control + session := client.GetSession() + if session == nil { + // Client is not connected yet. + return + } else if session.ClientType() != HelloClientTypeInternal { + log.Printf("Ignore internal message %+v from %s", msg, session.PublicId()) + return + } + + switch message.Internal.Type { + case "addsession": + msg := message.Internal.AddSession + room := h.getRoom(msg.RoomId) + if room == nil { + log.Printf("Ignore add session message %+v for invalid room %s from %s", *msg, msg.RoomId, session.PublicId()) + return + } + + sessionIdData := h.newSessionIdData(session.Backend()) + privateSessionId, err := h.encodeSessionId(sessionIdData, privateSessionName) + if err != nil { + log.Printf("Could not encode private virtual session id: %s", err) + return + } + publicSessionId, err := h.encodeSessionId(sessionIdData, publicSessionName) + if err != nil { + log.Printf("Could not encode public virtual session id: %s", err) + return + } + + ctx, cancel := context.WithTimeout(context.Background(), h.backendTimeout) + defer cancel() + + virtualSessionId := GetVirtualSessionId(session, msg.SessionId) + + request := NewBackendClientSessionRequest(room.Id(), "add", publicSessionId, msg) + var response BackendClientSessionResponse + if err := h.backend.PerformJSONRequest(ctx, session.ParsedBackendUrl(), request, &response); err != nil { + log.Printf("Could not add virtual session %s at backend %s: %s", virtualSessionId, session.BackendUrl(), err) + reply := message.NewErrorServerMessage(NewError("add_failed", "Could not add virtual session.")) + client.SendMessage(reply) + return + } + + sess := NewVirtualSession(session, privateSessionId, publicSessionId, sessionIdData, msg) + h.mu.Lock() + h.sessions[sessionIdData.Sid] = sess + h.virtualSessions[virtualSessionId] = sessionIdData.Sid + h.mu.Unlock() + log.Printf("Session %s added virtual session %s with initial flags %d", session.PublicId(), sess.PublicId(), sess.Flags()) + session.AddVirtualSession(sess) + sess.SetRoom(room) + room.AddSession(sess, nil) + case "updatesession": + msg := message.Internal.UpdateSession + room := h.getRoom(msg.RoomId) + if room == nil { + log.Printf("Ignore remove session message %+v for invalid room %s from %s", *msg, msg.RoomId, session.PublicId()) + return + } + + virtualSessionId := GetVirtualSessionId(session, msg.SessionId) + h.mu.Lock() + sid, found := h.virtualSessions[virtualSessionId] + if !found { + h.mu.Unlock() + return + } + + sess := h.sessions[sid] + h.mu.Unlock() + if sess != nil { + update := false + if virtualSession, ok := sess.(*VirtualSession); ok { + if msg.Flags != nil { + if virtualSession.SetFlags(*msg.Flags) { + update = true + } + } + } else { + log.Printf("Ignore update request for non-virtual session %s", sess.PublicId()) + } + if update { + room.NotifySessionChanged(sess) + } + } + case "removesession": + msg := message.Internal.RemoveSession + room := h.getRoom(msg.RoomId) + if room == nil { + log.Printf("Ignore remove session message %+v for invalid room %s from %s", *msg, msg.RoomId, session.PublicId()) + return + } + + virtualSessionId := GetVirtualSessionId(session, msg.SessionId) + h.mu.Lock() + sid, found := h.virtualSessions[virtualSessionId] + if !found { + h.mu.Unlock() + return + } + + delete(h.virtualSessions, virtualSessionId) + sess := h.sessions[sid] + h.mu.Unlock() + if sess != nil { + log.Printf("Session %s removed virtual session %s", session.PublicId(), sess.PublicId()) + sess.Close() + + go func() { + ctx, cancel := context.WithTimeout(context.Background(), h.backendTimeout) + defer cancel() + + request := NewBackendClientSessionRequest(room.Id(), "remove", sess.PublicId(), nil) + var response BackendClientSessionResponse + err := h.backend.PerformJSONRequest(ctx, sess.ParsedBackendUrl(), request, &response) + if err != nil { + log.Printf("Could not remove virtual session %s from backend %s: %s", sess.PublicId(), sess.BackendUrl(), err) + reply := message.NewErrorServerMessage(NewError("remove_failed", "Could not remove virtual session from backend.")) + client.SendMessage(reply) + } + }() + } + default: + log.Printf("Ignore unsupported internal message %+v from %s", message.Internal, session.PublicId()) + return + } +} + func isAllowedToSend(session *ClientSession, data *MessageClientMessageData) bool { var permission Permission if data.RoomType == "screen" { diff --git a/src/signaling/hub_test.go b/src/signaling/hub_test.go index 4ba9e15..a28d664 100644 --- a/src/signaling/hub_test.go +++ b/src/signaling/hub_test.go @@ -183,6 +183,21 @@ func validateBackendChecksum(t *testing.T, f func(http.ResponseWriter, *http.Req t.Fatal(err) } + if r.Header.Get("OCS-APIRequest") != "" { + var ocs OcsResponse + ocs.Ocs = &OcsBody{ + Meta: OcsMeta{ + Status: "ok", + StatusCode: http.StatusOK, + Message: http.StatusText(http.StatusOK), + }, + Data: (*json.RawMessage)(&data), + } + if data, err = json.Marshal(ocs); err != nil { + t.Fatal(err) + } + } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(data) @@ -239,22 +254,44 @@ func processRoomRequest(t *testing.T, w http.ResponseWriter, r *http.Request, re return response } +func processSessionRequest(t *testing.T, w http.ResponseWriter, r *http.Request, request *BackendClientRequest) *BackendClientResponse { + if request.Type != "session" || request.Session == nil { + t.Fatalf("Expected an session backend request, got %+v", request) + } + + // TODO(jojo): Evaluate request. + + response := &BackendClientResponse{ + Type: "session", + Session: &BackendClientSessionResponse{ + Version: BackendVersion, + RoomId: request.Session.RoomId, + }, + } + return response +} + func registerBackendHandler(t *testing.T, router *mux.Router) { registerBackendHandlerUrl(t, router, "/") } func registerBackendHandlerUrl(t *testing.T, router *mux.Router, url string) { - router.HandleFunc(url, validateBackendChecksum(t, func(w http.ResponseWriter, r *http.Request, request *BackendClientRequest) *BackendClientResponse { + handleFunc := validateBackendChecksum(t, func(w http.ResponseWriter, r *http.Request, request *BackendClientRequest) *BackendClientResponse { switch request.Type { case "auth": return processAuthRequest(t, w, r, request) case "room": return processRoomRequest(t, w, r, request) + case "session": + return processSessionRequest(t, w, r, request) default: t.Fatalf("Unsupported request received: %+v", request) return nil } - })) + }) + + router.HandleFunc(url, handleFunc) + router.HandleFunc("/ocs/v2.php/apps/spreed/api/v1/signaling/backend", handleFunc) } func performHousekeeping(hub *Hub, now time.Time) *sync.WaitGroup { diff --git a/src/signaling/room.go b/src/signaling/room.go index f7a928e..d639b7b 100644 --- a/src/signaling/room.go +++ b/src/signaling/room.go @@ -25,6 +25,7 @@ import ( "bytes" "context" "encoding/json" + "fmt" "log" "net/url" "sync" @@ -58,6 +59,7 @@ type Room struct { sessions map[string]Session internalSessions map[Session]bool + virtualSessions map[*VirtualSession]bool inCallSessions map[Session]bool roomSessionData map[string]*RoomSessionData @@ -116,6 +118,7 @@ func NewRoom(roomId string, properties *json.RawMessage, hub *Hub, n NatsClient, sessions: make(map[string]Session), internalSessions: make(map[Session]bool), + virtualSessions: make(map[*VirtualSession]bool), inCallSessions: make(map[Session]bool), roomSessionData: make(map[string]*RoomSessionData), @@ -258,8 +261,19 @@ func (r *Room) AddSession(session Session, sessionData *json.RawMessage) []Sessi } } r.sessions[sid] = session - if session.ClientType() == HelloClientTypeInternal { + var publishUsersChanged bool + switch session.ClientType() { + case HelloClientTypeInternal: r.internalSessions[session] = true + case HelloClientTypeVirtual: + virtualSession, ok := session.(*VirtualSession) + if !ok { + delete(r.sessions, sid) + r.mu.Unlock() + panic(fmt.Sprintf("Expected a virtual session, got %v", session)) + } + r.virtualSessions[virtualSession] = true + publishUsersChanged = true } if roomSessionData != nil { r.roomSessionData[sid] = roomSessionData @@ -268,6 +282,9 @@ func (r *Room) AddSession(session Session, sessionData *json.RawMessage) []Sessi r.mu.Unlock() if !found { r.PublishSessionJoined(session, roomSessionData) + if publishUsersChanged { + r.publishUsersChangedWithInternal() + } } return result } @@ -290,6 +307,9 @@ func (r *Room) RemoveSession(session Session) bool { sid := session.PublicId() delete(r.sessions, sid) delete(r.internalSessions, session) + if virtualSession, ok := session.(*VirtualSession); ok { + delete(r.virtualSessions, virtualSession) + } delete(r.inCallSessions, session) delete(r.roomSessionData, sid) if len(r.sessions) > 0 { @@ -407,6 +427,15 @@ func (r *Room) addInternalSessions(users []map[string]interface{}) []map[string] "internal": true, }) } + for session := range r.virtualSessions { + users = append(users, map[string]interface{}{ + "inCall": true, + "sessionId": session.PublicId(), + "lastPing": now, + "virtual": true, + "flags": session.Flags(), + }) + } r.mu.Unlock() return users } @@ -548,6 +577,15 @@ func (r *Room) NotifySessionResumed(client *Client) { client.SendMessage(message) } +func (r *Room) NotifySessionChanged(session Session) { + if session.ClientType() != HelloClientTypeVirtual { + // Only notify if a virtual session has changed. + return + } + + r.publishUsersChangedWithInternal() +} + func (r *Room) publishUsersChangedWithInternal() { message := r.getParticipantsUpdateMessage(r.users) r.publish(message) @@ -570,6 +608,9 @@ func (r *Room) publishActiveSessions() { case *ClientSession: // Use Nextcloud session id sid = sess.RoomSessionId() + case *VirtualSession: + // Use our internal generated session id (will be added to Nextcloud). + sid = sess.PublicId() default: continue } diff --git a/src/signaling/testclient_test.go b/src/signaling/testclient_test.go index bdacaeb..28198d2 100644 --- a/src/signaling/testclient_test.go +++ b/src/signaling/testclient_test.go @@ -527,6 +527,10 @@ func (c *TestClient) RunUntilRoom(ctx context.Context, roomId string) error { } func (c *TestClient) checkMessageJoined(message *ServerMessage, hello *HelloServerMessage) error { + return c.checkMessageJoinedSession(message, hello.SessionId, hello.UserId) +} + +func (c *TestClient) checkMessageJoinedSession(message *ServerMessage, sessionId string, userId string) error { if err := checkMessageType(message, "event"); err != nil { return err } else if message.Event.Target != "room" { @@ -537,12 +541,12 @@ func (c *TestClient) checkMessageJoined(message *ServerMessage, hello *HelloServ return fmt.Errorf("Expected one join event entry, got %+v", message.Event) } else { evt := message.Event.Join[0] - if evt.SessionId != hello.SessionId { + if sessionId != "" && evt.SessionId != sessionId { return fmt.Errorf("Expected join session id %+v, got %+v", - getPubliceSessionIdData(c.hub, hello.SessionId), getPubliceSessionIdData(c.hub, evt.SessionId)) + getPubliceSessionIdData(c.hub, sessionId), getPubliceSessionIdData(c.hub, evt.SessionId)) } - if evt.UserId != hello.UserId { - return fmt.Errorf("Expected join user id %s, got %+v", hello.UserId, evt) + if evt.UserId != userId { + return fmt.Errorf("Expected join user id %s, got %+v", userId, evt) } } return nil @@ -557,6 +561,10 @@ func (c *TestClient) RunUntilJoined(ctx context.Context, hello *HelloServerMessa } func (c *TestClient) checkMessageRoomLeave(message *ServerMessage, hello *HelloServerMessage) error { + return c.checkMessageRoomLeaveSession(message, hello.SessionId) +} + +func (c *TestClient) checkMessageRoomLeaveSession(message *ServerMessage, sessionId string) error { if err := checkMessageType(message, "event"); err != nil { return err } else if message.Event.Target != "room" { @@ -565,9 +573,9 @@ func (c *TestClient) checkMessageRoomLeave(message *ServerMessage, hello *HelloS return fmt.Errorf("Expected event type leave, got %+v", message.Event) } else if len(message.Event.Leave) != 1 { return fmt.Errorf("Expected one leave event entry, got %+v", message.Event) - } else if message.Event.Leave[0] != hello.SessionId { + } else if message.Event.Leave[0] != sessionId { return fmt.Errorf("Expected leave session id %+v, got %+v", - getPubliceSessionIdData(c.hub, hello.SessionId), getPubliceSessionIdData(c.hub, message.Event.Leave[0])) + getPubliceSessionIdData(c.hub, sessionId), getPubliceSessionIdData(c.hub, message.Event.Leave[0])) } return nil } diff --git a/src/signaling/virtualsession.go b/src/signaling/virtualsession.go new file mode 100644 index 0000000..c6e6d2f --- /dev/null +++ b/src/signaling/virtualsession.go @@ -0,0 +1,195 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2019 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "encoding/json" + "log" + "net/url" + "sync/atomic" + "time" + "unsafe" +) + +const ( + FLAG_MUTED_SPEAKING = 1 + FLAG_MUTED_LISTENING = 2 + FLAG_TALKING = 4 +) + +type VirtualSession struct { + hub *Hub + session *ClientSession + privateId string + publicId string + data *SessionIdData + room unsafe.Pointer + + sessionId string + userId string + userData *json.RawMessage + flags uint32 +} + +func GetVirtualSessionId(session *ClientSession, sessionId string) string { + return session.PublicId() + "|" + sessionId +} + +func NewVirtualSession(session *ClientSession, privateId string, publicId string, data *SessionIdData, msg *AddSessionInternalClientMessage) *VirtualSession { + return &VirtualSession{ + session: session, + privateId: privateId, + publicId: publicId, + data: data, + + sessionId: msg.SessionId, + userId: msg.UserId, + userData: msg.User, + flags: msg.Flags, + } +} + +func (s *VirtualSession) PrivateId() string { + return s.privateId +} + +func (s *VirtualSession) PublicId() string { + return s.publicId +} + +func (s *VirtualSession) ClientType() string { + return HelloClientTypeVirtual +} + +func (s *VirtualSession) Data() *SessionIdData { + return s.data +} + +func (s *VirtualSession) Backend() *Backend { + return s.session.Backend() +} + +func (s *VirtualSession) BackendUrl() string { + return s.session.BackendUrl() +} + +func (s *VirtualSession) ParsedBackendUrl() *url.URL { + return s.session.ParsedBackendUrl() +} + +func (s *VirtualSession) UserId() string { + return s.userId +} + +func (s *VirtualSession) UserData() *json.RawMessage { + return s.userData +} + +func (s *VirtualSession) SetRoom(room *Room) { + atomic.StorePointer(&s.room, unsafe.Pointer(room)) +} + +func (s *VirtualSession) GetRoom() *Room { + return (*Room)(atomic.LoadPointer(&s.room)) +} + +func (s *VirtualSession) LeaveRoom(notify bool) *Room { + room := s.GetRoom() + if room == nil { + return nil + } + + s.SetRoom(nil) + room.RemoveSession(s) + return room +} + +func (s *VirtualSession) IsExpired(now time.Time) bool { + return false +} + +func (s *VirtualSession) Close() { + s.session.RemoveVirtualSession(s) + s.session.hub.removeSession(s) +} + +func (s *VirtualSession) HasPermission(permission Permission) bool { + return true +} + +func (s *VirtualSession) Session() *ClientSession { + return s.session +} + +func (s *VirtualSession) SessionId() string { + return s.sessionId +} + +func (s *VirtualSession) AddFlags(flags uint32) bool { + for { + old := atomic.LoadUint32(&s.flags) + if old&flags == flags { + // Flags already set. + return false + } + newFlags := old | flags + if atomic.CompareAndSwapUint32(&s.flags, old, newFlags) { + log.Printf("Flags for session %s now %d (added %d)", s.PublicId(), newFlags, flags) + return true + } + // Another thread updated the flags while we were checking, retry. + } +} + +func (s *VirtualSession) RemoveFlags(flags uint32) bool { + for { + old := atomic.LoadUint32(&s.flags) + if old&flags == 0 { + // Flags not set. + return false + } + newFlags := old & ^flags + if atomic.CompareAndSwapUint32(&s.flags, old, newFlags) { + log.Printf("Flags for session %s now %d (removed %d)", s.PublicId(), newFlags, flags) + return true + } + // Another thread updated the flags while we were checking, retry. + } +} + +func (s *VirtualSession) SetFlags(flags uint32) bool { + for { + old := atomic.LoadUint32(&s.flags) + if old == flags { + return false + } + + if atomic.CompareAndSwapUint32(&s.flags, old, flags) { + log.Printf("Flags for session %s now %d", s.PublicId(), flags) + return true + } + } +} + +func (s *VirtualSession) Flags() uint32 { + return atomic.LoadUint32(&s.flags) +} diff --git a/src/signaling/virtualsession_test.go b/src/signaling/virtualsession_test.go new file mode 100644 index 0000000..ffc8b62 --- /dev/null +++ b/src/signaling/virtualsession_test.go @@ -0,0 +1,378 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2019 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "encoding/json" + "testing" + + "golang.org/x/net/context" +) + +func TestVirtualSession(t *testing.T) { + hub, _, _, server, shutdown := CreateHubForTest(t) + defer shutdown() + + roomId := "the-room-id" + emptyProperties := json.RawMessage("{}") + backend := &Backend{ + id: "compat", + compat: true, + } + room, err := hub.createRoom(roomId, &emptyProperties, backend) + if err != nil { + t.Fatalf("Could not create room: %s", err) + } + defer room.Close() + + clientInternal := NewTestClient(t, server, hub) + defer clientInternal.CloseWithBye() + if err := clientInternal.SendHelloInternal(); err != nil { + t.Fatal(err) + } + + client := NewTestClient(t, server, hub) + defer client.CloseWithBye() + if err := client.SendHello(testDefaultUserId); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + if hello, err := clientInternal.RunUntilHello(ctx); err != nil { + t.Error(err) + } else { + if hello.Hello.UserId != "" { + t.Errorf("Expected empty user id, got %+v", hello.Hello) + } + if hello.Hello.SessionId == "" { + t.Errorf("Expected session id, got %+v", hello.Hello) + } + if hello.Hello.ResumeId == "" { + t.Errorf("Expected resume id, got %+v", hello.Hello) + } + } + hello, err := client.RunUntilHello(ctx) + if err != nil { + t.Error(err) + } + + if room, err := client.JoinRoom(ctx, roomId); err != nil { + t.Fatal(err) + } else if room.Room.RoomId != roomId { + t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) + } + + // Ignore "join" events. + if err := client.DrainMessages(ctx); err != nil { + t.Error(err) + } + + internalSessionId := "session1" + userId := "user1" + msgAdd := &ClientMessage{ + Type: "internal", + Internal: &InternalClientMessage{ + Type: "addsession", + AddSession: &AddSessionInternalClientMessage{ + CommonSessionInternalClientMessage: CommonSessionInternalClientMessage{ + SessionId: internalSessionId, + RoomId: roomId, + }, + UserId: userId, + Flags: FLAG_MUTED_SPEAKING, + }, + }, + } + if err := clientInternal.WriteJSON(msgAdd); err != nil { + t.Fatal(err) + } + + msg1, err := client.RunUntilMessage(ctx) + if err != nil { + t.Fatal(err) + } + // The public session id will be generated by the server, so don't check for it. + if err := client.checkMessageJoinedSession(msg1, "", userId); err != nil { + t.Fatal(err) + } + sessionId := msg1.Event.Join[0].SessionId + session := hub.GetSessionByPublicId(sessionId) + if session == nil { + t.Fatalf("Could not get virtual session %s", sessionId) + } + if session.ClientType() != HelloClientTypeVirtual { + t.Errorf("Expected client type %s, got %s", HelloClientTypeVirtual, session.ClientType()) + } + if sid := session.(*VirtualSession).SessionId(); sid != internalSessionId { + t.Errorf("Expected internal session id %s, got %s", internalSessionId, sid) + } + + // Also a participants update event will be triggered for the virtual user. + msg2, err := client.RunUntilMessage(ctx) + if err != nil { + t.Fatal(err) + } + updateMsg, err := checkMessageParticipantsInCall(msg2) + if err != nil { + t.Error(err) + } else if updateMsg.RoomId != roomId { + t.Errorf("Expected room %s, got %s", roomId, updateMsg.RoomId) + } else if len(updateMsg.Users) != 1 { + t.Errorf("Expected one user, got %+v", updateMsg.Users) + } else if sid, ok := updateMsg.Users[0]["sessionId"].(string); !ok || sid != sessionId { + t.Errorf("Expected session id %s, got %+v", sessionId, updateMsg.Users[0]) + } else if virtual, ok := updateMsg.Users[0]["virtual"].(bool); !ok || !virtual { + t.Errorf("Expected virtual user, got %+v", updateMsg.Users[0]) + } else if inCall, ok := updateMsg.Users[0]["inCall"].(bool); !ok || !inCall { + t.Errorf("Expected user in call, got %+v", updateMsg.Users[0]) + } else if flags, ok := updateMsg.Users[0]["flags"].(float64); !ok || flags != FLAG_MUTED_SPEAKING { + t.Errorf("Expected flags %d, got %+v", FLAG_MUTED_SPEAKING, updateMsg.Users[0]) + } + + // When sending to a virtual session, the message is sent to the actual + // client and contains a "Recipient" block with the internal session id. + recipient := MessageClientMessageRecipient{ + Type: "session", + SessionId: sessionId, + } + + data := "from-client-to-virtual" + client.SendMessage(recipient, data) + + msg2, err = clientInternal.RunUntilMessage(ctx) + if err != nil { + t.Fatal(err) + } else if err := checkMessageType(msg2, "message"); err != nil { + t.Fatal(err) + } else if err := checkMessageSender(hub, msg2.Message, "session", hello.Hello); err != nil { + t.Error(err) + } + + if msg2.Message.Recipient == nil { + t.Errorf("Expected recipient, got none") + } else if msg2.Message.Recipient.Type != "session" { + t.Errorf("Expected recipient type session, got %s", msg2.Message.Recipient.Type) + } else if msg2.Message.Recipient.SessionId != internalSessionId { + t.Errorf("Expected recipient %s, got %s", internalSessionId, msg2.Message.Recipient.SessionId) + } + + var payload string + if err := json.Unmarshal(*msg2.Message.Data, &payload); err != nil { + t.Error(err) + } else if payload != data { + t.Errorf("Expected payload %s, got %s", data, payload) + } + + msgRemove := &ClientMessage{ + Type: "internal", + Internal: &InternalClientMessage{ + Type: "removesession", + RemoveSession: &RemoveSessionInternalClientMessage{ + CommonSessionInternalClientMessage: CommonSessionInternalClientMessage{ + SessionId: internalSessionId, + RoomId: roomId, + }, + }, + }, + } + if err := clientInternal.WriteJSON(msgRemove); err != nil { + t.Fatal(err) + } + + msg3, err := client.RunUntilMessage(ctx) + if err != nil { + t.Fatal(err) + } + if err := client.checkMessageRoomLeaveSession(msg3, sessionId); err != nil { + t.Error(err) + } +} + +func TestVirtualSessionCleanup(t *testing.T) { + hub, _, _, server, shutdown := CreateHubForTest(t) + defer shutdown() + + roomId := "the-room-id" + emptyProperties := json.RawMessage("{}") + backend := &Backend{ + id: "compat", + compat: true, + } + room, err := hub.createRoom(roomId, &emptyProperties, backend) + if err != nil { + t.Fatalf("Could not create room: %s", err) + } + defer room.Close() + + clientInternal := NewTestClient(t, server, hub) + defer clientInternal.CloseWithBye() + if err := clientInternal.SendHelloInternal(); err != nil { + t.Fatal(err) + } + + client := NewTestClient(t, server, hub) + defer client.CloseWithBye() + if err := client.SendHello(testDefaultUserId); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + if hello, err := clientInternal.RunUntilHello(ctx); err != nil { + t.Error(err) + } else { + if hello.Hello.UserId != "" { + t.Errorf("Expected empty user id, got %+v", hello.Hello) + } + if hello.Hello.SessionId == "" { + t.Errorf("Expected session id, got %+v", hello.Hello) + } + if hello.Hello.ResumeId == "" { + t.Errorf("Expected resume id, got %+v", hello.Hello) + } + } + if _, err := client.RunUntilHello(ctx); err != nil { + t.Error(err) + } + + if room, err := client.JoinRoom(ctx, roomId); err != nil { + t.Fatal(err) + } else if room.Room.RoomId != roomId { + t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) + } + + // Ignore "join" events. + if err := client.DrainMessages(ctx); err != nil { + t.Error(err) + } + + internalSessionId := "session1" + userId := "user1" + msgAdd := &ClientMessage{ + Type: "internal", + Internal: &InternalClientMessage{ + Type: "addsession", + AddSession: &AddSessionInternalClientMessage{ + CommonSessionInternalClientMessage: CommonSessionInternalClientMessage{ + SessionId: internalSessionId, + RoomId: roomId, + }, + UserId: userId, + Flags: FLAG_MUTED_SPEAKING, + }, + }, + } + if err := clientInternal.WriteJSON(msgAdd); err != nil { + t.Fatal(err) + } + + msg1, err := client.RunUntilMessage(ctx) + if err != nil { + t.Fatal(err) + } + // The public session id will be generated by the server, so don't check for it. + if err := client.checkMessageJoinedSession(msg1, "", userId); err != nil { + t.Fatal(err) + } + sessionId := msg1.Event.Join[0].SessionId + session := hub.GetSessionByPublicId(sessionId) + if session == nil { + t.Fatalf("Could not get virtual session %s", sessionId) + } + if session.ClientType() != HelloClientTypeVirtual { + t.Errorf("Expected client type %s, got %s", HelloClientTypeVirtual, session.ClientType()) + } + if sid := session.(*VirtualSession).SessionId(); sid != internalSessionId { + t.Errorf("Expected internal session id %s, got %s", internalSessionId, sid) + } + + // Also a participants update event will be triggered for the virtual user. + msg2, err := client.RunUntilMessage(ctx) + if err != nil { + t.Fatal(err) + } + updateMsg, err := checkMessageParticipantsInCall(msg2) + if err != nil { + t.Error(err) + } else if updateMsg.RoomId != roomId { + t.Errorf("Expected room %s, got %s", roomId, updateMsg.RoomId) + } else if len(updateMsg.Users) != 1 { + t.Errorf("Expected one user, got %+v", updateMsg.Users) + } else if sid, ok := updateMsg.Users[0]["sessionId"].(string); !ok || sid != sessionId { + t.Errorf("Expected session id %s, got %+v", sessionId, updateMsg.Users[0]) + } else if virtual, ok := updateMsg.Users[0]["virtual"].(bool); !ok || !virtual { + t.Errorf("Expected virtual user, got %+v", updateMsg.Users[0]) + } else if inCall, ok := updateMsg.Users[0]["inCall"].(bool); !ok || !inCall { + t.Errorf("Expected user in call, got %+v", updateMsg.Users[0]) + } else if flags, ok := updateMsg.Users[0]["flags"].(float64); !ok || flags != FLAG_MUTED_SPEAKING { + t.Errorf("Expected flags %d, got %+v", FLAG_MUTED_SPEAKING, updateMsg.Users[0]) + } + + // The virtual sessions are closed when the parent session is deleted. + clientInternal.CloseWithBye() + + if msg2, err := client.RunUntilMessage(ctx); err != nil { + t.Fatal(err) + } else if err := client.checkMessageRoomLeaveSession(msg2, sessionId); err != nil { + t.Error(err) + } +} + +func TestVirtualSessionFlags(t *testing.T) { + s := &VirtualSession{ + publicId: "dummy-for-testing", + } + if s.Flags() != 0 { + t.Fatalf("Expected flags 0, got %d", s.Flags()) + } + s.AddFlags(1) + if s.Flags() != 1 { + t.Fatalf("Expected flags 1, got %d", s.Flags()) + } + s.AddFlags(1) + if s.Flags() != 1 { + t.Fatalf("Expected flags 1, got %d", s.Flags()) + } + s.AddFlags(2) + if s.Flags() != 3 { + t.Fatalf("Expected flags 3, got %d", s.Flags()) + } + s.RemoveFlags(1) + if s.Flags() != 2 { + t.Fatalf("Expected flags 2, got %d", s.Flags()) + } + s.RemoveFlags(1) + if s.Flags() != 2 { + t.Fatalf("Expected flags 2, got %d", s.Flags()) + } + s.AddFlags(3) + if s.Flags() != 3 { + t.Fatalf("Expected flags 3, got %d", s.Flags()) + } + s.RemoveFlags(1) + if s.Flags() != 2 { + t.Fatalf("Expected flags 2, got %d", s.Flags()) + } +}