153 lines
5.6 KiB
Python
153 lines
5.6 KiB
Python
# scripts/run_chronos.py
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
from pathlib import Path
|
|
import sys
|
|
from typing import Tuple
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import pandas as pd
|
|
import torch
|
|
|
|
PROJECT_ROOT = Path(__file__).resolve().parents[3]
|
|
if str(PROJECT_ROOT) not in sys.path:
|
|
sys.path.insert(0, str(PROJECT_ROOT))
|
|
|
|
from meteo.dataset import load_raw_csv
|
|
|
|
try:
|
|
from chronos import ChronosPipeline
|
|
except ImportError as exc: # pragma: no cover - guidance if deps missing
|
|
raise SystemExit(
|
|
"chronos-forecasting est manquant. Installez-le (pip install chronos-forecasting) "
|
|
"puis relancez."
|
|
) from exc
|
|
|
|
|
|
CSV_PATH = PROJECT_ROOT / "data" / "weather_minutely.csv"
|
|
DOC_DIR = Path(__file__).resolve().parent.parent
|
|
DATA_DIR = DOC_DIR / "data"
|
|
FIG_DIR = DOC_DIR / "figures"
|
|
|
|
MODEL_ID = os.getenv("CHRONOS_MODEL", "amazon/chronos-t5-small")
|
|
CONTEXT_LEN = int(os.getenv("CHRONOS_CONTEXT", "336")) # 14 jours (1H)
|
|
HORIZON_LEN = int(os.getenv("CHRONOS_HORIZON", "96")) # 4 jours (1H)
|
|
RESAMPLE_RULE = os.getenv("CHRONOS_RESAMPLE", "1h")
|
|
NUM_SAMPLES = int(os.getenv("CHRONOS_SAMPLES", "20")) # échantillons stochastiques
|
|
|
|
|
|
def _load_series(csv_path: Path, target_col: str, rule: str) -> pd.Series:
|
|
df = load_raw_csv(csv_path)
|
|
if target_col not in df.columns:
|
|
raise SystemExit(f"Colonne absente dans le CSV : {target_col!r}")
|
|
series = (
|
|
df[target_col]
|
|
.resample(rule)
|
|
.mean()
|
|
.interpolate(limit_direction="both")
|
|
)
|
|
return series.dropna()
|
|
|
|
|
|
def _split_context_target(series: pd.Series, context_len: int, horizon_len: int) -> Tuple[pd.Series, pd.Series]:
|
|
needed = context_len + horizon_len
|
|
if len(series) < needed:
|
|
raise SystemExit(f"Pas assez de données après resampling : {len(series)} < {needed} (context+horizon).")
|
|
window = series.iloc[-needed:]
|
|
context = window.iloc[:context_len]
|
|
target = window.iloc[context_len:]
|
|
return context, target
|
|
|
|
|
|
def _plot_forecast(target: pd.Series, forecast: pd.Series, output_path: Path) -> None:
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
fig, ax = plt.subplots(figsize=(10, 4))
|
|
ax.plot(target.index, target.values, label="Observé", linewidth=2)
|
|
ax.plot(forecast.index, forecast.values, label="Chronos-2 (préd.)", linewidth=2, linestyle="--")
|
|
ax.set_title("Prévision Chronos-2 vs observation")
|
|
ax.set_xlabel("Date")
|
|
ax.set_ylabel("Température (°C)")
|
|
ax.grid(True, linestyle=":", alpha=0.4)
|
|
ax.legend()
|
|
fig.autofmt_xdate()
|
|
fig.tight_layout()
|
|
fig.savefig(output_path, dpi=150)
|
|
plt.close(fig)
|
|
|
|
|
|
def _plot_errors(errors: pd.Series, output_path: Path) -> None:
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
fig, ax = plt.subplots(figsize=(8, 4))
|
|
ax.plot(errors.index, errors.values, marker="o")
|
|
ax.set_title("Erreur absolue par horizon (heures)")
|
|
ax.set_xlabel("Horizon (h)")
|
|
ax.set_ylabel("Erreur absolue (°C)")
|
|
ax.grid(True, linestyle=":", alpha=0.4)
|
|
fig.tight_layout()
|
|
fig.savefig(output_path, dpi=150)
|
|
plt.close(fig)
|
|
|
|
|
|
def main() -> None:
|
|
if not CSV_PATH.exists():
|
|
raise SystemExit(f"Fichier introuvable : {CSV_PATH}")
|
|
|
|
series = _load_series(CSV_PATH, target_col="temperature", rule=RESAMPLE_RULE)
|
|
context, target = _split_context_target(series, CONTEXT_LEN, HORIZON_LEN)
|
|
|
|
model_slug = MODEL_ID.replace("/", "__")
|
|
print(f"Contexte : {len(context)} points ({context.index[0]} -> {context.index[-1]})")
|
|
print(f"Cible : {len(target)} points ({target.index[0]} -> {target.index[-1]})")
|
|
print(f"Modèle HF : {MODEL_ID}")
|
|
|
|
pipeline = ChronosPipeline.from_pretrained(
|
|
MODEL_ID,
|
|
device_map="auto", # CUDA/MPS/CPU automatique
|
|
dtype="auto",
|
|
)
|
|
|
|
context_array = context.to_numpy(dtype=float)
|
|
context_tensor = torch.tensor(context_array, dtype=torch.float32)
|
|
forecasts = pipeline.predict(
|
|
[context_tensor], # batch de 1 série (Tensor attendu)
|
|
prediction_length=HORIZON_LEN,
|
|
num_samples=NUM_SAMPLES,
|
|
)
|
|
forecast_mean = forecasts.mean(0)
|
|
# Si batch dimension présente, on prend la première série
|
|
if forecast_mean.ndim == 2:
|
|
forecast_mean = forecast_mean[0]
|
|
forecast_mean = np.asarray(forecast_mean).ravel()
|
|
|
|
forecast_index = pd.date_range(target.index[0], periods=HORIZON_LEN, freq=RESAMPLE_RULE)
|
|
forecast_series = pd.Series(forecast_mean, index=forecast_index, name="y_pred")
|
|
target_series = target.rename("y_true")
|
|
|
|
# Métriques
|
|
abs_errors = (forecast_series - target_series).abs()
|
|
mae = abs_errors.mean()
|
|
rmse = np.sqrt(((forecast_series - target_series) ** 2).mean())
|
|
print(f"MAE : {mae:.3f} °C | RMSE : {rmse:.3f} °C")
|
|
|
|
# Sauvegardes
|
|
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
|
out_csv = DATA_DIR / f"chronos_forecast_{model_slug}.csv"
|
|
per_horizon_csv = DATA_DIR / f"chronos_errors_{model_slug}.csv"
|
|
|
|
pd.concat([target_series, forecast_series], axis=1).to_csv(out_csv, index=True)
|
|
pd.DataFrame({"horizon_h": np.arange(1, HORIZON_LEN + 1), "abs_error": abs_errors.values}).to_csv(per_horizon_csv, index=False)
|
|
|
|
# Figures
|
|
_plot_forecast(target_series, forecast_series, FIG_DIR / f"chronos_forecast_{model_slug}.png")
|
|
_plot_errors(abs_errors.rename("abs_error").rename_axis("h"), FIG_DIR / f"chronos_errors_{model_slug}.png")
|
|
|
|
print(f"✔ Sauvegardé : {out_csv}")
|
|
print(f"✔ Sauvegardé : {per_horizon_csv}")
|
|
print(f"✔ Figures : {FIG_DIR / f'chronos_forecast_{model_slug}.png'} ; {FIG_DIR / f'chronos_errors_{model_slug}.png'}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|