diff --git a/clientsession.go b/clientsession.go index d3174da..1567278 100644 --- a/clientsession.go +++ b/clientsession.go @@ -60,7 +60,7 @@ type ClientSession struct { userId string userData *json.RawMessage - inCall atomic.Uint32 + inCall Flags supportsPermissions bool permissions map[Permission]bool @@ -169,7 +169,7 @@ func (s *ClientSession) ClientType() string { // GetInCall is only used for internal clients. func (s *ClientSession) GetInCall() int { - return int(s.inCall.Load()) + return int(s.inCall.Get()) } func (s *ClientSession) SetInCall(inCall int) bool { @@ -177,16 +177,7 @@ func (s *ClientSession) SetInCall(inCall int) bool { inCall = 0 } - for { - old := s.inCall.Load() - if old == uint32(inCall) { - return false - } - - if s.inCall.CompareAndSwap(old, uint32(inCall)) { - return true - } - } + return s.inCall.Set(uint32(inCall)) } func (s *ClientSession) GetFeatures() []string { diff --git a/flags.go b/flags.go new file mode 100644 index 0000000..3f67283 --- /dev/null +++ b/flags.go @@ -0,0 +1,77 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2023 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "sync/atomic" +) + +type Flags struct { + flags atomic.Uint32 +} + +func (f *Flags) Add(flags uint32) bool { + for { + old := f.flags.Load() + if old&flags == flags { + // Flags already set. + return false + } + newFlags := old | flags + if f.flags.CompareAndSwap(old, newFlags) { + return true + } + // Another thread updated the flags while we were checking, retry. + } +} + +func (f *Flags) Remove(flags uint32) bool { + for { + old := f.flags.Load() + if old&flags == 0 { + // Flags not set. + return false + } + newFlags := old & ^flags + if f.flags.CompareAndSwap(old, newFlags) { + return true + } + // Another thread updated the flags while we were checking, retry. + } +} + +func (f *Flags) Set(flags uint32) bool { + for { + old := f.flags.Load() + if old == flags { + return false + } + + if f.flags.CompareAndSwap(old, flags) { + return true + } + } +} + +func (f *Flags) Get() uint32 { + return f.flags.Load() +} diff --git a/flags_test.go b/flags_test.go new file mode 100644 index 0000000..406b2c5 --- /dev/null +++ b/flags_test.go @@ -0,0 +1,140 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2023 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "sync" + "sync/atomic" + "testing" +) + +func TestFlags(t *testing.T) { + var f Flags + if f.Get() != 0 { + t.Fatalf("Expected flags 0, got %d", f.Get()) + } + if !f.Add(1) { + t.Error("expected true") + } + if f.Get() != 1 { + t.Fatalf("Expected flags 1, got %d", f.Get()) + } + if f.Add(1) { + t.Error("expected false") + } + if f.Get() != 1 { + t.Fatalf("Expected flags 1, got %d", f.Get()) + } + if !f.Add(2) { + t.Error("expected true") + } + if f.Get() != 3 { + t.Fatalf("Expected flags 3, got %d", f.Get()) + } + if !f.Remove(1) { + t.Error("expected true") + } + if f.Get() != 2 { + t.Fatalf("Expected flags 2, got %d", f.Get()) + } + if f.Remove(1) { + t.Error("expected false") + } + if f.Get() != 2 { + t.Fatalf("Expected flags 2, got %d", f.Get()) + } + if !f.Add(3) { + t.Error("expected true") + } + if f.Get() != 3 { + t.Fatalf("Expected flags 3, got %d", f.Get()) + } + if !f.Remove(1) { + t.Error("expected true") + } + if f.Get() != 2 { + t.Fatalf("Expected flags 2, got %d", f.Get()) + } +} + +func runConcurrentFlags(t *testing.T, count int, f func()) { + var start sync.WaitGroup + start.Add(1) + var ready sync.WaitGroup + var done sync.WaitGroup + for i := 0; i < count; i++ { + done.Add(1) + ready.Add(1) + go func() { + defer done.Done() + ready.Done() + start.Wait() + f() + }() + } + ready.Wait() + start.Done() + done.Wait() +} + +func TestFlagsConcurrentAdd(t *testing.T) { + var flags Flags + + var added atomic.Int32 + runConcurrentFlags(t, 100, func() { + if flags.Add(1) { + added.Add(1) + } + }) + if added.Load() != 1 { + t.Errorf("expected only one successfull attempt, got %d", added.Load()) + } +} + +func TestFlagsConcurrentRemove(t *testing.T) { + var flags Flags + flags.Set(1) + + var removed atomic.Int32 + runConcurrentFlags(t, 100, func() { + if flags.Remove(1) { + removed.Add(1) + } + }) + if removed.Load() != 1 { + t.Errorf("expected only one successfull attempt, got %d", removed.Load()) + } +} + +func TestFlagsConcurrentSet(t *testing.T) { + var flags Flags + + var set atomic.Int32 + runConcurrentFlags(t, 100, func() { + if flags.Set(1) { + set.Add(1) + } + }) + if set.Load() != 1 { + t.Errorf("expected only one successfull attempt, got %d", set.Load()) + } +} diff --git a/virtualsession.go b/virtualsession.go index 8614199..426731e 100644 --- a/virtualsession.go +++ b/virtualsession.go @@ -47,8 +47,8 @@ type VirtualSession struct { sessionId string userId string userData *json.RawMessage - inCall atomic.Uint32 - flags atomic.Uint32 + inCall Flags + flags Flags options *AddSessionOptions } @@ -69,7 +69,6 @@ func NewVirtualSession(session *ClientSession, privateId string, publicId string userData: msg.User, options: msg.Options, } - result.flags.Store(msg.Flags) if err := session.events.RegisterSessionListener(publicId, session.Backend(), result); err != nil { return nil, err @@ -80,6 +79,9 @@ func NewVirtualSession(session *ClientSession, privateId string, publicId string } else if !session.HasFeature(ClientFeatureInternalInCall) { result.SetInCall(FlagInCall | FlagWithPhone) } + if msg.Flags != 0 { + result.SetFlags(msg.Flags) + } return result, nil } @@ -97,7 +99,7 @@ func (s *VirtualSession) ClientType() string { } func (s *VirtualSession) GetInCall() int { - return int(s.inCall.Load()) + return int(s.inCall.Get()) } func (s *VirtualSession) SetInCall(inCall int) bool { @@ -105,16 +107,7 @@ func (s *VirtualSession) SetInCall(inCall int) bool { inCall = 0 } - for { - old := s.inCall.Load() - if old == uint32(inCall) { - return false - } - - if s.inCall.CompareAndSwap(old, uint32(inCall)) { - return true - } - } + return s.inCall.Set(uint32(inCall)) } func (s *VirtualSession) Data() *SessionIdData { @@ -240,50 +233,19 @@ func (s *VirtualSession) SessionId() string { } func (s *VirtualSession) AddFlags(flags uint32) bool { - for { - old := s.flags.Load() - if old&flags == flags { - // Flags already set. - return false - } - newFlags := old | flags - if s.flags.CompareAndSwap(old, newFlags) { - return true - } - // Another thread updated the flags while we were checking, retry. - } + return s.flags.Add(flags) } func (s *VirtualSession) RemoveFlags(flags uint32) bool { - for { - old := s.flags.Load() - if old&flags == 0 { - // Flags not set. - return false - } - newFlags := old & ^flags - if s.flags.CompareAndSwap(old, newFlags) { - return true - } - // Another thread updated the flags while we were checking, retry. - } + return s.flags.Remove(flags) } func (s *VirtualSession) SetFlags(flags uint32) bool { - for { - old := s.flags.Load() - if old == flags { - return false - } - - if s.flags.CompareAndSwap(old, flags) { - return true - } - } + return s.flags.Set(flags) } func (s *VirtualSession) Flags() uint32 { - return s.flags.Load() + return s.flags.Get() } func (s *VirtualSession) Options() *AddSessionOptions { diff --git a/virtualsession_test.go b/virtualsession_test.go index 7a986bf..76673f9 100644 --- a/virtualsession_test.go +++ b/virtualsession_test.go @@ -711,40 +711,3 @@ func TestVirtualSessionCleanup(t *testing.T) { t.Error(err) } } - -func TestVirtualSessionFlags(t *testing.T) { - s := &VirtualSession{ - publicId: "dummy-for-testing", - } - if s.Flags() != 0 { - t.Fatalf("Expected flags 0, got %d", s.Flags()) - } - s.AddFlags(1) - if s.Flags() != 1 { - t.Fatalf("Expected flags 1, got %d", s.Flags()) - } - s.AddFlags(1) - if s.Flags() != 1 { - t.Fatalf("Expected flags 1, got %d", s.Flags()) - } - s.AddFlags(2) - if s.Flags() != 3 { - t.Fatalf("Expected flags 3, got %d", s.Flags()) - } - s.RemoveFlags(1) - if s.Flags() != 2 { - t.Fatalf("Expected flags 2, got %d", s.Flags()) - } - s.RemoveFlags(1) - if s.Flags() != 2 { - t.Fatalf("Expected flags 2, got %d", s.Flags()) - } - s.AddFlags(3) - if s.Flags() != 3 { - t.Fatalf("Expected flags 3, got %d", s.Flags()) - } - s.RemoveFlags(1) - if s.Flags() != 2 { - t.Fatalf("Expected flags 2, got %d", s.Flags()) - } -}