mirror of
https://github.com/ngoduykhanh/wireguard-ui
synced 2024-06-02 14:02:13 +02:00
Parse and validate subnet ranges
This commit is contained in:
parent
e73047b14f
commit
92333a08d8
8
Makefile
Normal file
8
Makefile
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
build:
|
||||||
|
CGO_ENABLED=0 go build -v -ldflags="-s -w"
|
||||||
|
|
||||||
|
run:
|
||||||
|
./wireguard-ui
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -f wireguard-ui
|
|
@ -381,9 +381,9 @@ func GetClient(db store.IStore) echo.HandlerFunc {
|
||||||
}
|
}
|
||||||
|
|
||||||
qrCodeSettings := model.QRCodeSettings{
|
qrCodeSettings := model.QRCodeSettings{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
IncludeDNS: true,
|
IncludeDNS: true,
|
||||||
IncludeMTU: true,
|
IncludeMTU: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
clientData, err := db.GetClientByID(clientID, qrCodeSettings)
|
clientData, err := db.GetClientByID(clientID, qrCodeSettings)
|
||||||
|
@ -513,9 +513,9 @@ func EmailClient(db store.IStore, mailer emailer.Emailer, emailSubject, emailCon
|
||||||
}
|
}
|
||||||
|
|
||||||
qrCodeSettings := model.QRCodeSettings{
|
qrCodeSettings := model.QRCodeSettings{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
IncludeDNS: true,
|
IncludeDNS: true,
|
||||||
IncludeMTU: true,
|
IncludeMTU: true,
|
||||||
}
|
}
|
||||||
clientData, err := db.GetClientByID(payload.ID, qrCodeSettings)
|
clientData, err := db.GetClientByID(payload.ID, qrCodeSettings)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1078,7 +1078,6 @@ func ApplyServerConfig(db store.IStore, tmplDir fs.FS) echo.HandlerFunc {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// GetHashesChanges handler returns if database hashes have changed
|
// GetHashesChanges handler returns if database hashes have changed
|
||||||
func GetHashesChanges(db store.IStore) echo.HandlerFunc {
|
func GetHashesChanges(db store.IStore) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
|
|
15
main.go
15
main.go
|
@ -45,6 +45,7 @@ var (
|
||||||
flagSessionSecret string = util.RandomString(32)
|
flagSessionSecret string = util.RandomString(32)
|
||||||
flagWgConfTemplate string
|
flagWgConfTemplate string
|
||||||
flagBasePath string
|
flagBasePath string
|
||||||
|
flagSubnetRanges string
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -81,6 +82,7 @@ func init() {
|
||||||
flag.StringVar(&flagEmailFromName, "email-from-name", util.LookupEnvOrString("EMAIL_FROM_NAME", flagEmailFromName), "'From' email name.")
|
flag.StringVar(&flagEmailFromName, "email-from-name", util.LookupEnvOrString("EMAIL_FROM_NAME", flagEmailFromName), "'From' email name.")
|
||||||
flag.StringVar(&flagWgConfTemplate, "wg-conf-template", util.LookupEnvOrString("WG_CONF_TEMPLATE", flagWgConfTemplate), "Path to custom wg.conf template.")
|
flag.StringVar(&flagWgConfTemplate, "wg-conf-template", util.LookupEnvOrString("WG_CONF_TEMPLATE", flagWgConfTemplate), "Path to custom wg.conf template.")
|
||||||
flag.StringVar(&flagBasePath, "base-path", util.LookupEnvOrString("BASE_PATH", flagBasePath), "The base path of the URL")
|
flag.StringVar(&flagBasePath, "base-path", util.LookupEnvOrString("BASE_PATH", flagBasePath), "The base path of the URL")
|
||||||
|
flag.StringVar(&flagSubnetRanges, "subnet-ranges", util.LookupEnvOrString("SUBNET_RANGES", flagSubnetRanges), "IP ranges to choose from when assigning an IP for a client.")
|
||||||
|
|
||||||
var (
|
var (
|
||||||
smtpPasswordLookup = util.LookupEnvOrString("SMTP_PASSWORD", flagSmtpPassword)
|
smtpPasswordLookup = util.LookupEnvOrString("SMTP_PASSWORD", flagSmtpPassword)
|
||||||
|
@ -127,6 +129,7 @@ func init() {
|
||||||
util.SessionSecret = []byte(flagSessionSecret)
|
util.SessionSecret = []byte(flagSessionSecret)
|
||||||
util.WgConfTemplate = flagWgConfTemplate
|
util.WgConfTemplate = flagWgConfTemplate
|
||||||
util.BasePath = util.ParseBasePath(flagBasePath)
|
util.BasePath = util.ParseBasePath(flagBasePath)
|
||||||
|
util.SubnetRanges = util.ParseSubnetRanges(flagSubnetRanges)
|
||||||
|
|
||||||
// print only if log level is INFO or lower
|
// print only if log level is INFO or lower
|
||||||
if lvl, _ := util.ParseLogLevel(util.LookupEnvOrString(util.LogLevel, "INFO")); lvl <= log.INFO {
|
if lvl, _ := util.ParseLogLevel(util.LookupEnvOrString(util.LogLevel, "INFO")); lvl <= log.INFO {
|
||||||
|
@ -145,6 +148,7 @@ func init() {
|
||||||
//fmt.Println("Session secret\t:", util.SessionSecret)
|
//fmt.Println("Session secret\t:", util.SessionSecret)
|
||||||
fmt.Println("Custom wg.conf\t:", util.WgConfTemplate)
|
fmt.Println("Custom wg.conf\t:", util.WgConfTemplate)
|
||||||
fmt.Println("Base path\t:", util.BasePath+"/")
|
fmt.Println("Base path\t:", util.BasePath+"/")
|
||||||
|
fmt.Println("Subnet ranges\t:", util.GetSubnetRangesString())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -170,6 +174,17 @@ func main() {
|
||||||
// create the wireguard config on start, if it doesn't exist
|
// create the wireguard config on start, if it doesn't exist
|
||||||
initServerConfig(db, tmplDir)
|
initServerConfig(db, tmplDir)
|
||||||
|
|
||||||
|
// Check if subnet ranges are valid for the server configuration
|
||||||
|
// Remove any non-valid CIDRs
|
||||||
|
if err := util.ValidateAndFixSubnetRanges(db); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print valid ranges
|
||||||
|
if lvl, _ := util.ParseLogLevel(util.LookupEnvOrString(util.LogLevel, "INFO")); lvl <= log.INFO {
|
||||||
|
fmt.Println("Valid subnet ranges:", util.GetSubnetRangesString())
|
||||||
|
}
|
||||||
|
|
||||||
// register routes
|
// register routes
|
||||||
app := router.New(tmplDir, extraData, util.SessionSecret)
|
app := router.New(tmplDir, extraData, util.SessionSecret)
|
||||||
|
|
||||||
|
|
|
@ -1,24 +1,31 @@
|
||||||
package util
|
package util
|
||||||
|
|
||||||
import "strings"
|
import (
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/labstack/gommon/log"
|
||||||
|
)
|
||||||
|
|
||||||
// Runtime config
|
// Runtime config
|
||||||
var (
|
var (
|
||||||
DisableLogin bool
|
DisableLogin bool
|
||||||
BindAddress string
|
BindAddress string
|
||||||
SmtpHostname string
|
SmtpHostname string
|
||||||
SmtpPort int
|
SmtpPort int
|
||||||
SmtpUsername string
|
SmtpUsername string
|
||||||
SmtpPassword string
|
SmtpPassword string
|
||||||
SmtpNoTLSCheck bool
|
SmtpNoTLSCheck bool
|
||||||
SmtpEncryption string
|
SmtpEncryption string
|
||||||
SmtpAuthType string
|
SmtpAuthType string
|
||||||
SendgridApiKey string
|
SendgridApiKey string
|
||||||
EmailFrom string
|
EmailFrom string
|
||||||
EmailFromName string
|
EmailFromName string
|
||||||
SessionSecret []byte
|
SessionSecret []byte
|
||||||
WgConfTemplate string
|
WgConfTemplate string
|
||||||
BasePath string
|
BasePath string
|
||||||
|
SubnetRanges map[string]([]*net.IPNet)
|
||||||
|
SubnetRangesOrder []string
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -66,3 +73,45 @@ func ParseBasePath(basePath string) string {
|
||||||
}
|
}
|
||||||
return basePath
|
return basePath
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ParseSubnetRanges(subnetRangesStr string) map[string]([]*net.IPNet) {
|
||||||
|
subnetRanges := map[string]([]*net.IPNet){}
|
||||||
|
if subnetRangesStr == "" {
|
||||||
|
return subnetRanges
|
||||||
|
}
|
||||||
|
cidrSet := map[string]bool{}
|
||||||
|
subnetRangesStr = strings.TrimSpace(subnetRangesStr)
|
||||||
|
subnetRangesStr = strings.Trim(subnetRangesStr, ";:,")
|
||||||
|
ranges := strings.Split(subnetRangesStr, ";")
|
||||||
|
for _, rng := range ranges {
|
||||||
|
rng = strings.TrimSpace(rng)
|
||||||
|
rngSpl := strings.Split(rng, ":")
|
||||||
|
if len(rngSpl) != 2 {
|
||||||
|
log.Warnf("Unable to parse subnet range: %v. Skipped.", rng)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rngName := strings.TrimSpace(rngSpl[0])
|
||||||
|
subnetRanges[rngName] = make([]*net.IPNet, 0)
|
||||||
|
cidrs := strings.Split(rngSpl[1], ",")
|
||||||
|
for _, cidr := range cidrs {
|
||||||
|
cidr = strings.TrimSpace(cidr)
|
||||||
|
_, net, err := net.ParseCIDR(cidr)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("[%v] Unable to parse CIDR: %v. Skipped.", rngName, cidr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if cidrSet[net.String()] {
|
||||||
|
log.Warnf("[%v] CIDR already exists: %v. Skipped.", rngName, net.String())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cidrSet[net.String()] = true
|
||||||
|
subnetRanges[rngName] = append(subnetRanges[rngName], net)
|
||||||
|
}
|
||||||
|
if len(subnetRanges[rngName]) == 0 {
|
||||||
|
delete(subnetRanges, rngName)
|
||||||
|
} else {
|
||||||
|
SubnetRangesOrder = append(SubnetRangesOrder, rngName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return subnetRanges
|
||||||
|
}
|
||||||
|
|
91
util/util.go
91
util/util.go
|
@ -95,6 +95,15 @@ func ClientDefaultsFromEnv() model.ClientDefaults {
|
||||||
return clientDefaults
|
return clientDefaults
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ContainsCIDR to check if ipnet1 contains ipnet2
|
||||||
|
// https://stackoverflow.com/a/40406619/6111641
|
||||||
|
// https://go.dev/play/p/Q4J-JEN3sF
|
||||||
|
func ContainsCIDR(ipnet1, ipnet2 *net.IPNet) bool {
|
||||||
|
ones1, _ := ipnet1.Mask.Size()
|
||||||
|
ones2, _ := ipnet2.Mask.Size()
|
||||||
|
return ones1 <= ones2 && ipnet1.Contains(ipnet2.IP)
|
||||||
|
}
|
||||||
|
|
||||||
// ValidateCIDR to validate a network CIDR
|
// ValidateCIDR to validate a network CIDR
|
||||||
func ValidateCIDR(cidr string) bool {
|
func ValidateCIDR(cidr string) bool {
|
||||||
_, _, err := net.ParseCIDR(cidr)
|
_, _, err := net.ParseCIDR(cidr)
|
||||||
|
@ -384,6 +393,88 @@ func ValidateIPAllocation(serverAddresses []string, ipAllocatedList []string, ip
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ValidateAndFixSubnetRanges to check if subnet ranges are valid for the server configuration
|
||||||
|
// Removes all non-valid CIDRs
|
||||||
|
func ValidateAndFixSubnetRanges(db store.IStore) error {
|
||||||
|
if len(SubnetRangesOrder) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
server, err := db.GetServer()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var serverSubnets []*net.IPNet
|
||||||
|
for _, addr := range server.Interface.Addresses {
|
||||||
|
addr = strings.TrimSpace(addr)
|
||||||
|
_, net, err := net.ParseCIDR(addr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
serverSubnets = append(serverSubnets, net)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rng := range SubnetRangesOrder {
|
||||||
|
cidrs := SubnetRanges[rng]
|
||||||
|
if len(cidrs) > 0 {
|
||||||
|
newCIDRs := make([]*net.IPNet, 0)
|
||||||
|
for _, cidr := range cidrs {
|
||||||
|
valid := false
|
||||||
|
|
||||||
|
for _, serverSubnet := range serverSubnets {
|
||||||
|
if ContainsCIDR(serverSubnet, cidr) {
|
||||||
|
valid = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if valid {
|
||||||
|
newCIDRs = append(newCIDRs, cidr)
|
||||||
|
} else {
|
||||||
|
log.Warnf("[%v] CIDR is outside of all server subnets: %v. Removed.", rng, cidr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(newCIDRs) > 0 {
|
||||||
|
SubnetRanges[rng] = newCIDRs
|
||||||
|
} else {
|
||||||
|
delete(SubnetRanges, rng)
|
||||||
|
log.Warnf("[%v] No valid CIDRs in this subnet range. Removed.", rng)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSubnetRangesString to get a formatted string, representing active subnet ranges
|
||||||
|
func GetSubnetRangesString() string {
|
||||||
|
if len(SubnetRangesOrder) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
strB := strings.Builder{}
|
||||||
|
|
||||||
|
for _, rng := range SubnetRangesOrder {
|
||||||
|
cidrs := SubnetRanges[rng]
|
||||||
|
if len(cidrs) > 0 {
|
||||||
|
strB.WriteString(rng)
|
||||||
|
strB.WriteString(":[")
|
||||||
|
first := true
|
||||||
|
for _, cidr := range cidrs {
|
||||||
|
if !first {
|
||||||
|
strB.WriteString(", ")
|
||||||
|
}
|
||||||
|
strB.WriteString(cidr.String())
|
||||||
|
first = false
|
||||||
|
}
|
||||||
|
strB.WriteString("] ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.TrimSpace(strB.String())
|
||||||
|
}
|
||||||
|
|
||||||
// WriteWireGuardServerConfig to write Wireguard server config. e.g. wg0.conf
|
// WriteWireGuardServerConfig to write Wireguard server config. e.g. wg0.conf
|
||||||
func WriteWireGuardServerConfig(tmplDir fs.FS, serverConfig model.Server, clientDataList []model.ClientData, usersList []model.User, globalSettings model.GlobalSetting) error {
|
func WriteWireGuardServerConfig(tmplDir fs.FS, serverConfig model.Server, clientDataList []model.ClientData, usersList []model.User, globalSettings model.GlobalSetting) error {
|
||||||
var tmplWireguardConf string
|
var tmplWireguardConf string
|
||||||
|
|
Loading…
Reference in a new issue