From 4e87d842379518951ddffcb6087d4d0a25891c10 Mon Sep 17 00:00:00 2001 From: "Codex, agent ChatGPT" Date: Thu, 12 Mar 2026 00:51:06 +0100 Subject: [PATCH] Build initial caddy-opnsense-blocker daemon --- .gitignore | 2 + README.md | 97 +++ cmd/caddy-opnsense-blocker/main.go | 91 +++ config.example.yaml | 74 +++ go.mod | 21 + go.sum | 53 ++ internal/caddylog/parser.go | 194 ++++++ internal/caddylog/parser_test.go | 57 ++ internal/config/config.go | 414 +++++++++++++ internal/config/config_test.go | 106 ++++ internal/engine/evaluator.go | 69 +++ internal/engine/evaluator_test.go | 153 +++++ internal/model/types.go | 140 +++++ internal/opnsense/client.go | 306 +++++++++ internal/opnsense/client_test.go | 134 ++++ internal/service/service.go | 412 +++++++++++++ internal/service/service_test.go | 247 ++++++++ internal/store/store.go | 961 +++++++++++++++++++++++++++++ internal/store/store_test.go | 116 ++++ internal/web/handler.go | 583 +++++++++++++++++ internal/web/handler_test.go | 124 ++++ 21 files changed, 4354 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 cmd/caddy-opnsense-blocker/main.go create mode 100644 config.example.yaml create mode 100644 go.mod create mode 100644 go.sum create mode 100644 internal/caddylog/parser.go create mode 100644 internal/caddylog/parser_test.go create mode 100644 internal/config/config.go create mode 100644 internal/config/config_test.go create mode 100644 internal/engine/evaluator.go create mode 100644 internal/engine/evaluator_test.go create mode 100644 internal/model/types.go create mode 100644 internal/opnsense/client.go create mode 100644 internal/opnsense/client_test.go create mode 100644 internal/service/service.go create mode 100644 internal/service/service_test.go create mode 100644 internal/store/store.go create mode 100644 internal/store/store_test.go create mode 100644 internal/web/handler.go create mode 100644 internal/web/handler_test.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c00b98f --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/data/ +/caddy-opnsense-blocker diff --git a/README.md b/README.md new file mode 100644 index 0000000..b67a2db --- /dev/null +++ b/README.md @@ -0,0 +1,97 @@ +# caddy-opnsense-blocker + +`caddy-opnsense-blocker` is a local-first daemon that ingests Caddy access logs in their default JSON format, evaluates suspicious requests, keeps persistent local state in SQLite, provides a lightweight web UI for review, and blocks or unblocks IP addresses through an OPNsense alias. + +## Features + +- Real-time ingestion of multiple Caddy JSON log files. +- One heuristic profile per log source. +- Persistent local state in SQLite. +- Local-only web UI for reviewing events and IPs. +- Manual block, unblock, and override reset actions. +- OPNsense alias backend with automatic alias creation. +- Concurrent polling across multiple log files. + +## Current scope + +This first version is intentionally strong on ingestion, persistence, UI, and OPNsense integration. +The decision engine is deliberately simple and deterministic for now: + +- suspicious path prefixes +- unexpected `POST` requests +- `.php` path detection +- explicit known-agent allow/deny rules +- excluded CIDR ranges +- manual overrides + +This keeps the application usable immediately while leaving room for a more advanced network-intelligence engine later. + +## Architecture + +- `internal/caddylog`: parses default Caddy JSON access logs +- `internal/engine`: evaluates requests against a profile +- `internal/store`: persists events, IP state, manual decisions, backend actions, and source offsets +- `internal/opnsense`: manages the target OPNsense alias through its API +- `internal/service`: runs concurrent log followers and applies automatic decisions +- `internal/web`: serves the local review UI and JSON API + +## Quick start + +1. Generate or provision OPNsense API credentials. +2. Copy `config.example.yaml` to `config.yaml` and adapt it. +3. Start the daemon: + +```bash +CGO_ENABLED=0 go run ./cmd/caddy-opnsense-blocker -config ./config.yaml +``` + +4. Open the local UI on the configured address, for example `http://127.0.0.1:9080`. + +## Example configuration + +See `config.example.yaml`. + +Important points: + +- Each source points to one Caddy log file. +- Each source references exactly one profile. +- `initial_position: end` means “start following new lines only” on first boot. +- The web UI should stay bound to a local address such as `127.0.0.1:9080`. + +## Web UI and API + +The web UI is intentionally small and server-rendered. +It refreshes through lightweight JSON polling and exposes these endpoints: + +- `GET /healthz` +- `GET /api/overview` +- `GET /api/events` +- `GET /api/ips` +- `GET /api/ips/{ip}` +- `POST /api/ips/{ip}/block` +- `POST /api/ips/{ip}/unblock` +- `POST /api/ips/{ip}/reset` + +## Development + +Run the test suite: + +```bash +CGO_ENABLED=0 go test ./... +``` + +Build the daemon: + +```bash +CGO_ENABLED=0 go build ./cmd/caddy-opnsense-blocker +``` + +`CGO_ENABLED=0` is useful on systems without a C toolchain. The application itself only relies on pure-Go dependencies. + +## Roadmap + +- richer decision engine +- asynchronous DNS / RDAP / ASN enrichment +- richer review filters in the UI +- alternative blocking backends besides OPNsense +- direct streaming ingestion targets in addition to file polling diff --git a/cmd/caddy-opnsense-blocker/main.go b/cmd/caddy-opnsense-blocker/main.go new file mode 100644 index 0000000..467c54f --- /dev/null +++ b/cmd/caddy-opnsense-blocker/main.go @@ -0,0 +1,91 @@ +package main + +import ( + "context" + "errors" + "flag" + "fmt" + "log" + "net/http" + "os" + "os/signal" + "syscall" + + "git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/config" + "git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/opnsense" + "git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/service" + "git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/store" + "git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/web" +) + +func main() { + if err := run(); err != nil { + log.Fatal(err) + } +} + +func run() error { + var configPath string + flag.StringVar(&configPath, "config", "./config.yaml", "Path to the YAML configuration file") + flag.Parse() + + cfg, err := config.Load(configPath) + if err != nil { + return fmt.Errorf("load config: %w", err) + } + + database, err := store.Open(cfg.Storage.Path) + if err != nil { + return fmt.Errorf("open store: %w", err) + } + defer database.Close() + + logger := log.New(os.Stderr, "caddy-opnsense-blocker: ", log.LstdFlags|log.Lmsgprefix) + + var blocker opnsense.AliasClient + if cfg.OPNsense.Enabled { + blocker = opnsense.NewClient(cfg.OPNsense) + } + + svc := service.New(cfg, database, blocker, logger) + handler := web.NewHandler(svc) + httpServer := &http.Server{ + Addr: cfg.Server.ListenAddress, + Handler: handler, + ReadTimeout: cfg.Server.ReadTimeout.Duration, + WriteTimeout: cfg.Server.WriteTimeout.Duration, + } + + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + errCh := make(chan error, 2) + go func() { + if err := svc.Run(ctx); err != nil && !errors.Is(err, context.Canceled) { + errCh <- err + } + }() + go func() { + logger.Printf("serving on %s", cfg.Server.ListenAddress) + if err := httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + errCh <- err + } + }() + + select { + case err := <-errCh: + stop() + shutdownCtx, cancel := context.WithTimeout(context.Background(), cfg.Server.ShutdownTimeout.Duration) + defer cancel() + _ = httpServer.Shutdown(shutdownCtx) + return err + case <-ctx.Done(): + } + + shutdownCtx, cancel := context.WithTimeout(context.Background(), cfg.Server.ShutdownTimeout.Duration) + defer cancel() + if err := httpServer.Shutdown(shutdownCtx); err != nil && !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("shutdown http server: %w", err) + } + return nil +} diff --git a/config.example.yaml b/config.example.yaml new file mode 100644 index 0000000..4cb0ce6 --- /dev/null +++ b/config.example.yaml @@ -0,0 +1,74 @@ +server: + listen_address: 127.0.0.1:9080 + read_timeout: 5s + write_timeout: 10s + shutdown_timeout: 15s + +storage: + path: ./data/caddy-opnsense-blocker.db + +opnsense: + enabled: true + base_url: https://router.example.test + api_key_file: /run/secrets/opnsense-api-key + api_secret_file: /run/secrets/opnsense-api-secret + timeout: 8s + insecure_skip_verify: false + ensure_alias: true + alias: + name: blocked-ips + type: host + description: Managed by caddy-opnsense-blocker + +profiles: + public-web: + auto_block: true + min_status: 400 + max_status: 599 + block_unexpected_posts: true + block_php_paths: true + allowed_post_paths: + - /search + suspicious_path_prefixes: + - /wp-admin + - /wp-login.php + - /.env + - /.git + excluded_cidrs: + - 10.0.0.0/8 + - 127.0.0.0/8 + known_agents: + - name: friendly-bot + decision: allow + user_agent_prefixes: + - FriendlyBot/ + + gitea: + auto_block: false + min_status: 400 + max_status: 599 + block_unexpected_posts: true + block_php_paths: false + allowed_post_paths: + - /user/login + - /user/sign_up + - /user/forgot_password + suspicious_path_prefixes: + - /install.php + - /.env + - /.git + +sources: + - name: public-web + path: /var/log/caddy/public-web-access.json + profile: public-web + initial_position: end + poll_interval: 1s + batch_size: 256 + + - name: gitea + path: /var/log/caddy/gitea-access.json + profile: gitea + initial_position: end + poll_interval: 1s + batch_size: 256 diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..fb5af58 --- /dev/null +++ b/go.mod @@ -0,0 +1,21 @@ +module git.dern.ovh/infrastructure/caddy-opnsense-blocker + +go 1.25 + +require ( + gopkg.in/yaml.v3 v3.0.1 + modernc.org/sqlite v1.39.1 +) + +require ( + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/ncruces/go-strftime v0.1.9 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b // indirect + golang.org/x/sys v0.36.0 // indirect + modernc.org/libc v1.66.10 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.11.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..9f7bc87 --- /dev/null +++ b/go.sum @@ -0,0 +1,53 @@ +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= +github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o= +golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= +golang.org/x/mod v0.27.0 h1:kb+q2PyFnEADO2IEF935ehFUXlWiNjJWtRNgBLSfbxQ= +golang.org/x/mod v0.27.0/go.mod h1:rWI627Fq0DEoudcK+MBkNkCe0EetEaDSwJJkCcjpazc= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= +golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg= +golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +modernc.org/cc/v4 v4.26.5 h1:xM3bX7Mve6G8K8b+T11ReenJOT+BmVqQj0FY5T4+5Y4= +modernc.org/cc/v4 v4.26.5/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= +modernc.org/ccgo/v4 v4.28.1 h1:wPKYn5EC/mYTqBO373jKjvX2n+3+aK7+sICCv4Fjy1A= +modernc.org/ccgo/v4 v4.28.1/go.mod h1:uD+4RnfrVgE6ec9NGguUNdhqzNIeeomeXf6CL0GTE5Q= +modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA= +modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc= +modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= +modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= +modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= +modernc.org/libc v1.66.10 h1:yZkb3YeLx4oynyR+iUsXsybsX4Ubx7MQlSYEw4yj59A= +modernc.org/libc v1.66.10/go.mod h1:8vGSEwvoUoltr4dlywvHqjtAqHBaw0j1jI7iFBTAr2I= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= +modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= +modernc.org/sqlite v1.39.1 h1:H+/wGFzuSCIEVCvXYVHX5RQglwhMOvtHSv+VtidL2r4= +modernc.org/sqlite v1.39.1/go.mod h1:9fjQZ0mB1LLP0GYrp39oOJXx/I2sxEnZtzCmEQIKvGE= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/internal/caddylog/parser.go b/internal/caddylog/parser.go new file mode 100644 index 0000000..8496d17 --- /dev/null +++ b/internal/caddylog/parser.go @@ -0,0 +1,194 @@ +package caddylog + +import ( + "encoding/json" + "errors" + "fmt" + "net" + "net/url" + "strconv" + "strings" + "time" + + "git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/model" +) + +var ErrEmptyLine = errors.New("empty log line") + +type accessLogEntry struct { + Timestamp json.RawMessage `json:"ts"` + Status int `json:"status"` + Request accessLogRequest `json:"request"` +} + +type accessLogRequest struct { + RemoteIP string `json:"remote_ip"` + ClientIP string `json:"client_ip"` + Host string `json:"host"` + Method string `json:"method"` + URI string `json:"uri"` + Headers map[string][]string `json:"headers"` +} + +func ParseLine(line string) (model.AccessLogRecord, error) { + trimmed := strings.TrimSpace(line) + if trimmed == "" { + return model.AccessLogRecord{}, ErrEmptyLine + } + + var entry accessLogEntry + if err := json.Unmarshal([]byte(trimmed), &entry); err != nil { + return model.AccessLogRecord{}, fmt.Errorf("decode caddy log line: %w", err) + } + if entry.Status == 0 { + return model.AccessLogRecord{}, errors.New("missing caddy status") + } + + remoteIP, err := normalizeIP(entry.Request.RemoteIP) + if err != nil && strings.TrimSpace(entry.Request.RemoteIP) != "" { + return model.AccessLogRecord{}, fmt.Errorf("normalize remote ip: %w", err) + } + + clientCandidate := entry.Request.ClientIP + if strings.TrimSpace(clientCandidate) == "" { + clientCandidate = entry.Request.RemoteIP + } + clientIP, err := normalizeIP(clientCandidate) + if err != nil { + return model.AccessLogRecord{}, fmt.Errorf("normalize client ip: %w", err) + } + + occurredAt, err := parseTimestamp(entry.Timestamp) + if err != nil { + return model.AccessLogRecord{}, fmt.Errorf("parse timestamp: %w", err) + } + + uri := entry.Request.URI + if strings.TrimSpace(uri) == "" { + uri = "/" + } + + return model.AccessLogRecord{ + OccurredAt: occurredAt, + RemoteIP: remoteIP, + ClientIP: clientIP, + Host: strings.TrimSpace(entry.Request.Host), + Method: strings.ToUpper(strings.TrimSpace(entry.Request.Method)), + URI: uri, + Path: pathFromURI(uri), + Status: entry.Status, + UserAgent: firstUserAgent(entry.Request.Headers), + RawJSON: trimmed, + }, nil +} + +func ParseLines(lines []string) ([]model.AccessLogRecord, error) { + records := make([]model.AccessLogRecord, 0, len(lines)) + for _, line := range lines { + record, err := ParseLine(line) + if err != nil { + if errors.Is(err, ErrEmptyLine) { + continue + } + return nil, err + } + records = append(records, record) + } + return records, nil +} + +func firstUserAgent(headers map[string][]string) string { + if len(headers) == 0 { + return "" + } + for _, key := range []string{"User-Agent", "user-agent", "USER-AGENT"} { + if values, ok := headers[key]; ok && len(values) > 0 { + return strings.TrimSpace(values[0]) + } + } + for key, values := range headers { + if strings.EqualFold(key, "user-agent") && len(values) > 0 { + return strings.TrimSpace(values[0]) + } + } + return "" +} + +func parseTimestamp(raw json.RawMessage) (time.Time, error) { + if len(raw) == 0 { + return time.Time{}, errors.New("missing timestamp") + } + + var numeric float64 + if err := json.Unmarshal(raw, &numeric); err == nil { + seconds := int64(numeric) + nanos := int64((numeric - float64(seconds)) * float64(time.Second)) + return time.Unix(seconds, nanos).UTC(), nil + } + + var text string + if err := json.Unmarshal(raw, &text); err == nil { + text = strings.TrimSpace(text) + if text == "" { + return time.Time{}, errors.New("empty timestamp") + } + if numeric, err := strconv.ParseFloat(text, 64); err == nil { + seconds := int64(numeric) + nanos := int64((numeric - float64(seconds)) * float64(time.Second)) + return time.Unix(seconds, nanos).UTC(), nil + } + for _, layout := range []string{time.RFC3339Nano, time.RFC3339} { + parsed, err := time.Parse(layout, text) + if err == nil { + return parsed.UTC(), nil + } + } + } + + return time.Time{}, fmt.Errorf("unsupported timestamp payload %s", string(raw)) +} + +func normalizeIP(value string) (string, error) { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return "", errors.New("missing ip address") + } + parsed := net.ParseIP(trimmed) + if parsed == nil { + return "", fmt.Errorf("invalid ip address %q", value) + } + return parsed.String(), nil +} + +func pathFromURI(rawURI string) string { + trimmed := strings.TrimSpace(rawURI) + if trimmed == "" { + return "/" + } + + parsed, err := url.ParseRequestURI(trimmed) + if err == nil { + if parsed.Path == "" { + return "/" + } + return normalizePath(parsed.Path) + } + + value := strings.SplitN(trimmed, "?", 2)[0] + value = strings.SplitN(value, "#", 2)[0] + return normalizePath(value) +} + +func normalizePath(value string) string { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return "/" + } + if !strings.HasPrefix(trimmed, "/") { + trimmed = "/" + trimmed + } + if trimmed != "/" { + trimmed = strings.TrimRight(trimmed, "/") + } + return strings.ToLower(trimmed) +} diff --git a/internal/caddylog/parser_test.go b/internal/caddylog/parser_test.go new file mode 100644 index 0000000..5b06881 --- /dev/null +++ b/internal/caddylog/parser_test.go @@ -0,0 +1,57 @@ +package caddylog + +import ( + "strings" + "testing" + "time" +) + +func TestParseLineWithNumericTimestamp(t *testing.T) { + t.Parallel() + + line := `{"ts":1710000000.5,"status":404,"request":{"remote_ip":"198.51.100.10","client_ip":"203.0.113.5","host":"example.test","method":"GET","uri":"/wp-login.php?foo=bar","headers":{"User-Agent":["UnitTestBot/1.0"]}}}` + record, err := ParseLine(line) + if err != nil { + t.Fatalf("parse line: %v", err) + } + + if got, want := record.ClientIP, "203.0.113.5"; got != want { + t.Fatalf("unexpected client ip: got %q want %q", got, want) + } + if got, want := record.RemoteIP, "198.51.100.10"; got != want { + t.Fatalf("unexpected remote ip: got %q want %q", got, want) + } + if got, want := record.Path, "/wp-login.php"; got != want { + t.Fatalf("unexpected path: got %q want %q", got, want) + } + if got, want := record.UserAgent, "UnitTestBot/1.0"; got != want { + t.Fatalf("unexpected user agent: got %q want %q", got, want) + } + if got, want := record.Method, "GET"; got != want { + t.Fatalf("unexpected method: got %q want %q", got, want) + } + expected := time.Unix(1710000000, 500000000).UTC() + if !record.OccurredAt.Equal(expected) { + t.Fatalf("unexpected timestamp: got %s want %s", record.OccurredAt, expected) + } +} + +func TestParseLineWithRFC3339TimestampAndMissingClientIP(t *testing.T) { + t.Parallel() + + line := `{"ts":"2025-03-11T12:13:14.123456Z","status":401,"request":{"remote_ip":"2001:db8::1","host":"git.example.test","method":"POST","uri":"user/login","headers":{"user-agent":["curl/8.0"]}}}` + record, err := ParseLine(line) + if err != nil { + t.Fatalf("parse line: %v", err) + } + + if got, want := record.ClientIP, "2001:db8::1"; got != want { + t.Fatalf("unexpected fallback client ip: got %q want %q", got, want) + } + if got, want := record.Path, "/user/login"; got != want { + t.Fatalf("unexpected path: got %q want %q", got, want) + } + if !strings.Contains(record.RawJSON, `"status":401`) { + t.Fatalf("raw json was not preserved") + } +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..31d606a --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,414 @@ +package config + +import ( + "errors" + "fmt" + "net" + "os" + "path/filepath" + "sort" + "strings" + "time" + + "gopkg.in/yaml.v3" +) + +type Duration struct { + time.Duration +} + +func (d *Duration) UnmarshalYAML(node *yaml.Node) error { + if node == nil || node.Value == "" { + d.Duration = 0 + return nil + } + + parsed, err := time.ParseDuration(node.Value) + if err != nil { + return fmt.Errorf("invalid duration %q: %w", node.Value, err) + } + d.Duration = parsed + return nil +} + +func (d Duration) MarshalYAML() (any, error) { + return d.String(), nil +} + +type Config struct { + Server ServerConfig `yaml:"server"` + Storage StorageConfig `yaml:"storage"` + OPNsense OPNsenseConfig `yaml:"opnsense"` + Profiles map[string]ProfileConfig `yaml:"profiles"` + Sources []SourceConfig `yaml:"sources"` +} + +type ServerConfig struct { + ListenAddress string `yaml:"listen_address"` + ReadTimeout Duration `yaml:"read_timeout"` + WriteTimeout Duration `yaml:"write_timeout"` + ShutdownTimeout Duration `yaml:"shutdown_timeout"` +} + +type StorageConfig struct { + Path string `yaml:"path"` +} + +type OPNsenseConfig struct { + Enabled bool `yaml:"enabled"` + BaseURL string `yaml:"base_url"` + APIKey string `yaml:"api_key"` + APISecret string `yaml:"api_secret"` + APIKeyFile string `yaml:"api_key_file"` + APISecretFile string `yaml:"api_secret_file"` + Timeout Duration `yaml:"timeout"` + InsecureSkipVerify bool `yaml:"insecure_skip_verify"` + EnsureAlias bool `yaml:"ensure_alias"` + Alias AliasConfig `yaml:"alias"` + APIPaths APIPathsConfig `yaml:"api_paths"` +} + +type AliasConfig struct { + Name string `yaml:"name"` + Type string `yaml:"type"` + Description string `yaml:"description"` +} + +type APIPathsConfig struct { + AliasGetUUID string `yaml:"alias_get_uuid"` + AliasAddItem string `yaml:"alias_add_item"` + AliasSetItem string `yaml:"alias_set_item"` + AliasReconfig string `yaml:"alias_reconfigure"` + AliasUtilList string `yaml:"alias_util_list"` + AliasUtilAdd string `yaml:"alias_util_add"` + AliasUtilDelete string `yaml:"alias_util_delete"` +} + +type SourceConfig struct { + Name string `yaml:"name"` + Path string `yaml:"path"` + Profile string `yaml:"profile"` + InitialPosition string `yaml:"initial_position"` + PollInterval Duration `yaml:"poll_interval"` + BatchSize int `yaml:"batch_size"` +} + +type KnownAgentRule struct { + Name string `yaml:"name"` + Decision string `yaml:"decision"` + UserAgentPrefixes []string `yaml:"user_agent_prefixes"` + CIDRs []string `yaml:"cidrs"` + + normalizedPrefixes []string + networks []*net.IPNet +} + +type ProfileConfig struct { + AutoBlock bool `yaml:"auto_block"` + MinStatus int `yaml:"min_status"` + MaxStatus int `yaml:"max_status"` + BlockUnexpectedPosts bool `yaml:"block_unexpected_posts"` + BlockPHPPaths bool `yaml:"block_php_paths"` + AllowedPostPaths []string `yaml:"allowed_post_paths"` + SuspiciousPathPrefixes []string `yaml:"suspicious_path_prefixes"` + ExcludedCIDRs []string `yaml:"excluded_cidrs"` + KnownAgents []KnownAgentRule `yaml:"known_agents"` + + normalizedAllowedPostPaths map[string]struct{} + normalizedSuspiciousPaths []string + excludedNetworks []*net.IPNet +} + +func Load(path string) (*Config, error) { + payload, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read config: %w", err) + } + + var cfg Config + if err := yaml.Unmarshal(payload, &cfg); err != nil { + return nil, fmt.Errorf("decode config: %w", err) + } + + if err := cfg.applyDefaults(); err != nil { + return nil, err + } + + if err := cfg.validate(path); err != nil { + return nil, err + } + + return &cfg, nil +} + +func (c *Config) applyDefaults() error { + if c.Server.ListenAddress == "" { + c.Server.ListenAddress = "127.0.0.1:9080" + } + if c.Server.ReadTimeout.Duration == 0 { + c.Server.ReadTimeout.Duration = 5 * time.Second + } + if c.Server.WriteTimeout.Duration == 0 { + c.Server.WriteTimeout.Duration = 10 * time.Second + } + if c.Server.ShutdownTimeout.Duration == 0 { + c.Server.ShutdownTimeout.Duration = 15 * time.Second + } + if c.Storage.Path == "" { + c.Storage.Path = "./data/caddy-opnsense-blocker.db" + } + + if c.OPNsense.Timeout.Duration == 0 { + c.OPNsense.Timeout.Duration = 8 * time.Second + } + if c.OPNsense.Alias.Type == "" { + c.OPNsense.Alias.Type = "host" + } + if c.OPNsense.Alias.Description == "" { + c.OPNsense.Alias.Description = "Managed by caddy-opnsense-blocker" + } + if c.OPNsense.APIPaths.AliasGetUUID == "" { + c.OPNsense.APIPaths.AliasGetUUID = "/api/firewall/alias/get_alias_u_u_i_d/{alias}" + } + if c.OPNsense.APIPaths.AliasAddItem == "" { + c.OPNsense.APIPaths.AliasAddItem = "/api/firewall/alias/add_item" + } + if c.OPNsense.APIPaths.AliasSetItem == "" { + c.OPNsense.APIPaths.AliasSetItem = "/api/firewall/alias/set_item/{uuid}" + } + if c.OPNsense.APIPaths.AliasReconfig == "" { + c.OPNsense.APIPaths.AliasReconfig = "/api/firewall/alias/reconfigure" + } + if c.OPNsense.APIPaths.AliasUtilList == "" { + c.OPNsense.APIPaths.AliasUtilList = "/api/firewall/alias_util/list/{alias}" + } + if c.OPNsense.APIPaths.AliasUtilAdd == "" { + c.OPNsense.APIPaths.AliasUtilAdd = "/api/firewall/alias_util/add/{alias}" + } + if c.OPNsense.APIPaths.AliasUtilDelete == "" { + c.OPNsense.APIPaths.AliasUtilDelete = "/api/firewall/alias_util/delete/{alias}" + } + if !c.OPNsense.EnsureAlias { + c.OPNsense.EnsureAlias = true + } + + for name, profile := range c.Profiles { + if profile.MinStatus == 0 { + profile.MinStatus = 400 + } + if profile.MaxStatus == 0 { + profile.MaxStatus = 599 + } + profile.normalizedAllowedPostPaths = make(map[string]struct{}, len(profile.AllowedPostPaths)) + for _, path := range profile.AllowedPostPaths { + profile.normalizedAllowedPostPaths[normalizePath(path)] = struct{}{} + } + profile.normalizedSuspiciousPaths = make([]string, 0, len(profile.SuspiciousPathPrefixes)) + for _, prefix := range profile.SuspiciousPathPrefixes { + profile.normalizedSuspiciousPaths = append(profile.normalizedSuspiciousPaths, normalizePrefix(prefix)) + } + sort.Strings(profile.normalizedSuspiciousPaths) + + for _, cidr := range profile.ExcludedCIDRs { + _, network, err := net.ParseCIDR(strings.TrimSpace(cidr)) + if err != nil { + return fmt.Errorf("profile %q has invalid excluded_cidr %q: %w", name, cidr, err) + } + profile.excludedNetworks = append(profile.excludedNetworks, network) + } + + for index, rule := range profile.KnownAgents { + decision := strings.ToLower(strings.TrimSpace(rule.Decision)) + profile.KnownAgents[index].Decision = decision + for _, prefix := range rule.UserAgentPrefixes { + normalized := strings.ToLower(strings.TrimSpace(prefix)) + if normalized != "" { + profile.KnownAgents[index].normalizedPrefixes = append(profile.KnownAgents[index].normalizedPrefixes, normalized) + } + } + for _, cidr := range rule.CIDRs { + _, network, err := net.ParseCIDR(strings.TrimSpace(cidr)) + if err != nil { + return fmt.Errorf("profile %q rule %q has invalid cidr %q: %w", name, rule.Name, cidr, err) + } + profile.KnownAgents[index].networks = append(profile.KnownAgents[index].networks, network) + } + } + + c.Profiles[name] = profile + } + + for index := range c.Sources { + c.Sources[index].InitialPosition = strings.ToLower(strings.TrimSpace(c.Sources[index].InitialPosition)) + if c.Sources[index].InitialPosition == "" { + c.Sources[index].InitialPosition = "end" + } + if c.Sources[index].PollInterval.Duration == 0 { + c.Sources[index].PollInterval.Duration = time.Second + } + if c.Sources[index].BatchSize <= 0 { + c.Sources[index].BatchSize = 256 + } + } + + return nil +} + +func (c *Config) validate(sourcePath string) error { + if len(c.Profiles) == 0 { + return errors.New("at least one profile is required") + } + if len(c.Sources) == 0 { + return errors.New("at least one source is required") + } + if _, _, err := net.SplitHostPort(c.Server.ListenAddress); err != nil { + return fmt.Errorf("invalid server.listen_address: %w", err) + } + if err := os.MkdirAll(filepath.Dir(c.Storage.Path), 0o755); err != nil { + return fmt.Errorf("prepare storage directory: %w", err) + } + + seenNames := map[string]struct{}{} + seenPaths := map[string]struct{}{} + for _, source := range c.Sources { + if source.Name == "" { + return errors.New("source.name must not be empty") + } + if source.Path == "" { + return fmt.Errorf("source %q must define a path", source.Name) + } + if _, ok := seenNames[source.Name]; ok { + return fmt.Errorf("duplicate source name %q", source.Name) + } + seenNames[source.Name] = struct{}{} + if _, ok := seenPaths[source.Path]; ok { + return fmt.Errorf("duplicate source path %q", source.Path) + } + seenPaths[source.Path] = struct{}{} + if _, ok := c.Profiles[source.Profile]; !ok { + return fmt.Errorf("source %q references unknown profile %q", source.Name, source.Profile) + } + if source.InitialPosition != "beginning" && source.InitialPosition != "end" { + return fmt.Errorf("source %q has invalid initial_position %q", source.Name, source.InitialPosition) + } + } + + for name, profile := range c.Profiles { + if profile.MinStatus < 100 || profile.MinStatus > 599 { + return fmt.Errorf("profile %q has invalid min_status %d", name, profile.MinStatus) + } + if profile.MaxStatus < profile.MinStatus || profile.MaxStatus > 599 { + return fmt.Errorf("profile %q has invalid max_status %d", name, profile.MaxStatus) + } + for _, prefix := range profile.normalizedSuspiciousPaths { + if prefix == "/" { + return fmt.Errorf("profile %q has overly broad suspicious path prefix %q", name, prefix) + } + } + for _, rule := range profile.KnownAgents { + if rule.Decision != "allow" && rule.Decision != "deny" { + return fmt.Errorf("profile %q known agent %q has invalid decision %q", name, rule.Name, rule.Decision) + } + if len(rule.normalizedPrefixes) == 0 && len(rule.networks) == 0 { + return fmt.Errorf("profile %q known agent %q must define user_agent_prefixes and/or cidrs", name, rule.Name) + } + } + } + + if c.OPNsense.Enabled { + if c.OPNsense.BaseURL == "" { + return errors.New("opnsense.base_url is required when opnsense is enabled") + } + if c.OPNsense.Alias.Name == "" { + return errors.New("opnsense.alias.name is required when opnsense is enabled") + } + if c.OPNsense.APIKey == "" && c.OPNsense.APIKeyFile == "" { + return errors.New("opnsense.api_key or opnsense.api_key_file is required when opnsense is enabled") + } + if c.OPNsense.APISecret == "" && c.OPNsense.APISecretFile == "" { + return errors.New("opnsense.api_secret or opnsense.api_secret_file is required when opnsense is enabled") + } + if c.OPNsense.APIKey == "" { + payload, err := os.ReadFile(c.OPNsense.APIKeyFile) + if err != nil { + return fmt.Errorf("read opnsense.api_key_file: %w", err) + } + c.OPNsense.APIKey = strings.TrimSpace(string(payload)) + } + if c.OPNsense.APISecret == "" { + payload, err := os.ReadFile(c.OPNsense.APISecretFile) + if err != nil { + return fmt.Errorf("read opnsense.api_secret_file: %w", err) + } + c.OPNsense.APISecret = strings.TrimSpace(string(payload)) + } + } + + _ = sourcePath + return nil +} + +func normalizePath(input string) string { + value := strings.TrimSpace(input) + if value == "" { + return "/" + } + value = strings.SplitN(value, "?", 2)[0] + value = strings.SplitN(value, "#", 2)[0] + if !strings.HasPrefix(value, "/") { + value = "/" + value + } + if value != "/" { + value = strings.TrimRight(value, "/") + } + return strings.ToLower(value) +} + +func normalizePrefix(input string) string { + return normalizePath(input) +} + +func (p ProfileConfig) IsExcluded(ip net.IP) bool { + for _, network := range p.excludedNetworks { + if network.Contains(ip) { + return true + } + } + return false +} + +func (p ProfileConfig) IsAllowedPostPath(path string) bool { + _, ok := p.normalizedAllowedPostPaths[normalizePath(path)] + return ok +} + +func (p ProfileConfig) SuspiciousPrefixes() []string { + return append([]string(nil), p.normalizedSuspiciousPaths...) +} + +func (p ProfileConfig) MatchKnownAgent(ip net.IP, userAgent string) (KnownAgentRule, bool) { + normalizedUA := strings.ToLower(strings.TrimSpace(userAgent)) + for _, rule := range p.KnownAgents { + uaMatched := len(rule.normalizedPrefixes) == 0 + for _, prefix := range rule.normalizedPrefixes { + if strings.HasPrefix(normalizedUA, prefix) { + uaMatched = true + break + } + } + if !uaMatched { + continue + } + cidrMatched := len(rule.networks) == 0 + for _, network := range rule.networks { + if network.Contains(ip) { + cidrMatched = true + break + } + } + if cidrMatched { + return rule, true + } + } + return KnownAgentRule{}, false +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..c094232 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,106 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" + "testing" +) + +func TestLoadAppliesDefaultsAndReadsSecrets(t *testing.T) { + t.Parallel() + + tempDir := t.TempDir() + keyPath := filepath.Join(tempDir, "api-key") + secretPath := filepath.Join(tempDir, "api-secret") + if err := os.WriteFile(keyPath, []byte("test-key\n"), 0o600); err != nil { + t.Fatalf("write key file: %v", err) + } + if err := os.WriteFile(secretPath, []byte("test-secret\n"), 0o600); err != nil { + t.Fatalf("write secret file: %v", err) + } + + configPath := filepath.Join(tempDir, "config.yaml") + payload := fmt.Sprintf(`storage: + path: %s/data/blocker.db +opnsense: + enabled: true + base_url: https://router.example.test + api_key_file: %s + api_secret_file: %s + ensure_alias: true + alias: + name: blocked-ips +profiles: + main: + auto_block: true + block_unexpected_posts: true + block_php_paths: true + allowed_post_paths: + - /search + suspicious_path_prefixes: + - /wp-admin + excluded_cidrs: + - 10.0.0.0/8 + known_agents: + - name: friendly-bot + decision: allow + user_agent_prefixes: + - FriendlyBot/ +sources: + - name: main + path: %s/access.json + profile: main +`, tempDir, keyPath, secretPath, tempDir) + if err := os.WriteFile(configPath, []byte(payload), 0o600); err != nil { + t.Fatalf("write config file: %v", err) + } + + cfg, err := Load(configPath) + if err != nil { + t.Fatalf("load config: %v", err) + } + + if got, want := cfg.Server.ListenAddress, "127.0.0.1:9080"; got != want { + t.Fatalf("unexpected listen address: got %q want %q", got, want) + } + if got, want := cfg.Sources[0].InitialPosition, "end"; got != want { + t.Fatalf("unexpected initial position: got %q want %q", got, want) + } + if got, want := cfg.OPNsense.APIKey, "test-key"; got != want { + t.Fatalf("unexpected api key: got %q want %q", got, want) + } + if got, want := cfg.OPNsense.APISecret, "test-secret"; got != want { + t.Fatalf("unexpected api secret: got %q want %q", got, want) + } + profile := cfg.Profiles["main"] + if !profile.IsAllowedPostPath("/search") { + t.Fatalf("expected /search to be normalized as an allowed POST path") + } + if len(profile.SuspiciousPrefixes()) != 1 || profile.SuspiciousPrefixes()[0] != "/wp-admin" { + t.Fatalf("unexpected suspicious prefixes: %#v", profile.SuspiciousPrefixes()) + } +} + +func TestLoadRejectsInvalidInitialPosition(t *testing.T) { + t.Parallel() + + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.yaml") + payload := fmt.Sprintf(`profiles: + main: + auto_block: true +sources: + - name: main + path: %s/access.json + profile: main + initial_position: sideways +`, tempDir) + if err := os.WriteFile(configPath, []byte(payload), 0o600); err != nil { + t.Fatalf("write config file: %v", err) + } + + if _, err := Load(configPath); err == nil { + t.Fatalf("expected invalid initial_position to be rejected") + } +} diff --git a/internal/engine/evaluator.go b/internal/engine/evaluator.go new file mode 100644 index 0000000..fc9fbe9 --- /dev/null +++ b/internal/engine/evaluator.go @@ -0,0 +1,69 @@ +package engine + +import ( + "fmt" + "net" + "strings" + + "git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/config" + "git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/model" +) + +type Evaluator struct{} + +func NewEvaluator() *Evaluator { + return &Evaluator{} +} + +func (e *Evaluator) Evaluate(record model.AccessLogRecord, profile config.ProfileConfig, override model.ManualOverride) model.Decision { + switch override { + case model.ManualOverrideForceAllow: + return model.Decision{Action: model.DecisionActionAllow, Reasons: []string{"manual_override_force_allow"}} + case model.ManualOverrideForceBlock: + return model.Decision{Action: model.DecisionActionBlock, Reasons: []string{"manual_override_force_block"}} + } + + if record.Status < profile.MinStatus || record.Status > profile.MaxStatus { + return model.Decision{Action: model.DecisionActionNone} + } + + ip := net.ParseIP(record.ClientIP) + if ip == nil { + return model.Decision{Action: model.DecisionActionReview, Reasons: []string{"invalid_client_ip"}} + } + if profile.IsExcluded(ip) { + return model.Decision{Action: model.DecisionActionAllow, Reasons: []string{"excluded_cidr"}} + } + if rule, ok := profile.MatchKnownAgent(ip, record.UserAgent); ok { + if rule.Decision == "allow" { + return model.Decision{Action: model.DecisionActionAllow, Reasons: []string{fmt.Sprintf("known_agent_allow:%s", rule.Name)}} + } + return blockDecision(profile.AutoBlock, []string{fmt.Sprintf("known_agent_deny:%s", rule.Name)}) + } + + blockReasons := make([]string, 0, 3) + for _, prefix := range profile.SuspiciousPrefixes() { + if strings.HasPrefix(record.Path, prefix) { + blockReasons = append(blockReasons, fmt.Sprintf("suspicious_path_prefix:%s", prefix)) + break + } + } + if profile.BlockUnexpectedPosts && strings.EqualFold(record.Method, "POST") && !profile.IsAllowedPostPath(record.Path) { + blockReasons = append(blockReasons, "unexpected_post") + } + if profile.BlockPHPPaths && strings.HasSuffix(record.Path, ".php") { + blockReasons = append(blockReasons, "php_path") + } + + return blockDecision(profile.AutoBlock, blockReasons) +} + +func blockDecision(autoBlock bool, reasons []string) model.Decision { + if len(reasons) == 0 { + return model.Decision{Action: model.DecisionActionNone} + } + if autoBlock { + return model.Decision{Action: model.DecisionActionBlock, Reasons: reasons} + } + return model.Decision{Action: model.DecisionActionReview, Reasons: reasons} +} diff --git a/internal/engine/evaluator_test.go b/internal/engine/evaluator_test.go new file mode 100644 index 0000000..16a9e78 --- /dev/null +++ b/internal/engine/evaluator_test.go @@ -0,0 +1,153 @@ +package engine + +import ( + "fmt" + "os" + "path/filepath" + "testing" + + "git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/config" + "git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/model" +) + +func TestEvaluatorManualOverridesTakePriority(t *testing.T) { + t.Parallel() + + profile := loadProfile(t, ` +auto_block: true +block_unexpected_posts: true +block_php_paths: true +suspicious_path_prefixes: + - /wp-admin +`) + evaluator := NewEvaluator() + record := model.AccessLogRecord{ClientIP: "203.0.113.10", Status: 404, Method: "GET", Path: "/wp-admin/install.php", UserAgent: "curl/8.0"} + + if decision := evaluator.Evaluate(record, profile, model.ManualOverrideForceAllow); decision.Action != model.DecisionActionAllow { + t.Fatalf("expected manual allow to win, got %+v", decision) + } + if decision := evaluator.Evaluate(record, profile, model.ManualOverrideForceBlock); decision.Action != model.DecisionActionBlock { + t.Fatalf("expected manual block to win, got %+v", decision) + } +} + +func TestEvaluatorBlocksSuspiciousRequests(t *testing.T) { + t.Parallel() + + profile := loadProfile(t, ` +auto_block: true +block_unexpected_posts: true +block_php_paths: true +allowed_post_paths: + - /search +suspicious_path_prefixes: + - /wp-admin +`) + evaluator := NewEvaluator() + record := model.AccessLogRecord{ClientIP: "203.0.113.11", Status: 404, Method: "POST", Path: "/wp-admin/install.php", UserAgent: "curl/8.0"} + + decision := evaluator.Evaluate(record, profile, model.ManualOverrideNone) + if decision.Action != model.DecisionActionBlock { + t.Fatalf("expected block decision, got %+v", decision) + } + if len(decision.Reasons) < 2 { + t.Fatalf("expected multiple blocking reasons, got %+v", decision) + } +} + +func TestEvaluatorAllowsExcludedCIDRAndKnownAgents(t *testing.T) { + t.Parallel() + + profile := loadProfile(t, ` +auto_block: true +excluded_cidrs: + - 10.0.0.0/8 +known_agents: + - name: friendly-bot + decision: allow + user_agent_prefixes: + - FriendlyBot/ +`) + evaluator := NewEvaluator() + + excluded := model.AccessLogRecord{ClientIP: "10.0.0.5", Status: 404, Method: "GET", Path: "/wp-login.php", UserAgent: "curl/8.0"} + if decision := evaluator.Evaluate(excluded, profile, model.ManualOverrideNone); decision.Action != model.DecisionActionAllow { + t.Fatalf("expected excluded cidr to be allowed, got %+v", decision) + } + + knownAgent := model.AccessLogRecord{ClientIP: "203.0.113.12", Status: 404, Method: "GET", Path: "/wp-login.php", UserAgent: "FriendlyBot/2.0"} + if decision := evaluator.Evaluate(knownAgent, profile, model.ManualOverrideNone); decision.Action != model.DecisionActionAllow { + t.Fatalf("expected known agent to be allowed, got %+v", decision) + } +} + +func TestEvaluatorReturnsReviewWhenAutoBlockIsDisabled(t *testing.T) { + t.Parallel() + + profile := loadProfile(t, ` +auto_block: false +block_unexpected_posts: true +suspicious_path_prefixes: + - /admin +`) + evaluator := NewEvaluator() + record := model.AccessLogRecord{ClientIP: "203.0.113.13", Status: 404, Method: "POST", Path: "/admin", UserAgent: "curl/8.0"} + + decision := evaluator.Evaluate(record, profile, model.ManualOverrideNone) + if decision.Action != model.DecisionActionReview { + t.Fatalf("expected review decision, got %+v", decision) + } +} + +func loadProfile(t *testing.T, profileSnippet string) config.ProfileConfig { + t.Helper() + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.yaml") + payload := fmt.Sprintf(`profiles: + main:%s +sources: + - name: main + path: %s/access.json + profile: main +`, indent(profileSnippet, 4), tempDir) + if err := os.WriteFile(configPath, []byte(payload), 0o600); err != nil { + t.Fatalf("write config file: %v", err) + } + cfg, err := config.Load(configPath) + if err != nil { + t.Fatalf("load config: %v", err) + } + return cfg.Profiles["main"] +} + +func indent(value string, spaces int) string { + padding := "" + for range spaces { + padding += " " + } + lines := []byte(value) + _ = lines + var output string + for _, line := range splitLines(value) { + trimmed := line + if trimmed == "" { + output += "\n" + continue + } + output += "\n" + padding + trimmed + } + return output +} + +func splitLines(value string) []string { + var lines []string + start := 0 + for index, character := range value { + if character == '\n' { + lines = append(lines, value[start:index]) + start = index + 1 + } + } + lines = append(lines, value[start:]) + return lines +} diff --git a/internal/model/types.go b/internal/model/types.go new file mode 100644 index 0000000..6c9f895 --- /dev/null +++ b/internal/model/types.go @@ -0,0 +1,140 @@ +package model + +import "time" + +type DecisionAction string + +const ( + DecisionActionNone DecisionAction = "none" + DecisionActionReview DecisionAction = "review" + DecisionActionBlock DecisionAction = "block" + DecisionActionAllow DecisionAction = "allow" +) + +type ManualOverride string + +const ( + ManualOverrideNone ManualOverride = "none" + ManualOverrideForceAllow ManualOverride = "force_allow" + ManualOverrideForceBlock ManualOverride = "force_block" +) + +type IPStateStatus string + +const ( + IPStateObserved IPStateStatus = "observed" + IPStateReview IPStateStatus = "review" + IPStateBlocked IPStateStatus = "blocked" + IPStateAllowed IPStateStatus = "allowed" +) + +type AccessLogRecord struct { + OccurredAt time.Time + RemoteIP string + ClientIP string + Host string + Method string + URI string + Path string + Status int + UserAgent string + RawJSON string +} + +type Decision struct { + Action DecisionAction + Reasons []string +} + +func (d Decision) PrimaryReason() string { + if len(d.Reasons) == 0 { + return "" + } + return d.Reasons[0] +} + +type Event struct { + ID int64 `json:"id"` + SourceName string `json:"source_name"` + ProfileName string `json:"profile_name"` + OccurredAt time.Time `json:"occurred_at"` + RemoteIP string `json:"remote_ip"` + ClientIP string `json:"client_ip"` + Host string `json:"host"` + Method string `json:"method"` + URI string `json:"uri"` + Path string `json:"path"` + Status int `json:"status"` + UserAgent string `json:"user_agent"` + Decision DecisionAction `json:"decision"` + DecisionReason string `json:"decision_reason"` + DecisionReasons []string `json:"decision_reasons,omitempty"` + Enforced bool `json:"enforced"` + RawJSON string `json:"raw_json"` + CreatedAt time.Time `json:"created_at"` + CurrentState IPStateStatus `json:"current_state"` + ManualOverride ManualOverride `json:"manual_override"` +} + +type IPState struct { + IP string `json:"ip"` + FirstSeenAt time.Time `json:"first_seen_at"` + LastSeenAt time.Time `json:"last_seen_at"` + LastSourceName string `json:"last_source_name"` + LastUserAgent string `json:"last_user_agent"` + LatestStatus int `json:"latest_status"` + TotalEvents int64 `json:"total_events"` + State IPStateStatus `json:"state"` + StateReason string `json:"state_reason"` + ManualOverride ManualOverride `json:"manual_override"` + LastEventID int64 `json:"last_event_id"` + UpdatedAt time.Time `json:"updated_at"` +} + +type DecisionRecord struct { + ID int64 `json:"id"` + EventID int64 `json:"event_id"` + IP string `json:"ip"` + SourceName string `json:"source_name"` + Kind string `json:"kind"` + Action DecisionAction `json:"action"` + Reason string `json:"reason"` + Actor string `json:"actor"` + Enforced bool `json:"enforced"` + CreatedAt time.Time `json:"created_at"` +} + +type OPNsenseAction struct { + ID int64 `json:"id"` + IP string `json:"ip"` + Action string `json:"action"` + Result string `json:"result"` + Message string `json:"message"` + CreatedAt time.Time `json:"created_at"` +} + +type SourceOffset struct { + SourceName string + Path string + Inode string + Offset int64 + UpdatedAt time.Time +} + +type IPDetails struct { + State IPState `json:"state"` + RecentEvents []Event `json:"recent_events"` + Decisions []DecisionRecord `json:"decisions"` + BackendActions []OPNsenseAction `json:"backend_actions"` +} + +type Overview struct { + TotalEvents int64 `json:"total_events"` + TotalIPs int64 `json:"total_ips"` + BlockedIPs int64 `json:"blocked_ips"` + ReviewIPs int64 `json:"review_ips"` + AllowedIPs int64 `json:"allowed_ips"` + ObservedIPs int64 `json:"observed_ips"` + RecentIPs []IPState `json:"recent_ips"` + RecentEvents []Event `json:"recent_events"` +} diff --git a/internal/opnsense/client.go b/internal/opnsense/client.go new file mode 100644 index 0000000..8def0d4 --- /dev/null +++ b/internal/opnsense/client.go @@ -0,0 +1,306 @@ +package opnsense + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strings" + "sync" + + "git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/config" +) + +type AliasClient interface { + AddIPIfMissing(ctx context.Context, ip string) (string, error) + RemoveIPIfPresent(ctx context.Context, ip string) (string, error) + IsIPPresent(ctx context.Context, ip string) (bool, error) +} + +type Client struct { + cfg config.OPNsenseConfig + httpClient *http.Client + + mu sync.Mutex + aliasUUID string + knownAliasIPs map[string]struct{} +} + +func NewClient(cfg config.OPNsenseConfig) *Client { + transport := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: cfg.InsecureSkipVerify}, + } + + return &Client{ + cfg: cfg, + httpClient: &http.Client{ + Timeout: cfg.Timeout.Duration, + Transport: transport, + }, + } +} + +func (c *Client) AddIPIfMissing(ctx context.Context, ip string) (string, error) { + normalized, err := normalizeIP(ip) + if err != nil { + return "", err + } + + c.mu.Lock() + defer c.mu.Unlock() + + snapshot, err := c.ensureAliasSnapshotLocked(ctx) + if err != nil { + return "", err + } + if _, ok := snapshot[normalized]; ok { + return "already_present", nil + } + + payload, err := c.requestJSON(ctx, http.MethodPost, c.cfg.APIPaths.AliasUtilAdd, map[string]string{"alias": c.cfg.Alias.Name}, map[string]string{"address": normalized}) + if err != nil { + return "", err + } + if status := strings.ToLower(strings.TrimSpace(asString(payload["status"]))); status != "done" { + return "", fmt.Errorf("opnsense alias add failed: %v", payload) + } + snapshot[normalized] = struct{}{} + return "added", nil +} + +func (c *Client) RemoveIPIfPresent(ctx context.Context, ip string) (string, error) { + normalized, err := normalizeIP(ip) + if err != nil { + return "", err + } + + c.mu.Lock() + defer c.mu.Unlock() + + snapshot, err := c.ensureAliasSnapshotLocked(ctx) + if err != nil { + return "", err + } + if _, ok := snapshot[normalized]; !ok { + return "already_absent", nil + } + + payload, err := c.requestJSON(ctx, http.MethodPost, c.cfg.APIPaths.AliasUtilDelete, map[string]string{"alias": c.cfg.Alias.Name}, map[string]string{"address": normalized}) + if err != nil { + return "", err + } + if status := strings.ToLower(strings.TrimSpace(asString(payload["status"]))); status != "done" { + return "", fmt.Errorf("opnsense alias delete failed: %v", payload) + } + delete(snapshot, normalized) + return "removed", nil +} + +func (c *Client) IsIPPresent(ctx context.Context, ip string) (bool, error) { + normalized, err := normalizeIP(ip) + if err != nil { + return false, err + } + + c.mu.Lock() + defer c.mu.Unlock() + + snapshot, err := c.ensureAliasSnapshotLocked(ctx) + if err != nil { + return false, err + } + _, ok := snapshot[normalized] + return ok, nil +} + +func (c *Client) ensureAliasSnapshotLocked(ctx context.Context) (map[string]struct{}, error) { + if c.knownAliasIPs != nil { + return c.knownAliasIPs, nil + } + if err := c.ensureAliasExistsLocked(ctx); err != nil { + return nil, err + } + payload, err := c.requestJSON(ctx, http.MethodGet, c.cfg.APIPaths.AliasUtilList, map[string]string{"alias": c.cfg.Alias.Name}, nil) + if err != nil { + return nil, err + } + rows, ok := payload["rows"].([]any) + if !ok { + return nil, fmt.Errorf("unexpected opnsense alias listing payload: %v", payload) + } + snapshot := make(map[string]struct{}, len(rows)) + for _, row := range rows { + rowMap, ok := row.(map[string]any) + if !ok { + return nil, fmt.Errorf("unexpected opnsense alias row payload: %T", row) + } + candidate := asString(rowMap["ip"]) + if candidate == "" { + candidate = asString(rowMap["address"]) + } + if candidate == "" { + candidate = asString(rowMap["item"]) + } + if candidate == "" { + continue + } + normalized, err := normalizeIP(candidate) + if err != nil { + continue + } + snapshot[normalized] = struct{}{} + } + c.knownAliasIPs = snapshot + return snapshot, nil +} + +func (c *Client) ensureAliasExistsLocked(ctx context.Context) error { + if c.aliasUUID != "" { + return nil + } + + uuid, err := c.getAliasUUIDLocked(ctx) + if err != nil { + return err + } + if uuid == "" { + if !c.cfg.EnsureAlias { + return fmt.Errorf("opnsense alias %q does not exist and ensure_alias is disabled", c.cfg.Alias.Name) + } + if _, err := c.requestJSON(ctx, http.MethodPost, c.cfg.APIPaths.AliasAddItem, nil, map[string]any{ + "alias": map[string]string{ + "enabled": "1", + "name": c.cfg.Alias.Name, + "type": c.cfg.Alias.Type, + "content": "", + "description": c.cfg.Alias.Description, + }, + }); err != nil { + return err + } + uuid, err = c.getAliasUUIDLocked(ctx) + if err != nil { + return err + } + if uuid == "" { + return fmt.Errorf("unable to create opnsense alias %q", c.cfg.Alias.Name) + } + if _, err := c.requestJSON(ctx, http.MethodPost, c.cfg.APIPaths.AliasSetItem, map[string]string{"uuid": uuid}, map[string]any{ + "alias": map[string]string{ + "enabled": "1", + "name": c.cfg.Alias.Name, + "type": c.cfg.Alias.Type, + "content": "", + "description": c.cfg.Alias.Description, + }, + }); err != nil { + return err + } + if err := c.reconfigureLocked(ctx); err != nil { + return err + } + } + c.aliasUUID = uuid + return nil +} + +func (c *Client) getAliasUUIDLocked(ctx context.Context) (string, error) { + payload, err := c.requestJSON(ctx, http.MethodGet, c.cfg.APIPaths.AliasGetUUID, map[string]string{"alias": c.cfg.Alias.Name}, nil) + if err != nil { + return "", err + } + return strings.TrimSpace(asString(payload["uuid"])), nil +} + +func (c *Client) reconfigureLocked(ctx context.Context) error { + payload, err := c.requestJSON(ctx, http.MethodPost, c.cfg.APIPaths.AliasReconfig, nil, nil) + if err != nil { + return err + } + status := strings.ToLower(strings.TrimSpace(asString(payload["status"]))) + if status != "ok" && status != "done" { + return fmt.Errorf("opnsense alias reconfigure failed: %v", payload) + } + return nil +} + +func (c *Client) requestJSON(ctx context.Context, method, pathTemplate string, pathValues map[string]string, body any) (map[string]any, error) { + requestURL, err := c.buildURL(pathTemplate, pathValues) + if err != nil { + return nil, err + } + + var payload io.Reader + if body != nil { + encoded, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("encode request body: %w", err) + } + payload = bytes.NewReader(encoded) + } + + req, err := http.NewRequestWithContext(ctx, method, requestURL, payload) + if err != nil { + return nil, fmt.Errorf("build request: %w", err) + } + req.SetBasicAuth(c.cfg.APIKey, c.cfg.APISecret) + req.Header.Set("Accept", "application/json") + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("perform request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + payload, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<10)) + return nil, fmt.Errorf("unexpected status %s: %s", resp.Status, strings.TrimSpace(string(payload))) + } + + var decoded map[string]any + if err := json.NewDecoder(resp.Body).Decode(&decoded); err != nil { + return nil, fmt.Errorf("decode response: %w", err) + } + return decoded, nil +} + +func (c *Client) buildURL(pathTemplate string, values map[string]string) (string, error) { + baseURL := strings.TrimRight(c.cfg.BaseURL, "/") + if baseURL == "" { + return "", fmt.Errorf("missing opnsense base url") + } + path := pathTemplate + for key, value := range values { + path = strings.ReplaceAll(path, "{"+key+"}", url.PathEscape(value)) + } + return baseURL + path, nil +} + +func normalizeIP(ip string) (string, error) { + parsed := net.ParseIP(strings.TrimSpace(ip)) + if parsed == nil { + return "", fmt.Errorf("invalid ip address %q", ip) + } + return parsed.String(), nil +} + +func asString(value any) string { + switch typed := value.(type) { + case string: + return typed + case fmt.Stringer: + return typed.String() + case nil: + return "" + default: + return fmt.Sprintf("%v", typed) + } +} diff --git a/internal/opnsense/client_test.go b/internal/opnsense/client_test.go new file mode 100644 index 0000000..017fc8f --- /dev/null +++ b/internal/opnsense/client_test.go @@ -0,0 +1,134 @@ +package opnsense + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/config" +) + +func TestClientCreatesAliasAndBlocksAndUnblocksIPs(t *testing.T) { + t.Parallel() + + type state struct { + mu sync.Mutex + aliasUUID string + aliasExists bool + ips map[string]struct{} + } + + backendState := &state{ips: map[string]struct{}{}} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + username, password, ok := r.BasicAuth() + if !ok || username != "key" || password != "secret" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + w.Header().Set("Content-Type", "application/json") + backendState.mu.Lock() + defer backendState.mu.Unlock() + + switch { + case r.Method == http.MethodGet && r.URL.Path == "/api/firewall/alias/get_alias_u_u_i_d/blocked-ips": + if backendState.aliasExists { + _ = json.NewEncoder(w).Encode(map[string]any{"uuid": backendState.aliasUUID}) + } else { + _ = json.NewEncoder(w).Encode(map[string]any{"uuid": ""}) + } + case r.Method == http.MethodPost && r.URL.Path == "/api/firewall/alias/add_item": + backendState.aliasExists = true + backendState.aliasUUID = "uuid-1" + _ = json.NewEncoder(w).Encode(map[string]any{"status": "ok"}) + case r.Method == http.MethodPost && r.URL.Path == "/api/firewall/alias/set_item/uuid-1": + _ = json.NewEncoder(w).Encode(map[string]any{"status": "ok"}) + case r.Method == http.MethodPost && r.URL.Path == "/api/firewall/alias/reconfigure": + _ = json.NewEncoder(w).Encode(map[string]any{"status": "ok"}) + case r.Method == http.MethodGet && r.URL.Path == "/api/firewall/alias_util/list/blocked-ips": + rows := make([]map[string]string, 0, len(backendState.ips)) + for ip := range backendState.ips { + rows = append(rows, map[string]string{"ip": ip}) + } + _ = json.NewEncoder(w).Encode(map[string]any{"rows": rows}) + case r.Method == http.MethodPost && r.URL.Path == "/api/firewall/alias_util/add/blocked-ips": + var payload map[string]string + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + backendState.ips[payload["address"]] = struct{}{} + _ = json.NewEncoder(w).Encode(map[string]any{"status": "done"}) + case r.Method == http.MethodPost && r.URL.Path == "/api/firewall/alias_util/delete/blocked-ips": + var payload map[string]string + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + delete(backendState.ips, payload["address"]) + _ = json.NewEncoder(w).Encode(map[string]any{"status": "done"}) + default: + http.Error(w, "not found", http.StatusNotFound) + } + })) + defer server.Close() + + client := NewClient(config.OPNsenseConfig{ + Enabled: true, + BaseURL: server.URL, + APIKey: "key", + APISecret: "secret", + EnsureAlias: true, + Timeout: config.Duration{Duration: time.Second}, + Alias: config.AliasConfig{ + Name: "blocked-ips", + Type: "host", + }, + APIPaths: config.APIPathsConfig{ + AliasGetUUID: "/api/firewall/alias/get_alias_u_u_i_d/{alias}", + AliasAddItem: "/api/firewall/alias/add_item", + AliasSetItem: "/api/firewall/alias/set_item/{uuid}", + AliasReconfig: "/api/firewall/alias/reconfigure", + AliasUtilList: "/api/firewall/alias_util/list/{alias}", + AliasUtilAdd: "/api/firewall/alias_util/add/{alias}", + AliasUtilDelete: "/api/firewall/alias_util/delete/{alias}", + }, + }) + + ctx := context.Background() + if result, err := client.AddIPIfMissing(ctx, "203.0.113.10"); err != nil || result != "added" { + t.Fatalf("unexpected add result: result=%q err=%v", result, err) + } + if result, err := client.AddIPIfMissing(ctx, "203.0.113.10"); err != nil || result != "already_present" { + t.Fatalf("unexpected add replay result: result=%q err=%v", result, err) + } + present, err := client.IsIPPresent(ctx, "203.0.113.10") + if err != nil { + t.Fatalf("is ip present: %v", err) + } + if !present { + t.Fatalf("expected IP to be present in alias") + } + if result, err := client.RemoveIPIfPresent(ctx, "203.0.113.10"); err != nil || result != "removed" { + t.Fatalf("unexpected remove result: result=%q err=%v", result, err) + } + if result, err := client.RemoveIPIfPresent(ctx, "203.0.113.10"); err != nil || result != "already_absent" { + t.Fatalf("unexpected remove replay result: result=%q err=%v", result, err) + } + + backendState.mu.Lock() + defer backendState.mu.Unlock() + if !backendState.aliasExists || backendState.aliasUUID == "" { + t.Fatalf("expected alias to exist after first add") + } + if len(backendState.ips) != 0 { + t.Fatalf("expected alias to be empty after remove, got %v", backendState.ips) + } + if strings.TrimSpace(backendState.aliasUUID) == "" { + t.Fatalf("expected alias uuid to be populated") + } +} diff --git a/internal/service/service.go b/internal/service/service.go new file mode 100644 index 0000000..6b4c753 --- /dev/null +++ b/internal/service/service.go @@ -0,0 +1,412 @@ +package service + +import ( + "bufio" + "context" + "errors" + "fmt" + "io" + "log" + "net" + "os" + "strings" + "sync" + "syscall" + "time" + + "git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/caddylog" + "git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/config" + "git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/engine" + "git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/model" + "git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/opnsense" + "git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/store" +) + +type Service struct { + cfg *config.Config + store *store.Store + evaluator *engine.Evaluator + blocker opnsense.AliasClient + logger *log.Logger +} + +func New(cfg *config.Config, db *store.Store, blocker opnsense.AliasClient, logger *log.Logger) *Service { + if logger == nil { + logger = log.New(io.Discard, "", 0) + } + return &Service{ + cfg: cfg, + store: db, + evaluator: engine.NewEvaluator(), + blocker: blocker, + logger: logger, + } +} + +func (s *Service) Run(ctx context.Context) error { + var wg sync.WaitGroup + for _, source := range s.cfg.Sources { + source := source + wg.Add(1) + go func() { + defer wg.Done() + s.runSource(ctx, source) + }() + } + <-ctx.Done() + wg.Wait() + return nil +} + +func (s *Service) GetOverview(ctx context.Context, limit int) (model.Overview, error) { + return s.store.GetOverview(ctx, limit) +} + +func (s *Service) ListEvents(ctx context.Context, limit int) ([]model.Event, error) { + return s.store.ListRecentEvents(ctx, limit) +} + +func (s *Service) ListIPs(ctx context.Context, limit int, state string) ([]model.IPState, error) { + return s.store.ListIPStates(ctx, limit, state) +} + +func (s *Service) GetIPDetails(ctx context.Context, ip string) (model.IPDetails, error) { + normalized, err := normalizeIP(ip) + if err != nil { + return model.IPDetails{}, err + } + return s.store.GetIPDetails(ctx, normalized, 100, 100, 100) +} + +func (s *Service) ForceBlock(ctx context.Context, ip string, actor string, reason string) error { + return s.applyManualOverride(ctx, ip, model.ManualOverrideForceBlock, model.IPStateBlocked, actor, defaultReason(reason, "manual block"), "block") +} + +func (s *Service) ForceAllow(ctx context.Context, ip string, actor string, reason string) error { + return s.applyManualOverride(ctx, ip, model.ManualOverrideForceAllow, model.IPStateAllowed, actor, defaultReason(reason, "manual allow"), "unblock") +} + +func (s *Service) ClearOverride(ctx context.Context, ip string, actor string, reason string) error { + normalized, err := normalizeIP(ip) + if err != nil { + return err + } + reason = defaultReason(reason, "manual override cleared") + state, err := s.store.ClearManualOverride(ctx, normalized, reason) + if err != nil { + return err + } + return s.store.AddDecision(ctx, &model.DecisionRecord{ + EventID: state.LastEventID, + IP: normalized, + SourceName: state.LastSourceName, + Kind: "manual", + Action: model.DecisionActionNone, + Reason: reason, + Actor: defaultActor(actor), + Enforced: false, + CreatedAt: time.Now().UTC(), + }) +} + +func (s *Service) runSource(ctx context.Context, source config.SourceConfig) { + s.pollSource(ctx, source) + ticker := time.NewTicker(source.PollInterval.Duration) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + s.pollSource(ctx, source) + } + } +} + +func (s *Service) pollSource(ctx context.Context, source config.SourceConfig) { + lines, err := s.readNewLines(ctx, source) + if err != nil { + s.logger.Printf("source %s: %v", source.Name, err) + return + } + if len(lines) == 0 { + return + } + + profile := s.cfg.Profiles[source.Profile] + for _, line := range lines { + record, err := caddylog.ParseLine(line) + if err != nil { + if errors.Is(err, caddylog.ErrEmptyLine) { + continue + } + s.logger.Printf("source %s: parse line: %v", source.Name, err) + continue + } + if record.Status < profile.MinStatus || record.Status > profile.MaxStatus { + continue + } + if err := s.processRecord(ctx, source, profile, record); err != nil { + s.logger.Printf("source %s: process record: %v", source.Name, err) + } + } +} + +func (s *Service) processRecord(ctx context.Context, source config.SourceConfig, profile config.ProfileConfig, record model.AccessLogRecord) error { + state, found, err := s.store.GetIPState(ctx, record.ClientIP) + if err != nil { + return err + } + override := model.ManualOverrideNone + if found { + override = state.ManualOverride + } + + decision := s.evaluator.Evaluate(record, profile, override) + event := model.Event{ + SourceName: source.Name, + ProfileName: source.Profile, + OccurredAt: record.OccurredAt, + RemoteIP: record.RemoteIP, + ClientIP: record.ClientIP, + Host: record.Host, + Method: record.Method, + URI: record.URI, + Path: record.Path, + Status: record.Status, + UserAgent: record.UserAgent, + Decision: decision.Action, + DecisionReason: decision.PrimaryReason(), + DecisionReasons: append([]string(nil), decision.Reasons...), + Enforced: false, + RawJSON: record.RawJSON, + CreatedAt: time.Now().UTC(), + } + + var backendAction *model.OPNsenseAction + if decision.Action == model.DecisionActionBlock && s.blocker != nil { + result, blockErr := s.blocker.AddIPIfMissing(ctx, record.ClientIP) + backendAction = &model.OPNsenseAction{ + IP: record.ClientIP, + Action: "block", + CreatedAt: time.Now().UTC(), + } + if blockErr != nil { + backendAction.Result = "error" + backendAction.Message = blockErr.Error() + } else { + backendAction.Result = result + backendAction.Message = decision.PrimaryReason() + event.Enforced = true + } + } + + if err := s.store.RecordEvent(ctx, &event); err != nil { + return err + } + if decision.Action != model.DecisionActionNone { + if err := s.store.AddDecision(ctx, &model.DecisionRecord{ + EventID: event.ID, + IP: record.ClientIP, + SourceName: source.Name, + Kind: "automatic", + Action: decision.Action, + Reason: strings.Join(decision.Reasons, ", "), + Actor: "engine", + Enforced: event.Enforced, + CreatedAt: time.Now().UTC(), + }); err != nil { + return err + } + } + if backendAction != nil { + if err := s.store.AddBackendAction(ctx, backendAction); err != nil { + return err + } + } + return nil +} + +func (s *Service) readNewLines(ctx context.Context, source config.SourceConfig) ([]string, error) { + info, err := os.Stat(source.Path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil, nil + } + return nil, fmt.Errorf("stat source path %q: %w", source.Path, err) + } + inode := fileIdentity(info) + size := info.Size() + + offset, found, err := s.store.GetSourceOffset(ctx, source.Name) + if err != nil { + return nil, err + } + if !found { + start := int64(0) + if source.InitialPosition == "end" { + start = size + } + offset = model.SourceOffset{ + SourceName: source.Name, + Path: source.Path, + Inode: inode, + Offset: start, + UpdatedAt: time.Now().UTC(), + } + if err := s.store.SaveSourceOffset(ctx, offset); err != nil { + return nil, err + } + if start >= size { + return nil, nil + } + } else if offset.Inode != inode || size < offset.Offset { + offset = model.SourceOffset{ + SourceName: source.Name, + Path: source.Path, + Inode: inode, + Offset: 0, + UpdatedAt: time.Now().UTC(), + } + } + + file, err := os.Open(source.Path) + if err != nil { + return nil, fmt.Errorf("open source path %q: %w", source.Path, err) + } + defer file.Close() + + if _, err := file.Seek(offset.Offset, io.SeekStart); err != nil { + return nil, fmt.Errorf("seek source path %q: %w", source.Path, err) + } + + reader := bufio.NewReader(file) + lines := make([]string, 0, source.BatchSize) + currentOffset := offset.Offset + + for len(lines) < source.BatchSize { + line, err := reader.ReadString('\n') + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return nil, fmt.Errorf("read source path %q: %w", source.Path, err) + } + currentOffset += int64(len(line)) + lines = append(lines, strings.TrimRight(line, "\r\n")) + } + + offset.Path = source.Path + offset.Inode = inode + offset.Offset = currentOffset + offset.UpdatedAt = time.Now().UTC() + if err := s.store.SaveSourceOffset(ctx, offset); err != nil { + return nil, err + } + + return lines, nil +} + +func (s *Service) applyManualOverride(ctx context.Context, ip string, override model.ManualOverride, state model.IPStateStatus, actor string, reason string, backendAction string) error { + normalized, err := normalizeIP(ip) + if err != nil { + return err + } + + enforced := false + var backendRecord *model.OPNsenseAction + if s.blocker != nil { + backendRecord = &model.OPNsenseAction{ + IP: normalized, + Action: backendAction, + CreatedAt: time.Now().UTC(), + } + switch override { + case model.ManualOverrideForceBlock: + result, callErr := s.blocker.AddIPIfMissing(ctx, normalized) + if callErr != nil { + backendRecord.Result = "error" + backendRecord.Message = callErr.Error() + } else { + backendRecord.Result = result + backendRecord.Message = reason + enforced = true + } + case model.ManualOverrideForceAllow: + result, callErr := s.blocker.RemoveIPIfPresent(ctx, normalized) + if callErr != nil { + backendRecord.Result = "error" + backendRecord.Message = callErr.Error() + } else { + backendRecord.Result = result + backendRecord.Message = reason + enforced = true + } + } + } + + current, err := s.store.SetManualOverride(ctx, normalized, override, state, reason) + if err != nil { + return err + } + if err := s.store.AddDecision(ctx, &model.DecisionRecord{ + EventID: current.LastEventID, + IP: normalized, + SourceName: current.LastSourceName, + Kind: "manual", + Action: actionForOverride(override), + Reason: reason, + Actor: defaultActor(actor), + Enforced: enforced, + CreatedAt: time.Now().UTC(), + }); err != nil { + return err + } + if backendRecord != nil { + if err := s.store.AddBackendAction(ctx, backendRecord); err != nil { + return err + } + } + return nil +} + +func normalizeIP(ip string) (string, error) { + parsed := net.ParseIP(strings.TrimSpace(ip)) + if parsed == nil { + return "", fmt.Errorf("invalid ip address %q", ip) + } + return parsed.String(), nil +} + +func fileIdentity(info os.FileInfo) string { + if stat, ok := info.Sys().(*syscall.Stat_t); ok { + return fmt.Sprintf("%d:%d", stat.Dev, stat.Ino) + } + return fmt.Sprintf("fallback:%d:%d", info.ModTime().UnixNano(), info.Size()) +} + +func actionForOverride(override model.ManualOverride) model.DecisionAction { + switch override { + case model.ManualOverrideForceBlock: + return model.DecisionActionBlock + case model.ManualOverrideForceAllow: + return model.DecisionActionAllow + default: + return model.DecisionActionNone + } +} + +func defaultActor(actor string) string { + if strings.TrimSpace(actor) == "" { + return "web-ui" + } + return strings.TrimSpace(actor) +} + +func defaultReason(reason string, fallback string) string { + if strings.TrimSpace(reason) == "" { + return fallback + } + return strings.TrimSpace(reason) +} diff --git a/internal/service/service_test.go b/internal/service/service_test.go new file mode 100644 index 0000000..e68b9ee --- /dev/null +++ b/internal/service/service_test.go @@ -0,0 +1,247 @@ +package service + +import ( + "context" + "encoding/json" + "fmt" + "log" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/config" + "git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/model" + "git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/opnsense" + "git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/store" +) + +func TestServiceProcessesMultipleSourcesAndManualActions(t *testing.T) { + t.Parallel() + + tempDir := t.TempDir() + mainLogPath := filepath.Join(tempDir, "main.log") + giteaLogPath := filepath.Join(tempDir, "gitea.log") + if err := os.WriteFile(mainLogPath, nil, 0o600); err != nil { + t.Fatalf("create main log: %v", err) + } + if err := os.WriteFile(giteaLogPath, nil, 0o600); err != nil { + t.Fatalf("create gitea log: %v", err) + } + + backend := newFakeOPNsenseServer(t) + defer backend.Close() + + configPath := filepath.Join(tempDir, "config.yaml") + payload := fmt.Sprintf(`storage: + path: %s/blocker.db +opnsense: + enabled: true + base_url: %s + api_key: key + api_secret: secret + ensure_alias: true + alias: + name: blocked-ips +profiles: + main: + auto_block: true + block_unexpected_posts: true + block_php_paths: true + suspicious_path_prefixes: + - /wp-login.php + gitea: + auto_block: false + block_unexpected_posts: true + allowed_post_paths: + - /user/login + suspicious_path_prefixes: + - /install.php +sources: + - name: main + path: %s + profile: main + initial_position: beginning + poll_interval: 20ms + batch_size: 128 + - name: gitea + path: %s + profile: gitea + initial_position: beginning + poll_interval: 20ms + batch_size: 128 +`, tempDir, backend.URL, mainLogPath, giteaLogPath) + if err := os.WriteFile(configPath, []byte(payload), 0o600); err != nil { + t.Fatalf("write config: %v", err) + } + + cfg, err := config.Load(configPath) + if err != nil { + t.Fatalf("load config: %v", err) + } + database, err := store.Open(cfg.Storage.Path) + if err != nil { + t.Fatalf("open store: %v", err) + } + defer database.Close() + svc := New(cfg, database, opnsense.NewClient(cfg.OPNsense), log.New(os.Stderr, "", 0)) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { _ = svc.Run(ctx) }() + + appendLine(t, mainLogPath, caddyJSONLine("203.0.113.10", "198.51.100.10", "example.test", "GET", "/wp-login.php", 404, "curl/8.0", time.Now().UTC())) + appendLine(t, giteaLogPath, caddyJSONLine("203.0.113.11", "198.51.100.11", "git.example.test", "POST", "/user/login", 401, "curl/8.0", time.Now().UTC())) + appendLine(t, giteaLogPath, caddyJSONLine("203.0.113.12", "198.51.100.12", "git.example.test", "GET", "/install.php", 404, "curl/8.0", time.Now().UTC())) + + waitFor(t, 3*time.Second, func() bool { + overview, err := database.GetOverview(context.Background(), 10) + return err == nil && overview.TotalEvents == 3 + }) + + blockedState, found, err := database.GetIPState(context.Background(), "203.0.113.10") + if err != nil || !found { + t.Fatalf("load blocked state: found=%v err=%v", found, err) + } + if blockedState.State != model.IPStateBlocked { + t.Fatalf("expected blocked state, got %+v", blockedState) + } + + reviewState, found, err := database.GetIPState(context.Background(), "203.0.113.12") + if err != nil || !found { + t.Fatalf("load review state: found=%v err=%v", found, err) + } + if reviewState.State != model.IPStateReview { + t.Fatalf("expected review state, got %+v", reviewState) + } + + observedState, found, err := database.GetIPState(context.Background(), "203.0.113.11") + if err != nil || !found { + t.Fatalf("load observed state: found=%v err=%v", found, err) + } + if observedState.State != model.IPStateObserved { + t.Fatalf("expected observed state, got %+v", observedState) + } + + if err := svc.ForceAllow(context.Background(), "203.0.113.10", "test", "manual unblock"); err != nil { + t.Fatalf("force allow: %v", err) + } + state, found, err := database.GetIPState(context.Background(), "203.0.113.10") + if err != nil || !found { + t.Fatalf("reload unblocked state: found=%v err=%v", found, err) + } + if state.ManualOverride != model.ManualOverrideForceAllow || state.State != model.IPStateAllowed { + t.Fatalf("unexpected manual allow state: %+v", state) + } + + backend.mu.Lock() + defer backend.mu.Unlock() + if _, ok := backend.ips["203.0.113.10"]; ok { + t.Fatalf("expected IP to be removed from backend alias after manual unblock") + } +} + +type fakeOPNsenseServer struct { + *httptest.Server + mu sync.Mutex + aliasUUID string + aliasExists bool + ips map[string]struct{} +} + +func newFakeOPNsenseServer(t *testing.T) *fakeOPNsenseServer { + t.Helper() + backend := &fakeOPNsenseServer{ips: map[string]struct{}{}} + backend.Server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + username, password, ok := r.BasicAuth() + if !ok || username != "key" || password != "secret" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + w.Header().Set("Content-Type", "application/json") + backend.mu.Lock() + defer backend.mu.Unlock() + + switch { + case r.Method == http.MethodGet && r.URL.Path == "/api/firewall/alias/get_alias_u_u_i_d/blocked-ips": + if backend.aliasExists { + _ = json.NewEncoder(w).Encode(map[string]any{"uuid": backend.aliasUUID}) + } else { + _ = json.NewEncoder(w).Encode(map[string]any{"uuid": ""}) + } + case r.Method == http.MethodPost && r.URL.Path == "/api/firewall/alias/add_item": + backend.aliasExists = true + backend.aliasUUID = "uuid-1" + _ = json.NewEncoder(w).Encode(map[string]any{"status": "ok"}) + case r.Method == http.MethodPost && r.URL.Path == "/api/firewall/alias/set_item/uuid-1": + _ = json.NewEncoder(w).Encode(map[string]any{"status": "ok"}) + case r.Method == http.MethodPost && r.URL.Path == "/api/firewall/alias/reconfigure": + _ = json.NewEncoder(w).Encode(map[string]any{"status": "ok"}) + case r.Method == http.MethodGet && r.URL.Path == "/api/firewall/alias_util/list/blocked-ips": + rows := make([]map[string]string, 0, len(backend.ips)) + for ip := range backend.ips { + rows = append(rows, map[string]string{"ip": ip}) + } + _ = json.NewEncoder(w).Encode(map[string]any{"rows": rows}) + case r.Method == http.MethodPost && r.URL.Path == "/api/firewall/alias_util/add/blocked-ips": + var payload map[string]string + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + backend.ips[payload["address"]] = struct{}{} + _ = json.NewEncoder(w).Encode(map[string]any{"status": "done"}) + case r.Method == http.MethodPost && r.URL.Path == "/api/firewall/alias_util/delete/blocked-ips": + var payload map[string]string + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + delete(backend.ips, payload["address"]) + _ = json.NewEncoder(w).Encode(map[string]any{"status": "done"}) + default: + http.Error(w, "not found", http.StatusNotFound) + } + })) + return backend +} + +func appendLine(t *testing.T, path string, line string) { + t.Helper() + file, err := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0) + if err != nil { + t.Fatalf("open log file for append: %v", err) + } + defer file.Close() + if _, err := file.WriteString(line + "\n"); err != nil { + t.Fatalf("append log line: %v", err) + } +} + +func caddyJSONLine(clientIP string, remoteIP string, host string, method string, uri string, status int, userAgent string, occurredAt time.Time) string { + return fmt.Sprintf(`{"ts":%q,"status":%d,"request":{"remote_ip":%q,"client_ip":%q,"host":%q,"method":%q,"uri":%q,"headers":{"User-Agent":[%q]}}}`, + occurredAt.UTC().Format(time.RFC3339Nano), + status, + remoteIP, + clientIP, + host, + method, + uri, + userAgent, + ) +} + +func waitFor(t *testing.T, timeout time.Duration, condition func() bool) { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if condition() { + return + } + time.Sleep(20 * time.Millisecond) + } + t.Fatalf("condition was not met within %s", timeout) +} diff --git a/internal/store/store.go b/internal/store/store.go new file mode 100644 index 0000000..3f83e3f --- /dev/null +++ b/internal/store/store.go @@ -0,0 +1,961 @@ +package store + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/model" + _ "modernc.org/sqlite" +) + +const schema = ` +CREATE TABLE IF NOT EXISTS events ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + source_name TEXT NOT NULL, + profile_name TEXT NOT NULL, + occurred_at TEXT NOT NULL, + remote_ip TEXT NOT NULL, + client_ip TEXT NOT NULL, + host TEXT NOT NULL, + method TEXT NOT NULL, + uri TEXT NOT NULL, + path TEXT NOT NULL, + status INTEGER NOT NULL, + user_agent TEXT NOT NULL, + decision TEXT NOT NULL, + decision_reason TEXT NOT NULL, + decision_reasons_json TEXT NOT NULL, + enforced INTEGER NOT NULL DEFAULT 0, + raw_json TEXT NOT NULL, + created_at TEXT NOT NULL +); + +CREATE INDEX IF NOT EXISTS idx_events_occurred_at ON events(occurred_at DESC, id DESC); +CREATE INDEX IF NOT EXISTS idx_events_client_ip ON events(client_ip, occurred_at DESC, id DESC); +CREATE INDEX IF NOT EXISTS idx_events_source_name ON events(source_name, occurred_at DESC, id DESC); +CREATE INDEX IF NOT EXISTS idx_events_decision ON events(decision, occurred_at DESC, id DESC); + +CREATE TABLE IF NOT EXISTS ip_state ( + ip TEXT PRIMARY KEY, + first_seen_at TEXT NOT NULL, + last_seen_at TEXT NOT NULL, + last_source_name TEXT NOT NULL, + last_user_agent TEXT NOT NULL, + latest_status INTEGER NOT NULL, + total_events INTEGER NOT NULL, + state TEXT NOT NULL, + state_reason TEXT NOT NULL, + manual_override TEXT NOT NULL, + last_event_id INTEGER NOT NULL, + updated_at TEXT NOT NULL +); + +CREATE INDEX IF NOT EXISTS idx_ip_state_last_seen ON ip_state(last_seen_at DESC, ip ASC); +CREATE INDEX IF NOT EXISTS idx_ip_state_state ON ip_state(state, last_seen_at DESC, ip ASC); + +CREATE TABLE IF NOT EXISTS decisions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + event_id INTEGER NOT NULL, + ip TEXT NOT NULL, + source_name TEXT NOT NULL, + kind TEXT NOT NULL, + action TEXT NOT NULL, + reason TEXT NOT NULL, + actor TEXT NOT NULL, + enforced INTEGER NOT NULL DEFAULT 0, + created_at TEXT NOT NULL +); + +CREATE INDEX IF NOT EXISTS idx_decisions_ip ON decisions(ip, created_at DESC, id DESC); +CREATE INDEX IF NOT EXISTS idx_decisions_event_id ON decisions(event_id, created_at DESC, id DESC); + +CREATE TABLE IF NOT EXISTS backend_actions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + ip TEXT NOT NULL, + action TEXT NOT NULL, + result TEXT NOT NULL, + message TEXT NOT NULL, + created_at TEXT NOT NULL +); + +CREATE INDEX IF NOT EXISTS idx_backend_actions_ip ON backend_actions(ip, created_at DESC, id DESC); + +CREATE TABLE IF NOT EXISTS source_offsets ( + source_name TEXT PRIMARY KEY, + path TEXT NOT NULL, + inode TEXT NOT NULL, + offset INTEGER NOT NULL, + updated_at TEXT NOT NULL +); +` + +type Store struct { + db *sql.DB +} + +func Open(path string) (*Store, error) { + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return nil, fmt.Errorf("create storage directory: %w", err) + } + + db, err := sql.Open("sqlite", path) + if err != nil { + return nil, fmt.Errorf("open sqlite database: %w", err) + } + db.SetMaxOpenConns(1) + db.SetMaxIdleConns(1) + db.SetConnMaxLifetime(0) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + for _, statement := range []string{ + "PRAGMA journal_mode = WAL;", + "PRAGMA busy_timeout = 5000;", + "PRAGMA foreign_keys = ON;", + } { + if _, err := db.ExecContext(ctx, statement); err != nil { + db.Close() + return nil, fmt.Errorf("apply sqlite pragma %q: %w", statement, err) + } + } + if _, err := db.ExecContext(ctx, schema); err != nil { + db.Close() + return nil, fmt.Errorf("apply sqlite schema: %w", err) + } + + return &Store{db: db}, nil +} + +func (s *Store) Close() error { + if s == nil || s.db == nil { + return nil + } + return s.db.Close() +} + +func (s *Store) RecordEvent(ctx context.Context, event *model.Event) error { + if event == nil { + return errors.New("nil event") + } + if event.OccurredAt.IsZero() { + event.OccurredAt = time.Now().UTC() + } + if event.CreatedAt.IsZero() { + event.CreatedAt = time.Now().UTC() + } + encodedReasons, err := json.Marshal(event.DecisionReasons) + if err != nil { + return fmt.Errorf("encode decision reasons: %w", err) + } + + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("begin transaction: %w", err) + } + defer tx.Rollback() + + state, found, err := getIPStateTx(tx, event.ClientIP) + if err != nil { + return err + } + + result, err := tx.ExecContext( + ctx, + `INSERT INTO events ( + source_name, profile_name, occurred_at, remote_ip, client_ip, host, method, uri, path, + status, user_agent, decision, decision_reason, decision_reasons_json, enforced, raw_json, created_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + event.SourceName, + event.ProfileName, + formatTime(event.OccurredAt), + event.RemoteIP, + event.ClientIP, + event.Host, + event.Method, + event.URI, + event.Path, + event.Status, + event.UserAgent, + string(event.Decision), + event.DecisionReason, + string(encodedReasons), + boolToInt(event.Enforced), + event.RawJSON, + formatTime(event.CreatedAt), + ) + if err != nil { + return fmt.Errorf("insert event: %w", err) + } + eventID, err := result.LastInsertId() + if err != nil { + return fmt.Errorf("load inserted event id: %w", err) + } + event.ID = eventID + + updatedState := mergeEventIntoState(state, found, *event) + event.CurrentState = updatedState.State + event.ManualOverride = updatedState.ManualOverride + + if err := upsertIPStateTx(ctx, tx, updatedState); err != nil { + return err + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("commit event transaction: %w", err) + } + return nil +} + +func (s *Store) AddDecision(ctx context.Context, decision *model.DecisionRecord) error { + if decision == nil { + return errors.New("nil decision record") + } + if decision.CreatedAt.IsZero() { + decision.CreatedAt = time.Now().UTC() + } + result, err := s.db.ExecContext( + ctx, + `INSERT INTO decisions (event_id, ip, source_name, kind, action, reason, actor, enforced, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, + decision.EventID, + decision.IP, + decision.SourceName, + decision.Kind, + string(decision.Action), + decision.Reason, + decision.Actor, + boolToInt(decision.Enforced), + formatTime(decision.CreatedAt), + ) + if err != nil { + return fmt.Errorf("insert decision record: %w", err) + } + decision.ID, err = result.LastInsertId() + if err != nil { + return fmt.Errorf("load inserted decision id: %w", err) + } + return nil +} + +func (s *Store) AddBackendAction(ctx context.Context, action *model.OPNsenseAction) error { + if action == nil { + return errors.New("nil backend action") + } + if action.CreatedAt.IsZero() { + action.CreatedAt = time.Now().UTC() + } + result, err := s.db.ExecContext( + ctx, + `INSERT INTO backend_actions (ip, action, result, message, created_at) + VALUES (?, ?, ?, ?, ?)`, + action.IP, + action.Action, + action.Result, + action.Message, + formatTime(action.CreatedAt), + ) + if err != nil { + return fmt.Errorf("insert backend action: %w", err) + } + action.ID, err = result.LastInsertId() + if err != nil { + return fmt.Errorf("load inserted backend action id: %w", err) + } + return nil +} + +func (s *Store) GetIPState(ctx context.Context, ip string) (model.IPState, bool, error) { + return getIPStateDB(ctx, s.db, ip) +} + +func (s *Store) SetManualOverride(ctx context.Context, ip string, override model.ManualOverride, state model.IPStateStatus, reason string) (model.IPState, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return model.IPState{}, fmt.Errorf("begin transaction: %w", err) + } + defer tx.Rollback() + + current, found, err := getIPStateTx(tx, ip) + if err != nil { + return model.IPState{}, err + } + now := time.Now().UTC() + if !found { + current = model.IPState{ + IP: ip, + FirstSeenAt: now, + LastSeenAt: now, + LastSourceName: "", + LastUserAgent: "", + LatestStatus: 0, + TotalEvents: 0, + State: state, + StateReason: strings.TrimSpace(reason), + ManualOverride: override, + LastEventID: 0, + UpdatedAt: now, + } + } else { + current.ManualOverride = override + current.State = state + if strings.TrimSpace(reason) != "" { + current.StateReason = strings.TrimSpace(reason) + } + current.UpdatedAt = now + } + if err := upsertIPStateTx(ctx, tx, current); err != nil { + return model.IPState{}, err + } + if err := tx.Commit(); err != nil { + return model.IPState{}, fmt.Errorf("commit transaction: %w", err) + } + return current, nil +} + +func (s *Store) ClearManualOverride(ctx context.Context, ip string, reason string) (model.IPState, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return model.IPState{}, fmt.Errorf("begin transaction: %w", err) + } + defer tx.Rollback() + + current, found, err := getIPStateTx(tx, ip) + if err != nil { + return model.IPState{}, err + } + now := time.Now().UTC() + if !found { + current = model.IPState{ + IP: ip, + FirstSeenAt: now, + LastSeenAt: now, + State: model.IPStateObserved, + StateReason: strings.TrimSpace(reason), + ManualOverride: model.ManualOverrideNone, + UpdatedAt: now, + } + } else { + current.ManualOverride = model.ManualOverrideNone + if current.State == "" { + current.State = model.IPStateObserved + } + if strings.TrimSpace(reason) != "" { + current.StateReason = strings.TrimSpace(reason) + } + current.UpdatedAt = now + } + if err := upsertIPStateTx(ctx, tx, current); err != nil { + return model.IPState{}, err + } + if err := tx.Commit(); err != nil { + return model.IPState{}, fmt.Errorf("commit transaction: %w", err) + } + return current, nil +} + +func (s *Store) GetOverview(ctx context.Context, limit int) (model.Overview, error) { + if limit <= 0 { + limit = 50 + } + var overview model.Overview + if err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM events`).Scan(&overview.TotalEvents); err != nil { + return model.Overview{}, fmt.Errorf("count events: %w", err) + } + if err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM ip_state`).Scan(&overview.TotalIPs); err != nil { + return model.Overview{}, fmt.Errorf("count ip states: %w", err) + } + if err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM ip_state WHERE state = ?`, string(model.IPStateBlocked)).Scan(&overview.BlockedIPs); err != nil { + return model.Overview{}, fmt.Errorf("count blocked ip states: %w", err) + } + if err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM ip_state WHERE state = ?`, string(model.IPStateReview)).Scan(&overview.ReviewIPs); err != nil { + return model.Overview{}, fmt.Errorf("count review ip states: %w", err) + } + if err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM ip_state WHERE state = ?`, string(model.IPStateAllowed)).Scan(&overview.AllowedIPs); err != nil { + return model.Overview{}, fmt.Errorf("count allowed ip states: %w", err) + } + if err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM ip_state WHERE state = ?`, string(model.IPStateObserved)).Scan(&overview.ObservedIPs); err != nil { + return model.Overview{}, fmt.Errorf("count observed ip states: %w", err) + } + + recentIPs, err := s.ListIPStates(ctx, limit, "") + if err != nil { + return model.Overview{}, err + } + recentEvents, err := s.ListRecentEvents(ctx, limit) + if err != nil { + return model.Overview{}, err + } + overview.RecentIPs = recentIPs + overview.RecentEvents = recentEvents + return overview, nil +} + +func (s *Store) ListRecentEvents(ctx context.Context, limit int) ([]model.Event, error) { + if limit <= 0 { + limit = 50 + } + rows, err := s.db.QueryContext(ctx, ` + SELECT e.id, e.source_name, e.profile_name, e.occurred_at, e.remote_ip, e.client_ip, e.host, + e.method, e.uri, e.path, e.status, e.user_agent, e.decision, e.decision_reason, + e.decision_reasons_json, e.enforced, e.raw_json, e.created_at, + COALESCE(s.state, ''), COALESCE(s.manual_override, '') + FROM events e + LEFT JOIN ip_state s ON s.ip = e.client_ip + ORDER BY e.occurred_at DESC, e.id DESC + LIMIT ?`, + limit, + ) + if err != nil { + return nil, fmt.Errorf("list recent events: %w", err) + } + defer rows.Close() + + items := make([]model.Event, 0, limit) + for rows.Next() { + item, err := scanEvent(rows) + if err != nil { + return nil, err + } + items = append(items, item) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate recent events: %w", err) + } + return items, nil +} + +func (s *Store) ListIPStates(ctx context.Context, limit int, stateFilter string) ([]model.IPState, error) { + if limit <= 0 { + limit = 50 + } + query := `SELECT ip, first_seen_at, last_seen_at, last_source_name, last_user_agent, latest_status, + total_events, state, state_reason, manual_override, last_event_id, updated_at + FROM ip_state` + args := []any{} + if strings.TrimSpace(stateFilter) != "" { + query += ` WHERE state = ?` + args = append(args, strings.TrimSpace(stateFilter)) + } + query += ` ORDER BY last_seen_at DESC, ip ASC LIMIT ?` + args = append(args, limit) + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("list ip states: %w", err) + } + defer rows.Close() + + items := make([]model.IPState, 0, limit) + for rows.Next() { + item, err := scanIPState(rows) + if err != nil { + return nil, err + } + items = append(items, item) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate ip states: %w", err) + } + return items, nil +} + +func (s *Store) GetIPDetails(ctx context.Context, ip string, eventLimit, decisionLimit, actionLimit int) (model.IPDetails, error) { + state, _, err := s.GetIPState(ctx, ip) + if err != nil { + return model.IPDetails{}, err + } + events, err := s.listEventsForIP(ctx, ip, eventLimit) + if err != nil { + return model.IPDetails{}, err + } + decisions, err := s.listDecisionsForIP(ctx, ip, decisionLimit) + if err != nil { + return model.IPDetails{}, err + } + actions, err := s.listBackendActionsForIP(ctx, ip, actionLimit) + if err != nil { + return model.IPDetails{}, err + } + return model.IPDetails{ + State: state, + RecentEvents: events, + Decisions: decisions, + BackendActions: actions, + }, nil +} + +func (s *Store) GetSourceOffset(ctx context.Context, sourceName string) (model.SourceOffset, bool, error) { + row := s.db.QueryRowContext(ctx, `SELECT source_name, path, inode, offset, updated_at FROM source_offsets WHERE source_name = ?`, sourceName) + var offset model.SourceOffset + var updatedAt string + if err := row.Scan(&offset.SourceName, &offset.Path, &offset.Inode, &offset.Offset, &updatedAt); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return model.SourceOffset{}, false, nil + } + return model.SourceOffset{}, false, fmt.Errorf("query source offset %q: %w", sourceName, err) + } + parsed, err := parseTime(updatedAt) + if err != nil { + return model.SourceOffset{}, false, fmt.Errorf("parse source offset updated_at: %w", err) + } + offset.UpdatedAt = parsed + return offset, true, nil +} + +func (s *Store) SaveSourceOffset(ctx context.Context, offset model.SourceOffset) error { + if offset.UpdatedAt.IsZero() { + offset.UpdatedAt = time.Now().UTC() + } + _, err := s.db.ExecContext( + ctx, + `INSERT INTO source_offsets (source_name, path, inode, offset, updated_at) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(source_name) DO UPDATE SET + path = excluded.path, + inode = excluded.inode, + offset = excluded.offset, + updated_at = excluded.updated_at`, + offset.SourceName, + offset.Path, + offset.Inode, + offset.Offset, + formatTime(offset.UpdatedAt), + ) + if err != nil { + return fmt.Errorf("upsert source offset: %w", err) + } + return nil +} + +func (s *Store) listEventsForIP(ctx context.Context, ip string, limit int) ([]model.Event, error) { + if limit <= 0 { + limit = 50 + } + rows, err := s.db.QueryContext(ctx, ` + SELECT e.id, e.source_name, e.profile_name, e.occurred_at, e.remote_ip, e.client_ip, e.host, + e.method, e.uri, e.path, e.status, e.user_agent, e.decision, e.decision_reason, + e.decision_reasons_json, e.enforced, e.raw_json, e.created_at, + COALESCE(s.state, ''), COALESCE(s.manual_override, '') + FROM events e + LEFT JOIN ip_state s ON s.ip = e.client_ip + WHERE e.client_ip = ? + ORDER BY e.occurred_at DESC, e.id DESC + LIMIT ?`, + ip, + limit, + ) + if err != nil { + return nil, fmt.Errorf("list events for ip %q: %w", ip, err) + } + defer rows.Close() + + items := make([]model.Event, 0, limit) + for rows.Next() { + item, err := scanEvent(rows) + if err != nil { + return nil, err + } + items = append(items, item) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate events for ip %q: %w", ip, err) + } + return items, nil +} + +func (s *Store) listDecisionsForIP(ctx context.Context, ip string, limit int) ([]model.DecisionRecord, error) { + if limit <= 0 { + limit = 50 + } + rows, err := s.db.QueryContext(ctx, ` + SELECT id, event_id, ip, source_name, kind, action, reason, actor, enforced, created_at + FROM decisions + WHERE ip = ? + ORDER BY created_at DESC, id DESC + LIMIT ?`, + ip, + limit, + ) + if err != nil { + return nil, fmt.Errorf("list decisions for ip %q: %w", ip, err) + } + defer rows.Close() + + items := make([]model.DecisionRecord, 0, limit) + for rows.Next() { + var item model.DecisionRecord + var action string + var enforced int + var createdAt string + if err := rows.Scan(&item.ID, &item.EventID, &item.IP, &item.SourceName, &item.Kind, &action, &item.Reason, &item.Actor, &enforced, &createdAt); err != nil { + return nil, fmt.Errorf("scan decision record: %w", err) + } + parsed, err := parseTime(createdAt) + if err != nil { + return nil, fmt.Errorf("parse decision created_at: %w", err) + } + item.Action = model.DecisionAction(action) + item.Enforced = enforced != 0 + item.CreatedAt = parsed + items = append(items, item) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate decisions for ip %q: %w", ip, err) + } + return items, nil +} + +func (s *Store) listBackendActionsForIP(ctx context.Context, ip string, limit int) ([]model.OPNsenseAction, error) { + if limit <= 0 { + limit = 50 + } + rows, err := s.db.QueryContext(ctx, ` + SELECT id, ip, action, result, message, created_at + FROM backend_actions + WHERE ip = ? + ORDER BY created_at DESC, id DESC + LIMIT ?`, + ip, + limit, + ) + if err != nil { + return nil, fmt.Errorf("list backend actions for ip %q: %w", ip, err) + } + defer rows.Close() + + items := make([]model.OPNsenseAction, 0, limit) + for rows.Next() { + var item model.OPNsenseAction + var createdAt string + if err := rows.Scan(&item.ID, &item.IP, &item.Action, &item.Result, &item.Message, &createdAt); err != nil { + return nil, fmt.Errorf("scan backend action: %w", err) + } + parsed, err := parseTime(createdAt) + if err != nil { + return nil, fmt.Errorf("parse backend action created_at: %w", err) + } + item.CreatedAt = parsed + items = append(items, item) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate backend actions for ip %q: %w", ip, err) + } + return items, nil +} + +func getIPStateDB(ctx context.Context, db queryer, ip string) (model.IPState, bool, error) { + row := db.QueryRowContext(ctx, ` + SELECT ip, first_seen_at, last_seen_at, last_source_name, last_user_agent, latest_status, + total_events, state, state_reason, manual_override, last_event_id, updated_at + FROM ip_state WHERE ip = ?`, ip) + + var item model.IPState + var firstSeenAt string + var lastSeenAt string + var updatedAt string + var state string + var manualOverride string + if err := row.Scan( + &item.IP, + &firstSeenAt, + &lastSeenAt, + &item.LastSourceName, + &item.LastUserAgent, + &item.LatestStatus, + &item.TotalEvents, + &state, + &item.StateReason, + &manualOverride, + &item.LastEventID, + &updatedAt, + ); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return model.IPState{}, false, nil + } + return model.IPState{}, false, fmt.Errorf("query ip state %q: %w", ip, err) + } + + var err error + item.FirstSeenAt, err = parseTime(firstSeenAt) + if err != nil { + return model.IPState{}, false, fmt.Errorf("parse ip state first_seen_at: %w", err) + } + item.LastSeenAt, err = parseTime(lastSeenAt) + if err != nil { + return model.IPState{}, false, fmt.Errorf("parse ip state last_seen_at: %w", err) + } + item.UpdatedAt, err = parseTime(updatedAt) + if err != nil { + return model.IPState{}, false, fmt.Errorf("parse ip state updated_at: %w", err) + } + item.State = model.IPStateStatus(state) + item.ManualOverride = model.ManualOverride(manualOverride) + return item, true, nil +} + +func getIPStateTx(tx *sql.Tx, ip string) (model.IPState, bool, error) { + return getIPStateDB(context.Background(), tx, ip) +} + +func upsertIPStateTx(ctx context.Context, tx *sql.Tx, state model.IPState) error { + if state.UpdatedAt.IsZero() { + state.UpdatedAt = time.Now().UTC() + } + if state.FirstSeenAt.IsZero() { + state.FirstSeenAt = state.UpdatedAt + } + if state.LastSeenAt.IsZero() { + state.LastSeenAt = state.UpdatedAt + } + if state.State == "" { + state.State = model.IPStateObserved + } + if state.ManualOverride == "" { + state.ManualOverride = model.ManualOverrideNone + } + _, err := tx.ExecContext( + ctx, + `INSERT INTO ip_state ( + ip, first_seen_at, last_seen_at, last_source_name, last_user_agent, latest_status, + total_events, state, state_reason, manual_override, last_event_id, updated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(ip) DO UPDATE SET + first_seen_at = excluded.first_seen_at, + last_seen_at = excluded.last_seen_at, + last_source_name = excluded.last_source_name, + last_user_agent = excluded.last_user_agent, + latest_status = excluded.latest_status, + total_events = excluded.total_events, + state = excluded.state, + state_reason = excluded.state_reason, + manual_override = excluded.manual_override, + last_event_id = excluded.last_event_id, + updated_at = excluded.updated_at`, + state.IP, + formatTime(state.FirstSeenAt), + formatTime(state.LastSeenAt), + state.LastSourceName, + state.LastUserAgent, + state.LatestStatus, + state.TotalEvents, + string(state.State), + state.StateReason, + string(state.ManualOverride), + state.LastEventID, + formatTime(state.UpdatedAt), + ) + if err != nil { + return fmt.Errorf("upsert ip state %q: %w", state.IP, err) + } + return nil +} + +func mergeEventIntoState(existing model.IPState, found bool, event model.Event) model.IPState { + now := time.Now().UTC() + state := existing + if !found { + state = model.IPState{ + IP: event.ClientIP, + FirstSeenAt: event.OccurredAt, + LastSeenAt: event.OccurredAt, + LastSourceName: event.SourceName, + LastUserAgent: event.UserAgent, + LatestStatus: event.Status, + TotalEvents: 0, + State: model.IPStateObserved, + StateReason: "", + ManualOverride: model.ManualOverrideNone, + LastEventID: 0, + UpdatedAt: now, + } + } + if state.FirstSeenAt.IsZero() || event.OccurredAt.Before(state.FirstSeenAt) { + state.FirstSeenAt = event.OccurredAt + } + if state.LastSeenAt.IsZero() || event.OccurredAt.After(state.LastSeenAt) { + state.LastSeenAt = event.OccurredAt + } + state.LastSourceName = event.SourceName + state.LastUserAgent = event.UserAgent + state.LatestStatus = event.Status + state.TotalEvents++ + state.LastEventID = event.ID + state.UpdatedAt = now + if state.ManualOverride == "" { + state.ManualOverride = model.ManualOverrideNone + } + + switch state.ManualOverride { + case model.ManualOverrideForceBlock: + state.State = model.IPStateBlocked + if event.DecisionReason != "" { + state.StateReason = event.DecisionReason + } else if state.StateReason == "" { + state.StateReason = "manual override: block" + } + return state + case model.ManualOverrideForceAllow: + state.State = model.IPStateAllowed + if event.DecisionReason != "" { + state.StateReason = event.DecisionReason + } else if state.StateReason == "" { + state.StateReason = "manual override: allow" + } + return state + } + + switch event.Decision { + case model.DecisionActionBlock: + state.State = model.IPStateBlocked + state.StateReason = event.DecisionReason + case model.DecisionActionReview: + if state.State != model.IPStateBlocked && state.State != model.IPStateAllowed { + state.State = model.IPStateReview + state.StateReason = event.DecisionReason + } + case model.DecisionActionAllow: + state.State = model.IPStateAllowed + state.StateReason = event.DecisionReason + default: + if state.State == "" { + state.State = model.IPStateObserved + } + } + return state +} + +func scanEvent(scanner interface{ Scan(dest ...any) error }) (model.Event, error) { + var item model.Event + var occurredAt string + var createdAt string + var decision string + var decisionReasonsJSON string + var enforced int + var currentState string + var manualOverride string + if err := scanner.Scan( + &item.ID, + &item.SourceName, + &item.ProfileName, + &occurredAt, + &item.RemoteIP, + &item.ClientIP, + &item.Host, + &item.Method, + &item.URI, + &item.Path, + &item.Status, + &item.UserAgent, + &decision, + &item.DecisionReason, + &decisionReasonsJSON, + &enforced, + &item.RawJSON, + &createdAt, + ¤tState, + &manualOverride, + ); err != nil { + return model.Event{}, fmt.Errorf("scan event: %w", err) + } + parsedOccurredAt, err := parseTime(occurredAt) + if err != nil { + return model.Event{}, fmt.Errorf("parse event occurred_at: %w", err) + } + parsedCreatedAt, err := parseTime(createdAt) + if err != nil { + return model.Event{}, fmt.Errorf("parse event created_at: %w", err) + } + var reasons []string + if strings.TrimSpace(decisionReasonsJSON) != "" { + if err := json.Unmarshal([]byte(decisionReasonsJSON), &reasons); err != nil { + return model.Event{}, fmt.Errorf("decode event decision_reasons_json: %w", err) + } + } + item.OccurredAt = parsedOccurredAt + item.CreatedAt = parsedCreatedAt + item.Decision = model.DecisionAction(decision) + item.DecisionReasons = reasons + item.Enforced = enforced != 0 + item.CurrentState = model.IPStateStatus(currentState) + item.ManualOverride = model.ManualOverride(manualOverride) + return item, nil +} + +func scanIPState(scanner interface{ Scan(dest ...any) error }) (model.IPState, error) { + var item model.IPState + var firstSeenAt string + var lastSeenAt string + var updatedAt string + var state string + var manualOverride string + if err := scanner.Scan( + &item.IP, + &firstSeenAt, + &lastSeenAt, + &item.LastSourceName, + &item.LastUserAgent, + &item.LatestStatus, + &item.TotalEvents, + &state, + &item.StateReason, + &manualOverride, + &item.LastEventID, + &updatedAt, + ); err != nil { + return model.IPState{}, fmt.Errorf("scan ip state: %w", err) + } + var err error + item.FirstSeenAt, err = parseTime(firstSeenAt) + if err != nil { + return model.IPState{}, fmt.Errorf("parse ip state first_seen_at: %w", err) + } + item.LastSeenAt, err = parseTime(lastSeenAt) + if err != nil { + return model.IPState{}, fmt.Errorf("parse ip state last_seen_at: %w", err) + } + item.UpdatedAt, err = parseTime(updatedAt) + if err != nil { + return model.IPState{}, fmt.Errorf("parse ip state updated_at: %w", err) + } + item.State = model.IPStateStatus(state) + item.ManualOverride = model.ManualOverride(manualOverride) + return item, nil +} + +func parseTime(value string) (time.Time, error) { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return time.Time{}, nil + } + parsed, err := time.Parse(time.RFC3339Nano, trimmed) + if err != nil { + return time.Time{}, err + } + return parsed.UTC(), nil +} + +func formatTime(value time.Time) string { + if value.IsZero() { + return "" + } + return value.UTC().Format(time.RFC3339Nano) +} + +func boolToInt(value bool) int { + if value { + return 1 + } + return 0 +} + +type queryer interface { + QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row +} diff --git a/internal/store/store_test.go b/internal/store/store_test.go new file mode 100644 index 0000000..9c6adac --- /dev/null +++ b/internal/store/store_test.go @@ -0,0 +1,116 @@ +package store + +import ( + "context" + "path/filepath" + "testing" + "time" + + "git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/model" +) + +func TestStoreRecordsEventsAndState(t *testing.T) { + t.Parallel() + + dbPath := filepath.Join(t.TempDir(), "blocker.db") + db, err := Open(dbPath) + if err != nil { + t.Fatalf("open store: %v", err) + } + defer db.Close() + + ctx := context.Background() + occurredAt := time.Date(2025, 3, 11, 12, 0, 0, 0, time.UTC) + event := &model.Event{ + SourceName: "main", + ProfileName: "main", + OccurredAt: occurredAt, + RemoteIP: "198.51.100.10", + ClientIP: "203.0.113.10", + Host: "example.test", + Method: "GET", + URI: "/wp-login.php", + Path: "/wp-login.php", + Status: 404, + UserAgent: "curl/8.0", + Decision: model.DecisionActionBlock, + DecisionReason: "php_path", + DecisionReasons: []string{"php_path"}, + Enforced: true, + RawJSON: `{"status":404}`, + } + if err := db.RecordEvent(ctx, event); err != nil { + t.Fatalf("record event: %v", err) + } + if event.ID == 0 { + t.Fatalf("expected inserted event ID") + } + + state, found, err := db.GetIPState(ctx, "203.0.113.10") + if err != nil { + t.Fatalf("get ip state: %v", err) + } + if !found { + t.Fatalf("expected IP state to exist") + } + if state.State != model.IPStateBlocked { + t.Fatalf("unexpected ip state: %+v", state) + } + if state.TotalEvents != 1 { + t.Fatalf("unexpected total events: %d", state.TotalEvents) + } + + if _, err := db.SetManualOverride(ctx, "203.0.113.10", model.ManualOverrideForceAllow, model.IPStateAllowed, "manual allow"); err != nil { + t.Fatalf("set manual override: %v", err) + } + state, found, err = db.GetIPState(ctx, "203.0.113.10") + if err != nil || !found { + t.Fatalf("get overridden ip state: found=%v err=%v", found, err) + } + if state.ManualOverride != model.ManualOverrideForceAllow { + t.Fatalf("unexpected override after set: %+v", state) + } + + if _, err := db.ClearManualOverride(ctx, "203.0.113.10", "reset"); err != nil { + t.Fatalf("clear manual override: %v", err) + } + state, found, err = db.GetIPState(ctx, "203.0.113.10") + if err != nil || !found { + t.Fatalf("get reset ip state: found=%v err=%v", found, err) + } + if state.ManualOverride != model.ManualOverrideNone { + t.Fatalf("expected cleared override, got %+v", state) + } + + if err := db.AddDecision(ctx, &model.DecisionRecord{EventID: event.ID, IP: event.ClientIP, SourceName: event.SourceName, Kind: "automatic", Action: model.DecisionActionBlock, Reason: "php_path", Actor: "engine", Enforced: true}); err != nil { + t.Fatalf("add decision: %v", err) + } + if err := db.AddBackendAction(ctx, &model.OPNsenseAction{IP: event.ClientIP, Action: "block", Result: "added", Message: "php_path"}); err != nil { + t.Fatalf("add backend action: %v", err) + } + if err := db.SaveSourceOffset(ctx, model.SourceOffset{SourceName: "main", Path: "/tmp/main.log", Inode: "1:2", Offset: 42, UpdatedAt: occurredAt}); err != nil { + t.Fatalf("save source offset: %v", err) + } + offset, found, err := db.GetSourceOffset(ctx, "main") + if err != nil { + t.Fatalf("get source offset: %v", err) + } + if !found || offset.Offset != 42 { + t.Fatalf("unexpected source offset: found=%v offset=%+v", found, offset) + } + + overview, err := db.GetOverview(ctx, 10) + if err != nil { + t.Fatalf("get overview: %v", err) + } + if overview.TotalEvents != 1 || overview.TotalIPs != 1 { + t.Fatalf("unexpected overview counters: %+v", overview) + } + details, err := db.GetIPDetails(ctx, event.ClientIP, 10, 10, 10) + if err != nil { + t.Fatalf("get ip details: %v", err) + } + if len(details.RecentEvents) != 1 || len(details.Decisions) != 1 || len(details.BackendActions) != 1 { + t.Fatalf("unexpected ip details: %+v", details) + } +} diff --git a/internal/web/handler.go b/internal/web/handler.go new file mode 100644 index 0000000..c3f3d51 --- /dev/null +++ b/internal/web/handler.go @@ -0,0 +1,583 @@ +package web + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "html/template" + "io" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/model" +) + +type App interface { + GetOverview(ctx context.Context, limit int) (model.Overview, error) + ListEvents(ctx context.Context, limit int) ([]model.Event, error) + ListIPs(ctx context.Context, limit int, state string) ([]model.IPState, error) + GetIPDetails(ctx context.Context, ip string) (model.IPDetails, error) + ForceBlock(ctx context.Context, ip string, actor string, reason string) error + ForceAllow(ctx context.Context, ip string, actor string, reason string) error + ClearOverride(ctx context.Context, ip string, actor string, reason string) error +} + +type handler struct { + app App + overviewPage *template.Template + ipDetailsPage *template.Template +} + +type pageData struct { + Title string + IP string +} + +type actionPayload struct { + Reason string `json:"reason"` + Actor string `json:"actor"` +} + +func NewHandler(app App) http.Handler { + h := &handler{ + app: app, + overviewPage: template.Must(template.New("overview").Parse(overviewHTML)), + ipDetailsPage: template.Must(template.New("ip-details").Parse(ipDetailsHTML)), + } + + mux := http.NewServeMux() + mux.HandleFunc("/", h.handleOverviewPage) + mux.HandleFunc("/healthz", h.handleHealth) + mux.HandleFunc("/ips/", h.handleIPPage) + mux.HandleFunc("/api/overview", h.handleAPIOverview) + mux.HandleFunc("/api/events", h.handleAPIEvents) + mux.HandleFunc("/api/ips", h.handleAPIIPs) + mux.HandleFunc("/api/ips/", h.handleAPIIP) + return mux +} + +func (h *handler) handleOverviewPage(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/" { + http.NotFound(w, r) + return + } + if r.Method != http.MethodGet { + methodNotAllowed(w) + return + } + renderTemplate(w, h.overviewPage, pageData{Title: "Caddy OPNsense Blocker"}) +} + +func (h *handler) handleIPPage(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + methodNotAllowed(w) + return + } + ip, ok := extractPathValue(r.URL.Path, "/ips/") + if !ok { + http.NotFound(w, r) + return + } + renderTemplate(w, h.ipDetailsPage, pageData{Title: "IP details", IP: ip}) +} + +func (h *handler) handleHealth(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + methodNotAllowed(w) + return + } + writeJSON(w, http.StatusOK, map[string]any{"status": "ok", "time": time.Now().UTC()}) +} + +func (h *handler) handleAPIOverview(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + methodNotAllowed(w) + return + } + limit := queryLimit(r, 50) + overview, err := h.app.GetOverview(r.Context(), limit) + if err != nil { + writeError(w, http.StatusInternalServerError, err) + return + } + writeJSON(w, http.StatusOK, overview) +} + +func (h *handler) handleAPIEvents(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + methodNotAllowed(w) + return + } + limit := queryLimit(r, 100) + events, err := h.app.ListEvents(r.Context(), limit) + if err != nil { + writeError(w, http.StatusInternalServerError, err) + return + } + writeJSON(w, http.StatusOK, events) +} + +func (h *handler) handleAPIIPs(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/ips" { + http.NotFound(w, r) + return + } + if r.Method != http.MethodGet { + methodNotAllowed(w) + return + } + limit := queryLimit(r, 100) + state := strings.TrimSpace(r.URL.Query().Get("state")) + items, err := h.app.ListIPs(r.Context(), limit, state) + if err != nil { + writeError(w, http.StatusInternalServerError, err) + return + } + writeJSON(w, http.StatusOK, items) +} + +func (h *handler) handleAPIIP(w http.ResponseWriter, r *http.Request) { + ip, action, ok := extractAPIPath(r.URL.Path) + if !ok { + http.NotFound(w, r) + return + } + + if action == "" { + if r.Method != http.MethodGet { + methodNotAllowed(w) + return + } + details, err := h.app.GetIPDetails(r.Context(), ip) + if err != nil { + writeError(w, http.StatusInternalServerError, err) + return + } + writeJSON(w, http.StatusOK, details) + return + } + + if r.Method != http.MethodPost { + methodNotAllowed(w) + return + } + payload, err := decodeActionPayload(r) + if err != nil { + writeError(w, http.StatusBadRequest, err) + return + } + switch action { + case "block": + err = h.app.ForceBlock(r.Context(), ip, payload.Actor, payload.Reason) + case "unblock": + err = h.app.ForceAllow(r.Context(), ip, payload.Actor, payload.Reason) + case "reset": + err = h.app.ClearOverride(r.Context(), ip, payload.Actor, payload.Reason) + default: + http.NotFound(w, r) + return + } + if err != nil { + writeError(w, http.StatusInternalServerError, err) + return + } + details, err := h.app.GetIPDetails(r.Context(), ip) + if err != nil { + writeError(w, http.StatusInternalServerError, err) + return + } + writeJSON(w, http.StatusOK, details) +} + +func decodeActionPayload(r *http.Request) (actionPayload, error) { + defer r.Body.Close() + var payload actionPayload + if r.ContentLength == 0 { + return payload, nil + } + decoder := json.NewDecoder(io.LimitReader(r.Body, 1<<20)) + if err := decoder.Decode(&payload); err != nil { + if errors.Is(err, io.EOF) { + return payload, nil + } + return actionPayload{}, fmt.Errorf("decode request body: %w", err) + } + return payload, nil +} + +func extractPathValue(path string, prefix string) (string, bool) { + if !strings.HasPrefix(path, prefix) { + return "", false + } + rest := strings.TrimPrefix(path, prefix) + rest = strings.Trim(rest, "/") + if rest == "" { + return "", false + } + decoded, err := url.PathUnescape(rest) + if err != nil { + return "", false + } + return decoded, true +} + +func extractAPIPath(path string) (ip string, action string, ok bool) { + if !strings.HasPrefix(path, "/api/ips/") { + return "", "", false + } + rest := strings.TrimPrefix(path, "/api/ips/") + rest = strings.Trim(rest, "/") + if rest == "" { + return "", "", false + } + parts := strings.Split(rest, "/") + decoded, err := url.PathUnescape(parts[0]) + if err != nil { + return "", "", false + } + if len(parts) == 1 { + return decoded, "", true + } + if len(parts) == 2 { + return decoded, parts[1], true + } + return "", "", false +} + +func queryLimit(r *http.Request, fallback int) int { + value := strings.TrimSpace(r.URL.Query().Get("limit")) + if value == "" { + return fallback + } + parsed, err := strconv.Atoi(value) + if err != nil || parsed <= 0 { + return fallback + } + if parsed > 500 { + return 500 + } + return parsed +} + +func writeJSON(w http.ResponseWriter, status int, payload any) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(payload) +} + +func writeError(w http.ResponseWriter, status int, err error) { + writeJSON(w, status, map[string]any{"error": err.Error()}) +} + +func methodNotAllowed(w http.ResponseWriter) { + writeError(w, http.StatusMethodNotAllowed, errors.New("method not allowed")) +} + +func renderTemplate(w http.ResponseWriter, tmpl *template.Template, data pageData) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + if err := tmpl.Execute(w, data); err != nil { + writeError(w, http.StatusInternalServerError, err) + } +} + +const overviewHTML = ` + + + + + {{ .Title }} + + + +
+

