mirror of
https://github.com/strukturag/nextcloud-spreed-signaling
synced 2026-03-14 14:35:44 +01:00
Implement custom session id codec that generates shorter ids.
Inspired by the previously used https://github.com/gorilla/securecookie and simplified / adjusted to our use case.
This commit is contained in:
parent
f4fca4f52b
commit
f3a81c23c3
8 changed files with 360 additions and 100 deletions
|
|
@ -24,9 +24,12 @@ package signaling
|
|||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Benchmark_GetSubjectForSessionId(b *testing.B) {
|
||||
require := require.New(b)
|
||||
backend := &Backend{
|
||||
id: "compat",
|
||||
}
|
||||
|
|
@ -35,11 +38,10 @@ func Benchmark_GetSubjectForSessionId(b *testing.B) {
|
|||
Created: time.Now().UnixMicro(),
|
||||
BackendId: backend.Id(),
|
||||
}
|
||||
codec := NewSessionIdCodec([]byte("12345678901234567890123456789012"), []byte("09876543210987654321098765432109"))
|
||||
codec, err := NewSessionIdCodec([]byte("12345678901234567890123456789012"), []byte("09876543210987654321098765432109"))
|
||||
require.NoError(err)
|
||||
sid, err := codec.EncodePublic(data)
|
||||
if err != nil {
|
||||
b.Fatalf("could not create session id: %s", err)
|
||||
}
|
||||
require.NoError(err, "could not create session id")
|
||||
for b.Loop() {
|
||||
GetSubjectForSessionId(sid, backend)
|
||||
}
|
||||
|
|
|
|||
1
go.mod
1
go.mod
|
|
@ -8,7 +8,6 @@ require (
|
|||
github.com/golang-jwt/jwt/v5 v5.3.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/mux v1.8.1
|
||||
github.com/gorilla/securecookie v1.1.2
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/mailru/easyjson v0.9.1
|
||||
github.com/nats-io/nats-server/v2 v2.12.2
|
||||
|
|
|
|||
4
go.sum
4
go.sum
|
|
@ -40,14 +40,10 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
|||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/go-tpm v0.9.6 h1:Ku42PT4LmjDu1H5C5ISWLlpI1mj+Zq7sPGKoRw2XROA=
|
||||
github.com/google/go-tpm v0.9.6/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY=
|
||||
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
|
||||
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
|
||||
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
|
||||
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
|
||||
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus v1.0.1 h1:qnpSQwGEnkcRpTqNOIR6bJbR0gAorgP9CSALpRcKoAA=
|
||||
|
|
|
|||
21
hub.go
21
hub.go
|
|
@ -141,7 +141,7 @@ type Hub struct {
|
|||
logger Logger
|
||||
events AsyncEvents
|
||||
upgrader websocket.Upgrader
|
||||
cookie *SessionIdCodec
|
||||
sessionIds *SessionIdCodec
|
||||
info *WelcomeServerMessage
|
||||
infoInternal *WelcomeServerMessage
|
||||
welcome atomic.Value // *ServerMessage
|
||||
|
|
@ -240,6 +240,11 @@ func NewHub(ctx context.Context, config *goconf.ConfigFile, events AsyncEvents,
|
|||
return nil, fmt.Errorf("the sessions block key must be 16, 24 or 32 bytes but is %d bytes", len(blockKey))
|
||||
}
|
||||
|
||||
sessionIds, err := NewSessionIdCodec([]byte(hashKey), blockBytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating session id codec: %w", err)
|
||||
}
|
||||
|
||||
internalClientsSecret, _ := GetStringOptionWithEnv(config, "clients", "internalsecret")
|
||||
if internalClientsSecret == "" {
|
||||
logger.Println("WARNING: No shared secret has been set for internal clients.")
|
||||
|
|
@ -360,7 +365,7 @@ func NewHub(ctx context.Context, config *goconf.ConfigFile, events AsyncEvents,
|
|||
JanusEventsSubprotocol,
|
||||
},
|
||||
},
|
||||
cookie: NewSessionIdCodec([]byte(hashKey), blockBytes),
|
||||
sessionIds: sessionIds,
|
||||
info: NewWelcomeServerMessage(version, DefaultFeatures...),
|
||||
infoInternal: NewWelcomeServerMessage(version, DefaultFeaturesInternal...),
|
||||
|
||||
|
|
@ -668,7 +673,7 @@ func (h *Hub) decodePrivateSessionId(id PrivateSessionId) *SessionIdData {
|
|||
return result
|
||||
}
|
||||
|
||||
data, err := h.cookie.DecodePrivate(id)
|
||||
data, err := h.sessionIds.DecodePrivate(id)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
|
@ -688,7 +693,7 @@ func (h *Hub) decodePublicSessionId(id PublicSessionId) *SessionIdData {
|
|||
return result
|
||||
}
|
||||
|
||||
data, err := h.cookie.DecodePublic(id)
|
||||
data, err := h.sessionIds.DecodePublic(id)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
|
@ -968,12 +973,12 @@ func (h *Hub) processRegister(c HandlerClient, message *ClientMessage, backend *
|
|||
}
|
||||
|
||||
sessionIdData := h.newSessionIdData(backend)
|
||||
privateSessionId, err := h.cookie.EncodePrivate(sessionIdData)
|
||||
privateSessionId, err := h.sessionIds.EncodePrivate(sessionIdData)
|
||||
if err != nil {
|
||||
client.SendMessage(message.NewWrappedErrorServerMessage(err))
|
||||
return
|
||||
}
|
||||
publicSessionId, err := h.cookie.EncodePublic(sessionIdData)
|
||||
publicSessionId, err := h.sessionIds.EncodePublic(sessionIdData)
|
||||
if err != nil {
|
||||
client.SendMessage(message.NewWrappedErrorServerMessage(err))
|
||||
return
|
||||
|
|
@ -2467,12 +2472,12 @@ func (h *Hub) processInternalMsg(sess Session, message *ClientMessage) {
|
|||
}
|
||||
|
||||
sessionIdData := h.newSessionIdData(session.Backend())
|
||||
privateSessionId, err := h.cookie.EncodePrivate(sessionIdData)
|
||||
privateSessionId, err := h.sessionIds.EncodePrivate(sessionIdData)
|
||||
if err != nil {
|
||||
h.logger.Printf("Could not encode private virtual session id: %s", err)
|
||||
return
|
||||
}
|
||||
publicSessionId, err := h.cookie.EncodePublic(sessionIdData)
|
||||
publicSessionId, err := h.sessionIds.EncodePublic(sessionIdData)
|
||||
if err != nil {
|
||||
h.logger.Printf("Could not encode public virtual session id: %s", err)
|
||||
return
|
||||
|
|
|
|||
28
hub_test.go
28
hub_test.go
|
|
@ -827,7 +827,8 @@ func performHousekeeping(hub *Hub, now time.Time) *sync.WaitGroup {
|
|||
return &wg
|
||||
}
|
||||
|
||||
func Benchmark_DecodePrivateSessionId(b *testing.B) {
|
||||
func Benchmark_DecodePrivateSessionIdCached(b *testing.B) {
|
||||
require := require.New(b)
|
||||
decodeCaches := make([]*LruCache[*SessionIdData], 0, numDecodeCaches)
|
||||
for range numDecodeCaches {
|
||||
decodeCaches = append(decodeCaches, NewLruCache[*SessionIdData](decodeCacheSize))
|
||||
|
|
@ -840,23 +841,23 @@ func Benchmark_DecodePrivateSessionId(b *testing.B) {
|
|||
Created: time.Now().UnixMicro(),
|
||||
BackendId: backend.Id(),
|
||||
}
|
||||
codec := NewSessionIdCodec([]byte("12345678901234567890123456789012"), []byte("09876543210987654321098765432109"))
|
||||
codec, err := NewSessionIdCodec([]byte("12345678901234567890123456789012"), []byte("09876543210987654321098765432109"))
|
||||
require.NoError(err)
|
||||
sid, err := codec.EncodePrivate(data)
|
||||
if err != nil {
|
||||
b.Fatalf("could not create session id: %s", err)
|
||||
}
|
||||
require.NoError(err, "could not create session id")
|
||||
hub := &Hub{
|
||||
cookie: codec,
|
||||
sessionIds: codec,
|
||||
decodeCaches: decodeCaches,
|
||||
}
|
||||
// Decode once to populate cache.
|
||||
hub.decodePrivateSessionId(sid)
|
||||
require.NotNil(hub.decodePrivateSessionId(sid))
|
||||
for b.Loop() {
|
||||
hub.decodePrivateSessionId(sid)
|
||||
}
|
||||
}
|
||||
|
||||
func Benchmark_DecodePublicSessionId(b *testing.B) {
|
||||
func Benchmark_DecodePublicSessionIdCached(b *testing.B) {
|
||||
require := require.New(b)
|
||||
decodeCaches := make([]*LruCache[*SessionIdData], 0, numDecodeCaches)
|
||||
for range numDecodeCaches {
|
||||
decodeCaches = append(decodeCaches, NewLruCache[*SessionIdData](decodeCacheSize))
|
||||
|
|
@ -869,17 +870,16 @@ func Benchmark_DecodePublicSessionId(b *testing.B) {
|
|||
Created: time.Now().UnixMicro(),
|
||||
BackendId: backend.Id(),
|
||||
}
|
||||
codec := NewSessionIdCodec([]byte("12345678901234567890123456789012"), []byte("09876543210987654321098765432109"))
|
||||
codec, err := NewSessionIdCodec([]byte("12345678901234567890123456789012"), []byte("09876543210987654321098765432109"))
|
||||
require.NoError(err)
|
||||
sid, err := codec.EncodePublic(data)
|
||||
if err != nil {
|
||||
b.Fatalf("could not create session id: %s", err)
|
||||
}
|
||||
require.NoError(err, "could not create session id")
|
||||
hub := &Hub{
|
||||
cookie: codec,
|
||||
sessionIds: codec,
|
||||
decodeCaches: decodeCaches,
|
||||
}
|
||||
// Decode once to populate cache.
|
||||
hub.decodePublicSessionId(sid)
|
||||
require.NotNil(hub.decodePublicSessionId(sid))
|
||||
for b.Loop() {
|
||||
hub.decodePublicSessionId(sid)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -226,8 +226,12 @@ func NewProxyServer(ctx context.Context, r *mux.Router, version string, config *
|
|||
return nil, fmt.Errorf("could not generate random block key: %s", err)
|
||||
}
|
||||
|
||||
sessionIds, err := signaling.NewSessionIdCodec(hashKey, blockKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating session id codec: %w", err)
|
||||
}
|
||||
|
||||
var tokens ProxyTokens
|
||||
var err error
|
||||
tokenType, _ := config.GetString("app", "tokentype")
|
||||
if tokenType == "" {
|
||||
tokenType = TokenTypeDefault
|
||||
|
|
@ -367,7 +371,7 @@ func NewProxyServer(ctx context.Context, r *mux.Router, version string, config *
|
|||
|
||||
tokens: tokens,
|
||||
|
||||
cookie: signaling.NewSessionIdCodec(hashKey, blockKey),
|
||||
cookie: sessionIds,
|
||||
sessions: make(map[uint64]*ProxySession),
|
||||
|
||||
clients: make(map[string]signaling.McuClient),
|
||||
|
|
|
|||
|
|
@ -22,74 +22,237 @@
|
|||
package signaling
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gorilla/securecookie"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
type protoSerializer struct {
|
||||
}
|
||||
|
||||
func (s *protoSerializer) Serialize(src any) ([]byte, error) {
|
||||
msg, ok := src.(proto.Message)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("can't serialize type %T", src)
|
||||
}
|
||||
return proto.Marshal(msg)
|
||||
}
|
||||
|
||||
func (s *protoSerializer) Deserialize(src []byte, dst any) error {
|
||||
msg, ok := dst.(proto.Message)
|
||||
if !ok {
|
||||
return fmt.Errorf("can't deserialize type %T", src)
|
||||
}
|
||||
return proto.Unmarshal(src, msg)
|
||||
}
|
||||
|
||||
const (
|
||||
privateSessionName = "private-session"
|
||||
publicSessionName = "public-session"
|
||||
|
||||
// hmacLength specifies the length of the HMAC to use. 80 bits should be enough
|
||||
// to prevent tampering.
|
||||
hmacLength = 10
|
||||
)
|
||||
|
||||
type SessionIdCodec struct {
|
||||
cookie *securecookie.SecureCookie
|
||||
var (
|
||||
sessionHashFunc = sha256.New
|
||||
sessionEncoding = base64.URLEncoding.WithPadding(base64.NoPadding)
|
||||
sessionMarshalOptions = proto.MarshalOptions{
|
||||
UseCachedSize: true,
|
||||
}
|
||||
sessionUnmarshalOptions = proto.UnmarshalOptions{}
|
||||
sessionSeparator = []byte{'|'}
|
||||
)
|
||||
|
||||
type bytesPool struct {
|
||||
pool sync.Pool
|
||||
}
|
||||
|
||||
func NewSessionIdCodec(hashKey []byte, blockKey []byte) *SessionIdCodec {
|
||||
cookie := securecookie.New(hashKey, blockKey).
|
||||
MaxAge(0).
|
||||
SetSerializer(&protoSerializer{})
|
||||
return &SessionIdCodec{
|
||||
cookie: cookie,
|
||||
func (p *bytesPool) Get(size int) []byte {
|
||||
bb := p.pool.Get()
|
||||
if bb == nil {
|
||||
return make([]byte, size)
|
||||
}
|
||||
|
||||
b := *(bb.(*[]byte))
|
||||
if cap(b) < size {
|
||||
b = make([]byte, size)
|
||||
} else {
|
||||
b = b[:size]
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func (p *bytesPool) Put(b []byte) {
|
||||
p.pool.Put(&b)
|
||||
}
|
||||
|
||||
// SessionIdCodec encodes and decodes session ids.
|
||||
//
|
||||
// Inspired by https://github.com/gorilla/securecookie
|
||||
type SessionIdCodec struct {
|
||||
hashKey []byte
|
||||
cipher cipher.Block
|
||||
|
||||
bytesPool bytesPool
|
||||
hmacPool sync.Pool
|
||||
dataPool sync.Pool
|
||||
}
|
||||
|
||||
func NewSessionIdCodec(hashKey []byte, blockKey []byte) (*SessionIdCodec, error) {
|
||||
if len(hashKey) == 0 {
|
||||
return nil, errors.New("hash key is not set")
|
||||
}
|
||||
|
||||
codec := &SessionIdCodec{
|
||||
hashKey: hashKey,
|
||||
hmacPool: sync.Pool{
|
||||
New: func() any {
|
||||
return hmac.New(sessionHashFunc, hashKey)
|
||||
},
|
||||
},
|
||||
dataPool: sync.Pool{
|
||||
New: func() any {
|
||||
return &SessionIdData{}
|
||||
},
|
||||
},
|
||||
}
|
||||
if len(blockKey) > 0 {
|
||||
block, err := aes.NewCipher(blockKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating cipher: %w", err)
|
||||
}
|
||||
codec.cipher = block
|
||||
}
|
||||
return codec, nil
|
||||
}
|
||||
|
||||
func (c *SessionIdCodec) encrypt(data []byte) ([]byte, error) {
|
||||
iv := c.bytesPool.Get(c.cipher.BlockSize() + len(data))[:c.cipher.BlockSize()]
|
||||
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
|
||||
return nil, fmt.Errorf("error creating iv: %w", err)
|
||||
}
|
||||
|
||||
ctr := cipher.NewCTR(c.cipher, iv)
|
||||
ctr.XORKeyStream(data, data)
|
||||
return append(iv, data...), nil
|
||||
}
|
||||
|
||||
func (c *SessionIdCodec) decrypt(data []byte) ([]byte, error) {
|
||||
bs := c.cipher.BlockSize()
|
||||
if len(data) <= bs {
|
||||
return nil, errors.New("no iv found in data")
|
||||
}
|
||||
|
||||
iv := data[:bs]
|
||||
data = data[bs:]
|
||||
ctr := cipher.NewCTR(c.cipher, iv)
|
||||
ctr.XORKeyStream(data, data)
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (c *SessionIdCodec) encodeToString(b []byte) string {
|
||||
s := c.bytesPool.Get(sessionEncoding.EncodedLen(len(b)))
|
||||
defer c.bytesPool.Put(s)
|
||||
|
||||
sessionEncoding.Encode(s, b)
|
||||
return string(s)
|
||||
}
|
||||
|
||||
func (c *SessionIdCodec) decodeFromString(s string) ([]byte, error) {
|
||||
b := c.bytesPool.Get(sessionEncoding.DecodedLen(len(s)))
|
||||
n, err := sessionEncoding.Decode(b, []byte(s))
|
||||
if err != nil {
|
||||
c.bytesPool.Put(b)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b[:n], nil
|
||||
}
|
||||
|
||||
func (c *SessionIdCodec) encodeRaw(name string, data *SessionIdData) ([]byte, error) {
|
||||
body := c.bytesPool.Get(sessionMarshalOptions.Size(data))
|
||||
defer c.bytesPool.Put(body)
|
||||
|
||||
body, err := sessionMarshalOptions.MarshalAppend(body[:0], data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling data: %w", err)
|
||||
}
|
||||
|
||||
if c.cipher != nil {
|
||||
body, err = c.encrypt(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error encrypting data: %w", err)
|
||||
}
|
||||
|
||||
defer c.bytesPool.Put(body)
|
||||
}
|
||||
|
||||
h := c.hmacPool.Get().(hash.Hash)
|
||||
defer c.hmacPool.Put(h)
|
||||
h.Reset()
|
||||
h.Write(unsafe.Slice(unsafe.StringData(name), len(name))) // nolint
|
||||
h.Write(sessionSeparator) // nolint
|
||||
h.Write(body) // nolint
|
||||
mac := c.bytesPool.Get(h.Size())
|
||||
defer c.bytesPool.Put(mac)
|
||||
mac = h.Sum(mac[:0])
|
||||
|
||||
result := c.bytesPool.Get(len(body) + hmacLength)[:0]
|
||||
result = append(result, body...)
|
||||
result = append(result, mac[:hmacLength]...)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *SessionIdCodec) decodeRaw(name string, value []byte) (*SessionIdData, error) {
|
||||
h := c.hmacPool.Get().(hash.Hash)
|
||||
defer c.hmacPool.Put(h)
|
||||
size := min(hmacLength, h.Size())
|
||||
if len(value) <= size {
|
||||
return nil, errors.New("no hmac found in session id")
|
||||
}
|
||||
|
||||
h.Reset()
|
||||
mac := value[len(value)-size:]
|
||||
decoded := value[:len(value)-size]
|
||||
|
||||
h.Write(unsafe.Slice(unsafe.StringData(name), len(name))) // nolint
|
||||
h.Write(sessionSeparator) // nolint
|
||||
h.Write(decoded) // nolint
|
||||
check := c.bytesPool.Get(h.Size())
|
||||
defer c.bytesPool.Put(check)
|
||||
if subtle.ConstantTimeCompare(mac, h.Sum(check[:0])[:hmacLength]) == 0 {
|
||||
return nil, errors.New("invalid hmac in session id")
|
||||
}
|
||||
|
||||
if c.cipher != nil {
|
||||
var err error
|
||||
if decoded, err = c.decrypt(decoded); err != nil {
|
||||
return nil, fmt.Errorf("invalid session id: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
data := c.dataPool.Get().(*SessionIdData)
|
||||
if err := sessionUnmarshalOptions.Unmarshal(decoded, data); err != nil {
|
||||
c.dataPool.Put(data)
|
||||
return nil, fmt.Errorf("invalid session id: %w", err)
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (c *SessionIdCodec) EncodePrivate(sessionData *SessionIdData) (PrivateSessionId, error) {
|
||||
id, err := c.cookie.Encode(privateSessionName, sessionData)
|
||||
id, err := c.encodeRaw(privateSessionName, sessionData)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return PrivateSessionId(id), nil
|
||||
defer c.bytesPool.Put(id)
|
||||
return PrivateSessionId(c.encodeToString(id)), nil
|
||||
}
|
||||
|
||||
func reverseSessionId(s string) (string, error) {
|
||||
// Note that we are assuming base64 encoded strings here.
|
||||
decoded, err := base64.URLEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
func (c *SessionIdCodec) reverseSessionId(data []byte) {
|
||||
for i, j := 0, len(data)-1; i < j; i, j = i+1, j-1 {
|
||||
data[i], data[j] = data[j], data[i]
|
||||
}
|
||||
|
||||
for i, j := 0, len(decoded)-1; i < j; i, j = i+1, j-1 {
|
||||
decoded[i], decoded[j] = decoded[j], decoded[i]
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(decoded), nil
|
||||
}
|
||||
|
||||
func (c *SessionIdCodec) EncodePublic(sessionData *SessionIdData) (PublicSessionId, error) {
|
||||
encoded, err := c.cookie.Encode(publicSessionName, sessionData)
|
||||
encoded, err := c.encodeRaw(publicSessionName, sessionData)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
|
@ -99,33 +262,36 @@ func (c *SessionIdCodec) EncodePublic(sessionData *SessionIdData) (PublicSession
|
|||
// (a timestamp) but the suffix the (random) hash.
|
||||
// By reversing we move the hash to the front, making the comparison of
|
||||
// session ids "random".
|
||||
id, err := reverseSessionId(encoded)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
c.reverseSessionId(encoded)
|
||||
|
||||
return PublicSessionId(id), nil
|
||||
defer c.bytesPool.Put(encoded)
|
||||
return PublicSessionId(c.encodeToString(encoded)), nil
|
||||
}
|
||||
|
||||
func (c *SessionIdCodec) DecodePrivate(encodedData PrivateSessionId) (*SessionIdData, error) {
|
||||
var data SessionIdData
|
||||
if err := c.cookie.Decode(privateSessionName, string(encodedData), &data); err != nil {
|
||||
return nil, err
|
||||
decoded, err := c.decodeFromString(string(encodedData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid session id: %w", err)
|
||||
}
|
||||
defer c.bytesPool.Put(decoded)
|
||||
|
||||
return &data, nil
|
||||
return c.decodeRaw(privateSessionName, decoded)
|
||||
}
|
||||
|
||||
func (c *SessionIdCodec) DecodePublic(encodedData PublicSessionId) (*SessionIdData, error) {
|
||||
reversed, err := reverseSessionId(string(encodedData))
|
||||
decoded, err := c.decodeFromString(string(encodedData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("invalid session id: %w", err)
|
||||
}
|
||||
defer c.bytesPool.Put(decoded)
|
||||
|
||||
var data SessionIdData
|
||||
if err := c.cookie.Decode(publicSessionName, reversed, &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &data, nil
|
||||
c.reverseSessionId(decoded)
|
||||
return c.decodeRaw(publicSessionName, decoded)
|
||||
}
|
||||
|
||||
func (c *SessionIdCodec) Put(data *SessionIdData) {
|
||||
if data != nil {
|
||||
data.Reset()
|
||||
c.dataPool.Put(data)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@
|
|||
package signaling
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
|
@ -33,24 +32,97 @@ import (
|
|||
func TestReverseSessionId(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
a := base64.URLEncoding.EncodeToString([]byte("12345"))
|
||||
ar, err := reverseSessionId(a)
|
||||
codec, err := NewSessionIdCodec([]byte("12345678901234567890123456789012"), []byte("09876543210987654321098765432109"))
|
||||
require.NoError(err)
|
||||
require.NotEqual(a, ar)
|
||||
b := base64.URLEncoding.EncodeToString([]byte("54321"))
|
||||
br, err := reverseSessionId(b)
|
||||
require.NoError(err)
|
||||
require.NotEqual(b, br)
|
||||
assert.Equal(b, ar)
|
||||
assert.Equal(a, br)
|
||||
a := []byte("12345")
|
||||
codec.reverseSessionId(a)
|
||||
assert.Equal([]byte("54321"), a)
|
||||
b := []byte("4321")
|
||||
codec.reverseSessionId(b)
|
||||
assert.Equal([]byte("1234"), b)
|
||||
}
|
||||
|
||||
// Invalid base64.
|
||||
if s, err := reverseSessionId("hello world!"); !assert.Error(err) {
|
||||
assert.Fail("should have failed", "received %s", s)
|
||||
func Benchmark_EncodePrivateSessionId(b *testing.B) {
|
||||
require := require.New(b)
|
||||
backend := &Backend{
|
||||
id: "compat",
|
||||
}
|
||||
// Invalid base64 length.
|
||||
if s, err := reverseSessionId("123"); !assert.Error(err) {
|
||||
assert.Fail("should have failed", "received %s", s)
|
||||
data := &SessionIdData{
|
||||
Sid: 1,
|
||||
Created: time.Now().UnixMicro(),
|
||||
BackendId: backend.Id(),
|
||||
}
|
||||
codec, err := NewSessionIdCodec([]byte("12345678901234567890123456789012"), []byte("09876543210987654321098765432109"))
|
||||
require.NoError(err)
|
||||
for b.Loop() {
|
||||
if _, err := codec.EncodePrivate(data); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Benchmark_DecodePrivateSessionId(b *testing.B) {
|
||||
require := require.New(b)
|
||||
backend := &Backend{
|
||||
id: "compat",
|
||||
}
|
||||
data := &SessionIdData{
|
||||
Sid: 1,
|
||||
Created: time.Now().UnixMicro(),
|
||||
BackendId: backend.Id(),
|
||||
}
|
||||
codec, err := NewSessionIdCodec([]byte("12345678901234567890123456789012"), []byte("09876543210987654321098765432109"))
|
||||
require.NoError(err)
|
||||
sid, err := codec.EncodePrivate(data)
|
||||
require.NoError(err)
|
||||
for b.Loop() {
|
||||
if decoded, err := codec.DecodePrivate(sid); err != nil {
|
||||
b.Fatal(err)
|
||||
} else {
|
||||
codec.Put(decoded)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Benchmark_EncodePublicSessionId(b *testing.B) {
|
||||
require := require.New(b)
|
||||
backend := &Backend{
|
||||
id: "compat",
|
||||
}
|
||||
data := &SessionIdData{
|
||||
Sid: 1,
|
||||
Created: time.Now().UnixMicro(),
|
||||
BackendId: backend.Id(),
|
||||
}
|
||||
codec, err := NewSessionIdCodec([]byte("12345678901234567890123456789012"), []byte("09876543210987654321098765432109"))
|
||||
require.NoError(err)
|
||||
for b.Loop() {
|
||||
if _, err := codec.EncodePublic(data); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Benchmark_DecodePublicSessionId(b *testing.B) {
|
||||
require := require.New(b)
|
||||
backend := &Backend{
|
||||
id: "compat",
|
||||
}
|
||||
data := &SessionIdData{
|
||||
Sid: 1,
|
||||
Created: time.Now().UnixMicro(),
|
||||
BackendId: backend.Id(),
|
||||
}
|
||||
codec, err := NewSessionIdCodec([]byte("12345678901234567890123456789012"), []byte("09876543210987654321098765432109"))
|
||||
require.NoError(err)
|
||||
sid, err := codec.EncodePublic(data)
|
||||
require.NoError(err)
|
||||
for b.Loop() {
|
||||
if decoded, err := codec.DecodePublic(sid); err != nil {
|
||||
b.Fatal(err)
|
||||
} else {
|
||||
codec.Put(decoded)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -63,17 +135,33 @@ func TestPublicPrivate(t *testing.T) {
|
|||
BackendId: "foo",
|
||||
}
|
||||
|
||||
codec := NewSessionIdCodec([]byte("0123456789012345"), []byte("0123456789012345"))
|
||||
codec, err := NewSessionIdCodec([]byte("0123456789012345"), []byte("0123456789012345"))
|
||||
require.NoError(err)
|
||||
private, err := codec.EncodePrivate(sd)
|
||||
require.NoError(err)
|
||||
public, err := codec.EncodePublic(sd)
|
||||
require.NoError(err)
|
||||
assert.NotEqual(private, public)
|
||||
|
||||
if data, err := codec.DecodePublic(public); assert.NoError(err) {
|
||||
assert.Equal(sd.Sid, data.Sid)
|
||||
assert.Equal(sd.Created, data.Created)
|
||||
assert.Equal(sd.BackendId, data.BackendId)
|
||||
codec.Put(data)
|
||||
}
|
||||
if data, err := codec.DecodePrivate(private); assert.NoError(err) {
|
||||
assert.Equal(sd.Sid, data.Sid)
|
||||
assert.Equal(sd.Created, data.Created)
|
||||
assert.Equal(sd.BackendId, data.BackendId)
|
||||
codec.Put(data)
|
||||
}
|
||||
|
||||
if data, err := codec.DecodePublic(PublicSessionId(private)); !assert.Error(err) {
|
||||
assert.Fail("should have failed", "received %+v", data)
|
||||
codec.Put(data)
|
||||
}
|
||||
if data, err := codec.DecodePrivate(PrivateSessionId(public)); !assert.Error(err) {
|
||||
assert.Fail("should have failed", "received %+v", data)
|
||||
codec.Put(data)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue