73 lines
2.6 KiB
Python
73 lines
2.6 KiB
Python
# scripts/plot_chronos_multi_errors.py
|
||
from __future__ import annotations
|
||
|
||
from pathlib import Path
|
||
|
||
import matplotlib.pyplot as plt
|
||
import pandas as pd
|
||
|
||
DOC_DIR = Path(__file__).resolve().parent.parent
|
||
DATA_DIR = DOC_DIR / "data"
|
||
FIG_DIR = DOC_DIR / "figures"
|
||
|
||
|
||
def _plot_temp_wind(df: pd.DataFrame, output: Path) -> None:
|
||
output.parent.mkdir(parents=True, exist_ok=True)
|
||
fig, ax = plt.subplots(figsize=(6, 4))
|
||
for target in ("temperature", "wind_speed"):
|
||
sub = df[(df["target"] == target) & (df["kind"] == "reg")].sort_values("horizon_h")
|
||
ax.plot(sub["horizon_h"], sub["mae"], marker="o", label=f"{target} – MAE")
|
||
ax.plot(sub["horizon_h"], sub["rmse"], marker="s", label=f"{target} – RMSE", linestyle="--")
|
||
ax.set_xlabel("Horizon (heures)")
|
||
ax.set_ylabel("Erreur")
|
||
ax.set_title("Chronos small – erreurs température / vent")
|
||
ax.grid(True, linestyle=":", alpha=0.4)
|
||
ax.legend()
|
||
fig.tight_layout()
|
||
fig.savefig(output, dpi=150)
|
||
plt.close(fig)
|
||
|
||
|
||
def _plot_rain(df: pd.DataFrame, output: Path) -> None:
|
||
output.parent.mkdir(parents=True, exist_ok=True)
|
||
sub = df[(df["target"] == "rain_rate") & (df["kind"] == "cls")].sort_values("horizon_h")
|
||
fig, ax1 = plt.subplots(figsize=(6, 4))
|
||
ax1.plot(sub["horizon_h"], sub["f1"], marker="o", color="tab:blue", label="F1")
|
||
ax1.set_ylabel("F1", color="tab:blue")
|
||
ax1.tick_params(axis="y", labelcolor="tab:blue")
|
||
|
||
ax2 = ax1.twinx()
|
||
ax2.plot(sub["horizon_h"], sub["brier"], marker="s", color="tab:red", linestyle="--", label="Brier")
|
||
ax2.set_ylabel("Brier", color="tab:red")
|
||
ax2.tick_params(axis="y", labelcolor="tab:red")
|
||
|
||
ax1.set_xlabel("Horizon (heures)")
|
||
ax1.set_title("Chronos small – pluie (F1/Brier)")
|
||
ax1.grid(True, linestyle=":", alpha=0.4)
|
||
|
||
# Combine legends
|
||
handles, labels = [], []
|
||
for ax in (ax1, ax2):
|
||
h, l = ax.get_legend_handles_labels()
|
||
handles += h
|
||
labels += l
|
||
ax1.legend(handles, labels, loc="upper right")
|
||
|
||
fig.tight_layout()
|
||
fig.savefig(output, dpi=150)
|
||
plt.close(fig)
|
||
|
||
|
||
def main() -> None:
|
||
metrics_path = DATA_DIR / "chronos_multi_metrics.csv"
|
||
if not metrics_path.exists():
|
||
raise SystemExit("chronos_multi_metrics.csv introuvable. Lancez run_chronos_multi.py d'abord.")
|
||
df = pd.read_csv(metrics_path)
|
||
_plot_temp_wind(df, FIG_DIR / "chronos_multi_errors_temp_wind.png")
|
||
_plot_rain(df, FIG_DIR / "chronos_multi_errors_rain.png")
|
||
print("✔ Figures : chronos_multi_errors_temp_wind.png, chronos_multi_errors_rain.png")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|