1
donnees_meteo/docs/11 - Modèle Chronos/scripts/plot_chronos_errors_combined.py

58 lines
1.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# scripts/plot_chronos_errors_combined.py
from __future__ import annotations
from pathlib import Path
import re
import matplotlib.pyplot as plt
import pandas as pd
DOC_DIR = Path(__file__).resolve().parent.parent
DATA_DIR = DOC_DIR / "data"
FIG_DIR = DOC_DIR / "figures"
def load_errors() -> pd.DataFrame:
pattern = re.compile(r"chronos_forecast_(amazon__chronos-t5-[a-z]+)\.csv")
records = []
for csv in DATA_DIR.glob("chronos_forecast_*.csv"):
m = pattern.match(csv.name)
if not m:
continue
model = m.group(1).replace("__", "/")
df = pd.read_csv(csv)
if not {"y_true", "y_pred"}.issubset(df.columns):
continue
err = (df["y_pred"] - df["y_true"]).abs()
for i, v in enumerate(err, start=1):
records.append({"model": model, "horizon_h": i, "abs_error": v})
return pd.DataFrame(records)
def plot_errors(df: pd.DataFrame, output_path: Path) -> None:
output_path.parent.mkdir(parents=True, exist_ok=True)
fig, ax = plt.subplots(figsize=(8, 4))
for model, sub in df.groupby("model"):
sub_sorted = sub.sort_values("horizon_h")
ax.plot(sub_sorted["horizon_h"], sub_sorted["abs_error"], label=model, linewidth=2)
ax.set_xlabel("Horizon (heures)")
ax.set_ylabel("Erreur absolue (°C)")
ax.set_title("Chronos T5 erreur absolue vs horizon")
ax.grid(True, linestyle=":", alpha=0.4)
ax.legend()
fig.tight_layout()
fig.savefig(output_path, dpi=150)
plt.close(fig)
def main() -> None:
df = load_errors()
if df.empty:
raise SystemExit("Aucun fichier chronos_forecast_*.csv trouvé ou colonnes manquantes.")
plot_errors(df, FIG_DIR / "chronos_errors_combined.png")
print("✔ Figure : figures/chronos_errors_combined.png")
if __name__ == "__main__":
main()