148 lines
4.5 KiB
Python
148 lines
4.5 KiB
Python
# scripts/run_chronos_holdout6.py
|
||
from __future__ import annotations
|
||
|
||
from pathlib import Path
|
||
import sys
|
||
from typing import Dict
|
||
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
import pandas as pd
|
||
import torch
|
||
|
||
PROJECT_ROOT = Path(__file__).resolve().parents[3]
|
||
if str(PROJECT_ROOT) not in sys.path:
|
||
sys.path.insert(0, str(PROJECT_ROOT))
|
||
|
||
from meteo.dataset import load_raw_csv
|
||
|
||
try:
|
||
from chronos import ChronosPipeline
|
||
except ImportError as exc:
|
||
raise SystemExit("chronos-forecasting manquant : pip install -r requirements.txt") from exc
|
||
|
||
|
||
CSV_PATH = PROJECT_ROOT / "data" / "weather_minutely.csv"
|
||
DOC_DIR = Path(__file__).resolve().parent.parent
|
||
DATA_DIR = DOC_DIR / "data"
|
||
FIG_DIR = DOC_DIR / "figures"
|
||
|
||
MODEL_ID = "amazon/chronos-t5-small"
|
||
RESAMPLE_RULE = "1h"
|
||
CONTEXT_H = 288 # 12 jours
|
||
HOLDOUT_H = 6
|
||
NUM_SAMPLES = 50
|
||
|
||
TARGETS = {
|
||
"temperature": {"kind": "reg", "unit": "°C"},
|
||
"wind_speed": {"kind": "reg", "unit": "km/h"},
|
||
"wind_direction": {"kind": "angle", "unit": "deg"},
|
||
"humidity": {"kind": "reg", "unit": "%"},
|
||
"pressure": {"kind": "reg", "unit": "hPa"},
|
||
}
|
||
|
||
|
||
def _load_series(target: str) -> pd.Series:
|
||
df = load_raw_csv(CSV_PATH)
|
||
if target not in df.columns:
|
||
raise SystemExit(f"Colonne absente : {target}")
|
||
return (
|
||
df[target]
|
||
.resample(RESAMPLE_RULE)
|
||
.mean()
|
||
.interpolate(limit_direction="both")
|
||
.dropna()
|
||
)
|
||
|
||
|
||
def _angular_error(y_true: pd.Series, y_pred: pd.Series) -> pd.Series:
|
||
diff = (y_pred - y_true).abs() % 360
|
||
return diff.apply(lambda x: min(x, 360 - x))
|
||
|
||
|
||
def _plot_errors(df: pd.DataFrame, output_path: Path) -> None:
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
fig, ax = plt.subplots(figsize=(8, 4))
|
||
for target in df["target"].unique():
|
||
sub = df[df["target"] == target]
|
||
ax.plot(sub["h"], sub["abs_error"], marker="o", label=target)
|
||
ax.set_xlabel("Heure du horizon (1-6)")
|
||
ax.set_ylabel("Erreur absolue (unité cible)")
|
||
ax.set_title("Chronos small – erreurs sur les 6 dernières heures (holdout)")
|
||
ax.grid(True, linestyle=":", alpha=0.4)
|
||
ax.legend()
|
||
fig.tight_layout()
|
||
fig.savefig(output_path, dpi=150)
|
||
plt.close(fig)
|
||
|
||
|
||
def main() -> None:
|
||
if not CSV_PATH.exists():
|
||
raise SystemExit(f"Fichier introuvable : {CSV_PATH}")
|
||
|
||
pipeline = ChronosPipeline.from_pretrained(
|
||
MODEL_ID,
|
||
device_map="auto",
|
||
dtype="auto",
|
||
)
|
||
|
||
rows: list[Dict[str, object]] = []
|
||
|
||
for target, meta in TARGETS.items():
|
||
series = _load_series(target)
|
||
if len(series) < CONTEXT_H + HOLDOUT_H:
|
||
raise SystemExit("Pas assez de données pour contexte+holdout.")
|
||
|
||
context = series.iloc[-(CONTEXT_H + HOLDOUT_H) : -HOLDOUT_H]
|
||
holdout = series.iloc[-HOLDOUT_H:]
|
||
|
||
context_tensor = torch.tensor(context.to_numpy(dtype=float), dtype=torch.float32)
|
||
forecasts = pipeline.predict(
|
||
[context_tensor],
|
||
prediction_length=HOLDOUT_H,
|
||
num_samples=NUM_SAMPLES,
|
||
)
|
||
forecast_mean = forecasts.mean(0)
|
||
if forecast_mean.ndim == 2:
|
||
forecast_mean = forecast_mean[0]
|
||
forecast_series = pd.Series(
|
||
np.asarray(forecast_mean).ravel(),
|
||
index=holdout.index,
|
||
)
|
||
|
||
if meta["kind"] == "angle":
|
||
err = _angular_error(holdout, forecast_series)
|
||
else:
|
||
err = (forecast_series - holdout).abs()
|
||
|
||
for i, (ts, e) in enumerate(err.items(), start=1):
|
||
rows.append(
|
||
{
|
||
"target": target,
|
||
"timestamp": ts,
|
||
"h": i,
|
||
"abs_error": float(e),
|
||
"unit": meta["unit"],
|
||
}
|
||
)
|
||
|
||
out_prefix = f"{target}_{MODEL_ID.replace('/', '__')}_holdout6"
|
||
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||
pd.concat([forecast_series.rename("y_pred"), holdout.rename("y_true")], axis=1).to_csv(
|
||
DATA_DIR / f"chronos_holdout6_forecast_{out_prefix}.csv",
|
||
index=True,
|
||
)
|
||
|
||
df_errors = pd.DataFrame(rows)
|
||
err_path = DATA_DIR / "chronos_holdout6_errors.csv"
|
||
df_errors.to_csv(err_path, index=False)
|
||
_plot_errors(df_errors, FIG_DIR / "chronos_holdout6_errors.png")
|
||
|
||
print(df_errors.groupby("target")["abs_error"].mean().to_string(float_format=lambda x: f"{x:.3f}"))
|
||
print(f"✔ Sauvegardé : {err_path}")
|
||
print("✔ Figure : figures/chronos_holdout6_errors.png")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|