Add common code to handle allowed IPs.

This commit is contained in:
Joachim Bauch 2023-03-16 17:11:57 +01:00
parent 407fee2685
commit be949f90b1
No known key found for this signature in database
GPG key ID: 77C1D22D53E15F02
4 changed files with 164 additions and 71 deletions

80
allowed_ips.go Normal file
View file

@ -0,0 +1,80 @@
package signaling
import (
"fmt"
"net"
"strings"
)
type AllowedIps struct {
allowed []*net.IPNet
}
func (a *AllowedIps) Empty() bool {
return len(a.allowed) == 0
}
func (a *AllowedIps) Allowed(ip net.IP) bool {
for _, i := range a.allowed {
if i.Contains(ip) {
return true
}
}
return false
}
func parseIPNet(s string) (*net.IPNet, error) {
var ipnet *net.IPNet
if strings.ContainsRune(s, '/') {
var err error
if _, ipnet, err = net.ParseCIDR(s); err != nil {
return nil, fmt.Errorf("invalid IP address/subnet %s: %w", s, err)
}
} else {
ip := net.ParseIP(s)
if ip == nil {
return nil, fmt.Errorf("invalid IP address %s", s)
}
ipnet = &net.IPNet{
IP: ip,
Mask: net.CIDRMask(len(ip)*8, len(ip)*8),
}
}
return ipnet, nil
}
func ParseAllowedIps(allowed string) (*AllowedIps, error) {
var allowedIps []*net.IPNet
for _, ip := range strings.Split(allowed, ",") {
ip = strings.TrimSpace(ip)
if ip != "" {
i, err := parseIPNet(ip)
if err != nil {
return nil, err
}
allowedIps = append(allowedIps, i)
}
}
result := &AllowedIps{
allowed: allowedIps,
}
return result, nil
}
func DefaultAllowedIps() *AllowedIps {
allowedIps := []*net.IPNet{
{
IP: net.ParseIP("127.0.0.1"),
Mask: net.CIDRMask(32, 32),
},
}
result := &AllowedIps{
allowed: allowedIps,
}
return result
}

49
allowed_ips_test.go Normal file
View file

@ -0,0 +1,49 @@
package signaling
import (
"net"
"testing"
)
func TestAllowedIps(t *testing.T) {
a, err := ParseAllowedIps("127.0.0.1, 192.168.0.1, 192.168.1.1/24")
if err != nil {
t.Fatal(err)
}
if a.Empty() {
t.Fatal("should not be empty")
}
allowed := []string{
"127.0.0.1",
"192.168.0.1",
"192.168.1.1",
"192.168.1.100",
}
notAllowed := []string{
"192.168.0.2",
"10.1.2.3",
}
for _, addr := range allowed {
t.Run(addr, func(t *testing.T) {
ip := net.ParseIP(addr)
if ip == nil {
t.Errorf("error parsing %s", addr)
} else if !a.Allowed(ip) {
t.Errorf("should allow %s", addr)
}
})
}
for _, addr := range notAllowed {
t.Run(addr, func(t *testing.T) {
ip := net.ParseIP(addr)
if ip == nil {
t.Errorf("error parsing %s", addr)
} else if a.Allowed(ip) {
t.Errorf("should not allow %s", addr)
}
})
}
}

View file

