diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..05e4ead --- /dev/null +++ b/Makefile @@ -0,0 +1,8 @@ +build: + CGO_ENABLED=0 go build -v -ldflags="-s -w" + +run: + ./wireguard-ui + +clean: + rm -f wireguard-ui diff --git a/handler/routes.go b/handler/routes.go index aa5461b..11ec048 100644 --- a/handler/routes.go +++ b/handler/routes.go @@ -381,9 +381,9 @@ func GetClient(db store.IStore) echo.HandlerFunc { } qrCodeSettings := model.QRCodeSettings{ - Enabled: true, - IncludeDNS: true, - IncludeMTU: true, + Enabled: true, + IncludeDNS: true, + IncludeMTU: true, } clientData, err := db.GetClientByID(clientID, qrCodeSettings) @@ -513,9 +513,9 @@ func EmailClient(db store.IStore, mailer emailer.Emailer, emailSubject, emailCon } qrCodeSettings := model.QRCodeSettings{ - Enabled: true, - IncludeDNS: true, - IncludeMTU: true, + Enabled: true, + IncludeDNS: true, + IncludeMTU: true, } clientData, err := db.GetClientByID(payload.ID, qrCodeSettings) if err != nil { @@ -1078,7 +1078,6 @@ func ApplyServerConfig(db store.IStore, tmplDir fs.FS) echo.HandlerFunc { } } - // GetHashesChanges handler returns if database hashes have changed func GetHashesChanges(db store.IStore) echo.HandlerFunc { return func(c echo.Context) error { diff --git a/main.go b/main.go index fd4bc90..4996d89 100644 --- a/main.go +++ b/main.go @@ -45,6 +45,7 @@ var ( flagSessionSecret string = util.RandomString(32) flagWgConfTemplate string flagBasePath string + flagSubnetRanges string ) const ( @@ -81,6 +82,7 @@ func init() { flag.StringVar(&flagEmailFromName, "email-from-name", util.LookupEnvOrString("EMAIL_FROM_NAME", flagEmailFromName), "'From' email name.") flag.StringVar(&flagWgConfTemplate, "wg-conf-template", util.LookupEnvOrString("WG_CONF_TEMPLATE", flagWgConfTemplate), "Path to custom wg.conf template.") flag.StringVar(&flagBasePath, "base-path", util.LookupEnvOrString("BASE_PATH", flagBasePath), "The base path of the URL") + flag.StringVar(&flagSubnetRanges, "subnet-ranges", util.LookupEnvOrString("SUBNET_RANGES", flagSubnetRanges), "IP ranges to choose from when assigning an IP for a client.") var ( smtpPasswordLookup = util.LookupEnvOrString("SMTP_PASSWORD", flagSmtpPassword) @@ -127,6 +129,7 @@ func init() { util.SessionSecret = []byte(flagSessionSecret) util.WgConfTemplate = flagWgConfTemplate util.BasePath = util.ParseBasePath(flagBasePath) + util.SubnetRanges = util.ParseSubnetRanges(flagSubnetRanges) // print only if log level is INFO or lower if lvl, _ := util.ParseLogLevel(util.LookupEnvOrString(util.LogLevel, "INFO")); lvl <= log.INFO { @@ -145,6 +148,7 @@ func init() { //fmt.Println("Session secret\t:", util.SessionSecret) fmt.Println("Custom wg.conf\t:", util.WgConfTemplate) fmt.Println("Base path\t:", util.BasePath+"/") + fmt.Println("Subnet ranges\t:", util.GetSubnetRangesString()) } } @@ -170,6 +174,17 @@ func main() { // create the wireguard config on start, if it doesn't exist initServerConfig(db, tmplDir) + // Check if subnet ranges are valid for the server configuration + // Remove any non-valid CIDRs + if err := util.ValidateAndFixSubnetRanges(db); err != nil { + panic(err) + } + + // Print valid ranges + if lvl, _ := util.ParseLogLevel(util.LookupEnvOrString(util.LogLevel, "INFO")); lvl <= log.INFO { + fmt.Println("Valid subnet ranges:", util.GetSubnetRangesString()) + } + // register routes app := router.New(tmplDir, extraData, util.SessionSecret) diff --git a/util/config.go b/util/config.go index 6d5a8df..2c7791f 100644 --- a/util/config.go +++ b/util/config.go @@ -1,24 +1,31 @@ package util -import "strings" +import ( + "net" + "strings" + + "github.com/labstack/gommon/log" +) // Runtime config var ( - DisableLogin bool - BindAddress string - SmtpHostname string - SmtpPort int - SmtpUsername string - SmtpPassword string - SmtpNoTLSCheck bool - SmtpEncryption string - SmtpAuthType string - SendgridApiKey string - EmailFrom string - EmailFromName string - SessionSecret []byte - WgConfTemplate string - BasePath string + DisableLogin bool + BindAddress string + SmtpHostname string + SmtpPort int + SmtpUsername string + SmtpPassword string + SmtpNoTLSCheck bool + SmtpEncryption string + SmtpAuthType string + SendgridApiKey string + EmailFrom string + EmailFromName string + SessionSecret []byte + WgConfTemplate string + BasePath string + SubnetRanges map[string]([]*net.IPNet) + SubnetRangesOrder []string ) const ( @@ -66,3 +73,45 @@ func ParseBasePath(basePath string) string { } return basePath } + +func ParseSubnetRanges(subnetRangesStr string) map[string]([]*net.IPNet) { + subnetRanges := map[string]([]*net.IPNet){} + if subnetRangesStr == "" { + return subnetRanges + } + cidrSet := map[string]bool{} + subnetRangesStr = strings.TrimSpace(subnetRangesStr) + subnetRangesStr = strings.Trim(subnetRangesStr, ";:,") + ranges := strings.Split(subnetRangesStr, ";") + for _, rng := range ranges { + rng = strings.TrimSpace(rng) + rngSpl := strings.Split(rng, ":") + if len(rngSpl) != 2 { + log.Warnf("Unable to parse subnet range: %v. Skipped.", rng) + continue + } + rngName := strings.TrimSpace(rngSpl[0]) + subnetRanges[rngName] = make([]*net.IPNet, 0) + cidrs := strings.Split(rngSpl[1], ",") + for _, cidr := range cidrs { + cidr = strings.TrimSpace(cidr) + _, net, err := net.ParseCIDR(cidr) + if err != nil { + log.Warnf("[%v] Unable to parse CIDR: %v. Skipped.", rngName, cidr) + continue + } + if cidrSet[net.String()] { + log.Warnf("[%v] CIDR already exists: %v. Skipped.", rngName, net.String()) + continue + } + cidrSet[net.String()] = true + subnetRanges[rngName] = append(subnetRanges[rngName], net) + } + if len(subnetRanges[rngName]) == 0 { + delete(subnetRanges, rngName) + } else { + SubnetRangesOrder = append(SubnetRangesOrder, rngName) + } + } + return subnetRanges +} diff --git a/util/util.go b/util/util.go index f455fc8..431647d 100644 --- a/util/util.go +++ b/util/util.go @@ -95,6 +95,15 @@ func ClientDefaultsFromEnv() model.ClientDefaults { return clientDefaults } +// ContainsCIDR to check if ipnet1 contains ipnet2 +// https://stackoverflow.com/a/40406619/6111641 +// https://go.dev/play/p/Q4J-JEN3sF +func ContainsCIDR(ipnet1, ipnet2 *net.IPNet) bool { + ones1, _ := ipnet1.Mask.Size() + ones2, _ := ipnet2.Mask.Size() + return ones1 <= ones2 && ipnet1.Contains(ipnet2.IP) +} + // ValidateCIDR to validate a network CIDR func ValidateCIDR(cidr string) bool { _, _, err := net.ParseCIDR(cidr) @@ -384,6 +393,88 @@ func ValidateIPAllocation(serverAddresses []string, ipAllocatedList []string, ip return true, nil } +// ValidateAndFixSubnetRanges to check if subnet ranges are valid for the server configuration +// Removes all non-valid CIDRs +func ValidateAndFixSubnetRanges(db store.IStore) error { + if len(SubnetRangesOrder) == 0 { + return nil + } + + server, err := db.GetServer() + if err != nil { + return err + } + var serverSubnets []*net.IPNet + for _, addr := range server.Interface.Addresses { + addr = strings.TrimSpace(addr) + _, net, err := net.ParseCIDR(addr) + if err != nil { + return err + } + serverSubnets = append(serverSubnets, net) + } + + for _, rng := range SubnetRangesOrder { + cidrs := SubnetRanges[rng] + if len(cidrs) > 0 { + newCIDRs := make([]*net.IPNet, 0) + for _, cidr := range cidrs { + valid := false + + for _, serverSubnet := range serverSubnets { + if ContainsCIDR(serverSubnet, cidr) { + valid = true + break + } + } + + if valid { + newCIDRs = append(newCIDRs, cidr) + } else { + log.Warnf("[%v] CIDR is outside of all server subnets: %v. Removed.", rng, cidr) + } + } + + if len(newCIDRs) > 0 { + SubnetRanges[rng] = newCIDRs + } else { + delete(SubnetRanges, rng) + log.Warnf("[%v] No valid CIDRs in this subnet range. Removed.", rng) + } + } + } + + return nil +} + +// GetSubnetRangesString to get a formatted string, representing active subnet ranges +func GetSubnetRangesString() string { + if len(SubnetRangesOrder) == 0 { + return "" + } + + strB := strings.Builder{} + + for _, rng := range SubnetRangesOrder { + cidrs := SubnetRanges[rng] + if len(cidrs) > 0 { + strB.WriteString(rng) + strB.WriteString(":[") + first := true + for _, cidr := range cidrs { + if !first { + strB.WriteString(", ") + } + strB.WriteString(cidr.String()) + first = false + } + strB.WriteString("] ") + } + } + + return strings.TrimSpace(strB.String()) +} + // WriteWireGuardServerConfig to write Wireguard server config. e.g. wg0.conf func WriteWireGuardServerConfig(tmplDir fs.FS, serverConfig model.Server, clientDataList []model.ClientData, usersList []model.User, globalSettings model.GlobalSetting) error { var tmplWireguardConf string