You've already forked caddy-opnsense-blocker
Build initial caddy-opnsense-blocker daemon
This commit is contained in:
961
internal/store/store.go
Normal file
961
internal/store/store.go
Normal file
@@ -0,0 +1,961 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/model"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
const schema = `
|
||||
CREATE TABLE IF NOT EXISTS events (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
source_name TEXT NOT NULL,
|
||||
profile_name TEXT NOT NULL,
|
||||
occurred_at TEXT NOT NULL,
|
||||
remote_ip TEXT NOT NULL,
|
||||
client_ip TEXT NOT NULL,
|
||||
host TEXT NOT NULL,
|
||||
method TEXT NOT NULL,
|
||||
uri TEXT NOT NULL,
|
||||
path TEXT NOT NULL,
|
||||
status INTEGER NOT NULL,
|
||||
user_agent TEXT NOT NULL,
|
||||
decision TEXT NOT NULL,
|
||||
decision_reason TEXT NOT NULL,
|
||||
decision_reasons_json TEXT NOT NULL,
|
||||
enforced INTEGER NOT NULL DEFAULT 0,
|
||||
raw_json TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_events_occurred_at ON events(occurred_at DESC, id DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_events_client_ip ON events(client_ip, occurred_at DESC, id DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_events_source_name ON events(source_name, occurred_at DESC, id DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_events_decision ON events(decision, occurred_at DESC, id DESC);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS ip_state (
|
||||
ip TEXT PRIMARY KEY,
|
||||
first_seen_at TEXT NOT NULL,
|
||||
last_seen_at TEXT NOT NULL,
|
||||
last_source_name TEXT NOT NULL,
|
||||
last_user_agent TEXT NOT NULL,
|
||||
latest_status INTEGER NOT NULL,
|
||||
total_events INTEGER NOT NULL,
|
||||
state TEXT NOT NULL,
|
||||
state_reason TEXT NOT NULL,
|
||||
manual_override TEXT NOT NULL,
|
||||
last_event_id INTEGER NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_ip_state_last_seen ON ip_state(last_seen_at DESC, ip ASC);
|
||||
CREATE INDEX IF NOT EXISTS idx_ip_state_state ON ip_state(state, last_seen_at DESC, ip ASC);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS decisions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
event_id INTEGER NOT NULL,
|
||||
ip TEXT NOT NULL,
|
||||
source_name TEXT NOT NULL,
|
||||
kind TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
reason TEXT NOT NULL,
|
||||
actor TEXT NOT NULL,
|
||||
enforced INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_decisions_ip ON decisions(ip, created_at DESC, id DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_decisions_event_id ON decisions(event_id, created_at DESC, id DESC);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS backend_actions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
ip TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
result TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_backend_actions_ip ON backend_actions(ip, created_at DESC, id DESC);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS source_offsets (
|
||||
source_name TEXT PRIMARY KEY,
|
||||
path TEXT NOT NULL,
|
||||
inode TEXT NOT NULL,
|
||||
offset INTEGER NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
);
|
||||
`
|
||||
|
||||
type Store struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func Open(path string) (*Store, error) {
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create storage directory: %w", err)
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlite", path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open sqlite database: %w", err)
|
||||
}
|
||||
db.SetMaxOpenConns(1)
|
||||
db.SetMaxIdleConns(1)
|
||||
db.SetConnMaxLifetime(0)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
for _, statement := range []string{
|
||||
"PRAGMA journal_mode = WAL;",
|
||||
"PRAGMA busy_timeout = 5000;",
|
||||
"PRAGMA foreign_keys = ON;",
|
||||
} {
|
||||
if _, err := db.ExecContext(ctx, statement); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("apply sqlite pragma %q: %w", statement, err)
|
||||
}
|
||||
}
|
||||
if _, err := db.ExecContext(ctx, schema); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("apply sqlite schema: %w", err)
|
||||
}
|
||||
|
||||
return &Store{db: db}, nil
|
||||
}
|
||||
|
||||
func (s *Store) Close() error {
|
||||
if s == nil || s.db == nil {
|
||||
return nil
|
||||
}
|
||||
return s.db.Close()
|
||||
}
|
||||
|
||||
func (s *Store) RecordEvent(ctx context.Context, event *model.Event) error {
|
||||
if event == nil {
|
||||
return errors.New("nil event")
|
||||
}
|
||||
if event.OccurredAt.IsZero() {
|
||||
event.OccurredAt = time.Now().UTC()
|
||||
}
|
||||
if event.CreatedAt.IsZero() {
|
||||
event.CreatedAt = time.Now().UTC()
|
||||
}
|
||||
encodedReasons, err := json.Marshal(event.DecisionReasons)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encode decision reasons: %w", err)
|
||||
}
|
||||
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
state, found, err := getIPStateTx(tx, event.ClientIP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result, err := tx.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO events (
|
||||
source_name, profile_name, occurred_at, remote_ip, client_ip, host, method, uri, path,
|
||||
status, user_agent, decision, decision_reason, decision_reasons_json, enforced, raw_json, created_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
event.SourceName,
|
||||
event.ProfileName,
|
||||
formatTime(event.OccurredAt),
|
||||
event.RemoteIP,
|
||||
event.ClientIP,
|
||||
event.Host,
|
||||
event.Method,
|
||||
event.URI,
|
||||
event.Path,
|
||||
event.Status,
|
||||
event.UserAgent,
|
||||
string(event.Decision),
|
||||
event.DecisionReason,
|
||||
string(encodedReasons),
|
||||
boolToInt(event.Enforced),
|
||||
event.RawJSON,
|
||||
formatTime(event.CreatedAt),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert event: %w", err)
|
||||
}
|
||||
eventID, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return fmt.Errorf("load inserted event id: %w", err)
|
||||
}
|
||||
event.ID = eventID
|
||||
|
||||
updatedState := mergeEventIntoState(state, found, *event)
|
||||
event.CurrentState = updatedState.State
|
||||
event.ManualOverride = updatedState.ManualOverride
|
||||
|
||||
if err := upsertIPStateTx(ctx, tx, updatedState); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("commit event transaction: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) AddDecision(ctx context.Context, decision *model.DecisionRecord) error {
|
||||
if decision == nil {
|
||||
return errors.New("nil decision record")
|
||||
}
|
||||
if decision.CreatedAt.IsZero() {
|
||||
decision.CreatedAt = time.Now().UTC()
|
||||
}
|
||||
result, err := s.db.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO decisions (event_id, ip, source_name, kind, action, reason, actor, enforced, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
decision.EventID,
|
||||
decision.IP,
|
||||
decision.SourceName,
|
||||
decision.Kind,
|
||||
string(decision.Action),
|
||||
decision.Reason,
|
||||
decision.Actor,
|
||||
boolToInt(decision.Enforced),
|
||||
formatTime(decision.CreatedAt),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert decision record: %w", err)
|
||||
}
|
||||
decision.ID, err = result.LastInsertId()
|
||||
if err != nil {
|
||||
return fmt.Errorf("load inserted decision id: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) AddBackendAction(ctx context.Context, action *model.OPNsenseAction) error {
|
||||
if action == nil {
|
||||
return errors.New("nil backend action")
|
||||
}
|
||||
if action.CreatedAt.IsZero() {
|
||||
action.CreatedAt = time.Now().UTC()
|
||||
}
|
||||
result, err := s.db.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO backend_actions (ip, action, result, message, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
action.IP,
|
||||
action.Action,
|
||||
action.Result,
|
||||
action.Message,
|
||||
formatTime(action.CreatedAt),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert backend action: %w", err)
|
||||
}
|
||||
action.ID, err = result.LastInsertId()
|
||||
if err != nil {
|
||||
return fmt.Errorf("load inserted backend action id: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) GetIPState(ctx context.Context, ip string) (model.IPState, bool, error) {
|
||||
return getIPStateDB(ctx, s.db, ip)
|
||||
}
|
||||
|
||||
func (s *Store) SetManualOverride(ctx context.Context, ip string, override model.ManualOverride, state model.IPStateStatus, reason string) (model.IPState, error) {
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return model.IPState{}, fmt.Errorf("begin transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
current, found, err := getIPStateTx(tx, ip)
|
||||
if err != nil {
|
||||
return model.IPState{}, err
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
if !found {
|
||||
current = model.IPState{
|
||||
IP: ip,
|
||||
FirstSeenAt: now,
|
||||
LastSeenAt: now,
|
||||
LastSourceName: "",
|
||||
LastUserAgent: "",
|
||||
LatestStatus: 0,
|
||||
TotalEvents: 0,
|
||||
State: state,
|
||||
StateReason: strings.TrimSpace(reason),
|
||||
ManualOverride: override,
|
||||
LastEventID: 0,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
} else {
|
||||
current.ManualOverride = override
|
||||
current.State = state
|
||||
if strings.TrimSpace(reason) != "" {
|
||||
current.StateReason = strings.TrimSpace(reason)
|
||||
}
|
||||
current.UpdatedAt = now
|
||||
}
|
||||
if err := upsertIPStateTx(ctx, tx, current); err != nil {
|
||||
return model.IPState{}, err
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return model.IPState{}, fmt.Errorf("commit transaction: %w", err)
|
||||
}
|
||||
return current, nil
|
||||
}
|
||||
|
||||
func (s *Store) ClearManualOverride(ctx context.Context, ip string, reason string) (model.IPState, error) {
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return model.IPState{}, fmt.Errorf("begin transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
current, found, err := getIPStateTx(tx, ip)
|
||||
if err != nil {
|
||||
return model.IPState{}, err
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
if !found {
|
||||
current = model.IPState{
|
||||
IP: ip,
|
||||
FirstSeenAt: now,
|
||||
LastSeenAt: now,
|
||||
State: model.IPStateObserved,
|
||||
StateReason: strings.TrimSpace(reason),
|
||||
ManualOverride: model.ManualOverrideNone,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
} else {
|
||||
current.ManualOverride = model.ManualOverrideNone
|
||||
if current.State == "" {
|
||||
current.State = model.IPStateObserved
|
||||
}
|
||||
if strings.TrimSpace(reason) != "" {
|
||||
current.StateReason = strings.TrimSpace(reason)
|
||||
}
|
||||
current.UpdatedAt = now
|
||||
}
|
||||
if err := upsertIPStateTx(ctx, tx, current); err != nil {
|
||||
return model.IPState{}, err
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return model.IPState{}, fmt.Errorf("commit transaction: %w", err)
|
||||
}
|
||||
return current, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetOverview(ctx context.Context, limit int) (model.Overview, error) {
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
var overview model.Overview
|
||||
if err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM events`).Scan(&overview.TotalEvents); err != nil {
|
||||
return model.Overview{}, fmt.Errorf("count events: %w", err)
|
||||
}
|
||||
if err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM ip_state`).Scan(&overview.TotalIPs); err != nil {
|
||||
return model.Overview{}, fmt.Errorf("count ip states: %w", err)
|
||||
}
|
||||
if err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM ip_state WHERE state = ?`, string(model.IPStateBlocked)).Scan(&overview.BlockedIPs); err != nil {
|
||||
return model.Overview{}, fmt.Errorf("count blocked ip states: %w", err)
|
||||
}
|
||||
if err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM ip_state WHERE state = ?`, string(model.IPStateReview)).Scan(&overview.ReviewIPs); err != nil {
|
||||
return model.Overview{}, fmt.Errorf("count review ip states: %w", err)
|
||||
}
|
||||
if err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM ip_state WHERE state = ?`, string(model.IPStateAllowed)).Scan(&overview.AllowedIPs); err != nil {
|
||||
return model.Overview{}, fmt.Errorf("count allowed ip states: %w", err)
|
||||
}
|
||||
if err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM ip_state WHERE state = ?`, string(model.IPStateObserved)).Scan(&overview.ObservedIPs); err != nil {
|
||||
return model.Overview{}, fmt.Errorf("count observed ip states: %w", err)
|
||||
}
|
||||
|
||||
recentIPs, err := s.ListIPStates(ctx, limit, "")
|
||||
if err != nil {
|
||||
return model.Overview{}, err
|
||||
}
|
||||
recentEvents, err := s.ListRecentEvents(ctx, limit)
|
||||
if err != nil {
|
||||
return model.Overview{}, err
|
||||
}
|
||||
overview.RecentIPs = recentIPs
|
||||
overview.RecentEvents = recentEvents
|
||||
return overview, nil
|
||||
}
|
||||
|
||||
func (s *Store) ListRecentEvents(ctx context.Context, limit int) ([]model.Event, error) {
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
rows, err := s.db.QueryContext(ctx, `
|
||||
SELECT e.id, e.source_name, e.profile_name, e.occurred_at, e.remote_ip, e.client_ip, e.host,
|
||||
e.method, e.uri, e.path, e.status, e.user_agent, e.decision, e.decision_reason,
|
||||
e.decision_reasons_json, e.enforced, e.raw_json, e.created_at,
|
||||
COALESCE(s.state, ''), COALESCE(s.manual_override, '')
|
||||
FROM events e
|
||||
LEFT JOIN ip_state s ON s.ip = e.client_ip
|
||||
ORDER BY e.occurred_at DESC, e.id DESC
|
||||
LIMIT ?`,
|
||||
limit,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list recent events: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
items := make([]model.Event, 0, limit)
|
||||
for rows.Next() {
|
||||
item, err := scanEvent(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate recent events: %w", err)
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (s *Store) ListIPStates(ctx context.Context, limit int, stateFilter string) ([]model.IPState, error) {
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
query := `SELECT ip, first_seen_at, last_seen_at, last_source_name, last_user_agent, latest_status,
|
||||
total_events, state, state_reason, manual_override, last_event_id, updated_at
|
||||
FROM ip_state`
|
||||
args := []any{}
|
||||
if strings.TrimSpace(stateFilter) != "" {
|
||||
query += ` WHERE state = ?`
|
||||
args = append(args, strings.TrimSpace(stateFilter))
|
||||
}
|
||||
query += ` ORDER BY last_seen_at DESC, ip ASC LIMIT ?`
|
||||
args = append(args, limit)
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list ip states: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
items := make([]model.IPState, 0, limit)
|
||||
for rows.Next() {
|
||||
item, err := scanIPState(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate ip states: %w", err)
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetIPDetails(ctx context.Context, ip string, eventLimit, decisionLimit, actionLimit int) (model.IPDetails, error) {
|
||||
state, _, err := s.GetIPState(ctx, ip)
|
||||
if err != nil {
|
||||
return model.IPDetails{}, err
|
||||
}
|
||||
events, err := s.listEventsForIP(ctx, ip, eventLimit)
|
||||
if err != nil {
|
||||
return model.IPDetails{}, err
|
||||
}
|
||||
decisions, err := s.listDecisionsForIP(ctx, ip, decisionLimit)
|
||||
if err != nil {
|
||||
return model.IPDetails{}, err
|
||||
}
|
||||
actions, err := s.listBackendActionsForIP(ctx, ip, actionLimit)
|
||||
if err != nil {
|
||||
return model.IPDetails{}, err
|
||||
}
|
||||
return model.IPDetails{
|
||||
State: state,
|
||||
RecentEvents: events,
|
||||
Decisions: decisions,
|
||||
BackendActions: actions,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetSourceOffset(ctx context.Context, sourceName string) (model.SourceOffset, bool, error) {
|
||||
row := s.db.QueryRowContext(ctx, `SELECT source_name, path, inode, offset, updated_at FROM source_offsets WHERE source_name = ?`, sourceName)
|
||||
var offset model.SourceOffset
|
||||
var updatedAt string
|
||||
if err := row.Scan(&offset.SourceName, &offset.Path, &offset.Inode, &offset.Offset, &updatedAt); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return model.SourceOffset{}, false, nil
|
||||
}
|
||||
return model.SourceOffset{}, false, fmt.Errorf("query source offset %q: %w", sourceName, err)
|
||||
}
|
||||
parsed, err := parseTime(updatedAt)
|
||||
if err != nil {
|
||||
return model.SourceOffset{}, false, fmt.Errorf("parse source offset updated_at: %w", err)
|
||||
}
|
||||
offset.UpdatedAt = parsed
|
||||
return offset, true, nil
|
||||
}
|
||||
|
||||
func (s *Store) SaveSourceOffset(ctx context.Context, offset model.SourceOffset) error {
|
||||
if offset.UpdatedAt.IsZero() {
|
||||
offset.UpdatedAt = time.Now().UTC()
|
||||
}
|
||||
_, err := s.db.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO source_offsets (source_name, path, inode, offset, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(source_name) DO UPDATE SET
|
||||
path = excluded.path,
|
||||
inode = excluded.inode,
|
||||
offset = excluded.offset,
|
||||
updated_at = excluded.updated_at`,
|
||||
offset.SourceName,
|
||||
offset.Path,
|
||||
offset.Inode,
|
||||
offset.Offset,
|
||||
formatTime(offset.UpdatedAt),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("upsert source offset: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) listEventsForIP(ctx context.Context, ip string, limit int) ([]model.Event, error) {
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
rows, err := s.db.QueryContext(ctx, `
|
||||
SELECT e.id, e.source_name, e.profile_name, e.occurred_at, e.remote_ip, e.client_ip, e.host,
|
||||
e.method, e.uri, e.path, e.status, e.user_agent, e.decision, e.decision_reason,
|
||||
e.decision_reasons_json, e.enforced, e.raw_json, e.created_at,
|
||||
COALESCE(s.state, ''), COALESCE(s.manual_override, '')
|
||||
FROM events e
|
||||
LEFT JOIN ip_state s ON s.ip = e.client_ip
|
||||
WHERE e.client_ip = ?
|
||||
ORDER BY e.occurred_at DESC, e.id DESC
|
||||
LIMIT ?`,
|
||||
ip,
|
||||
limit,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list events for ip %q: %w", ip, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
items := make([]model.Event, 0, limit)
|
||||
for rows.Next() {
|
||||
item, err := scanEvent(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate events for ip %q: %w", ip, err)
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (s *Store) listDecisionsForIP(ctx context.Context, ip string, limit int) ([]model.DecisionRecord, error) {
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
rows, err := s.db.QueryContext(ctx, `
|
||||
SELECT id, event_id, ip, source_name, kind, action, reason, actor, enforced, created_at
|
||||
FROM decisions
|
||||
WHERE ip = ?
|
||||
ORDER BY created_at DESC, id DESC
|
||||
LIMIT ?`,
|
||||
ip,
|
||||
limit,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list decisions for ip %q: %w", ip, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
items := make([]model.DecisionRecord, 0, limit)
|
||||
for rows.Next() {
|
||||
var item model.DecisionRecord
|
||||
var action string
|
||||
var enforced int
|
||||
var createdAt string
|
||||
if err := rows.Scan(&item.ID, &item.EventID, &item.IP, &item.SourceName, &item.Kind, &action, &item.Reason, &item.Actor, &enforced, &createdAt); err != nil {
|
||||
return nil, fmt.Errorf("scan decision record: %w", err)
|
||||
}
|
||||
parsed, err := parseTime(createdAt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse decision created_at: %w", err)
|
||||
}
|
||||
item.Action = model.DecisionAction(action)
|
||||
item.Enforced = enforced != 0
|
||||
item.CreatedAt = parsed
|
||||
items = append(items, item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate decisions for ip %q: %w", ip, err)
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (s *Store) listBackendActionsForIP(ctx context.Context, ip string, limit int) ([]model.OPNsenseAction, error) {
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
rows, err := s.db.QueryContext(ctx, `
|
||||
SELECT id, ip, action, result, message, created_at
|
||||
FROM backend_actions
|
||||
WHERE ip = ?
|
||||
ORDER BY created_at DESC, id DESC
|
||||
LIMIT ?`,
|
||||
ip,
|
||||
limit,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list backend actions for ip %q: %w", ip, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
items := make([]model.OPNsenseAction, 0, limit)
|
||||
for rows.Next() {
|
||||
var item model.OPNsenseAction
|
||||
var createdAt string
|
||||
if err := rows.Scan(&item.ID, &item.IP, &item.Action, &item.Result, &item.Message, &createdAt); err != nil {
|
||||
return nil, fmt.Errorf("scan backend action: %w", err)
|
||||
}
|
||||
parsed, err := parseTime(createdAt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse backend action created_at: %w", err)
|
||||
}
|
||||
item.CreatedAt = parsed
|
||||
items = append(items, item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate backend actions for ip %q: %w", ip, err)
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func getIPStateDB(ctx context.Context, db queryer, ip string) (model.IPState, bool, error) {
|
||||
row := db.QueryRowContext(ctx, `
|
||||
SELECT ip, first_seen_at, last_seen_at, last_source_name, last_user_agent, latest_status,
|
||||
total_events, state, state_reason, manual_override, last_event_id, updated_at
|
||||
FROM ip_state WHERE ip = ?`, ip)
|
||||
|
||||
var item model.IPState
|
||||
var firstSeenAt string
|
||||
var lastSeenAt string
|
||||
var updatedAt string
|
||||
var state string
|
||||
var manualOverride string
|
||||
if err := row.Scan(
|
||||
&item.IP,
|
||||
&firstSeenAt,
|
||||
&lastSeenAt,
|
||||
&item.LastSourceName,
|
||||
&item.LastUserAgent,
|
||||
&item.LatestStatus,
|
||||
&item.TotalEvents,
|
||||
&state,
|
||||
&item.StateReason,
|
||||
&manualOverride,
|
||||
&item.LastEventID,
|
||||
&updatedAt,
|
||||
); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return model.IPState{}, false, nil
|
||||
}
|
||||
return model.IPState{}, false, fmt.Errorf("query ip state %q: %w", ip, err)
|
||||
}
|
||||
|
||||
var err error
|
||||
item.FirstSeenAt, err = parseTime(firstSeenAt)
|
||||
if err != nil {
|
||||
return model.IPState{}, false, fmt.Errorf("parse ip state first_seen_at: %w", err)
|
||||
}
|
||||
item.LastSeenAt, err = parseTime(lastSeenAt)
|
||||
if err != nil {
|
||||
return model.IPState{}, false, fmt.Errorf("parse ip state last_seen_at: %w", err)
|
||||
}
|
||||
item.UpdatedAt, err = parseTime(updatedAt)
|
||||
if err != nil {
|
||||
return model.IPState{}, false, fmt.Errorf("parse ip state updated_at: %w", err)
|
||||
}
|
||||
item.State = model.IPStateStatus(state)
|
||||
item.ManualOverride = model.ManualOverride(manualOverride)
|
||||
return item, true, nil
|
||||
}
|
||||
|
||||
func getIPStateTx(tx *sql.Tx, ip string) (model.IPState, bool, error) {
|
||||
return getIPStateDB(context.Background(), tx, ip)
|
||||
}
|
||||
|
||||
func upsertIPStateTx(ctx context.Context, tx *sql.Tx, state model.IPState) error {
|
||||
if state.UpdatedAt.IsZero() {
|
||||
state.UpdatedAt = time.Now().UTC()
|
||||
}
|
||||
if state.FirstSeenAt.IsZero() {
|
||||
state.FirstSeenAt = state.UpdatedAt
|
||||
}
|
||||
if state.LastSeenAt.IsZero() {
|
||||
state.LastSeenAt = state.UpdatedAt
|
||||
}
|
||||
if state.State == "" {
|
||||
state.State = model.IPStateObserved
|
||||
}
|
||||
if state.ManualOverride == "" {
|
||||
state.ManualOverride = model.ManualOverrideNone
|
||||
}
|
||||
_, err := tx.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO ip_state (
|
||||
ip, first_seen_at, last_seen_at, last_source_name, last_user_agent, latest_status,
|
||||
total_events, state, state_reason, manual_override, last_event_id, updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(ip) DO UPDATE SET
|
||||
first_seen_at = excluded.first_seen_at,
|
||||
last_seen_at = excluded.last_seen_at,
|
||||
last_source_name = excluded.last_source_name,
|
||||
last_user_agent = excluded.last_user_agent,
|
||||
latest_status = excluded.latest_status,
|
||||
total_events = excluded.total_events,
|
||||
state = excluded.state,
|
||||
state_reason = excluded.state_reason,
|
||||
manual_override = excluded.manual_override,
|
||||
last_event_id = excluded.last_event_id,
|
||||
updated_at = excluded.updated_at`,
|
||||
state.IP,
|
||||
formatTime(state.FirstSeenAt),
|
||||
formatTime(state.LastSeenAt),
|
||||
state.LastSourceName,
|
||||
state.LastUserAgent,
|
||||
state.LatestStatus,
|
||||
state.TotalEvents,
|
||||
string(state.State),
|
||||
state.StateReason,
|
||||
string(state.ManualOverride),
|
||||
state.LastEventID,
|
||||
formatTime(state.UpdatedAt),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("upsert ip state %q: %w", state.IP, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func mergeEventIntoState(existing model.IPState, found bool, event model.Event) model.IPState {
|
||||
now := time.Now().UTC()
|
||||
state := existing
|
||||
if !found {
|
||||
state = model.IPState{
|
||||
IP: event.ClientIP,
|
||||
FirstSeenAt: event.OccurredAt,
|
||||
LastSeenAt: event.OccurredAt,
|
||||
LastSourceName: event.SourceName,
|
||||
LastUserAgent: event.UserAgent,
|
||||
LatestStatus: event.Status,
|
||||
TotalEvents: 0,
|
||||
State: model.IPStateObserved,
|
||||
StateReason: "",
|
||||
ManualOverride: model.ManualOverrideNone,
|
||||
LastEventID: 0,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
}
|
||||
if state.FirstSeenAt.IsZero() || event.OccurredAt.Before(state.FirstSeenAt) {
|
||||
state.FirstSeenAt = event.OccurredAt
|
||||
}
|
||||
if state.LastSeenAt.IsZero() || event.OccurredAt.After(state.LastSeenAt) {
|
||||
state.LastSeenAt = event.OccurredAt
|
||||
}
|
||||
state.LastSourceName = event.SourceName
|
||||
state.LastUserAgent = event.UserAgent
|
||||
state.LatestStatus = event.Status
|
||||
state.TotalEvents++
|
||||
state.LastEventID = event.ID
|
||||
state.UpdatedAt = now
|
||||
if state.ManualOverride == "" {
|
||||
state.ManualOverride = model.ManualOverrideNone
|
||||
}
|
||||
|
||||
switch state.ManualOverride {
|
||||
case model.ManualOverrideForceBlock:
|
||||
state.State = model.IPStateBlocked
|
||||
if event.DecisionReason != "" {
|
||||
state.StateReason = event.DecisionReason
|
||||
} else if state.StateReason == "" {
|
||||
state.StateReason = "manual override: block"
|
||||
}
|
||||
return state
|
||||
case model.ManualOverrideForceAllow:
|
||||
state.State = model.IPStateAllowed
|
||||
if event.DecisionReason != "" {
|
||||
state.StateReason = event.DecisionReason
|
||||
} else if state.StateReason == "" {
|
||||
state.StateReason = "manual override: allow"
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
switch event.Decision {
|
||||
case model.DecisionActionBlock:
|
||||
state.State = model.IPStateBlocked
|
||||
state.StateReason = event.DecisionReason
|
||||
case model.DecisionActionReview:
|
||||
if state.State != model.IPStateBlocked && state.State != model.IPStateAllowed {
|
||||
state.State = model.IPStateReview
|
||||
state.StateReason = event.DecisionReason
|
||||
}
|
||||
case model.DecisionActionAllow:
|
||||
state.State = model.IPStateAllowed
|
||||
state.StateReason = event.DecisionReason
|
||||
default:
|
||||
if state.State == "" {
|
||||
state.State = model.IPStateObserved
|
||||
}
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func scanEvent(scanner interface{ Scan(dest ...any) error }) (model.Event, error) {
|
||||
var item model.Event
|
||||
var occurredAt string
|
||||
var createdAt string
|
||||
var decision string
|
||||
var decisionReasonsJSON string
|
||||
var enforced int
|
||||
var currentState string
|
||||
var manualOverride string
|
||||
if err := scanner.Scan(
|
||||
&item.ID,
|
||||
&item.SourceName,
|
||||
&item.ProfileName,
|
||||
&occurredAt,
|
||||
&item.RemoteIP,
|
||||
&item.ClientIP,
|
||||
&item.Host,
|
||||
&item.Method,
|
||||
&item.URI,
|
||||
&item.Path,
|
||||
&item.Status,
|
||||
&item.UserAgent,
|
||||
&decision,
|
||||
&item.DecisionReason,
|
||||
&decisionReasonsJSON,
|
||||
&enforced,
|
||||
&item.RawJSON,
|
||||
&createdAt,
|
||||
¤tState,
|
||||
&manualOverride,
|
||||
); err != nil {
|
||||
return model.Event{}, fmt.Errorf("scan event: %w", err)
|
||||
}
|
||||
parsedOccurredAt, err := parseTime(occurredAt)
|
||||
if err != nil {
|
||||
return model.Event{}, fmt.Errorf("parse event occurred_at: %w", err)
|
||||
}
|
||||
parsedCreatedAt, err := parseTime(createdAt)
|
||||
if err != nil {
|
||||
return model.Event{}, fmt.Errorf("parse event created_at: %w", err)
|
||||
}
|
||||
var reasons []string
|
||||
if strings.TrimSpace(decisionReasonsJSON) != "" {
|
||||
if err := json.Unmarshal([]byte(decisionReasonsJSON), &reasons); err != nil {
|
||||
return model.Event{}, fmt.Errorf("decode event decision_reasons_json: %w", err)
|
||||
}
|
||||
}
|
||||
item.OccurredAt = parsedOccurredAt
|
||||
item.CreatedAt = parsedCreatedAt
|
||||
item.Decision = model.DecisionAction(decision)
|
||||
item.DecisionReasons = reasons
|
||||
item.Enforced = enforced != 0
|
||||
item.CurrentState = model.IPStateStatus(currentState)
|
||||
item.ManualOverride = model.ManualOverride(manualOverride)
|
||||
return item, nil
|
||||
}
|
||||
|
||||
func scanIPState(scanner interface{ Scan(dest ...any) error }) (model.IPState, error) {
|
||||
var item model.IPState
|
||||
var firstSeenAt string
|
||||
var lastSeenAt string
|
||||
var updatedAt string
|
||||
var state string
|
||||
var manualOverride string
|
||||
if err := scanner.Scan(
|
||||
&item.IP,
|
||||
&firstSeenAt,
|
||||
&lastSeenAt,
|
||||
&item.LastSourceName,
|
||||
&item.LastUserAgent,
|
||||
&item.LatestStatus,
|
||||
&item.TotalEvents,
|
||||
&state,
|
||||
&item.StateReason,
|
||||
&manualOverride,
|
||||
&item.LastEventID,
|
||||
&updatedAt,
|
||||
); err != nil {
|
||||
return model.IPState{}, fmt.Errorf("scan ip state: %w", err)
|
||||
}
|
||||
var err error
|
||||
item.FirstSeenAt, err = parseTime(firstSeenAt)
|
||||
if err != nil {
|
||||
return model.IPState{}, fmt.Errorf("parse ip state first_seen_at: %w", err)
|
||||
}
|
||||
item.LastSeenAt, err = parseTime(lastSeenAt)
|
||||
if err != nil {
|
||||
return model.IPState{}, fmt.Errorf("parse ip state last_seen_at: %w", err)
|
||||
}
|
||||
item.UpdatedAt, err = parseTime(updatedAt)
|
||||
if err != nil {
|
||||
return model.IPState{}, fmt.Errorf("parse ip state updated_at: %w", err)
|
||||
}
|
||||
item.State = model.IPStateStatus(state)
|
||||
item.ManualOverride = model.ManualOverride(manualOverride)
|
||||
return item, nil
|
||||
}
|
||||
|
||||
func parseTime(value string) (time.Time, error) {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
return time.Time{}, nil
|
||||
}
|
||||
parsed, err := time.Parse(time.RFC3339Nano, trimmed)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
return parsed.UTC(), nil
|
||||
}
|
||||
|
||||
func formatTime(value time.Time) string {
|
||||
if value.IsZero() {
|
||||
return ""
|
||||
}
|
||||
return value.UTC().Format(time.RFC3339Nano)
|
||||
}
|
||||
|
||||
func boolToInt(value bool) int {
|
||||
if value {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
type queryer interface {
|
||||
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
|
||||
}
|
||||
116
internal/store/store_test.go
Normal file
116
internal/store/store_test.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/model"
|
||||
)
|
||||
|
||||
func TestStoreRecordsEventsAndState(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbPath := filepath.Join(t.TempDir(), "blocker.db")
|
||||
db, err := Open(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("open store: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
occurredAt := time.Date(2025, 3, 11, 12, 0, 0, 0, time.UTC)
|
||||
event := &model.Event{
|
||||
SourceName: "main",
|
||||
ProfileName: "main",
|
||||
OccurredAt: occurredAt,
|
||||
RemoteIP: "198.51.100.10",
|
||||
ClientIP: "203.0.113.10",
|
||||
Host: "example.test",
|
||||
Method: "GET",
|
||||
URI: "/wp-login.php",
|
||||
Path: "/wp-login.php",
|
||||
Status: 404,
|
||||
UserAgent: "curl/8.0",
|
||||
Decision: model.DecisionActionBlock,
|
||||
DecisionReason: "php_path",
|
||||
DecisionReasons: []string{"php_path"},
|
||||
Enforced: true,
|
||||
RawJSON: `{"status":404}`,
|
||||
}
|
||||
if err := db.RecordEvent(ctx, event); err != nil {
|
||||
t.Fatalf("record event: %v", err)
|
||||
}
|
||||
if event.ID == 0 {
|
||||
t.Fatalf("expected inserted event ID")
|
||||
}
|
||||
|
||||
state, found, err := db.GetIPState(ctx, "203.0.113.10")
|
||||
if err != nil {
|
||||
t.Fatalf("get ip state: %v", err)
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("expected IP state to exist")
|
||||
}
|
||||
if state.State != model.IPStateBlocked {
|
||||
t.Fatalf("unexpected ip state: %+v", state)
|
||||
}
|
||||
if state.TotalEvents != 1 {
|
||||
t.Fatalf("unexpected total events: %d", state.TotalEvents)
|
||||
}
|
||||
|
||||
if _, err := db.SetManualOverride(ctx, "203.0.113.10", model.ManualOverrideForceAllow, model.IPStateAllowed, "manual allow"); err != nil {
|
||||
t.Fatalf("set manual override: %v", err)
|
||||
}
|
||||
state, found, err = db.GetIPState(ctx, "203.0.113.10")
|
||||
if err != nil || !found {
|
||||
t.Fatalf("get overridden ip state: found=%v err=%v", found, err)
|
||||
}
|
||||
if state.ManualOverride != model.ManualOverrideForceAllow {
|
||||
t.Fatalf("unexpected override after set: %+v", state)
|
||||
}
|
||||
|
||||
if _, err := db.ClearManualOverride(ctx, "203.0.113.10", "reset"); err != nil {
|
||||
t.Fatalf("clear manual override: %v", err)
|
||||
}
|
||||
state, found, err = db.GetIPState(ctx, "203.0.113.10")
|
||||
if err != nil || !found {
|
||||
t.Fatalf("get reset ip state: found=%v err=%v", found, err)
|
||||
}
|
||||
if state.ManualOverride != model.ManualOverrideNone {
|
||||
t.Fatalf("expected cleared override, got %+v", state)
|
||||
}
|
||||
|
||||
if err := db.AddDecision(ctx, &model.DecisionRecord{EventID: event.ID, IP: event.ClientIP, SourceName: event.SourceName, Kind: "automatic", Action: model.DecisionActionBlock, Reason: "php_path", Actor: "engine", Enforced: true}); err != nil {
|
||||
t.Fatalf("add decision: %v", err)
|
||||
}
|
||||
if err := db.AddBackendAction(ctx, &model.OPNsenseAction{IP: event.ClientIP, Action: "block", Result: "added", Message: "php_path"}); err != nil {
|
||||
t.Fatalf("add backend action: %v", err)
|
||||
}
|
||||
if err := db.SaveSourceOffset(ctx, model.SourceOffset{SourceName: "main", Path: "/tmp/main.log", Inode: "1:2", Offset: 42, UpdatedAt: occurredAt}); err != nil {
|
||||
t.Fatalf("save source offset: %v", err)
|
||||
}
|
||||
offset, found, err := db.GetSourceOffset(ctx, "main")
|
||||
if err != nil {
|
||||
t.Fatalf("get source offset: %v", err)
|
||||
}
|
||||
if !found || offset.Offset != 42 {
|
||||
t.Fatalf("unexpected source offset: found=%v offset=%+v", found, offset)
|
||||
}
|
||||
|
||||
overview, err := db.GetOverview(ctx, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("get overview: %v", err)
|
||||
}
|
||||
if overview.TotalEvents != 1 || overview.TotalIPs != 1 {
|
||||
t.Fatalf("unexpected overview counters: %+v", overview)
|
||||
}
|
||||
details, err := db.GetIPDetails(ctx, event.ClientIP, 10, 10, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("get ip details: %v", err)
|
||||
}
|
||||
if len(details.RecentEvents) != 1 || len(details.Decisions) != 1 || len(details.BackendActions) != 1 {
|
||||
t.Fatalf("unexpected ip details: %+v", details)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user