1
2025-11-26 01:03:41 +01:00

158 lines
5.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# scripts/run_chronos_tuned.py
from __future__ import annotations
from pathlib import Path
import sys
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
PROJECT_ROOT = Path(__file__).resolve().parents[3]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from meteo.dataset import load_raw_csv
try:
from chronos import ChronosPipeline
except ImportError as exc:
raise SystemExit("chronos-forecasting manquant : pip install -r requirements.txt") from exc
CSV_PATH = PROJECT_ROOT / "data" / "weather_minutely.csv"
DOC_DIR = Path(__file__).resolve().parent.parent
DATA_DIR = DOC_DIR / "data"
FIG_DIR = DOC_DIR / "figures"
# Réglages plus prudents : contexte raccourci, horizon 64 max, plus d'échantillons
MODEL_ID = "amazon/chronos-t5-small"
RESAMPLE_RULE = "1h"
CONTEXT_H = 288 # 12 jours
MAX_HORIZON_H = 64 # <=64 conseillé
HORIZONS_H = (1, 6, 24, 48) # on reste cohérents avec nos jalons
NUM_SAMPLES = 100
TARGETS = {
"temperature": {"kind": "reg"},
"wind_speed": {"kind": "reg"},
"rain_rate": {"kind": "rain"},
}
def _load_series(target: str) -> pd.Series:
df = load_raw_csv(CSV_PATH)
if target not in df.columns:
raise SystemExit(f"Colonne absente : {target}")
s = df[target].resample(RESAMPLE_RULE).mean().interpolate(limit_direction="both")
return s.dropna()
def _prepare_window(series: pd.Series, context_h: int, horizon_h: int) -> tuple[np.ndarray, pd.Series]:
needed = context_h + horizon_h
if len(series) < needed:
raise SystemExit(f"Pas assez de données pour {needed} heures.")
window = series.iloc[-needed:]
context = window.iloc[:context_h]
target = window.iloc[context_h:]
return context.to_numpy(dtype=float), target
def _plot_metrics(df: pd.DataFrame, target: str, output_path: Path) -> None:
output_path.parent.mkdir(parents=True, exist_ok=True)
fig, ax = plt.subplots(figsize=(6, 4))
if df["kind"].iloc[0] == "reg":
ax.plot(df["horizon_h"], df["mae"], marker="o", label="MAE")
ax.plot(df["horizon_h"], df["rmse"], marker="o", label="RMSE")
ax.set_ylabel("Erreur (°C)" if target == "temperature" else "Erreur (unité)")
else:
ax.plot(df["horizon_h"], df["f1"], marker="o", label="F1")
ax.plot(df["horizon_h"], df["brier"], marker="o", label="Brier")
ax.set_ylabel("Score")
ax.set_xlabel("Horizon (heures)")
ax.set_title(f"Chronos (small, réglages prudents) {target}")
ax.grid(True, linestyle=":", alpha=0.4)
ax.legend()
fig.tight_layout()
fig.savefig(output_path, dpi=150)
plt.close(fig)
def main() -> None:
pipeline = ChronosPipeline.from_pretrained(
MODEL_ID,
device_map="auto",
dtype="auto",
)
rows: list[dict[str, object]] = []
for target, meta in TARGETS.items():
series = _load_series(target)
context_arr, target_series = _prepare_window(series, CONTEXT_H, MAX_HORIZON_H)
context_tensor = torch.tensor(context_arr, dtype=torch.float32)
forecasts = pipeline.predict(
[context_tensor],
prediction_length=MAX_HORIZON_H,
num_samples=NUM_SAMPLES,
)
forecast_mean = forecasts.mean(0)
if forecast_mean.ndim == 2:
forecast_mean = forecast_mean[0]
forecast_series = pd.Series(
np.asarray(forecast_mean).ravel(),
index=pd.date_range(target_series.index[0], periods=MAX_HORIZON_H, freq=RESAMPLE_RULE),
)
for h in HORIZONS_H:
y_true = target_series.iloc[:h]
y_pred = forecast_series.iloc[:h]
if meta["kind"] == "reg":
mae = float((y_pred - y_true).abs().mean())
rmse = float(np.sqrt(((y_pred - y_true) ** 2).mean()))
rows.append(
{"target": target, "kind": "reg", "horizon_h": h, "mae": mae, "rmse": rmse}
)
else:
y_true_bin = (y_true > 0).astype(int)
y_pred_bin = (y_pred > 0).astype(int)
if y_true_bin.sum() == 0:
f1 = 0.0
else:
tp = ((y_true_bin == 1) & (y_pred_bin == 1)).sum()
fp = ((y_true_bin == 0) & (y_pred_bin == 1)).sum()
fn = ((y_true_bin == 1) & (y_pred_bin == 0)).sum()
prec = tp / (tp + fp) if tp + fp > 0 else 0.0
rec = tp / (tp + fn) if tp + fn > 0 else 0.0
f1 = 2 * prec * rec / (prec + rec) if prec + rec > 0 else 0.0
proba = y_pred.clip(0, 1)
brier = float(((proba - y_true_bin) ** 2).mean())
rows.append(
{"target": target, "kind": "cls", "horizon_h": h, "f1": float(f1), "brier": brier}
)
out_prefix = f"{target}_{MODEL_ID.replace('/', '__')}_tuned"
DATA_DIR.mkdir(parents=True, exist_ok=True)
pd.concat([forecast_series.rename("y_pred"), target_series.rename("y_true")], axis=1).to_csv(
DATA_DIR / f"chronos_tuned_forecast_{out_prefix}.csv",
index=True,
)
df_metrics = pd.DataFrame(rows)
DATA_DIR.mkdir(parents=True, exist_ok=True)
metrics_path = DATA_DIR / "chronos_tuned_metrics.csv"
df_metrics.to_csv(metrics_path, index=False)
for target in df_metrics["target"].unique():
sub = df_metrics[df_metrics["target"] == target].sort_values("horizon_h")
_plot_metrics(sub, target, FIG_DIR / f"chronos_tuned_{target}.png")
print(df_metrics.to_string(index=False, float_format=lambda x: f"{x:.3f}"))
print(f"✔ Sauvegardé : {metrics_path}")
print("✔ Figures par cible (réglages prudents) dans figures/chronos_tuned_<target>.png")
if __name__ == "__main__":
main()