diff --git a/backend_server.go b/backend_server.go index fb0cfb0..fc3d73f 100644 --- a/backend_server.go +++ b/backend_server.go @@ -52,6 +52,28 @@ const ( sessionIdNotInMeeting = "0" ) +func parseIPNet(s string) (*net.IPNet, error) { + var ipnet *net.IPNet + if strings.ContainsRune(s, '/') { + var err error + if _, ipnet, err = net.ParseCIDR(s); err != nil { + return nil, fmt.Errorf("invalid IP address/subnet %s: %w", s, err) + } + } else { + ip := net.ParseIP(s) + if ip == nil { + return nil, fmt.Errorf("invalid IP address %s", s) + } + + ipnet = &net.IPNet{ + IP: ip, + Mask: net.CIDRMask(len(ip)*8, len(ip)*8), + } + } + + return ipnet, nil +} + type BackendServer struct { hub *Hub events AsyncEvents @@ -65,7 +87,7 @@ type BackendServer struct { turnvalid time.Duration turnservers []string - statsAllowedIps map[string]bool + statsAllowedIps []*net.IPNet invalidSecret []byte } @@ -100,20 +122,26 @@ func NewBackendServer(config *goconf.ConfigFile, hub *Hub, version string) (*Bac } statsAllowed, _ := config.GetString("stats", "allowed_ips") - var statsAllowedIps map[string]bool - if statsAllowed == "" { - log.Printf("No IPs configured for the stats endpoint, only allowing access from 127.0.0.1") - statsAllowedIps = map[string]bool{ - "127.0.0.1": true, - } - } else { - log.Printf("Only allowing access to the stats endpoing from %s", statsAllowed) - statsAllowedIps = make(map[string]bool) - for _, ip := range strings.Split(statsAllowed, ",") { - ip = strings.TrimSpace(ip) - if ip != "" { - statsAllowedIps[ip] = true + var statsAllowedIps []*net.IPNet + for _, ip := range strings.Split(statsAllowed, ",") { + ip = strings.TrimSpace(ip) + if ip != "" { + i, err := parseIPNet(ip) + if err != nil { + return nil, err } + statsAllowedIps = append(statsAllowedIps, i) + } + } + if len(statsAllowedIps) > 0 { + log.Printf("Only allowing access to the stats endpoint from %s", statsAllowed) + } else { + log.Printf("No IPs configured for the stats endpoint, only allowing access from 127.0.0.1") + statsAllowedIps = []*net.IPNet{ + { + IP: net.ParseIP("127.0.0.1"), + Mask: net.CIDRMask(32, 32), + }, } } @@ -745,15 +773,33 @@ func (b *BackendServer) roomHandler(w http.ResponseWriter, r *http.Request, body w.Write([]byte("{}")) // nolint } +func (b *BackendServer) allowStatsAccess(r *http.Request) bool { + addr := getRealUserIP(r) + if strings.Contains(addr, ":") { + if host, _, err := net.SplitHostPort(addr); err == nil { + addr = host + } + } + + ip := net.ParseIP(addr) + if ip == nil { + return false + } + + allowed := false + for _, i := range b.statsAllowedIps { + if i.Contains(ip) { + allowed = true + break + } + } + + return allowed +} + func (b *BackendServer) validateStatsRequest(f func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - addr := getRealUserIP(r) - if strings.Contains(addr, ":") { - if host, _, err := net.SplitHostPort(addr); err == nil { - addr = host - } - } - if !b.statsAllowedIps[addr] { + if !b.allowStatsAccess(r) { http.Error(w, "Authentication check failed", http.StatusForbidden) return } diff --git a/backend_server_test.go b/backend_server_test.go index dbd7446..82f3ee8 100644 --- a/backend_server_test.go +++ b/backend_server_test.go @@ -32,6 +32,7 @@ import ( "io" "net/http" "net/http/httptest" + "net/textproto" "net/url" "reflect" "strings" @@ -1689,3 +1690,73 @@ func TestBackendServer_TurnCredentials(t *testing.T) { t.Errorf("Expected the list of servers as %s, got %s", turnServers, cred.URIs) } } + +func TestBackendServer_StatsAllowedIps(t *testing.T) { + config := goconf.NewConfigFile() + config.AddOption("stats", "allowed_ips", "127.0.0.1, 192.168.0.1, 192.168.1.1/24") + _, backend, _, _, _, _ := CreateBackendServerForTestFromConfig(t, config) + + allowed := []string{ + "127.0.0.1", + "127.0.0.1:1234", + "192.168.0.1:1234", + "192.168.1.1:1234", + "192.168.1.100:1234", + } + notAllowed := []string{ + "192.168.0.2:1234", + "10.1.2.3:1234", + } + + for _, addr := range allowed { + t.Run(addr, func(t *testing.T) { + r1 := &http.Request{ + RemoteAddr: addr, + } + if !backend.allowStatsAccess(r1) { + t.Errorf("should allow %s", addr) + } + + r2 := &http.Request{ + RemoteAddr: "1.2.3.4:12345", + Header: http.Header{ + textproto.CanonicalMIMEHeaderKey("x-real-ip"): []string{addr}, + }, + } + if !backend.allowStatsAccess(r2) { + t.Errorf("should allow %s", addr) + } + + r3 := &http.Request{ + RemoteAddr: "1.2.3.4:12345", + Header: http.Header{ + textproto.CanonicalMIMEHeaderKey("x-forwarded-for"): []string{addr}, + }, + } + if !backend.allowStatsAccess(r3) { + t.Errorf("should allow %s", addr) + } + + r4 := &http.Request{ + RemoteAddr: "1.2.3.4:12345", + Header: http.Header{ + textproto.CanonicalMIMEHeaderKey("x-forwarded-for"): []string{addr + ", 1.2.3.4:23456"}, + }, + } + if !backend.allowStatsAccess(r4) { + t.Errorf("should allow %s", addr) + } + }) + } + + for _, addr := range notAllowed { + t.Run(addr, func(t *testing.T) { + r := &http.Request{ + RemoteAddr: addr, + } + if backend.allowStatsAccess(r) { + t.Errorf("should not allow %s", addr) + } + }) + } +}