From 38c1f3a3026cfdfd288aed77c5b4be1917ac24dc Mon Sep 17 00:00:00 2001 From: Khanh Ngo Date: Thu, 23 Apr 2020 18:01:40 +0700 Subject: [PATCH] DB query refactoring --- handler/routes.go | 193 ++++++++++------------------------------------ main.go | 10 +++ util/db.go | 167 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 216 insertions(+), 154 deletions(-) create mode 100644 util/db.go diff --git a/handler/routes.go b/handler/routes.go index b4dbdd4..409dc48 100644 --- a/handler/routes.go +++ b/handler/routes.go @@ -1,7 +1,6 @@ package handler import ( - "encoding/base64" "encoding/json" "fmt" "net/http" @@ -12,68 +11,15 @@ import ( "github.com/ngoduykhanh/wireguard-ui/model" "github.com/ngoduykhanh/wireguard-ui/util" "github.com/rs/xid" - "github.com/sdomino/scribble" - "github.com/skip2/go-qrcode" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) // WireGuardClients handler func WireGuardClients() echo.HandlerFunc { return func(c echo.Context) error { - // initialize database directory - dir := "./db" - db, err := scribble.New(dir, nil) + clientDataList, err := util.GetClients(true) if err != nil { - log.Error("Cannot initialize the database: ", err) - } - - // read server information - serverInterface := model.ServerInterface{} - if err := db.Read("server", "interfaces", &serverInterface); err != nil { - log.Error("Cannot fetch server interface config from database: ", err) - } - - serverKeyPair := model.ServerKeypair{} - if err := db.Read("server", "keypair", &serverKeyPair); err != nil { - log.Error("Cannot fetch server key pair from database: ", err) - } - - // read global settings - globalSettings := model.GlobalSetting{} - if err := db.Read("server", "global_settings", &globalSettings); err != nil { - log.Error("Cannot fetch global settings from database: ", err) - } - - server := model.Server{} - server.Interface = &serverInterface - server.KeyPair = &serverKeyPair - - // read client information and build a client list - records, err := db.ReadAll("clients") - if err != nil { - log.Error("Cannot fetch clients from database: ", err) - } - - clientDataList := []model.ClientData{} - for _, f := range records { - client := model.Client{} - clientData := model.ClientData{} - - // get client info - if err := json.Unmarshal([]byte(f), &client); err != nil { - log.Error("Cannot decode client json structure: ", err) - } - clientData.Client = &client - - // generate client qrcode image in base64 - png, err := qrcode.Encode(util.BuildClientConfig(client, server, globalSettings), qrcode.Medium, 256) - if err != nil { - log.Error("Cannot generate QRCode: ", err) - } - clientData.QRCode = "data:image/png;base64," + base64.StdEncoding.EncodeToString([]byte(png)) - - // create the list of clients and their qrcode data - clientDataList = append(clientDataList, clientData) + return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, fmt.Sprintf("Cannot get client list: %v", err)}) } return c.Render(http.StatusOK, "clients.html", map[string]interface{}{ @@ -89,11 +35,9 @@ func NewClient() echo.HandlerFunc { client := new(model.Client) c.Bind(client) - // initialize db - dir := "./db" - db, err := scribble.New(dir, nil) + db, err := util.DBConn() if err != nil { - log.Error("Cannot initialize the database: ", err) + log.Error("Cannot initialize database: ", err) return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Cannot access database"}) } @@ -152,11 +96,10 @@ func SetClientStatus() echo.HandlerFunc { clientID := data["id"].(string) status := data["status"].(bool) - // initialize database directory - dir := "./db" - db, err := scribble.New(dir, nil) + db, err := util.DBConn() if err != nil { - log.Error("Cannot initialize the database: ", err) + log.Error("Cannot initialize database: ", err) + return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Cannot access database"}) } client := model.Client{} @@ -179,12 +122,12 @@ func RemoveClient() echo.HandlerFunc { c.Bind(client) // delete client from database - dir := "./db" - db, err := scribble.New(dir, nil) + db, err := util.DBConn() if err != nil { - log.Error("Cannot initialize the database: ", err) + log.Error("Cannot initialize database: ", err) return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Cannot access database"}) } + if err := db.Delete("clients", client.ID); err != nil { log.Error("Cannot delete wireguard client: ", err) return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Cannot delete client from database"}) @@ -198,27 +141,15 @@ func RemoveClient() echo.HandlerFunc { // WireGuardServer handler func WireGuardServer() echo.HandlerFunc { return func(c echo.Context) error { - // initialize database directory - dir := "./db" - db, err := scribble.New(dir, nil) + server, err := util.GetServer() if err != nil { - log.Error("Cannot initialize the database: ", err) - } - - serverInterface := model.ServerInterface{} - if err := db.Read("server", "interfaces", &serverInterface); err != nil { - log.Error("Cannot fetch server interface config from database: ", err) - } - - serverKeyPair := model.ServerKeypair{} - if err := db.Read("server", "keypair", &serverKeyPair); err != nil { - log.Error("Cannot fetch server key pair from database: ", err) + log.Error("Cannot get server config: ", err) } return c.Render(http.StatusOK, "server.html", map[string]interface{}{ "baseData": model.BaseData{Active: "wg-server"}, - "serverInterface": serverInterface, - "serverKeyPair": serverKeyPair, + "serverInterface": server.Interface, + "serverKeyPair": server.KeyPair, }) } } @@ -238,12 +169,12 @@ func WireGuardServerInterfaces() echo.HandlerFunc { serverInterface.UpdatedAt = time.Now().UTC() // write config to the database - dir := "./db" - db, err := scribble.New(dir, nil) + db, err := util.DBConn() if err != nil { - log.Error("Cannot initialize the database: ", err) + log.Error("Cannot initialize database: ", err) return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Cannot access database"}) } + db.Write("server", "interfaces", serverInterface) log.Infof("Updated wireguard server interfaces settings: %v", serverInterface) @@ -267,12 +198,12 @@ func WireGuardServerKeyPair() echo.HandlerFunc { serverKeyPair.UpdatedAt = time.Now().UTC() // write config to the database - dir := "./db" - db, err := scribble.New(dir, nil) + db, err := util.DBConn() if err != nil { - log.Error("Cannot initialize the database: ", err) + log.Error("Cannot initialize database: ", err) return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Cannot access database"}) } + db.Write("server", "keypair", serverKeyPair) log.Infof("Updated wireguard server interfaces settings: %v", serverKeyPair) @@ -283,16 +214,9 @@ func WireGuardServerKeyPair() echo.HandlerFunc { // GlobalSettings handler func GlobalSettings() echo.HandlerFunc { return func(c echo.Context) error { - // initialize database directory - dir := "./db" - db, err := scribble.New(dir, nil) + globalSettings, err := util.GetGlobalSettings() if err != nil { - log.Error("Cannot initialize the database: ", err) - } - - globalSettings := model.GlobalSetting{} - if err := db.Read("server", "global_settings", &globalSettings); err != nil { - log.Error("Cannot fetch global settings from database: ", err) + log.Error("Cannot get global settings: ", err) } return c.Render(http.StatusOK, "global_settings.html", map[string]interface{}{ @@ -317,12 +241,12 @@ func GlobalSettingSubmit() echo.HandlerFunc { globalSettings.UpdatedAt = time.Now().UTC() // write config to the database - dir := "./db" - db, err := scribble.New(dir, nil) + db, err := util.DBConn() if err != nil { - log.Error("Cannot initialize the database: ", err) + log.Error("Cannot initialize database: ", err) return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Cannot access database"}) } + db.Write("server", "global_settings", globalSettings) log.Infof("Updated global settings: %v", globalSettings) @@ -355,17 +279,9 @@ func MachineIPAddresses() echo.HandlerFunc { // SuggestIPAllocation handler to get the list of ip address for client func SuggestIPAllocation() echo.HandlerFunc { return func(c echo.Context) error { - // initialize database directory - dir := "./db" - db, err := scribble.New(dir, nil) + server, err := util.GetServer() if err != nil { - log.Error("Cannot initialize the database: ", err) - } - - // read server information - serverInterface := model.ServerInterface{} - if err := db.Read("server", "interfaces", &serverInterface); err != nil { - log.Error("Cannot fetch server interface config from database: ", err) + log.Error("Cannot fetch server config from database: ", err) } // return the list of suggestedIPs @@ -377,7 +293,7 @@ func SuggestIPAllocation() echo.HandlerFunc { log.Error("Cannot suggest ip allocation. Failed to get list of allocated ip addresses: ", err) return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Cannot suggest ip allocation: failed to get list of allocated ip addresses"}) } - for _, cidr := range serverInterface.Addresses { + for _, cidr := range server.Interface.Addresses { ip, err := util.GetAvailableIP(cidr, allocatedIPs) if err != nil { log.Error("Failed to get available ip from a CIDR: ", err) @@ -393,57 +309,26 @@ func SuggestIPAllocation() echo.HandlerFunc { // ApplyServerConfig handler to write config file and restart Wireguard server func ApplyServerConfig() echo.HandlerFunc { return func(c echo.Context) error { - // initialize database directory - dir := "./db" - db, err := scribble.New(dir, nil) + server, err := util.GetServer() if err != nil { - log.Error("Cannot initialize the database: ", err) + log.Error("Cannot get server config: ", err) + return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Cannot get server config"}) } - // read server information - serverInterface := model.ServerInterface{} - if err := db.Read("server", "interfaces", &serverInterface); err != nil { - log.Error("Cannot fetch server interface config from database: ", err) - } - - serverKeyPair := model.ServerKeypair{} - if err := db.Read("server", "keypair", &serverKeyPair); err != nil { - log.Error("Cannot fetch server key pair from database: ", err) - } - - server := model.Server{} - server.Interface = &serverInterface - server.KeyPair = &serverKeyPair - - // read global settings - globalSettings := model.GlobalSetting{} - if err := db.Read("server", "global_settings", &globalSettings); err != nil { - log.Error("Cannot fetch global settings from database: ", err) - } - - // read client information and build a client list - records, err := db.ReadAll("clients") + clients, err := util.GetClients(false) if err != nil { - log.Error("Cannot fetch clients from database: ", err) + log.Error("Cannot get client config: ", err) + return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Cannot get client config"}) } - clientDataList := []model.ClientData{} - for _, f := range records { - client := model.Client{} - clientData := model.ClientData{} - - // get client info - if err := json.Unmarshal([]byte(f), &client); err != nil { - log.Error("Cannot decode client json structure: ", err) - } - clientData.Client = &client - - // create the list of clients and their qrcode data - clientDataList = append(clientDataList, clientData) + settings, err := util.GetGlobalSettings() + if err != nil { + log.Error("Cannot get global settings: ", err) + return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Cannot get global settings"}) } // Write config file - err = util.WriteWireGuardServerConfig(server, clientDataList, globalSettings) + err = util.WriteWireGuardServerConfig(server, clients, settings) if err != nil { log.Error("Cannot apply server config: ", err) return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, fmt.Sprintf("Cannot apply server config: %v", err)}) diff --git a/main.go b/main.go index 5adb242..972474e 100644 --- a/main.go +++ b/main.go @@ -1,11 +1,21 @@ package main import ( + "fmt" + "github.com/ngoduykhanh/wireguard-ui/handler" "github.com/ngoduykhanh/wireguard-ui/router" + "github.com/ngoduykhanh/wireguard-ui/util" ) func main() { + // initialize DB + err := util.InitDB() + if err != nil { + fmt.Print("Cannot init database: ", err) + } + + // register routes app := router.New() app.GET("/", handler.WireGuardClients()) diff --git a/util/db.go b/util/db.go new file mode 100644 index 0000000..536e007 --- /dev/null +++ b/util/db.go @@ -0,0 +1,167 @@ +package util + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "os" + "path" + "time" + + "github.com/ngoduykhanh/wireguard-ui/model" + "github.com/sdomino/scribble" + "github.com/skip2/go-qrcode" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +const dbPath = "./db" +const defaultServerAddress = "10.252.1.0/24" +const defaultServerPort = 51820 + +// DBConn to initialize the database connection +func DBConn() (*scribble.Driver, error) { + db, err := scribble.New(dbPath, nil) + if err != nil { + return nil, err + } + return db, nil +} + +// InitDB to create the default database +func InitDB() error { + var clientPath string = path.Join(dbPath, "clients") + var serverPath string = path.Join(dbPath, "server") + var serverInterfacePath string = path.Join(serverPath, "interfaces.json") + var serverKeyPairPath string = path.Join(serverPath, "keypair.json") + + // create directories if they do not exist + if _, err := os.Stat(clientPath); os.IsNotExist(err) { + os.Mkdir(clientPath, os.ModePerm) + } + if _, err := os.Stat(serverPath); os.IsNotExist(err) { + os.Mkdir(serverPath, os.ModePerm) + } + + // server's interface + if _, err := os.Stat(serverInterfacePath); os.IsNotExist(err) { + db, err := DBConn() + if err != nil { + return err + } + + serverInterface := new(model.ServerInterface) + serverInterface.Addresses = []string{defaultServerAddress} + serverInterface.ListenPort = defaultServerPort + serverInterface.UpdatedAt = time.Now().UTC() + db.Write("server", "interfaces", serverInterface) + } + + // server's key pair + if _, err := os.Stat(serverKeyPairPath); os.IsNotExist(err) { + db, err := DBConn() + if err != nil { + return err + } + + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + return scribble.ErrMissingCollection + } + serverKeyPair := new(model.ServerKeypair) + serverKeyPair.PrivateKey = key.String() + serverKeyPair.PublicKey = key.PublicKey().String() + serverKeyPair.UpdatedAt = time.Now().UTC() + db.Write("server", "keypair", serverKeyPair) + } + + return nil +} + +// GetGlobalSettings func to query global settings from the database +func GetGlobalSettings() (model.GlobalSetting, error) { + settings := model.GlobalSetting{} + + db, err := DBConn() + if err != nil { + return settings, err + } + + if err := db.Read("server", "global_settings", &settings); err != nil { + return settings, err + } + + return settings, nil +} + +// GetServer func to query Server setting from the database +func GetServer() (model.Server, error) { + server := model.Server{} + + db, err := DBConn() + if err != nil { + return server, err + } + + // read server interface information + serverInterface := model.ServerInterface{} + if err := db.Read("server", "interfaces", &serverInterface); err != nil { + return server, err + } + + // read server key pair information + serverKeyPair := model.ServerKeypair{} + if err := db.Read("server", "keypair", &serverKeyPair); err != nil { + return server, err + } + + // create Server object and return + server.Interface = &serverInterface + server.KeyPair = &serverKeyPair + return server, nil +} + +// GetClients to get all clients from the database +func GetClients(hasQRCode bool) ([]model.ClientData, error) { + clients := []model.ClientData{} + + db, err := DBConn() + if err != nil { + return clients, err + } + + // read all client json file in "clients" directory + records, err := db.ReadAll("clients") + if err != nil { + return clients, err + } + + // build the ClientData list + for _, f := range records { + client := model.Client{} + clientData := model.ClientData{} + + // get client info + if err := json.Unmarshal([]byte(f), &client); err != nil { + return clients, fmt.Errorf("Cannot decode client json structure: %v", err) + } + + // generate client qrcode image in base64 + if hasQRCode { + server, _ := GetServer() + globalSettings, _ := GetGlobalSettings() + + png, _ := qrcode.Encode(BuildClientConfig(client, server, globalSettings), qrcode.Medium, 256) + if err == nil { + clientData.QRCode = "data:image/png;base64," + base64.StdEncoding.EncodeToString([]byte(png)) + } else { + fmt.Print("Cannot generate QR code: ", err) + } + } + + // create the list of clients and their qrcode data + clientData.Client = &client + clients = append(clients, clientData) + } + + return clients, nil +}