58 lines
1.8 KiB
Python
58 lines
1.8 KiB
Python
# scripts/plot_chronos_errors_combined.py
|
||
from __future__ import annotations
|
||
|
||
from pathlib import Path
|
||
import re
|
||
|
||
import matplotlib.pyplot as plt
|
||
import pandas as pd
|
||
|
||
DOC_DIR = Path(__file__).resolve().parent.parent
|
||
DATA_DIR = DOC_DIR / "data"
|
||
FIG_DIR = DOC_DIR / "figures"
|
||
|
||
|
||
def load_errors() -> pd.DataFrame:
|
||
pattern = re.compile(r"chronos_forecast_(amazon__chronos-t5-[a-z]+)\.csv")
|
||
records = []
|
||
for csv in DATA_DIR.glob("chronos_forecast_*.csv"):
|
||
m = pattern.match(csv.name)
|
||
if not m:
|
||
continue
|
||
model = m.group(1).replace("__", "/")
|
||
df = pd.read_csv(csv)
|
||
if not {"y_true", "y_pred"}.issubset(df.columns):
|
||
continue
|
||
err = (df["y_pred"] - df["y_true"]).abs()
|
||
for i, v in enumerate(err, start=1):
|
||
records.append({"model": model, "horizon_h": i, "abs_error": v})
|
||
return pd.DataFrame(records)
|
||
|
||
|
||
def plot_errors(df: pd.DataFrame, output_path: Path) -> None:
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
fig, ax = plt.subplots(figsize=(8, 4))
|
||
for model, sub in df.groupby("model"):
|
||
sub_sorted = sub.sort_values("horizon_h")
|
||
ax.plot(sub_sorted["horizon_h"], sub_sorted["abs_error"], label=model, linewidth=2)
|
||
ax.set_xlabel("Horizon (heures)")
|
||
ax.set_ylabel("Erreur absolue (°C)")
|
||
ax.set_title("Chronos T5 – erreur absolue vs horizon")
|
||
ax.grid(True, linestyle=":", alpha=0.4)
|
||
ax.legend()
|
||
fig.tight_layout()
|
||
fig.savefig(output_path, dpi=150)
|
||
plt.close(fig)
|
||
|
||
|
||
def main() -> None:
|
||
df = load_errors()
|
||
if df.empty:
|
||
raise SystemExit("Aucun fichier chronos_forecast_*.csv trouvé ou colonnes manquantes.")
|
||
plot_errors(df, FIG_DIR / "chronos_errors_combined.png")
|
||
print("✔ Figure : figures/chronos_errors_combined.png")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|