# scripts/compare_chronos.py from __future__ import annotations from pathlib import Path import re import matplotlib.pyplot as plt import pandas as pd DOC_DIR = Path(__file__).resolve().parent.parent DATA_DIR = DOC_DIR / "data" FIG_DIR = DOC_DIR / "figures" def _load_forecasts() -> pd.DataFrame: records = [] pattern = re.compile(r"chronos_forecast_(.+)\.csv") for csv in DATA_DIR.glob("chronos_forecast_*.csv"): match = pattern.match(csv.name) if not match: continue model = match.group(1).replace("__", "/") df = pd.read_csv(csv) if not {"y_true", "y_pred"}.issubset(df.columns): continue err = df["y_pred"] - df["y_true"] records.append( { "model": model, "mae": err.abs().mean(), "rmse": (err.pow(2).mean()) ** 0.5, } ) return pd.DataFrame(records) def _plot_comparison(df: pd.DataFrame, output_path: Path) -> None: output_path.parent.mkdir(parents=True, exist_ok=True) fig, ax = plt.subplots(figsize=(7, 4)) x = range(len(df)) ax.bar([i - 0.15 for i in x], df["mae"], width=0.3, label="MAE") ax.bar([i + 0.15 for i in x], df["rmse"], width=0.3, label="RMSE") ax.set_xticks(list(x)) ax.set_xticklabels(df["model"], rotation=15) ax.set_ylabel("Erreur (°C)") ax.set_title("Chronos T5 – comparaison des tailles") ax.grid(True, linestyle=":", alpha=0.4, axis="y") ax.legend() fig.tight_layout() fig.savefig(output_path, dpi=150) plt.close(fig) def main() -> None: df = _load_forecasts() if df.empty: raise SystemExit("Aucune sortie chronos_forecast_*.csv trouvée dans data/. Lancez run_chronos.py d'abord.") df_sorted = df.sort_values("mae") summary_path = DATA_DIR / "chronos_summary.csv" df_sorted.to_csv(summary_path, index=False) _plot_comparison(df_sorted, FIG_DIR / "chronos_models_comparison.png") print(df_sorted.to_string(index=False, float_format=lambda x: f"{x:.3f}")) print(f"✔ Sauvegardé : {summary_path}") print(f"✔ Figure : {FIG_DIR / 'chronos_models_comparison.png'}") if __name__ == "__main__": main()