2
Files
caddy-opnsense-blocker/internal/service/service.go

678 lines
18 KiB
Go

package service
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"log"
"net"
"os"
"strings"
"sync"
"syscall"
"time"
"git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/caddylog"
"git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/config"
"git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/engine"
"git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/model"
"git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/opnsense"
"git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/store"
)
type Service struct {
cfg *config.Config
store *store.Store
evaluator *engine.Evaluator
blocker opnsense.AliasClient
investigator Investigator
logger *log.Logger
investigationQueueMu sync.Mutex
investigationQueued map[string]struct{}
investigationQueue chan string
}
type Investigator interface {
Investigate(ctx context.Context, ip string, userAgents []string) (model.IPInvestigation, error)
}
func New(cfg *config.Config, db *store.Store, blocker opnsense.AliasClient, investigator Investigator, logger *log.Logger) *Service {
if logger == nil {
logger = log.New(io.Discard, "", 0)
}
service := &Service{
cfg: cfg,
store: db,
evaluator: engine.NewEvaluator(),
blocker: blocker,
investigator: investigator,
logger: logger,
}
if investigator != nil && cfg.Investigation.BackgroundWorkers > 0 {
queueSize := cfg.Investigation.BackgroundBatchSize
if queueSize < 64 {
queueSize = 64
}
service.investigationQueue = make(chan string, queueSize)
service.investigationQueued = make(map[string]struct{}, queueSize)
}
return service
}
func (s *Service) Run(ctx context.Context) error {
var wg sync.WaitGroup
if s.investigationQueue != nil {
wg.Add(1)
go func() {
defer wg.Done()
s.runInvestigationScheduler(ctx)
}()
for workerIndex := 0; workerIndex < s.cfg.Investigation.BackgroundWorkers; workerIndex++ {
wg.Add(1)
go func() {
defer wg.Done()
s.runInvestigationWorker(ctx)
}()
}
}
for _, source := range s.cfg.Sources {
source := source
wg.Add(1)
go func() {
defer wg.Done()
s.runSource(ctx, source)
}()
}
<-ctx.Done()
wg.Wait()
return nil
}
func (s *Service) GetOverview(ctx context.Context, limit int) (model.Overview, error) {
return s.store.GetOverview(ctx, limit)
}
func (s *Service) ListEvents(ctx context.Context, limit int) ([]model.Event, error) {
return s.store.ListRecentEvents(ctx, limit)
}
func (s *Service) ListIPs(ctx context.Context, limit int, state string) ([]model.IPState, error) {
return s.store.ListIPStates(ctx, limit, state)
}
func (s *Service) ListRecentIPs(ctx context.Context, since time.Time, limit int) ([]model.RecentIPRow, error) {
items, err := s.store.ListRecentIPRows(ctx, since, limit)
if err != nil {
return nil, err
}
investigations, err := s.store.GetInvestigationsForIPs(ctx, recentRowIPs(items))
if err != nil {
return nil, err
}
staleSince := time.Now().UTC().Add(-s.cfg.Investigation.RefreshAfter.Duration)
for index := range items {
state := model.IPState{
IP: items[index].IP,
State: items[index].State,
ManualOverride: items[index].ManualOverride,
}
if investigation, ok := investigations[items[index].IP]; ok {
items[index].Bot = investigation.Bot
if investigation.UpdatedAt.Before(staleSince) {
s.enqueueInvestigation(items[index].IP)
}
} else {
s.enqueueInvestigation(items[index].IP)
}
backend := s.resolveOPNsenseStatus(ctx, state)
items[index].Actions = actionAvailability(state, backend)
}
return items, nil
}
func (s *Service) GetIPDetails(ctx context.Context, ip string) (model.IPDetails, error) {
normalized, err := normalizeIP(ip)
if err != nil {
return model.IPDetails{}, err
}
details, err := s.store.GetIPDetails(ctx, normalized, 0, 100, 100)
if err != nil {
return model.IPDetails{}, err
}
return s.decorateDetails(ctx, details)
}
func (s *Service) InvestigateIP(ctx context.Context, ip string) (model.IPDetails, error) {
normalized, err := normalizeIP(ip)
if err != nil {
return model.IPDetails{}, err
}
details, err := s.store.GetIPDetails(ctx, normalized, 0, 100, 100)
if err != nil {
return model.IPDetails{}, err
}
fresh, err := s.refreshInvestigation(ctx, normalized, true)
if err != nil {
return model.IPDetails{}, err
}
if fresh != nil {
details.Investigation = fresh
}
return s.decorateDetails(ctx, details)
}
func (s *Service) ForceBlock(ctx context.Context, ip string, actor string, reason string) error {
return s.applyManualOverride(ctx, ip, model.ManualOverrideForceBlock, model.IPStateBlocked, actor, defaultReason(reason, "manual block"), "block")
}
func (s *Service) ForceAllow(ctx context.Context, ip string, actor string, reason string) error {
return s.applyManualOverride(ctx, ip, model.ManualOverrideForceAllow, model.IPStateAllowed, actor, defaultReason(reason, "manual allow"), "unblock")
}
func (s *Service) ClearOverride(ctx context.Context, ip string, actor string, reason string) error {
normalized, err := normalizeIP(ip)
if err != nil {
return err
}
reason = defaultReason(reason, "manual override cleared")
state, err := s.store.ClearManualOverride(ctx, normalized, reason)
if err != nil {
return err
}
return s.store.AddDecision(ctx, &model.DecisionRecord{
EventID: state.LastEventID,
IP: normalized,
SourceName: state.LastSourceName,
Kind: "manual",
Action: model.DecisionActionNone,
Reason: reason,
Actor: defaultActor(actor),
Enforced: false,
CreatedAt: time.Now().UTC(),
})
}
func (s *Service) runSource(ctx context.Context, source config.SourceConfig) {
s.pollSource(ctx, source)
ticker := time.NewTicker(source.PollInterval.Duration)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.pollSource(ctx, source)
}
}
}
func (s *Service) pollSource(ctx context.Context, source config.SourceConfig) {
lines, err := s.readNewLines(ctx, source)
if err != nil {
s.logger.Printf("source %s: %v", source.Name, err)
return
}
if len(lines) == 0 {
return
}
profile := s.cfg.Profiles[source.Profile]
for _, line := range lines {
record, err := caddylog.ParseLine(line)
if err != nil {
if errors.Is(err, caddylog.ErrEmptyLine) {
continue
}
s.logger.Printf("source %s: parse line: %v", source.Name, err)
continue
}
if record.Status < profile.MinStatus || record.Status > profile.MaxStatus {
continue
}
if err := s.processRecord(ctx, source, profile, record); err != nil {
s.logger.Printf("source %s: process record: %v", source.Name, err)
}
}
}
func (s *Service) processRecord(ctx context.Context, source config.SourceConfig, profile config.ProfileConfig, record model.AccessLogRecord) error {
state, found, err := s.store.GetIPState(ctx, record.ClientIP)
if err != nil {
return err
}
override := model.ManualOverrideNone
if found {
override = state.ManualOverride
}
decision := s.evaluator.Evaluate(record, profile, override)
event := model.Event{
SourceName: source.Name,
ProfileName: source.Profile,
OccurredAt: record.OccurredAt,
RemoteIP: record.RemoteIP,
ClientIP: record.ClientIP,
Host: record.Host,
Method: record.Method,
URI: record.URI,
Path: record.Path,
Status: record.Status,
UserAgent: record.UserAgent,
Decision: decision.Action,
DecisionReason: decision.PrimaryReason(),
DecisionReasons: append([]string(nil), decision.Reasons...),
Enforced: false,
RawJSON: record.RawJSON,
CreatedAt: time.Now().UTC(),
}
var backendAction *model.OPNsenseAction
if decision.Action == model.DecisionActionBlock && s.blocker != nil {
result, blockErr := s.blocker.AddIPIfMissing(ctx, record.ClientIP)
backendAction = &model.OPNsenseAction{
IP: record.ClientIP,
Action: "block",
CreatedAt: time.Now().UTC(),
}
if blockErr != nil {
backendAction.Result = "error"
backendAction.Message = blockErr.Error()
} else {
backendAction.Result = result
backendAction.Message = decision.PrimaryReason()
event.Enforced = true
}
}
if err := s.store.RecordEvent(ctx, &event); err != nil {
return err
}
if decision.Action != model.DecisionActionNone {
if err := s.store.AddDecision(ctx, &model.DecisionRecord{
EventID: event.ID,
IP: record.ClientIP,
SourceName: source.Name,
Kind: "automatic",
Action: decision.Action,
Reason: strings.Join(decision.Reasons, ", "),
Actor: "engine",
Enforced: event.Enforced,
CreatedAt: time.Now().UTC(),
}); err != nil {
return err
}
}
if backendAction != nil {
if err := s.store.AddBackendAction(ctx, backendAction); err != nil {
return err
}
}
s.enqueueInvestigation(record.ClientIP)
return nil
}
func (s *Service) runInvestigationScheduler(ctx context.Context) {
s.enqueueRecentInvestigations(ctx)
ticker := time.NewTicker(s.cfg.Investigation.BackgroundPollInterval.Duration)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.enqueueRecentInvestigations(ctx)
}
}
}
func (s *Service) runInvestigationWorker(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case ip := <-s.investigationQueue:
func() {
defer s.markInvestigationDone(ip)
workerCtx, cancel := context.WithTimeout(ctx, s.cfg.Investigation.Timeout.Duration)
_, err := s.refreshInvestigation(workerCtx, ip, false)
cancel()
if err != nil && !errors.Is(err, context.Canceled) {
s.logger.Printf("investigation %s: %v", ip, err)
}
}()
}
}
}
func (s *Service) enqueueRecentInvestigations(ctx context.Context) {
if s.investigationQueue == nil {
return
}
since := time.Now().UTC().Add(-s.cfg.Investigation.BackgroundLookback.Duration)
items, err := s.store.ListRecentIPRows(ctx, since, s.cfg.Investigation.BackgroundBatchSize)
if err != nil {
s.logger.Printf("list recent IPs for investigation: %v", err)
return
}
investigations, err := s.store.GetInvestigationsForIPs(ctx, recentRowIPs(items))
if err != nil {
s.logger.Printf("list investigations for recent IPs: %v", err)
return
}
staleSince := time.Now().UTC().Add(-s.cfg.Investigation.RefreshAfter.Duration)
for _, item := range items {
investigation, found := investigations[item.IP]
if !found || investigation.UpdatedAt.Before(staleSince) {
s.enqueueInvestigation(item.IP)
}
}
}
func (s *Service) enqueueInvestigation(ip string) {
if s.investigationQueue == nil {
return
}
normalized, err := normalizeIP(ip)
if err != nil {
return
}
s.investigationQueueMu.Lock()
if _, ok := s.investigationQueued[normalized]; ok {
s.investigationQueueMu.Unlock()
return
}
s.investigationQueued[normalized] = struct{}{}
s.investigationQueueMu.Unlock()
select {
case s.investigationQueue <- normalized:
default:
s.markInvestigationDone(normalized)
}
}
func (s *Service) markInvestigationDone(ip string) {
s.investigationQueueMu.Lock()
defer s.investigationQueueMu.Unlock()
delete(s.investigationQueued, ip)
}
func (s *Service) refreshInvestigation(ctx context.Context, ip string, force bool) (*model.IPInvestigation, error) {
if s.investigator == nil {
return nil, nil
}
normalized, err := normalizeIP(ip)
if err != nil {
return nil, err
}
investigation, found, err := s.store.GetInvestigation(ctx, normalized)
if err != nil {
return nil, err
}
shouldRefresh := force || !found || time.Since(investigation.UpdatedAt) >= s.cfg.Investigation.RefreshAfter.Duration
if !shouldRefresh {
return &investigation, nil
}
userAgents, err := s.store.ListRecentUserAgentsForIP(ctx, normalized, 12)
if err != nil {
return nil, err
}
fresh, err := s.investigator.Investigate(ctx, normalized, userAgents)
if err != nil {
return nil, err
}
if err := s.store.SaveInvestigation(ctx, fresh); err != nil {
return nil, err
}
return &fresh, nil
}
func (s *Service) readNewLines(ctx context.Context, source config.SourceConfig) ([]string, error) {
info, err := os.Stat(source.Path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil, nil
}
return nil, fmt.Errorf("stat source path %q: %w", source.Path, err)
}
inode := fileIdentity(info)
size := info.Size()
offset, found, err := s.store.GetSourceOffset(ctx, source.Name)
if err != nil {
return nil, err
}
if !found {
start := int64(0)
if source.InitialPosition == "end" {
start = size
}
offset = model.SourceOffset{
SourceName: source.Name,
Path: source.Path,
Inode: inode,
Offset: start,
UpdatedAt: time.Now().UTC(),
}
if err := s.store.SaveSourceOffset(ctx, offset); err != nil {
return nil, err
}
if start >= size {
return nil, nil
}
} else if offset.Inode != inode || size < offset.Offset {
offset = model.SourceOffset{
SourceName: source.Name,
Path: source.Path,
Inode: inode,
Offset: 0,
UpdatedAt: time.Now().UTC(),
}
}
file, err := os.Open(source.Path)
if err != nil {
return nil, fmt.Errorf("open source path %q: %w", source.Path, err)
}
defer file.Close()
if _, err := file.Seek(offset.Offset, io.SeekStart); err != nil {
return nil, fmt.Errorf("seek source path %q: %w", source.Path, err)
}
reader := bufio.NewReader(file)
lines := make([]string, 0, source.BatchSize)
currentOffset := offset.Offset
for len(lines) < source.BatchSize {
line, err := reader.ReadString('\n')
if err != nil {
if errors.Is(err, io.EOF) {
break
}
return nil, fmt.Errorf("read source path %q: %w", source.Path, err)
}
currentOffset += int64(len(line))
lines = append(lines, strings.TrimRight(line, "\r\n"))
}
offset.Path = source.Path
offset.Inode = inode
offset.Offset = currentOffset
offset.UpdatedAt = time.Now().UTC()
if err := s.store.SaveSourceOffset(ctx, offset); err != nil {
return nil, err
}
return lines, nil
}
func (s *Service) applyManualOverride(ctx context.Context, ip string, override model.ManualOverride, state model.IPStateStatus, actor string, reason string, backendAction string) error {
normalized, err := normalizeIP(ip)
if err != nil {
return err
}
enforced := false
var backendRecord *model.OPNsenseAction
if s.blocker != nil {
backendRecord = &model.OPNsenseAction{
IP: normalized,
Action: backendAction,
CreatedAt: time.Now().UTC(),
}
switch override {
case model.ManualOverrideForceBlock:
result, callErr := s.blocker.AddIPIfMissing(ctx, normalized)
if callErr != nil {
backendRecord.Result = "error"
backendRecord.Message = callErr.Error()
} else {
backendRecord.Result = result
backendRecord.Message = reason
enforced = true
}
case model.ManualOverrideForceAllow:
result, callErr := s.blocker.RemoveIPIfPresent(ctx, normalized)
if callErr != nil {
backendRecord.Result = "error"
backendRecord.Message = callErr.Error()
} else {
backendRecord.Result = result
backendRecord.Message = reason
enforced = true
}
}
}
current, err := s.store.SetManualOverride(ctx, normalized, override, state, reason)
if err != nil {
return err
}
if err := s.store.AddDecision(ctx, &model.DecisionRecord{
EventID: current.LastEventID,
IP: normalized,
SourceName: current.LastSourceName,
Kind: "manual",
Action: actionForOverride(override),
Reason: reason,
Actor: defaultActor(actor),
Enforced: enforced,
CreatedAt: time.Now().UTC(),
}); err != nil {
return err
}
if backendRecord != nil {
if err := s.store.AddBackendAction(ctx, backendRecord); err != nil {
return err
}
}
return nil
}
func normalizeIP(ip string) (string, error) {
parsed := net.ParseIP(strings.TrimSpace(ip))
if parsed == nil {
return "", fmt.Errorf("invalid ip address %q", ip)
}
return parsed.String(), nil
}
func fileIdentity(info os.FileInfo) string {
if stat, ok := info.Sys().(*syscall.Stat_t); ok {
return fmt.Sprintf("%d:%d", stat.Dev, stat.Ino)
}
return fmt.Sprintf("fallback:%d:%d", info.ModTime().UnixNano(), info.Size())
}
func actionForOverride(override model.ManualOverride) model.DecisionAction {
switch override {
case model.ManualOverrideForceBlock:
return model.DecisionActionBlock
case model.ManualOverrideForceAllow:
return model.DecisionActionAllow
default:
return model.DecisionActionNone
}
}
func defaultActor(actor string) string {
if strings.TrimSpace(actor) == "" {
return "web-ui"
}
return strings.TrimSpace(actor)
}
func defaultReason(reason string, fallback string) string {
if strings.TrimSpace(reason) == "" {
return fallback
}
return strings.TrimSpace(reason)
}
func (s *Service) decorateDetails(ctx context.Context, details model.IPDetails) (model.IPDetails, error) {
if details.State.IP != "" && details.Investigation == nil {
investigation, found, err := s.store.GetInvestigation(ctx, details.State.IP)
if err != nil {
return model.IPDetails{}, err
}
if found {
details.Investigation = &investigation
}
}
details.OPNsense = s.resolveOPNsenseStatus(ctx, details.State)
details.Actions = actionAvailability(details.State, details.OPNsense)
return details, nil
}
func (s *Service) resolveOPNsenseStatus(ctx context.Context, state model.IPState) model.OPNsenseStatus {
status := model.OPNsenseStatus{Configured: s.blocker != nil}
if s.blocker == nil || state.IP == "" {
return status
}
status.CheckedAt = time.Now().UTC()
present, err := s.blocker.IsIPPresent(ctx, state.IP)
if err != nil {
status.Error = err.Error()
return status
}
status.Present = present
return status
}
func actionAvailability(state model.IPState, backend model.OPNsenseStatus) model.ActionAvailability {
present := false
if backend.Configured && backend.Error == "" {
present = backend.Present
} else {
present = state.State == model.IPStateBlocked || state.ManualOverride == model.ManualOverrideForceBlock
}
return model.ActionAvailability{
CanBlock: !present,
CanUnblock: present,
CanClearOverride: state.ManualOverride != model.ManualOverrideNone,
}
}
func collectUserAgents(events []model.Event) []string {
items := make([]string, 0, len(events))
for _, event := range events {
if strings.TrimSpace(event.UserAgent) == "" {
continue
}
items = append(items, event.UserAgent)
}
return items
}
func recentRowIPs(items []model.RecentIPRow) []string {
result := make([]string, 0, len(items))
for _, item := range items {
result = append(result, item.IP)
}
return result
}