From e9140178f92d3e11c6ec809677c25d54a2faa634 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Mon, 7 Aug 2023 11:30:24 +0200 Subject: [PATCH] Use GeoIP overrides if no GeoIP database is configured. --- hub.go | 144 +++++++++++++++++++++++++--------------------------- hub_test.go | 37 ++++++++++++++ 2 files changed, 107 insertions(+), 74 deletions(-) diff --git a/hub.go b/hub.go index 1832d20..7a4e51a 100644 --- a/hub.go +++ b/hub.go @@ -256,53 +256,53 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer if err != nil { return nil, err } - - if options, _ := config.GetOptions("geoip-overrides"); len(options) > 0 { - geoipOverrides = make(map[*net.IPNet]string) - for _, option := range options { - var ip net.IP - var ipNet *net.IPNet - if strings.Contains(option, "/") { - _, ipNet, err = net.ParseCIDR(option) - if err != nil { - return nil, fmt.Errorf("could not parse CIDR %s: %s", option, err) - } - } else { - ip = net.ParseIP(option) - if ip == nil { - return nil, fmt.Errorf("could not parse IP %s", option) - } - - var mask net.IPMask - if ipv4 := ip.To4(); ipv4 != nil { - mask = net.CIDRMask(32, 32) - } else { - mask = net.CIDRMask(128, 128) - } - ipNet = &net.IPNet{ - IP: ip, - Mask: mask, - } - } - - value, _ := config.GetString("geoip-overrides", option) - value = strings.ToUpper(strings.TrimSpace(value)) - if value == "" { - log.Printf("IP %s doesn't have a country assigned, skipping", option) - continue - } else if !IsValidCountry(value) { - log.Printf("Country %s for IP %s is invalid, skipping", value, option) - continue - } - - log.Printf("Using country %s for %s", value, ipNet) - geoipOverrides[ipNet] = value - } - } } else { log.Printf("Not using GeoIP database") } + if options, _ := config.GetOptions("geoip-overrides"); len(options) > 0 { + geoipOverrides = make(map[*net.IPNet]string) + for _, option := range options { + var ip net.IP + var ipNet *net.IPNet + if strings.Contains(option, "/") { + _, ipNet, err = net.ParseCIDR(option) + if err != nil { + return nil, fmt.Errorf("could not parse CIDR %s: %s", option, err) + } + } else { + ip = net.ParseIP(option) + if ip == nil { + return nil, fmt.Errorf("could not parse IP %s", option) + } + + var mask net.IPMask + if ipv4 := ip.To4(); ipv4 != nil { + mask = net.CIDRMask(32, 32) + } else { + mask = net.CIDRMask(128, 128) + } + ipNet = &net.IPNet{ + IP: ip, + Mask: mask, + } + } + + value, _ := config.GetString("geoip-overrides", option) + value = strings.ToUpper(strings.TrimSpace(value)) + if value == "" { + log.Printf("IP %s doesn't have a country assigned, skipping", option) + continue + } else if !IsValidCountry(value) { + log.Printf("Country %s for IP %s is invalid, skipping", value, option) + continue + } + + log.Printf("Using country %s for %s", value, ipNet) + geoipOverrides[ipNet] = value + } + } + hub := &Hub{ events: events, upgrader: websocket.Upgrader{ @@ -2277,34 +2277,6 @@ func getRealUserIP(r *http.Request) string { return r.RemoteAddr } -func (h *Hub) lookupClientCountry(client *Client) string { - ip := net.ParseIP(client.RemoteAddr()) - if ip == nil { - return noCountry - } - - for overrideNet, country := range h.geoipOverrides { - if overrideNet.Contains(ip) { - return country - } - } - - if ip.IsLoopback() { - return loopback - } - - country, err := h.geoip.LookupCountry(ip) - if err != nil { - log.Printf("Could not lookup country for %s: %s", ip, err) - return unknownCountry - } - - if country == "" { - return unknownCountry - } - return country -} - func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) { addr := getRealUserIP(r) agent := r.Header.Get("User-Agent") @@ -2335,11 +2307,35 @@ func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) { } func (h *Hub) OnLookupCountry(client *Client) string { - if h.geoip == nil { - return unknownCountry + ip := net.ParseIP(client.RemoteAddr()) + if ip == nil { + return noCountry } - return h.lookupClientCountry(client) + for overrideNet, country := range h.geoipOverrides { + if overrideNet.Contains(ip) { + return country + } + } + + if ip.IsLoopback() { + return loopback + } + + country := unknownCountry + if h.geoip != nil { + var err error + country, err = h.geoip.LookupCountry(ip) + if err != nil { + log.Printf("Could not lookup country for %s: %s", ip, err) + return unknownCountry + } + + if country == "" { + country = unknownCountry + } + } + return country } func (h *Hub) OnClosed(client *Client) { diff --git a/hub_test.go b/hub_test.go index 12d0ec1..6f05d30 100644 --- a/hub_test.go +++ b/hub_test.go @@ -5229,3 +5229,40 @@ func TestSwitchToMultipleMixed(t *testing.T) { "foo": "bar", }, nil) } + +func TestGeoipOverrides(t *testing.T) { + country1 := "DE" + country2 := "IT" + country3 := "site1" + hub, _, _, _ := CreateHubForTestWithConfig(t, func(server *httptest.Server) (*goconf.ConfigFile, error) { + conf, err := getTestConfig(server) + if err != nil { + return nil, err + } + + conf.AddOption("geoip-overrides", "10.1.0.0/16", country1) + conf.AddOption("geoip-overrides", "10.2.0.0/16", country2) + conf.AddOption("geoip-overrides", "192.168.10.20", country3) + return conf, err + }) + + if country := hub.OnLookupCountry(&Client{addr: "127.0.0.1"}); country != loopback { + t.Errorf("expected country %s, got %s", loopback, country) + } + + if country := hub.OnLookupCountry(&Client{addr: "8.8.8.8"}); country != unknownCountry { + t.Errorf("expected country %s, got %s", unknownCountry, country) + } + + if country := hub.OnLookupCountry(&Client{addr: "10.1.1.2"}); country != country1 { + t.Errorf("expected country %s, got %s", country1, country) + } + + if country := hub.OnLookupCountry(&Client{addr: "10.2.1.2"}); country != country2 { + t.Errorf("expected country %s, got %s", country2, country) + } + + if country := hub.OnLookupCountry(&Client{addr: "192.168.10.20"}); country != strings.ToUpper(country3) { + t.Errorf("expected country %s, got %s", strings.ToUpper(country3), country) + } +}