Use GeoIP overrides if no GeoIP database is configured.

This commit is contained in:
Joachim Bauch 2023-08-07 11:30:24 +02:00
parent e703982890
commit e9140178f9
No known key found for this signature in database
GPG key ID: 77C1D22D53E15F02
2 changed files with 107 additions and 74 deletions

144
hub.go
View file

@ -256,53 +256,53 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer
if err != nil {
return nil, err
}
if options, _ := config.GetOptions("geoip-overrides"); len(options) > 0 {
geoipOverrides = make(map[*net.IPNet]string)
for _, option := range options {
var ip net.IP
var ipNet *net.IPNet
if strings.Contains(option, "/") {
_, ipNet, err = net.ParseCIDR(option)
if err != nil {
return nil, fmt.Errorf("could not parse CIDR %s: %s", option, err)
}
} else {
ip = net.ParseIP(option)
if ip == nil {
return nil, fmt.Errorf("could not parse IP %s", option)
}
var mask net.IPMask
if ipv4 := ip.To4(); ipv4 != nil {
mask = net.CIDRMask(32, 32)
} else {
mask = net.CIDRMask(128, 128)
}
ipNet = &net.IPNet{
IP: ip,
Mask: mask,
}
}
value, _ := config.GetString("geoip-overrides", option)
value = strings.ToUpper(strings.TrimSpace(value))
if value == "" {
log.Printf("IP %s doesn't have a country assigned, skipping", option)
continue
} else if !IsValidCountry(value) {
log.Printf("Country %s for IP %s is invalid, skipping", value, option)
continue
}
log.Printf("Using country %s for %s", value, ipNet)
geoipOverrides[ipNet] = value
}
}
} else {
log.Printf("Not using GeoIP database")
}
if options, _ := config.GetOptions("geoip-overrides"); len(options) > 0 {
geoipOverrides = make(map[*net.IPNet]string)
for _, option := range options {
var ip net.IP
var ipNet *net.IPNet
if strings.Contains(option, "/") {
_, ipNet, err = net.ParseCIDR(option)
if err != nil {
return nil, fmt.Errorf("could not parse CIDR %s: %s", option, err)
}
} else {
ip = net.ParseIP(option)
if ip == nil {
return nil, fmt.Errorf("could not parse IP %s", option)
}
var mask net.IPMask
if ipv4 := ip.To4(); ipv4 != nil {
mask = net.CIDRMask(32, 32)
} else {
mask = net.CIDRMask(128, 128)
}
ipNet = &net.IPNet{
IP: ip,
Mask: mask,
}
}
value, _ := config.GetString("geoip-overrides", option)
value = strings.ToUpper(strings.TrimSpace(value))
if value == "" {
log.Printf("IP %s doesn't have a country assigned, skipping", option)
continue
} else if !IsValidCountry(value) {
log.Printf("Country %s for IP %s is invalid, skipping", value, option)
continue
}
log.Printf("Using country %s for %s", value, ipNet)
geoipOverrides[ipNet] = value
}
}
hub := &Hub{
events: events,
upgrader: websocket.Upgrader{
@ -2277,34 +2277,6 @@ func getRealUserIP(r *http.Request) string {
return r.RemoteAddr
}
func (h *Hub) lookupClientCountry(client *Client) string {
ip := net.ParseIP(client.RemoteAddr())
if ip == nil {
return noCountry
}
for overrideNet, country := range h.geoipOverrides {
if overrideNet.Contains(ip) {
return country
}
}
if ip.IsLoopback() {
return loopback
}
country, err := h.geoip.LookupCountry(ip)
if err != nil {
log.Printf("Could not lookup country for %s: %s", ip, err)
return unknownCountry
}
if country == "" {
return unknownCountry
}
return country
}
func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) {
addr := getRealUserIP(r)
agent := r.Header.Get("User-Agent")
@ -2335,11 +2307,35 @@ func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) {
}
func (h *Hub) OnLookupCountry(client *Client) string {
if h.geoip == nil {
return unknownCountry
ip := net.ParseIP(client.RemoteAddr())
if ip == nil {
return noCountry
}
return h.lookupClientCountry(client)
for overrideNet, country := range h.geoipOverrides {
if overrideNet.Contains(ip) {
return country
}
}
if ip.IsLoopback() {
return loopback
}
country := unknownCountry
if h.geoip != nil {
var err error
country, err = h.geoip.LookupCountry(ip)
if err != nil {
log.Printf("Could not lookup country for %s: %s", ip, err)
return unknownCountry
}
if country == "" {
country = unknownCountry
}
}
return country
}
func (h *Hub) OnClosed(client *Client) {

View file

@ -5229,3 +5229,40 @@ func TestSwitchToMultipleMixed(t *testing.T) {
"foo": "bar",
}, nil)
}
func TestGeoipOverrides(t *testing.T) {
country1 := "DE"
country2 := "IT"
country3 := "site1"
hub, _, _, _ := CreateHubForTestWithConfig(t, func(server *httptest.Server) (*goconf.ConfigFile, error) {
conf, err := getTestConfig(server)
if err != nil {
return nil, err
}
conf.AddOption("geoip-overrides", "10.1.0.0/16", country1)
conf.AddOption("geoip-overrides", "10.2.0.0/16", country2)
conf.AddOption("geoip-overrides", "192.168.10.20", country3)
return conf, err
})
if country := hub.OnLookupCountry(&Client{addr: "127.0.0.1"}); country != loopback {
t.Errorf("expected country %s, got %s", loopback, country)
}
if country := hub.OnLookupCountry(&Client{addr: "8.8.8.8"}); country != unknownCountry {
t.Errorf("expected country %s, got %s", unknownCountry, country)
}
if country := hub.OnLookupCountry(&Client{addr: "10.1.1.2"}); country != country1 {
t.Errorf("expected country %s, got %s", country1, country)
}
if country := hub.OnLookupCountry(&Client{addr: "10.2.1.2"}); country != country2 {
t.Errorf("expected country %s, got %s", country2, country)
}
if country := hub.OnLookupCountry(&Client{addr: "192.168.10.20"}); country != strings.ToUpper(country3) {
t.Errorf("expected country %s, got %s", strings.ToUpper(country3), country)
}
}