Support reloading GeoIP overrides.

This commit is contained in:
Joachim Bauch 2024-05-28 12:26:05 +02:00
commit 15edeca814
No known key found for this signature in database
GPG key ID: 77C1D22D53E15F02
2 changed files with 81 additions and 47 deletions

View file

@ -35,6 +35,7 @@ import (
"sync"
"time"
"github.com/dlintw/goconf"
"github.com/oschwald/maxminddb-golang"
)
@ -276,3 +277,63 @@ func IsValidContinent(continent string) bool {
return false
}
}
func LoadGeoIPOverrides(config *goconf.ConfigFile, ignoreErrors bool) (map[*net.IPNet]string, error) {
options, _ := GetStringOptions(config, "geoip-overrides", true)
if len(options) == 0 {
return nil, nil
}
var err error
geoipOverrides := make(map[*net.IPNet]string, len(options))
for option, value := range options {
var ip net.IP
var ipNet *net.IPNet
if strings.Contains(option, "/") {
_, ipNet, err = net.ParseCIDR(option)
if err != nil {
if ignoreErrors {
log.Printf("could not parse CIDR %s (%s), skipping", option, err)
continue
}
return nil, fmt.Errorf("could not parse CIDR %s: %s", option, err)
}
} else {
ip = net.ParseIP(option)
if ip == nil {
if ignoreErrors {
log.Printf("could not parse IP %s, skipping", option)
continue
}
return nil, fmt.Errorf("could not parse IP %s", option)
}
var mask net.IPMask
if ipv4 := ip.To4(); ipv4 != nil {
mask = net.CIDRMask(32, 32)
} else {
mask = net.CIDRMask(128, 128)
}
ipNet = &net.IPNet{
IP: ip,
Mask: mask,
}
}
value = strings.ToUpper(strings.TrimSpace(value))
if value == "" {
log.Printf("IP %s doesn't have a country assigned, skipping", option)
continue
} else if !IsValidCountry(value) {
log.Printf("Country %s for IP %s is invalid, skipping", value, option)
continue
}
log.Printf("Using country %s for %s", value, ipNet)
geoipOverrides[ipNet] = value
}
return geoipOverrides, nil
}

67
hub.go
View file

@ -170,7 +170,7 @@ type Hub struct {
trustedProxies atomic.Pointer[AllowedIps]
geoip *GeoLookup
geoipOverrides map[*net.IPNet]string
geoipOverrides atomic.Pointer[map[*net.IPNet]string]
geoipUpdating atomic.Bool
rpcServer *GrpcServer
@ -273,7 +273,6 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer
}
var geoip *GeoLookup
var geoipOverrides map[*net.IPNet]string
if geoipUrl != "" {
if strings.HasPrefix(geoipUrl, "file://") {
geoipUrl = geoipUrl[7:]
@ -290,46 +289,9 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer
log.Printf("Not using GeoIP database")
}
if options, _ := GetStringOptions(config, "geoip-overrides", true); len(options) > 0 {
geoipOverrides = make(map[*net.IPNet]string, len(options))
for option, value := range options {
var ip net.IP
var ipNet *net.IPNet
if strings.Contains(option, "/") {
_, ipNet, err = net.ParseCIDR(option)
if err != nil {
return nil, fmt.Errorf("could not parse CIDR %s: %s", option, err)
}
} else {
ip = net.ParseIP(option)
if ip == nil {
return nil, fmt.Errorf("could not parse IP %s", option)
}
var mask net.IPMask
if ipv4 := ip.To4(); ipv4 != nil {
mask = net.CIDRMask(32, 32)
} else {
mask = net.CIDRMask(128, 128)
}
ipNet = &net.IPNet{
IP: ip,
Mask: mask,
}
}
value = strings.ToUpper(strings.TrimSpace(value))
if value == "" {
log.Printf("IP %s doesn't have a country assigned, skipping", option)
continue
} else if !IsValidCountry(value) {
log.Printf("Country %s for IP %s is invalid, skipping", value, option)
continue
}
log.Printf("Using country %s for %s", value, ipNet)
geoipOverrides[ipNet] = value
}
geoipOverrides, err := LoadGeoIPOverrides(config, false)
if err != nil {
return nil, err
}
throttler, err := NewMemoryThrottler()
@ -379,8 +341,7 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer
backendTimeout: backendTimeout,
backend: backend,
geoip: geoip,
geoipOverrides: geoipOverrides,
geoip: geoip,
rpcServer: rpcServer,
rpcClients: rpcClients,
@ -388,6 +349,9 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer
throttler: throttler,
}
hub.trustedProxies.Store(trustedProxiesIps)
if len(geoipOverrides) > 0 {
hub.geoipOverrides.Store(&geoipOverrides)
}
hub.setWelcomeMessage(&ServerMessage{
Type: "welcome",
Welcome: NewWelcomeServerMessage(version, DefaultWelcomeFeatures...),
@ -526,6 +490,13 @@ func (h *Hub) Reload(config *goconf.ConfigFile) {
log.Printf("Error parsing trusted proxies from \"%s\": %s", trustedProxies, err)
}
geoipOverrides, _ := LoadGeoIPOverrides(config, true)
if len(geoipOverrides) > 0 {
h.geoipOverrides.Store(&geoipOverrides)
} else {
h.geoipOverrides.Store(nil)
}
if h.mcu != nil {
h.mcu.Reload(config)
}
@ -2685,9 +2656,11 @@ func (h *Hub) OnLookupCountry(client HandlerClient) string {
return noCountry
}
for overrideNet, country := range h.geoipOverrides {
if overrideNet.Contains(ip) {
return country
if overrides := h.geoipOverrides.Load(); overrides != nil {
for overrideNet, country := range *overrides {
if overrideNet.Contains(ip) {
return country
}
}
}