mirror of
https://github.com/ngoduykhanh/wireguard-ui
synced 2024-05-07 08:16:34 +02:00
Parse and validate subnet ranges
This commit is contained in:
parent
e73047b14f
commit
92333a08d8
8
Makefile
Normal file
8
Makefile
Normal file
|
@ -0,0 +1,8 @@
|
|||
build:
|
||||
CGO_ENABLED=0 go build -v -ldflags="-s -w"
|
||||
|
||||
run:
|
||||
./wireguard-ui
|
||||
|
||||
clean:
|
||||
rm -f wireguard-ui
|
|
@ -381,9 +381,9 @@ func GetClient(db store.IStore) echo.HandlerFunc {
|
|||
}
|
||||
|
||||
qrCodeSettings := model.QRCodeSettings{
|
||||
Enabled: true,
|
||||
IncludeDNS: true,
|
||||
IncludeMTU: true,
|
||||
Enabled: true,
|
||||
IncludeDNS: true,
|
||||
IncludeMTU: true,
|
||||
}
|
||||
|
||||
clientData, err := db.GetClientByID(clientID, qrCodeSettings)
|
||||
|
@ -513,9 +513,9 @@ func EmailClient(db store.IStore, mailer emailer.Emailer, emailSubject, emailCon
|
|||
}
|
||||
|
||||
qrCodeSettings := model.QRCodeSettings{
|
||||
Enabled: true,
|
||||
IncludeDNS: true,
|
||||
IncludeMTU: true,
|
||||
Enabled: true,
|
||||
IncludeDNS: true,
|
||||
IncludeMTU: true,
|
||||
}
|
||||
clientData, err := db.GetClientByID(payload.ID, qrCodeSettings)
|
||||
if err != nil {
|
||||
|
@ -1078,7 +1078,6 @@ func ApplyServerConfig(db store.IStore, tmplDir fs.FS) echo.HandlerFunc {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
// GetHashesChanges handler returns if database hashes have changed
|
||||
func GetHashesChanges(db store.IStore) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
|
|
15
main.go
15
main.go
|
@ -45,6 +45,7 @@ var (
|
|||
flagSessionSecret string = util.RandomString(32)
|
||||
flagWgConfTemplate string
|
||||
flagBasePath string
|
||||
flagSubnetRanges string
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -81,6 +82,7 @@ func init() {
|
|||
flag.StringVar(&flagEmailFromName, "email-from-name", util.LookupEnvOrString("EMAIL_FROM_NAME", flagEmailFromName), "'From' email name.")
|
||||
flag.StringVar(&flagWgConfTemplate, "wg-conf-template", util.LookupEnvOrString("WG_CONF_TEMPLATE", flagWgConfTemplate), "Path to custom wg.conf template.")
|
||||
flag.StringVar(&flagBasePath, "base-path", util.LookupEnvOrString("BASE_PATH", flagBasePath), "The base path of the URL")
|
||||
flag.StringVar(&flagSubnetRanges, "subnet-ranges", util.LookupEnvOrString("SUBNET_RANGES", flagSubnetRanges), "IP ranges to choose from when assigning an IP for a client.")
|
||||
|
||||
var (
|
||||
smtpPasswordLookup = util.LookupEnvOrString("SMTP_PASSWORD", flagSmtpPassword)
|
||||
|
@ -127,6 +129,7 @@ func init() {
|
|||
util.SessionSecret = []byte(flagSessionSecret)
|
||||
util.WgConfTemplate = flagWgConfTemplate
|
||||
util.BasePath = util.ParseBasePath(flagBasePath)
|
||||
util.SubnetRanges = util.ParseSubnetRanges(flagSubnetRanges)
|
||||
|
||||
// print only if log level is INFO or lower
|
||||
if lvl, _ := util.ParseLogLevel(util.LookupEnvOrString(util.LogLevel, "INFO")); lvl <= log.INFO {
|
||||
|
@ -145,6 +148,7 @@ func init() {
|
|||
//fmt.Println("Session secret\t:", util.SessionSecret)
|
||||
fmt.Println("Custom wg.conf\t:", util.WgConfTemplate)
|
||||
fmt.Println("Base path\t:", util.BasePath+"/")
|
||||
fmt.Println("Subnet ranges\t:", util.GetSubnetRangesString())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -170,6 +174,17 @@ func main() {
|
|||
// create the wireguard config on start, if it doesn't exist
|
||||
initServerConfig(db, tmplDir)
|
||||
|
||||
// Check if subnet ranges are valid for the server configuration
|
||||
// Remove any non-valid CIDRs
|
||||
if err := util.ValidateAndFixSubnetRanges(db); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Print valid ranges
|
||||
if lvl, _ := util.ParseLogLevel(util.LookupEnvOrString(util.LogLevel, "INFO")); lvl <= log.INFO {
|
||||
fmt.Println("Valid subnet ranges:", util.GetSubnetRangesString())
|
||||
}
|
||||
|
||||
// register routes
|
||||
app := router.New(tmplDir, extraData, util.SessionSecret)
|
||||
|
||||
|
|
|
@ -1,24 +1,31 @@
|
|||
package util
|
||||
|
||||
import "strings"
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/gommon/log"
|
||||
)
|
||||
|
||||
// Runtime config
|
||||
var (
|
||||
DisableLogin bool
|
||||
BindAddress string
|
||||
SmtpHostname string
|
||||
SmtpPort int
|
||||
SmtpUsername string
|
||||
SmtpPassword string
|
||||
SmtpNoTLSCheck bool
|
||||
SmtpEncryption string
|
||||
SmtpAuthType string
|
||||
SendgridApiKey string
|
||||
EmailFrom string
|
||||
EmailFromName string
|
||||
SessionSecret []byte
|
||||
WgConfTemplate string
|
||||
BasePath string
|
||||
DisableLogin bool
|
||||
BindAddress string
|
||||
SmtpHostname string
|
||||
SmtpPort int
|
||||
SmtpUsername string
|
||||
SmtpPassword string
|
||||
SmtpNoTLSCheck bool
|
||||
SmtpEncryption string
|
||||
SmtpAuthType string
|
||||
SendgridApiKey string
|
||||
EmailFrom string
|
||||
EmailFromName string
|
||||
SessionSecret []byte
|
||||
WgConfTemplate string
|
||||
BasePath string
|
||||
SubnetRanges map[string]([]*net.IPNet)
|
||||
SubnetRangesOrder []string
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -66,3 +73,45 @@ func ParseBasePath(basePath string) string {
|
|||
}
|
||||
return basePath
|
||||
}
|
||||
|
||||
func ParseSubnetRanges(subnetRangesStr string) map[string]([]*net.IPNet) {
|
||||
subnetRanges := map[string]([]*net.IPNet){}
|
||||
if subnetRangesStr == "" {
|
||||
return subnetRanges
|
||||
}
|
||||
cidrSet := map[string]bool{}
|
||||
subnetRangesStr = strings.TrimSpace(subnetRangesStr)
|
||||
subnetRangesStr = strings.Trim(subnetRangesStr, ";:,")
|
||||
ranges := strings.Split(subnetRangesStr, ";")
|
||||
for _, rng := range ranges {
|
||||
rng = strings.TrimSpace(rng)
|
||||
rngSpl := strings.Split(rng, ":")
|
||||
if len(rngSpl) != 2 {
|
||||
log.Warnf("Unable to parse subnet range: %v. Skipped.", rng)
|
||||
continue
|
||||
}
|
||||
rngName := strings.TrimSpace(rngSpl[0])
|
||||
subnetRanges[rngName] = make([]*net.IPNet, 0)
|
||||
cidrs := strings.Split(rngSpl[1], ",")
|
||||
for _, cidr := range cidrs {
|
||||
cidr = strings.TrimSpace(cidr)
|
||||
_, net, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
log.Warnf("[%v] Unable to parse CIDR: %v. Skipped.", rngName, cidr)
|
||||
continue
|
||||
}
|
||||
if cidrSet[net.String()] {
|
||||
log.Warnf("[%v] CIDR already exists: %v. Skipped.", rngName, net.String())
|
||||
continue
|
||||
}
|
||||
cidrSet[net.String()] = true
|
||||
subnetRanges[rngName] = append(subnetRanges[rngName], net)
|
||||
}
|
||||
if len(subnetRanges[rngName]) == 0 {
|
||||
delete(subnetRanges, rngName)
|
||||
} else {
|
||||
SubnetRangesOrder = append(SubnetRangesOrder, rngName)
|
||||
}
|
||||
}
|
||||
return subnetRanges
|
||||
}
|
||||
|
|
91
util/util.go
91
util/util.go
|
@ -95,6 +95,15 @@ func ClientDefaultsFromEnv() model.ClientDefaults {
|
|||
return clientDefaults
|
||||
}
|
||||
|
||||
// ContainsCIDR to check if ipnet1 contains ipnet2
|
||||
// https://stackoverflow.com/a/40406619/6111641
|
||||
// https://go.dev/play/p/Q4J-JEN3sF
|
||||
func ContainsCIDR(ipnet1, ipnet2 *net.IPNet) bool {
|
||||
ones1, _ := ipnet1.Mask.Size()
|
||||
ones2, _ := ipnet2.Mask.Size()
|
||||
return ones1 <= ones2 && ipnet1.Contains(ipnet2.IP)
|
||||
}
|
||||
|
||||
// ValidateCIDR to validate a network CIDR
|
||||
func ValidateCIDR(cidr string) bool {
|
||||
_, _, err := net.ParseCIDR(cidr)
|
||||
|
@ -384,6 +393,88 @@ func ValidateIPAllocation(serverAddresses []string, ipAllocatedList []string, ip
|
|||
return true, nil
|
||||
}
|
||||
|
||||
// ValidateAndFixSubnetRanges to check if subnet ranges are valid for the server configuration
|
||||
// Removes all non-valid CIDRs
|
||||
func ValidateAndFixSubnetRanges(db store.IStore) error {
|
||||
if len(SubnetRangesOrder) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
server, err := db.GetServer()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var serverSubnets []*net.IPNet
|
||||
for _, addr := range server.Interface.Addresses {
|
||||
addr = strings.TrimSpace(addr)
|
||||
_, net, err := net.ParseCIDR(addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
serverSubnets = append(serverSubnets, net)
|
||||
}
|
||||
|
||||
for _, rng := range SubnetRangesOrder {
|
||||
cidrs := SubnetRanges[rng]
|
||||
if len(cidrs) > 0 {
|
||||
newCIDRs := make([]*net.IPNet, 0)
|
||||
for _, cidr := range cidrs {
|
||||
valid := false
|
||||
|
||||
for _, serverSubnet := range serverSubnets {
|
||||
if ContainsCIDR(serverSubnet, cidr) {
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if valid {
|
||||
newCIDRs = append(newCIDRs, cidr)
|
||||
} else {
|
||||
log.Warnf("[%v] CIDR is outside of all server subnets: %v. Removed.", rng, cidr)
|
||||
}
|
||||
}
|
||||
|
||||
if len(newCIDRs) > 0 {
|
||||
SubnetRanges[rng] = newCIDRs
|
||||
} else {
|
||||
delete(SubnetRanges, rng)
|
||||
log.Warnf("[%v] No valid CIDRs in this subnet range. Removed.", rng)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSubnetRangesString to get a formatted string, representing active subnet ranges
|
||||
func GetSubnetRangesString() string {
|
||||
if len(SubnetRangesOrder) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
strB := strings.Builder{}
|
||||
|
||||
for _, rng := range SubnetRangesOrder {
|
||||
cidrs := SubnetRanges[rng]
|
||||
if len(cidrs) > 0 {
|
||||
strB.WriteString(rng)
|
||||
strB.WriteString(":[")
|
||||
first := true
|
||||
for _, cidr := range cidrs {
|
||||
if !first {
|
||||
strB.WriteString(", ")
|
||||
}
|
||||
strB.WriteString(cidr.String())
|
||||
first = false
|
||||
}
|
||||
strB.WriteString("] ")
|
||||
}
|
||||
}
|
||||
|
||||
return strings.TrimSpace(strB.String())
|
||||
}
|
||||
|
||||
// WriteWireGuardServerConfig to write Wireguard server config. e.g. wg0.conf
|
||||
func WriteWireGuardServerConfig(tmplDir fs.FS, serverConfig model.Server, clientDataList []model.ClientData, usersList []model.User, globalSettings model.GlobalSetting) error {
|
||||
var tmplWireguardConf string
|
||||
|
|
Loading…
Reference in a new issue