diff --git a/config.example.yml b/config.example.yml index cb8e7f5..da44a45 100644 --- a/config.example.yml +++ b/config.example.yml @@ -74,3 +74,8 @@ security: auto_block_enabled: true auto_block_threshold: 50 auto_block_duration: 24h0m0s + # Only list proxies you control. Forwarded client IPs are honored solely + # when the connecting peer matches; an empty list ignores them, which + # prevents X-Forwarded-For / CF-Connecting-IP spoofing. + trusted_proxies: [] + trust_cf_header: false diff --git a/internal/api/config_handlers.go b/internal/api/config_handlers.go index 75b2d7c..6e39368 100644 --- a/internal/api/config_handlers.go +++ b/internal/api/config_handlers.go @@ -1,6 +1,7 @@ package api import ( + "fmt" "net/http" "strings" @@ -53,23 +54,61 @@ func (s *Server) updateConfigKey(c *gin.Context) { } applied := false + var applyErr error if apply, ok := s.runtimeAppliers()[key]; ok { - apply(s) - applied = true + applyErr = apply(s) + applied = applyErr == nil } entry, _ := config.Get(s.config, key) - c.JSON(http.StatusOK, gin.H{ + resp := gin.H{ "entry": entry, "applied": applied, - }) + } + if applyErr != nil { + resp["apply_error"] = applyErr.Error() + } + c.JSON(http.StatusOK, resp) } -func (s *Server) runtimeAppliers() map[string]func(*Server) { - return map[string]func(*Server){ - "cleanup.timeout": func(srv *Server) { +func (s *Server) runtimeAppliers() map[string]func(*Server) error { + applyDetectorThresholds := func(srv *Server) error { + if srv.securityManager == nil { + return nil + } + srv.securityManager.SetDetectorThresholds( + srv.config.Security.RateThreshold, + srv.config.Security.NotFoundThreshold, + srv.config.Security.AuthFailureThreshold, + srv.config.Security.UniquePathsThreshold, + srv.config.Security.RepeatedHitsThreshold, + srv.config.Security.DetectionWindow, + ) + return nil + } + regenerateSecurityScripts := func(srv *Server) error { + if !srv.config.Security.Enabled { + return nil + } + if !srv.infraManager.IsNginxRunning() { + return fmt.Errorf("value saved but nginx is not running; regenerate security scripts to apply") + } + _, err := srv.infraManager.RefreshSecurityScripts() + return err + } + return map[string]func(*Server) error{ + "cleanup.timeout": func(srv *Server) error { srv.manager.SetCleanupTimeout(srv.config.Cleanup.Timeout) + return nil }, + "security.rate_threshold": applyDetectorThresholds, + "security.not_found_threshold": applyDetectorThresholds, + "security.auth_failure_threshold": applyDetectorThresholds, + "security.unique_paths_threshold": applyDetectorThresholds, + "security.repeated_hits_threshold": applyDetectorThresholds, + "security.detection_window": applyDetectorThresholds, + "security.trusted_proxies": regenerateSecurityScripts, + "security.trust_cf_header": regenerateSecurityScripts, } } diff --git a/internal/api/config_handlers_test.go b/internal/api/config_handlers_test.go new file mode 100644 index 0000000..f9edd3f --- /dev/null +++ b/internal/api/config_handlers_test.go @@ -0,0 +1,31 @@ +package api + +import ( + "testing" + + "github.com/flatrun/agent/pkg/config" +) + +func TestRuntimeConfigKeysAdvertisesTrustKeys(t *testing.T) { + server := &Server{config: &config.Config{}} + keys := server.runtimeConfigKeys() + + for _, key := range []string{"security.trusted_proxies", "security.trust_cf_header"} { + if !keys[key] { + t.Errorf("expected %q to be advertised as a runtime config key", key) + } + } +} + +func TestTrustKeyApplierNoOpWhenSecurityDisabled(t *testing.T) { + server := &Server{config: &config.Config{}} + server.config.Security.Enabled = false + + apply := server.runtimeAppliers()["security.trusted_proxies"] + if apply == nil { + t.Fatal("expected an applier for security.trusted_proxies") + } + if err := apply(server); err != nil { + t.Fatalf("applier should be a no-op when security is disabled, got: %v", err) + } +} diff --git a/internal/infra/manager.go b/internal/infra/manager.go index 472e47b..74ab02f 100644 --- a/internal/infra/manager.go +++ b/internal/infra/manager.go @@ -379,7 +379,7 @@ func (m *Manager) SetNginxRealtimeCaptureWithStatus(enabled bool) (map[string]in result["agent_ip"] = agentIP result["agent_port"] = agentPort - securityLua, err := templates.GetNginxSecurityLuaWithConfig(agentIP, agentPort, m.config.Security.InternalAPIToken) + securityLua, err := templates.GetNginxSecurityLuaWithConfig(agentIP, agentPort, m.config.Security.InternalAPIToken, m.config.Security.TrustedProxies, m.config.Security.TrustCFHeader) if err != nil { errors = append(errors, fmt.Sprintf("failed to get security.lua template: %v", err)) } else { @@ -1245,7 +1245,7 @@ func (m *Manager) RefreshSecurityScripts() (*RefreshSecurityScriptsResult, error } // Generate and write security.lua with injected IP - securityLua, err := templates.GetNginxSecurityLuaWithConfig(agentIP, agentPort, m.config.Security.InternalAPIToken) + securityLua, err := templates.GetNginxSecurityLuaWithConfig(agentIP, agentPort, m.config.Security.InternalAPIToken, m.config.Security.TrustedProxies, m.config.Security.TrustCFHeader) if err != nil { result.Errors = append(result.Errors, fmt.Sprintf("failed to generate security.lua: %v", err)) result.Success = false diff --git a/internal/security/manager.go b/internal/security/manager.go index 68b22f2..c3540c5 100644 --- a/internal/security/manager.go +++ b/internal/security/manager.go @@ -11,6 +11,9 @@ type Manager struct { detector *Detector deploymentsPath string mu sync.RWMutex + + wlMu sync.RWMutex + wlCache *whitelistCache } func NewManager(deploymentsPath string) (*Manager, error) { @@ -57,6 +60,14 @@ func (m *Manager) IngestEvent(event *IngestEvent, autoBlockDuration time.Duratio result := &IngestResult{} + whitelisted, err := m.IsRequestWhitelisted(event.SourceIP, event.RequestPath) + if err != nil { + return nil, err + } + if whitelisted { + return result, nil + } + // Check if IP is blocked - if so, don't process blocked, err := m.db.IsIPBlocked(event.SourceIP) if err != nil { @@ -180,11 +191,19 @@ func (m *Manager) GetWhitelist() ([]WhitelistEntry, error) { } func (m *Manager) AddWhitelistEntry(value, entryType, reason string) (int64, error) { - return m.db.AddWhitelistEntry(value, entryType, reason, false) + id, err := m.db.AddWhitelistEntry(value, entryType, reason, false) + if err == nil { + m.invalidateWhitelistCache() + } + return id, err } func (m *Manager) RemoveWhitelistEntry(id int64) error { - return m.db.RemoveWhitelistEntry(id) + err := m.db.RemoveWhitelistEntry(id) + if err == nil { + m.invalidateWhitelistCache() + } + return err } func (m *Manager) IsWhitelisted(value string) (bool, error) { @@ -196,6 +215,9 @@ func (m *Manager) AddDockerGatewayToWhitelist(gatewayIP string) error { return nil } _, err := m.db.AddWhitelistEntry(gatewayIP, "ip", "Docker gateway", true) + if err == nil { + m.invalidateWhitelistCache() + } return err } diff --git a/internal/security/manager_test.go b/internal/security/manager_test.go new file mode 100644 index 0000000..fd3a242 --- /dev/null +++ b/internal/security/manager_test.go @@ -0,0 +1,186 @@ +package security + +import ( + "testing" + "time" +) + +func newTestManager(t *testing.T) *Manager { + t.Helper() + m, err := NewManager(t.TempDir()) + if err != nil { + t.Fatalf("NewManager: %v", err) + } + t.Cleanup(func() { m.Close() }) + return m +} + +func ingestAuthFailures(t *testing.T, m *Manager, ip, path string, n int) *IngestResult { + t.Helper() + var last *IngestResult + for i := 0; i < n; i++ { + var err error + last, err = m.IngestEvent(&IngestEvent{ + SourceIP: ip, + RequestPath: path, + RequestMethod: "GET", + StatusCode: 401, + UserAgent: "Mozilla/5.0", + }, time.Hour) + if err != nil { + t.Fatalf("IngestEvent: %v", err) + } + } + return last +} + +func TestIngestEventAutoBlocksOnRepeatedAuthFailures(t *testing.T) { + m := newTestManager(t) + + result := ingestAuthFailures(t, m, "203.0.113.10", "/api/v1/stats", 5) + + if !result.AutoBlocked { + t.Fatal("expected IP to be auto-blocked after repeated auth failures") + } + blocked, err := m.IsIPBlocked("203.0.113.10") + if err != nil { + t.Fatalf("IsIPBlocked: %v", err) + } + if !blocked { + t.Fatal("expected IP to be in blocked list") + } +} + +func TestIngestEventSkipsWhitelistedIP(t *testing.T) { + m := newTestManager(t) + + if _, err := m.AddWhitelistEntry("203.0.113.7", "ip", "test"); err != nil { + t.Fatalf("AddWhitelistEntry: %v", err) + } + + result := ingestAuthFailures(t, m, "203.0.113.7", "/api/v1/stats", 20) + + if result.Event != nil { + t.Fatal("expected no event for whitelisted IP") + } + if result.AutoBlocked { + t.Fatal("expected whitelisted IP to never be auto-blocked") + } + blocked, err := m.IsIPBlocked("203.0.113.7") + if err != nil { + t.Fatalf("IsIPBlocked: %v", err) + } + if blocked { + t.Fatal("whitelisted IP must not be blocked") + } +} + +func TestIngestEventSkipsIPInWhitelistedCIDR(t *testing.T) { + m := newTestManager(t) + + if _, err := m.AddWhitelistEntry("198.51.100.0/24", "cidr", "test range"); err != nil { + t.Fatalf("AddWhitelistEntry: %v", err) + } + + result := ingestAuthFailures(t, m, "198.51.100.20", "/api/v1/stats", 20) + + if result.Event != nil || result.AutoBlocked { + t.Fatal("expected IP inside whitelisted CIDR to be skipped") + } +} + +func TestIngestEventSkipsSeededPrivateNetworks(t *testing.T) { + m := newTestManager(t) + + for _, ip := range []string{"127.0.0.1", "10.1.2.3", "172.18.0.5", "192.168.1.50"} { + result := ingestAuthFailures(t, m, ip, "/api/v1/stats", 20) + if result.Event != nil || result.AutoBlocked { + t.Fatalf("expected default-whitelisted IP %s to be skipped", ip) + } + } +} + +func TestIngestEventSkipsWhitelistedPathPrefix(t *testing.T) { + m := newTestManager(t) + + result, err := m.IngestEvent(&IngestEvent{ + SourceIP: "203.0.113.99", + RequestPath: "/api/health", + RequestMethod: "GET", + StatusCode: 500, + UserAgent: "Mozilla/5.0", + }, time.Hour) + if err != nil { + t.Fatalf("IngestEvent: %v", err) + } + + if result.Event != nil || result.AutoBlocked { + t.Fatal("expected request to whitelisted path to be skipped") + } +} + +func TestWhitelistCacheInvalidatedOnMutation(t *testing.T) { + m := newTestManager(t) + + got, err := m.IsRequestWhitelisted("203.0.113.40", "/api/v1/stats") + if err != nil { + t.Fatalf("IsRequestWhitelisted: %v", err) + } + if got { + t.Fatal("IP unexpectedly whitelisted before adding entry") + } + + id, err := m.AddWhitelistEntry("203.0.113.40", "ip", "test") + if err != nil { + t.Fatalf("AddWhitelistEntry: %v", err) + } + got, err = m.IsRequestWhitelisted("203.0.113.40", "/api/v1/stats") + if err != nil { + t.Fatalf("IsRequestWhitelisted: %v", err) + } + if !got { + t.Fatal("expected entry added after cache build to be honored") + } + + if err := m.RemoveWhitelistEntry(id); err != nil { + t.Fatalf("RemoveWhitelistEntry: %v", err) + } + got, err = m.IsRequestWhitelisted("203.0.113.40", "/api/v1/stats") + if err != nil { + t.Fatalf("IsRequestWhitelisted: %v", err) + } + if got { + t.Fatal("expected removed entry to stop matching") + } +} + +func TestIsRequestWhitelisted(t *testing.T) { + m := newTestManager(t) + + if _, err := m.AddWhitelistEntry("2001:db8::/32", "cidr", "test v6"); err != nil { + t.Fatalf("AddWhitelistEntry: %v", err) + } + + cases := []struct { + ip string + path string + want bool + }{ + {"127.0.0.1", "/anything", true}, + {"10.255.0.1", "/anything", true}, + {"2001:db8::1", "/anything", true}, + {"203.0.113.5", "/api/_internal/blocked-ips", true}, + {"203.0.113.5", "/wp-login.php", false}, + {"not-an-ip", "/wp-login.php", false}, + } + + for _, tc := range cases { + got, err := m.IsRequestWhitelisted(tc.ip, tc.path) + if err != nil { + t.Fatalf("IsRequestWhitelisted(%s, %s): %v", tc.ip, tc.path, err) + } + if got != tc.want { + t.Errorf("IsRequestWhitelisted(%s, %s) = %v, want %v", tc.ip, tc.path, got, tc.want) + } + } +} diff --git a/internal/security/whitelist.go b/internal/security/whitelist.go new file mode 100644 index 0000000..d0702af --- /dev/null +++ b/internal/security/whitelist.go @@ -0,0 +1,84 @@ +package security + +import ( + "net/netip" + "strings" +) + +type whitelistCache struct { + ips map[string]struct{} + prefixes []netip.Prefix + paths []string +} + +func (m *Manager) whitelistCacheLoad() (*whitelistCache, error) { + m.wlMu.RLock() + cache := m.wlCache + m.wlMu.RUnlock() + if cache != nil { + return cache, nil + } + + entries, err := m.db.GetWhitelist() + if err != nil { + return nil, err + } + + cache = &whitelistCache{ips: make(map[string]struct{})} + for _, entry := range entries { + switch entry.Type { + case "ip": + cache.ips[entry.Value] = struct{}{} + case "cidr": + if prefix, err := netip.ParsePrefix(entry.Value); err == nil { + cache.prefixes = append(cache.prefixes, prefix) + } + case "path": + cache.paths = append(cache.paths, entry.Value) + } + } + + m.wlMu.Lock() + m.wlCache = cache + m.wlMu.Unlock() + return cache, nil +} + +func (m *Manager) invalidateWhitelistCache() { + m.wlMu.Lock() + m.wlCache = nil + m.wlMu.Unlock() +} + +// IsRequestWhitelisted reports whether a request's source IP or path matches +// any whitelist entry. IP entries match exactly, CIDR entries match contained +// addresses, and path entries match by prefix. +func (m *Manager) IsRequestWhitelisted(ip, path string) (bool, error) { + cache, err := m.whitelistCacheLoad() + if err != nil { + return false, err + } + + if ip != "" { + if _, ok := cache.ips[ip]; ok { + return true, nil + } + if addr, err := netip.ParseAddr(ip); err == nil { + for _, prefix := range cache.prefixes { + if prefix.Contains(addr) { + return true, nil + } + } + } + } + + if path != "" { + for _, p := range cache.paths { + if strings.HasPrefix(path, p) { + return true, nil + } + } + } + + return false, nil +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 9753c4c..199bb52 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -166,6 +166,9 @@ type SecurityConfig struct { // Internal API token for nginx-to-agent communication (auto-generated if empty) InternalAPIToken string `yaml:"internal_api_token" json:"-"` + + TrustedProxies []string `yaml:"trusted_proxies" json:"trusted_proxies"` + TrustCFHeader bool `yaml:"trust_cf_header" json:"trust_cf_header"` } type AuditConfig struct { diff --git a/templates/infra/nginx/lua/security.lua b/templates/infra/nginx/lua/security.lua index 55c40b2..de1e1b5 100644 --- a/templates/infra/nginx/lua/security.lua +++ b/templates/infra/nginx/lua/security.lua @@ -11,6 +11,9 @@ local AGENT_IP = "{{.AgentIP}}" local AGENT_PORT = {{.AgentPort}} local INTERNAL_TOKEN = "{{.InternalAPIToken}}" +local TRUSTED_PROXIES_RAW = "{{.TrustedProxies}}" +local TRUST_CF_HEADER = {{if .TrustCFHeader}}true{{else}}false{{end}} + -- Cache settings local CACHE_TTL = 30 -- seconds local BLOCKED_IPS_LAST_FETCH = "blocked_ips_last_fetch" @@ -63,27 +66,6 @@ local scanner_patterns = { "zgrab", } -local function get_real_client_ip() - local cf_ip = ngx.var.http_cf_connecting_ip - if cf_ip and cf_ip ~= "" then - return cf_ip - end - - local xff = ngx.var.http_x_forwarded_for - if xff and xff ~= "" then - local first_ip = xff:match("^([^,]+)") - if first_ip then - return first_ip:match("^%s*(.-)%s*$") - end - end - - return ngx.var.remote_addr -end - -function _M.get_client_ip() - return get_real_client_ip() -end - function _M.is_blocked(ip) if not ip then return false end if _M.is_whitelisted(ip, nil) then return false end @@ -271,6 +253,59 @@ local function is_ip_in_cidr(ip, cidr) end end +local trusted_proxies = {} +for cidr in TRUSTED_PROXIES_RAW:gmatch("[^,]+") do + local trimmed = cidr:match("^%s*(.-)%s*$") + if trimmed ~= "" then + trusted_proxies[#trusted_proxies + 1] = trimmed + end +end + +local function is_trusted_proxy(ip) + if not ip then return false end + for _, cidr in ipairs(trusted_proxies) do + if is_ip_in_cidr(ip, cidr) then return true end + end + return false +end + +local function get_real_client_ip() + local peer = ngx.var.remote_addr + + if not is_trusted_proxy(peer) then + return peer + end + + if TRUST_CF_HEADER then + local cf_ip = ngx.var.http_cf_connecting_ip + if cf_ip and cf_ip ~= "" then + return cf_ip:match("^%s*(.-)%s*$") + end + end + + local xff = ngx.var.http_x_forwarded_for + if xff and xff ~= "" then + local hops = {} + for hop in xff:gmatch("[^,]+") do + hops[#hops + 1] = hop:match("^%s*(.-)%s*$") + end + for i = #hops, 1, -1 do + if not is_trusted_proxy(hops[i]) then + return hops[i] + end + end + if #hops > 0 then + return hops[1] + end + end + + return peer +end + +function _M.get_client_ip() + return get_real_client_ip() +end + function _M.is_whitelisted(ip, path) local dict = ngx.shared.whitelist if not dict then return false end diff --git a/templates/templates.go b/templates/templates.go index 76f6b29..b974f4d 100644 --- a/templates/templates.go +++ b/templates/templates.go @@ -4,7 +4,9 @@ import ( "bytes" "embed" "io/fs" + "net/netip" "path/filepath" + "strings" "text/template" ) @@ -129,15 +131,41 @@ func GetNginxSecurityLua() ([]byte, error) { return FS.ReadFile("infra/nginx/lua/security.lua") } +// sanitizeTrustedProxies keeps only well-formed IP and CIDR entries in their +// canonical form. Anything else is dropped, which both rejects malformed +// config and guarantees the value cannot break out of the Lua string literal +// it is injected into. +func sanitizeTrustedProxies(entries []string) []string { + out := make([]string, 0, len(entries)) + for _, e := range entries { + e = strings.TrimSpace(e) + if e == "" { + continue + } + if prefix, err := netip.ParsePrefix(e); err == nil { + out = append(out, prefix.String()) + continue + } + // ParseAddr accepts IPv6 zone IDs that can carry arbitrary characters; + // a zone is meaningless for proxy trust, so reject those entries + if addr, err := netip.ParseAddr(e); err == nil && addr.Zone() == "" { + out = append(out, addr.String()) + } + } + return out +} + // LuaTemplateData contains the data for Lua template processing type LuaTemplateData struct { AgentIP string AgentPort int InternalAPIToken string + TrustedProxies string + TrustCFHeader bool } // GetNginxSecurityLuaWithConfig returns the security.lua template processed with agent config -func GetNginxSecurityLuaWithConfig(agentIP string, agentPort int, internalAPIToken string) ([]byte, error) { +func GetNginxSecurityLuaWithConfig(agentIP string, agentPort int, internalAPIToken string, trustedProxies []string, trustCFHeader bool) ([]byte, error) { content, err := FS.ReadFile("infra/nginx/lua/security.lua") if err != nil { return nil, err @@ -153,6 +181,8 @@ func GetNginxSecurityLuaWithConfig(agentIP string, agentPort int, internalAPITok AgentIP: agentIP, AgentPort: agentPort, InternalAPIToken: internalAPIToken, + TrustedProxies: strings.Join(sanitizeTrustedProxies(trustedProxies), ","), + TrustCFHeader: trustCFHeader, } if err := tmpl.Execute(&buf, data); err != nil { diff --git a/templates/trust_render_test.go b/templates/trust_render_test.go new file mode 100644 index 0000000..8d6e0c3 --- /dev/null +++ b/templates/trust_render_test.go @@ -0,0 +1,77 @@ +package templates + +import ( + "strings" + "testing" +) + +func TestSecurityLuaRendersTrustConfig(t *testing.T) { + out, err := GetNginxSecurityLuaWithConfig("10.0.0.1", 8080, "tok", nil, false) + if err != nil { + t.Fatalf("render default: %v", err) + } + s := string(out) + if !strings.Contains(s, `local TRUSTED_PROXIES_RAW = ""`) { + t.Errorf("default TRUSTED_PROXIES_RAW not empty:\n%s", grepLua(s)) + } + if !strings.Contains(s, `local TRUST_CF_HEADER = false`) { + t.Errorf("default TRUST_CF_HEADER not false") + } + if strings.Contains(s, "{{") { + t.Errorf("unrendered template directive remains") + } + + out, err = GetNginxSecurityLuaWithConfig("10.0.0.1", 8080, "tok", []string{"103.21.244.0/22", "172.16.0.0/12"}, true) + if err != nil { + t.Fatalf("render trusted: %v", err) + } + s = string(out) + if !strings.Contains(s, `local TRUSTED_PROXIES_RAW = "103.21.244.0/22,172.16.0.0/12"`) { + t.Errorf("trusted proxies not joined:\n%s", grepLua(s)) + } + if !strings.Contains(s, `local TRUST_CF_HEADER = true`) { + t.Errorf("TRUST_CF_HEADER not true") + } +} + +func TestSecurityLuaSanitizesTrustedProxies(t *testing.T) { + malicious := []string{ + `10.0.0.0/8";os.execute("touch /tmp/pwned");--`, + `fe80::1%e"vil`, + "1.2.3.4\nlocal x = 1", + "not-an-ip", + " 192.168.0.0/16 ", + "203.0.113.7", + } + out, err := GetNginxSecurityLuaWithConfig("10.0.0.1", 8080, "tok", malicious, false) + if err != nil { + t.Fatalf("render: %v", err) + } + s := string(out) + + if !strings.Contains(s, `local TRUSTED_PROXIES_RAW = "192.168.0.0/16,203.0.113.7"`) { + t.Errorf("expected only the two valid entries, got:\n%s", grepLua(s)) + } + line := "" + for _, l := range strings.Split(s, "\n") { + if strings.Contains(l, "TRUSTED_PROXIES_RAW") { + line = l + break + } + } + for _, bad := range []string{"os.execute", `\"`, "\\", "%", "not-an-ip", "local x"} { + if strings.Contains(line, bad) { + t.Errorf("sanitized line still contains %q: %s", bad, line) + } + } +} + +func grepLua(s string) string { + var b strings.Builder + for _, line := range strings.Split(s, "\n") { + if strings.Contains(line, "TRUSTED_PROXIES") || strings.Contains(line, "TRUST_CF") { + b.WriteString(line + "\n") + } + } + return b.String() +} diff --git a/test/e2e/docker-compose.lua.yml b/test/e2e/docker-compose.lua.yml index 37d7d1e..24dd280 100644 --- a/test/e2e/docker-compose.lua.yml +++ b/test/e2e/docker-compose.lua.yml @@ -27,6 +27,7 @@ services: - "18082:80" environment: - FLATRUN_AGENT_URL=http://flatrun-e2e-lua-agent:8090 + - FLATRUN_TRUSTED_PROXIES=10.0.0.0/8,172.16.0.0/12,192.168.0.0/16 volumes: - ./nginx/lua/nginx.conf:/usr/local/openresty/nginx/conf/nginx.conf:ro - ./nginx/lua/default.conf:/etc/nginx/conf.d/default.conf:ro diff --git a/test/e2e/lua_test.go b/test/e2e/lua_test.go index 343e035..d72835a 100644 --- a/test/e2e/lua_test.go +++ b/test/e2e/lua_test.go @@ -52,6 +52,7 @@ func TestLuaRealtimeCapture(t *testing.T) { req, _ := http.NewRequest("GET", fmt.Sprintf("http://localhost:%s/admin", luaNginxPort), nil) req.Header.Set("User-Agent", "Mozilla/5.0 FlatRunTest") + req.Header.Set("X-Forwarded-For", "203.0.113.50") client := &http.Client{} resp, err := client.Do(req) @@ -78,6 +79,7 @@ func TestLuaRealtimeCapture(t *testing.T) { req, _ := http.NewRequest("GET", fmt.Sprintf("http://localhost:%s/api/private", luaNginxPort), nil) req.Header.Set("User-Agent", "Mozilla/5.0 FlatRunTest") + req.Header.Set("X-Forwarded-For", "203.0.113.51") client := &http.Client{} resp, err := client.Do(req) @@ -104,6 +106,7 @@ func TestLuaRealtimeCapture(t *testing.T) { req, _ := http.NewRequest("GET", fmt.Sprintf("http://localhost:%s/error", luaNginxPort), nil) req.Header.Set("User-Agent", "Mozilla/5.0 FlatRunTest") + req.Header.Set("X-Forwarded-For", "203.0.113.52") client := &http.Client{} resp, err := client.Do(req) @@ -130,6 +133,7 @@ func TestLuaRealtimeCapture(t *testing.T) { req, _ := http.NewRequest("GET", fmt.Sprintf("http://localhost:%s/.env", luaNginxPort), nil) req.Header.Set("User-Agent", "Mozilla/5.0 FlatRunTest") + req.Header.Set("X-Forwarded-For", "203.0.113.53") client := &http.Client{} resp, err := client.Do(req) @@ -152,6 +156,7 @@ func TestLuaRealtimeCapture(t *testing.T) { req, _ := http.NewRequest("GET", fmt.Sprintf("http://localhost:%s/", luaNginxPort), nil) req.Header.Set("User-Agent", "nikto/2.1.6") + req.Header.Set("X-Forwarded-For", "203.0.113.54") client := &http.Client{} resp, err := client.Do(req) diff --git a/test/e2e/nginx/lua/security.lua b/test/e2e/nginx/lua/security.lua index 982fe7e..d02620b 100644 --- a/test/e2e/nginx/lua/security.lua +++ b/test/e2e/nginx/lua/security.lua @@ -11,6 +11,8 @@ local _M = {} -- Configuration via environment variable (test-specific) local AGENT_URL = os.getenv("FLATRUN_AGENT_URL") or "http://host.docker.internal:8080" local INTERNAL_TOKEN = os.getenv("FLATRUN_INTERNAL_TOKEN") or "" +local TRUSTED_PROXIES_RAW = os.getenv("FLATRUN_TRUSTED_PROXIES") or "" +local TRUST_CF_HEADER = os.getenv("FLATRUN_TRUST_CF_HEADER") == "true" -- Blocked IPs cache settings local BLOCKED_IPS_CACHE_TTL = 30 -- seconds @@ -138,6 +140,75 @@ local scanner_patterns = { "zgrab", } +local function ipv4_to_int(ip_str) + local parts = {ip_str:match("^(%d+)%.(%d+)%.(%d+)%.(%d+)$")} + if #parts ~= 4 then return nil end + return tonumber(parts[1]) * 16777216 + tonumber(parts[2]) * 65536 + + tonumber(parts[3]) * 256 + tonumber(parts[4]) +end + +local function is_ipv4_in_cidr(ip, cidr) + local cidr_ip, cidr_bits = cidr:match("^(.+)/(%d+)$") + if not cidr_ip then return ip == cidr end + -- IPv6 ranges fail closed (untrusted) so a spoofable forwarded header is never honored + if ip:find(":") or cidr_ip:find(":") then return false end + local bits = tonumber(cidr_bits) + local ip_int = ipv4_to_int(ip) + local cidr_int = ipv4_to_int(cidr_ip) + if not ip_int or not cidr_int then return false end + local mask = bits == 0 and 0 or (0xFFFFFFFF - (2^(32 - bits) - 1)) + return bit.band(ip_int, mask) == bit.band(cidr_int, mask) +end + +local trusted_proxies = {} +for cidr in TRUSTED_PROXIES_RAW:gmatch("[^,]+") do + local trimmed = cidr:match("^%s*(.-)%s*$") + if trimmed ~= "" then + trusted_proxies[#trusted_proxies + 1] = trimmed + end +end + +local function is_trusted_proxy(ip) + if not ip then return false end + for _, cidr in ipairs(trusted_proxies) do + if is_ipv4_in_cidr(ip, cidr) then return true end + end + return false +end + +local function get_real_client_ip() + local peer = ngx.var.remote_addr + + if not is_trusted_proxy(peer) then + return peer + end + + if TRUST_CF_HEADER then + local cf_ip = ngx.var.http_cf_connecting_ip + if cf_ip and cf_ip ~= "" then + return cf_ip:match("^%s*(.-)%s*$") + end + end + + local xff = ngx.var.http_x_forwarded_for + if xff and xff ~= "" then + local hops = {} + for hop in xff:gmatch("[^,]+") do + hops[#hops + 1] = hop:match("^%s*(.-)%s*$") + end + for i = #hops, 1, -1 do + if not is_trusted_proxy(hops[i]) then + return hops[i] + end + end + if #hops > 0 then + return hops[1] + end + end + + return peer +end + function _M.is_suspicious_path(uri) if not uri then return false end local uri_lower = string.lower(uri) @@ -163,7 +234,7 @@ end function _M.capture_event() local status = ngx.status local uri = ngx.var.uri - local ip = ngx.var.remote_addr + local ip = get_real_client_ip() local method = ngx.var.request_method local user_agent = ngx.var.http_user_agent or "" local host = ngx.var.host or ""