You've already forked caddy-opnsense-blocker
Build initial caddy-opnsense-blocker daemon
This commit is contained in:
412
internal/service/service.go
Normal file
412
internal/service/service.go
Normal file
@@ -0,0 +1,412 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/caddylog"
|
||||
"git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/config"
|
||||
"git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/engine"
|
||||
"git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/model"
|
||||
"git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/opnsense"
|
||||
"git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/store"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
cfg *config.Config
|
||||
store *store.Store
|
||||
evaluator *engine.Evaluator
|
||||
blocker opnsense.AliasClient
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
func New(cfg *config.Config, db *store.Store, blocker opnsense.AliasClient, logger *log.Logger) *Service {
|
||||
if logger == nil {
|
||||
logger = log.New(io.Discard, "", 0)
|
||||
}
|
||||
return &Service{
|
||||
cfg: cfg,
|
||||
store: db,
|
||||
evaluator: engine.NewEvaluator(),
|
||||
blocker: blocker,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Run(ctx context.Context) error {
|
||||
var wg sync.WaitGroup
|
||||
for _, source := range s.cfg.Sources {
|
||||
source := source
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
s.runSource(ctx, source)
|
||||
}()
|
||||
}
|
||||
<-ctx.Done()
|
||||
wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) GetOverview(ctx context.Context, limit int) (model.Overview, error) {
|
||||
return s.store.GetOverview(ctx, limit)
|
||||
}
|
||||
|
||||
func (s *Service) ListEvents(ctx context.Context, limit int) ([]model.Event, error) {
|
||||
return s.store.ListRecentEvents(ctx, limit)
|
||||
}
|
||||
|
||||
func (s *Service) ListIPs(ctx context.Context, limit int, state string) ([]model.IPState, error) {
|
||||
return s.store.ListIPStates(ctx, limit, state)
|
||||
}
|
||||
|
||||
func (s *Service) GetIPDetails(ctx context.Context, ip string) (model.IPDetails, error) {
|
||||
normalized, err := normalizeIP(ip)
|
||||
if err != nil {
|
||||
return model.IPDetails{}, err
|
||||
}
|
||||
return s.store.GetIPDetails(ctx, normalized, 100, 100, 100)
|
||||
}
|
||||
|
||||
func (s *Service) ForceBlock(ctx context.Context, ip string, actor string, reason string) error {
|
||||
return s.applyManualOverride(ctx, ip, model.ManualOverrideForceBlock, model.IPStateBlocked, actor, defaultReason(reason, "manual block"), "block")
|
||||
}
|
||||
|
||||
func (s *Service) ForceAllow(ctx context.Context, ip string, actor string, reason string) error {
|
||||
return s.applyManualOverride(ctx, ip, model.ManualOverrideForceAllow, model.IPStateAllowed, actor, defaultReason(reason, "manual allow"), "unblock")
|
||||
}
|
||||
|
||||
func (s *Service) ClearOverride(ctx context.Context, ip string, actor string, reason string) error {
|
||||
normalized, err := normalizeIP(ip)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
reason = defaultReason(reason, "manual override cleared")
|
||||
state, err := s.store.ClearManualOverride(ctx, normalized, reason)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.store.AddDecision(ctx, &model.DecisionRecord{
|
||||
EventID: state.LastEventID,
|
||||
IP: normalized,
|
||||
SourceName: state.LastSourceName,
|
||||
Kind: "manual",
|
||||
Action: model.DecisionActionNone,
|
||||
Reason: reason,
|
||||
Actor: defaultActor(actor),
|
||||
Enforced: false,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) runSource(ctx context.Context, source config.SourceConfig) {
|
||||
s.pollSource(ctx, source)
|
||||
ticker := time.NewTicker(source.PollInterval.Duration)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.pollSource(ctx, source)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) pollSource(ctx context.Context, source config.SourceConfig) {
|
||||
lines, err := s.readNewLines(ctx, source)
|
||||
if err != nil {
|
||||
s.logger.Printf("source %s: %v", source.Name, err)
|
||||
return
|
||||
}
|
||||
if len(lines) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
profile := s.cfg.Profiles[source.Profile]
|
||||
for _, line := range lines {
|
||||
record, err := caddylog.ParseLine(line)
|
||||
if err != nil {
|
||||
if errors.Is(err, caddylog.ErrEmptyLine) {
|
||||
continue
|
||||
}
|
||||
s.logger.Printf("source %s: parse line: %v", source.Name, err)
|
||||
continue
|
||||
}
|
||||
if record.Status < profile.MinStatus || record.Status > profile.MaxStatus {
|
||||
continue
|
||||
}
|
||||
if err := s.processRecord(ctx, source, profile, record); err != nil {
|
||||
s.logger.Printf("source %s: process record: %v", source.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) processRecord(ctx context.Context, source config.SourceConfig, profile config.ProfileConfig, record model.AccessLogRecord) error {
|
||||
state, found, err := s.store.GetIPState(ctx, record.ClientIP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
override := model.ManualOverrideNone
|
||||
if found {
|
||||
override = state.ManualOverride
|
||||
}
|
||||
|
||||
decision := s.evaluator.Evaluate(record, profile, override)
|
||||
event := model.Event{
|
||||
SourceName: source.Name,
|
||||
ProfileName: source.Profile,
|
||||
OccurredAt: record.OccurredAt,
|
||||
RemoteIP: record.RemoteIP,
|
||||
ClientIP: record.ClientIP,
|
||||
Host: record.Host,
|
||||
Method: record.Method,
|
||||
URI: record.URI,
|
||||
Path: record.Path,
|
||||
Status: record.Status,
|
||||
UserAgent: record.UserAgent,
|
||||
Decision: decision.Action,
|
||||
DecisionReason: decision.PrimaryReason(),
|
||||
DecisionReasons: append([]string(nil), decision.Reasons...),
|
||||
Enforced: false,
|
||||
RawJSON: record.RawJSON,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
var backendAction *model.OPNsenseAction
|
||||
if decision.Action == model.DecisionActionBlock && s.blocker != nil {
|
||||
result, blockErr := s.blocker.AddIPIfMissing(ctx, record.ClientIP)
|
||||
backendAction = &model.OPNsenseAction{
|
||||
IP: record.ClientIP,
|
||||
Action: "block",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
if blockErr != nil {
|
||||
backendAction.Result = "error"
|
||||
backendAction.Message = blockErr.Error()
|
||||
} else {
|
||||
backendAction.Result = result
|
||||
backendAction.Message = decision.PrimaryReason()
|
||||
event.Enforced = true
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.store.RecordEvent(ctx, &event); err != nil {
|
||||
return err
|
||||
}
|
||||
if decision.Action != model.DecisionActionNone {
|
||||
if err := s.store.AddDecision(ctx, &model.DecisionRecord{
|
||||
EventID: event.ID,
|
||||
IP: record.ClientIP,
|
||||
SourceName: source.Name,
|
||||
Kind: "automatic",
|
||||
Action: decision.Action,
|
||||
Reason: strings.Join(decision.Reasons, ", "),
|
||||
Actor: "engine",
|
||||
Enforced: event.Enforced,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if backendAction != nil {
|
||||
if err := s.store.AddBackendAction(ctx, backendAction); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) readNewLines(ctx context.Context, source config.SourceConfig) ([]string, error) {
|
||||
info, err := os.Stat(source.Path)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("stat source path %q: %w", source.Path, err)
|
||||
}
|
||||
inode := fileIdentity(info)
|
||||
size := info.Size()
|
||||
|
||||
offset, found, err := s.store.GetSourceOffset(ctx, source.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !found {
|
||||
start := int64(0)
|
||||
if source.InitialPosition == "end" {
|
||||
start = size
|
||||
}
|
||||
offset = model.SourceOffset{
|
||||
SourceName: source.Name,
|
||||
Path: source.Path,
|
||||
Inode: inode,
|
||||
Offset: start,
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
}
|
||||
if err := s.store.SaveSourceOffset(ctx, offset); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if start >= size {
|
||||
return nil, nil
|
||||
}
|
||||
} else if offset.Inode != inode || size < offset.Offset {
|
||||
offset = model.SourceOffset{
|
||||
SourceName: source.Name,
|
||||
Path: source.Path,
|
||||
Inode: inode,
|
||||
Offset: 0,
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
file, err := os.Open(source.Path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open source path %q: %w", source.Path, err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
if _, err := file.Seek(offset.Offset, io.SeekStart); err != nil {
|
||||
return nil, fmt.Errorf("seek source path %q: %w", source.Path, err)
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(file)
|
||||
lines := make([]string, 0, source.BatchSize)
|
||||
currentOffset := offset.Offset
|
||||
|
||||
for len(lines) < source.BatchSize {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
return nil, fmt.Errorf("read source path %q: %w", source.Path, err)
|
||||
}
|
||||
currentOffset += int64(len(line))
|
||||
lines = append(lines, strings.TrimRight(line, "\r\n"))
|
||||
}
|
||||
|
||||
offset.Path = source.Path
|
||||
offset.Inode = inode
|
||||
offset.Offset = currentOffset
|
||||
offset.UpdatedAt = time.Now().UTC()
|
||||
if err := s.store.SaveSourceOffset(ctx, offset); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return lines, nil
|
||||
}
|
||||
|
||||
func (s *Service) applyManualOverride(ctx context.Context, ip string, override model.ManualOverride, state model.IPStateStatus, actor string, reason string, backendAction string) error {
|
||||
normalized, err := normalizeIP(ip)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
enforced := false
|
||||
var backendRecord *model.OPNsenseAction
|
||||
if s.blocker != nil {
|
||||
backendRecord = &model.OPNsenseAction{
|
||||
IP: normalized,
|
||||
Action: backendAction,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
switch override {
|
||||
case model.ManualOverrideForceBlock:
|
||||
result, callErr := s.blocker.AddIPIfMissing(ctx, normalized)
|
||||
if callErr != nil {
|
||||
backendRecord.Result = "error"
|
||||
backendRecord.Message = callErr.Error()
|
||||
} else {
|
||||
backendRecord.Result = result
|
||||
backendRecord.Message = reason
|
||||
enforced = true
|
||||
}
|
||||
case model.ManualOverrideForceAllow:
|
||||
result, callErr := s.blocker.RemoveIPIfPresent(ctx, normalized)
|
||||
if callErr != nil {
|
||||
backendRecord.Result = "error"
|
||||
backendRecord.Message = callErr.Error()
|
||||
} else {
|
||||
backendRecord.Result = result
|
||||
backendRecord.Message = reason
|
||||
enforced = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
current, err := s.store.SetManualOverride(ctx, normalized, override, state, reason)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.store.AddDecision(ctx, &model.DecisionRecord{
|
||||
EventID: current.LastEventID,
|
||||
IP: normalized,
|
||||
SourceName: current.LastSourceName,
|
||||
Kind: "manual",
|
||||
Action: actionForOverride(override),
|
||||
Reason: reason,
|
||||
Actor: defaultActor(actor),
|
||||
Enforced: enforced,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
if backendRecord != nil {
|
||||
if err := s.store.AddBackendAction(ctx, backendRecord); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeIP(ip string) (string, error) {
|
||||
parsed := net.ParseIP(strings.TrimSpace(ip))
|
||||
if parsed == nil {
|
||||
return "", fmt.Errorf("invalid ip address %q", ip)
|
||||
}
|
||||
return parsed.String(), nil
|
||||
}
|
||||
|
||||
func fileIdentity(info os.FileInfo) string {
|
||||
if stat, ok := info.Sys().(*syscall.Stat_t); ok {
|
||||
return fmt.Sprintf("%d:%d", stat.Dev, stat.Ino)
|
||||
}
|
||||
return fmt.Sprintf("fallback:%d:%d", info.ModTime().UnixNano(), info.Size())
|
||||
}
|
||||
|
||||
func actionForOverride(override model.ManualOverride) model.DecisionAction {
|
||||
switch override {
|
||||
case model.ManualOverrideForceBlock:
|
||||
return model.DecisionActionBlock
|
||||
case model.ManualOverrideForceAllow:
|
||||
return model.DecisionActionAllow
|
||||
default:
|
||||
return model.DecisionActionNone
|
||||
}
|
||||
}
|
||||
|
||||
func defaultActor(actor string) string {
|
||||
if strings.TrimSpace(actor) == "" {
|
||||
return "web-ui"
|
||||
}
|
||||
return strings.TrimSpace(actor)
|
||||
}
|
||||
|
||||
func defaultReason(reason string, fallback string) string {
|
||||
if strings.TrimSpace(reason) == "" {
|
||||
return fallback
|
||||
}
|
||||
return strings.TrimSpace(reason)
|
||||
}
|
||||
247
internal/service/service_test.go
Normal file
247
internal/service/service_test.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/config"
|
||||
"git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/model"
|
||||
"git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/opnsense"
|
||||
"git.dern.ovh/infrastructure/caddy-opnsense-blocker/internal/store"
|
||||
)
|
||||
|
||||
func TestServiceProcessesMultipleSourcesAndManualActions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
mainLogPath := filepath.Join(tempDir, "main.log")
|
||||
giteaLogPath := filepath.Join(tempDir, "gitea.log")
|
||||
if err := os.WriteFile(mainLogPath, nil, 0o600); err != nil {
|
||||
t.Fatalf("create main log: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(giteaLogPath, nil, 0o600); err != nil {
|
||||
t.Fatalf("create gitea log: %v", err)
|
||||
}
|
||||
|
||||
backend := newFakeOPNsenseServer(t)
|
||||
defer backend.Close()
|
||||
|
||||
configPath := filepath.Join(tempDir, "config.yaml")
|
||||
payload := fmt.Sprintf(`storage:
|
||||
path: %s/blocker.db
|
||||
opnsense:
|
||||
enabled: true
|
||||
base_url: %s
|
||||
api_key: key
|
||||
api_secret: secret
|
||||
ensure_alias: true
|
||||
alias:
|
||||
name: blocked-ips
|
||||
profiles:
|
||||
main:
|
||||
auto_block: true
|
||||
block_unexpected_posts: true
|
||||
block_php_paths: true
|
||||
suspicious_path_prefixes:
|
||||
- /wp-login.php
|
||||
gitea:
|
||||
auto_block: false
|
||||
block_unexpected_posts: true
|
||||
allowed_post_paths:
|
||||
- /user/login
|
||||
suspicious_path_prefixes:
|
||||
- /install.php
|
||||
sources:
|
||||
- name: main
|
||||
path: %s
|
||||
profile: main
|
||||
initial_position: beginning
|
||||
poll_interval: 20ms
|
||||
batch_size: 128
|
||||
- name: gitea
|
||||
path: %s
|
||||
profile: gitea
|
||||
initial_position: beginning
|
||||
poll_interval: 20ms
|
||||
batch_size: 128
|
||||
`, tempDir, backend.URL, mainLogPath, giteaLogPath)
|
||||
if err := os.WriteFile(configPath, []byte(payload), 0o600); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := config.Load(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("load config: %v", err)
|
||||
}
|
||||
database, err := store.Open(cfg.Storage.Path)
|
||||
if err != nil {
|
||||
t.Fatalf("open store: %v", err)
|
||||
}
|
||||
defer database.Close()
|
||||
svc := New(cfg, database, opnsense.NewClient(cfg.OPNsense), log.New(os.Stderr, "", 0))
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
go func() { _ = svc.Run(ctx) }()
|
||||
|
||||
appendLine(t, mainLogPath, caddyJSONLine("203.0.113.10", "198.51.100.10", "example.test", "GET", "/wp-login.php", 404, "curl/8.0", time.Now().UTC()))
|
||||
appendLine(t, giteaLogPath, caddyJSONLine("203.0.113.11", "198.51.100.11", "git.example.test", "POST", "/user/login", 401, "curl/8.0", time.Now().UTC()))
|
||||
appendLine(t, giteaLogPath, caddyJSONLine("203.0.113.12", "198.51.100.12", "git.example.test", "GET", "/install.php", 404, "curl/8.0", time.Now().UTC()))
|
||||
|
||||
waitFor(t, 3*time.Second, func() bool {
|
||||
overview, err := database.GetOverview(context.Background(), 10)
|
||||
return err == nil && overview.TotalEvents == 3
|
||||
})
|
||||
|
||||
blockedState, found, err := database.GetIPState(context.Background(), "203.0.113.10")
|
||||
if err != nil || !found {
|
||||
t.Fatalf("load blocked state: found=%v err=%v", found, err)
|
||||
}
|
||||
if blockedState.State != model.IPStateBlocked {
|
||||
t.Fatalf("expected blocked state, got %+v", blockedState)
|
||||
}
|
||||
|
||||
reviewState, found, err := database.GetIPState(context.Background(), "203.0.113.12")
|
||||
if err != nil || !found {
|
||||
t.Fatalf("load review state: found=%v err=%v", found, err)
|
||||
}
|
||||
if reviewState.State != model.IPStateReview {
|
||||
t.Fatalf("expected review state, got %+v", reviewState)
|
||||
}
|
||||
|
||||
observedState, found, err := database.GetIPState(context.Background(), "203.0.113.11")
|
||||
if err != nil || !found {
|
||||
t.Fatalf("load observed state: found=%v err=%v", found, err)
|
||||
}
|
||||
if observedState.State != model.IPStateObserved {
|
||||
t.Fatalf("expected observed state, got %+v", observedState)
|
||||
}
|
||||
|
||||
if err := svc.ForceAllow(context.Background(), "203.0.113.10", "test", "manual unblock"); err != nil {
|
||||
t.Fatalf("force allow: %v", err)
|
||||
}
|
||||
state, found, err := database.GetIPState(context.Background(), "203.0.113.10")
|
||||
if err != nil || !found {
|
||||
t.Fatalf("reload unblocked state: found=%v err=%v", found, err)
|
||||
}
|
||||
if state.ManualOverride != model.ManualOverrideForceAllow || state.State != model.IPStateAllowed {
|
||||
t.Fatalf("unexpected manual allow state: %+v", state)
|
||||
}
|
||||
|
||||
backend.mu.Lock()
|
||||
defer backend.mu.Unlock()
|
||||
if _, ok := backend.ips["203.0.113.10"]; ok {
|
||||
t.Fatalf("expected IP to be removed from backend alias after manual unblock")
|
||||
}
|
||||
}
|
||||
|
||||
type fakeOPNsenseServer struct {
|
||||
*httptest.Server
|
||||
mu sync.Mutex
|
||||
aliasUUID string
|
||||
aliasExists bool
|
||||
ips map[string]struct{}
|
||||
}
|
||||
|
||||
func newFakeOPNsenseServer(t *testing.T) *fakeOPNsenseServer {
|
||||
t.Helper()
|
||||
backend := &fakeOPNsenseServer{ips: map[string]struct{}{}}
|
||||
backend.Server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
username, password, ok := r.BasicAuth()
|
||||
if !ok || username != "key" || password != "secret" {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
backend.mu.Lock()
|
||||
defer backend.mu.Unlock()
|
||||
|
||||
switch {
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/api/firewall/alias/get_alias_u_u_i_d/blocked-ips":
|
||||
if backend.aliasExists {
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{"uuid": backend.aliasUUID})
|
||||
} else {
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{"uuid": ""})
|
||||
}
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/api/firewall/alias/add_item":
|
||||
backend.aliasExists = true
|
||||
backend.aliasUUID = "uuid-1"
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{"status": "ok"})
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/api/firewall/alias/set_item/uuid-1":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{"status": "ok"})
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/api/firewall/alias/reconfigure":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{"status": "ok"})
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/api/firewall/alias_util/list/blocked-ips":
|
||||
rows := make([]map[string]string, 0, len(backend.ips))
|
||||
for ip := range backend.ips {
|
||||
rows = append(rows, map[string]string{"ip": ip})
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{"rows": rows})
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/api/firewall/alias_util/add/blocked-ips":
|
||||
var payload map[string]string
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
backend.ips[payload["address"]] = struct{}{}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{"status": "done"})
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/api/firewall/alias_util/delete/blocked-ips":
|
||||
var payload map[string]string
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
delete(backend.ips, payload["address"])
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{"status": "done"})
|
||||
default:
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
return backend
|
||||
}
|
||||
|
||||
func appendLine(t *testing.T, path string, line string) {
|
||||
t.Helper()
|
||||
file, err := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("open log file for append: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
if _, err := file.WriteString(line + "\n"); err != nil {
|
||||
t.Fatalf("append log line: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func caddyJSONLine(clientIP string, remoteIP string, host string, method string, uri string, status int, userAgent string, occurredAt time.Time) string {
|
||||
return fmt.Sprintf(`{"ts":%q,"status":%d,"request":{"remote_ip":%q,"client_ip":%q,"host":%q,"method":%q,"uri":%q,"headers":{"User-Agent":[%q]}}}`,
|
||||
occurredAt.UTC().Format(time.RFC3339Nano),
|
||||
status,
|
||||
remoteIP,
|
||||
clientIP,
|
||||
host,
|
||||
method,
|
||||
uri,
|
||||
userAgent,
|
||||
)
|
||||
}
|
||||
|
||||
func waitFor(t *testing.T, timeout time.Duration, condition func() bool) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
if condition() {
|
||||
return
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("condition was not met within %s", timeout)
|
||||
}
|
||||
Reference in New Issue
Block a user