# scripts/run_chronos.py from __future__ import annotations import os from pathlib import Path import sys from typing import Tuple import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch PROJECT_ROOT = Path(__file__).resolve().parents[3] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from meteo.dataset import load_raw_csv try: from chronos import ChronosPipeline except ImportError as exc: # pragma: no cover - guidance if deps missing raise SystemExit( "chronos-forecasting est manquant. Installez-le (pip install chronos-forecasting) " "puis relancez." ) from exc CSV_PATH = PROJECT_ROOT / "data" / "weather_minutely.csv" DOC_DIR = Path(__file__).resolve().parent.parent DATA_DIR = DOC_DIR / "data" FIG_DIR = DOC_DIR / "figures" MODEL_ID = os.getenv("CHRONOS_MODEL", "amazon/chronos-t5-small") CONTEXT_LEN = int(os.getenv("CHRONOS_CONTEXT", "336")) # 14 jours (1H) HORIZON_LEN = int(os.getenv("CHRONOS_HORIZON", "96")) # 4 jours (1H) RESAMPLE_RULE = os.getenv("CHRONOS_RESAMPLE", "1h") NUM_SAMPLES = int(os.getenv("CHRONOS_SAMPLES", "20")) # échantillons stochastiques def _load_series(csv_path: Path, target_col: str, rule: str) -> pd.Series: df = load_raw_csv(csv_path) if target_col not in df.columns: raise SystemExit(f"Colonne absente dans le CSV : {target_col!r}") series = ( df[target_col] .resample(rule) .mean() .interpolate(limit_direction="both") ) return series.dropna() def _split_context_target(series: pd.Series, context_len: int, horizon_len: int) -> Tuple[pd.Series, pd.Series]: needed = context_len + horizon_len if len(series) < needed: raise SystemExit(f"Pas assez de données après resampling : {len(series)} < {needed} (context+horizon).") window = series.iloc[-needed:] context = window.iloc[:context_len] target = window.iloc[context_len:] return context, target def _plot_forecast(target: pd.Series, forecast: pd.Series, output_path: Path) -> None: output_path.parent.mkdir(parents=True, exist_ok=True) fig, ax = plt.subplots(figsize=(10, 4)) ax.plot(target.index, target.values, label="Observé", linewidth=2) ax.plot(forecast.index, forecast.values, label="Chronos-2 (préd.)", linewidth=2, linestyle="--") ax.set_title("Prévision Chronos-2 vs observation") ax.set_xlabel("Date") ax.set_ylabel("Température (°C)") ax.grid(True, linestyle=":", alpha=0.4) ax.legend() fig.autofmt_xdate() fig.tight_layout() fig.savefig(output_path, dpi=150) plt.close(fig) def _plot_errors(errors: pd.Series, output_path: Path) -> None: output_path.parent.mkdir(parents=True, exist_ok=True) fig, ax = plt.subplots(figsize=(8, 4)) ax.plot(errors.index, errors.values, marker="o") ax.set_title("Erreur absolue par horizon (heures)") ax.set_xlabel("Horizon (h)") ax.set_ylabel("Erreur absolue (°C)") ax.grid(True, linestyle=":", alpha=0.4) fig.tight_layout() fig.savefig(output_path, dpi=150) plt.close(fig) def main() -> None: if not CSV_PATH.exists(): raise SystemExit(f"Fichier introuvable : {CSV_PATH}") series = _load_series(CSV_PATH, target_col="temperature", rule=RESAMPLE_RULE) context, target = _split_context_target(series, CONTEXT_LEN, HORIZON_LEN) model_slug = MODEL_ID.replace("/", "__") print(f"Contexte : {len(context)} points ({context.index[0]} -> {context.index[-1]})") print(f"Cible : {len(target)} points ({target.index[0]} -> {target.index[-1]})") print(f"Modèle HF : {MODEL_ID}") pipeline = ChronosPipeline.from_pretrained( MODEL_ID, device_map="auto", # CUDA/MPS/CPU automatique dtype="auto", ) context_array = context.to_numpy(dtype=float) context_tensor = torch.tensor(context_array, dtype=torch.float32) forecasts = pipeline.predict( [context_tensor], # batch de 1 série (Tensor attendu) prediction_length=HORIZON_LEN, num_samples=NUM_SAMPLES, ) forecast_mean = forecasts.mean(0) # Si batch dimension présente, on prend la première série if forecast_mean.ndim == 2: forecast_mean = forecast_mean[0] forecast_mean = np.asarray(forecast_mean).ravel() forecast_index = pd.date_range(target.index[0], periods=HORIZON_LEN, freq=RESAMPLE_RULE) forecast_series = pd.Series(forecast_mean, index=forecast_index, name="y_pred") target_series = target.rename("y_true") # Métriques abs_errors = (forecast_series - target_series).abs() mae = abs_errors.mean() rmse = np.sqrt(((forecast_series - target_series) ** 2).mean()) print(f"MAE : {mae:.3f} °C | RMSE : {rmse:.3f} °C") # Sauvegardes DATA_DIR.mkdir(parents=True, exist_ok=True) out_csv = DATA_DIR / f"chronos_forecast_{model_slug}.csv" per_horizon_csv = DATA_DIR / f"chronos_errors_{model_slug}.csv" pd.concat([target_series, forecast_series], axis=1).to_csv(out_csv, index=True) pd.DataFrame({"horizon_h": np.arange(1, HORIZON_LEN + 1), "abs_error": abs_errors.values}).to_csv(per_horizon_csv, index=False) # Figures _plot_forecast(target_series, forecast_series, FIG_DIR / f"chronos_forecast_{model_slug}.png") _plot_errors(abs_errors.rename("abs_error").rename_axis("h"), FIG_DIR / f"chronos_errors_{model_slug}.png") print(f"✔ Sauvegardé : {out_csv}") print(f"✔ Sauvegardé : {per_horizon_csv}") print(f"✔ Figures : {FIG_DIR / f'chronos_forecast_{model_slug}.png'} ; {FIG_DIR / f'chronos_errors_{model_slug}.png'}") if __name__ == "__main__": main()