Parse and validate subnet ranges

This commit is contained in:
0xCA 2023-11-04 22:47:08 +05:00
parent e73047b14f
commit 92333a08d8
5 changed files with 185 additions and 23 deletions

8
Makefile Normal file
View file

@ -0,0 +1,8 @@
build:
CGO_ENABLED=0 go build -v -ldflags="-s -w"
run:
./wireguard-ui
clean:
rm -f wireguard-ui

View file

@ -381,9 +381,9 @@ func GetClient(db store.IStore) echo.HandlerFunc {
} }
qrCodeSettings := model.QRCodeSettings{ qrCodeSettings := model.QRCodeSettings{
Enabled: true, Enabled: true,
IncludeDNS: true, IncludeDNS: true,
IncludeMTU: true, IncludeMTU: true,
} }
clientData, err := db.GetClientByID(clientID, qrCodeSettings) clientData, err := db.GetClientByID(clientID, qrCodeSettings)
@ -513,9 +513,9 @@ func EmailClient(db store.IStore, mailer emailer.Emailer, emailSubject, emailCon
} }
qrCodeSettings := model.QRCodeSettings{ qrCodeSettings := model.QRCodeSettings{
Enabled: true, Enabled: true,
IncludeDNS: true, IncludeDNS: true,
IncludeMTU: true, IncludeMTU: true,
} }
clientData, err := db.GetClientByID(payload.ID, qrCodeSettings) clientData, err := db.GetClientByID(payload.ID, qrCodeSettings)
if err != nil { if err != nil {
@ -1078,7 +1078,6 @@ func ApplyServerConfig(db store.IStore, tmplDir fs.FS) echo.HandlerFunc {
} }
} }
// GetHashesChanges handler returns if database hashes have changed // GetHashesChanges handler returns if database hashes have changed
func GetHashesChanges(db store.IStore) echo.HandlerFunc { func GetHashesChanges(db store.IStore) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {

15
main.go
View file

@ -45,6 +45,7 @@ var (
flagSessionSecret string = util.RandomString(32) flagSessionSecret string = util.RandomString(32)
flagWgConfTemplate string flagWgConfTemplate string
flagBasePath string flagBasePath string
flagSubnetRanges string
) )
const ( const (
@ -81,6 +82,7 @@ func init() {
flag.StringVar(&flagEmailFromName, "email-from-name", util.LookupEnvOrString("EMAIL_FROM_NAME", flagEmailFromName), "'From' email name.") flag.StringVar(&flagEmailFromName, "email-from-name", util.LookupEnvOrString("EMAIL_FROM_NAME", flagEmailFromName), "'From' email name.")
flag.StringVar(&flagWgConfTemplate, "wg-conf-template", util.LookupEnvOrString("WG_CONF_TEMPLATE", flagWgConfTemplate), "Path to custom wg.conf template.") flag.StringVar(&flagWgConfTemplate, "wg-conf-template", util.LookupEnvOrString("WG_CONF_TEMPLATE", flagWgConfTemplate), "Path to custom wg.conf template.")
flag.StringVar(&flagBasePath, "base-path", util.LookupEnvOrString("BASE_PATH", flagBasePath), "The base path of the URL") flag.StringVar(&flagBasePath, "base-path", util.LookupEnvOrString("BASE_PATH", flagBasePath), "The base path of the URL")
flag.StringVar(&flagSubnetRanges, "subnet-ranges", util.LookupEnvOrString("SUBNET_RANGES", flagSubnetRanges), "IP ranges to choose from when assigning an IP for a client.")
var ( var (
smtpPasswordLookup = util.LookupEnvOrString("SMTP_PASSWORD", flagSmtpPassword) smtpPasswordLookup = util.LookupEnvOrString("SMTP_PASSWORD", flagSmtpPassword)
@ -127,6 +129,7 @@ func init() {
util.SessionSecret = []byte(flagSessionSecret) util.SessionSecret = []byte(flagSessionSecret)
util.WgConfTemplate = flagWgConfTemplate util.WgConfTemplate = flagWgConfTemplate
util.BasePath = util.ParseBasePath(flagBasePath) util.BasePath = util.ParseBasePath(flagBasePath)
util.SubnetRanges = util.ParseSubnetRanges(flagSubnetRanges)
// print only if log level is INFO or lower // print only if log level is INFO or lower
if lvl, _ := util.ParseLogLevel(util.LookupEnvOrString(util.LogLevel, "INFO")); lvl <= log.INFO { if lvl, _ := util.ParseLogLevel(util.LookupEnvOrString(util.LogLevel, "INFO")); lvl <= log.INFO {
@ -145,6 +148,7 @@ func init() {
//fmt.Println("Session secret\t:", util.SessionSecret) //fmt.Println("Session secret\t:", util.SessionSecret)
fmt.Println("Custom wg.conf\t:", util.WgConfTemplate) fmt.Println("Custom wg.conf\t:", util.WgConfTemplate)
fmt.Println("Base path\t:", util.BasePath+"/") fmt.Println("Base path\t:", util.BasePath+"/")
fmt.Println("Subnet ranges\t:", util.GetSubnetRangesString())
} }
} }
@ -170,6 +174,17 @@ func main() {
// create the wireguard config on start, if it doesn't exist // create the wireguard config on start, if it doesn't exist
initServerConfig(db, tmplDir) initServerConfig(db, tmplDir)
// Check if subnet ranges are valid for the server configuration
// Remove any non-valid CIDRs
if err := util.ValidateAndFixSubnetRanges(db); err != nil {
panic(err)
}
// Print valid ranges
if lvl, _ := util.ParseLogLevel(util.LookupEnvOrString(util.LogLevel, "INFO")); lvl <= log.INFO {
fmt.Println("Valid subnet ranges:", util.GetSubnetRangesString())
}
// register routes // register routes
app := router.New(tmplDir, extraData, util.SessionSecret) app := router.New(tmplDir, extraData, util.SessionSecret)

View file

@ -1,24 +1,31 @@
package util package util
import "strings" import (
"net"
"strings"
"github.com/labstack/gommon/log"
)
// Runtime config // Runtime config
var ( var (
DisableLogin bool DisableLogin bool
BindAddress string BindAddress string
SmtpHostname string SmtpHostname string
SmtpPort int SmtpPort int
SmtpUsername string SmtpUsername string
SmtpPassword string SmtpPassword string
SmtpNoTLSCheck bool SmtpNoTLSCheck bool
SmtpEncryption string SmtpEncryption string
SmtpAuthType string SmtpAuthType string
SendgridApiKey string SendgridApiKey string
EmailFrom string EmailFrom string
EmailFromName string EmailFromName string
SessionSecret []byte SessionSecret []byte
WgConfTemplate string WgConfTemplate string
BasePath string BasePath string
SubnetRanges map[string]([]*net.IPNet)
SubnetRangesOrder []string
) )
const ( const (
@ -66,3 +73,45 @@ func ParseBasePath(basePath string) string {
} }
return basePath return basePath
} }
func ParseSubnetRanges(subnetRangesStr string) map[string]([]*net.IPNet) {
subnetRanges := map[string]([]*net.IPNet){}
if subnetRangesStr == "" {
return subnetRanges
}
cidrSet := map[string]bool{}
subnetRangesStr = strings.TrimSpace(subnetRangesStr)
subnetRangesStr = strings.Trim(subnetRangesStr, ";:,")
ranges := strings.Split(subnetRangesStr, ";")
for _, rng := range ranges {
rng = strings.TrimSpace(rng)
rngSpl := strings.Split(rng, ":")
if len(rngSpl) != 2 {
log.Warnf("Unable to parse subnet range: %v. Skipped.", rng)
continue
}
rngName := strings.TrimSpace(rngSpl[0])
subnetRanges[rngName] = make([]*net.IPNet, 0)
cidrs := strings.Split(rngSpl[1], ",")
for _, cidr := range cidrs {
cidr = strings.TrimSpace(cidr)
_, net, err := net.ParseCIDR(cidr)
if err != nil {
log.Warnf("[%v] Unable to parse CIDR: %v. Skipped.", rngName, cidr)
continue
}
if cidrSet[net.String()] {
log.Warnf("[%v] CIDR already exists: %v. Skipped.", rngName, net.String())
continue
}
cidrSet[net.String()] = true
subnetRanges[rngName] = append(subnetRanges[rngName], net)
}
if len(subnetRanges[rngName]) == 0 {
delete(subnetRanges, rngName)
} else {
SubnetRangesOrder = append(SubnetRangesOrder, rngName)
}
}
return subnetRanges
}

View file

@ -95,6 +95,15 @@ func ClientDefaultsFromEnv() model.ClientDefaults {
return clientDefaults return clientDefaults
} }
// ContainsCIDR to check if ipnet1 contains ipnet2
// https://stackoverflow.com/a/40406619/6111641
// https://go.dev/play/p/Q4J-JEN3sF
func ContainsCIDR(ipnet1, ipnet2 *net.IPNet) bool {
ones1, _ := ipnet1.Mask.Size()
ones2, _ := ipnet2.Mask.Size()
return ones1 <= ones2 && ipnet1.Contains(ipnet2.IP)
}
// ValidateCIDR to validate a network CIDR // ValidateCIDR to validate a network CIDR
func ValidateCIDR(cidr string) bool { func ValidateCIDR(cidr string) bool {
_, _, err := net.ParseCIDR(cidr) _, _, err := net.ParseCIDR(cidr)
@ -384,6 +393,88 @@ func ValidateIPAllocation(serverAddresses []string, ipAllocatedList []string, ip
return true, nil return true, nil
} }
// ValidateAndFixSubnetRanges to check if subnet ranges are valid for the server configuration
// Removes all non-valid CIDRs
func ValidateAndFixSubnetRanges(db store.IStore) error {
if len(SubnetRangesOrder) == 0 {
return nil
}
server, err := db.GetServer()
if err != nil {
return err
}
var serverSubnets []*net.IPNet
for _, addr := range server.Interface.Addresses {
addr = strings.TrimSpace(addr)
_, net, err := net.ParseCIDR(addr)
if err != nil {
return err
}
serverSubnets = append(serverSubnets, net)
}
for _, rng := range SubnetRangesOrder {
cidrs := SubnetRanges[rng]
if len(cidrs) > 0 {
newCIDRs := make([]*net.IPNet, 0)
for _, cidr := range cidrs {
valid := false
for _, serverSubnet := range serverSubnets {
if ContainsCIDR(serverSubnet, cidr) {
valid = true
break
}
}
if valid {
newCIDRs = append(newCIDRs, cidr)
} else {
log.Warnf("[%v] CIDR is outside of all server subnets: %v. Removed.", rng, cidr)
}
}
if len(newCIDRs) > 0 {
SubnetRanges[rng] = newCIDRs
} else {
delete(SubnetRanges, rng)
log.Warnf("[%v] No valid CIDRs in this subnet range. Removed.", rng)
}
}
}
return nil
}
// GetSubnetRangesString to get a formatted string, representing active subnet ranges
func GetSubnetRangesString() string {
if len(SubnetRangesOrder) == 0 {
return ""
}
strB := strings.Builder{}
for _, rng := range SubnetRangesOrder {
cidrs := SubnetRanges[rng]
if len(cidrs) > 0 {
strB.WriteString(rng)
strB.WriteString(":[")
first := true
for _, cidr := range cidrs {
if !first {
strB.WriteString(", ")
}
strB.WriteString(cidr.String())
first = false
}
strB.WriteString("] ")
}
}
return strings.TrimSpace(strB.String())
}
// WriteWireGuardServerConfig to write Wireguard server config. e.g. wg0.conf // WriteWireGuardServerConfig to write Wireguard server config. e.g. wg0.conf
func WriteWireGuardServerConfig(tmplDir fs.FS, serverConfig model.Server, clientDataList []model.ClientData, usersList []model.User, globalSettings model.GlobalSetting) error { func WriteWireGuardServerConfig(tmplDir fs.FS, serverConfig model.Server, clientDataList []model.ClientData, usersList []model.User, globalSettings model.GlobalSetting) error {
var tmplWireguardConf string var tmplWireguardConf string