@ -52,28 +52,6 @@ const (
sessionIdNotInMeeting = "0"
)
func parseIPNet(s string) (*net.IPNet, error) {
var ipnet *net.IPNet
if strings.ContainsRune(s, '/') {
var err error
if _, ipnet, err = net.ParseCIDR(s); err != nil {
return nil, fmt.Errorf("invalid IP address/subnet %s: %w", s, err)
}
} else {
ip := net.ParseIP(s)
if ip == nil {
return nil, fmt.Errorf("invalid IP address %s", s)
}
ipnet = &net.IPNet{
IP: ip,
Mask: net.CIDRMask(len(ip)*8, len(ip)*8),
}
}
return ipnet, nil
}
type BackendServer struct {
hub *Hub
events AsyncEvents
@ -87,7 +65,7 @@ type BackendServer struct {
turnvalid time.Duration
turnservers []string
statsAllowedIps []*net.IPNet
statsAllowedIps *AllowedIps
invalidSecret []byte
}
@ -122,27 +100,16 @@ func NewBackendServer(config *goconf.ConfigFile, hub *Hub, version string) (*Bac
}
statsAllowed, _ := config.GetString("stats", "allowed_ips")
var statsAllowedIps []*net.IPNet
for _, ip := range strings.Split(statsAllowed, ",") {
ip = strings.TrimSpace(ip)
if ip != "" {
i, err := parseIPNet(ip)
if err != nil {
return nil, err
}
statsAllowedIps = append(statsAllowedIps, i)
}
statsAllowedIps, err := ParseAllowedIps(statsAllowed)
if err != nil {
return nil, err
}
if len(statsAllowedIps) > 0 {
if !statsAllowedIps.Empty() {
log.Printf("Only allowing access to the stats endpoint from %s", statsAllowed)
} else {
log.Printf("No IPs configured for the stats endpoint, only allowing access from 127.0.0.1")
statsAllowedIps = []*net.IPNet{
{
IP: net.ParseIP("127.0.0.1"),
Mask: net.CIDRMask(32, 32),
},
}
statsAllowedIps = DefaultAllowedIps()
}
invalidSecret := make([]byte, 32)
@ -786,15 +753,7 @@ func (b *BackendServer) allowStatsAccess(r *http.Request) bool {
return false
}
allowed := false
for _, i := range b.statsAllowedIps {
if i.Contains(ip) {
allowed = true
break
}
}
return allowed
return b.statsAllowedIps.Allowed(ip)
}
func (b *BackendServer) validateStatsRequest(f func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {

View file

@ -98,7 +98,7 @@ type ProxyServer struct {
upgrader websocket.Upgrader
tokens ProxyTokens
statsAllowedIps map[string]bool
statsAllowedIps *signaling.AllowedIps
sid uint64
cookie *securecookie.SecureCookie
@ -141,21 +141,16 @@ func NewProxyServer(r *mux.Router, version string, config *goconf.ConfigFile) (*
}
statsAllowed, _ := config.GetString("stats", "allowed_ips")
var statsAllowedIps map[string]bool
if statsAllowed == "" {
log.Printf("No IPs configured for the stats endpoint, only allowing access from 127.0.0.1")
statsAllowedIps = map[string]bool{
"127.0.0.1": true,
}
statsAllowedIps, err := signaling.ParseAllowedIps(statsAllowed)
if err != nil {
return nil, err
}
if !statsAllowedIps.Empty() {
log.Printf("Only allowing access to the stats endpoint from %s", statsAllowed)
} else {
log.Printf("Only allowing access to the stats endpoing from %s", statsAllowed)
statsAllowedIps = make(map[string]bool)
for _, ip := range strings.Split(statsAllowed, ",") {
ip = strings.TrimSpace(ip)
if ip != "" {
statsAllowedIps[ip] = true
}
}
log.Printf("No IPs configured for the stats endpoint, only allowing access from 127.0.0.1")
statsAllowedIps = signaling.DefaultAllowedIps()
}
country, _ := config.GetString("app", "country")
@ -996,15 +991,25 @@ func (s *ProxyServer) getStats() map[string]interface{} {
return result
}
func (s *ProxyServer) allowStatsAccess(r *http.Request) bool {
addr := getRealUserIP(r)
if strings.Contains(addr, ":") {
if host, _, err := net.SplitHostPort(addr); err == nil {
addr = host
}
}
ip := net.ParseIP(addr)
if ip == nil {
return false
}
return s.statsAllowedIps.Allowed(ip)
}
func (s *ProxyServer) validateStatsRequest(f func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
addr := getRealUserIP(r)
if strings.Contains(addr, ":") {
if host, _, err := net.SplitHostPort(addr); err == nil {
addr = host
}
}
if !s.statsAllowedIps[addr] {
if !s.allowStatsAccess(r) {
http.Error(w, "Authentication check failed", http.StatusForbidden)
return
}