1
2025-11-26 01:03:41 +01:00

153 lines
5.6 KiB
Python

# scripts/run_chronos.py
from __future__ import annotations
import os
from pathlib import Path
import sys
from typing import Tuple
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
PROJECT_ROOT = Path(__file__).resolve().parents[3]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from meteo.dataset import load_raw_csv
try:
from chronos import ChronosPipeline
except ImportError as exc: # pragma: no cover - guidance if deps missing
raise SystemExit(
"chronos-forecasting est manquant. Installez-le (pip install chronos-forecasting) "
"puis relancez."
) from exc
CSV_PATH = PROJECT_ROOT / "data" / "weather_minutely.csv"
DOC_DIR = Path(__file__).resolve().parent.parent
DATA_DIR = DOC_DIR / "data"
FIG_DIR = DOC_DIR / "figures"
MODEL_ID = os.getenv("CHRONOS_MODEL", "amazon/chronos-t5-small")
CONTEXT_LEN = int(os.getenv("CHRONOS_CONTEXT", "336")) # 14 jours (1H)
HORIZON_LEN = int(os.getenv("CHRONOS_HORIZON", "96")) # 4 jours (1H)
RESAMPLE_RULE = os.getenv("CHRONOS_RESAMPLE", "1h")
NUM_SAMPLES = int(os.getenv("CHRONOS_SAMPLES", "20")) # échantillons stochastiques
def _load_series(csv_path: Path, target_col: str, rule: str) -> pd.Series:
df = load_raw_csv(csv_path)
if target_col not in df.columns:
raise SystemExit(f"Colonne absente dans le CSV : {target_col!r}")
series = (
df[target_col]
.resample(rule)
.mean()
.interpolate(limit_direction="both")
)
return series.dropna()
def _split_context_target(series: pd.Series, context_len: int, horizon_len: int) -> Tuple[pd.Series, pd.Series]:
needed = context_len + horizon_len
if len(series) < needed:
raise SystemExit(f"Pas assez de données après resampling : {len(series)} < {needed} (context+horizon).")
window = series.iloc[-needed:]
context = window.iloc[:context_len]
target = window.iloc[context_len:]
return context, target
def _plot_forecast(target: pd.Series, forecast: pd.Series, output_path: Path) -> None:
output_path.parent.mkdir(parents=True, exist_ok=True)
fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(target.index, target.values, label="Observé", linewidth=2)
ax.plot(forecast.index, forecast.values, label="Chronos-2 (préd.)", linewidth=2, linestyle="--")
ax.set_title("Prévision Chronos-2 vs observation")
ax.set_xlabel("Date")
ax.set_ylabel("Température (°C)")
ax.grid(True, linestyle=":", alpha=0.4)
ax.legend()
fig.autofmt_xdate()
fig.tight_layout()
fig.savefig(output_path, dpi=150)
plt.close(fig)
def _plot_errors(errors: pd.Series, output_path: Path) -> None:
output_path.parent.mkdir(parents=True, exist_ok=True)
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(errors.index, errors.values, marker="o")
ax.set_title("Erreur absolue par horizon (heures)")
ax.set_xlabel("Horizon (h)")
ax.set_ylabel("Erreur absolue (°C)")
ax.grid(True, linestyle=":", alpha=0.4)
fig.tight_layout()
fig.savefig(output_path, dpi=150)
plt.close(fig)
def main() -> None:
if not CSV_PATH.exists():
raise SystemExit(f"Fichier introuvable : {CSV_PATH}")
series = _load_series(CSV_PATH, target_col="temperature", rule=RESAMPLE_RULE)
context, target = _split_context_target(series, CONTEXT_LEN, HORIZON_LEN)
model_slug = MODEL_ID.replace("/", "__")
print(f"Contexte : {len(context)} points ({context.index[0]} -> {context.index[-1]})")
print(f"Cible : {len(target)} points ({target.index[0]} -> {target.index[-1]})")
print(f"Modèle HF : {MODEL_ID}")
pipeline = ChronosPipeline.from_pretrained(
MODEL_ID,
device_map="auto", # CUDA/MPS/CPU automatique
dtype="auto",
)
context_array = context.to_numpy(dtype=float)
context_tensor = torch.tensor(context_array, dtype=torch.float32)
forecasts = pipeline.predict(
[context_tensor], # batch de 1 série (Tensor attendu)
prediction_length=HORIZON_LEN,
num_samples=NUM_SAMPLES,
)
forecast_mean = forecasts.mean(0)
# Si batch dimension présente, on prend la première série
if forecast_mean.ndim == 2:
forecast_mean = forecast_mean[0]
forecast_mean = np.asarray(forecast_mean).ravel()
forecast_index = pd.date_range(target.index[0], periods=HORIZON_LEN, freq=RESAMPLE_RULE)
forecast_series = pd.Series(forecast_mean, index=forecast_index, name="y_pred")
target_series = target.rename("y_true")
# Métriques
abs_errors = (forecast_series - target_series).abs()
mae = abs_errors.mean()
rmse = np.sqrt(((forecast_series - target_series) ** 2).mean())
print(f"MAE : {mae:.3f} °C | RMSE : {rmse:.3f} °C")
# Sauvegardes
DATA_DIR.mkdir(parents=True, exist_ok=True)
out_csv = DATA_DIR / f"chronos_forecast_{model_slug}.csv"
per_horizon_csv = DATA_DIR / f"chronos_errors_{model_slug}.csv"
pd.concat([target_series, forecast_series], axis=1).to_csv(out_csv, index=True)
pd.DataFrame({"horizon_h": np.arange(1, HORIZON_LEN + 1), "abs_error": abs_errors.values}).to_csv(per_horizon_csv, index=False)
# Figures
_plot_forecast(target_series, forecast_series, FIG_DIR / f"chronos_forecast_{model_slug}.png")
_plot_errors(abs_errors.rename("abs_error").rename_axis("h"), FIG_DIR / f"chronos_errors_{model_slug}.png")
print(f"✔ Sauvegardé : {out_csv}")
print(f"✔ Sauvegardé : {per_horizon_csv}")
print(f"✔ Figures : {FIG_DIR / f'chronos_forecast_{model_slug}.png'} ; {FIG_DIR / f'chronos_errors_{model_slug}.png'}")
if __name__ == "__main__":
main()