# scripts/plot_chronos_multi_errors.py from __future__ import annotations from pathlib import Path import matplotlib.pyplot as plt import pandas as pd DOC_DIR = Path(__file__).resolve().parent.parent DATA_DIR = DOC_DIR / "data" FIG_DIR = DOC_DIR / "figures" def _plot_temp_wind(df: pd.DataFrame, output: Path) -> None: output.parent.mkdir(parents=True, exist_ok=True) fig, ax = plt.subplots(figsize=(6, 4)) for target in ("temperature", "wind_speed"): sub = df[(df["target"] == target) & (df["kind"] == "reg")].sort_values("horizon_h") ax.plot(sub["horizon_h"], sub["mae"], marker="o", label=f"{target} – MAE") ax.plot(sub["horizon_h"], sub["rmse"], marker="s", label=f"{target} – RMSE", linestyle="--") ax.set_xlabel("Horizon (heures)") ax.set_ylabel("Erreur") ax.set_title("Chronos small – erreurs température / vent") ax.grid(True, linestyle=":", alpha=0.4) ax.legend() fig.tight_layout() fig.savefig(output, dpi=150) plt.close(fig) def _plot_rain(df: pd.DataFrame, output: Path) -> None: output.parent.mkdir(parents=True, exist_ok=True) sub = df[(df["target"] == "rain_rate") & (df["kind"] == "cls")].sort_values("horizon_h") fig, ax1 = plt.subplots(figsize=(6, 4)) ax1.plot(sub["horizon_h"], sub["f1"], marker="o", color="tab:blue", label="F1") ax1.set_ylabel("F1", color="tab:blue") ax1.tick_params(axis="y", labelcolor="tab:blue") ax2 = ax1.twinx() ax2.plot(sub["horizon_h"], sub["brier"], marker="s", color="tab:red", linestyle="--", label="Brier") ax2.set_ylabel("Brier", color="tab:red") ax2.tick_params(axis="y", labelcolor="tab:red") ax1.set_xlabel("Horizon (heures)") ax1.set_title("Chronos small – pluie (F1/Brier)") ax1.grid(True, linestyle=":", alpha=0.4) # Combine legends handles, labels = [], [] for ax in (ax1, ax2): h, l = ax.get_legend_handles_labels() handles += h labels += l ax1.legend(handles, labels, loc="upper right") fig.tight_layout() fig.savefig(output, dpi=150) plt.close(fig) def main() -> None: metrics_path = DATA_DIR / "chronos_multi_metrics.csv" if not metrics_path.exists(): raise SystemExit("chronos_multi_metrics.csv introuvable. Lancez run_chronos_multi.py d'abord.") df = pd.read_csv(metrics_path) _plot_temp_wind(df, FIG_DIR / "chronos_multi_errors_temp_wind.png") _plot_rain(df, FIG_DIR / "chronos_multi_errors_rain.png") print("✔ Figures : chronos_multi_errors_temp_wind.png, chronos_multi_errors_rain.png") if __name__ == "__main__": main()