Session improvements (#510)

This commit is contained in:
0xCA 2024-01-06 13:11:20 +05:00 committed by GitHub
parent 46b09348e3
commit fa33d3f66e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 274 additions and 36 deletions

View file

@ -42,6 +42,7 @@ docker-compose up
| `BIND_ADDRESS` | The addresses that can access to the web interface and the port, use unix:///abspath/to/file.socket for unix domain socket. | 0.0.0.0:80 |
| `SESSION_SECRET` | The secret key used to encrypt the session cookies. Set this to a random value | N/A |
| `SESSION_SECRET_FILE` | Optional filepath for the secret key used to encrypt the session cookies. Leave `SESSION_SECRET` blank to take effect | N/A |
| `SESSION_MAX_DURATION` | Max time in days a remembered session is refreshed and valid. Non-refreshed session is valid for 7 days max, regardless of this setting. | 90 |
| `SUBNET_RANGES` | The list of address subdivision ranges. Format: `SR Name:10.0.1.0/24; SR2:10.0.2.0/24,10.0.3.0/24` Each CIDR must be inside one of the server interfaces. | N/A |
| `WGUI_USERNAME` | The username for the login page. Used for db initialization only | `admin` |
| `WGUI_PASSWORD` | The password for the user on the login page. Will be hashed automatically. Used for db initialization only | `admin` |

View file

@ -93,32 +93,41 @@ func Login(db store.IStore) echo.HandlerFunc {
}
if userCorrect && passwordCorrect {
// TODO: refresh the token
ageMax := 0
expiration := time.Now().Add(24 * time.Hour)
if rememberMe {
ageMax = 86400
expiration.Add(144 * time.Hour)
ageMax = 86400 * 7
}
cookiePath := util.GetCookiePath()
sess, _ := session.Get("session", c)
sess.Options = &sessions.Options{
Path: util.BasePath,
Path: cookiePath,
MaxAge: ageMax,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}
// set session_token
tokenUID := xid.New().String()
now := time.Now().UTC().Unix()
sess.Values["username"] = dbuser.Username
sess.Values["user_hash"] = util.GetDBUserCRC32(dbuser)
sess.Values["admin"] = dbuser.Admin
sess.Values["session_token"] = tokenUID
sess.Values["max_age"] = ageMax
sess.Values["created_at"] = now
sess.Values["updated_at"] = now
sess.Save(c.Request(), c.Response())
// set session_token in cookie
cookie := new(http.Cookie)
cookie.Name = "session_token"
cookie.Path = cookiePath
cookie.Value = tokenUID
cookie.Expires = expiration
cookie.MaxAge = ageMax
cookie.HttpOnly = true
cookie.SameSite = http.SameSiteLaxMode
c.SetCookie(cookie)
return c.JSON(http.StatusOK, jsonHTTPResponse{true, "Logged in successfully"})
@ -256,7 +265,7 @@ func UpdateUser(db store.IStore) echo.HandlerFunc {
log.Infof("Updated user information successfully")
if previousUsername == currentUser(c) {
setUser(c, user.Username, user.Admin)
setUser(c, user.Username, user.Admin, util.GetDBUserCRC32(user))
}
return c.JSON(http.StatusOK, jsonHTTPResponse{true, "Updated user information successfully"})

View file

@ -3,7 +3,9 @@ package handler
import (
"fmt"
"net/http"
"time"
"github.com/gorilla/sessions"
"github.com/labstack/echo-contrib/session"
"github.com/labstack/echo/v4"
"github.com/ngoduykhanh/wireguard-ui/util"
@ -23,6 +25,15 @@ func ValidSession(next echo.HandlerFunc) echo.HandlerFunc {
}
}
// RefreshSession must only be used after ValidSession middleware
// RefreshSession checks if the session is eligible for the refresh, but doesn't check if it's fully valid
func RefreshSession(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
doRefreshSession(c)
return next(c)
}
}
func NeedsAdmin(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if !isAdmin(c) {
@ -41,9 +52,146 @@ func isValidSession(c echo.Context) bool {
if err != nil || sess.Values["session_token"] != cookie.Value {
return false
}
// Check time bounds
createdAt := getCreatedAt(sess)
updatedAt := getUpdatedAt(sess)
maxAge := getMaxAge(sess)
// Temporary session is considered valid within 24h if browser is not closed before
// This value is not saved and is used as virtual expiration
if maxAge == 0 {
maxAge = 86400
}
expiration := updatedAt + int64(maxAge)
now := time.Now().UTC().Unix()
if updatedAt > now || expiration < now || createdAt+util.SessionMaxDuration < now {
return false
}
// Check if user still exists and unchanged
username := fmt.Sprintf("%s", sess.Values["username"])
userHash := getUserHash(sess)
if uHash, ok := util.DBUsersToCRC32[username]; !ok || userHash != uHash {
return false
}
return true
}
// Refreshes a "remember me" session when the user visits web pages (not API)
// Session must be valid before calling this function
// Refresh is performed at most once per 24h
func doRefreshSession(c echo.Context) {
if util.DisableLogin {
return
}
sess, _ := session.Get("session", c)
maxAge := getMaxAge(sess)
if maxAge <= 0 {
return
}
oldCookie, err := c.Cookie("session_token")
if err != nil || sess.Values["session_token"] != oldCookie.Value {
return
}
// Refresh no sooner than 24h
createdAt := getCreatedAt(sess)
updatedAt := getUpdatedAt(sess)
expiration := updatedAt + int64(getMaxAge(sess))
now := time.Now().UTC().Unix()
if updatedAt > now || expiration < now || now-updatedAt < 86_400 || createdAt+util.SessionMaxDuration < now {
return
}
cookiePath := util.GetCookiePath()
sess.Values["updated_at"] = now
sess.Options = &sessions.Options{
Path: cookiePath,
MaxAge: maxAge,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}
sess.Save(c.Request(), c.Response())
cookie := new(http.Cookie)
cookie.Name = "session_token"
cookie.Path = cookiePath
cookie.Value = oldCookie.Value
cookie.MaxAge = maxAge
cookie.HttpOnly = true
cookie.SameSite = http.SameSiteLaxMode
c.SetCookie(cookie)
}
// Get time in seconds this session is valid without updating
func getMaxAge(sess *sessions.Session) int {
if util.DisableLogin {
return 0
}
maxAge := sess.Values["max_age"]
switch typedMaxAge := maxAge.(type) {
case int:
return typedMaxAge
default:
return 0
}
}
// Get a timestamp in seconds of the time the session was created
func getCreatedAt(sess *sessions.Session) int64 {
if util.DisableLogin {
return 0
}
createdAt := sess.Values["created_at"]
switch typedCreatedAt := createdAt.(type) {
case int64:
return typedCreatedAt
default:
return 0
}
}
// Get a timestamp in seconds of the last session update
func getUpdatedAt(sess *sessions.Session) int64 {
if util.DisableLogin {
return 0
}
lastUpdate := sess.Values["updated_at"]
switch typedLastUpdate := lastUpdate.(type) {
case int64:
return typedLastUpdate
default:
return 0
}
}
// Get CRC32 of a user at the moment of log in
// Any changes to user will result in logout of other (not updated) sessions
func getUserHash(sess *sessions.Session) uint32 {
if util.DisableLogin {
return 0
}
userHash := sess.Values["user_hash"]
switch typedUserHash := userHash.(type) {
case uint32:
return typedUserHash
default:
return 0
}
}
// currentUser to get username of logged in user
func currentUser(c echo.Context) string {
if util.DisableLogin {
@ -66,9 +214,10 @@ func isAdmin(c echo.Context) bool {
return admin == "true"
}
func setUser(c echo.Context, username string, admin bool) {
func setUser(c echo.Context, username string, admin bool, userCRC32 uint32) {
sess, _ := session.Get("session", c)
sess.Values["username"] = username
sess.Values["user_hash"] = userCRC32
sess.Values["admin"] = admin
sess.Save(c.Request(), c.Response())
}
@ -77,7 +226,24 @@ func setUser(c echo.Context, username string, admin bool) {
func clearSession(c echo.Context) {
sess, _ := session.Get("session", c)
sess.Values["username"] = ""
sess.Values["user_hash"] = 0
sess.Values["admin"] = false
sess.Values["session_token"] = ""
sess.Values["max_age"] = -1
sess.Options.MaxAge = -1
sess.Save(c.Request(), c.Response())
cookiePath := util.GetCookiePath()
cookie, err := c.Cookie("session_token")
if err != nil {
cookie = new(http.Cookie)
}
cookie.Name = "session_token"
cookie.Path = cookiePath
cookie.MaxAge = -1
cookie.HttpOnly = true
cookie.SameSite = http.SameSiteLaxMode
c.SetCookie(cookie)
}

20
main.go
View file

@ -1,6 +1,7 @@
package main
import (
"crypto/sha512"
"embed"
"flag"
"fmt"
@ -48,6 +49,7 @@ var (
flagTelegramAllowConfRequest = false
flagTelegramFloodWait = 60
flagSessionSecret = util.RandomString(32)
flagSessionMaxDuration = 90
flagWgConfTemplate string
flagBasePath string
flagSubnetRanges string
@ -91,6 +93,7 @@ func init() {
flag.StringVar(&flagWgConfTemplate, "wg-conf-template", util.LookupEnvOrString("WG_CONF_TEMPLATE", flagWgConfTemplate), "Path to custom wg.conf template.")
flag.StringVar(&flagBasePath, "base-path", util.LookupEnvOrString("BASE_PATH", flagBasePath), "The base path of the URL")
flag.StringVar(&flagSubnetRanges, "subnet-ranges", util.LookupEnvOrString("SUBNET_RANGES", flagSubnetRanges), "IP ranges to choose from when assigning an IP for a client.")
flag.IntVar(&flagSessionMaxDuration, "session-max-duration", util.LookupEnvOrInt("SESSION_MAX_DURATION", flagSessionMaxDuration), "Max time in days a remembered session is refreshed and valid.")
var (
smtpPasswordLookup = util.LookupEnvOrString("SMTP_PASSWORD", flagSmtpPassword)
@ -135,7 +138,8 @@ func init() {
util.SendgridApiKey = flagSendgridApiKey
util.EmailFrom = flagEmailFrom
util.EmailFromName = flagEmailFromName
util.SessionSecret = []byte(flagSessionSecret)
util.SessionSecret = sha512.Sum512([]byte(flagSessionSecret))
util.SessionMaxDuration = int64(flagSessionMaxDuration) * 86_400 // Store in seconds
util.WgConfTemplate = flagWgConfTemplate
util.BasePath = util.ParseBasePath(flagBasePath)
util.SubnetRanges = util.ParseSubnetRanges(flagSubnetRanges)
@ -204,7 +208,7 @@ func main() {
// register routes
app := router.New(tmplDir, extraData, util.SessionSecret)
app.GET(util.BasePath, handler.WireGuardClients(db), handler.ValidSession)
app.GET(util.BasePath, handler.WireGuardClients(db), handler.ValidSession, handler.RefreshSession)
// Important: Make sure that all non-GET routes check the request content type using handler.ContentTypeJson to
// mitigate CSRF attacks. This is effective, because browsers don't allow setting the Content-Type header on
@ -214,8 +218,8 @@ func main() {
app.GET(util.BasePath+"/login", handler.LoginPage())
app.POST(util.BasePath+"/login", handler.Login(db), handler.ContentTypeJson)
app.GET(util.BasePath+"/logout", handler.Logout(), handler.ValidSession)
app.GET(util.BasePath+"/profile", handler.LoadProfile(), handler.ValidSession)
app.GET(util.BasePath+"/users-settings", handler.UsersSettings(), handler.ValidSession, handler.NeedsAdmin)
app.GET(util.BasePath+"/profile", handler.LoadProfile(), handler.ValidSession, handler.RefreshSession)
app.GET(util.BasePath+"/users-settings", handler.UsersSettings(), handler.ValidSession, handler.RefreshSession, handler.NeedsAdmin)
app.POST(util.BasePath+"/update-user", handler.UpdateUser(db), handler.ValidSession, handler.ContentTypeJson)
app.POST(util.BasePath+"/create-user", handler.CreateUser(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin)
app.POST(util.BasePath+"/remove-user", handler.RemoveUser(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin)
@ -241,19 +245,19 @@ func main() {
app.POST(util.BasePath+"/client/set-status", handler.SetClientStatus(db), handler.ValidSession, handler.ContentTypeJson)
app.POST(util.BasePath+"/remove-client", handler.RemoveClient(db), handler.ValidSession, handler.ContentTypeJson)
app.GET(util.BasePath+"/download", handler.DownloadClient(db), handler.ValidSession)
app.GET(util.BasePath+"/wg-server", handler.WireGuardServer(db), handler.ValidSession, handler.NeedsAdmin)
app.GET(util.BasePath+"/wg-server", handler.WireGuardServer(db), handler.ValidSession, handler.RefreshSession, handler.NeedsAdmin)
app.POST(util.BasePath+"/wg-server/interfaces", handler.WireGuardServerInterfaces(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin)
app.POST(util.BasePath+"/wg-server/keypair", handler.WireGuardServerKeyPair(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin)
app.GET(util.BasePath+"/global-settings", handler.GlobalSettings(db), handler.ValidSession, handler.NeedsAdmin)
app.GET(util.BasePath+"/global-settings", handler.GlobalSettings(db), handler.ValidSession, handler.RefreshSession, handler.NeedsAdmin)
app.POST(util.BasePath+"/global-settings", handler.GlobalSettingSubmit(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin)
app.GET(util.BasePath+"/status", handler.Status(db), handler.ValidSession)
app.GET(util.BasePath+"/status", handler.Status(db), handler.ValidSession, handler.RefreshSession)
app.GET(util.BasePath+"/api/clients", handler.GetClients(db), handler.ValidSession)
app.GET(util.BasePath+"/api/client/:id", handler.GetClient(db), handler.ValidSession)
app.GET(util.BasePath+"/api/machine-ips", handler.MachineIPAddresses(), handler.ValidSession)
app.GET(util.BasePath+"/api/subnet-ranges", handler.GetOrderedSubnetRanges(), handler.ValidSession)
app.GET(util.BasePath+"/api/suggest-client-ips", handler.SuggestIPAllocation(db), handler.ValidSession)
app.POST(util.BasePath+"/api/apply-wg-config", handler.ApplyServerConfig(db, tmplDir), handler.ValidSession, handler.ContentTypeJson)
app.GET(util.BasePath+"/wake_on_lan_hosts", handler.GetWakeOnLanHosts(db), handler.ValidSession)
app.GET(util.BasePath+"/wake_on_lan_hosts", handler.GetWakeOnLanHosts(db), handler.ValidSession, handler.RefreshSession)
app.POST(util.BasePath+"/wake_on_lan_host", handler.SaveWakeOnLanHost(db), handler.ValidSession, handler.ContentTypeJson)
app.DELETE(util.BasePath+"/wake_on_lan_host/:mac_address", handler.DeleteWakeOnHost(db), handler.ValidSession, handler.ContentTypeJson)
app.PUT(util.BasePath+"/wake_on_lan_host/:mac_address", handler.WakeOnHost(db), handler.ValidSession, handler.ContentTypeJson)

View file

@ -48,9 +48,17 @@ func (t *TemplateRegistry) Render(w io.Writer, name string, data interface{}, c
}
// New function
func New(tmplDir fs.FS, extraData map[string]interface{}, secret []byte) *echo.Echo {
func New(tmplDir fs.FS, extraData map[string]interface{}, secret [64]byte) *echo.Echo {
e := echo.New()
e.Use(session.Middleware(sessions.NewCookieStore(secret)))
cookiePath := util.GetCookiePath()
cookieStore := sessions.NewCookieStore(secret[:32], secret[32:])
cookieStore.Options.Path = cookiePath
cookieStore.Options.HttpOnly = true
cookieStore.MaxAge(86400 * 7)
e.Use(session.Middleware(cookieStore))
// read html template file to string
tmplBaseString, err := util.StringFromEmbedFile(tmplDir, "base.html")

View file

@ -161,6 +161,14 @@ func (o *JsonDB) Init() error {
}
// init cache
for _, i := range results {
user := model.User{}
if err := json.Unmarshal([]byte(i), &user); err == nil {
util.DBUsersToCRC32[user.Username] = util.GetDBUserCRC32(user)
}
}
clients, err := o.GetClients(false)
if err != nil {
return nil
@ -214,11 +222,13 @@ func (o *JsonDB) SaveUser(user model.User) error {
if err != nil {
return err
}
util.DBUsersToCRC32[user.Username] = util.GetDBUserCRC32(user)
return output
}
// DeleteUser func to remove user from the database
func (o *JsonDB) DeleteUser(username string) error {
delete(util.DBUsersToCRC32, username)
return o.conn.Delete("users", username)
}

View file

@ -5,3 +5,4 @@ import "sync"
var IPToSubnetRange = map[string]uint16{}
var TgUseridToClientID = map[int64][]string{}
var TgUseridToClientIDMutex sync.RWMutex
var DBUsersToCRC32 = map[string]uint32{}

View file

@ -9,24 +9,25 @@ import (
// Runtime config
var (
DisableLogin bool
BindAddress string
SmtpHostname string
SmtpPort int
SmtpUsername string
SmtpPassword string
SmtpNoTLSCheck bool
SmtpEncryption string
SmtpAuthType string
SmtpHelo string
SendgridApiKey string
EmailFrom string
EmailFromName string
SessionSecret []byte
WgConfTemplate string
BasePath string
SubnetRanges map[string]([]*net.IPNet)
SubnetRangesOrder []string
DisableLogin bool
BindAddress string
SmtpHostname string
SmtpPort int
SmtpUsername string
SmtpPassword string
SmtpNoTLSCheck bool
SmtpEncryption string
SmtpAuthType string
SmtpHelo string
SendgridApiKey string
EmailFrom string
EmailFromName string
SessionSecret [64]byte
SessionMaxDuration int64
WgConfTemplate string
BasePath string
SubnetRanges map[string]([]*net.IPNet)
SubnetRangesOrder []string
)
const (

View file

@ -2,9 +2,12 @@ package util
import (
"bufio"
"bytes"
"encoding/gob"
"encoding/json"
"errors"
"fmt"
"hash/crc32"
"io"
"io/fs"
"math/rand"
@ -827,3 +830,38 @@ func filterStringSlice(s []string, excludedStr string) []string {
}
return filtered
}
func GetDBUserCRC32(dbuser model.User) uint32 {
buf := new(bytes.Buffer)
enc := gob.NewEncoder(buf)
if err := enc.Encode(dbuser); err != nil {
panic("model.User is gob-incompatible, session verification is impossible")
}
return crc32.ChecksumIEEE(buf.Bytes())
}
func ConcatMultipleSlices(slices ...[]byte) []byte {
var totalLen int
for _, s := range slices {
totalLen += len(s)
}
result := make([]byte, totalLen)
var i int
for _, s := range slices {
i += copy(result[i:], s)
}
return result
}
func GetCookiePath() string {
cookiePath := BasePath
if cookiePath == "" {
cookiePath = "/"
}
return cookiePath
}