1
2025-11-26 01:03:41 +01:00

69 lines
2.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# scripts/compare_chronos.py
from __future__ import annotations
from pathlib import Path
import re
import matplotlib.pyplot as plt
import pandas as pd
DOC_DIR = Path(__file__).resolve().parent.parent
DATA_DIR = DOC_DIR / "data"
FIG_DIR = DOC_DIR / "figures"
def _load_forecasts() -> pd.DataFrame:
records = []
pattern = re.compile(r"chronos_forecast_(.+)\.csv")
for csv in DATA_DIR.glob("chronos_forecast_*.csv"):
match = pattern.match(csv.name)
if not match:
continue
model = match.group(1).replace("__", "/")
df = pd.read_csv(csv)
if not {"y_true", "y_pred"}.issubset(df.columns):
continue
err = df["y_pred"] - df["y_true"]
records.append(
{
"model": model,
"mae": err.abs().mean(),
"rmse": (err.pow(2).mean()) ** 0.5,
}
)
return pd.DataFrame(records)
def _plot_comparison(df: pd.DataFrame, output_path: Path) -> None:
output_path.parent.mkdir(parents=True, exist_ok=True)
fig, ax = plt.subplots(figsize=(7, 4))
x = range(len(df))
ax.bar([i - 0.15 for i in x], df["mae"], width=0.3, label="MAE")
ax.bar([i + 0.15 for i in x], df["rmse"], width=0.3, label="RMSE")
ax.set_xticks(list(x))
ax.set_xticklabels(df["model"], rotation=15)
ax.set_ylabel("Erreur (°C)")
ax.set_title("Chronos T5 comparaison des tailles")
ax.grid(True, linestyle=":", alpha=0.4, axis="y")
ax.legend()
fig.tight_layout()
fig.savefig(output_path, dpi=150)
plt.close(fig)
def main() -> None:
df = _load_forecasts()
if df.empty:
raise SystemExit("Aucune sortie chronos_forecast_*.csv trouvée dans data/. Lancez run_chronos.py d'abord.")
df_sorted = df.sort_values("mae")
summary_path = DATA_DIR / "chronos_summary.csv"
df_sorted.to_csv(summary_path, index=False)
_plot_comparison(df_sorted, FIG_DIR / "chronos_models_comparison.png")
print(df_sorted.to_string(index=False, float_format=lambda x: f"{x:.3f}"))
print(f"✔ Sauvegardé : {summary_path}")
print(f"✔ Figure : {FIG_DIR / 'chronos_models_comparison.png'}")
if __name__ == "__main__":
main()