69 lines
2.2 KiB
Python
69 lines
2.2 KiB
Python
# scripts/compare_chronos.py
|
||
from __future__ import annotations
|
||
|
||
from pathlib import Path
|
||
import re
|
||
|
||
import matplotlib.pyplot as plt
|
||
import pandas as pd
|
||
|
||
DOC_DIR = Path(__file__).resolve().parent.parent
|
||
DATA_DIR = DOC_DIR / "data"
|
||
FIG_DIR = DOC_DIR / "figures"
|
||
|
||
|
||
def _load_forecasts() -> pd.DataFrame:
|
||
records = []
|
||
pattern = re.compile(r"chronos_forecast_(.+)\.csv")
|
||
for csv in DATA_DIR.glob("chronos_forecast_*.csv"):
|
||
match = pattern.match(csv.name)
|
||
if not match:
|
||
continue
|
||
model = match.group(1).replace("__", "/")
|
||
df = pd.read_csv(csv)
|
||
if not {"y_true", "y_pred"}.issubset(df.columns):
|
||
continue
|
||
err = df["y_pred"] - df["y_true"]
|
||
records.append(
|
||
{
|
||
"model": model,
|
||
"mae": err.abs().mean(),
|
||
"rmse": (err.pow(2).mean()) ** 0.5,
|
||
}
|
||
)
|
||
return pd.DataFrame(records)
|
||
|
||
|
||
def _plot_comparison(df: pd.DataFrame, output_path: Path) -> None:
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
fig, ax = plt.subplots(figsize=(7, 4))
|
||
x = range(len(df))
|
||
ax.bar([i - 0.15 for i in x], df["mae"], width=0.3, label="MAE")
|
||
ax.bar([i + 0.15 for i in x], df["rmse"], width=0.3, label="RMSE")
|
||
ax.set_xticks(list(x))
|
||
ax.set_xticklabels(df["model"], rotation=15)
|
||
ax.set_ylabel("Erreur (°C)")
|
||
ax.set_title("Chronos T5 – comparaison des tailles")
|
||
ax.grid(True, linestyle=":", alpha=0.4, axis="y")
|
||
ax.legend()
|
||
fig.tight_layout()
|
||
fig.savefig(output_path, dpi=150)
|
||
plt.close(fig)
|
||
|
||
|
||
def main() -> None:
|
||
df = _load_forecasts()
|
||
if df.empty:
|
||
raise SystemExit("Aucune sortie chronos_forecast_*.csv trouvée dans data/. Lancez run_chronos.py d'abord.")
|
||
df_sorted = df.sort_values("mae")
|
||
summary_path = DATA_DIR / "chronos_summary.csv"
|
||
df_sorted.to_csv(summary_path, index=False)
|
||
_plot_comparison(df_sorted, FIG_DIR / "chronos_models_comparison.png")
|
||
print(df_sorted.to_string(index=False, float_format=lambda x: f"{x:.3f}"))
|
||
print(f"✔ Sauvegardé : {summary_path}")
|
||
print(f"✔ Figure : {FIG_DIR / 'chronos_models_comparison.png'}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|