wireguard-ui/store/jsondb/jsondb.go
0xCA bee5c54127 Further session protections and fixes
Use MaxAge instead of Expires
Verify if the cookie is not too old and not from the future
Verify if the user exists and unchanged
Refresh not sooner than 24h
Do not refresh temporary sessions
Delete cookies on logout
2023-12-29 15:08:50 +05:00

413 lines
12 KiB
Go

package jsondb
import (
"encoding/base64"
"encoding/json"
"fmt"
"os"
"path"
"strconv"
"time"
"github.com/sdomino/scribble"
"github.com/skip2/go-qrcode"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/ngoduykhanh/wireguard-ui/model"
"github.com/ngoduykhanh/wireguard-ui/util"
)
type JsonDB struct {
conn *scribble.Driver
dbPath string
}
// New returns a new pointer JsonDB
func New(dbPath string) (*JsonDB, error) {
conn, err := scribble.New(dbPath, nil)
if err != nil {
return nil, err
}
ans := JsonDB{
conn: conn,
dbPath: dbPath,
}
return &ans, nil
}
func (o *JsonDB) Init() error {
var clientPath = path.Join(o.dbPath, "clients")
var serverPath = path.Join(o.dbPath, "server")
var userPath = path.Join(o.dbPath, "users")
var wakeOnLanHostsPath = path.Join(o.dbPath, "wake_on_lan_hosts")
var serverInterfacePath = path.Join(serverPath, "interfaces.json")
var serverKeyPairPath = path.Join(serverPath, "keypair.json")
var globalSettingPath = path.Join(serverPath, "global_settings.json")
var hashesPath = path.Join(serverPath, "hashes.json")
// create directories if they do not exist
if _, err := os.Stat(clientPath); os.IsNotExist(err) {
os.MkdirAll(clientPath, os.ModePerm)
}
if _, err := os.Stat(serverPath); os.IsNotExist(err) {
os.MkdirAll(serverPath, os.ModePerm)
}
if _, err := os.Stat(userPath); os.IsNotExist(err) {
os.MkdirAll(userPath, os.ModePerm)
}
if _, err := os.Stat(wakeOnLanHostsPath); os.IsNotExist(err) {
os.MkdirAll(wakeOnLanHostsPath, os.ModePerm)
}
// server's interface
if _, err := os.Stat(serverInterfacePath); os.IsNotExist(err) {
serverInterface := new(model.ServerInterface)
serverInterface.Addresses = util.LookupEnvOrStrings(util.ServerAddressesEnvVar, []string{util.DefaultServerAddress})
serverInterface.ListenPort = util.LookupEnvOrInt(util.ServerListenPortEnvVar, util.DefaultServerPort)
serverInterface.PostUp = util.LookupEnvOrString(util.ServerPostUpScriptEnvVar, "")
serverInterface.PostDown = util.LookupEnvOrString(util.ServerPostDownScriptEnvVar, "")
serverInterface.UpdatedAt = time.Now().UTC()
o.conn.Write("server", "interfaces", serverInterface)
err := util.ManagePerms(serverInterfacePath)
if err != nil {
return err
}
}
// server's key pair
if _, err := os.Stat(serverKeyPairPath); os.IsNotExist(err) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return scribble.ErrMissingCollection
}
serverKeyPair := new(model.ServerKeypair)
serverKeyPair.PrivateKey = key.String()
serverKeyPair.PublicKey = key.PublicKey().String()
serverKeyPair.UpdatedAt = time.Now().UTC()
o.conn.Write("server", "keypair", serverKeyPair)
err = util.ManagePerms(serverKeyPairPath)
if err != nil {
return err
}
}
// global settings
if _, err := os.Stat(globalSettingPath); os.IsNotExist(err) {
endpointAddress := util.LookupEnvOrString(util.EndpointAddressEnvVar, "")
if endpointAddress == "" {
// automatically find an external IP address
publicInterface, err := util.GetPublicIP()
if err != nil {
return err
}
endpointAddress = publicInterface.IPAddress
}
globalSetting := new(model.GlobalSetting)
globalSetting.EndpointAddress = endpointAddress
globalSetting.DNSServers = util.LookupEnvOrStrings(util.DNSEnvVar, []string{util.DefaultDNS})
globalSetting.MTU = util.LookupEnvOrInt(util.MTUEnvVar, util.DefaultMTU)
globalSetting.PersistentKeepalive = util.LookupEnvOrInt(util.PersistentKeepaliveEnvVar, util.DefaultPersistentKeepalive)
globalSetting.FirewallMark = util.LookupEnvOrString(util.FirewallMarkEnvVar, util.DefaultFirewallMark)
globalSetting.Table = util.LookupEnvOrString(util.TableEnvVar, util.DefaultTable)
globalSetting.ConfigFilePath = util.LookupEnvOrString(util.ConfigFilePathEnvVar, util.DefaultConfigFilePath)
globalSetting.UpdatedAt = time.Now().UTC()
o.conn.Write("server", "global_settings", globalSetting)
err := util.ManagePerms(globalSettingPath)
if err != nil {
return err
}
}
// hashes
if _, err := os.Stat(hashesPath); os.IsNotExist(err) {
clientServerHashes := new(model.ClientServerHashes)
clientServerHashes.Client = "none"
clientServerHashes.Server = "none"
o.conn.Write("server", "hashes", clientServerHashes)
err := util.ManagePerms(hashesPath)
if err != nil {
return err
}
}
// user info
results, err := o.conn.ReadAll("users")
if err != nil || len(results) < 1 {
user := new(model.User)
user.Username = util.LookupEnvOrString(util.UsernameEnvVar, util.DefaultUsername)
user.Admin = util.DefaultIsAdmin
user.PasswordHash = util.LookupEnvOrString(util.PasswordHashEnvVar, "")
if user.PasswordHash == "" {
user.PasswordHash = util.LookupEnvOrFile(util.PasswordHashFileEnvVar, "")
if user.PasswordHash == "" {
plaintext := util.LookupEnvOrString(util.PasswordEnvVar, util.DefaultPassword)
if plaintext == util.DefaultPassword {
plaintext = util.LookupEnvOrFile(util.PasswordFileEnvVar, util.DefaultPassword)
}
hash, err := util.HashPassword(plaintext)
if err != nil {
return err
}
user.PasswordHash = hash
}
}
o.conn.Write("users", user.Username, user)
err = util.ManagePerms(path.Join(path.Join(o.dbPath, "users"), user.Username+".json"))
if err != nil {
return err
}
}
// init cache
for _, i := range results {
user := model.User{}
if err := json.Unmarshal([]byte(i), &user); err == nil {
util.DBUsersToCRC32[user.Username] = util.GetDBUserCRC32(user)
}
}
clients, err := o.GetClients(false)
if err != nil {
return nil
}
for _, cl := range clients {
client := cl.Client
if client.Enabled && len(client.TgUserid) > 0 {
if userid, err := strconv.ParseInt(client.TgUserid, 10, 64); err == nil {
util.UpdateTgToClientID(userid, client.ID)
}
}
}
return nil
}
// GetUsers func to get all users from the database
func (o *JsonDB) GetUsers() ([]model.User, error) {
var users []model.User
results, err := o.conn.ReadAll("users")
if err != nil {
return users, err
}
for _, i := range results {
user := model.User{}
if err := json.Unmarshal(i, &user); err != nil {
return users, fmt.Errorf("cannot decode user json structure: %v", err)
}
users = append(users, user)
}
return users, err
}
// GetUserByName func to get single user from the database
func (o *JsonDB) GetUserByName(username string) (model.User, error) {
user := model.User{}
if err := o.conn.Read("users", username, &user); err != nil {
return user, err
}
return user, nil
}
// SaveUser func to save user in the database
func (o *JsonDB) SaveUser(user model.User) error {
userPath := path.Join(path.Join(o.dbPath, "users"), user.Username+".json")
output := o.conn.Write("users", user.Username, user)
err := util.ManagePerms(userPath)
if err != nil {
return err
}
util.DBUsersToCRC32[user.Username] = util.GetDBUserCRC32(user)
return output
}
// DeleteUser func to remove user from the database
func (o *JsonDB) DeleteUser(username string) error {
delete(util.DBUsersToCRC32, username)
return o.conn.Delete("users", username)
}
// GetGlobalSettings func to query global settings from the database
func (o *JsonDB) GetGlobalSettings() (model.GlobalSetting, error) {
settings := model.GlobalSetting{}
return settings, o.conn.Read("server", "global_settings", &settings)
}
// GetServer func to query Server settings from the database
func (o *JsonDB) GetServer() (model.Server, error) {
server := model.Server{}
// read server interface information
serverInterface := model.ServerInterface{}
if err := o.conn.Read("server", "interfaces", &serverInterface); err != nil {
return server, err
}
// read server key pair information
serverKeyPair := model.ServerKeypair{}
if err := o.conn.Read("server", "keypair", &serverKeyPair); err != nil {
return server, err
}
// create Server object and return
server.Interface = &serverInterface
server.KeyPair = &serverKeyPair
return server, nil
}
func (o *JsonDB) GetClients(hasQRCode bool) ([]model.ClientData, error) {
var clients []model.ClientData
// read all client json files in "clients" directory
records, err := o.conn.ReadAll("clients")
if err != nil {
return clients, err
}
// build the ClientData list
for _, f := range records {
client := model.Client{}
clientData := model.ClientData{}
// get client info
if err := json.Unmarshal(f, &client); err != nil {
return clients, fmt.Errorf("cannot decode client json structure: %v", err)
}
// generate client qrcode image in base64
if hasQRCode && client.PrivateKey != "" {
server, _ := o.GetServer()
globalSettings, _ := o.GetGlobalSettings()
png, err := qrcode.Encode(util.BuildClientConfig(client, server, globalSettings), qrcode.Medium, 256)
if err == nil {
clientData.QRCode = "data:image/png;base64," + base64.StdEncoding.EncodeToString(png)
} else {
fmt.Print("Cannot generate QR code: ", err)
}
}
// create the list of clients and their qrcode data
clientData.Client = &client
clients = append(clients, clientData)
}
return clients, nil
}
func (o *JsonDB) GetClientByID(clientID string, qrCodeSettings model.QRCodeSettings) (model.ClientData, error) {
client := model.Client{}
clientData := model.ClientData{}
// read client information
if err := o.conn.Read("clients", clientID, &client); err != nil {
return clientData, err
}
// generate client qrcode image in base64
if qrCodeSettings.Enabled && client.PrivateKey != "" {
server, _ := o.GetServer()
globalSettings, _ := o.GetGlobalSettings()
client := client
if !qrCodeSettings.IncludeDNS {
globalSettings.DNSServers = []string{}
}
if !qrCodeSettings.IncludeMTU {
globalSettings.MTU = 0
}
png, err := qrcode.Encode(util.BuildClientConfig(client, server, globalSettings), qrcode.Medium, 256)
if err == nil {
clientData.QRCode = "data:image/png;base64," + base64.StdEncoding.EncodeToString(png)
} else {
fmt.Print("Cannot generate QR code: ", err)
}
}
clientData.Client = &client
return clientData, nil
}
func (o *JsonDB) SaveClient(client model.Client) error {
clientPath := path.Join(path.Join(o.dbPath, "clients"), client.ID+".json")
output := o.conn.Write("clients", client.ID, client)
if output == nil {
if client.Enabled && len(client.TgUserid) > 0 {
if userid, err := strconv.ParseInt(client.TgUserid, 10, 64); err == nil {
util.UpdateTgToClientID(userid, client.ID)
}
} else {
util.RemoveTgToClientID(client.ID)
}
} else {
util.RemoveTgToClientID(client.ID)
}
err := util.ManagePerms(clientPath)
if err != nil {
return err
}
return output
}
func (o *JsonDB) DeleteClient(clientID string) error {
util.RemoveTgToClientID(clientID)
return o.conn.Delete("clients", clientID)
}
func (o *JsonDB) SaveServerInterface(serverInterface model.ServerInterface) error {
serverInterfacePath := path.Join(path.Join(o.dbPath, "server"), "interfaces.json")
output := o.conn.Write("server", "interfaces", serverInterface)
err := util.ManagePerms(serverInterfacePath)
if err != nil {
return err
}
return output
}
func (o *JsonDB) SaveServerKeyPair(serverKeyPair model.ServerKeypair) error {
serverKeyPairPath := path.Join(path.Join(o.dbPath, "server"), "keypair.json")
output := o.conn.Write("server", "keypair", serverKeyPair)
err := util.ManagePerms(serverKeyPairPath)
if err != nil {
return err
}
return output
}
func (o *JsonDB) SaveGlobalSettings(globalSettings model.GlobalSetting) error {
globalSettingsPath := path.Join(path.Join(o.dbPath, "server"), "global_settings.json")
output := o.conn.Write("server", "global_settings", globalSettings)
err := util.ManagePerms(globalSettingsPath)
if err != nil {
return err
}
return output
}
func (o *JsonDB) GetPath() string {
return o.dbPath
}
func (o *JsonDB) GetHashes() (model.ClientServerHashes, error) {
hashes := model.ClientServerHashes{}
return hashes, o.conn.Read("server", "hashes", &hashes)
}
func (o *JsonDB) SaveHashes(hashes model.ClientServerHashes) error {
hashesPath := path.Join(path.Join(o.dbPath, "server"), "hashes.json")
output := o.conn.Write("server", "hashes", hashes)
err := util.ManagePerms(hashesPath)
if err != nil {
return err
}
return output
}