Use GeoIP overrides if no GeoIP database is configured.

This commit is contained in:
Joachim Bauch 2023-08-07 11:30:24 +02:00
parent e703982890
commit e9140178f9
No known key found for this signature in database
GPG key ID: 77C1D22D53E15F02
2 changed files with 107 additions and 74 deletions

62
hub.go
View file

@ -256,6 +256,9 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer
if err != nil {
return nil, err
}
} else {
log.Printf("Not using GeoIP database")
}
if options, _ := config.GetOptions("geoip-overrides"); len(options) > 0 {
geoipOverrides = make(map[*net.IPNet]string)
@ -299,9 +302,6 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer
geoipOverrides[ipNet] = value
}
}
} else {
log.Printf("Not using GeoIP database")
}
hub := &Hub{
events: events,
@ -2277,34 +2277,6 @@ func getRealUserIP(r *http.Request) string {
return r.RemoteAddr
}
func (h *Hub) lookupClientCountry(client *Client) string {
ip := net.ParseIP(client.RemoteAddr())
if ip == nil {
return noCountry
}
for overrideNet, country := range h.geoipOverrides {
if overrideNet.Contains(ip) {
return country
}
}
if ip.IsLoopback() {
return loopback
}
country, err := h.geoip.LookupCountry(ip)
if err != nil {
log.Printf("Could not lookup country for %s: %s", ip, err)
return unknownCountry
}
if country == "" {
return unknownCountry
}
return country
}
func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) {
addr := getRealUserIP(r)
agent := r.Header.Get("User-Agent")
@ -2335,11 +2307,35 @@ func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) {
}
func (h *Hub) OnLookupCountry(client *Client) string {
if h.geoip == nil {
ip := net.ParseIP(client.RemoteAddr())
if ip == nil {
return noCountry
}
for overrideNet, country := range h.geoipOverrides {
if overrideNet.Contains(ip) {
return country
}
}
if ip.IsLoopback() {
return loopback
}
country := unknownCountry
if h.geoip != nil {
var err error
country, err = h.geoip.LookupCountry(ip)
if err != nil {
log.Printf("Could not lookup country for %s: %s", ip, err)
return unknownCountry
}
return h.lookupClientCountry(client)
if country == "" {
country = unknownCountry
}
}
return country
}
func (h *Hub) OnClosed(client *Client) {

View file

@ -5229,3 +5229,40 @@ func TestSwitchToMultipleMixed(t *testing.T) {
"foo": "bar",
}, nil)
}
func TestGeoipOverrides(t *testing.T) {
country1 := "DE"
country2 := "IT"
country3 := "site1"
hub, _, _, _ := CreateHubForTestWithConfig(t, func(server *httptest.Server) (*goconf.ConfigFile, error) {
conf, err := getTestConfig(server)
if err != nil {
return nil, err
}
conf.AddOption("geoip-overrides", "10.1.0.0/16", country1)
conf.AddOption("geoip-overrides", "10.2.0.0/16", country2)
conf.AddOption("geoip-overrides", "192.168.10.20", country3)
return conf, err
})
if country := hub.OnLookupCountry(&Client{addr: "127.0.0.1"}); country != loopback {
t.Errorf("expected country %s, got %s", loopback, country)
}
if country := hub.OnLookupCountry(&Client{addr: "8.8.8.8"}); country != unknownCountry {
t.Errorf("expected country %s, got %s", unknownCountry, country)
}
if country := hub.OnLookupCountry(&Client{addr: "10.1.1.2"}); country != country1 {
t.Errorf("expected country %s, got %s", country1, country)
}
if country := hub.OnLookupCountry(&Client{addr: "10.2.1.2"}); country != country2 {
t.Errorf("expected country %s, got %s", country2, country)
}
if country := hub.OnLookupCountry(&Client{addr: "192.168.10.20"}); country != strings.ToUpper(country3) {
t.Errorf("expected country %s, got %s", strings.ToUpper(country3), country)
}
}