1
donnees_meteo/docs/11 - Modèle Chronos/scripts/plot_chronos_multi_errors.py

73 lines
2.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# scripts/plot_chronos_multi_errors.py
from __future__ import annotations
from pathlib import Path
import matplotlib.pyplot as plt
import pandas as pd
DOC_DIR = Path(__file__).resolve().parent.parent
DATA_DIR = DOC_DIR / "data"
FIG_DIR = DOC_DIR / "figures"
def _plot_temp_wind(df: pd.DataFrame, output: Path) -> None:
output.parent.mkdir(parents=True, exist_ok=True)
fig, ax = plt.subplots(figsize=(6, 4))
for target in ("temperature", "wind_speed"):
sub = df[(df["target"] == target) & (df["kind"] == "reg")].sort_values("horizon_h")
ax.plot(sub["horizon_h"], sub["mae"], marker="o", label=f"{target} MAE")
ax.plot(sub["horizon_h"], sub["rmse"], marker="s", label=f"{target} RMSE", linestyle="--")
ax.set_xlabel("Horizon (heures)")
ax.set_ylabel("Erreur")
ax.set_title("Chronos small erreurs température / vent")
ax.grid(True, linestyle=":", alpha=0.4)
ax.legend()
fig.tight_layout()
fig.savefig(output, dpi=150)
plt.close(fig)
def _plot_rain(df: pd.DataFrame, output: Path) -> None:
output.parent.mkdir(parents=True, exist_ok=True)
sub = df[(df["target"] == "rain_rate") & (df["kind"] == "cls")].sort_values("horizon_h")
fig, ax1 = plt.subplots(figsize=(6, 4))
ax1.plot(sub["horizon_h"], sub["f1"], marker="o", color="tab:blue", label="F1")
ax1.set_ylabel("F1", color="tab:blue")
ax1.tick_params(axis="y", labelcolor="tab:blue")
ax2 = ax1.twinx()
ax2.plot(sub["horizon_h"], sub["brier"], marker="s", color="tab:red", linestyle="--", label="Brier")
ax2.set_ylabel("Brier", color="tab:red")
ax2.tick_params(axis="y", labelcolor="tab:red")
ax1.set_xlabel("Horizon (heures)")
ax1.set_title("Chronos small pluie (F1/Brier)")
ax1.grid(True, linestyle=":", alpha=0.4)
# Combine legends
handles, labels = [], []
for ax in (ax1, ax2):
h, l = ax.get_legend_handles_labels()
handles += h
labels += l
ax1.legend(handles, labels, loc="upper right")
fig.tight_layout()
fig.savefig(output, dpi=150)
plt.close(fig)
def main() -> None:
metrics_path = DATA_DIR / "chronos_multi_metrics.csv"
if not metrics_path.exists():
raise SystemExit("chronos_multi_metrics.csv introuvable. Lancez run_chronos_multi.py d'abord.")
df = pd.read_csv(metrics_path)
_plot_temp_wind(df, FIG_DIR / "chronos_multi_errors_temp_wind.png")
_plot_rain(df, FIG_DIR / "chronos_multi_errors_rain.png")
print("✔ Figures : chronos_multi_errors_temp_wind.png, chronos_multi_errors_rain.png")
if __name__ == "__main__":
main()