Merge pull request #591 from strukturag/common-flags

Move common flags code to own struct.
This commit is contained in:
Joachim Bauch 2023-10-30 10:23:11 +01:00 committed by GitHub
commit 00c2e4ea9f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 231 additions and 98 deletions

View file

@ -60,7 +60,7 @@ type ClientSession struct {
userId string
userData *json.RawMessage
inCall atomic.Uint32
inCall Flags
supportsPermissions bool
permissions map[Permission]bool
@ -169,7 +169,7 @@ func (s *ClientSession) ClientType() string {
// GetInCall is only used for internal clients.
func (s *ClientSession) GetInCall() int {
return int(s.inCall.Load())
return int(s.inCall.Get())
}
func (s *ClientSession) SetInCall(inCall int) bool {
@ -177,16 +177,7 @@ func (s *ClientSession) SetInCall(inCall int) bool {
inCall = 0
}
for {
old := s.inCall.Load()
if old == uint32(inCall) {
return false
}
if s.inCall.CompareAndSwap(old, uint32(inCall)) {
return true
}
}
return s.inCall.Set(uint32(inCall))
}
func (s *ClientSession) GetFeatures() []string {

77
flags.go Normal file
View file

@ -0,0 +1,77 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2023 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package signaling
import (
"sync/atomic"
)
type Flags struct {
flags atomic.Uint32
}
func (f *Flags) Add(flags uint32) bool {
for {
old := f.flags.Load()
if old&flags == flags {
// Flags already set.
return false
}
newFlags := old | flags
if f.flags.CompareAndSwap(old, newFlags) {
return true
}
// Another thread updated the flags while we were checking, retry.
}
}
func (f *Flags) Remove(flags uint32) bool {
for {
old := f.flags.Load()
if old&flags == 0 {
// Flags not set.
return false
}
newFlags := old & ^flags
if f.flags.CompareAndSwap(old, newFlags) {
return true
}
// Another thread updated the flags while we were checking, retry.
}
}
func (f *Flags) Set(flags uint32) bool {
for {
old := f.flags.Load()
if old == flags {
return false
}
if f.flags.CompareAndSwap(old, flags) {
return true
}
}
}
func (f *Flags) Get() uint32 {
return f.flags.Load()
}

140
flags_test.go Normal file
View file

@ -0,0 +1,140 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2023 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package signaling
import (
"sync"
"sync/atomic"
"testing"
)
func TestFlags(t *testing.T) {
var f Flags
if f.Get() != 0 {
t.Fatalf("Expected flags 0, got %d", f.Get())
}
if !f.Add(1) {
t.Error("expected true")
}
if f.Get() != 1 {
t.Fatalf("Expected flags 1, got %d", f.Get())
}
if f.Add(1) {
t.Error("expected false")
}
if f.Get() != 1 {
t.Fatalf("Expected flags 1, got %d", f.Get())
}
if !f.Add(2) {
t.Error("expected true")
}
if f.Get() != 3 {
t.Fatalf("Expected flags 3, got %d", f.Get())
}
if !f.Remove(1) {
t.Error("expected true")
}
if f.Get() != 2 {
t.Fatalf("Expected flags 2, got %d", f.Get())
}
if f.Remove(1) {
t.Error("expected false")
}
if f.Get() != 2 {
t.Fatalf("Expected flags 2, got %d", f.Get())
}
if !f.Add(3) {
t.Error("expected true")
}
if f.Get() != 3 {
t.Fatalf("Expected flags 3, got %d", f.Get())
}
if !f.Remove(1) {
t.Error("expected true")
}
if f.Get() != 2 {
t.Fatalf("Expected flags 2, got %d", f.Get())
}
}
func runConcurrentFlags(t *testing.T, count int, f func()) {
var start sync.WaitGroup
start.Add(1)
var ready sync.WaitGroup
var done sync.WaitGroup
for i := 0; i < count; i++ {
done.Add(1)
ready.Add(1)
go func() {
defer done.Done()
ready.Done()
start.Wait()
f()
}()
}
ready.Wait()
start.Done()
done.Wait()
}
func TestFlagsConcurrentAdd(t *testing.T) {
var flags Flags
var added atomic.Int32
runConcurrentFlags(t, 100, func() {
if flags.Add(1) {
added.Add(1)
}
})
if added.Load() != 1 {
t.Errorf("expected only one successfull attempt, got %d", added.Load())
}
}
func TestFlagsConcurrentRemove(t *testing.T) {
var flags Flags
flags.Set(1)
var removed atomic.Int32
runConcurrentFlags(t, 100, func() {
if flags.Remove(1) {
removed.Add(1)
}
})
if removed.Load() != 1 {
t.Errorf("expected only one successfull attempt, got %d", removed.Load())
}
}
func TestFlagsConcurrentSet(t *testing.T) {
var flags Flags
var set atomic.Int32
runConcurrentFlags(t, 100, func() {
if flags.Set(1) {
set.Add(1)
}
})
if set.Load() != 1 {
t.Errorf("expected only one successfull attempt, got %d", set.Load())
}
}

View file

@ -47,8 +47,8 @@ type VirtualSession struct {
sessionId string
userId string
userData *json.RawMessage
inCall atomic.Uint32
flags atomic.Uint32
inCall Flags
flags Flags
options *AddSessionOptions
}
@ -69,7 +69,6 @@ func NewVirtualSession(session *ClientSession, privateId string, publicId string
userData: msg.User,
options: msg.Options,
}
result.flags.Store(msg.Flags)
if err := session.events.RegisterSessionListener(publicId, session.Backend(), result); err != nil {
return nil, err
@ -80,6 +79,9 @@ func NewVirtualSession(session *ClientSession, privateId string, publicId string
} else if !session.HasFeature(ClientFeatureInternalInCall) {
result.SetInCall(FlagInCall | FlagWithPhone)
}
if msg.Flags != 0 {
result.SetFlags(msg.Flags)
}
return result, nil
}
@ -97,7 +99,7 @@ func (s *VirtualSession) ClientType() string {
}
func (s *VirtualSession) GetInCall() int {
return int(s.inCall.Load())
return int(s.inCall.Get())
}
func (s *VirtualSession) SetInCall(inCall int) bool {
@ -105,16 +107,7 @@ func (s *VirtualSession) SetInCall(inCall int) bool {
inCall = 0
}
for {
old := s.inCall.Load()
if old == uint32(inCall) {
return false
}
if s.inCall.CompareAndSwap(old, uint32(inCall)) {
return true
}
}
return s.inCall.Set(uint32(inCall))
}
func (s *VirtualSession) Data() *SessionIdData {
@ -240,50 +233,19 @@ func (s *VirtualSession) SessionId() string {
}
func (s *VirtualSession) AddFlags(flags uint32) bool {
for {
old := s.flags.Load()
if old&flags == flags {
// Flags already set.
return false
}
newFlags := old | flags
if s.flags.CompareAndSwap(old, newFlags) {
return true
}
// Another thread updated the flags while we were checking, retry.
}
return s.flags.Add(flags)
}
func (s *VirtualSession) RemoveFlags(flags uint32) bool {
for {
old := s.flags.Load()
if old&flags == 0 {
// Flags not set.
return false
}
newFlags := old & ^flags
if s.flags.CompareAndSwap(old, newFlags) {
return true
}
// Another thread updated the flags while we were checking, retry.
}
return s.flags.Remove(flags)
}
func (s *VirtualSession) SetFlags(flags uint32) bool {
for {
old := s.flags.Load()
if old == flags {
return false
}
if s.flags.CompareAndSwap(old, flags) {
return true
}
}
return s.flags.Set(flags)
}
func (s *VirtualSession) Flags() uint32 {
return s.flags.Load()
return s.flags.Get()
}
func (s *VirtualSession) Options() *AddSessionOptions {

View file

@ -711,40 +711,3 @@ func TestVirtualSessionCleanup(t *testing.T) {
t.Error(err)
}
}
func TestVirtualSessionFlags(t *testing.T) {
s := &VirtualSession{
publicId: "dummy-for-testing",
}
if s.Flags() != 0 {
t.Fatalf("Expected flags 0, got %d", s.Flags())
}
s.AddFlags(1)
if s.Flags() != 1 {
t.Fatalf("Expected flags 1, got %d", s.Flags())
}
s.AddFlags(1)
if s.Flags() != 1 {
t.Fatalf("Expected flags 1, got %d", s.Flags())
}
s.AddFlags(2)
if s.Flags() != 3 {
t.Fatalf("Expected flags 3, got %d", s.Flags())
}
s.RemoveFlags(1)
if s.Flags() != 2 {
t.Fatalf("Expected flags 2, got %d", s.Flags())
}
s.RemoveFlags(1)
if s.Flags() != 2 {
t.Fatalf("Expected flags 2, got %d", s.Flags())
}
s.AddFlags(3)
if s.Flags() != 3 {
t.Fatalf("Expected flags 3, got %d", s.Flags())
}
s.RemoveFlags(1)
if s.Flags() != 2 {
t.Fatalf("Expected flags 2, got %d", s.Flags())
}
}