{{ .Title }}

+
Local-only review and enforcement console
+
+
+
+
+

Recent IPs

+ + + + + +
IPStateOverrideEventsLast seenReasonActions
+
+
+

Recent Events

+ + + + + +
TimeSourceIPHostMethodPathStatusDecision
+
+
+ + +` + +const ipDetailsHTML = ` + + + + + {{ .Title }} + + + +
+
← Back
+

{{ .IP }}

+
+
+
+

State

+
+
+ + + +
+
+
+

Recent events

+ + + + + +
TimeSourceMethodPathStatusDecision
+
+
+

Decisions

+ + + + + +
TimeKindActionReasonActor
+
+
+

Backend actions

+ + + + + +
TimeActionResultMessage
+
+
+ + +` diff --git a/internal/web/handler_test.go b/internal/web/handler_test.go new file mode 100644 index 0000000..2db442e --- /dev/null +++ b/internal/web/handler_test.go @@ -0,0 +1,124 @@ +package web + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/model" +) + +func TestHandlerServesOverviewAndManualActions(t *testing.T) { + t.Parallel() + + app := &stubApp{} + handler := NewHandler(app) + + recorder := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, "/api/overview?limit=10", nil) + handler.ServeHTTP(recorder, request) + if recorder.Code != http.StatusOK { + t.Fatalf("unexpected overview status: %d", recorder.Code) + } + var overview model.Overview + if err := json.Unmarshal(recorder.Body.Bytes(), &overview); err != nil { + t.Fatalf("decode overview payload: %v", err) + } + if overview.TotalEvents != 1 || len(overview.RecentIPs) != 1 { + t.Fatalf("unexpected overview payload: %+v", overview) + } + + recorder = httptest.NewRecorder() + request = httptest.NewRequest(http.MethodPost, "/api/ips/203.0.113.10/block", strings.NewReader(`{"reason":"test reason","actor":"tester"}`)) + request.Header.Set("Content-Type", "application/json") + handler.ServeHTTP(recorder, request) + if recorder.Code != http.StatusOK { + t.Fatalf("unexpected block status: %d body=%s", recorder.Code, recorder.Body.String()) + } + if app.lastAction != "block:203.0.113.10:tester:test reason" { + t.Fatalf("unexpected recorded action: %q", app.lastAction) + } + + recorder = httptest.NewRecorder() + request = httptest.NewRequest(http.MethodGet, "/", nil) + handler.ServeHTTP(recorder, request) + if recorder.Code != http.StatusOK { + t.Fatalf("unexpected overview page status: %d", recorder.Code) + } + if !strings.Contains(recorder.Body.String(), "Local-only review and enforcement console") { + t.Fatalf("overview page did not render expected content") + } +} + +type stubApp struct { + lastAction string +} + +func (s *stubApp) GetOverview(context.Context, int) (model.Overview, error) { + now := time.Now().UTC() + return model.Overview{ + TotalEvents: 1, + TotalIPs: 1, + BlockedIPs: 1, + RecentIPs: []model.IPState{{ + IP: "203.0.113.10", + State: model.IPStateBlocked, + ManualOverride: model.ManualOverrideNone, + TotalEvents: 1, + LastSeenAt: now, + }}, + RecentEvents: []model.Event{{ + ID: 1, + SourceName: "main", + ClientIP: "203.0.113.10", + OccurredAt: now, + Decision: model.DecisionActionBlock, + CurrentState: model.IPStateBlocked, + }}, + }, nil +} + +func (s *stubApp) ListEvents(ctx context.Context, limit int) ([]model.Event, error) { + overview, _ := s.GetOverview(ctx, limit) + return overview.RecentEvents, nil +} + +func (s *stubApp) ListIPs(ctx context.Context, limit int, state string) ([]model.IPState, error) { + overview, _ := s.GetOverview(ctx, limit) + return overview.RecentIPs, nil +} + +func (s *stubApp) GetIPDetails(context.Context, string) (model.IPDetails, error) { + now := time.Now().UTC() + return model.IPDetails{ + State: model.IPState{ + IP: "203.0.113.10", + State: model.IPStateBlocked, + ManualOverride: model.ManualOverrideNone, + TotalEvents: 1, + LastSeenAt: now, + }, + RecentEvents: []model.Event{{ID: 1, ClientIP: "203.0.113.10", OccurredAt: now, Decision: model.DecisionActionBlock}}, + Decisions: []model.DecisionRecord{{ID: 1, IP: "203.0.113.10", Action: model.DecisionActionBlock, CreatedAt: now}}, + BackendActions: []model.OPNsenseAction{{ID: 1, IP: "203.0.113.10", Action: "block", Result: "added", CreatedAt: now}}, + }, nil +} + +func (s *stubApp) ForceBlock(_ context.Context, ip string, actor string, reason string) error { + s.lastAction = "block:" + ip + ":" + actor + ":" + reason + return nil +} + +func (s *stubApp) ForceAllow(_ context.Context, ip string, actor string, reason string) error { + s.lastAction = "allow:" + ip + ":" + actor + ":" + reason + return nil +} + +func (s *stubApp) ClearOverride(_ context.Context, ip string, actor string, reason string) error { + s.lastAction = "reset:" + ip + ":" + actor + ":" + reason + return nil +}