mirror of
https://github.com/strukturag/nextcloud-spreed-signaling
synced 2026-03-14 14:35:44 +01:00
Move throttler code to async package.
This commit is contained in:
parent
446936f7ff
commit
1c3a03e972
6 changed files with 44 additions and 36 deletions
|
|
@ -19,7 +19,7 @@
|
|||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package signaling
|
||||
package async
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
|
@ -92,9 +92,12 @@ type throttleEntry struct {
|
|||
ts time.Time
|
||||
}
|
||||
|
||||
type GetTimeFunc func() time.Time
|
||||
type ThrottleDelayFunc func(context.Context, time.Duration)
|
||||
|
||||
type memoryThrottler struct {
|
||||
getNow func() time.Time
|
||||
doDelay func(context.Context, time.Duration)
|
||||
getNow GetTimeFunc
|
||||
doDelay ThrottleDelayFunc
|
||||
|
||||
mu sync.RWMutex
|
||||
// +checklocks:mu
|
||||
|
|
@ -104,14 +107,18 @@ type memoryThrottler struct {
|
|||
}
|
||||
|
||||
func NewMemoryThrottler() (Throttler, error) {
|
||||
return NewCustomMemoryThrottler(time.Now, defaultDelay)
|
||||
}
|
||||
|
||||
func NewCustomMemoryThrottler(getNow GetTimeFunc, delay ThrottleDelayFunc) (Throttler, error) {
|
||||
result := &memoryThrottler{
|
||||
getNow: time.Now,
|
||||
getNow: getNow,
|
||||
doDelay: delay,
|
||||
|
||||
clients: make(map[string]map[string][]throttleEntry),
|
||||
|
||||
closer: internal.NewCloser(),
|
||||
}
|
||||
result.doDelay = result.delay
|
||||
go result.housekeeping()
|
||||
return result, nil
|
||||
}
|
||||
|
|
@ -310,7 +317,7 @@ func (t *memoryThrottler) throttle(ctx context.Context, client string, action st
|
|||
t.doDelay(ctx, delay)
|
||||
}
|
||||
|
||||
func (t *memoryThrottler) delay(ctx context.Context, duration time.Duration) {
|
||||
func defaultDelay(ctx context.Context, duration time.Duration) {
|
||||
c, cancel := context.WithTimeout(ctx, duration)
|
||||
defer cancel()
|
||||
|
||||
|
|
@ -19,7 +19,7 @@
|
|||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package signaling
|
||||
package async
|
||||
|
||||
import (
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
|
|
@ -19,10 +19,9 @@
|
|||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package signaling
|
||||
package async
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
|
@ -45,22 +44,6 @@ func newMemoryThrottlerForTest(t *testing.T) Throttler {
|
|||
return result
|
||||
}
|
||||
|
||||
type throttlerTiming struct {
|
||||
t *testing.T
|
||||
|
||||
now time.Time
|
||||
expectedSleep time.Duration
|
||||
}
|
||||
|
||||
func (t *throttlerTiming) getNow() time.Time {
|
||||
return t.now
|
||||
}
|
||||
|
||||
func (t *throttlerTiming) doDelay(ctx context.Context, duration time.Duration) {
|
||||
t.t.Helper()
|
||||
assert.Equal(t.t, t.expectedSleep, duration)
|
||||
}
|
||||
|
||||
func expectDelay(t *testing.T, f func(), delay time.Duration) {
|
||||
t.Helper()
|
||||
a := time.Now()
|
||||
|
|
@ -49,6 +49,7 @@ import (
|
|||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
|
||||
"github.com/strukturag/nextcloud-spreed-signaling/api"
|
||||
"github.com/strukturag/nextcloud-spreed-signaling/async"
|
||||
"github.com/strukturag/nextcloud-spreed-signaling/log"
|
||||
)
|
||||
|
||||
|
|
@ -840,7 +841,7 @@ func (b *BackendServer) startDialout(ctx context.Context, roomid string, backend
|
|||
|
||||
func (b *BackendServer) roomHandler(ctx context.Context, w http.ResponseWriter, r *http.Request, body []byte) {
|
||||
throttle, err := b.hub.throttler.CheckBruteforce(ctx, b.hub.getRealUserIP(r), "BackendRoomAuth")
|
||||
if err == ErrBruteforceDetected {
|
||||
if err == async.ErrBruteforceDetected {
|
||||
http.Error(w, "Too many requests", http.StatusTooManyRequests)
|
||||
return
|
||||
} else if err != nil {
|
||||
|
|
|
|||
8
hub.go
8
hub.go
|
|
@ -213,7 +213,7 @@ type Hub struct {
|
|||
rpcServer *GrpcServer
|
||||
rpcClients *GrpcClients
|
||||
|
||||
throttler Throttler
|
||||
throttler async.Throttler
|
||||
|
||||
skipFederationVerify bool
|
||||
federationTimeout time.Duration
|
||||
|
|
@ -352,7 +352,7 @@ func NewHub(ctx context.Context, config *goconf.ConfigFile, events AsyncEvents,
|
|||
return nil, err
|
||||
}
|
||||
|
||||
throttler, err := NewMemoryThrottler()
|
||||
throttler, err := async.NewMemoryThrottler()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -1277,7 +1277,7 @@ func (h *Hub) processHello(client HandlerClient, message *ClientMessage) {
|
|||
resumeId := message.Hello.ResumeId
|
||||
if resumeId != "" {
|
||||
throttle, err := h.throttler.CheckBruteforce(ctx, client.RemoteAddr(), "HelloResume")
|
||||
if err == ErrBruteforceDetected {
|
||||
if err == async.ErrBruteforceDetected {
|
||||
client.SendMessage(message.NewErrorServerMessage(TooManyRequests))
|
||||
return
|
||||
} else if err != nil {
|
||||
|
|
@ -1581,7 +1581,7 @@ func (h *Hub) processHelloInternal(client HandlerClient, message *ClientMessage)
|
|||
|
||||
ctx := log.NewLoggerContext(client.Context(), h.logger)
|
||||
throttle, err := h.throttler.CheckBruteforce(ctx, client.RemoteAddr(), "HelloInternal")
|
||||
if err == ErrBruteforceDetected {
|
||||
if err == async.ErrBruteforceDetected {
|
||||
client.SendMessage(message.NewErrorServerMessage(TooManyRequests))
|
||||
return
|
||||
} else if err != nil {
|
||||
|
|
|
|||
29
hub_test.go
29
hub_test.go
|
|
@ -53,6 +53,7 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/strukturag/nextcloud-spreed-signaling/api"
|
||||
"github.com/strukturag/nextcloud-spreed-signaling/async"
|
||||
"github.com/strukturag/nextcloud-spreed-signaling/internal"
|
||||
"github.com/strukturag/nextcloud-spreed-signaling/log"
|
||||
"github.com/strukturag/nextcloud-spreed-signaling/test"
|
||||
|
|
@ -1440,6 +1441,22 @@ func TestClientHelloResume(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
type throttlerTiming struct {
|
||||
t *testing.T
|
||||
|
||||
now time.Time
|
||||
expectedSleep time.Duration
|
||||
}
|
||||
|
||||
func (t *throttlerTiming) getNow() time.Time {
|
||||
return t.now
|
||||
}
|
||||
|
||||
func (t *throttlerTiming) doDelay(ctx context.Context, duration time.Duration) {
|
||||
t.t.Helper()
|
||||
assert.Equal(t.t, t.expectedSleep, duration)
|
||||
}
|
||||
|
||||
func TestClientHelloResumeThrottle(t *testing.T) {
|
||||
t.Parallel()
|
||||
require := require.New(t)
|
||||
|
|
@ -1450,12 +1467,12 @@ func TestClientHelloResumeThrottle(t *testing.T) {
|
|||
t: t,
|
||||
now: time.Now(),
|
||||
}
|
||||
throttler := newMemoryThrottlerForTest(t)
|
||||
th, ok := throttler.(*memoryThrottler)
|
||||
require.True(ok, "expected memoryThrottler, got %T", throttler)
|
||||
th.getNow = timing.getNow
|
||||
th.doDelay = timing.doDelay
|
||||
hub.throttler = th
|
||||
throttler, err := async.NewCustomMemoryThrottler(timing.getNow, timing.doDelay)
|
||||
require.NoError(err)
|
||||
t.Cleanup(func() {
|
||||
throttler.Close()
|
||||
})
|
||||
hub.throttler = throttler
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue