From be949f90b16230b18d887e3f733048757a1ee564 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Thu, 16 Mar 2023 17:11:57 +0100 Subject: [PATCH] Add common code to handle allowed IPs. --- allowed_ips.go | 80 +++++++++++++++++++++++++++++++++++++++++++ allowed_ips_test.go | 49 ++++++++++++++++++++++++++ backend_server.go | 57 +++++------------------------- proxy/proxy_server.go | 49 ++++++++++++++------------ 4 files changed, 164 insertions(+), 71 deletions(-) create mode 100644 allowed_ips.go create mode 100644 allowed_ips_test.go diff --git a/allowed_ips.go b/allowed_ips.go new file mode 100644 index 0000000..22b57ca --- /dev/null +++ b/allowed_ips.go @@ -0,0 +1,80 @@ +package signaling + +import ( + "fmt" + "net" + "strings" +) + +type AllowedIps struct { + allowed []*net.IPNet +} + +func (a *AllowedIps) Empty() bool { + return len(a.allowed) == 0 +} + +func (a *AllowedIps) Allowed(ip net.IP) bool { + for _, i := range a.allowed { + if i.Contains(ip) { + return true + } + } + + return false +} + +func parseIPNet(s string) (*net.IPNet, error) { + var ipnet *net.IPNet + if strings.ContainsRune(s, '/') { + var err error + if _, ipnet, err = net.ParseCIDR(s); err != nil { + return nil, fmt.Errorf("invalid IP address/subnet %s: %w", s, err) + } + } else { + ip := net.ParseIP(s) + if ip == nil { + return nil, fmt.Errorf("invalid IP address %s", s) + } + + ipnet = &net.IPNet{ + IP: ip, + Mask: net.CIDRMask(len(ip)*8, len(ip)*8), + } + } + + return ipnet, nil +} + +func ParseAllowedIps(allowed string) (*AllowedIps, error) { + var allowedIps []*net.IPNet + for _, ip := range strings.Split(allowed, ",") { + ip = strings.TrimSpace(ip) + if ip != "" { + i, err := parseIPNet(ip) + if err != nil { + return nil, err + } + allowedIps = append(allowedIps, i) + } + } + + result := &AllowedIps{ + allowed: allowedIps, + } + return result, nil +} + +func DefaultAllowedIps() *AllowedIps { + allowedIps := []*net.IPNet{ + { + IP: net.ParseIP("127.0.0.1"), + Mask: net.CIDRMask(32, 32), + }, + } + + result := &AllowedIps{ + allowed: allowedIps, + } + return result +} diff --git a/allowed_ips_test.go b/allowed_ips_test.go new file mode 100644 index 0000000..001da6e --- /dev/null +++ b/allowed_ips_test.go @@ -0,0 +1,49 @@ +package signaling + +import ( + "net" + "testing" +) + +func TestAllowedIps(t *testing.T) { + a, err := ParseAllowedIps("127.0.0.1, 192.168.0.1, 192.168.1.1/24") + if err != nil { + t.Fatal(err) + } + if a.Empty() { + t.Fatal("should not be empty") + } + + allowed := []string{ + "127.0.0.1", + "192.168.0.1", + "192.168.1.1", + "192.168.1.100", + } + notAllowed := []string{ + "192.168.0.2", + "10.1.2.3", + } + + for _, addr := range allowed { + t.Run(addr, func(t *testing.T) { + ip := net.ParseIP(addr) + if ip == nil { + t.Errorf("error parsing %s", addr) + } else if !a.Allowed(ip) { + t.Errorf("should allow %s", addr) + } + }) + } + + for _, addr := range notAllowed { + t.Run(addr, func(t *testing.T) { + ip := net.ParseIP(addr) + if ip == nil { + t.Errorf("error parsing %s", addr) + } else if a.Allowed(ip) { + t.Errorf("should not allow %s", addr) + } + }) + } +} diff --git a/backend_server.go b/backend_server.go index fc3d73f..45a7f6f 100644 --- a/backend_server.go +++ b/backend_server.go @@ -52,28 +52,6 @@ const ( sessionIdNotInMeeting = "0" ) -func parseIPNet(s string) (*net.IPNet, error) { - var ipnet *net.IPNet - if strings.ContainsRune(s, '/') { - var err error - if _, ipnet, err = net.ParseCIDR(s); err != nil { - return nil, fmt.Errorf("invalid IP address/subnet %s: %w", s, err) - } - } else { - ip := net.ParseIP(s) - if ip == nil { - return nil, fmt.Errorf("invalid IP address %s", s) - } - - ipnet = &net.IPNet{ - IP: ip, - Mask: net.CIDRMask(len(ip)*8, len(ip)*8), - } - } - - return ipnet, nil -} - type BackendServer struct { hub *Hub events AsyncEvents @@ -87,7 +65,7 @@ type BackendServer struct { turnvalid time.Duration turnservers []string - statsAllowedIps []*net.IPNet + statsAllowedIps *AllowedIps invalidSecret []byte } @@ -122,27 +100,16 @@ func NewBackendServer(config *goconf.ConfigFile, hub *Hub, version string) (*Bac } statsAllowed, _ := config.GetString("stats", "allowed_ips") - var statsAllowedIps []*net.IPNet - for _, ip := range strings.Split(statsAllowed, ",") { - ip = strings.TrimSpace(ip) - if ip != "" { - i, err := parseIPNet(ip) - if err != nil { - return nil, err - } - statsAllowedIps = append(statsAllowedIps, i) - } + statsAllowedIps, err := ParseAllowedIps(statsAllowed) + if err != nil { + return nil, err } - if len(statsAllowedIps) > 0 { + + if !statsAllowedIps.Empty() { log.Printf("Only allowing access to the stats endpoint from %s", statsAllowed) } else { log.Printf("No IPs configured for the stats endpoint, only allowing access from 127.0.0.1") - statsAllowedIps = []*net.IPNet{ - { - IP: net.ParseIP("127.0.0.1"), - Mask: net.CIDRMask(32, 32), - }, - } + statsAllowedIps = DefaultAllowedIps() } invalidSecret := make([]byte, 32) @@ -786,15 +753,7 @@ func (b *BackendServer) allowStatsAccess(r *http.Request) bool { return false } - allowed := false - for _, i := range b.statsAllowedIps { - if i.Contains(ip) { - allowed = true - break - } - } - - return allowed + return b.statsAllowedIps.Allowed(ip) } func (b *BackendServer) validateStatsRequest(f func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { diff --git a/proxy/proxy_server.go b/proxy/proxy_server.go index ac837c3..e057b59 100644 --- a/proxy/proxy_server.go +++ b/proxy/proxy_server.go @@ -98,7 +98,7 @@ type ProxyServer struct { upgrader websocket.Upgrader tokens ProxyTokens - statsAllowedIps map[string]bool + statsAllowedIps *signaling.AllowedIps sid uint64 cookie *securecookie.SecureCookie @@ -141,21 +141,16 @@ func NewProxyServer(r *mux.Router, version string, config *goconf.ConfigFile) (* } statsAllowed, _ := config.GetString("stats", "allowed_ips") - var statsAllowedIps map[string]bool - if statsAllowed == "" { - log.Printf("No IPs configured for the stats endpoint, only allowing access from 127.0.0.1") - statsAllowedIps = map[string]bool{ - "127.0.0.1": true, - } + statsAllowedIps, err := signaling.ParseAllowedIps(statsAllowed) + if err != nil { + return nil, err + } + + if !statsAllowedIps.Empty() { + log.Printf("Only allowing access to the stats endpoint from %s", statsAllowed) } else { - log.Printf("Only allowing access to the stats endpoing from %s", statsAllowed) - statsAllowedIps = make(map[string]bool) - for _, ip := range strings.Split(statsAllowed, ",") { - ip = strings.TrimSpace(ip) - if ip != "" { - statsAllowedIps[ip] = true - } - } + log.Printf("No IPs configured for the stats endpoint, only allowing access from 127.0.0.1") + statsAllowedIps = signaling.DefaultAllowedIps() } country, _ := config.GetString("app", "country") @@ -996,15 +991,25 @@ func (s *ProxyServer) getStats() map[string]interface{} { return result } +func (s *ProxyServer) allowStatsAccess(r *http.Request) bool { + addr := getRealUserIP(r) + if strings.Contains(addr, ":") { + if host, _, err := net.SplitHostPort(addr); err == nil { + addr = host + } + } + + ip := net.ParseIP(addr) + if ip == nil { + return false + } + + return s.statsAllowedIps.Allowed(ip) +} + func (s *ProxyServer) validateStatsRequest(f func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - addr := getRealUserIP(r) - if strings.Contains(addr, ":") { - if host, _, err := net.SplitHostPort(addr); err == nil { - addr = host - } - } - if !s.statsAllowedIps[addr] { + if !s.allowStatsAccess(r) { http.Error(w, "Authentication check failed", http.StatusForbidden) return }