diff --git a/.codecov.yml b/.codecov.yml index 7f3ccd4..d790054 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -46,6 +46,10 @@ component_management: name: etcd paths: - etcd/** + - component_id: module_geoip + name: geoip + paths: + - geoip/** - component_id: module_internal name: internal paths: diff --git a/Makefile b/Makefile index b726693..f9085d9 100644 --- a/Makefile +++ b/Makefile @@ -21,7 +21,7 @@ GRPC_PROTO_GO_FILES := $(addsuffix .pb.go,$(GRPC_PROTO_FILES)) $(addsuffix _grpc TEST_GO_FILES := $(wildcard *_test.go)) EASYJSON_FILES := $(filter-out $(TEST_GO_FILES),$(wildcard api*.go api/signaling.go */api.go talk/ocs.go)) EASYJSON_GO_FILES := $(patsubst %.go,%_easyjson.go,$(EASYJSON_FILES)) -COMMON_GO_FILES := $(filter-out continentmap.go $(PROTO_GO_FILES) $(GRPC_PROTO_GO_FILES) $(EASYJSON_GO_FILES) $(TEST_GO_FILES),$(wildcard *.go)) +COMMON_GO_FILES := $(filter-out geoip/continentmap.go $(PROTO_GO_FILES) $(GRPC_PROTO_GO_FILES) $(EASYJSON_GO_FILES) $(TEST_GO_FILES),$(wildcard *.go)) CLIENT_TEST_GO_FILES := $(wildcard client/*_test.go)) CLIENT_GO_FILES := $(filter-out $(CLIENT_TEST_GO_FILES),$(wildcard client/*.go)) SERVER_TEST_GO_FILES := $(wildcard server/*_test.go)) @@ -92,7 +92,7 @@ $(GOPATHBIN)/protoc-gen-go-grpc: go.mod go.sum $(GOPATHBIN)/checklocks: go.mod go.sum $(GO) install gvisor.dev/gvisor/tools/checklocks/cmd/checklocks@go -continentmap.go: +geoip/continentmap.go: $(CURDIR)/scripts/get_continent_map.py $@ check-continentmap: @@ -100,7 +100,7 @@ check-continentmap: TMP=$$(mktemp -d) ;\ echo Make sure to remove $$TMP on error ;\ $(CURDIR)/scripts/get_continent_map.py $$TMP/continentmap.go ;\ - diff -u continentmap.go $$TMP/continentmap.go ;\ + diff -u geoip/continentmap.go $$TMP/continentmap.go ;\ rm -rf $$TMP get: @@ -215,6 +215,6 @@ tarball: vendor | $(TMPDIR) dist: tarball .NOTPARALLEL: $(EASYJSON_GO_FILES) -.PHONY: continentmap.go common vendor +.PHONY: geoip/continentmap.go common vendor .SECONDARY: $(EASYJSON_GO_FILES) $(PROTO_GO_FILES) .DELETE_ON_ERROR: diff --git a/api/signaling.go b/api/signaling.go index 04a7988..0c7df41 100644 --- a/api/signaling.go +++ b/api/signaling.go @@ -37,6 +37,7 @@ import ( "github.com/pion/sdp/v3" "github.com/strukturag/nextcloud-spreed-signaling/container" + "github.com/strukturag/nextcloud-spreed-signaling/geoip" "github.com/strukturag/nextcloud-spreed-signaling/internal" ) @@ -301,9 +302,9 @@ func (e *Error) Error() string { } type WelcomeServerMessage struct { - Version string `json:"version"` - Features []string `json:"features,omitempty"` - Country string `json:"country,omitempty"` + Version string `json:"version"` + Features []string `json:"features,omitempty"` + Country geoip.Country `json:"country,omitempty"` } func NewWelcomeServerMessage(version string, feature ...string) *WelcomeServerMessage { diff --git a/api/signaling_easyjson.go b/api/signaling_easyjson.go index ff1a12b..4a4ce0e 100644 --- a/api/signaling_easyjson.go +++ b/api/signaling_easyjson.go @@ -8,6 +8,7 @@ import ( easyjson "github.com/mailru/easyjson" jlexer "github.com/mailru/easyjson/jlexer" jwriter "github.com/mailru/easyjson/jwriter" + geoip "github.com/strukturag/nextcloud-spreed-signaling/geoip" time "time" ) @@ -70,7 +71,7 @@ func easyjson6128dd2DecodeGithubComStrukturagNextcloudSpreedSignalingApi(in *jle if in.IsNull() { in.Skip() } else { - out.Country = string(in.String()) + out.Country = geoip.Country(in.String()) } default: in.SkipRecursive() diff --git a/api_backend.go b/api_backend.go index 322d075..bb720b8 100644 --- a/api_backend.go +++ b/api_backend.go @@ -35,6 +35,7 @@ import ( "github.com/strukturag/nextcloud-spreed-signaling/api" "github.com/strukturag/nextcloud-spreed-signaling/etcd" + "github.com/strukturag/nextcloud-spreed-signaling/geoip" ) const ( @@ -458,7 +459,7 @@ type BackendServerInfoSfuProxy struct { Version string `json:"version,omitempty"` Features []string `json:"features,omitempty"` - Country string `json:"country,omitempty"` + Country geoip.Country `json:"country,omitempty"` Load *uint64 `json:"load,omitempty"` Bandwidth *EventProxyServerBandwidth `json:"bandwidth,omitempty"` } diff --git a/api_backend_easyjson.go b/api_backend_easyjson.go index 6dc0a11..9e8ef43 100644 --- a/api_backend_easyjson.go +++ b/api_backend_easyjson.go @@ -9,6 +9,7 @@ import ( jwriter "github.com/mailru/easyjson/jwriter" api "github.com/strukturag/nextcloud-spreed-signaling/api" etcd "github.com/strukturag/nextcloud-spreed-signaling/etcd" + geoip "github.com/strukturag/nextcloud-spreed-signaling/geoip" time "time" ) @@ -769,7 +770,7 @@ func easyjson4354c623DecodeGithubComStrukturagNextcloudSpreedSignaling5(in *jlex if in.IsNull() { in.Skip() } else { - out.Country = string(in.String()) + out.Country = geoip.Country(in.String()) } case "load": if in.IsNull() { diff --git a/client.go b/client.go index fcb86a3..0c024b2 100644 --- a/client.go +++ b/client.go @@ -38,6 +38,7 @@ import ( "github.com/mailru/easyjson" "github.com/strukturag/nextcloud-spreed-signaling/api" + "github.com/strukturag/nextcloud-spreed-signaling/geoip" "github.com/strukturag/nextcloud-spreed-signaling/internal" "github.com/strukturag/nextcloud-spreed-signaling/log" "github.com/strukturag/nextcloud-spreed-signaling/pool" @@ -57,33 +58,10 @@ const ( maxMessageSize = 64 * 1024 ) -var ( - noCountry = "no-country" - - loopback = "loopback" - - unknownCountry = "unknown-country" -) - func init() { RegisterClientStats() } -func IsValidCountry(country string) bool { - switch country { - case "": - fallthrough - case noCountry: - fallthrough - case loopback: - fallthrough - case unknownCountry: - return false - default: - return true - } -} - var ( InvalidFormat = api.NewError("invalid_format", "Invalid data format.") @@ -99,7 +77,7 @@ type WritableClientMessage interface { type HandlerClient interface { Context() context.Context RemoteAddr() string - Country() string + Country() geoip.Country UserAgent() string IsConnected() bool IsAuthenticated() bool @@ -122,7 +100,7 @@ type ClientHandler interface { } type ClientGeoIpHandler interface { - OnLookupCountry(HandlerClient) string + OnLookupCountry(HandlerClient) geoip.Country } type Client struct { @@ -132,7 +110,7 @@ type Client struct { addr string agent string closed atomic.Int32 - country *string + country *geoip.Country logRTT bool handlerMu sync.RWMutex @@ -246,13 +224,13 @@ func (c *Client) UserAgent() string { return c.agent } -func (c *Client) Country() string { +func (c *Client) Country() geoip.Country { if c.country == nil { - var country string + var country geoip.Country if handler, ok := c.getHandler().(ClientGeoIpHandler); ok { country = handler.OnLookupCountry(c) } else { - country = unknownCountry + country = geoip.UnknownCountry } c.country = &country } diff --git a/continentmap.go b/geoip/continentmap.go similarity index 98% rename from continentmap.go rename to geoip/continentmap.go index a99b944..75673bd 100644 --- a/continentmap.go +++ b/geoip/continentmap.go @@ -1,10 +1,10 @@ -package signaling +package geoip // This file has been automatically generated, do not modify. // Source: https://raw.githubusercontent.com/datasets/country-codes/refs/heads/main/data/country-codes.csv var ( - ContinentMap = map[string][]string{ + ContinentMap = map[Country][]Continent{ "AD": {"EU"}, "AE": {"AS"}, "AF": {"AS"}, diff --git a/geoip/geoip.go b/geoip/geoip.go new file mode 100644 index 0000000..a020a38 --- /dev/null +++ b/geoip/geoip.go @@ -0,0 +1,98 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2025 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package geoip + +import "strings" + +type ( + Country string + Continent string +) + +var ( + NoCountry = Country("no-country") + noCountryUpper = Country(strings.ToUpper("no-country")) + + Loopback = Country("loopback") + loopbackUpper = Country(strings.ToUpper("loopback")) + + UnknownCountry = Country("unknown-country") + unknownCountryUpper = Country(strings.ToUpper("unknown-country")) +) + +func IsValidCountry(country Country) bool { + switch country { + case "": + fallthrough + case NoCountry: + fallthrough + case noCountryUpper: + fallthrough + case Loopback: + fallthrough + case loopbackUpper: + fallthrough + case UnknownCountry: + fallthrough + case unknownCountryUpper: + return false + default: + return true + } +} + +func LookupContinents(country Country) []Continent { + continents, found := ContinentMap[country] + if !found { + return nil + } + + return continents +} + +func IsValidContinent(continent Continent) bool { + switch continent { + case "AF": + // Africa + fallthrough + case "AN": + // Antartica + fallthrough + case "AS": + // Asia + fallthrough + case "EU": + // Europe + fallthrough + case "NA": + // North America + fallthrough + case "SA": + // South America + fallthrough + case "OC": + // Oceania + return true + default: + return false + } +} diff --git a/geoip.go b/geoip/maxmind.go similarity index 63% rename from geoip.go rename to geoip/maxmind.go index 4da1aa4..d5f1c8c 100644 --- a/geoip.go +++ b/geoip/maxmind.go @@ -19,12 +19,11 @@ * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see . */ -package signaling +package geoip import ( "archive/tar" "compress/gzip" - "context" "errors" "fmt" "io" @@ -36,10 +35,8 @@ import ( "sync/atomic" "time" - "github.com/dlintw/goconf" "github.com/oschwald/maxminddb-golang" - "github.com/strukturag/nextcloud-spreed-signaling/config" "github.com/strukturag/nextcloud-spreed-signaling/log" ) @@ -47,7 +44,7 @@ var ( ErrDatabaseNotInitialized = errors.New("GeoIP database not initialized yet") ) -func GetGeoIpDownloadUrl(license string) string { +func GetMaxMindDownloadUrl(license string) string { if license == "" { return "" } @@ -59,7 +56,7 @@ func GetGeoIpDownloadUrl(license string) string { return result } -type GeoLookup struct { +type Lookup struct { logger log.Logger url string isFile bool @@ -71,16 +68,16 @@ type GeoLookup struct { reader atomic.Pointer[maxminddb.Reader] } -func NewGeoLookupFromUrl(logger log.Logger, url string) (*GeoLookup, error) { - geoip := &GeoLookup{ +func NewLookupFromUrl(logger log.Logger, url string) (*Lookup, error) { + geoip := &Lookup{ logger: logger, url: url, } return geoip, nil } -func NewGeoLookupFromFile(logger log.Logger, filename string) (*GeoLookup, error) { - geoip := &GeoLookup{ +func NewLookupFromFile(logger log.Logger, filename string) (*Lookup, error) { + geoip := &Lookup{ logger: logger, url: filename, isFile: true, @@ -92,13 +89,13 @@ func NewGeoLookupFromFile(logger log.Logger, filename string) (*GeoLookup, error return geoip, nil } -func (g *GeoLookup) Close() { +func (g *Lookup) Close() { if reader := g.reader.Swap(nil); reader != nil { reader.Close() } } -func (g *GeoLookup) Update() error { +func (g *Lookup) Update() error { if g.isFile { return g.updateFile() } @@ -106,7 +103,7 @@ func (g *GeoLookup) Update() error { return g.updateUrl() } -func (g *GeoLookup) updateFile() error { +func (g *Lookup) updateFile() error { info, err := os.Stat(g.url) if err != nil { return err @@ -136,7 +133,7 @@ func (g *GeoLookup) updateFile() error { return nil } -func (g *GeoLookup) updateUrl() error { +func (g *Lookup) updateUrl() error { request, err := http.NewRequest("GET", g.url, nil) if err != nil { return err @@ -219,7 +216,7 @@ func (g *GeoLookup) updateUrl() error { return nil } -func (g *GeoLookup) LookupCountry(ip net.IP) (string, error) { +func (g *Lookup) LookupCountry(ip net.IP) (Country, error) { var record struct { Country struct { ISOCode string `maxminddb:"iso_code"` @@ -235,103 +232,5 @@ func (g *GeoLookup) LookupCountry(ip net.IP) (string, error) { return "", err } - return record.Country.ISOCode, nil -} - -func LookupContinents(country string) []string { - continents, found := ContinentMap[country] - if !found { - return nil - } - - return continents -} - -func IsValidContinent(continent string) bool { - switch continent { - case "AF": - // Africa - fallthrough - case "AN": - // Antartica - fallthrough - case "AS": - // Asia - fallthrough - case "EU": - // Europe - fallthrough - case "NA": - // North America - fallthrough - case "SA": - // South America - fallthrough - case "OC": - // Oceania - return true - default: - return false - } -} - -func LoadGeoIPOverrides(ctx context.Context, cfg *goconf.ConfigFile, ignoreErrors bool) (map[*net.IPNet]string, error) { - logger := log.LoggerFromContext(ctx) - options, _ := config.GetStringOptions(cfg, "geoip-overrides", true) - if len(options) == 0 { - return nil, nil - } - - var err error - geoipOverrides := make(map[*net.IPNet]string, len(options)) - for option, value := range options { - var ip net.IP - var ipNet *net.IPNet - if strings.Contains(option, "/") { - _, ipNet, err = net.ParseCIDR(option) - if err != nil { - if ignoreErrors { - logger.Printf("could not parse CIDR %s (%s), skipping", option, err) - continue - } - - return nil, fmt.Errorf("could not parse CIDR %s: %s", option, err) - } - } else { - ip = net.ParseIP(option) - if ip == nil { - if ignoreErrors { - logger.Printf("could not parse IP %s, skipping", option) - continue - } - - return nil, fmt.Errorf("could not parse IP %s", option) - } - - var mask net.IPMask - if ipv4 := ip.To4(); ipv4 != nil { - mask = net.CIDRMask(32, 32) - } else { - mask = net.CIDRMask(128, 128) - } - ipNet = &net.IPNet{ - IP: ip, - Mask: mask, - } - } - - value = strings.ToUpper(strings.TrimSpace(value)) - if value == "" { - logger.Printf("IP %s doesn't have a country assigned, skipping", option) - continue - } else if !IsValidCountry(value) { - logger.Printf("Country %s for IP %s is invalid, skipping", value, option) - continue - } - - logger.Printf("Using country %s for %s", value, ipNet) - geoipOverrides[ipNet] = value - } - - return geoipOverrides, nil + return Country(record.Country.ISOCode), nil } diff --git a/geoip_test.go b/geoip/maxmind_test.go similarity index 84% rename from geoip_test.go rename to geoip/maxmind_test.go index de59b78..82d4657 100644 --- a/geoip_test.go +++ b/geoip/maxmind_test.go @@ -19,7 +19,7 @@ * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see . */ -package signaling +package geoip import ( "archive/tar" @@ -39,8 +39,8 @@ import ( "github.com/strukturag/nextcloud-spreed-signaling/log" ) -func testGeoLookupReader(t *testing.T, reader *GeoLookup) { - tests := map[string]string{ +func testLookupReader(t *testing.T, reader *Lookup) { + tests := map[string]Country{ // Example from maxminddb-golang code. "81.2.69.142": "GB", // Local addresses don't have a country assigned. @@ -59,7 +59,7 @@ func testGeoLookupReader(t *testing.T, reader *GeoLookup) { } } -func GetGeoIpUrlForTest(t *testing.T) string { +func GetIpUrlForTest(t *testing.T) string { t.Helper() var geoIpUrl string @@ -72,29 +72,29 @@ func GetGeoIpUrlForTest(t *testing.T) string { if license == "" { t.Skip("No MaxMind GeoLite2 license was set in MAXMIND_GEOLITE2_LICENSE environment variable.") } - geoIpUrl = GetGeoIpDownloadUrl(license) + geoIpUrl = GetMaxMindDownloadUrl(license) } return geoIpUrl } -func TestGeoLookup(t *testing.T) { +func TestLookup(t *testing.T) { t.Parallel() logger := log.NewLoggerForTest(t) require := require.New(t) - reader, err := NewGeoLookupFromUrl(logger, GetGeoIpUrlForTest(t)) + reader, err := NewLookupFromUrl(logger, GetIpUrlForTest(t)) require.NoError(err) defer reader.Close() require.NoError(reader.Update()) - testGeoLookupReader(t, reader) + testLookupReader(t, reader) } -func TestGeoLookupCaching(t *testing.T) { +func TestLookupCaching(t *testing.T) { t.Parallel() logger := log.NewLoggerForTest(t) require := require.New(t) - reader, err := NewGeoLookupFromUrl(logger, GetGeoIpUrlForTest(t)) + reader, err := NewLookupFromUrl(logger, GetIpUrlForTest(t)) require.NoError(err) defer reader.Close() @@ -105,9 +105,9 @@ func TestGeoLookupCaching(t *testing.T) { require.NoError(reader.Update()) } -func TestGeoLookupContinent(t *testing.T) { +func TestLookupContinent(t *testing.T) { t.Parallel() - tests := map[string][]string{ + tests := map[Country][]Continent{ "AU": {"OC"}, "DE": {"EU"}, "RU": {"EU"}, @@ -116,7 +116,7 @@ func TestGeoLookupContinent(t *testing.T) { } for country, expected := range tests { - t.Run(country, func(t *testing.T) { + t.Run(string(country), func(t *testing.T) { t.Parallel() continents := LookupContinents(country) if !assert.Len(t, continents, len(expected), "Continents didn't match for %s: got %s, expected %s", country, continents, expected) { @@ -131,19 +131,19 @@ func TestGeoLookupContinent(t *testing.T) { } } -func TestGeoLookupCloseEmpty(t *testing.T) { +func TestLookupCloseEmpty(t *testing.T) { t.Parallel() logger := log.NewLoggerForTest(t) - reader, err := NewGeoLookupFromUrl(logger, "ignore-url") + reader, err := NewLookupFromUrl(logger, "ignore-url") require.NoError(t, err) reader.Close() } -func TestGeoLookupFromFile(t *testing.T) { +func TestLookupFromFile(t *testing.T) { t.Parallel() logger := log.NewLoggerForTest(t) require := require.New(t) - geoIpUrl := GetGeoIpUrlForTest(t) + geoIpUrl := GetIpUrlForTest(t) resp, err := http.Get(geoIpUrl) require.NoError(err) @@ -196,11 +196,11 @@ func TestGeoLookupFromFile(t *testing.T) { require.True(foundDatabase, "Did not find GeoIP database in download from %s", geoIpUrl) - reader, err := NewGeoLookupFromFile(logger, tmpfile.Name()) + reader, err := NewLookupFromFile(logger, tmpfile.Name()) require.NoError(err) defer reader.Close() - testGeoLookupReader(t, reader) + testLookupReader(t, reader) } func TestIsValidContinent(t *testing.T) { diff --git a/geoip/overrides.go b/geoip/overrides.go new file mode 100644 index 0000000..a15faa7 --- /dev/null +++ b/geoip/overrides.go @@ -0,0 +1,130 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2025 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package geoip + +import ( + "context" + "fmt" + "maps" + "net" + "strings" + "sync/atomic" + + "github.com/dlintw/goconf" + "github.com/strukturag/nextcloud-spreed-signaling/config" + "github.com/strukturag/nextcloud-spreed-signaling/log" +) + +type Overrides map[*net.IPNet]Country + +func (o Overrides) Lookup(ip net.IP) (Country, bool) { + for overrideNet, country := range o { + if overrideNet.Contains(ip) { + return country, true + } + } + + return UnknownCountry, false +} + +type AtomicOverrides struct { + value atomic.Pointer[Overrides] +} + +func (a *AtomicOverrides) Store(value Overrides) { + if len(value) == 0 { + a.value.Store(nil) + } else { + v := maps.Clone(value) + a.value.Store(&v) + } +} + +func (a *AtomicOverrides) Load() Overrides { + value := a.value.Load() + if value == nil { + return nil + } + + return *value +} + +func LoadOverrides(ctx context.Context, cfg *goconf.ConfigFile, ignoreErrors bool) (Overrides, error) { + logger := log.LoggerFromContext(ctx) + options, _ := config.GetStringOptions(cfg, "geoip-overrides", true) + if len(options) == 0 { + return nil, nil + } + + var err error + geoipOverrides := make(Overrides, len(options)) + for option, value := range options { + var ip net.IP + var ipNet *net.IPNet + if strings.Contains(option, "/") { + _, ipNet, err = net.ParseCIDR(option) + if err != nil { + if ignoreErrors { + logger.Printf("could not parse CIDR %s (%s), skipping", option, err) + continue + } + + return nil, fmt.Errorf("could not parse CIDR %s: %w", option, err) + } + } else { + ip = net.ParseIP(option) + if ip == nil { + if ignoreErrors { + logger.Printf("could not parse IP %s, skipping", option) + continue + } + + return nil, fmt.Errorf("could not parse IP %s", option) + } + + var mask net.IPMask + if ipv4 := ip.To4(); ipv4 != nil { + mask = net.CIDRMask(32, 32) + } else { + mask = net.CIDRMask(128, 128) + } + ipNet = &net.IPNet{ + IP: ip, + Mask: mask, + } + } + + value = strings.ToUpper(strings.TrimSpace(value)) + if value == "" { + logger.Printf("IP %s doesn't have a country assigned, skipping", option) + continue + } else if !IsValidCountry(Country(value)) { + logger.Printf("Country %s for IP %s is invalid, skipping", value, option) + continue + } + + logger.Printf("Using country %s for %s", value, ipNet) + geoipOverrides[ipNet] = Country(value) + } + + return geoipOverrides, nil +} diff --git a/geoip/overrides_test.go b/geoip/overrides_test.go new file mode 100644 index 0000000..68845c2 --- /dev/null +++ b/geoip/overrides_test.go @@ -0,0 +1,182 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2025 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package geoip + +import ( + "maps" + "net" + "testing" + + "github.com/dlintw/goconf" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/strukturag/nextcloud-spreed-signaling/log" +) + +func mustSucceed1[T any, A1 any](t *testing.T, f func(a1 A1) (T, bool), a1 A1) T { + t.Helper() + result, ok := f(a1) + if !ok { + t.FailNow() + } + return result +} + +func TestOverridesEmpty(t *testing.T) { + t.Parallel() + require := require.New(t) + assert := assert.New(t) + + config := goconf.NewConfigFile() + logger := log.NewLoggerForTest(t) + ctx := log.NewLoggerContext(t.Context(), logger) + + overrides, err := LoadOverrides(ctx, config, true) + require.NoError(err) + assert.Empty(overrides) +} + +func TestOverrides(t *testing.T) { + t.Parallel() + require := require.New(t) + assert := assert.New(t) + + config := goconf.NewConfigFile() + config.AddOption("geoip-overrides", "10.1.0.0/16", "DE") + config.AddOption("geoip-overrides", "2001:db8::/48", "FR") + config.AddOption("geoip-overrides", "2001:db9::3", "CH") + config.AddOption("geoip-overrides", "10.3.4.5", "custom") + config.AddOption("geoip-overrides", "10.4.5.6", "loopback") + config.AddOption("geoip-overrides", "192.168.1.0", "") + + logger := log.NewLoggerForTest(t) + ctx := log.NewLoggerContext(t.Context(), logger) + + overrides, err := LoadOverrides(ctx, config, true) + require.NoError(err) + + if assert.Len(overrides, 4) { + assert.EqualValues("DE", mustSucceed1(t, overrides.Lookup, net.ParseIP("10.1.2.3"))) + assert.EqualValues("DE", mustSucceed1(t, overrides.Lookup, net.ParseIP("10.1.3.4"))) + assert.EqualValues("FR", mustSucceed1(t, overrides.Lookup, net.ParseIP("2001:db8::1"))) + assert.EqualValues("FR", mustSucceed1(t, overrides.Lookup, net.ParseIP("2001:db8::2"))) + assert.EqualValues("CH", mustSucceed1(t, overrides.Lookup, net.ParseIP("2001:db9::3"))) + assert.EqualValues("CUSTOM", mustSucceed1(t, overrides.Lookup, net.ParseIP("10.3.4.5"))) + + country, ok := overrides.Lookup(net.ParseIP("10.4.5.6")) + assert.False(ok, "expected no country, got %s", country) + + country, ok = overrides.Lookup(net.ParseIP("192.168.1.0")) + assert.False(ok, "expected no country, got %s", country) + } +} + +func TestOverridesInvalidIgnoreErrors(t *testing.T) { + t.Parallel() + require := require.New(t) + assert := assert.New(t) + + config := goconf.NewConfigFile() + config.AddOption("geoip-overrides", "invalid-ip", "DE") + config.AddOption("geoip-overrides", "300.1.2.3/8", "DE") + config.AddOption("geoip-overrides", "10.2.0.0/16", "FR") + + logger := log.NewLoggerForTest(t) + ctx := log.NewLoggerContext(t.Context(), logger) + + overrides, err := LoadOverrides(ctx, config, true) + require.NoError(err) + + if assert.Len(overrides, 1) { + assert.EqualValues("FR", mustSucceed1(t, overrides.Lookup, net.ParseIP("10.2.3.4"))) + assert.EqualValues("FR", mustSucceed1(t, overrides.Lookup, net.ParseIP("10.2.4.5"))) + + country, ok := overrides.Lookup(net.ParseIP("10.3.4.5")) + assert.False(ok, "expected no country, got %s", country) + + country, ok = overrides.Lookup(net.ParseIP("192.168.1.0")) + assert.False(ok, "expected no country, got %s", country) + } +} + +func TestOverridesInvalidIPReturnErrors(t *testing.T) { + t.Parallel() + assert := assert.New(t) + + config := goconf.NewConfigFile() + config.AddOption("geoip-overrides", "invalid-ip", "DE") + config.AddOption("geoip-overrides", "10.2.0.0/16", "FR") + + logger := log.NewLoggerForTest(t) + ctx := log.NewLoggerContext(t.Context(), logger) + + overrides, err := LoadOverrides(ctx, config, false) + assert.ErrorContains(err, "could not parse IP", err) + assert.Empty(overrides) +} + +func TestOverridesInvalidCIDRReturnErrors(t *testing.T) { + t.Parallel() + assert := assert.New(t) + + config := goconf.NewConfigFile() + config.AddOption("geoip-overrides", "300.1.2.3/8", "DE") + config.AddOption("geoip-overrides", "10.2.0.0/16", "FR") + + logger := log.NewLoggerForTest(t) + ctx := log.NewLoggerContext(t.Context(), logger) + + overrides, err := LoadOverrides(ctx, config, false) + var e *net.ParseError + assert.ErrorAs(err, &e) + assert.Empty(overrides) +} + +func TestAtomicOverrides(t *testing.T) { + t.Parallel() + assert := assert.New(t) + + overrides := make(Overrides) + overrides[&net.IPNet{ + IP: net.ParseIP("10.1.2.3."), + Mask: net.CIDRMask(32, 32), + }] = "DE" + + var value AtomicOverrides + assert.Nil(value.Load()) + value.Store(make(Overrides)) + assert.Nil(value.Load()) + value.Store(overrides) + if o := value.Load(); assert.NotEmpty(o) { + assert.Equal(overrides, o) + } + // Updating the overrides doesn't change the stored value. + overrides2 := maps.Clone(overrides) + overrides[&net.IPNet{ + IP: net.ParseIP("10.1.2.3."), + Mask: net.CIDRMask(32, 32), + }] = "FR" + if o := value.Load(); assert.NotEmpty(o) { + assert.Equal(overrides2, o) + } +} diff --git a/grpc_client.go b/grpc_client.go index 9e49772..07912fb 100644 --- a/grpc_client.go +++ b/grpc_client.go @@ -47,6 +47,7 @@ import ( "github.com/strukturag/nextcloud-spreed-signaling/async" "github.com/strukturag/nextcloud-spreed-signaling/dns" "github.com/strukturag/nextcloud-spreed-signaling/etcd" + "github.com/strukturag/nextcloud-spreed-signaling/geoip" "github.com/strukturag/nextcloud-spreed-signaling/internal" "github.com/strukturag/nextcloud-spreed-signaling/log" ) @@ -369,7 +370,7 @@ func (c *GrpcClient) GetTransientData(ctx context.Context, room *Room) (Transien type ProxySessionReceiver interface { RemoteAddr() string - Country() string + Country() geoip.Country UserAgent() string OnProxyMessage(message *ServerSessionMessage) error @@ -429,7 +430,7 @@ func (c *GrpcClient) ProxySession(ctx context.Context, sessionId api.PublicSessi md := metadata.Pairs( "sessionId", string(sessionId), "remoteAddr", receiver.RemoteAddr(), - "country", receiver.Country(), + "country", string(receiver.Country()), "userAgent", receiver.UserAgent(), ) client, err := c.impl.ProxySession(metadata.NewOutgoingContext(ctx, md), grpc.WaitForReady(true)) diff --git a/grpc_remote_client.go b/grpc_remote_client.go index b04f52b..85fbd95 100644 --- a/grpc_remote_client.go +++ b/grpc_remote_client.go @@ -34,6 +34,7 @@ import ( "google.golang.org/grpc/status" "github.com/strukturag/nextcloud-spreed-signaling/api" + "github.com/strukturag/nextcloud-spreed-signaling/geoip" "github.com/strukturag/nextcloud-spreed-signaling/log" ) @@ -57,7 +58,7 @@ type remoteGrpcClient struct { sessionId string remoteAddr string - country string + country geoip.Country userAgent string closeCtx context.Context @@ -82,7 +83,7 @@ func newRemoteGrpcClient(hub *Hub, request RpcSessions_ProxySessionServer) (*rem sessionId: getMD(md, "sessionId"), remoteAddr: getMD(md, "remoteAddr"), - country: getMD(md, "country"), + country: geoip.Country(getMD(md, "country")), userAgent: getMD(md, "userAgent"), closeCtx: closeCtx, @@ -131,7 +132,7 @@ func (c *remoteGrpcClient) UserAgent() string { return c.userAgent } -func (c *remoteGrpcClient) Country() string { +func (c *remoteGrpcClient) Country() geoip.Country { return c.country } diff --git a/hub.go b/hub.go index 1db57fa..f15d9e6 100644 --- a/hub.go +++ b/hub.go @@ -56,6 +56,7 @@ import ( "github.com/strukturag/nextcloud-spreed-signaling/config" "github.com/strukturag/nextcloud-spreed-signaling/container" "github.com/strukturag/nextcloud-spreed-signaling/etcd" + "github.com/strukturag/nextcloud-spreed-signaling/geoip" "github.com/strukturag/nextcloud-spreed-signaling/internal" "github.com/strukturag/nextcloud-spreed-signaling/log" "github.com/strukturag/nextcloud-spreed-signaling/talk" @@ -207,8 +208,8 @@ type Hub struct { backend *BackendClient trustedProxies atomic.Pointer[container.IPList] - geoip *GeoLookup - geoipOverrides atomic.Pointer[map[*net.IPNet]string] + geoip *geoip.Lookup + geoipOverrides geoip.AtomicOverrides geoipUpdating atomic.Bool etcdClient etcd.Client @@ -329,18 +330,18 @@ func NewHub(ctx context.Context, cfg *goconf.ConfigFile, events AsyncEvents, rpc } if geoipUrl == "" { if geoipLicense, _ := cfg.GetString("geoip", "license"); geoipLicense != "" { - geoipUrl = GetGeoIpDownloadUrl(geoipLicense) + geoipUrl = geoip.GetMaxMindDownloadUrl(geoipLicense) } } - var geoip *GeoLookup + var geoipLookup *geoip.Lookup if geoipUrl != "" { if geoipUrl, found := strings.CutPrefix(geoipUrl, "file://"); found { logger.Printf("Using GeoIP database from %s", geoipUrl) - geoip, err = NewGeoLookupFromFile(logger, geoipUrl) + geoipLookup, err = geoip.NewLookupFromFile(logger, geoipUrl) } else { logger.Printf("Downloading GeoIP database from %s", geoipUrl) - geoip, err = NewGeoLookupFromUrl(logger, geoipUrl) + geoipLookup, err = geoip.NewLookupFromUrl(logger, geoipUrl) } if err != nil { return nil, err @@ -349,7 +350,7 @@ func NewHub(ctx context.Context, cfg *goconf.ConfigFile, events AsyncEvents, rpc logger.Printf("Not using GeoIP database") } - geoipOverrides, err := LoadGeoIPOverrides(ctx, cfg, false) + geoipOverrides, err := geoip.LoadOverrides(ctx, cfg, false) if err != nil { return nil, err } @@ -409,7 +410,7 @@ func NewHub(ctx context.Context, cfg *goconf.ConfigFile, events AsyncEvents, rpc backendTimeout: backendTimeout, backend: backend, - geoip: geoip, + geoip: geoipLookup, etcdClient: etcdClient, rpcServer: rpcServer, @@ -444,9 +445,8 @@ func NewHub(ctx context.Context, cfg *goconf.ConfigFile, events AsyncEvents, rpc } hub.trustedProxies.Store(trustedProxiesIps) - if len(geoipOverrides) > 0 { - hub.geoipOverrides.Store(&geoipOverrides) - } + hub.geoipOverrides.Store(geoipOverrides) + hub.setWelcomeMessage(&api.ServerMessage{ Type: "welcome", Welcome: api.NewWelcomeServerMessage(version, api.DefaultWelcomeFeatures...), @@ -588,12 +588,8 @@ func (h *Hub) Reload(ctx context.Context, config *goconf.ConfigFile) { h.logger.Printf("Error parsing trusted proxies from \"%s\": %s", trustedProxies, err) } - geoipOverrides, _ := LoadGeoIPOverrides(ctx, config, true) - if len(geoipOverrides) > 0 { - h.geoipOverrides.Store(&geoipOverrides) - } else { - h.geoipOverrides.Store(nil) - } + geoipOverrides, _ := geoip.LoadOverrides(ctx, config, true) + h.geoipOverrides.Store(geoipOverrides) if value, _ := config.GetString("mcu", "allowedcandidates"); value != "" { if allowed, err := container.ParseIPList(value); err != nil { @@ -1066,8 +1062,8 @@ func (h *Hub) processRegister(c HandlerClient, message *api.ClientMessage, backe } h.mu.Unlock() - if country := client.Country(); IsValidCountry(country) { - statsClientCountries.WithLabelValues(country).Inc() + if country := client.Country(); geoip.IsValidCountry(country) { + statsClientCountries.WithLabelValues(string(country)).Inc() } statsHubSessionsCurrent.WithLabelValues(backend.Id(), string(session.ClientType())).Inc() statsHubSessionsTotal.WithLabelValues(backend.Id(), string(session.ClientType())).Inc() @@ -3157,35 +3153,31 @@ func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) { client.ReadPump() } -func (h *Hub) OnLookupCountry(client HandlerClient) string { +func (h *Hub) OnLookupCountry(client HandlerClient) geoip.Country { ip := net.ParseIP(client.RemoteAddr()) if ip == nil { - return noCountry + return geoip.NoCountry } - if overrides := h.geoipOverrides.Load(); overrides != nil { - for overrideNet, country := range *overrides { - if overrideNet.Contains(ip) { - return country - } - } + if country, found := h.geoipOverrides.Load().Lookup(ip); found { + return country } if ip.IsLoopback() { - return loopback + return geoip.Loopback } - country := unknownCountry + country := geoip.UnknownCountry if h.geoip != nil { var err error country, err = h.geoip.LookupCountry(ip) if err != nil { h.logger.Printf("Could not lookup country for %s: %s", ip, err) - return unknownCountry + return geoip.UnknownCountry } if country == "" { - country = unknownCountry + country = geoip.UnknownCountry } } return country diff --git a/hub_test.go b/hub_test.go index 79ba5b2..a5929f7 100644 --- a/hub_test.go +++ b/hub_test.go @@ -55,6 +55,7 @@ import ( "github.com/strukturag/nextcloud-spreed-signaling/api" "github.com/strukturag/nextcloud-spreed-signaling/async" "github.com/strukturag/nextcloud-spreed-signaling/container" + "github.com/strukturag/nextcloud-spreed-signaling/geoip" "github.com/strukturag/nextcloud-spreed-signaling/internal" "github.com/strukturag/nextcloud-spreed-signaling/log" "github.com/strukturag/nextcloud-spreed-signaling/mock" @@ -5218,11 +5219,11 @@ func TestGeoipOverrides(t *testing.T) { return conf, err }) - assert.Equal(loopback, hub.OnLookupCountry(&Client{addr: "127.0.0.1"})) - assert.Equal(unknownCountry, hub.OnLookupCountry(&Client{addr: "8.8.8.8"})) - assert.Equal(country1, hub.OnLookupCountry(&Client{addr: "10.1.1.2"})) - assert.Equal(country2, hub.OnLookupCountry(&Client{addr: "10.2.1.2"})) - assert.Equal(strings.ToUpper(country3), hub.OnLookupCountry(&Client{addr: "192.168.10.20"})) + assert.Equal(geoip.Loopback, hub.OnLookupCountry(&Client{addr: "127.0.0.1"})) + assert.Equal(geoip.UnknownCountry, hub.OnLookupCountry(&Client{addr: "8.8.8.8"})) + assert.EqualValues(country1, hub.OnLookupCountry(&Client{addr: "10.1.1.2"})) + assert.EqualValues(country2, hub.OnLookupCountry(&Client{addr: "10.2.1.2"})) + assert.EqualValues(strings.ToUpper(country3), hub.OnLookupCountry(&Client{addr: "192.168.10.20"})) } func TestDialoutStatus(t *testing.T) { diff --git a/mcu_common.go b/mcu_common.go index 38d852c..82b5c54 100644 --- a/mcu_common.go +++ b/mcu_common.go @@ -30,6 +30,7 @@ import ( "github.com/dlintw/goconf" "github.com/strukturag/nextcloud-spreed-signaling/api" + "github.com/strukturag/nextcloud-spreed-signaling/geoip" "github.com/strukturag/nextcloud-spreed-signaling/log" ) @@ -70,7 +71,7 @@ type McuListener interface { } type McuInitiator interface { - Country() string + Country() geoip.Country } type McuSettings interface { diff --git a/mcu_common_test.go b/mcu_common_test.go index 736b447..bdff6fb 100644 --- a/mcu_common_test.go +++ b/mcu_common_test.go @@ -25,6 +25,7 @@ import ( "testing" "github.com/strukturag/nextcloud-spreed-signaling/api" + "github.com/strukturag/nextcloud-spreed-signaling/geoip" ) func TestCommonMcuStats(t *testing.T) { @@ -65,9 +66,9 @@ func (m *MockMcuListener) SubscriberClosed(subscriber McuSubscriber) { } type MockMcuInitiator struct { - country string + country geoip.Country } -func (m *MockMcuInitiator) Country() string { +func (m *MockMcuInitiator) Country() geoip.Country { return m.country } diff --git a/mcu_janus_test.go b/mcu_janus_test.go index 1b2eebb..1b1ba15 100644 --- a/mcu_janus_test.go +++ b/mcu_janus_test.go @@ -37,6 +37,7 @@ import ( "github.com/stretchr/testify/require" "github.com/strukturag/nextcloud-spreed-signaling/api" + "github.com/strukturag/nextcloud-spreed-signaling/geoip" "github.com/strukturag/nextcloud-spreed-signaling/log" "github.com/strukturag/nextcloud-spreed-signaling/mock" ) @@ -675,10 +676,10 @@ func (c *TestMcuController) GetStreams(ctx context.Context) ([]PublisherStream, } type TestMcuInitiator struct { - country string + country geoip.Country } -func (i *TestMcuInitiator) Country() string { +func (i *TestMcuInitiator) Country() geoip.Country { return i.country } diff --git a/mcu_proxy.go b/mcu_proxy.go index 930cdfe..e0407ed 100644 --- a/mcu_proxy.go +++ b/mcu_proxy.go @@ -50,6 +50,7 @@ import ( "github.com/strukturag/nextcloud-spreed-signaling/config" "github.com/strukturag/nextcloud-spreed-signaling/dns" "github.com/strukturag/nextcloud-spreed-signaling/etcd" + "github.com/strukturag/nextcloud-spreed-signaling/geoip" "github.com/strukturag/nextcloud-spreed-signaling/internal" "github.com/strukturag/nextcloud-spreed-signaling/log" ) @@ -77,6 +78,8 @@ const ( rttLogDuration = 500 * time.Millisecond ) +type ContinentsMap map[geoip.Continent][]geoip.Continent + type McuProxy interface { AddConnection(ignoreErrors bool, url string, ips ...net.IP) error KeepConnection(url string, ips ...net.IP) @@ -422,7 +425,7 @@ func newMcuProxyConnection(proxy *mcuProxy, baseUrl string, ip net.IP, token str conn.reconnectInterval.Store(int64(initialReconnectInterval)) conn.load.Store(loadNotConnected) conn.bandwidth.Store(nil) - conn.country.Store("") + conn.country.Store(geoip.Country("")) conn.version.Store("") conn.features.Store([]string{}) statsProxyBackendLoadCurrent.WithLabelValues(conn.url.String()).Set(0) @@ -474,7 +477,7 @@ func (c *mcuProxyConnection) IsSameContinent(initiator McuInitiator) bool { return true } - initiatorContinents, found := ContinentMap[initiatorCountry] + initiatorContinents, found := geoip.ContinentMap[initiatorCountry] if found { m := c.proxy.getContinentsMap() // Map continents to other continents (e.g. use Europe for Africa). @@ -485,7 +488,7 @@ func (c *mcuProxyConnection) IsSameContinent(initiator McuInitiator) bool { } } - connContinents := ContinentMap[connCountry] + connContinents := geoip.ContinentMap[connCountry] return ContinentsOverlap(initiatorContinents, connContinents) } @@ -539,8 +542,8 @@ func (c *mcuProxyConnection) Bandwidth() *EventProxyServerBandwidth { return c.bandwidth.Load() } -func (c *mcuProxyConnection) Country() string { - return c.country.Load().(string) +func (c *mcuProxyConnection) Country() geoip.Country { + return c.country.Load().(geoip.Country) } func (c *mcuProxyConnection) Version() string { @@ -733,7 +736,7 @@ func (c *mcuProxyConnection) close() { c.conn = nil c.connectedSince.Store(0) if c.trackClose.CompareAndSwap(true, false) { - statsConnectedProxyBackendsCurrent.WithLabelValues(c.Country()).Dec() + statsConnectedProxyBackendsCurrent.WithLabelValues(string(c.Country())).Dec() } } } @@ -1005,9 +1008,9 @@ func (c *mcuProxyConnection) processMessage(msg *ProxyServerMessage) { case "hello": resumed := c.SessionId() == msg.Hello.SessionId c.sessionId.Store(msg.Hello.SessionId) - country := "" + var country geoip.Country if server := msg.Hello.Server; server != nil { - if country = server.Country; country != "" && !IsValidCountry(country) { + if country = server.Country; country != "" && !geoip.IsValidCountry(country) { c.logger.Printf("Proxy %s sent invalid country %s in hello response", c, country) country = "" } @@ -1029,7 +1032,7 @@ func (c *mcuProxyConnection) processMessage(msg *ProxyServerMessage) { c.logger.Printf("Received session %s from %s", c.SessionId(), c) } if c.trackClose.CompareAndSwap(false, true) { - statsConnectedProxyBackendsCurrent.WithLabelValues(c.Country()).Inc() + statsConnectedProxyBackendsCurrent.WithLabelValues(string(c.Country())).Inc() } c.helloProcessed.Store(true) @@ -1603,18 +1606,18 @@ func (m *mcuProxy) loadContinentsMap(cfg *goconf.ConfigFile) error { return nil } - continentsMap := make(map[string][]string) + continentsMap := make(ContinentsMap) for option, value := range options { - option = strings.ToUpper(strings.TrimSpace(option)) - if !IsValidContinent(option) { + option := geoip.Continent(strings.ToUpper(strings.TrimSpace(option))) + if !geoip.IsValidContinent(option) { m.logger.Printf("Ignore unknown continent %s", option) continue } - var values []string + var values []geoip.Continent for v := range internal.SplitEntries(value, ",") { - v = strings.ToUpper(v) - if !IsValidContinent(v) { + v := geoip.Continent(strings.ToUpper(v)) + if !geoip.IsValidContinent(v) { m.logger.Printf("Ignore unknown continent %s for override %s", v, option) continue } @@ -1916,24 +1919,24 @@ func (m *mcuProxy) GetServerInfoSfu() *BackendServerInfoSfu { return sfu } -func (m *mcuProxy) getContinentsMap() map[string][]string { +func (m *mcuProxy) getContinentsMap() ContinentsMap { continentsMap := m.continentsMap.Load() if continentsMap == nil { return nil } - return continentsMap.(map[string][]string) + return continentsMap.(ContinentsMap) } -func (m *mcuProxy) setContinentsMap(continentsMap map[string][]string) { +func (m *mcuProxy) setContinentsMap(continentsMap ContinentsMap) { if continentsMap == nil { - continentsMap = make(map[string][]string) + continentsMap = make(ContinentsMap) } m.continentsMap.Store(continentsMap) } type mcuProxyConnectionsList []*mcuProxyConnection -func ContinentsOverlap(a, b []string) bool { +func ContinentsOverlap(a, b []geoip.Continent) bool { if len(a) == 0 || len(b) == 0 { return false } @@ -1946,7 +1949,7 @@ func ContinentsOverlap(a, b []string) bool { return false } -func sortConnectionsForCountry(connections []*mcuProxyConnection, country string, continentMap map[string][]string) []*mcuProxyConnection { +func sortConnectionsForCountry(connections []*mcuProxyConnection, country geoip.Country, continentMap ContinentsMap) []*mcuProxyConnection { // Move connections in the same country to the start of the list. sorted := make(mcuProxyConnectionsList, 0, len(connections)) unprocessed := make(mcuProxyConnectionsList, 0, len(connections)) @@ -1957,7 +1960,7 @@ func sortConnectionsForCountry(connections []*mcuProxyConnection, country string unprocessed = append(unprocessed, conn) } } - if continents, found := ContinentMap[country]; found && len(unprocessed) > 1 { + if continents, found := geoip.ContinentMap[country]; found && len(unprocessed) > 1 { remaining := make(mcuProxyConnectionsList, 0, len(unprocessed)) // Map continents to other continents (e.g. use Europe for Africa). for _, continent := range continents { @@ -1969,8 +1972,8 @@ func sortConnectionsForCountry(connections []*mcuProxyConnection, country string // Next up are connections on the same or mapped continent. for _, conn := range unprocessed { connCountry := conn.Country() - if IsValidCountry(connCountry) { - connContinents := ContinentMap[connCountry] + if geoip.IsValidCountry(connCountry) { + connContinents := geoip.ContinentMap[connCountry] if ContinentsOverlap(continents, connContinents) { sorted = append(sorted, conn) } else { @@ -2013,7 +2016,7 @@ func (m *mcuProxy) getSortedConnections(initiator McuInitiator) []*mcuProxyConne } if initiator != nil { - if country := initiator.Country(); IsValidCountry(country) { + if country := initiator.Country(); geoip.IsValidCountry(country) { connections = sortConnectionsForCountry(connections, country, m.getContinentsMap()) } } diff --git a/mcu_proxy_test.go b/mcu_proxy_test.go index 40f5881..219ef34 100644 --- a/mcu_proxy_test.go +++ b/mcu_proxy_test.go @@ -52,6 +52,7 @@ import ( "github.com/strukturag/nextcloud-spreed-signaling/dns" "github.com/strukturag/nextcloud-spreed-signaling/etcd" "github.com/strukturag/nextcloud-spreed-signaling/etcd/etcdtest" + "github.com/strukturag/nextcloud-spreed-signaling/geoip" "github.com/strukturag/nextcloud-spreed-signaling/internal" "github.com/strukturag/nextcloud-spreed-signaling/log" "github.com/strukturag/nextcloud-spreed-signaling/talk" @@ -66,7 +67,7 @@ func TestMcuProxyStats(t *testing.T) { collectAndLint(t, proxyMcuStats...) } -func newProxyConnectionWithCountry(country string) *mcuProxyConnection { +func newProxyConnectionWithCountry(country geoip.Country) *mcuProxyConnection { conn := &mcuProxyConnection{} conn.country.Store(country) return conn @@ -79,7 +80,7 @@ func Test_sortConnectionsForCountry(t *testing.T) { conn_jp := newProxyConnectionWithCountry("JP") conn_us := newProxyConnectionWithCountry("US") - testcases := map[string][][]*mcuProxyConnection{ + testcases := map[geoip.Country][][]*mcuProxyConnection{ // Direct country match "DE": { {conn_at, conn_jp, conn_de}, @@ -118,7 +119,7 @@ func Test_sortConnectionsForCountry(t *testing.T) { } for country, test := range testcases { - t.Run(country, func(t *testing.T) { + t.Run(string(country), func(t *testing.T) { t.Parallel() sorted := sortConnectionsForCountry(test[0], country, nil) for idx, conn := range sorted { @@ -135,7 +136,7 @@ func Test_sortConnectionsForCountryWithOverride(t *testing.T) { conn_jp := newProxyConnectionWithCountry("JP") conn_us := newProxyConnectionWithCountry("US") - testcases := map[string][][]*mcuProxyConnection{ + testcases := map[geoip.Country][][]*mcuProxyConnection{ // Direct country match "DE": { {conn_at, conn_jp, conn_de}, @@ -183,14 +184,14 @@ func Test_sortConnectionsForCountryWithOverride(t *testing.T) { }, } - continentMap := map[string][]string{ + continentMap := ContinentsMap{ // Use European connections for Africa. "AF": {"EU"}, // Use Asian and North American connections for Oceania. "OC": {"AS", "NA"}, } for country, test := range testcases { - t.Run(country, func(t *testing.T) { + t.Run(string(country), func(t *testing.T) { t.Parallel() sorted := sortConnectionsForCountry(test[0], country, continentMap) for idx, conn := range sorted { @@ -561,7 +562,7 @@ type TestProxyServerHandler struct { servers []*TestProxyServerHandler tokens map[string]*rsa.PublicKey upgrader *websocket.Upgrader - country string + country geoip.Country mu sync.Mutex load atomic.Uint64 @@ -815,7 +816,7 @@ func (h *TestProxyServerHandler) ClearClients() { clear(h.clients) } -func NewProxyServerForTest(t *testing.T, country string) *TestProxyServerHandler { +func NewProxyServerForTest(t *testing.T, country geoip.Country) *TestProxyServerHandler { t.Helper() upgrader := websocket.Upgrader{} diff --git a/proxy/proxy_remote.go b/proxy/proxy_remote.go index 05c1f4f..3e8712e 100644 --- a/proxy/proxy_remote.go +++ b/proxy/proxy_remote.go @@ -41,6 +41,7 @@ import ( signaling "github.com/strukturag/nextcloud-spreed-signaling" "github.com/strukturag/nextcloud-spreed-signaling/api" + "github.com/strukturag/nextcloud-spreed-signaling/geoip" "github.com/strukturag/nextcloud-spreed-signaling/log" ) @@ -468,9 +469,9 @@ func (c *RemoteConnection) processHello(msg *signaling.ProxyServerMessage) { resumed := c.sessionId == msg.Hello.SessionId c.sessionId = msg.Hello.SessionId c.helloReceived = true - country := "" + var country geoip.Country if msg.Hello.Server != nil { - if country = msg.Hello.Server.Country; country != "" && !signaling.IsValidCountry(country) { + if country = msg.Hello.Server.Country; country != "" && !geoip.IsValidCountry(country) { c.logger.Printf("Proxy %s sent invalid country %s in hello response", c, country) country = "" } diff --git a/proxy/proxy_server.go b/proxy/proxy_server.go index a0e7d4b..12f8fcd 100644 --- a/proxy/proxy_server.go +++ b/proxy/proxy_server.go @@ -54,6 +54,7 @@ import ( "github.com/strukturag/nextcloud-spreed-signaling/async" "github.com/strukturag/nextcloud-spreed-signaling/config" "github.com/strukturag/nextcloud-spreed-signaling/container" + "github.com/strukturag/nextcloud-spreed-signaling/geoip" "github.com/strukturag/nextcloud-spreed-signaling/log" ) @@ -108,7 +109,7 @@ var ( type ProxyServer struct { version string - country string + country geoip.Country welcomeMessage string welcomeMsg *api.WelcomeServerMessage config *goconf.ConfigFile @@ -279,9 +280,9 @@ func NewProxyServer(ctx context.Context, r *mux.Router, version string, config * logger.Printf("No trusted proxies configured, only allowing for %s", trustedProxiesIps) } - country, _ := config.GetString("app", "country") - country = strings.ToUpper(country) - if signaling.IsValidCountry(country) { + countryString, _ := config.GetString("app", "country") + country := geoip.Country(strings.ToUpper(countryString)) + if geoip.IsValidCountry(country) { logger.Printf("Sending %s as country information", country) } else if country != "" { return nil, fmt.Errorf("invalid country: %s", country) @@ -859,7 +860,7 @@ func (s *ProxyServer) processMessage(client *ProxyClient, data []byte) { type emptyInitiator struct{} -func (i *emptyInitiator) Country() string { +func (i *emptyInitiator) Country() geoip.Country { return "" } diff --git a/remotesession.go b/remotesession.go index 94bd4b7..8d71bc3 100644 --- a/remotesession.go +++ b/remotesession.go @@ -29,6 +29,7 @@ import ( "time" "github.com/strukturag/nextcloud-spreed-signaling/api" + "github.com/strukturag/nextcloud-spreed-signaling/geoip" "github.com/strukturag/nextcloud-spreed-signaling/log" ) @@ -65,7 +66,7 @@ func NewRemoteSession(hub *Hub, client *Client, remoteClient *GrpcClient, sessio return remoteSession, nil } -func (s *RemoteSession) Country() string { +func (s *RemoteSession) Country() geoip.Country { return s.client.Country() } @@ -138,7 +139,7 @@ func (s *RemoteSession) Close() { s.client.Close() } -func (s *RemoteSession) OnLookupCountry(client HandlerClient) string { +func (s *RemoteSession) OnLookupCountry(client HandlerClient) geoip.Country { return s.hub.OnLookupCountry(client) } diff --git a/scripts/get_continent_map.py b/scripts/get_continent_map.py index 61ef9c4..d86cf3d 100755 --- a/scripts/get_continent_map.py +++ b/scripts/get_continent_map.py @@ -87,13 +87,13 @@ def generate_map(filename): continents.setdefault(country, []).append(continent) out = StringIO() - out.write('package signaling\n') + out.write('package geoip\n') out.write('\n') out.write('// This file has been automatically generated, do not modify.\n') out.write('// Source: %s\n' % (URL)) out.write('\n') out.write('var (\n') - out.write('\tContinentMap = map[string][]string{\n') + out.write('\tContinentMap = map[Country][]Continent{\n') for country, continents in sorted(continents.items()): value = [] for continent in continents: