1
2025-11-26 01:03:41 +01:00

148 lines
4.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# scripts/run_chronos_holdout6.py
from __future__ import annotations
from pathlib import Path
import sys
from typing import Dict
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
PROJECT_ROOT = Path(__file__).resolve().parents[3]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from meteo.dataset import load_raw_csv
try:
from chronos import ChronosPipeline
except ImportError as exc:
raise SystemExit("chronos-forecasting manquant : pip install -r requirements.txt") from exc
CSV_PATH = PROJECT_ROOT / "data" / "weather_minutely.csv"
DOC_DIR = Path(__file__).resolve().parent.parent
DATA_DIR = DOC_DIR / "data"
FIG_DIR = DOC_DIR / "figures"
MODEL_ID = "amazon/chronos-t5-small"
RESAMPLE_RULE = "1h"
CONTEXT_H = 288 # 12 jours
HOLDOUT_H = 6
NUM_SAMPLES = 50
TARGETS = {
"temperature": {"kind": "reg", "unit": "°C"},
"wind_speed": {"kind": "reg", "unit": "km/h"},
"wind_direction": {"kind": "angle", "unit": "deg"},
"humidity": {"kind": "reg", "unit": "%"},
"pressure": {"kind": "reg", "unit": "hPa"},
}
def _load_series(target: str) -> pd.Series:
df = load_raw_csv(CSV_PATH)
if target not in df.columns:
raise SystemExit(f"Colonne absente : {target}")
return (
df[target]
.resample(RESAMPLE_RULE)
.mean()
.interpolate(limit_direction="both")
.dropna()
)
def _angular_error(y_true: pd.Series, y_pred: pd.Series) -> pd.Series:
diff = (y_pred - y_true).abs() % 360
return diff.apply(lambda x: min(x, 360 - x))
def _plot_errors(df: pd.DataFrame, output_path: Path) -> None:
output_path.parent.mkdir(parents=True, exist_ok=True)
fig, ax = plt.subplots(figsize=(8, 4))
for target in df["target"].unique():
sub = df[df["target"] == target]
ax.plot(sub["h"], sub["abs_error"], marker="o", label=target)
ax.set_xlabel("Heure du horizon (1-6)")
ax.set_ylabel("Erreur absolue (unité cible)")
ax.set_title("Chronos small erreurs sur les 6 dernières heures (holdout)")
ax.grid(True, linestyle=":", alpha=0.4)
ax.legend()
fig.tight_layout()
fig.savefig(output_path, dpi=150)
plt.close(fig)
def main() -> None:
if not CSV_PATH.exists():
raise SystemExit(f"Fichier introuvable : {CSV_PATH}")
pipeline = ChronosPipeline.from_pretrained(
MODEL_ID,
device_map="auto",
dtype="auto",
)
rows: list[Dict[str, object]] = []
for target, meta in TARGETS.items():
series = _load_series(target)
if len(series) < CONTEXT_H + HOLDOUT_H:
raise SystemExit("Pas assez de données pour contexte+holdout.")
context = series.iloc[-(CONTEXT_H + HOLDOUT_H) : -HOLDOUT_H]
holdout = series.iloc[-HOLDOUT_H:]
context_tensor = torch.tensor(context.to_numpy(dtype=float), dtype=torch.float32)
forecasts = pipeline.predict(
[context_tensor],
prediction_length=HOLDOUT_H,
num_samples=NUM_SAMPLES,
)
forecast_mean = forecasts.mean(0)
if forecast_mean.ndim == 2:
forecast_mean = forecast_mean[0]
forecast_series = pd.Series(
np.asarray(forecast_mean).ravel(),
index=holdout.index,
)
if meta["kind"] == "angle":
err = _angular_error(holdout, forecast_series)
else:
err = (forecast_series - holdout).abs()
for i, (ts, e) in enumerate(err.items(), start=1):
rows.append(
{
"target": target,
"timestamp": ts,
"h": i,
"abs_error": float(e),
"unit": meta["unit"],
}
)
out_prefix = f"{target}_{MODEL_ID.replace('/', '__')}_holdout6"
DATA_DIR.mkdir(parents=True, exist_ok=True)
pd.concat([forecast_series.rename("y_pred"), holdout.rename("y_true")], axis=1).to_csv(
DATA_DIR / f"chronos_holdout6_forecast_{out_prefix}.csv",
index=True,
)
df_errors = pd.DataFrame(rows)
err_path = DATA_DIR / "chronos_holdout6_errors.csv"
df_errors.to_csv(err_path, index=False)
_plot_errors(df_errors, FIG_DIR / "chronos_holdout6_errors.png")
print(df_errors.groupby("target")["abs_error"].mean().to_string(float_format=lambda x: f"{x:.3f}"))
print(f"✔ Sauvegardé : {err_path}")
print("✔ Figure : figures/chronos_holdout6_errors.png")
if __name__ == "__main__":
main()