JWT: replace jwtauth/jwx with lightweight wrapper around go-jose

We replaced the jwtauth and jwx libraries with a minimal custom wrapper
around go-jose because we don’t need the full feature set provided by jwx.
Implementing our own wrapper simplifies the codebase and improves
maintainability.

Moreover, go-jose depends only on the standard library, resulting in a
leaner dependency that still meets all our requirements.

This change also reduces the SFTPGo binary size by approximately 1MB

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino 2025-10-08 18:10:39 +02:00
commit 0ae2354fed
No known key found for this signature in database
GPG key ID: 935D2952DEC4EECF
31 changed files with 1222 additions and 967 deletions

12
go.mod
View file

@ -25,8 +25,8 @@ require (
github.com/fclairamb/go-log v0.6.0
github.com/go-acme/lego/v4 v4.26.0
github.com/go-chi/chi/v5 v5.2.3
github.com/go-chi/jwtauth/v5 v5.3.3
github.com/go-chi/render v1.0.3
github.com/go-jose/go-jose/v4 v4.1.3
github.com/go-sql-driver/mysql v1.9.3
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
github.com/google/uuid v1.6.0
@ -36,7 +36,6 @@ require (
github.com/jackc/pgx/v5 v5.7.6
github.com/jlaffaye/ftp v0.2.0
github.com/klauspost/compress v1.18.0
github.com/lestrrat-go/jwx/v2 v2.1.6
github.com/lithammer/shortuuid/v4 v4.2.0
github.com/mattn/go-sqlite3 v1.14.32
github.com/mhale/smtpd v0.8.3
@ -110,18 +109,15 @@ require (
github.com/coreos/go-systemd/v22 v22.6.0 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.7 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect
github.com/envoyproxy/go-control-plane/envoy v1.35.0 // indirect
github.com/envoyproxy/protoc-gen-validate v1.2.1 // indirect
github.com/fatih/color v1.18.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/go-jose/go-jose/v4 v4.1.3 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.3.0 // indirect
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
github.com/goccy/go-json v0.10.5 // indirect
github.com/golang-jwt/jwt/v5 v5.3.0 // indirect
github.com/golang/protobuf v1.5.4 // indirect
github.com/google/s2a-go v0.1.9 // indirect
@ -137,11 +133,6 @@ require (
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/kr/fs v0.1.0 // indirect
github.com/kylelemons/godebug v1.1.0 // indirect
github.com/lestrrat-go/blackmagic v1.0.4 // indirect
github.com/lestrrat-go/httpcc v1.0.1 // indirect
github.com/lestrrat-go/httprc v1.0.6 // indirect
github.com/lestrrat-go/iter v1.0.2 // indirect
github.com/lestrrat-go/option v1.0.1 // indirect
github.com/lufia/plan9stats v0.0.0-20250827001030-24949be3fa54 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
@ -159,7 +150,6 @@ require (
github.com/prometheus/procfs v0.17.0 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/sagikazarmark/locafero v0.12.0 // indirect
github.com/segmentio/asm v1.2.1 // indirect
github.com/shoenig/go-m1cpu v0.1.7 // indirect
github.com/spf13/cast v1.10.0 // indirect
github.com/spf13/pflag v1.0.10 // indirect

20
go.sum
View file

@ -124,8 +124,6 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 h1:NMZiJj8QnKe1LgsbDayM4UoHwbvwDRwnI3hwNaAHRnc=
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0/go.mod h1:ZXNYxsqcloTdSy/rNShjYzMhyjf0LaoftYK0p+A3h40=
github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0 h1:EW9gIJRmt9lzk66Fhh4S8VEtURA6QHZqGeSRE9Nb2/U=
github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/drakkan/ftp v0.0.0-20240430173938-7ba8270c8e7f h1:S9JUlrOzjK58UKoLqqb40YLyVlt0bcIFtYrvnanV3zc=
@ -159,8 +157,6 @@ github.com/go-acme/lego/v4 v4.26.0 h1:521aEQxNstXvPQcFDDPrJiFfixcCQuvAvm35R4GbyY
github.com/go-acme/lego/v4 v4.26.0/go.mod h1:BQVAWgcyzW4IT9eIKHY/RxYlVhoyKyOMXOkq7jK1eEQ=
github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE=
github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
github.com/go-chi/jwtauth/v5 v5.3.3 h1:50Uzmacu35/ZP9ER2Ht6SazwPsnLQ9LRJy6zTZJpHEo=
github.com/go-chi/jwtauth/v5 v5.3.3/go.mod h1:O4QvPRuZLZghl9WvfVaON+ARfGzpD2PBX/QY5vUz7aQ=
github.com/go-chi/render v1.0.3 h1:AsXqd2a1/INaIfUSKq3G5uA8weYx20FOsM7uSoCyyt4=
github.com/go-chi/render v1.0.3/go.mod h1:/gr3hVkmYR0YlEy3LxCuVRFzEu9Ruok+gFqbIofjao0=
github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs=
@ -181,8 +177,6 @@ github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gofrs/flock v0.12.1 h1:MTLVXXHf8ekldpJk3AKicLij9MdwOWkZ+a/jHHZby9E=
github.com/gofrs/flock v0.12.1/go.mod h1:9zxTsyu5xtJ9DK+1tFZyibEV7y3uwDxPPfbxeeHCoD0=
@ -246,18 +240,6 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/lestrrat-go/blackmagic v1.0.4 h1:IwQibdnf8l2KoO+qC3uT4OaTWsW7tuRQXy9TRN9QanA=
github.com/lestrrat-go/blackmagic v1.0.4/go.mod h1:6AWFyKNNj0zEXQYfTMPfZrAXUWUfTIZ5ECEUEJaijtw=
github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE=
github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E=
github.com/lestrrat-go/httprc v1.0.6 h1:qgmgIRhpvBqexMJjA/PmwSvhNk679oqD1RbovdCGW8k=
github.com/lestrrat-go/httprc v1.0.6/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo=
github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI=
github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4=
github.com/lestrrat-go/jwx/v2 v2.1.6 h1:hxM1gfDILk/l5ylers6BX/Eq1m/pnxe9NBwW6lVfecA=
github.com/lestrrat-go/jwx/v2 v2.1.6/go.mod h1:Y722kU5r/8mV7fYDifjug0r8FK8mZdw0K0GpJw/l8pU=
github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU=
github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lithammer/shortuuid/v4 v4.2.0 h1:LMFOzVB3996a7b8aBuEXxqOBflbfPQAiVzkIcHO0h8c=
@ -331,8 +313,6 @@ github.com/sagikazarmark/locafero v0.12.0 h1:/NQhBAkUb4+fH1jivKHWusDYFjMOOKU88ee
github.com/sagikazarmark/locafero v0.12.0/go.mod h1:sZh36u/YSZ918v0Io+U9ogLYQJ9tLLBmM4eneO6WwsI=
github.com/secsy/goftp v0.0.0-20200609142545-aa2de14babf4 h1:PT+ElG/UUFMfqy5HrxJxNzj3QBOf7dZwupeVC+mG1Lo=
github.com/secsy/goftp v0.0.0-20200609142545-aa2de14babf4/go.mod h1:MnkX001NG75g3p8bhFycnyIjeQoOjGL6CEIsdE/nKSY=
github.com/segmentio/asm v1.2.1 h1:DTNbBqs57ioxAD4PrArqftgypG4/qNpXoJx8TVXxPR0=
github.com/segmentio/asm v1.2.1/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
github.com/sftpgo/sdk v0.1.9-0.20241011171103-64fc18a344f9 h1:wlXBnaNfJJJRZjHO2AerSS5gp0ckkYUgBzSXivUo0Wo=
github.com/sftpgo/sdk v0.1.9-0.20241011171103-64fc18a344f9/go.mod h1:ehimvlTP+XTEiE3t1CPwWx9n7+6A6OGvMGlZ7ouvKFk=
github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI=

View file

@ -28,11 +28,10 @@ import (
"strings"
"time"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/rs/xid"
"github.com/go-jose/go-jose/v4"
"github.com/drakkan/sftpgo/v2/internal/httpclient"
"github.com/drakkan/sftpgo/v2/internal/jwt"
"github.com/drakkan/sftpgo/v2/internal/kms"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/util"
@ -46,7 +45,8 @@ const (
const (
// NodeTokenHeader defines the header to use for the node auth token
NodeTokenHeader = "X-SFTPGO-Node"
NodeTokenHeader = "X-SFTPGO-Node"
nodeTokenAudience = "node"
)
var (
@ -132,35 +132,26 @@ func (n *Node) validate() error {
return n.Data.validate()
}
func (n *Node) authenticate(token string) (string, string, []string, error) {
func (n *Node) authenticate(token string) (*jwt.Claims, error) {
if err := n.Data.Key.TryDecrypt(); err != nil {
providerLog(logger.LevelError, "unable to decrypt node key: %v", err)
return "", "", nil, err
return nil, err
}
if token == "" {
return "", "", nil, ErrInvalidCredentials
return nil, ErrInvalidCredentials
}
t, err := jwt.Parse([]byte(token), jwt.WithKey(jwa.HS256, []byte(n.Data.Key.GetPayload())), jwt.WithValidate(true))
claims, err := jwt.VerifyTokenWithKey(token, []jose.SignatureAlgorithm{jose.HS256}, []byte(n.Data.Key.GetPayload()))
if err != nil {
return "", "", nil, fmt.Errorf("unable to parse and validate token: %v", err)
return nil, fmt.Errorf("unable to parse and validate token: %v", err)
}
var adminUsername, role string
if admin, ok := t.Get("admin"); ok {
if val, ok := admin.(string); ok && val != "" {
adminUsername = val
}
if claims.Username == "" {
return nil, errors.New("no admin username associated with node token")
}
if adminUsername == "" {
return "", "", nil, errors.New("no admin username associated with node token")
if !claims.Audience.Contains(nodeTokenAudience) {
return nil, errors.New("invalid node token audience")
}
if r, ok := t.Get("role"); ok {
if val, ok := r.(string); ok && val != "" {
role = val
}
}
perms := getPermsFromToken(t)
return adminUsername, role, perms, nil
return claims, nil
}
// getBaseURL returns the base URL for this node
@ -181,22 +172,22 @@ func (n *Node) generateAuthToken(username, role string, permissions []string) (s
if err := n.Data.Key.TryDecrypt(); err != nil {
return "", fmt.Errorf("unable to decrypt node key: %w", err)
}
now := time.Now().UTC()
t := jwt.New()
t.Set("admin", username) //nolint:errcheck
t.Set("role", role) //nolint:errcheck
t.Set("perms", permissions) //nolint:errcheck
t.Set(jwt.IssuedAtKey, now) //nolint:errcheck
t.Set(jwt.JwtIDKey, xid.New().String()) //nolint:errcheck
t.Set(jwt.NotBeforeKey, now.Add(-30*time.Second)) //nolint:errcheck
t.Set(jwt.ExpirationKey, now.Add(1*time.Minute)) //nolint:errcheck
payload, err := jwt.Sign(t, jwt.WithKey(jwa.HS256, []byte(n.Data.Key.GetPayload())))
signer, err := jwt.NewSigner(jose.HS256, []byte(n.Data.Key.GetPayload()))
if err != nil {
return "", fmt.Errorf("unable to create signer: %w", err)
}
claims := &jwt.Claims{
Username: username,
Role: role,
Permissions: permissions,
}
claims.Audience = []string{nodeTokenAudience}
claims.SetExpiry(time.Now().Add(1 * time.Minute))
payload, err := signer.Sign(claims)
if err != nil {
return "", fmt.Errorf("unable to sign authentication token: %w", err)
}
return util.BytesToString(payload), nil
return payload, nil
}
func (n *Node) prepareRequest(ctx context.Context, username, role, relativeURL, method string,
@ -273,9 +264,9 @@ func (n *Node) SendDeleteRequest(username, role, relativeURL string, permissions
}
// AuthenticateNodeToken check the validity of the provided token
func AuthenticateNodeToken(token string) (string, string, []string, error) {
func AuthenticateNodeToken(token string) (*jwt.Claims, error) {
if currentNode == nil {
return "", "", nil, errNoClusterNodes
return nil, errNoClusterNodes
}
return currentNode.authenticate(token)
}
@ -287,21 +278,3 @@ func GetNodeName() string {
}
return currentNode.Name
}
func getPermsFromToken(t jwt.Token) []string {
var perms []string
if p, ok := t.Get("perms"); ok {
switch v := p.(type) {
case []any:
for _, elem := range v {
switch elemValue := elem.(type) {
case string:
perms = append(perms, elemValue)
}
}
case []string:
perms = v
}
}
return perms
}

View file

@ -21,10 +21,10 @@ import (
"net/http"
"net/url"
"github.com/go-chi/jwtauth/v5"
"github.com/go-chi/render"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/jwt"
"github.com/drakkan/sftpgo/v2/internal/smtp"
"github.com/drakkan/sftpgo/v2/internal/util"
)
@ -68,7 +68,7 @@ func renderAdmin(w http.ResponseWriter, r *http.Request, username string, status
func addAdmin(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -90,7 +90,7 @@ func addAdmin(w http.ResponseWriter, r *http.Request) {
func disableAdmin2FA(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -138,7 +138,7 @@ func updateAdmin(w http.ResponseWriter, r *http.Request) {
return
}
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -182,7 +182,7 @@ func updateAdmin(w http.ResponseWriter, r *http.Request) {
func deleteAdmin(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
username := getURLParam(r, "username")
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -202,7 +202,7 @@ func deleteAdmin(w http.ResponseWriter, r *http.Request) {
func getAdminProfile(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -224,7 +224,7 @@ func getAdminProfile(w http.ResponseWriter, r *http.Request) {
func updateAdminProfile(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -317,7 +317,7 @@ func doChangeAdminPassword(r *http.Request, currentPassword, newPassword, confir
util.I18nErrorChangePwdNoDifferent,
)
}
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil {
return util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)
}
@ -335,14 +335,3 @@ func doChangeAdminPassword(r *http.Request, currentPassword, newPassword, confir
return dataprovider.UpdateAdmin(&admin, dataprovider.ActionExecutorSelf, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role)
}
func getTokenClaims(r *http.Request) (jwtTokenClaims, error) {
tokenClaims := jwtTokenClaims{}
_, claims, err := jwtauth.FromContext(r.Context())
if err != nil {
return tokenClaims, err
}
tokenClaims.Decode(claims)
return tokenClaims, nil
}

View file

@ -24,6 +24,7 @@ import (
"github.com/drakkan/sftpgo/v2/internal/common"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/jwt"
"github.com/drakkan/sftpgo/v2/internal/util"
)
@ -42,7 +43,7 @@ func getEventActions(w http.ResponseWriter, r *http.Request) {
render.JSON(w, r, actions)
}
func renderEventAction(w http.ResponseWriter, r *http.Request, name string, claims *jwtTokenClaims, status int) {
func renderEventAction(w http.ResponseWriter, r *http.Request, name string, claims *jwt.Claims, status int) {
action, err := dataprovider.EventActionExists(name)
if err != nil {
sendAPIResponse(w, r, err, "", getRespStatus(err))
@ -61,19 +62,19 @@ func renderEventAction(w http.ResponseWriter, r *http.Request, name string, clai
func getEventActionByName(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
}
name := getURLParam(r, "name")
renderEventAction(w, r, name, &claims, http.StatusOK)
renderEventAction(w, r, name, claims, http.StatusOK)
}
func addEventAction(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -91,12 +92,12 @@ func addEventAction(w http.ResponseWriter, r *http.Request) {
return
}
w.Header().Add("Location", fmt.Sprintf("%s/%s", eventActionsPath, url.PathEscape(action.Name)))
renderEventAction(w, r, action.Name, &claims, http.StatusCreated)
renderEventAction(w, r, action.Name, claims, http.StatusCreated)
}
func updateEventAction(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -136,7 +137,7 @@ func updateEventAction(w http.ResponseWriter, r *http.Request) {
func deleteEventAction(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -165,7 +166,7 @@ func getEventRules(w http.ResponseWriter, r *http.Request) {
render.JSON(w, r, rules)
}
func renderEventRule(w http.ResponseWriter, r *http.Request, name string, claims *jwtTokenClaims, status int) {
func renderEventRule(w http.ResponseWriter, r *http.Request, name string, claims *jwt.Claims, status int) {
rule, err := dataprovider.EventRuleExists(name)
if err != nil {
sendAPIResponse(w, r, err, "", getRespStatus(err))
@ -184,19 +185,19 @@ func renderEventRule(w http.ResponseWriter, r *http.Request, name string, claims
func getEventRuleByName(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
}
name := getURLParam(r, "name")
renderEventRule(w, r, name, &claims, http.StatusOK)
renderEventRule(w, r, name, claims, http.StatusOK)
}
func addEventRule(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -213,12 +214,12 @@ func addEventRule(w http.ResponseWriter, r *http.Request) {
return
}
w.Header().Add("Location", fmt.Sprintf("%s/%s", eventRulesPath, url.PathEscape(rule.Name)))
renderEventRule(w, r, rule.Name, &claims, http.StatusCreated)
renderEventRule(w, r, rule.Name, claims, http.StatusCreated)
}
func updateEventRule(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -249,7 +250,7 @@ func updateEventRule(w http.ResponseWriter, r *http.Request) {
func deleteEventRule(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return

View file

@ -27,6 +27,7 @@ import (
"github.com/sftpgo/sdk/plugin/notifier"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/jwt"
"github.com/drakkan/sftpgo/v2/internal/plugin"
"github.com/drakkan/sftpgo/v2/internal/util"
)
@ -143,7 +144,7 @@ func getLogSearchParamsFromRequest(r *http.Request) (eventsearcher.LogEventSearc
func searchFsEvents(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -176,7 +177,7 @@ func searchFsEvents(w http.ResponseWriter, r *http.Request) {
func searchProviderEvents(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -211,7 +212,7 @@ func searchProviderEvents(w http.ResponseWriter, r *http.Request) {
func searchLogEvents(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return

View file

@ -23,6 +23,7 @@ import (
"github.com/go-chi/render"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/jwt"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/internal/vfs"
)
@ -45,7 +46,7 @@ func getFolders(w http.ResponseWriter, r *http.Request) {
func addFolder(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -62,12 +63,12 @@ func addFolder(w http.ResponseWriter, r *http.Request) {
return
}
w.Header().Add("Location", fmt.Sprintf("%s/%s", folderPath, url.PathEscape(folder.Name)))
renderFolder(w, r, folder.Name, &claims, http.StatusCreated)
renderFolder(w, r, folder.Name, claims, http.StatusCreated)
}
func updateFolder(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -100,7 +101,7 @@ func updateFolder(w http.ResponseWriter, r *http.Request) {
sendAPIResponse(w, r, nil, "Folder updated", http.StatusOK)
}
func renderFolder(w http.ResponseWriter, r *http.Request, name string, claims *jwtTokenClaims, status int) {
func renderFolder(w http.ResponseWriter, r *http.Request, name string, claims *jwt.Claims, status int) {
folder, err := dataprovider.GetFolderByName(name)
if err != nil {
sendAPIResponse(w, r, err, "", getRespStatus(err))
@ -119,18 +120,18 @@ func renderFolder(w http.ResponseWriter, r *http.Request, name string, claims *j
func getFolderByName(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
}
name := getURLParam(r, "name")
renderFolder(w, r, name, &claims, http.StatusOK)
renderFolder(w, r, name, claims, http.StatusOK)
}
func deleteFolder(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return

View file

@ -23,6 +23,7 @@ import (
"github.com/go-chi/render"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/jwt"
"github.com/drakkan/sftpgo/v2/internal/util"
)
@ -44,7 +45,7 @@ func getGroups(w http.ResponseWriter, r *http.Request) {
func addGroup(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -61,12 +62,12 @@ func addGroup(w http.ResponseWriter, r *http.Request) {
return
}
w.Header().Add("Location", fmt.Sprintf("%s/%s", groupPath, url.PathEscape(group.Name)))
renderGroup(w, r, group.Name, &claims, http.StatusCreated)
renderGroup(w, r, group.Name, claims, http.StatusCreated)
}
func updateGroup(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -98,7 +99,7 @@ func updateGroup(w http.ResponseWriter, r *http.Request) {
sendAPIResponse(w, r, nil, "Group updated", http.StatusOK)
}
func renderGroup(w http.ResponseWriter, r *http.Request, name string, claims *jwtTokenClaims, status int) {
func renderGroup(w http.ResponseWriter, r *http.Request, name string, claims *jwt.Claims, status int) {
group, err := dataprovider.GroupExists(name)
if err != nil {
sendAPIResponse(w, r, err, "", getRespStatus(err))
@ -117,18 +118,18 @@ func renderGroup(w http.ResponseWriter, r *http.Request, name string, claims *jw
func getGroupByName(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
}
name := getURLParam(r, "name")
renderGroup(w, r, name, &claims, http.StatusOK)
renderGroup(w, r, name, claims, http.StatusOK)
}
func deleteGroup(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return

View file

@ -31,12 +31,13 @@ import (
"github.com/drakkan/sftpgo/v2/internal/common"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/jwt"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/util"
)
func getUserConnection(w http.ResponseWriter, r *http.Request) (*Connection, error) {
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return nil, fmt.Errorf("invalid token claims %w", err)
@ -457,7 +458,7 @@ func getUserFilesAsZipStream(w http.ResponseWriter, r *http.Request) {
func getUserProfile(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -482,7 +483,7 @@ func getUserProfile(w http.ResponseWriter, r *http.Request) {
func updateUserProfile(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -557,7 +558,7 @@ func doChangeUserPassword(r *http.Request, currentPassword, newPassword, confirm
util.I18nErrorChangePwdNoDifferent,
)
}
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
return util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)
}

View file

@ -25,6 +25,7 @@ import (
"github.com/go-chi/render"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/jwt"
"github.com/drakkan/sftpgo/v2/internal/util"
)
@ -68,7 +69,7 @@ func getIPListEntry(w http.ResponseWriter, r *http.Request) {
func addIPListEntry(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -91,7 +92,7 @@ func addIPListEntry(w http.ResponseWriter, r *http.Request) {
func updateIPListEntry(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -125,7 +126,7 @@ func updateIPListEntry(w http.ResponseWriter, r *http.Request) {
func deleteIPListEntry(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return

View file

@ -23,6 +23,7 @@ import (
"github.com/go-chi/render"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/jwt"
"github.com/drakkan/sftpgo/v2/internal/util"
)
@ -56,7 +57,7 @@ func getAPIKeyByID(w http.ResponseWriter, r *http.Request) {
func addAPIKey(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -87,7 +88,7 @@ func addAPIKey(w http.ResponseWriter, r *http.Request) {
func updateAPIKey(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -119,7 +120,7 @@ func updateAPIKey(w http.ResponseWriter, r *http.Request) {
func deleteAPIKey(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
keyID := getURLParam(r, "id")
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return

View file

@ -29,6 +29,7 @@ import (
"github.com/drakkan/sftpgo/v2/internal/common"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/jwt"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/internal/vfs"
@ -115,7 +116,7 @@ func dumpData(w http.ResponseWriter, r *http.Request) {
func loadDataFromRequest(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, MaxRestoreSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -143,7 +144,7 @@ func loadDataFromRequest(w http.ResponseWriter, r *http.Request) {
func loadData(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return

View file

@ -27,6 +27,7 @@ import (
"github.com/go-chi/render"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/jwt"
"github.com/drakkan/sftpgo/v2/internal/kms"
"github.com/drakkan/sftpgo/v2/internal/mfa"
"github.com/drakkan/sftpgo/v2/internal/util"
@ -66,13 +67,13 @@ func getTOTPConfigs(w http.ResponseWriter, r *http.Request) {
func generateTOTPSecret(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
}
var accountName string
if claims.hasUserAudience() {
if hasUserAudience(claims) {
accountName = fmt.Sprintf("User %q", claims.Username)
} else {
accountName = fmt.Sprintf("Admin %q", claims.Username)
@ -113,7 +114,7 @@ func getQRCode(w http.ResponseWriter, r *http.Request) {
func saveTOTPConfig(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -124,7 +125,7 @@ func saveTOTPConfig(w http.ResponseWriter, r *http.Request) {
recoveryCodes = append(recoveryCodes, dataprovider.RecoveryCode{Secret: kms.NewPlainSecret(code)})
}
baseURL := webBaseClientPath
if claims.hasUserAudience() {
if hasUserAudience(claims) {
if err := saveUserTOTPConfig(claims.Username, r, recoveryCodes); err != nil {
sendAPIResponse(w, r, err, "", getRespStatus(err))
return
@ -164,14 +165,14 @@ func validateTOTPPasscode(w http.ResponseWriter, r *http.Request) {
func getRecoveryCodes(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
}
recoveryCodes := make([]recoveryCode, 0, 12)
var accountRecoveryCodes []dataprovider.RecoveryCode
if claims.hasUserAudience() {
if hasUserAudience(claims) {
user, err := dataprovider.UserExists(claims.Username, "")
if err != nil {
sendAPIResponse(w, r, err, "", getRespStatus(err))
@ -210,7 +211,7 @@ func getRecoveryCodes(w http.ResponseWriter, r *http.Request) {
func generateRecoveryCodes(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -222,7 +223,7 @@ func generateRecoveryCodes(w http.ResponseWriter, r *http.Request) {
recoveryCodes = append(recoveryCodes, code)
accountRecoveryCodes = append(accountRecoveryCodes, dataprovider.RecoveryCode{Secret: kms.NewPlainSecret(code)})
}
if claims.hasUserAudience() {
if hasUserAudience(claims) {
user, err := dataprovider.UserExists(claims.Username, "")
if err != nil {
sendAPIResponse(w, r, err, "", getRespStatus(err))

View file

@ -23,6 +23,7 @@ import (
"github.com/drakkan/sftpgo/v2/internal/common"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/jwt"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/vfs"
)
@ -44,7 +45,7 @@ type transferQuotaUsage struct {
func getUsersQuotaScans(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -91,7 +92,7 @@ func startFolderQuotaScan(w http.ResponseWriter, r *http.Request) {
func updateUserTransferQuotaUsage(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -132,7 +133,7 @@ func updateUserTransferQuotaUsage(w http.ResponseWriter, r *http.Request) {
}
func doUpdateUserQuotaUsage(w http.ResponseWriter, r *http.Request, username string, usage quotaUsage) {
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -204,7 +205,7 @@ func doStartUserQuotaScan(w http.ResponseWriter, r *http.Request, username strin
sendAPIResponse(w, r, nil, "Quota tracking is disabled!", http.StatusForbidden)
return
}
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return

View file

@ -20,11 +20,12 @@ import (
"github.com/go-chi/render"
"github.com/drakkan/sftpgo/v2/internal/common"
"github.com/drakkan/sftpgo/v2/internal/jwt"
)
func getRetentionChecks(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return

View file

@ -23,6 +23,7 @@ import (
"github.com/go-chi/render"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/jwt"
"github.com/drakkan/sftpgo/v2/internal/util"
)
@ -44,7 +45,7 @@ func getRoles(w http.ResponseWriter, r *http.Request) {
func addRole(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -67,7 +68,7 @@ func addRole(w http.ResponseWriter, r *http.Request) {
func updateRole(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -119,7 +120,7 @@ func getRoleByName(w http.ResponseWriter, r *http.Request) {
func deleteRole(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return

View file

@ -26,20 +26,20 @@ import (
"strings"
"time"
"github.com/go-chi/jwtauth/v5"
"github.com/go-chi/render"
"github.com/rs/xid"
"github.com/sftpgo/sdk"
"github.com/drakkan/sftpgo/v2/internal/common"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/jwt"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/util"
)
func getShares(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -59,7 +59,7 @@ func getShares(w http.ResponseWriter, r *http.Request) {
func getShareByID(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -77,7 +77,7 @@ func getShareByID(w http.ResponseWriter, r *http.Request) {
func addShare(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -126,7 +126,7 @@ func addShare(w http.ResponseWriter, r *http.Request) {
func updateShare(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -177,7 +177,7 @@ func updateShare(w http.ResponseWriter, r *http.Request) {
func deleteShare(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
shareID := getURLParam(r, "id")
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -432,16 +432,16 @@ func (s *httpdServer) uploadFilesToShare(w http.ResponseWriter, r *http.Request)
}
}
func (s *httpdServer) getShareClaims(r *http.Request, shareID string) (context.Context, *jwtTokenClaims, error) {
token, err := jwtauth.VerifyRequest(s.tokenAuth, r, jwtauth.TokenFromCookie)
func (s *httpdServer) getShareClaims(r *http.Request, shareID string) (context.Context, *jwt.Claims, error) {
token, err := jwt.VerifyRequest(s.tokenAuth, r, jwt.TokenFromCookie)
if err != nil || token == nil {
return nil, nil, errInvalidToken
}
tokenString := jwtauth.TokenFromCookie(r)
tokenString := jwt.TokenFromCookie(r)
if tokenString == "" || invalidatedJWTTokens.Get(tokenString) {
return nil, nil, errInvalidToken
}
if !slices.Contains(token.Audience(), tokenAudienceWebShare) {
if !token.Audience.Contains(tokenAudienceWebShare) {
logger.Debug(logSender, "", "invalid token audience for share %q", shareID)
return nil, nil, errInvalidToken
}
@ -450,13 +450,12 @@ func (s *httpdServer) getShareClaims(r *http.Request, shareID string) (context.C
logger.Debug(logSender, "", "token for share %q is not valid for the ip address %q", shareID, ipAddr)
return nil, nil, err
}
ctx := jwtauth.NewContext(r.Context(), token, nil)
claims, err := getTokenClaims(r.WithContext(ctx))
if err != nil || claims.Username != shareID {
if token.Username != shareID {
logger.Debug(logSender, "", "token not valid for share %q", shareID)
return nil, nil, errInvalidToken
}
return ctx, &claims, nil
ctx := jwt.NewContext(r.Context(), token, nil)
return ctx, token, nil
}
func (s *httpdServer) checkWebClientShareCredentials(w http.ResponseWriter, r *http.Request, share *dataprovider.Share) error {

View file

@ -27,6 +27,7 @@ import (
"github.com/drakkan/sftpgo/v2/internal/common"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/jwt"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/smtp"
"github.com/drakkan/sftpgo/v2/internal/util"
@ -39,7 +40,7 @@ func getUsers(w http.ResponseWriter, r *http.Request) {
if err != nil {
return
}
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -55,16 +56,16 @@ func getUsers(w http.ResponseWriter, r *http.Request) {
func getUserByUsername(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
}
username := getURLParam(r, "username")
renderUser(w, r, username, &claims, http.StatusOK)
renderUser(w, r, username, claims, http.StatusOK)
}
func renderUser(w http.ResponseWriter, r *http.Request, username string, claims *jwtTokenClaims, status int) {
func renderUser(w http.ResponseWriter, r *http.Request, username string, claims *jwt.Claims, status int) {
user, err := dataprovider.UserExists(username, claims.Role)
if err != nil {
sendAPIResponse(w, r, err, "", getRespStatus(err))
@ -84,7 +85,7 @@ func renderUser(w http.ResponseWriter, r *http.Request, username string, claims
func addUser(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -117,12 +118,12 @@ func addUser(w http.ResponseWriter, r *http.Request) {
return
}
w.Header().Add("Location", fmt.Sprintf("%s/%s", userPath, url.PathEscape(user.Username)))
renderUser(w, r, user.Username, &claims, http.StatusCreated)
renderUser(w, r, user.Username, claims, http.StatusCreated)
}
func disableUser2FA(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -150,7 +151,7 @@ func disableUser2FA(w http.ResponseWriter, r *http.Request) {
func updateUser(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -202,7 +203,7 @@ func updateUser(w http.ResponseWriter, r *http.Request) {
func deleteUser(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return

View file

@ -42,6 +42,7 @@ import (
"github.com/drakkan/sftpgo/v2/internal/common"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/jwt"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/metric"
"github.com/drakkan/sftpgo/v2/internal/plugin"
@ -177,7 +178,7 @@ func getBoolQueryParam(r *http.Request, param string) bool {
func getActiveConnections(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -191,7 +192,7 @@ func getActiveConnections(w http.ResponseWriter, r *http.Request) {
func handleCloseConnection(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -943,8 +944,8 @@ func getProtocolFromRequest(r *http.Request) string {
return common.ProtocolHTTP
}
func hideConfidentialData(claims *jwtTokenClaims, r *http.Request) bool {
if !claims.hasPerm(dataprovider.PermAdminAny) {
func hideConfidentialData(claims *jwt.Claims, r *http.Request) bool {
if !claims.HasPerm(dataprovider.PermAdminAny) {
return true
}
return r.URL.Query().Get("confidential_data") != "1"

View file

@ -15,17 +15,14 @@
package httpd
import (
"crypto/rand"
"errors"
"fmt"
"net/http"
"slices"
"time"
"github.com/go-chi/jwtauth/v5"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/rs/xid"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/jwt"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/util"
)
@ -52,18 +49,8 @@ const (
)
const (
claimUsernameKey = "username"
claimPermissionsKey = "permissions"
claimRole = "role"
claimAPIKey = "api_key"
claimNodeID = "node_id"
claimMustChangePasswordKey = "chpwd"
claimMustSetSecondFactorKey = "2fa_required"
claimRequiredTwoFactorProtocols = "2fa_protos"
claimHideUserPageSection = "hus"
claimRef = "ref"
basicRealm = "Basic realm=\"SFTPGo\""
jwtCookieKey = "jwt"
basicRealm = "Basic realm=\"SFTPGo\""
jwtCookieKey = "jwt"
)
var (
@ -129,212 +116,26 @@ func getMaxCookieDuration() time.Duration {
return result
}
type jwtTokenClaims struct {
Username string
Permissions []string
Role string
Signature string
Audience []string
APIKeyID string
NodeID string
MustSetTwoFactorAuth bool
MustChangePassword bool
RequiredTwoFactorProtocols []string
HideUserPageSections int
JwtID string
JwtIssuedAt time.Time
Ref string
func hasUserAudience(claims *jwt.Claims) bool {
return claims.HasAnyAudience([]string{tokenAudienceWebClient, tokenAudienceAPIUser})
}
func (c *jwtTokenClaims) hasUserAudience() bool {
for _, audience := range c.Audience {
if audience == tokenAudienceWebClient || audience == tokenAudienceAPIUser {
return true
}
}
return false
}
func (c *jwtTokenClaims) asMap() map[string]any {
claims := make(map[string]any)
claims[claimUsernameKey] = c.Username
claims[claimPermissionsKey] = c.Permissions
if c.JwtID != "" {
claims[jwt.JwtIDKey] = c.JwtID
}
if !c.JwtIssuedAt.IsZero() {
claims[jwt.IssuedAtKey] = c.JwtIssuedAt
}
if c.Ref != "" {
claims[claimRef] = c.Ref
}
if c.Role != "" {
claims[claimRole] = c.Role
}
if c.APIKeyID != "" {
claims[claimAPIKey] = c.APIKeyID
}
if c.NodeID != "" {
claims[claimNodeID] = c.NodeID
}
claims[jwt.SubjectKey] = c.Signature
if c.MustChangePassword {
claims[claimMustChangePasswordKey] = c.MustChangePassword
}
if c.MustSetTwoFactorAuth {
claims[claimMustSetSecondFactorKey] = c.MustSetTwoFactorAuth
}
if len(c.RequiredTwoFactorProtocols) > 0 {
claims[claimRequiredTwoFactorProtocols] = c.RequiredTwoFactorProtocols
}
if c.HideUserPageSections > 0 {
claims[claimHideUserPageSection] = c.HideUserPageSections
}
return claims
}
func (c *jwtTokenClaims) decodeSliceString(val any) []string {
switch v := val.(type) {
case []any:
result := make([]string, 0, len(v))
for _, elem := range v {
switch elemValue := elem.(type) {
case string:
result = append(result, elemValue)
}
}
return result
case []string:
return v
default:
return nil
}
}
func (c *jwtTokenClaims) decodeBoolean(val any) bool {
switch v := val.(type) {
case bool:
return v
default:
return false
}
}
func (c *jwtTokenClaims) decodeString(val any) string {
switch v := val.(type) {
case string:
return v
default:
return ""
}
}
func (c *jwtTokenClaims) Decode(token map[string]any) {
c.Permissions = nil
c.Username = c.decodeString(token[claimUsernameKey])
c.Signature = c.decodeString(token[jwt.SubjectKey])
c.JwtID = c.decodeString(token[jwt.JwtIDKey])
audience := token[jwt.AudienceKey]
switch v := audience.(type) {
case []string:
c.Audience = v
}
if val, ok := token[claimRef]; ok {
c.Ref = c.decodeString(val)
}
if val, ok := token[claimAPIKey]; ok {
c.APIKeyID = c.decodeString(val)
}
if val, ok := token[claimNodeID]; ok {
c.NodeID = c.decodeString(val)
}
if val, ok := token[claimRole]; ok {
c.Role = c.decodeString(val)
}
permissions := token[claimPermissionsKey]
c.Permissions = c.decodeSliceString(permissions)
if val, ok := token[claimMustChangePasswordKey]; ok {
c.MustChangePassword = c.decodeBoolean(val)
}
if val, ok := token[claimMustSetSecondFactorKey]; ok {
c.MustSetTwoFactorAuth = c.decodeBoolean(val)
}
if val, ok := token[claimRequiredTwoFactorProtocols]; ok {
c.RequiredTwoFactorProtocols = c.decodeSliceString(val)
}
if val, ok := token[claimHideUserPageSection]; ok {
switch v := val.(type) {
case float64:
c.HideUserPageSections = int(v)
}
}
}
func (c *jwtTokenClaims) hasPerm(perm string) bool {
if slices.Contains(c.Permissions, dataprovider.PermAdminAny) {
return true
}
return slices.Contains(c.Permissions, perm)
}
func (c *jwtTokenClaims) createToken(tokenAuth *jwtauth.JWTAuth, audience tokenAudience, ip string) (jwt.Token, string, error) {
claims := c.asMap()
now := time.Now().UTC()
if _, ok := claims[jwt.JwtIDKey]; !ok {
claims[jwt.JwtIDKey] = xid.New().String()
}
if _, ok := claims[jwt.IssuedAtKey]; !ok {
claims[jwt.IssuedAtKey] = now
}
claims[jwt.NotBeforeKey] = now.Add(-30 * time.Second)
claims[jwt.ExpirationKey] = now.Add(getTokenDuration(audience))
claims[jwt.AudienceKey] = []string{audience, ip}
return tokenAuth.Encode(claims)
}
func (c *jwtTokenClaims) createTokenResponse(tokenAuth *jwtauth.JWTAuth, audience tokenAudience, ip string) (map[string]any, error) {
token, tokenString, err := c.createToken(tokenAuth, audience, ip)
if err != nil {
return nil, err
}
response := make(map[string]any)
response["access_token"] = tokenString
response["expires_at"] = token.Expiration().Format(time.RFC3339)
return response, nil
}
func (c *jwtTokenClaims) createAndSetCookie(w http.ResponseWriter, r *http.Request, tokenAuth *jwtauth.JWTAuth,
func createAndSetCookie(w http.ResponseWriter, r *http.Request, claims *jwt.Claims, tokenAuth *jwt.Signer,
audience tokenAudience, ip string,
) error {
resp, err := c.createTokenResponse(tokenAuth, audience, ip)
duration := getTokenDuration(audience)
token, err := tokenAuth.SignWithParams(claims, audience, ip, duration)
if err != nil {
return err
}
resp := claims.BuildTokenResponse(token)
var basePath string
if audience == tokenAudienceWebAdmin || audience == tokenAudienceWebAdminPartial {
basePath = webBaseAdminPath
} else {
basePath = webBaseClientPath
}
setCookie(w, r, basePath, resp["access_token"].(string), getTokenDuration(audience))
setCookie(w, r, basePath, resp.Token, duration)
return nil
}
@ -386,8 +187,8 @@ func isTLS(r *http.Request) bool {
func isTokenInvalidated(r *http.Request) bool {
var findTokenFns []func(r *http.Request) string
findTokenFns = append(findTokenFns, jwtauth.TokenFromHeader)
findTokenFns = append(findTokenFns, jwtauth.TokenFromCookie)
findTokenFns = append(findTokenFns, jwt.TokenFromHeader)
findTokenFns = append(findTokenFns, jwt.TokenFromCookie)
findTokenFns = append(findTokenFns, oidcTokenFromContext)
isTokenFound := false
@ -405,89 +206,78 @@ func isTokenInvalidated(r *http.Request) bool {
}
func invalidateToken(r *http.Request) {
tokenString := jwtauth.TokenFromHeader(r)
tokenString := jwt.TokenFromHeader(r)
if tokenString != "" {
invalidateTokenString(r, tokenString, apiTokenDuration)
}
tokenString = jwtauth.TokenFromCookie(r)
tokenString = jwt.TokenFromCookie(r)
if tokenString != "" {
invalidateTokenString(r, tokenString, getMaxCookieDuration())
}
}
func invalidateTokenString(r *http.Request, tokenString string, fallbackDuration time.Duration) {
token, _, err := jwtauth.FromContext(r.Context())
if err != nil || token == nil {
token, err := jwt.FromContext(r.Context())
if err != nil {
invalidatedJWTTokens.Add(tokenString, time.Now().Add(fallbackDuration).UTC())
return
}
invalidatedJWTTokens.Add(tokenString, token.Expiration().Add(1*time.Minute).UTC())
invalidatedJWTTokens.Add(tokenString, token.Expiry.Time().Add(1*time.Minute).UTC())
}
func getUserFromToken(r *http.Request) *dataprovider.User {
user := &dataprovider.User{}
_, claims, err := jwtauth.FromContext(r.Context())
claims, err := jwt.FromContext(r.Context())
if err != nil {
return user
}
tokenClaims := jwtTokenClaims{}
tokenClaims.Decode(claims)
user.Username = tokenClaims.Username
user.Filters.WebClient = tokenClaims.Permissions
user.Role = tokenClaims.Role
user.Username = claims.Username
user.Filters.WebClient = claims.Permissions
user.Role = claims.Role
return user
}
func getAdminFromToken(r *http.Request) *dataprovider.Admin {
admin := &dataprovider.Admin{}
_, claims, err := jwtauth.FromContext(r.Context())
claims, err := jwt.FromContext(r.Context())
if err != nil {
return admin
}
tokenClaims := jwtTokenClaims{}
tokenClaims.Decode(claims)
admin.Username = tokenClaims.Username
admin.Permissions = tokenClaims.Permissions
admin.Filters.Preferences.HideUserPageSections = tokenClaims.HideUserPageSections
admin.Role = tokenClaims.Role
admin.Username = claims.Username
admin.Permissions = claims.Permissions
admin.Filters.Preferences.HideUserPageSections = claims.HideUserPageSections
admin.Role = claims.Role
return admin
}
func createLoginCookie(w http.ResponseWriter, r *http.Request, csrfTokenAuth *jwtauth.JWTAuth, tokenID, basePath, ip string,
func createLoginCookie(w http.ResponseWriter, r *http.Request, csrfTokenAuth *jwt.Signer, tokenID, basePath, ip string,
) {
c := jwtTokenClaims{
JwtID: tokenID,
}
resp, err := c.createTokenResponse(csrfTokenAuth, tokenAudienceWebLogin, ip)
c := jwt.NewClaims(tokenAudienceWebLogin, ip, getTokenDuration(tokenAudienceWebLogin))
c.ID = tokenID
resp, err := c.GenerateTokenResponse(csrfTokenAuth)
if err != nil {
return
}
setCookie(w, r, basePath, resp["access_token"].(string), csrfTokenDuration)
setCookie(w, r, basePath, resp.Token, csrfTokenDuration)
}
func createCSRFToken(w http.ResponseWriter, r *http.Request, csrfTokenAuth *jwtauth.JWTAuth, tokenID,
func createCSRFToken(w http.ResponseWriter, r *http.Request, csrfTokenAuth *jwt.Signer, tokenID,
basePath string,
) string {
ip := util.GetIPFromRemoteAddress(r.RemoteAddr)
claims := make(map[string]any)
now := time.Now().UTC()
claims[jwt.JwtIDKey] = xid.New().String()
claims[jwt.IssuedAtKey] = now
claims[jwt.NotBeforeKey] = now.Add(-30 * time.Second)
claims[jwt.ExpirationKey] = now.Add(csrfTokenDuration)
claims[jwt.AudienceKey] = []string{tokenAudienceCSRF, ip}
claims := jwt.NewClaims(tokenAudienceCSRF, ip, csrfTokenDuration)
claims.ID = rand.Text()
if tokenID != "" {
createLoginCookie(w, r, csrfTokenAuth, tokenID, basePath, ip)
claims[claimRef] = tokenID
claims.Ref = tokenID
} else {
if c, err := getTokenClaims(r); err == nil {
claims[claimRef] = c.JwtID
if c, err := jwt.FromContext(r.Context()); err == nil {
claims.Ref = c.ID
} else {
logger.Error(logSender, "", "unable to add reference to CSRF token: %v", err)
}
}
_, tokenString, err := csrfTokenAuth.Encode(claims)
tokenString, err := csrfTokenAuth.Sign(claims)
if err != nil {
logger.Debug(logSender, "", "unable to create CSRF token: %v", err)
return ""
@ -495,15 +285,15 @@ func createCSRFToken(w http.ResponseWriter, r *http.Request, csrfTokenAuth *jwta
return tokenString
}
func verifyCSRFToken(r *http.Request, csrfTokenAuth *jwtauth.JWTAuth) error {
func verifyCSRFToken(r *http.Request, csrfTokenAuth *jwt.Signer) error {
tokenString := r.Form.Get(csrfFormToken)
token, err := jwtauth.VerifyToken(csrfTokenAuth, tokenString)
token, err := jwt.VerifyToken(csrfTokenAuth, tokenString)
if err != nil || token == nil {
logger.Debug(logSender, "", "error validating CSRF token %q: %v", tokenString, err)
return fmt.Errorf("unable to verify form token: %v", err)
}
if !slices.Contains(token.Audience(), tokenAudienceCSRF) {
if !token.Audience.Contains(tokenAudienceCSRF) {
logger.Debug(logSender, "", "error validating CSRF token audience")
return errors.New("the form token is not valid")
}
@ -515,19 +305,18 @@ func verifyCSRFToken(r *http.Request, csrfTokenAuth *jwtauth.JWTAuth) error {
return checkCSRFTokenRef(r, token)
}
func checkCSRFTokenRef(r *http.Request, token jwt.Token) error {
claims, err := getTokenClaims(r)
func checkCSRFTokenRef(r *http.Request, token *jwt.Claims) error {
claims, err := jwt.FromContext(r.Context())
if err != nil {
logger.Debug(logSender, "", "error getting token claims for CSRF validation: %v", err)
return err
}
ref, ok := token.Get(claimRef)
if !ok {
if token.ID == "" {
logger.Debug(logSender, "", "error validating CSRF token, missing reference")
return errors.New("the form token is not valid")
}
if claims.JwtID == "" || claims.JwtID != ref.(string) {
logger.Debug(logSender, "", "error validating CSRF reference, id %q, reference %q", claims.JwtID, ref)
if claims.ID != token.Ref {
logger.Debug(logSender, "", "error validating CSRF reference, id %q, reference %q", claims.ID, token.ID)
return errors.New("unexpected form token")
}
@ -535,8 +324,8 @@ func checkCSRFTokenRef(r *http.Request, token jwt.Token) error {
}
func verifyLoginCookie(r *http.Request) error {
token, _, err := jwtauth.FromContext(r.Context())
if err != nil || token == nil {
token, err := jwt.FromContext(r.Context())
if err != nil {
logger.Debug(logSender, "", "error getting login token: %v", err)
return errInvalidToken
}
@ -544,8 +333,8 @@ func verifyLoginCookie(r *http.Request) error {
logger.Debug(logSender, "", "the login token has been invalidated")
return errInvalidToken
}
if !slices.Contains(token.Audience(), tokenAudienceWebLogin) {
logger.Debug(logSender, "", "the token with id %q is not valid for audience %q", token.JwtID(), tokenAudienceWebLogin)
if !token.Audience.Contains(tokenAudienceWebLogin) {
logger.Debug(logSender, "", "the token with id %q is not valid for audience %q", token.ID, tokenAudienceWebLogin)
return errInvalidToken
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
@ -555,7 +344,7 @@ func verifyLoginCookie(r *http.Request) error {
return nil
}
func verifyLoginCookieAndCSRFToken(r *http.Request, csrfTokenAuth *jwtauth.JWTAuth) error {
func verifyLoginCookieAndCSRFToken(r *http.Request, csrfTokenAuth *jwt.Signer) error {
if err := verifyLoginCookie(r); err != nil {
return err
}
@ -565,17 +354,11 @@ func verifyLoginCookieAndCSRFToken(r *http.Request, csrfTokenAuth *jwtauth.JWTAu
return nil
}
func createOAuth2Token(csrfTokenAuth *jwtauth.JWTAuth, state, ip string) string {
claims := make(map[string]any)
now := time.Now().UTC()
func createOAuth2Token(csrfTokenAuth *jwt.Signer, state, ip string) string {
claims := jwt.NewClaims(tokenAudienceOAuth2, ip, getTokenDuration(tokenAudienceOAuth2))
claims.ID = state
claims[jwt.JwtIDKey] = state
claims[jwt.IssuedAtKey] = now
claims[jwt.NotBeforeKey] = now.Add(-30 * time.Second)
claims[jwt.ExpirationKey] = now.Add(getTokenDuration(tokenAudienceOAuth2))
claims[jwt.AudienceKey] = []string{tokenAudienceOAuth2, ip}
_, tokenString, err := csrfTokenAuth.Encode(claims)
tokenString, err := csrfTokenAuth.Sign(claims)
if err != nil {
logger.Debug(logSender, "", "unable to create OAuth2 token: %v", err)
return ""
@ -583,8 +366,8 @@ func createOAuth2Token(csrfTokenAuth *jwtauth.JWTAuth, state, ip string) string
return tokenString
}
func verifyOAuth2Token(csrfTokenAuth *jwtauth.JWTAuth, tokenString, ip string) (string, error) {
token, err := jwtauth.VerifyToken(csrfTokenAuth, tokenString)
func verifyOAuth2Token(csrfTokenAuth *jwt.Signer, tokenString, ip string) (string, error) {
token, err := jwt.VerifyToken(csrfTokenAuth, tokenString)
if err != nil || token == nil {
logger.Debug(logSender, "", "error validating OAuth2 token %q: %v", tokenString, err)
return "", util.NewI18nError(
@ -593,7 +376,7 @@ func verifyOAuth2Token(csrfTokenAuth *jwtauth.JWTAuth, tokenString, ip string) (
)
}
if !slices.Contains(token.Audience(), tokenAudienceOAuth2) {
if !token.Audience.Contains(tokenAudienceOAuth2) {
logger.Debug(logSender, "", "error validating OAuth2 token audience")
return "", util.NewI18nError(errors.New("invalid OAuth2 state"), util.I18nOAuth2InvalidState)
}
@ -602,31 +385,29 @@ func verifyOAuth2Token(csrfTokenAuth *jwtauth.JWTAuth, tokenString, ip string) (
logger.Debug(logSender, "", "error validating OAuth2 token IP audience")
return "", util.NewI18nError(errors.New("invalid OAuth2 state"), util.I18nOAuth2InvalidState)
}
if val, ok := token.Get(jwt.JwtIDKey); ok {
if state, ok := val.(string); ok {
return state, nil
}
if token.ID != "" {
return token.ID, nil
}
logger.Debug(logSender, "", "jti not found in OAuth2 token")
return "", util.NewI18nError(errors.New("invalid OAuth2 state"), util.I18nOAuth2InvalidState)
}
func validateIPForToken(token jwt.Token, ip string) error {
func validateIPForToken(token *jwt.Claims, ip string) error {
if tokenValidationMode&tokenValidationModeNoIPMatch == 0 {
if !slices.Contains(token.Audience(), ip) {
if !token.Audience.Contains(ip) {
return errInvalidToken
}
}
return nil
}
func checkTokenSignature(r *http.Request, token jwt.Token) error {
func checkTokenSignature(r *http.Request, token *jwt.Claims) error {
if _, ok := r.Context().Value(oidcTokenKey).(string); ok {
return nil
}
var err error
if tokenValidationMode&tokenValidationModeUserSignature != 0 {
for _, audience := range token.Audience() {
for _, audience := range token.Audience {
switch audience {
case tokenAudienceAPI, tokenAudienceWebAdmin:
err = validateSignatureForToken(token, dataprovider.GetAdminSignature)
@ -641,22 +422,16 @@ func checkTokenSignature(r *http.Request, token jwt.Token) error {
return err
}
func validateSignatureForToken(token jwt.Token, getter func(string) (string, error)) error {
username := ""
if u, ok := token.Get(claimUsernameKey); ok {
c := jwtTokenClaims{}
username = c.decodeString(u)
}
signature, err := getter(username)
func validateSignatureForToken(token *jwt.Claims, getter func(string) (string, error)) error {
signature, err := getter(token.Username)
if err != nil {
logger.Debug(logSender, "", "unable to get signature for username %q: %v", username, err)
logger.Debug(logSender, "", "unable to get signature for username %q: %v", token.Username, err)
return errInvalidToken
}
if signature != "" && signature == token.Subject() {
if signature != "" && signature == token.Subject {
return nil
}
logger.Debug(logSender, "", "signature mismatch for username %q, signature %q, token signature %q",
username, signature, token.Subject())
token.Username, signature, token.Subject)
return errInvalidToken
}

View file

@ -1334,10 +1334,12 @@ func updateWebAdminURLs(baseURL string) {
}
// GetHTTPRouter returns an HTTP handler suitable to use for test cases
func GetHTTPRouter(b Binding) http.Handler {
func GetHTTPRouter(b Binding) (http.Handler, error) {
server := newHttpdServer(b, filepath.Join("..", "..", "static"), "", CorsConfig{}, filepath.Join("..", "..", "openapi"))
server.initializeRouter()
return server.router
if err := server.initializeRouter(); err != nil {
return nil, err
}
return server.router, nil
}
// the ticker cannot be started/stopped from multiple goroutines

View file

@ -328,7 +328,7 @@ type recoveryCode struct {
Used bool `json:"used"`
}
func TestMain(m *testing.M) {
func TestMain(m *testing.M) { //nolint:gocyclo
homeBasePath = os.TempDir()
logfilePath := filepath.Join(configDir, "sftpgo_api_test.log")
logger.InitLogger(logfilePath, 5, 1, 28, false, false, zerolog.DebugLevel)
@ -480,7 +480,12 @@ func TestMain(m *testing.M) {
waitTCPListening(httpdConf.Bindings[0].GetAddress())
httpd.ReloadCertificateMgr() //nolint:errcheck
testServer = httptest.NewServer(httpd.GetHTTPRouter(httpdConf.Bindings[0]))
handler, err := httpd.GetHTTPRouter(httpdConf.Bindings[0])
if err != nil {
logger.ErrorToConsole("unable to get http test handler: %v", err)
os.Exit(1)
}
testServer = httptest.NewServer(handler)
defer testServer.Close()
exitCode := m.Run()

File diff suppressed because it is too large Load diff

View file

@ -24,12 +24,12 @@ import (
"strings"
"time"
"github.com/go-chi/jwtauth/v5"
"github.com/rs/xid"
"github.com/sftpgo/sdk"
"github.com/drakkan/sftpgo/v2/internal/common"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/jwt"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/util"
)
@ -48,7 +48,7 @@ func (k *contextKey) String() string {
}
func validateJWTToken(w http.ResponseWriter, r *http.Request, audience tokenAudience) error {
token, _, err := jwtauth.FromContext(r.Context())
token, err := jwt.FromContext(r.Context())
var redirectPath string
if audience == tokenAudienceWebAdmin {
@ -70,7 +70,7 @@ func validateJWTToken(w http.ResponseWriter, r *http.Request, audience tokenAudi
}
}
if err != nil || token == nil {
if err != nil {
logger.Debug(logSender, "", "error getting jwt token: %v", err)
doRedirect(http.StatusText(http.StatusUnauthorized), err)
return errInvalidToken
@ -82,17 +82,17 @@ func validateJWTToken(w http.ResponseWriter, r *http.Request, audience tokenAudi
return errInvalidToken
}
// a user with a partial token will be always redirected to the appropriate two factor auth page
if err := checkPartialAuth(w, r, audience, token.Audience()); err != nil {
if err := checkPartialAuth(w, r, audience, token.Audience); err != nil {
return err
}
if !slices.Contains(token.Audience(), audience) {
if !token.Audience.Contains(audience) {
logger.Debug(logSender, "", "the token is not valid for audience %q", audience)
doRedirect("Your token audience is not valid", nil)
return errInvalidToken
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := validateIPForToken(token, ipAddr); err != nil {
logger.Debug(logSender, "", "the token with id %q is not valid for the ip address %q", token.JwtID(), ipAddr)
logger.Debug(logSender, "", "the token with id %q is not valid for the ip address %q", token.ID, ipAddr)
doRedirect("Your token is not valid", nil)
return err
}
@ -104,14 +104,14 @@ func validateJWTToken(w http.ResponseWriter, r *http.Request, audience tokenAudi
}
func (s *httpdServer) validateJWTPartialToken(w http.ResponseWriter, r *http.Request, audience tokenAudience) error {
token, _, err := jwtauth.FromContext(r.Context())
token, err := jwt.FromContext(r.Context())
var notFoundFunc func(w http.ResponseWriter, r *http.Request, err error)
if audience == tokenAudienceWebAdminPartial {
notFoundFunc = s.renderNotFoundPage
} else {
notFoundFunc = s.renderClientNotFoundPage
}
if err != nil || token == nil {
if err != nil {
notFoundFunc(w, r, nil)
return errInvalidToken
}
@ -119,14 +119,14 @@ func (s *httpdServer) validateJWTPartialToken(w http.ResponseWriter, r *http.Req
notFoundFunc(w, r, nil)
return errInvalidToken
}
if !slices.Contains(token.Audience(), audience) {
logger.Debug(logSender, "", "the partial token with id %q is not valid for audience %q", token.JwtID(), audience)
if !token.Audience.Contains(audience) {
logger.Debug(logSender, "", "the partial token with id %q is not valid for audience %q", token.ID, audience)
notFoundFunc(w, r, nil)
return errInvalidToken
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := validateIPForToken(token, ipAddr); err != nil {
logger.Debug(logSender, "", "the partial token with id %q is not valid for the ip address %q", token.JwtID(), ipAddr)
logger.Debug(logSender, "", "the partial token with id %q is not valid for the ip address %q", token.ID, ipAddr)
notFoundFunc(w, r, nil)
return err
}
@ -194,7 +194,7 @@ func jwtAuthenticatorWebClient(next http.Handler) http.Handler {
func (s *httpdServer) checkHTTPUserPerm(perm string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, claims, err := jwtauth.FromContext(r.Context())
claims, err := jwt.FromContext(r.Context())
if err != nil {
if isWebRequest(r) {
s.renderClientBadRequestPage(w, r, err)
@ -203,10 +203,8 @@ func (s *httpdServer) checkHTTPUserPerm(perm string) func(next http.Handler) htt
}
return
}
tokenClaims := jwtTokenClaims{}
tokenClaims.Decode(claims)
// for web client perms are negated and not granted
if tokenClaims.hasPerm(perm) {
if claims.HasPerm(perm) {
if isWebRequest(r) {
s.renderClientForbiddenPage(w, r, errors.New("you don't have permission for this action"))
} else {
@ -223,7 +221,7 @@ func (s *httpdServer) checkHTTPUserPerm(perm string) func(next http.Handler) htt
// checkAuthRequirements checks if the user must set a second factor auth or change the password
func (s *httpdServer) checkAuthRequirements(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, claims, err := jwtauth.FromContext(r.Context())
claims, err := jwt.FromContext(r.Context())
if err != nil {
if isWebRequest(r) {
if isWebClientRequest(r) {
@ -236,13 +234,11 @@ func (s *httpdServer) checkAuthRequirements(next http.Handler) http.Handler {
}
return
}
tokenClaims := jwtTokenClaims{}
tokenClaims.Decode(claims)
if tokenClaims.MustSetTwoFactorAuth || tokenClaims.MustChangePassword {
if claims.MustSetTwoFactorAuth || claims.MustChangePassword {
var err error
if tokenClaims.MustSetTwoFactorAuth {
if len(tokenClaims.RequiredTwoFactorProtocols) > 0 {
protocols := strings.Join(tokenClaims.RequiredTwoFactorProtocols, ", ")
if claims.MustSetTwoFactorAuth {
if len(claims.RequiredTwoFactorProtocols) > 0 {
protocols := strings.Join(claims.RequiredTwoFactorProtocols, ", ")
err = util.NewI18nError(
util.NewGenericError(
fmt.Sprintf("Two-factor authentication requirements not met, please configure two-factor authentication for the following protocols: %v",
@ -301,7 +297,7 @@ func (s *httpdServer) requireBuiltinLogin(next http.Handler) http.Handler {
func (s *httpdServer) checkPerms(perms ...string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, claims, err := jwtauth.FromContext(r.Context())
claims, err := jwt.FromContext(r.Context())
if err != nil {
if isWebRequest(r) {
s.renderBadRequestPage(w, r, err)
@ -310,11 +306,9 @@ func (s *httpdServer) checkPerms(perms ...string) func(next http.Handler) http.H
}
return
}
tokenClaims := jwtTokenClaims{}
tokenClaims.Decode(claims)
for _, perm := range perms {
if !tokenClaims.hasPerm(perm) {
if !claims.HasPerm(perm) {
if isWebRequest(r) {
s.renderForbiddenPage(w, r, util.NewI18nError(fs.ErrPermission, util.I18nError403Message))
} else {
@ -332,14 +326,14 @@ func (s *httpdServer) checkPerms(perms ...string) func(next http.Handler) http.H
func (s *httpdServer) verifyCSRFHeader(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tokenString := r.Header.Get(csrfHeaderToken)
token, err := jwtauth.VerifyToken(s.csrfTokenAuth, tokenString)
token, err := jwt.VerifyToken(s.csrfTokenAuth, tokenString)
if err != nil || token == nil {
logger.Debug(logSender, "", "error validating CSRF header: %v", err)
sendAPIResponse(w, r, err, "Invalid token", http.StatusForbidden)
return
}
if !slices.Contains(token.Audience(), tokenAudienceCSRF) {
if !token.Audience.Contains(tokenAudienceCSRF) {
logger.Debug(logSender, "", "error validating CSRF header token audience")
sendAPIResponse(w, r, errors.New("the token is not valid"), "", http.StatusForbidden)
return
@ -359,49 +353,52 @@ func (s *httpdServer) verifyCSRFHeader(next http.Handler) http.Handler {
})
}
func checkNodeToken(tokenAuth *jwtauth.JWTAuth) func(next http.Handler) http.Handler {
func checkNodeToken(tokenAuth *jwt.Signer) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get(dataprovider.NodeTokenHeader)
if token == "" {
bearer := r.Header.Get(dataprovider.NodeTokenHeader)
if bearer == "" {
next.ServeHTTP(w, r)
return
}
if len(token) > 7 && strings.ToUpper(token[0:6]) == "BEARER" {
token = token[7:]
const prefix = "Bearer "
if len(bearer) >= len(prefix) && strings.EqualFold(bearer[:len(prefix)], prefix) {
bearer = bearer[len(prefix):]
}
if invalidatedJWTTokens.Get(token) {
if invalidatedJWTTokens.Get(bearer) {
logger.Debug(logSender, "", "the node token has been invalidated")
sendAPIResponse(w, r, fmt.Errorf("the provided token is not valid"), "", http.StatusUnauthorized)
return
}
admin, role, perms, err := dataprovider.AuthenticateNodeToken(token)
claims, err := dataprovider.AuthenticateNodeToken(bearer)
if err != nil {
logger.Debug(logSender, "", "unable to authenticate node token %q: %v", token, err)
logger.Debug(logSender, "", "unable to authenticate node token %q: %v", bearer, err)
sendAPIResponse(w, r, fmt.Errorf("the provided token cannot be authenticated"), "", http.StatusUnauthorized)
return
}
defer invalidatedJWTTokens.Add(token, time.Now().Add(2*time.Minute).UTC())
defer invalidatedJWTTokens.Add(bearer, time.Now().Add(2*time.Minute).UTC())
c := jwtTokenClaims{
Username: admin,
Permissions: perms,
c := &jwt.Claims{
Username: claims.Username,
Permissions: claims.Permissions,
NodeID: dataprovider.GetNodeName(),
Role: role,
Role: claims.Role,
}
resp, err := c.createTokenResponse(tokenAuth, tokenAudienceAPI, util.GetIPFromRemoteAddress(r.RemoteAddr))
token, err := tokenAuth.SignWithParams(c, tokenAudienceAPI, util.GetIPFromRemoteAddress(r.RemoteAddr), getTokenDuration(tokenAudienceAPI))
if err != nil {
sendAPIResponse(w, r, err, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
r.Header.Set("Authorization", fmt.Sprintf("Bearer %v", resp["access_token"]))
resp := c.BuildTokenResponse(token)
r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", resp.Token))
next.ServeHTTP(w, r)
})
}
}
func checkAPIKeyAuth(tokenAuth *jwtauth.JWTAuth, scope dataprovider.APIKeyScope) func(next http.Handler) http.Handler {
func checkAPIKeyAuth(tokenAuth *jwt.Signer, scope dataprovider.APIKeyScope) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
apiKey := r.Header.Get("X-SFTPGO-API-KEY")
@ -484,7 +481,7 @@ func checkAPIKeyAuth(tokenAuth *jwtauth.JWTAuth, scope dataprovider.APIKeyScope)
func forbidAPIKeyAuthentication(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@ -498,7 +495,7 @@ func forbidAPIKeyAuthentication(next http.Handler) http.Handler {
})
}
func authenticateAdminWithAPIKey(username, keyID string, tokenAuth *jwtauth.JWTAuth, r *http.Request) error {
func authenticateAdminWithAPIKey(username, keyID string, tokenAuth *jwt.Signer, r *http.Request) error {
if username == "" {
return errors.New("the provided key is not associated with any admin and no username was provided")
}
@ -513,25 +510,26 @@ func authenticateAdminWithAPIKey(username, keyID string, tokenAuth *jwtauth.JWTA
if err := admin.CanLogin(ipAddr); err != nil {
return err
}
c := jwtTokenClaims{
c := &jwt.Claims{
Username: admin.Username,
Permissions: admin.Permissions,
Signature: admin.GetSignature(),
Role: admin.Role,
APIKeyID: keyID,
}
c.Subject = admin.GetSignature()
resp, err := c.createTokenResponse(tokenAuth, tokenAudienceAPI, ipAddr)
token, err := tokenAuth.SignWithParams(c, tokenAudienceAPI, ipAddr, getTokenDuration(tokenAudienceAPI))
if err != nil {
return err
}
r.Header.Set("Authorization", fmt.Sprintf("Bearer %v", resp["access_token"]))
resp := c.BuildTokenResponse(token)
r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", resp.Token))
dataprovider.UpdateAdminLastLogin(&admin)
common.DelayLogin(nil)
return nil
}
func authenticateUserWithAPIKey(username, keyID string, tokenAuth *jwtauth.JWTAuth, r *http.Request) error {
func authenticateUserWithAPIKey(username, keyID string, tokenAuth *jwt.Signer, r *http.Request) error {
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
protocol := common.ProtocolHTTP
if username == "" {
@ -569,20 +567,21 @@ func authenticateUserWithAPIKey(username, keyID string, tokenAuth *jwtauth.JWTAu
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure, r)
return common.ErrInternalFailure
}
c := jwtTokenClaims{
c := &jwt.Claims{
Username: user.Username,
Permissions: user.Filters.WebClient,
Signature: user.GetSignature(),
Role: user.Role,
APIKeyID: keyID,
}
c.Subject = user.GetSignature()
resp, err := c.createTokenResponse(tokenAuth, tokenAudienceAPIUser, ipAddr)
token, err := tokenAuth.SignWithParams(c, tokenAudienceAPIUser, ipAddr, getTokenDuration(tokenAudienceAPIUser))
if err != nil {
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure, r)
return err
}
r.Header.Set("Authorization", fmt.Sprintf("Bearer %v", resp["access_token"]))
resp := c.BuildTokenResponse(token)
r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", resp.Token))
dataprovider.UpdateLastLogin(&user)
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, nil, r)

View file

@ -31,6 +31,7 @@ import (
"github.com/drakkan/sftpgo/v2/internal/common"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/httpclient"
"github.com/drakkan/sftpgo/v2/internal/jwt"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/util"
)
@ -551,19 +552,20 @@ func (s *httpdServer) oidcTokenAuthenticator(audience tokenAudience) func(next h
if err != nil {
return
}
jwtTokenClaims := jwtTokenClaims{
JwtID: token.Cookie,
claims := jwt.Claims{
Username: dataprovider.ConvertName(token.Username),
Permissions: token.Permissions,
Role: token.TokenRole,
HideUserPageSections: token.HideUserPageSections,
}
claims.ID = token.Cookie
if audience == tokenAudienceWebClient {
jwtTokenClaims.MustSetTwoFactorAuth = token.MustSetTwoFactorAuth
jwtTokenClaims.MustChangePassword = token.MustChangePassword
jwtTokenClaims.RequiredTwoFactorProtocols = token.RequiredTwoFactorProtocols
claims.MustSetTwoFactorAuth = token.MustSetTwoFactorAuth
claims.MustChangePassword = token.MustChangePassword
claims.RequiredTwoFactorProtocols = token.RequiredTwoFactorProtocols
}
_, tokenString, err := jwtTokenClaims.createToken(s.tokenAuth, audience, util.GetIPFromRemoteAddress(r.RemoteAddr))
tokenString, err := s.tokenAuth.SignWithParams(&claims, audience, util.GetIPFromRemoteAddress(r.RemoteAddr),
getTokenDuration(audience))
if err != nil {
setFlashMessage(w, r, newFlashMessage("Unable to create cookie", util.I18nError500Message))
if audience == tokenAudienceWebAdmin {

View file

@ -32,7 +32,6 @@ import (
"unsafe"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/go-chi/jwtauth/v5"
"github.com/rs/xid"
"github.com/sftpgo/sdk"
"github.com/stretchr/testify/assert"
@ -41,6 +40,7 @@ import (
"github.com/drakkan/sftpgo/v2/internal/common"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/jwt"
"github.com/drakkan/sftpgo/v2/internal/kms"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/internal/vfs"
@ -142,7 +142,8 @@ func TestOIDCLoginLogout(t *testing.T) {
server := getTestOIDCServer()
err := server.binding.OIDC.initialize()
assert.NoError(t, err)
server.initializeRouter()
err = server.initializeRouter()
require.NoError(t, err)
rr := httptest.NewRecorder()
r, err := http.NewRequest(http.MethodGet, webOIDCRedirectPath, nil)
@ -768,7 +769,8 @@ func TestValidateOIDCToken(t *testing.T) {
server := getTestOIDCServer()
err := server.binding.OIDC.initialize()
assert.NoError(t, err)
server.initializeRouter()
err = server.initializeRouter()
require.NoError(t, err)
rr := httptest.NewRecorder()
r, err := http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
@ -796,7 +798,7 @@ func TestValidateOIDCToken(t *testing.T) {
oidcMgr.removeToken(token.Cookie)
assert.Len(t, oidcMgr.tokens, 0)
server.tokenAuth = jwtauth.New("PS256", util.GenerateRandomBytes(32), nil)
server.tokenAuth.SetSigner(&failingJoseSigner{})
token = oidcToken{
Cookie: util.GenerateOpaqueString(),
AccessToken: util.GenerateUniqueID(),
@ -833,11 +835,12 @@ func TestSkipOIDCAuth(t *testing.T) {
server := getTestOIDCServer()
err := server.binding.OIDC.initialize()
assert.NoError(t, err)
server.initializeRouter()
jwtTokenClaims := jwtTokenClaims{
Username: "user",
}
_, tokenString, err := jwtTokenClaims.createToken(server.tokenAuth, tokenAudienceWebClient, "")
err = server.initializeRouter()
require.NoError(t, err)
claims := jwt.NewClaims(tokenAudienceWebClient, "", getTokenDuration(tokenAudienceWebClient))
claims.Username = "user"
tokenString, err := server.tokenAuth.Sign(claims)
assert.NoError(t, err)
rr := httptest.NewRecorder()
r, err := http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
@ -968,7 +971,8 @@ func TestOIDCImplicitRoles(t *testing.T) {
server.binding.OIDC.ImplicitRoles = true
err := server.binding.OIDC.initialize()
assert.NoError(t, err)
server.initializeRouter()
err = server.initializeRouter()
require.NoError(t, err)
authReq := newOIDCPendingAuth(tokenAudienceWebAdmin)
oidcMgr.addPendingAuth(authReq)
@ -1241,7 +1245,8 @@ func TestOIDCEvMgrIntegration(t *testing.T) {
server.binding.OIDC.CustomFields = []string{"custom1.sub", "custom2"}
err = server.binding.OIDC.initialize()
assert.NoError(t, err)
server.initializeRouter()
err = server.initializeRouter()
require.NoError(t, err)
// login a user with OIDC
_, err = dataprovider.UserExists(username, "")
assert.ErrorIs(t, err, util.ErrNotFound)
@ -1378,7 +1383,8 @@ func TestOIDCPreLoginHook(t *testing.T) {
server.binding.OIDC.CustomFields = []string{"field1", "field2"}
err = server.binding.OIDC.initialize()
assert.NoError(t, err)
server.initializeRouter()
err = server.initializeRouter()
require.NoError(t, err)
_, err = dataprovider.UserExists(username, "")
assert.ErrorIs(t, err, util.ErrNotFound)
@ -1554,7 +1560,8 @@ func TestOIDCWithLoginFormsDisabled(t *testing.T) {
server.binding.EnableWebClient = true
err := server.binding.OIDC.initialize()
assert.NoError(t, err)
server.initializeRouter()
err = server.initializeRouter()
require.NoError(t, err)
// login with an admin user
authReq := newOIDCPendingAuth(tokenAudienceWebAdmin)
oidcMgr.addPendingAuth(authReq)

View file

@ -16,6 +16,7 @@ package httpd
import (
"context"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"errors"
@ -32,9 +33,8 @@ import (
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/jwtauth/v5"
"github.com/go-chi/render"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/go-jose/go-jose/v4"
"github.com/rs/cors"
"github.com/rs/xid"
"github.com/sftpgo/sdk"
@ -43,6 +43,7 @@ import (
"github.com/drakkan/sftpgo/v2/internal/acme"
"github.com/drakkan/sftpgo/v2/internal/common"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/jwt"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/mfa"
"github.com/drakkan/sftpgo/v2/internal/smtp"
@ -69,8 +70,8 @@ type httpdServer struct {
renderOpenAPI bool
isShared int
router *chi.Mux
tokenAuth *jwtauth.JWTAuth
csrfTokenAuth *jwtauth.JWTAuth
tokenAuth *jwt.Signer
csrfTokenAuth *jwt.Signer
signingPassphrase string
cors CorsConfig
}
@ -99,7 +100,9 @@ func (s *httpdServer) setShared(value int) {
}
func (s *httpdServer) listenAndServe() error {
s.initializeRouter()
if err := s.initializeRouter(); err != nil {
return err
}
httpServer := &http.Server{
Handler: s.router,
ReadHeaderTimeout: 30 * time.Second,
@ -173,7 +176,7 @@ func (s *httpdServer) renderClientLoginPage(w http.ResponseWriter, r *http.Reque
Title: util.I18nLoginTitle,
CurrentURL: webClientLoginPath,
Error: err,
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseClientPath),
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, rand.Text(), webBaseClientPath),
Branding: s.binding.webClientBranding(),
Languages: s.binding.languages(),
FormDisabled: s.binding.isWebClientLoginFormDisabled(),
@ -327,7 +330,7 @@ func (s *httpdServer) handleWebClientPasswordResetPost(w http.ResponseWriter, r
func (s *httpdServer) handleWebClientTwoFactorRecoveryPost(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil {
s.renderNotFoundPage(w, r, nil)
return
@ -393,7 +396,7 @@ func (s *httpdServer) handleWebClientTwoFactorRecoveryPost(w http.ResponseWriter
func (s *httpdServer) handleWebClientTwoFactorPost(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil {
s.renderNotFoundPage(w, r, nil)
return
@ -451,7 +454,7 @@ func (s *httpdServer) handleWebClientTwoFactorPost(w http.ResponseWriter, r *htt
func (s *httpdServer) handleWebAdminTwoFactorRecoveryPost(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil {
s.renderNotFoundPage(w, r, nil)
return
@ -511,7 +514,7 @@ func (s *httpdServer) handleWebAdminTwoFactorRecoveryPost(w http.ResponseWriter,
func (s *httpdServer) handleWebAdminTwoFactorPost(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil {
s.renderNotFoundPage(w, r, nil)
return
@ -592,7 +595,7 @@ func (s *httpdServer) renderAdminLoginPage(w http.ResponseWriter, r *http.Reques
Title: util.I18nLoginTitle,
CurrentURL: webAdminLoginPath,
Error: err,
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseAdminPath),
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, rand.Text(), webBaseAdminPath),
Branding: s.binding.webAdminBranding(),
Languages: s.binding.languages(),
FormDisabled: s.binding.isWebAdminLoginFormDisabled(),
@ -735,15 +738,15 @@ func (s *httpdServer) loginUser(
w http.ResponseWriter, r *http.Request, user *dataprovider.User, connectionID, ipAddr string,
isSecondFactorAuth bool, errorFunc func(w http.ResponseWriter, r *http.Request, err *util.I18nError),
) {
c := jwtTokenClaims{
c := &jwt.Claims{
Username: user.Username,
Permissions: user.Filters.WebClient,
Signature: user.GetSignature(),
Role: user.Role,
MustSetTwoFactorAuth: user.MustSetSecondFactor(),
MustChangePassword: user.MustChangePassword(),
RequiredTwoFactorProtocols: user.Filters.TwoFactorAuthProtocols,
}
c.Subject = user.GetSignature()
audience := tokenAudienceWebClient
if user.Filters.TOTPConfig.Enabled && slices.Contains(user.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) &&
@ -751,7 +754,7 @@ func (s *httpdServer) loginUser(
audience = tokenAudienceWebClientPartial
}
err := c.createAndSetCookie(w, r, s.tokenAuth, audience, ipAddr)
err := createAndSetCookie(w, r, c, s.tokenAuth, audience, ipAddr)
if err != nil {
logger.Warn(logSender, connectionID, "unable to set user login cookie %v", err)
updateLoginMetrics(user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure, r)
@ -781,22 +784,22 @@ func (s *httpdServer) loginAdmin(
isSecondFactorAuth bool, errorFunc func(w http.ResponseWriter, r *http.Request, err *util.I18nError),
ipAddr string,
) {
c := jwtTokenClaims{
c := &jwt.Claims{
Username: admin.Username,
Permissions: admin.Permissions,
Role: admin.Role,
Signature: admin.GetSignature(),
HideUserPageSections: admin.Filters.Preferences.HideUserPageSections,
MustSetTwoFactorAuth: admin.Filters.RequireTwoFactor && !admin.Filters.TOTPConfig.Enabled,
MustChangePassword: admin.Filters.RequirePasswordChange,
}
c.Subject = admin.GetSignature()
audience := tokenAudienceWebAdmin
if admin.Filters.TOTPConfig.Enabled && admin.CanManageMFA() && !isSecondFactorAuth {
audience = tokenAudienceWebAdminPartial
}
err := c.createAndSetCookie(w, r, s.tokenAuth, audience, ipAddr)
err := createAndSetCookie(w, r, c, s.tokenAuth, audience, ipAddr)
if err != nil {
logger.Warn(logSender, "", "unable to set admin login cookie %v", err)
if errorFunc == nil {
@ -907,17 +910,17 @@ func (s *httpdServer) getUserToken(w http.ResponseWriter, r *http.Request) {
}
func (s *httpdServer) generateAndSendUserToken(w http.ResponseWriter, r *http.Request, ipAddr string, user dataprovider.User) {
c := jwtTokenClaims{
c := &jwt.Claims{
Username: user.Username,
Permissions: user.Filters.WebClient,
Signature: user.GetSignature(),
Role: user.Role,
MustSetTwoFactorAuth: user.MustSetSecondFactor(),
MustChangePassword: user.MustChangePassword(),
RequiredTwoFactorProtocols: user.Filters.TwoFactorAuthProtocols,
}
c.Subject = user.GetSignature()
resp, err := c.createTokenResponse(s.tokenAuth, tokenAudienceAPIUser, ipAddr)
token, err := s.tokenAuth.SignWithParams(c, tokenAudienceAPIUser, ipAddr, getTokenDuration(tokenAudienceAPIUser))
if err != nil {
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure, r)
sendAPIResponse(w, r, err, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
@ -926,7 +929,7 @@ func (s *httpdServer) generateAndSendUserToken(w http.ResponseWriter, r *http.Re
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err, r)
dataprovider.UpdateLastLogin(&user)
render.JSON(w, r, resp)
render.JSON(w, r, c.BuildTokenResponse(token))
}
func (s *httpdServer) getToken(w http.ResponseWriter, r *http.Request) {
@ -976,17 +979,16 @@ func (s *httpdServer) getToken(w http.ResponseWriter, r *http.Request) {
}
func (s *httpdServer) generateAndSendToken(w http.ResponseWriter, r *http.Request, admin dataprovider.Admin, ip string) {
c := jwtTokenClaims{
c := &jwt.Claims{
Username: admin.Username,
Permissions: admin.Permissions,
Role: admin.Role,
Signature: admin.GetSignature(),
MustSetTwoFactorAuth: admin.Filters.RequireTwoFactor && !admin.Filters.TOTPConfig.Enabled,
MustChangePassword: admin.Filters.RequirePasswordChange,
}
c.Subject = admin.GetSignature()
resp, err := c.createTokenResponse(s.tokenAuth, tokenAudienceAPI, ip)
token, err := s.tokenAuth.SignWithParams(c, tokenAudienceAPI, ip, getTokenDuration(tokenAudienceAPI))
if err != nil {
sendAPIResponse(w, r, err, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
@ -994,42 +996,39 @@ func (s *httpdServer) generateAndSendToken(w http.ResponseWriter, r *http.Reques
dataprovider.UpdateAdminLastLogin(&admin)
common.DelayLogin(nil)
render.JSON(w, r, resp)
render.JSON(w, r, c.BuildTokenResponse(token))
}
func (s *httpdServer) checkCookieExpiration(w http.ResponseWriter, r *http.Request) {
if _, ok := r.Context().Value(oidcTokenKey).(string); ok {
return
}
token, claims, err := jwtauth.FromContext(r.Context())
if err != nil || token == nil {
claims, err := jwt.FromContext(r.Context())
if err != nil {
return
}
tokenClaims := jwtTokenClaims{}
tokenClaims.Decode(claims)
if tokenClaims.Username == "" || tokenClaims.Signature == "" {
if claims.Username == "" || claims.Subject == "" {
return
}
if time.Until(token.Expiration()) > cookieRefreshThreshold {
if time.Until(claims.Expiry.Time()) > cookieRefreshThreshold {
return
}
if (time.Since(token.IssuedAt()) + cookieTokenDuration) > maxTokenDuration {
if (time.Since(claims.IssuedAt.Time()) + cookieTokenDuration) > maxTokenDuration {
return
}
tokenClaims.JwtIssuedAt = token.IssuedAt()
if slices.Contains(token.Audience(), tokenAudienceWebClient) {
s.refreshClientToken(w, r, &tokenClaims)
if claims.Audience.Contains(tokenAudienceWebClient) {
s.refreshClientToken(w, r, claims)
} else {
s.refreshAdminToken(w, r, &tokenClaims)
s.refreshAdminToken(w, r, claims)
}
}
func (s *httpdServer) refreshClientToken(w http.ResponseWriter, r *http.Request, tokenClaims *jwtTokenClaims) {
func (s *httpdServer) refreshClientToken(w http.ResponseWriter, r *http.Request, tokenClaims *jwt.Claims) {
user, err := dataprovider.GetUserWithGroupSettings(tokenClaims.Username, "")
if err != nil {
return
}
if user.GetSignature() != tokenClaims.Signature {
if user.GetSignature() != tokenClaims.Subject {
logger.Debug(logSender, "", "signature mismatch for user %q, unable to refresh cookie", user.Username)
return
}
@ -1045,15 +1044,15 @@ func (s *httpdServer) refreshClientToken(w http.ResponseWriter, r *http.Request,
tokenClaims.Permissions = user.Filters.WebClient
tokenClaims.Role = user.Role
logger.Debug(logSender, "", "cookie refreshed for user %q", user.Username)
tokenClaims.createAndSetCookie(w, r, s.tokenAuth, tokenAudienceWebClient, util.GetIPFromRemoteAddress(r.RemoteAddr)) //nolint:errcheck
createAndSetCookie(w, r, tokenClaims, s.tokenAuth, tokenAudienceWebClient, util.GetIPFromRemoteAddress(r.RemoteAddr)) //nolint:errcheck
}
func (s *httpdServer) refreshAdminToken(w http.ResponseWriter, r *http.Request, tokenClaims *jwtTokenClaims) {
func (s *httpdServer) refreshAdminToken(w http.ResponseWriter, r *http.Request, tokenClaims *jwt.Claims) {
admin, err := dataprovider.AdminExists(tokenClaims.Username)
if err != nil {
return
}
if admin.GetSignature() != tokenClaims.Signature {
if admin.GetSignature() != tokenClaims.Subject {
logger.Debug(logSender, "", "signature mismatch for admin %q, unable to refresh cookie", admin.Username)
return
}
@ -1066,18 +1065,18 @@ func (s *httpdServer) refreshAdminToken(w http.ResponseWriter, r *http.Request,
tokenClaims.Role = admin.Role
tokenClaims.HideUserPageSections = admin.Filters.Preferences.HideUserPageSections
logger.Debug(logSender, "", "cookie refreshed for admin %q", admin.Username)
tokenClaims.createAndSetCookie(w, r, s.tokenAuth, tokenAudienceWebAdmin, ipAddr) //nolint:errcheck
createAndSetCookie(w, r, tokenClaims, s.tokenAuth, tokenAudienceWebAdmin, ipAddr) //nolint:errcheck
}
func (s *httpdServer) updateContextFromCookie(r *http.Request) *http.Request {
token, _, err := jwtauth.FromContext(r.Context())
if token == nil || err != nil {
_, err := jwt.FromContext(r.Context())
if err != nil {
_, err = r.Cookie(jwtCookieKey)
if err != nil {
return r
}
token, err = jwtauth.VerifyRequest(s.tokenAuth, r, jwtauth.TokenFromCookie)
ctx := jwtauth.NewContext(r.Context(), token, err)
token, err := jwt.VerifyRequest(s.tokenAuth, r, jwt.TokenFromCookie)
ctx := jwt.NewContext(r.Context(), token, err)
return r.WithContext(ctx)
}
return r
@ -1235,10 +1234,18 @@ func (s *httpdServer) mustCheckPath(r *http.Request) bool {
return !strings.HasPrefix(urlPath, webStaticFilesPath) && !strings.HasPrefix(urlPath, acmeChallengeURI)
}
func (s *httpdServer) initializeRouter() {
func (s *httpdServer) initializeRouter() error {
signer, err := jwt.NewSigner(jose.HS256, getSigningKey(s.signingPassphrase))
if err != nil {
return err
}
csrfSigner, err := jwt.NewSigner(jose.HS256, getSigningKey(s.signingPassphrase))
if err != nil {
return err
}
var hasHTTPSRedirect bool
s.tokenAuth = jwtauth.New(jwa.HS256.String(), getSigningKey(s.signingPassphrase), nil)
s.csrfTokenAuth = jwtauth.New(jwa.HS256.String(), getSigningKey(s.signingPassphrase), nil)
s.tokenAuth = signer
s.csrfTokenAuth = csrfSigner
s.router = chi.NewRouter()
s.router.Use(middleware.RequestID)
@ -1336,6 +1343,7 @@ func (s *httpdServer) initializeRouter() {
s.setupWebClientRoutes()
s.setupWebAdminRoutes()
return nil
}
func (s *httpdServer) setupRESTAPIRoutes() {
@ -1351,7 +1359,7 @@ func (s *httpdServer) setupRESTAPIRoutes() {
if !s.binding.isAdminAPIKeyAuthDisabled() {
router.Use(checkAPIKeyAuth(s.tokenAuth, dataprovider.APIKeyScopeAdmin))
}
router.Use(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromHeader))
router.Use(jwt.Verify(s.tokenAuth, jwt.TokenFromHeader))
router.Use(jwtAuthenticatorAPI)
router.Get(versionPath, func(w http.ResponseWriter, r *http.Request) {
@ -1480,7 +1488,7 @@ func (s *httpdServer) setupRESTAPIRoutes() {
if !s.binding.isUserAPIKeyAuthDisabled() {
router.Use(checkAPIKeyAuth(s.tokenAuth, dataprovider.APIKeyScopeUser))
}
router.Use(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromHeader))
router.Use(jwt.Verify(s.tokenAuth, jwt.TokenFromHeader))
router.Use(jwtAuthenticatorAPIUser)
router.With(forbidAPIKeyAuthentication).Get(userLogoutPath, s.logout)
@ -1568,31 +1576,31 @@ func (s *httpdServer) setupWebClientRoutes() {
s.router.Get(webClientOIDCLoginPath, s.handleWebClientOIDCLogin)
}
if !s.binding.isWebClientLoginFormDisabled() {
s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)).
s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)).
Post(webClientLoginPath, s.handleWebClientLoginPost)
s.router.Get(webClientForgotPwdPath, s.handleWebClientForgotPwd)
s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)).
s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)).
Post(webClientForgotPwdPath, s.handleWebClientForgotPwdPost)
s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)).
s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)).
Get(webClientResetPwdPath, s.handleWebClientPasswordReset)
s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)).
s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)).
Post(webClientResetPwdPath, s.handleWebClientPasswordResetPost)
s.router.With(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie),
s.router.With(jwt.Verify(s.tokenAuth, jwt.TokenFromCookie),
s.jwtAuthenticatorPartial(tokenAudienceWebClientPartial)).
Get(webClientTwoFactorPath, s.handleWebClientTwoFactor)
s.router.With(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie),
s.router.With(jwt.Verify(s.tokenAuth, jwt.TokenFromCookie),
s.jwtAuthenticatorPartial(tokenAudienceWebClientPartial)).
Post(webClientTwoFactorPath, s.handleWebClientTwoFactorPost)
s.router.With(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie),
s.router.With(jwt.Verify(s.tokenAuth, jwt.TokenFromCookie),
s.jwtAuthenticatorPartial(tokenAudienceWebClientPartial)).
Get(webClientTwoFactorRecoveryPath, s.handleWebClientTwoFactorRecovery)
s.router.With(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie),
s.router.With(jwt.Verify(s.tokenAuth, jwt.TokenFromCookie),
s.jwtAuthenticatorPartial(tokenAudienceWebClientPartial)).
Post(webClientTwoFactorRecoveryPath, s.handleWebClientTwoFactorRecoveryPost)
}
// share routes available to external users
s.router.Get(webClientPubSharesPath+"/{id}/login", s.handleClientShareLoginGet)
s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)).
s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)).
Post(webClientPubSharesPath+"/{id}/login", s.handleClientShareLoginPost)
s.router.Get(webClientPubSharesPath+"/{id}/logout", s.handleClientShareLogout)
s.router.Get(webClientPubSharesPath+"/{id}", s.downloadFromShare)
@ -1611,7 +1619,7 @@ func (s *httpdServer) setupWebClientRoutes() {
if s.binding.OIDC.isEnabled() {
router.Use(s.oidcTokenAuthenticator(tokenAudienceWebClient))
}
router.Use(jwtauth.Verify(s.tokenAuth, oidcTokenFromContext, jwtauth.TokenFromCookie))
router.Use(jwt.Verify(s.tokenAuth, oidcTokenFromContext, jwt.TokenFromCookie))
router.Use(jwtAuthenticatorWebClient)
router.Get(webClientLogoutPath, s.handleWebClientLogout)
@ -1702,29 +1710,29 @@ func (s *httpdServer) setupWebAdminRoutes() {
}
s.router.Get(webOAuth2RedirectPath, s.handleOAuth2TokenRedirect)
s.router.Get(webAdminSetupPath, s.handleWebAdminSetupGet)
s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)).
s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)).
Post(webAdminSetupPath, s.handleWebAdminSetupPost)
if !s.binding.isWebAdminLoginFormDisabled() {
s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)).
s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)).
Post(webAdminLoginPath, s.handleWebAdminLoginPost)
s.router.With(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie),
s.router.With(jwt.Verify(s.tokenAuth, jwt.TokenFromCookie),
s.jwtAuthenticatorPartial(tokenAudienceWebAdminPartial)).
Get(webAdminTwoFactorPath, s.handleWebAdminTwoFactor)
s.router.With(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie),
s.router.With(jwt.Verify(s.tokenAuth, jwt.TokenFromCookie),
s.jwtAuthenticatorPartial(tokenAudienceWebAdminPartial)).
Post(webAdminTwoFactorPath, s.handleWebAdminTwoFactorPost)
s.router.With(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie),
s.router.With(jwt.Verify(s.tokenAuth, jwt.TokenFromCookie),
s.jwtAuthenticatorPartial(tokenAudienceWebAdminPartial)).
Get(webAdminTwoFactorRecoveryPath, s.handleWebAdminTwoFactorRecovery)
s.router.With(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie),
s.router.With(jwt.Verify(s.tokenAuth, jwt.TokenFromCookie),
s.jwtAuthenticatorPartial(tokenAudienceWebAdminPartial)).
Post(webAdminTwoFactorRecoveryPath, s.handleWebAdminTwoFactorRecoveryPost)
s.router.Get(webAdminForgotPwdPath, s.handleWebAdminForgotPwd)
s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)).
s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)).
Post(webAdminForgotPwdPath, s.handleWebAdminForgotPwdPost)
s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)).
s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)).
Get(webAdminResetPwdPath, s.handleWebAdminPasswordReset)
s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)).
s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)).
Post(webAdminResetPwdPath, s.handleWebAdminPasswordResetPost)
}
@ -1732,7 +1740,7 @@ func (s *httpdServer) setupWebAdminRoutes() {
if s.binding.OIDC.isEnabled() {
router.Use(s.oidcTokenAuthenticator(tokenAudienceWebAdmin))
}
router.Use(jwtauth.Verify(s.tokenAuth, oidcTokenFromContext, jwtauth.TokenFromCookie))
router.Use(jwt.Verify(s.tokenAuth, oidcTokenFromContext, jwt.TokenFromCookie))
router.Use(jwtAuthenticatorWebAdmin)
router.Get(webLogoutPath, s.handleWebAdminLogout)

View file

@ -16,6 +16,7 @@ package httpd
import (
"context"
"crypto/rand"
"encoding/json"
"errors"
"fmt"
@ -31,7 +32,6 @@ import (
"strings"
"time"
"github.com/rs/xid"
"github.com/sftpgo/sdk"
sdkkms "github.com/sftpgo/sdk/kms"
@ -39,6 +39,7 @@ import (
"github.com/drakkan/sftpgo/v2/internal/common"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/ftpd"
"github.com/drakkan/sftpgo/v2/internal/jwt"
"github.com/drakkan/sftpgo/v2/internal/kms"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/mfa"
@ -726,7 +727,7 @@ func (s *httpdServer) renderForgotPwdPage(w http.ResponseWriter, r *http.Request
commonBasePage: getCommonBasePage(r),
CurrentURL: webAdminForgotPwdPath,
Error: err,
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseAdminPath),
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, rand.Text(), webBaseAdminPath),
LoginURL: webAdminLoginPath,
Title: util.I18nForgotPwdTitle,
Branding: s.binding.webAdminBranding(),
@ -863,7 +864,7 @@ func (s *httpdServer) renderAdminSetupPage(w http.ResponseWriter, r *http.Reques
commonBasePage: getCommonBasePage(r),
Title: util.I18nSetupTitle,
CurrentURL: webAdminSetupPath,
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseAdminPath),
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, rand.Text(), webBaseAdminPath),
Username: username,
HasInstallationCode: installationCode != "",
InstallationCodeHint: installationCodeHint,
@ -2964,7 +2965,7 @@ func (s *httpdServer) handleWebAdminProfilePost(w http.ResponseWriter, r *http.R
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderProfilePage(w, r, util.NewI18nError(err, util.I18nErrorInvalidToken))
return
@ -2992,7 +2993,7 @@ func (s *httpdServer) handleWebMaintenance(w http.ResponseWriter, r *http.Reques
func (s *httpdServer) handleWebRestore(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, MaxRestoreSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return
@ -3045,7 +3046,7 @@ func (s *httpdServer) handleWebRestore(w http.ResponseWriter, r *http.Request) {
func getAllAdmins(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, nil, util.I18nErrorInvalidToken, http.StatusForbidden)
return
@ -3103,7 +3104,7 @@ func (s *httpdServer) handleWebUpdateAdminGet(w http.ResponseWriter, r *http.Req
func (s *httpdServer) handleWebAddAdminPost(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return
@ -3163,7 +3164,7 @@ func (s *httpdServer) handleWebUpdateAdminPost(w http.ResponseWriter, r *http.Re
}
updatedAdmin.Filters.TOTPConfig = admin.Filters.TOTPConfig
updatedAdmin.Filters.RecoveryCodes = admin.Filters.RecoveryCodes
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderAddUpdateAdminPage(w, r, &updatedAdmin, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken), false)
return
@ -3214,7 +3215,7 @@ func (s *httpdServer) handleWebDefenderPage(w http.ResponseWriter, r *http.Reque
func getAllUsers(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, nil, util.I18nErrorInvalidToken, http.StatusForbidden)
return
@ -3234,7 +3235,7 @@ func getAllUsers(w http.ResponseWriter, r *http.Request) {
func (s *httpdServer) handleGetWebUsers(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return
@ -3264,7 +3265,7 @@ func (s *httpdServer) handleWebTemplateFolderGet(w http.ResponseWriter, r *http.
func (s *httpdServer) handleWebTemplateFolderPost(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return
@ -3361,7 +3362,7 @@ func (s *httpdServer) handleWebTemplateUserGet(w http.ResponseWriter, r *http.Re
func (s *httpdServer) handleWebTemplateUserPost(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return
@ -3429,7 +3430,7 @@ func (s *httpdServer) handleWebAddUserGet(w http.ResponseWriter, r *http.Request
func (s *httpdServer) handleWebUpdateUserGet(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return
@ -3447,7 +3448,7 @@ func (s *httpdServer) handleWebUpdateUserGet(w http.ResponseWriter, r *http.Requ
func (s *httpdServer) handleWebAddUserPost(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return
@ -3485,7 +3486,7 @@ func (s *httpdServer) handleWebAddUserPost(w http.ResponseWriter, r *http.Reques
func (s *httpdServer) handleWebUpdateUserPost(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return
@ -3552,7 +3553,7 @@ func (s *httpdServer) handleWebGetStatus(w http.ResponseWriter, r *http.Request)
func (s *httpdServer) handleWebGetConnections(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return
@ -3569,7 +3570,7 @@ func (s *httpdServer) handleWebAddFolderGet(w http.ResponseWriter, r *http.Reque
func (s *httpdServer) handleWebAddFolderPost(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return
@ -3621,7 +3622,7 @@ func (s *httpdServer) handleWebUpdateFolderGet(w http.ResponseWriter, r *http.Re
func (s *httpdServer) handleWebUpdateFolderPost(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return
@ -3756,7 +3757,7 @@ func (s *httpdServer) handleWebAddGroupGet(w http.ResponseWriter, r *http.Reques
func (s *httpdServer) handleWebAddGroupPost(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return
@ -3794,7 +3795,7 @@ func (s *httpdServer) handleWebUpdateGroupGet(w http.ResponseWriter, r *http.Req
func (s *httpdServer) handleWebUpdateGroupPost(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return
@ -3881,7 +3882,7 @@ func (s *httpdServer) handleWebAddEventActionGet(w http.ResponseWriter, r *http.
func (s *httpdServer) handleWebAddEventActionPost(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return
@ -3918,7 +3919,7 @@ func (s *httpdServer) handleWebUpdateEventActionGet(w http.ResponseWriter, r *ht
func (s *httpdServer) handleWebUpdateEventActionPost(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return
@ -3992,7 +3993,7 @@ func (s *httpdServer) handleWebAddEventRuleGet(w http.ResponseWriter, r *http.Re
func (s *httpdServer) handleWebAddEventRulePost(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return
@ -4030,7 +4031,7 @@ func (s *httpdServer) handleWebUpdateEventRuleGet(w http.ResponseWriter, r *http
func (s *httpdServer) handleWebUpdateEventRulePost(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return
@ -4114,7 +4115,7 @@ func (s *httpdServer) handleWebAddRolePost(w http.ResponseWriter, r *http.Reques
s.renderRolePage(w, r, role, genericPageModeAdd, err)
return
}
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return
@ -4146,7 +4147,7 @@ func (s *httpdServer) handleWebUpdateRoleGet(w http.ResponseWriter, r *http.Requ
func (s *httpdServer) handleWebUpdateRolePost(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return
@ -4228,7 +4229,7 @@ func (s *httpdServer) handleWebAddIPListEntryPost(w http.ResponseWriter, r *http
return
}
entry.Type = listType
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return
@ -4265,7 +4266,7 @@ func (s *httpdServer) handleWebUpdateIPListEntryGet(w http.ResponseWriter, r *ht
func (s *httpdServer) handleWebUpdateIPListEntryPost(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return
@ -4315,7 +4316,7 @@ func (s *httpdServer) handleWebConfigs(w http.ResponseWriter, r *http.Request) {
func (s *httpdServer) handleWebConfigsPost(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return

View file

@ -16,6 +16,7 @@ package httpd
import (
"bytes"
"crypto/rand"
"encoding/json"
"errors"
"fmt"
@ -38,6 +39,7 @@ import (
"github.com/drakkan/sftpgo/v2/internal/common"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/jwt"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/mfa"
"github.com/drakkan/sftpgo/v2/internal/smtp"
@ -568,7 +570,7 @@ func (s *httpdServer) renderClientForgotPwdPage(w http.ResponseWriter, r *http.R
commonBasePage: getCommonBasePage(r),
CurrentURL: webClientForgotPwdPath,
Error: err,
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseClientPath),
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, rand.Text(), webBaseClientPath),
LoginURL: webClientLoginPath,
Title: util.I18nForgotPwdTitle,
Branding: s.binding.webClientBranding(),
@ -597,7 +599,7 @@ func (s *httpdServer) renderShareLoginPage(w http.ResponseWriter, r *http.Reques
Title: util.I18nShareLoginTitle,
CurrentURL: r.RequestURI,
Error: err,
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseClientPath),
CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, rand.Text(), webBaseClientPath),
Branding: s.binding.webClientBranding(),
Languages: s.binding.languages(),
CheckRedirect: false,
@ -878,7 +880,7 @@ func (s *httpdServer) renderClientChangePasswordPage(w http.ResponseWriter, r *h
func (s *httpdServer) handleWebClientDownloadZip(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxMultipartMem)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return
@ -1175,7 +1177,7 @@ func (s *httpdServer) handleShareGetPDF(w http.ResponseWriter, r *http.Request)
func (s *httpdServer) handleClientGetDirContents(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, nil, util.I18nErrorDirList403, http.StatusForbidden)
return
@ -1261,7 +1263,7 @@ func (s *httpdServer) handleClientGetDirContents(w http.ResponseWriter, r *http.
func (s *httpdServer) handleClientGetFiles(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return
@ -1319,7 +1321,7 @@ func (s *httpdServer) handleClientGetFiles(w http.ResponseWriter, r *http.Reques
func (s *httpdServer) handleClientEditFile(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return
@ -1395,7 +1397,7 @@ func (s *httpdServer) handleClientEditFile(w http.ResponseWriter, r *http.Reques
func (s *httpdServer) handleClientAddShareGet(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return
@ -1437,7 +1439,7 @@ func (s *httpdServer) handleClientAddShareGet(w http.ResponseWriter, r *http.Req
func (s *httpdServer) handleClientUpdateShareGet(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return
@ -1455,7 +1457,7 @@ func (s *httpdServer) handleClientUpdateShareGet(w http.ResponseWriter, r *http.
func (s *httpdServer) handleClientAddSharePost(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return
@ -1514,7 +1516,7 @@ func (s *httpdServer) handleClientAddSharePost(w http.ResponseWriter, r *http.Re
func (s *httpdServer) handleClientUpdateSharePost(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return
@ -1583,7 +1585,7 @@ func (s *httpdServer) handleClientUpdateSharePost(w http.ResponseWriter, r *http
func getAllShares(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, nil, util.I18nErrorInvalidToken, http.StatusForbidden)
return
@ -1633,7 +1635,7 @@ func (s *httpdServer) handleWebClientProfilePost(w http.ResponseWriter, r *http.
s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
}
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return
@ -1807,7 +1809,7 @@ func (s *httpdServer) handleClientViewPDF(w http.ResponseWriter, r *http.Request
func (s *httpdServer) handleClientGetPDF(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return
@ -1914,13 +1916,13 @@ func (s *httpdServer) handleClientShareLoginPost(w http.ResponseWriter, r *http.
next := path.Clean(r.URL.Query().Get("next"))
baseShareURL := path.Join(webClientPubSharesPath, share.ShareID)
isRedirect, redirectTo := checkShareRedirectURL(next, baseShareURL)
c := jwtTokenClaims{
c := &jwt.Claims{
Username: shareID,
}
if isRedirect {
c.Ref = next
}
err = c.createAndSetCookie(w, r, s.tokenAuth, tokenAudienceWebShare, ipAddr)
err = createAndSetCookie(w, r, c, s.tokenAuth, tokenAudienceWebShare, ipAddr)
if err != nil {
s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nError500Message))
return
@ -2082,7 +2084,7 @@ func checkShareRedirectURL(next, base string) (bool, string) {
func getWebTask(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize)
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return

264
internal/jwt/jwt.go Normal file
View file

@ -0,0 +1,264 @@
// Copyright (C) 2025 Nicola Murino
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
// Package jwt provides functionality for creating, parsing, and validating
// JSON Web Tokens (JWT) used in authentication and authorization workflows.
package jwt
import (
"context"
"errors"
"fmt"
"net/http"
"slices"
"strings"
"time"
"github.com/go-jose/go-jose/v4"
"github.com/go-jose/go-jose/v4/jwt"
"github.com/rs/xid"
)
var (
TokenCtxKey = &contextKey{"Token"}
ErrorCtxKey = &contextKey{"Error"}
)
// contextKey is a value for use with context.WithValue. It's used as
// a pointer so it fits in an interface{} without allocation. This technique
// for defining context keys was copied from Go 1.7's new use of context in net/http.
type contextKey struct {
name string
}
func (k *contextKey) String() string {
return "jwt context value " + k.name
}
func NewClaims(audience, ip string, duration time.Duration) *Claims {
now := time.Now()
claims := &Claims{}
claims.IssuedAt = jwt.NewNumericDate(now)
claims.NotBefore = jwt.NewNumericDate(now.Add(-10 * time.Second))
claims.Expiry = jwt.NewNumericDate(now.Add(duration))
claims.Audience = []string{audience, ip}
return claims
}
type Claims struct {
jwt.Claims
Username string `json:"username,omitempty"`
Permissions []string `json:"permissions,omitempty"`
Role string `json:"role,omitempty"`
APIKeyID string `json:"api_key,omitempty"`
NodeID string `json:"node_id,omitempty"`
MustSetTwoFactorAuth bool `json:"2fa_required,omitempty"`
MustChangePassword bool `json:"chpwd,omitempty"`
RequiredTwoFactorProtocols []string `json:"2fa_protos,omitempty"`
HideUserPageSections int `json:"hus,omitempty"`
Ref string `json:"ref,omitempty"`
}
func (c *Claims) SetIssuedAt(t time.Time) {
c.IssuedAt = jwt.NewNumericDate(t)
}
func (c *Claims) SetNotBefore(t time.Time) {
c.NotBefore = jwt.NewNumericDate(t)
}
func (c *Claims) SetExpiry(t time.Time) {
c.Expiry = jwt.NewNumericDate(t)
}
func (c *Claims) HasPerm(perm string) bool {
for _, p := range c.Permissions {
if p == "*" || p == perm {
return true
}
}
return false
}
func (c *Claims) HasAnyAudience(audiences []string) bool {
for _, a := range c.Audience {
if slices.Contains(audiences, a) {
return true
}
}
return false
}
func (c *Claims) GenerateTokenResponse(signer *Signer) (TokenResponse, error) {
token, err := signer.Sign(c)
if err != nil {
return TokenResponse{}, err
}
return c.BuildTokenResponse(token), nil
}
func (c *Claims) BuildTokenResponse(token string) TokenResponse {
return TokenResponse{Token: token, Expiry: c.Expiry.Time().UTC().Format(time.RFC3339)}
}
type TokenResponse struct {
Token string `json:"access_token"`
Expiry string `json:"expires_at"`
}
func NewSigner(algo jose.SignatureAlgorithm, key any) (*Signer, error) {
opts := (&jose.SignerOptions{}).WithType("JWT")
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: algo, Key: key}, opts)
if err != nil {
return nil, err
}
return &Signer{
signer: signer,
algo: []jose.SignatureAlgorithm{algo},
key: key,
}, nil
}
type Signer struct {
algo []jose.SignatureAlgorithm
signer jose.Signer
key any
}
func (s *Signer) Sign(claims *Claims) (string, error) {
if claims.ID == "" {
claims.ID = xid.New().String()
}
if claims.IssuedAt == nil {
claims.IssuedAt = jwt.NewNumericDate(time.Now())
}
if claims.NotBefore == nil {
claims.NotBefore = jwt.NewNumericDate(time.Now().Add(-10 * time.Second))
}
if claims.Expiry == nil {
return "", errors.New("expiration must be set")
}
if len(claims.Audience) == 0 {
return "", errors.New("audience must be set")
}
return jwt.Signed(s.signer).Claims(claims).Serialize()
}
func (s *Signer) Signer() jose.Signer {
return s.signer
}
func (s *Signer) SetSigner(signer jose.Signer) {
s.signer = signer
}
func (s *Signer) SignWithParams(claims *Claims, audience, ip string, duration time.Duration) (string, error) {
claims.Expiry = jwt.NewNumericDate(time.Now().Add(duration))
claims.Audience = []string{audience, ip}
return s.Sign(claims)
}
func NewContext(ctx context.Context, claims *Claims, err error) context.Context {
ctx = context.WithValue(ctx, TokenCtxKey, claims)
ctx = context.WithValue(ctx, ErrorCtxKey, err)
return ctx
}
func FromContext(ctx context.Context) (*Claims, error) {
val := ctx.Value(TokenCtxKey)
token, ok := val.(*Claims)
if !ok && val != nil {
return nil, fmt.Errorf("invalid type for TokenCtxKey: %T", val)
}
valErr := ctx.Value(ErrorCtxKey)
err, ok := valErr.(error)
if !ok && valErr != nil {
return nil, fmt.Errorf("invalid type for ErrorCtxKey: %T", valErr)
}
if token == nil {
return nil, errors.New("no token found")
}
return token, err
}
func Verify(s *Signer, findTokenFns ...func(r *http.Request) string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
hfn := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
token, err := VerifyRequest(s, r, findTokenFns...)
ctx = NewContext(ctx, token, err)
next.ServeHTTP(w, r.WithContext(ctx))
}
return http.HandlerFunc(hfn)
}
}
func VerifyRequest(s *Signer, r *http.Request, findTokenFns ...func(r *http.Request) string) (*Claims, error) {
var tokenString string
for _, fn := range findTokenFns {
tokenString = fn(r)
if tokenString != "" {
break
}
}
if tokenString == "" {
return nil, errors.New("no token found")
}
return VerifyToken(s, tokenString)
}
func VerifyToken(s *Signer, payload string) (*Claims, error) {
return VerifyTokenWithKey(payload, s.algo, s.key)
}
func VerifyTokenWithKey(payload string, algo []jose.SignatureAlgorithm, key any) (*Claims, error) {
token, err := jwt.ParseSigned(payload, algo)
if err != nil {
return nil, err
}
var claims Claims
err = token.Claims(key, &claims)
if err != nil {
return nil, err
}
if err := claims.ValidateWithLeeway(jwt.Expected{Time: time.Now()}, 15*time.Second); err != nil {
return nil, err
}
return &claims, nil
}
// TokenFromCookie tries to retrieve the token string from a cookie named
// "jwt".
func TokenFromCookie(r *http.Request) string {
cookie, err := r.Cookie("jwt")
if err != nil {
return ""
}
return cookie.Value
}
// TokenFromHeader tries to retrieve the token string from the
// "Authorization" request header: "Authorization: BEARER T".
func TokenFromHeader(r *http.Request) string {
// Get token from authorization header.
bearer := r.Header.Get("Authorization")
const prefix = "Bearer "
if len(bearer) >= len(prefix) && strings.EqualFold(bearer[:len(prefix)], prefix) {
return bearer[len(prefix):]
}
return ""
}

225
internal/jwt/jwt_test.go Normal file
View file

@ -0,0 +1,225 @@
package jwt
import (
"context"
"errors"
"fmt"
"io/fs"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/go-jose/go-jose/v4"
"github.com/go-jose/go-jose/v4/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/drakkan/sftpgo/v2/internal/util"
)
type failingJoseSigner struct{}
func (s *failingJoseSigner) Sign(payload []byte) (*jose.JSONWebSignature, error) {
return nil, errors.New("sign test error")
}
func (s *failingJoseSigner) Options() jose.SignerOptions {
return jose.SignerOptions{}
}
func TestJWTToken(t *testing.T) {
s, err := NewSigner(jose.HS256, util.GenerateRandomBytes(32))
require.NoError(t, err)
username := util.GenerateUniqueID()
claims := Claims{
Username: username,
Claims: jwt.Claims{
Audience: jwt.Audience{"test"},
Expiry: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)),
NotBefore: jwt.NewNumericDate(time.Now()),
IssuedAt: jwt.NewNumericDate(time.Now()),
},
}
token, err := s.Sign(&claims)
require.NoError(t, err)
require.NotEmpty(t, token)
parsed, err := VerifyToken(s, token)
require.NoError(t, err)
require.Equal(t, username, parsed.Username)
ja1, err := NewSigner(jose.HS256, util.GenerateRandomBytes(32))
require.NoError(t, err)
token, err = ja1.Sign(&claims)
require.NoError(t, err)
require.NotEmpty(t, token)
_, err = VerifyToken(s, token)
require.Error(t, err)
_, err = VerifyToken(ja1, token)
require.NoError(t, err)
}
func TestClaims(t *testing.T) {
claims := NewClaims(util.GenerateUniqueID(), "", 10*time.Minute)
s, err := NewSigner(jose.HS256, util.GenerateRandomBytes(32))
require.NoError(t, err)
token, err := s.Sign(claims)
require.NoError(t, err)
assert.NotEmpty(t, token)
assert.NotNil(t, claims.Expiry)
assert.NotNil(t, claims.IssuedAt)
assert.NotNil(t, claims.NotBefore)
claims = &Claims{
Permissions: []string{"myperm"},
}
claims.SetExpiry(time.Now().Add(1 * time.Minute))
claims.Audience = []string{"testaudience"}
_, err = s.Sign(claims)
assert.NoError(t, err)
assert.NotNil(t, claims.IssuedAt)
assert.NotNil(t, claims.NotBefore)
assert.True(t, claims.HasAnyAudience([]string{util.GenerateUniqueID(), util.GenerateUniqueID(), "testaudience"}))
assert.False(t, claims.HasAnyAudience([]string{util.GenerateUniqueID()}))
assert.True(t, claims.HasPerm("myperm"))
assert.False(t, claims.HasPerm(util.GenerateUniqueID()))
resp, err := claims.GenerateTokenResponse(s)
require.NoError(t, err)
assert.NotEmpty(t, resp.Token)
assert.Equal(t, claims.Expiry.Time().UTC().Format(time.RFC3339), resp.Expiry)
claims.SetIssuedAt(time.Now())
claims.SetNotBefore(time.Now().Add(10 * time.Minute))
token, err = s.SignWithParams(claims, util.GenerateUniqueID(), "127.0.0.1", time.Minute)
assert.NoError(t, err)
_, err = VerifyToken(s, token)
assert.ErrorContains(t, err, "nbf")
claims = &Claims{}
_, err = s.Sign(claims)
assert.ErrorContains(t, err, "expiration must be set")
claims.SetExpiry(time.Now())
_, err = s.Sign(claims)
assert.ErrorContains(t, err, "audience must be set")
claims = &Claims{}
_, err = s.SignWithParams(claims, util.GenerateUniqueID(), "", time.Minute)
assert.NoError(t, err)
}
func TestClaimsPermissions(t *testing.T) {
c := Claims{
Permissions: []string{"*"},
}
assert.True(t, c.HasPerm(util.GenerateUniqueID()))
c.Permissions = []string{"list"}
assert.False(t, c.HasPerm(util.GenerateUniqueID()))
assert.True(t, c.HasPerm("list"))
}
func TestErrors(t *testing.T) {
s, err := NewSigner(jose.HS256, util.GenerateRandomBytes(32))
require.NoError(t, err)
_, err = VerifyToken(s, util.GenerateUniqueID())
assert.Error(t, err)
claims := &Claims{}
claims.SetExpiry(time.Now().Add(-1 * time.Minute))
token, err := jwt.Signed(s.Signer()).Claims(claims).Serialize()
assert.NoError(t, err)
_, err = VerifyToken(s, token)
assert.ErrorContains(t, err, "exp")
claims.SetExpiry(time.Now().Add(2 * time.Minute))
claims.SetIssuedAt(time.Now().Add(1 * time.Minute))
token, err = jwt.Signed(s.Signer()).Claims(claims).Serialize()
assert.NoError(t, err)
_, err = VerifyToken(s, token)
assert.ErrorContains(t, err, "iat")
claims.SetIssuedAt(time.Now())
claims.SetNotBefore(time.Now().Add(1 * time.Minute))
token, err = jwt.Signed(s.Signer()).Claims(claims).Serialize()
assert.NoError(t, err)
_, err = VerifyToken(s, token)
assert.ErrorContains(t, err, "nbf")
s.SetSigner(&failingJoseSigner{})
claims = NewClaims(util.GenerateUniqueID(), "", time.Minute)
_, err = s.Sign(claims)
assert.Error(t, err)
_, err = claims.GenerateTokenResponse(s)
assert.Error(t, err)
// Wrong algorithm
_, err = NewSigner("PS256", util.GenerateRandomBytes(32))
assert.Error(t, err)
}
func TestTokenFromRequest(t *testing.T) {
claims := NewClaims(util.GenerateUniqueID(), "", 10*time.Minute)
s, err := NewSigner(jose.HS256, util.GenerateRandomBytes(32))
require.NoError(t, err)
token, err := s.Sign(claims)
require.NoError(t, err)
assert.NotEmpty(t, token)
req, err := http.NewRequest(http.MethodGet, "/", nil)
require.NoError(t, err)
req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token))
cookie := TokenFromCookie(req)
assert.Equal(t, token, cookie)
req, err = http.NewRequest(http.MethodGet, "/", nil)
require.NoError(t, err)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
_, err = VerifyRequest(s, req, TokenFromHeader)
assert.NoError(t, err)
req.Header.Set("Authorization", token)
assert.Empty(t, TokenFromHeader(req))
assert.Empty(t, TokenFromCookie(req))
_, err = VerifyRequest(s, req, TokenFromCookie)
assert.ErrorContains(t, err, "no token found")
}
func TestContext(t *testing.T) {
claims := &Claims{
Username: util.GenerateUniqueID(),
}
s, err := NewSigner(jose.HS256, util.GenerateRandomBytes(32))
require.NoError(t, err)
token, err := s.SignWithParams(claims, util.GenerateUniqueID(), "", time.Minute)
require.NoError(t, err)
req, err := http.NewRequest(http.MethodGet, "/", nil)
require.NoError(t, err)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
h := Verify(s, TokenFromHeader)
wrapped := h(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token, err := FromContext(r.Context())
assert.Nil(t, err)
assert.Equal(t, claims.Username, token.Username)
w.WriteHeader(http.StatusOK)
}))
rr := httptest.NewRecorder()
wrapped.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
_, err = FromContext(context.Background())
assert.ErrorContains(t, err, "no token found")
ctx := NewContext(context.Background(), &Claims{}, fs.ErrClosed)
_, err = FromContext(ctx)
assert.Equal(t, fs.ErrClosed, err)
ctx = context.WithValue(context.Background(), TokenCtxKey, "1")
_, err = FromContext(ctx)
assert.ErrorContains(t, err, "invalid type for TokenCtxKey")
ctx = context.WithValue(context.Background(), ErrorCtxKey, 2)
_, err = FromContext(ctx)
assert.ErrorContains(t, err, "invalid type for ErrorCtxKey")
claims = NewClaims(util.GenerateUniqueID(), "127.1.1.1", time.Minute)
_, err = s.Sign(claims)
require.NoError(t, err)
ctx = context.WithValue(context.Background(), TokenCtxKey, claims)
claimsFromContext, err := FromContext(ctx)
assert.NoError(t, err)
assert.Equal(t, claims, claimsFromContext)
assert.Equal(t, "jwt context value Token", TokenCtxKey.String())
}