# scripts/run_chronos_multi.py from __future__ import annotations from pathlib import Path import sys from typing import Iterable import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch PROJECT_ROOT = Path(__file__).resolve().parents[3] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from meteo.dataset import load_raw_csv try: from chronos import ChronosPipeline except ImportError as exc: raise SystemExit("chronos-forecasting manquant : pip install -r requirements.txt") from exc CSV_PATH = PROJECT_ROOT / "data" / "weather_minutely.csv" DOC_DIR = Path(__file__).resolve().parent.parent DATA_DIR = DOC_DIR / "data" FIG_DIR = DOC_DIR / "figures" MODEL_ID = "amazon/chronos-t5-small" RESAMPLE_RULE = "1h" # on reste sur l'heure pour rester aligné avec le pré-entraînement Chronos CONTEXT_H = 336 # 14 jours HORIZONS_H = (1, 6, 24) # horizons demandés (10 min exclu car modèle horaire) NUM_SAMPLES = 50 TARGETS = { "temperature": {"kind": "reg"}, "wind_speed": {"kind": "reg"}, "rain_rate": {"kind": "rain"}, } def _load_series(target: str) -> pd.Series: df = load_raw_csv(CSV_PATH) if target not in df.columns: raise SystemExit(f"Colonne absente : {target}") s = df[target].resample(RESAMPLE_RULE).mean().interpolate(limit_direction="both") return s.dropna() def _prepare_window(series: pd.Series, context_h: int, horizon_h: int) -> tuple[np.ndarray, pd.Series]: needed = context_h + horizon_h if len(series) < needed: raise SystemExit(f"Pas assez de données pour {needed} heures.") window = series.iloc[-needed:] context = window.iloc[:context_h] target = window.iloc[context_h:] return context.to_numpy(dtype=float), target def _plot_metrics(df: pd.DataFrame, target: str, output_path: Path) -> None: output_path.parent.mkdir(parents=True, exist_ok=True) fig, ax = plt.subplots(figsize=(6, 4)) if df["kind"].iloc[0] == "reg": ax.plot(df["horizon_h"], df["mae"], marker="o", label="MAE") ax.plot(df["horizon_h"], df["rmse"], marker="o", label="RMSE") ax.set_ylabel("Erreur (°C)" if target == "temperature" else "Erreur (unité)") else: ax.plot(df["horizon_h"], df["f1"], marker="o", label="F1") ax.plot(df["horizon_h"], df["brier"], marker="o", label="Brier") ax.set_ylabel("Score") ax.set_xlabel("Horizon (heures)") ax.set_title(f"Chronos (small) – {target}") ax.grid(True, linestyle=":", alpha=0.4) ax.legend() fig.tight_layout() fig.savefig(output_path, dpi=150) plt.close(fig) def main() -> None: if not CSV_PATH.exists(): raise SystemExit(f"Fichier introuvable : {CSV_PATH}") pipeline = ChronosPipeline.from_pretrained( MODEL_ID, device_map="auto", dtype="auto", ) rows: list[dict[str, object]] = [] for target, meta in TARGETS.items(): series = _load_series(target) # On prend la plus grande fenêtre (max horizon) context_arr, target_series = _prepare_window(series, CONTEXT_H, max(HORIZONS_H)) context_tensor = torch.tensor(context_arr, dtype=torch.float32) forecasts = pipeline.predict( [context_tensor], prediction_length=max(HORIZONS_H), num_samples=NUM_SAMPLES, ) forecast_mean = forecasts.mean(0) if forecast_mean.ndim == 2: forecast_mean = forecast_mean[0] forecast_series = pd.Series( np.asarray(forecast_mean).ravel(), index=pd.date_range(target_series.index[0], periods=max(HORIZONS_H), freq=RESAMPLE_RULE), ) for h in HORIZONS_H: y_true = target_series.iloc[:h] y_pred = forecast_series.iloc[:h] if meta["kind"] == "reg": mae = float((y_pred - y_true).abs().mean()) rmse = float(np.sqrt(((y_pred - y_true) ** 2).mean())) rows.append( {"target": target, "kind": "reg", "horizon_h": h, "mae": mae, "rmse": rmse} ) else: y_true_bin = (y_true > 0).astype(int) y_pred_bin = (y_pred > 0).astype(int) if y_true_bin.sum() == 0: f1 = 0.0 else: tp = ((y_true_bin == 1) & (y_pred_bin == 1)).sum() fp = ((y_true_bin == 0) & (y_pred_bin == 1)).sum() fn = ((y_true_bin == 1) & (y_pred_bin == 0)).sum() prec = tp / (tp + fp) if tp + fp > 0 else 0.0 rec = tp / (tp + fn) if tp + fn > 0 else 0.0 f1 = 2 * prec * rec / (prec + rec) if prec + rec > 0 else 0.0 # Brier approximé en traitant y_pred comme proba clampée [0,1] proba = y_pred.clip(0, 1) brier = float(((proba - y_true_bin) ** 2).mean()) rows.append( {"target": target, "kind": "cls", "horizon_h": h, "f1": float(f1), "brier": brier} ) # Sauvegarde forecast/target complets pour inspection out_prefix = f"{target}_{MODEL_ID.replace('/', '__')}" DATA_DIR.mkdir(parents=True, exist_ok=True) pd.concat([forecast_series.rename("y_pred"), target_series.rename("y_true")], axis=1).to_csv( DATA_DIR / f"chronos_multi_forecast_{out_prefix}.csv", index=True, ) df_metrics = pd.DataFrame(rows) DATA_DIR.mkdir(parents=True, exist_ok=True) metrics_path = DATA_DIR / "chronos_multi_metrics.csv" df_metrics.to_csv(metrics_path, index=False) for target in df_metrics["target"].unique(): sub = df_metrics[df_metrics["target"] == target].sort_values("horizon_h") _plot_metrics(sub, target, FIG_DIR / f"chronos_multi_{target}.png") print(df_metrics.to_string(index=False, float_format=lambda x: f"{x:.3f}")) print(f"✔ Sauvegardé : {metrics_path}") print("✔ Figures par cible dans figures/chronos_multi_.png") if __name__ == "__main__": main()