164 lines
6.0 KiB
Python
164 lines
6.0 KiB
Python
# scripts/run_chronos_multi.py
|
||
from __future__ import annotations
|
||
|
||
from pathlib import Path
|
||
import sys
|
||
from typing import Iterable
|
||
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
import pandas as pd
|
||
import torch
|
||
|
||
PROJECT_ROOT = Path(__file__).resolve().parents[3]
|
||
if str(PROJECT_ROOT) not in sys.path:
|
||
sys.path.insert(0, str(PROJECT_ROOT))
|
||
|
||
from meteo.dataset import load_raw_csv
|
||
|
||
try:
|
||
from chronos import ChronosPipeline
|
||
except ImportError as exc:
|
||
raise SystemExit("chronos-forecasting manquant : pip install -r requirements.txt") from exc
|
||
|
||
|
||
CSV_PATH = PROJECT_ROOT / "data" / "weather_minutely.csv"
|
||
DOC_DIR = Path(__file__).resolve().parent.parent
|
||
DATA_DIR = DOC_DIR / "data"
|
||
FIG_DIR = DOC_DIR / "figures"
|
||
|
||
MODEL_ID = "amazon/chronos-t5-small"
|
||
RESAMPLE_RULE = "1h" # on reste sur l'heure pour rester aligné avec le pré-entraînement Chronos
|
||
CONTEXT_H = 336 # 14 jours
|
||
HORIZONS_H = (1, 6, 24) # horizons demandés (10 min exclu car modèle horaire)
|
||
NUM_SAMPLES = 50
|
||
|
||
TARGETS = {
|
||
"temperature": {"kind": "reg"},
|
||
"wind_speed": {"kind": "reg"},
|
||
"rain_rate": {"kind": "rain"},
|
||
}
|
||
|
||
|
||
def _load_series(target: str) -> pd.Series:
|
||
df = load_raw_csv(CSV_PATH)
|
||
if target not in df.columns:
|
||
raise SystemExit(f"Colonne absente : {target}")
|
||
s = df[target].resample(RESAMPLE_RULE).mean().interpolate(limit_direction="both")
|
||
return s.dropna()
|
||
|
||
|
||
def _prepare_window(series: pd.Series, context_h: int, horizon_h: int) -> tuple[np.ndarray, pd.Series]:
|
||
needed = context_h + horizon_h
|
||
if len(series) < needed:
|
||
raise SystemExit(f"Pas assez de données pour {needed} heures.")
|
||
window = series.iloc[-needed:]
|
||
context = window.iloc[:context_h]
|
||
target = window.iloc[context_h:]
|
||
return context.to_numpy(dtype=float), target
|
||
|
||
|
||
def _plot_metrics(df: pd.DataFrame, target: str, output_path: Path) -> None:
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
fig, ax = plt.subplots(figsize=(6, 4))
|
||
if df["kind"].iloc[0] == "reg":
|
||
ax.plot(df["horizon_h"], df["mae"], marker="o", label="MAE")
|
||
ax.plot(df["horizon_h"], df["rmse"], marker="o", label="RMSE")
|
||
ax.set_ylabel("Erreur (°C)" if target == "temperature" else "Erreur (unité)")
|
||
else:
|
||
ax.plot(df["horizon_h"], df["f1"], marker="o", label="F1")
|
||
ax.plot(df["horizon_h"], df["brier"], marker="o", label="Brier")
|
||
ax.set_ylabel("Score")
|
||
ax.set_xlabel("Horizon (heures)")
|
||
ax.set_title(f"Chronos (small) – {target}")
|
||
ax.grid(True, linestyle=":", alpha=0.4)
|
||
ax.legend()
|
||
fig.tight_layout()
|
||
fig.savefig(output_path, dpi=150)
|
||
plt.close(fig)
|
||
|
||
|
||
def main() -> None:
|
||
if not CSV_PATH.exists():
|
||
raise SystemExit(f"Fichier introuvable : {CSV_PATH}")
|
||
|
||
pipeline = ChronosPipeline.from_pretrained(
|
||
MODEL_ID,
|
||
device_map="auto",
|
||
dtype="auto",
|
||
)
|
||
|
||
rows: list[dict[str, object]] = []
|
||
|
||
for target, meta in TARGETS.items():
|
||
series = _load_series(target)
|
||
# On prend la plus grande fenêtre (max horizon)
|
||
context_arr, target_series = _prepare_window(series, CONTEXT_H, max(HORIZONS_H))
|
||
|
||
context_tensor = torch.tensor(context_arr, dtype=torch.float32)
|
||
forecasts = pipeline.predict(
|
||
[context_tensor],
|
||
prediction_length=max(HORIZONS_H),
|
||
num_samples=NUM_SAMPLES,
|
||
)
|
||
forecast_mean = forecasts.mean(0)
|
||
if forecast_mean.ndim == 2:
|
||
forecast_mean = forecast_mean[0]
|
||
forecast_series = pd.Series(
|
||
np.asarray(forecast_mean).ravel(),
|
||
index=pd.date_range(target_series.index[0], periods=max(HORIZONS_H), freq=RESAMPLE_RULE),
|
||
)
|
||
|
||
for h in HORIZONS_H:
|
||
y_true = target_series.iloc[:h]
|
||
y_pred = forecast_series.iloc[:h]
|
||
if meta["kind"] == "reg":
|
||
mae = float((y_pred - y_true).abs().mean())
|
||
rmse = float(np.sqrt(((y_pred - y_true) ** 2).mean()))
|
||
rows.append(
|
||
{"target": target, "kind": "reg", "horizon_h": h, "mae": mae, "rmse": rmse}
|
||
)
|
||
else:
|
||
y_true_bin = (y_true > 0).astype(int)
|
||
y_pred_bin = (y_pred > 0).astype(int)
|
||
if y_true_bin.sum() == 0:
|
||
f1 = 0.0
|
||
else:
|
||
tp = ((y_true_bin == 1) & (y_pred_bin == 1)).sum()
|
||
fp = ((y_true_bin == 0) & (y_pred_bin == 1)).sum()
|
||
fn = ((y_true_bin == 1) & (y_pred_bin == 0)).sum()
|
||
prec = tp / (tp + fp) if tp + fp > 0 else 0.0
|
||
rec = tp / (tp + fn) if tp + fn > 0 else 0.0
|
||
f1 = 2 * prec * rec / (prec + rec) if prec + rec > 0 else 0.0
|
||
# Brier approximé en traitant y_pred comme proba clampée [0,1]
|
||
proba = y_pred.clip(0, 1)
|
||
brier = float(((proba - y_true_bin) ** 2).mean())
|
||
rows.append(
|
||
{"target": target, "kind": "cls", "horizon_h": h, "f1": float(f1), "brier": brier}
|
||
)
|
||
|
||
# Sauvegarde forecast/target complets pour inspection
|
||
out_prefix = f"{target}_{MODEL_ID.replace('/', '__')}"
|
||
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||
pd.concat([forecast_series.rename("y_pred"), target_series.rename("y_true")], axis=1).to_csv(
|
||
DATA_DIR / f"chronos_multi_forecast_{out_prefix}.csv",
|
||
index=True,
|
||
)
|
||
|
||
df_metrics = pd.DataFrame(rows)
|
||
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||
metrics_path = DATA_DIR / "chronos_multi_metrics.csv"
|
||
df_metrics.to_csv(metrics_path, index=False)
|
||
|
||
for target in df_metrics["target"].unique():
|
||
sub = df_metrics[df_metrics["target"] == target].sort_values("horizon_h")
|
||
_plot_metrics(sub, target, FIG_DIR / f"chronos_multi_{target}.png")
|
||
|
||
print(df_metrics.to_string(index=False, float_format=lambda x: f"{x:.3f}"))
|
||
print(f"✔ Sauvegardé : {metrics_path}")
|
||
print("✔ Figures par cible dans figures/chronos_multi_<target>.png")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|