Support continent mapping overrides.

This can be used for example to route all users on continent A to proxies
on continent B. Useful if no proxy exists on continent A and the global
selection chooses a non-ideal proxy.
This commit is contained in:
Joachim Bauch 2021-08-06 16:00:54 +02:00
parent ffb79c747c
commit 7bf6fa903b
No known key found for this signature in database
GPG Key ID: 77C1D22D53E15F02
3 changed files with 156 additions and 4 deletions

View File

@ -988,6 +988,8 @@ type mcuProxy struct {
publisherWaitersId uint64
publisherWaiters map[uint64]chan bool
continentsMap atomic.Value
}
func NewMcuProxy(config *goconf.ConfigFile) (Mcu, error) {
@ -1045,6 +1047,10 @@ func NewMcuProxy(config *goconf.ConfigFile) (Mcu, error) {
publisherWaiters: make(map[uint64]chan bool),
}
if err := mcu.loadContinentsMap(config); err != nil {
return nil, err
}
skipverify, _ := config.GetBool("mcu", "skipverify")
if skipverify {
log.Println("WARNING: MCU verification is disabled!")
@ -1085,6 +1091,44 @@ func NewMcuProxy(config *goconf.ConfigFile) (Mcu, error) {
return mcu, nil
}
func (m *mcuProxy) loadContinentsMap(config *goconf.ConfigFile) error {
options, _ := config.GetOptions("continent-overrides")
if len(options) == 0 {
m.setContinentsMap(nil)
return nil
}
continentsMap := make(map[string][]string)
for _, option := range options {
option = strings.ToUpper(strings.TrimSpace(option))
if !IsValidContinent(option) {
log.Printf("Ignore unknown continent %s", option)
continue
}
var values []string
value, _ := config.GetString("continent-overrides", option)
for _, v := range strings.Split(value, ",") {
v = strings.ToUpper(strings.TrimSpace(v))
if !IsValidContinent(v) {
log.Printf("Ignore unknown continent %s for override %s", v, option)
continue
}
values = append(values, v)
}
if len(values) == 0 {
log.Printf("No valid values found for continent override %s, ignoring", option)
continue
}
continentsMap[option] = values
log.Printf("Mapping users on continent %s to %s", option, values)
}
m.setContinentsMap(continentsMap)
return nil
}
func (m *mcuProxy) getEtcdClient() *clientv3.Client {
c := m.client.Load()
if c == nil {
@ -1277,6 +1321,10 @@ func (m *mcuProxy) Reload(config *goconf.ConfigFile) {
m.connectionsMu.Lock()
defer m.connectionsMu.Unlock()
if err := m.loadContinentsMap(config); err != nil {
log.Printf("Error loading continents map: %s", err)
}
remove := make(map[string]*mcuProxyConnection)
for u, conn := range m.connectionsMap {
remove[u] = conn
@ -1462,6 +1510,21 @@ func (m *mcuProxy) GetStats() interface{} {
return result
}
func (m *mcuProxy) getContinentsMap() map[string][]string {
continentsMap := m.continentsMap.Load()
if continentsMap == nil {
return nil
}
return continentsMap.(map[string][]string)
}
func (m *mcuProxy) setContinentsMap(continentsMap map[string][]string) {
if continentsMap == nil {
continentsMap = make(map[string][]string)
}
m.continentsMap.Store(continentsMap)
}
type mcuProxyConnectionsList []*mcuProxyConnection
func (l mcuProxyConnectionsList) Len() int {
@ -1495,7 +1558,7 @@ func ContinentsOverlap(a, b []string) bool {
return false
}
func sortConnectionsForCountry(connections []*mcuProxyConnection, country string) []*mcuProxyConnection {
func sortConnectionsForCountry(connections []*mcuProxyConnection, country string, continentMap map[string][]string) []*mcuProxyConnection {
// Move connections in the same country to the start of the list.
sorted := make(mcuProxyConnectionsList, 0, len(connections))
unprocessed := make(mcuProxyConnectionsList, 0, len(connections))
@ -1508,7 +1571,14 @@ func sortConnectionsForCountry(connections []*mcuProxyConnection, country string
}
if continents, found := ContinentMap[country]; found && len(unprocessed) > 1 {
remaining := make(mcuProxyConnectionsList, 0, len(unprocessed))
// Next up are connections on the same continent.
// Map continents to other continents (e.g. use Europe for Africa).
for _, continent := range continents {
if toAdd, found := continentMap[continent]; found {
continents = append(continents, toAdd...)
}
}
// Next up are connections on the same or mapped continent.
for _, conn := range unprocessed {
connCountry := conn.Country()
if IsValidCountry(connCountry) {
@ -1556,7 +1626,7 @@ func (m *mcuProxy) getSortedConnections(initiator McuInitiator) []*mcuProxyConne
if initiator != nil {
if country := initiator.Country(); IsValidCountry(country) {
connections = sortConnectionsForCountry(connections, country)
connections = sortConnectionsForCountry(connections, country, m.getContinentsMap())
}
}
return connections

View File

@ -79,7 +79,81 @@ func Test_sortConnectionsForCountry(t *testing.T) {
country := country
test := test
t.Run(country, func(t *testing.T) {
sorted := sortConnectionsForCountry(test[0], country)
sorted := sortConnectionsForCountry(test[0], country, nil)
for idx, conn := range sorted {
if test[1][idx] != conn {
t.Errorf("Index %d for %s: expected %s, got %s", idx, country, test[1][idx].Country(), conn.Country())
}
}
})
}
}
func Test_sortConnectionsForCountryWithOverride(t *testing.T) {
conn_de := newProxyConnectionWithCountry("DE")
conn_at := newProxyConnectionWithCountry("AT")
conn_jp := newProxyConnectionWithCountry("JP")
conn_us := newProxyConnectionWithCountry("US")
testcases := map[string][][]*mcuProxyConnection{
// Direct country match
"DE": {
{conn_at, conn_jp, conn_de},
{conn_de, conn_at, conn_jp},
},
// Direct country match
"AT": {
{conn_at, conn_jp, conn_de},
{conn_at, conn_de, conn_jp},
},
// Continent match
"CH": {
{conn_de, conn_jp, conn_at},
{conn_de, conn_at, conn_jp},
},
// Direct country match
"JP": {
{conn_de, conn_jp, conn_at},
{conn_jp, conn_de, conn_at},
},
// Continent match
"CN": {
{conn_de, conn_jp, conn_at},
{conn_jp, conn_de, conn_at},
},
// Partial continent match
"RU": {
{conn_us, conn_de, conn_jp, conn_at},
{conn_de, conn_jp, conn_at, conn_us},
},
// No match
"AR": {
{conn_us, conn_de, conn_jp, conn_at},
{conn_us, conn_de, conn_jp, conn_at},
},
// No match but override (OC -> AS / NA)
"AU": {
{conn_us, conn_jp},
{conn_us, conn_jp},
},
// No match but override (AF -> EU)
"ZA": {
{conn_de, conn_at},
{conn_de, conn_at},
},
}
continentMap := map[string][]string{
// Use European connections for Africa.
"AF": {"EU"},
// Use Asian and North American connections for Oceania.
"OC": {"AS", "NA"},
}
for country, test := range testcases {
country := country
test := test
t.Run(country, func(t *testing.T) {
sorted := sortConnectionsForCountry(test[0], country, continentMap)
for idx, conn := range sorted {
if test[1][idx] != conn {
t.Errorf("Index %d for %s: expected %s, got %s", idx, country, test[1][idx].Country(), conn.Country())

View File

@ -210,6 +210,14 @@ connectionsperhost = 8
#127.0.0.1 = DE
#192.168.0.0/24 = DE
[continent-overrides]
# Optional overrides for continent mappings. The key is a continent code, the
# value a comma-separated list of continent codes to map the continent to.
# Use European servers for clients in Africa.
#AF = EU
# Use servers in North Africa for clients in South America.
#SA = NA
[stats]
# Comma-separated list of IP addresses that are allowed to access the stats
# endpoint. Leave empty (or commented) to only allow access from "127.0.0.1".