You've already forked donnees_meteo
Modèle Chronos-2
This commit is contained in:
68
docs/11 - Modèle Chronos/scripts/compare_chronos.py
Normal file
68
docs/11 - Modèle Chronos/scripts/compare_chronos.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# scripts/compare_chronos.py
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
import re
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
DOC_DIR = Path(__file__).resolve().parent.parent
|
||||
DATA_DIR = DOC_DIR / "data"
|
||||
FIG_DIR = DOC_DIR / "figures"
|
||||
|
||||
|
||||
def _load_forecasts() -> pd.DataFrame:
|
||||
records = []
|
||||
pattern = re.compile(r"chronos_forecast_(.+)\.csv")
|
||||
for csv in DATA_DIR.glob("chronos_forecast_*.csv"):
|
||||
match = pattern.match(csv.name)
|
||||
if not match:
|
||||
continue
|
||||
model = match.group(1).replace("__", "/")
|
||||
df = pd.read_csv(csv)
|
||||
if not {"y_true", "y_pred"}.issubset(df.columns):
|
||||
continue
|
||||
err = df["y_pred"] - df["y_true"]
|
||||
records.append(
|
||||
{
|
||||
"model": model,
|
||||
"mae": err.abs().mean(),
|
||||
"rmse": (err.pow(2).mean()) ** 0.5,
|
||||
}
|
||||
)
|
||||
return pd.DataFrame(records)
|
||||
|
||||
|
||||
def _plot_comparison(df: pd.DataFrame, output_path: Path) -> None:
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fig, ax = plt.subplots(figsize=(7, 4))
|
||||
x = range(len(df))
|
||||
ax.bar([i - 0.15 for i in x], df["mae"], width=0.3, label="MAE")
|
||||
ax.bar([i + 0.15 for i in x], df["rmse"], width=0.3, label="RMSE")
|
||||
ax.set_xticks(list(x))
|
||||
ax.set_xticklabels(df["model"], rotation=15)
|
||||
ax.set_ylabel("Erreur (°C)")
|
||||
ax.set_title("Chronos T5 – comparaison des tailles")
|
||||
ax.grid(True, linestyle=":", alpha=0.4, axis="y")
|
||||
ax.legend()
|
||||
fig.tight_layout()
|
||||
fig.savefig(output_path, dpi=150)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
df = _load_forecasts()
|
||||
if df.empty:
|
||||
raise SystemExit("Aucune sortie chronos_forecast_*.csv trouvée dans data/. Lancez run_chronos.py d'abord.")
|
||||
df_sorted = df.sort_values("mae")
|
||||
summary_path = DATA_DIR / "chronos_summary.csv"
|
||||
df_sorted.to_csv(summary_path, index=False)
|
||||
_plot_comparison(df_sorted, FIG_DIR / "chronos_models_comparison.png")
|
||||
print(df_sorted.to_string(index=False, float_format=lambda x: f"{x:.3f}"))
|
||||
print(f"✔ Sauvegardé : {summary_path}")
|
||||
print(f"✔ Figure : {FIG_DIR / 'chronos_models_comparison.png'}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
152
docs/11 - Modèle Chronos/scripts/run_chronos.py
Normal file
152
docs/11 - Modèle Chronos/scripts/run_chronos.py
Normal file
@@ -0,0 +1,152 @@
|
||||
# scripts/run_chronos.py
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from typing import Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[3]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from meteo.dataset import load_raw_csv
|
||||
|
||||
try:
|
||||
from chronos import ChronosPipeline
|
||||
except ImportError as exc: # pragma: no cover - guidance if deps missing
|
||||
raise SystemExit(
|
||||
"chronos-forecasting est manquant. Installez-le (pip install chronos-forecasting) "
|
||||
"puis relancez."
|
||||
) from exc
|
||||
|
||||
|
||||
CSV_PATH = PROJECT_ROOT / "data" / "weather_minutely.csv"
|
||||
DOC_DIR = Path(__file__).resolve().parent.parent
|
||||
DATA_DIR = DOC_DIR / "data"
|
||||
FIG_DIR = DOC_DIR / "figures"
|
||||
|
||||
MODEL_ID = os.getenv("CHRONOS_MODEL", "amazon/chronos-t5-small")
|
||||
CONTEXT_LEN = int(os.getenv("CHRONOS_CONTEXT", "336")) # 14 jours (1H)
|
||||
HORIZON_LEN = int(os.getenv("CHRONOS_HORIZON", "96")) # 4 jours (1H)
|
||||
RESAMPLE_RULE = os.getenv("CHRONOS_RESAMPLE", "1h")
|
||||
NUM_SAMPLES = int(os.getenv("CHRONOS_SAMPLES", "20")) # échantillons stochastiques
|
||||
|
||||
|
||||
def _load_series(csv_path: Path, target_col: str, rule: str) -> pd.Series:
|
||||
df = load_raw_csv(csv_path)
|
||||
if target_col not in df.columns:
|
||||
raise SystemExit(f"Colonne absente dans le CSV : {target_col!r}")
|
||||
series = (
|
||||
df[target_col]
|
||||
.resample(rule)
|
||||
.mean()
|
||||
.interpolate(limit_direction="both")
|
||||
)
|
||||
return series.dropna()
|
||||
|
||||
|
||||
def _split_context_target(series: pd.Series, context_len: int, horizon_len: int) -> Tuple[pd.Series, pd.Series]:
|
||||
needed = context_len + horizon_len
|
||||
if len(series) < needed:
|
||||
raise SystemExit(f"Pas assez de données après resampling : {len(series)} < {needed} (context+horizon).")
|
||||
window = series.iloc[-needed:]
|
||||
context = window.iloc[:context_len]
|
||||
target = window.iloc[context_len:]
|
||||
return context, target
|
||||
|
||||
|
||||
def _plot_forecast(target: pd.Series, forecast: pd.Series, output_path: Path) -> None:
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fig, ax = plt.subplots(figsize=(10, 4))
|
||||
ax.plot(target.index, target.values, label="Observé", linewidth=2)
|
||||
ax.plot(forecast.index, forecast.values, label="Chronos-2 (préd.)", linewidth=2, linestyle="--")
|
||||
ax.set_title("Prévision Chronos-2 vs observation")
|
||||
ax.set_xlabel("Date")
|
||||
ax.set_ylabel("Température (°C)")
|
||||
ax.grid(True, linestyle=":", alpha=0.4)
|
||||
ax.legend()
|
||||
fig.autofmt_xdate()
|
||||
fig.tight_layout()
|
||||
fig.savefig(output_path, dpi=150)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def _plot_errors(errors: pd.Series, output_path: Path) -> None:
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fig, ax = plt.subplots(figsize=(8, 4))
|
||||
ax.plot(errors.index, errors.values, marker="o")
|
||||
ax.set_title("Erreur absolue par horizon (heures)")
|
||||
ax.set_xlabel("Horizon (h)")
|
||||
ax.set_ylabel("Erreur absolue (°C)")
|
||||
ax.grid(True, linestyle=":", alpha=0.4)
|
||||
fig.tight_layout()
|
||||
fig.savefig(output_path, dpi=150)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
if not CSV_PATH.exists():
|
||||
raise SystemExit(f"Fichier introuvable : {CSV_PATH}")
|
||||
|
||||
series = _load_series(CSV_PATH, target_col="temperature", rule=RESAMPLE_RULE)
|
||||
context, target = _split_context_target(series, CONTEXT_LEN, HORIZON_LEN)
|
||||
|
||||
model_slug = MODEL_ID.replace("/", "__")
|
||||
print(f"Contexte : {len(context)} points ({context.index[0]} -> {context.index[-1]})")
|
||||
print(f"Cible : {len(target)} points ({target.index[0]} -> {target.index[-1]})")
|
||||
print(f"Modèle HF : {MODEL_ID}")
|
||||
|
||||
pipeline = ChronosPipeline.from_pretrained(
|
||||
MODEL_ID,
|
||||
device_map="auto", # CUDA/MPS/CPU automatique
|
||||
dtype="auto",
|
||||
)
|
||||
|
||||
context_array = context.to_numpy(dtype=float)
|
||||
context_tensor = torch.tensor(context_array, dtype=torch.float32)
|
||||
forecasts = pipeline.predict(
|
||||
[context_tensor], # batch de 1 série (Tensor attendu)
|
||||
prediction_length=HORIZON_LEN,
|
||||
num_samples=NUM_SAMPLES,
|
||||
)
|
||||
forecast_mean = forecasts.mean(0)
|
||||
# Si batch dimension présente, on prend la première série
|
||||
if forecast_mean.ndim == 2:
|
||||
forecast_mean = forecast_mean[0]
|
||||
forecast_mean = np.asarray(forecast_mean).ravel()
|
||||
|
||||
forecast_index = pd.date_range(target.index[0], periods=HORIZON_LEN, freq=RESAMPLE_RULE)
|
||||
forecast_series = pd.Series(forecast_mean, index=forecast_index, name="y_pred")
|
||||
target_series = target.rename("y_true")
|
||||
|
||||
# Métriques
|
||||
abs_errors = (forecast_series - target_series).abs()
|
||||
mae = abs_errors.mean()
|
||||
rmse = np.sqrt(((forecast_series - target_series) ** 2).mean())
|
||||
print(f"MAE : {mae:.3f} °C | RMSE : {rmse:.3f} °C")
|
||||
|
||||
# Sauvegardes
|
||||
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||
out_csv = DATA_DIR / f"chronos_forecast_{model_slug}.csv"
|
||||
per_horizon_csv = DATA_DIR / f"chronos_errors_{model_slug}.csv"
|
||||
|
||||
pd.concat([target_series, forecast_series], axis=1).to_csv(out_csv, index=True)
|
||||
pd.DataFrame({"horizon_h": np.arange(1, HORIZON_LEN + 1), "abs_error": abs_errors.values}).to_csv(per_horizon_csv, index=False)
|
||||
|
||||
# Figures
|
||||
_plot_forecast(target_series, forecast_series, FIG_DIR / f"chronos_forecast_{model_slug}.png")
|
||||
_plot_errors(abs_errors.rename("abs_error").rename_axis("h"), FIG_DIR / f"chronos_errors_{model_slug}.png")
|
||||
|
||||
print(f"✔ Sauvegardé : {out_csv}")
|
||||
print(f"✔ Sauvegardé : {per_horizon_csv}")
|
||||
print(f"✔ Figures : {FIG_DIR / f'chronos_forecast_{model_slug}.png'} ; {FIG_DIR / f'chronos_errors_{model_slug}.png'}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
147
docs/11 - Modèle Chronos/scripts/run_chronos_holdout6.py
Normal file
147
docs/11 - Modèle Chronos/scripts/run_chronos_holdout6.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# scripts/run_chronos_holdout6.py
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from typing import Dict
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[3]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from meteo.dataset import load_raw_csv
|
||||
|
||||
try:
|
||||
from chronos import ChronosPipeline
|
||||
except ImportError as exc:
|
||||
raise SystemExit("chronos-forecasting manquant : pip install -r requirements.txt") from exc
|
||||
|
||||
|
||||
CSV_PATH = PROJECT_ROOT / "data" / "weather_minutely.csv"
|
||||
DOC_DIR = Path(__file__).resolve().parent.parent
|
||||
DATA_DIR = DOC_DIR / "data"
|
||||
FIG_DIR = DOC_DIR / "figures"
|
||||
|
||||
MODEL_ID = "amazon/chronos-t5-small"
|
||||
RESAMPLE_RULE = "1h"
|
||||
CONTEXT_H = 288 # 12 jours
|
||||
HOLDOUT_H = 6
|
||||
NUM_SAMPLES = 50
|
||||
|
||||
TARGETS = {
|
||||
"temperature": {"kind": "reg", "unit": "°C"},
|
||||
"wind_speed": {"kind": "reg", "unit": "km/h"},
|
||||
"wind_direction": {"kind": "angle", "unit": "deg"},
|
||||
"humidity": {"kind": "reg", "unit": "%"},
|
||||
"pressure": {"kind": "reg", "unit": "hPa"},
|
||||
}
|
||||
|
||||
|
||||
def _load_series(target: str) -> pd.Series:
|
||||
df = load_raw_csv(CSV_PATH)
|
||||
if target not in df.columns:
|
||||
raise SystemExit(f"Colonne absente : {target}")
|
||||
return (
|
||||
df[target]
|
||||
.resample(RESAMPLE_RULE)
|
||||
.mean()
|
||||
.interpolate(limit_direction="both")
|
||||
.dropna()
|
||||
)
|
||||
|
||||
|
||||
def _angular_error(y_true: pd.Series, y_pred: pd.Series) -> pd.Series:
|
||||
diff = (y_pred - y_true).abs() % 360
|
||||
return diff.apply(lambda x: min(x, 360 - x))
|
||||
|
||||
|
||||
def _plot_errors(df: pd.DataFrame, output_path: Path) -> None:
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fig, ax = plt.subplots(figsize=(8, 4))
|
||||
for target in df["target"].unique():
|
||||
sub = df[df["target"] == target]
|
||||
ax.plot(sub["h"], sub["abs_error"], marker="o", label=target)
|
||||
ax.set_xlabel("Heure du horizon (1-6)")
|
||||
ax.set_ylabel("Erreur absolue (unité cible)")
|
||||
ax.set_title("Chronos small – erreurs sur les 6 dernières heures (holdout)")
|
||||
ax.grid(True, linestyle=":", alpha=0.4)
|
||||
ax.legend()
|
||||
fig.tight_layout()
|
||||
fig.savefig(output_path, dpi=150)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
if not CSV_PATH.exists():
|
||||
raise SystemExit(f"Fichier introuvable : {CSV_PATH}")
|
||||
|
||||
pipeline = ChronosPipeline.from_pretrained(
|
||||
MODEL_ID,
|
||||
device_map="auto",
|
||||
dtype="auto",
|
||||
)
|
||||
|
||||
rows: list[Dict[str, object]] = []
|
||||
|
||||
for target, meta in TARGETS.items():
|
||||
series = _load_series(target)
|
||||
if len(series) < CONTEXT_H + HOLDOUT_H:
|
||||
raise SystemExit("Pas assez de données pour contexte+holdout.")
|
||||
|
||||
context = series.iloc[-(CONTEXT_H + HOLDOUT_H) : -HOLDOUT_H]
|
||||
holdout = series.iloc[-HOLDOUT_H:]
|
||||
|
||||
context_tensor = torch.tensor(context.to_numpy(dtype=float), dtype=torch.float32)
|
||||
forecasts = pipeline.predict(
|
||||
[context_tensor],
|
||||
prediction_length=HOLDOUT_H,
|
||||
num_samples=NUM_SAMPLES,
|
||||
)
|
||||
forecast_mean = forecasts.mean(0)
|
||||
if forecast_mean.ndim == 2:
|
||||
forecast_mean = forecast_mean[0]
|
||||
forecast_series = pd.Series(
|
||||
np.asarray(forecast_mean).ravel(),
|
||||
index=holdout.index,
|
||||
)
|
||||
|
||||
if meta["kind"] == "angle":
|
||||
err = _angular_error(holdout, forecast_series)
|
||||
else:
|
||||
err = (forecast_series - holdout).abs()
|
||||
|
||||
for i, (ts, e) in enumerate(err.items(), start=1):
|
||||
rows.append(
|
||||
{
|
||||
"target": target,
|
||||
"timestamp": ts,
|
||||
"h": i,
|
||||
"abs_error": float(e),
|
||||
"unit": meta["unit"],
|
||||
}
|
||||
)
|
||||
|
||||
out_prefix = f"{target}_{MODEL_ID.replace('/', '__')}_holdout6"
|
||||
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||
pd.concat([forecast_series.rename("y_pred"), holdout.rename("y_true")], axis=1).to_csv(
|
||||
DATA_DIR / f"chronos_holdout6_forecast_{out_prefix}.csv",
|
||||
index=True,
|
||||
)
|
||||
|
||||
df_errors = pd.DataFrame(rows)
|
||||
err_path = DATA_DIR / "chronos_holdout6_errors.csv"
|
||||
df_errors.to_csv(err_path, index=False)
|
||||
_plot_errors(df_errors, FIG_DIR / "chronos_holdout6_errors.png")
|
||||
|
||||
print(df_errors.groupby("target")["abs_error"].mean().to_string(float_format=lambda x: f"{x:.3f}"))
|
||||
print(f"✔ Sauvegardé : {err_path}")
|
||||
print("✔ Figure : figures/chronos_holdout6_errors.png")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
163
docs/11 - Modèle Chronos/scripts/run_chronos_multi.py
Normal file
163
docs/11 - Modèle Chronos/scripts/run_chronos_multi.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# scripts/run_chronos_multi.py
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from typing import Iterable
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[3]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from meteo.dataset import load_raw_csv
|
||||
|
||||
try:
|
||||
from chronos import ChronosPipeline
|
||||
except ImportError as exc:
|
||||
raise SystemExit("chronos-forecasting manquant : pip install -r requirements.txt") from exc
|
||||
|
||||
|
||||
CSV_PATH = PROJECT_ROOT / "data" / "weather_minutely.csv"
|
||||
DOC_DIR = Path(__file__).resolve().parent.parent
|
||||
DATA_DIR = DOC_DIR / "data"
|
||||
FIG_DIR = DOC_DIR / "figures"
|
||||
|
||||
MODEL_ID = "amazon/chronos-t5-small"
|
||||
RESAMPLE_RULE = "1h" # on reste sur l'heure pour rester aligné avec le pré-entraînement Chronos
|
||||
CONTEXT_H = 336 # 14 jours
|
||||
HORIZONS_H = (1, 6, 24) # horizons demandés (10 min exclu car modèle horaire)
|
||||
NUM_SAMPLES = 50
|
||||
|
||||
TARGETS = {
|
||||
"temperature": {"kind": "reg"},
|
||||
"wind_speed": {"kind": "reg"},
|
||||
"rain_rate": {"kind": "rain"},
|
||||
}
|
||||
|
||||
|
||||
def _load_series(target: str) -> pd.Series:
|
||||
df = load_raw_csv(CSV_PATH)
|
||||
if target not in df.columns:
|
||||
raise SystemExit(f"Colonne absente : {target}")
|
||||
s = df[target].resample(RESAMPLE_RULE).mean().interpolate(limit_direction="both")
|
||||
return s.dropna()
|
||||
|
||||
|
||||
def _prepare_window(series: pd.Series, context_h: int, horizon_h: int) -> tuple[np.ndarray, pd.Series]:
|
||||
needed = context_h + horizon_h
|
||||
if len(series) < needed:
|
||||
raise SystemExit(f"Pas assez de données pour {needed} heures.")
|
||||
window = series.iloc[-needed:]
|
||||
context = window.iloc[:context_h]
|
||||
target = window.iloc[context_h:]
|
||||
return context.to_numpy(dtype=float), target
|
||||
|
||||
|
||||
def _plot_metrics(df: pd.DataFrame, target: str, output_path: Path) -> None:
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fig, ax = plt.subplots(figsize=(6, 4))
|
||||
if df["kind"].iloc[0] == "reg":
|
||||
ax.plot(df["horizon_h"], df["mae"], marker="o", label="MAE")
|
||||
ax.plot(df["horizon_h"], df["rmse"], marker="o", label="RMSE")
|
||||
ax.set_ylabel("Erreur (°C)" if target == "temperature" else "Erreur (unité)")
|
||||
else:
|
||||
ax.plot(df["horizon_h"], df["f1"], marker="o", label="F1")
|
||||
ax.plot(df["horizon_h"], df["brier"], marker="o", label="Brier")
|
||||
ax.set_ylabel("Score")
|
||||
ax.set_xlabel("Horizon (heures)")
|
||||
ax.set_title(f"Chronos (small) – {target}")
|
||||
ax.grid(True, linestyle=":", alpha=0.4)
|
||||
ax.legend()
|
||||
fig.tight_layout()
|
||||
fig.savefig(output_path, dpi=150)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
if not CSV_PATH.exists():
|
||||
raise SystemExit(f"Fichier introuvable : {CSV_PATH}")
|
||||
|
||||
pipeline = ChronosPipeline.from_pretrained(
|
||||
MODEL_ID,
|
||||
device_map="auto",
|
||||
dtype="auto",
|
||||
)
|
||||
|
||||
rows: list[dict[str, object]] = []
|
||||
|
||||
for target, meta in TARGETS.items():
|
||||
series = _load_series(target)
|
||||
# On prend la plus grande fenêtre (max horizon)
|
||||
context_arr, target_series = _prepare_window(series, CONTEXT_H, max(HORIZONS_H))
|
||||
|
||||
context_tensor = torch.tensor(context_arr, dtype=torch.float32)
|
||||
forecasts = pipeline.predict(
|
||||
[context_tensor],
|
||||
prediction_length=max(HORIZONS_H),
|
||||
num_samples=NUM_SAMPLES,
|
||||
)
|
||||
forecast_mean = forecasts.mean(0)
|
||||
if forecast_mean.ndim == 2:
|
||||
forecast_mean = forecast_mean[0]
|
||||
forecast_series = pd.Series(
|
||||
np.asarray(forecast_mean).ravel(),
|
||||
index=pd.date_range(target_series.index[0], periods=max(HORIZONS_H), freq=RESAMPLE_RULE),
|
||||
)
|
||||
|
||||
for h in HORIZONS_H:
|
||||
y_true = target_series.iloc[:h]
|
||||
y_pred = forecast_series.iloc[:h]
|
||||
if meta["kind"] == "reg":
|
||||
mae = float((y_pred - y_true).abs().mean())
|
||||
rmse = float(np.sqrt(((y_pred - y_true) ** 2).mean()))
|
||||
rows.append(
|
||||
{"target": target, "kind": "reg", "horizon_h": h, "mae": mae, "rmse": rmse}
|
||||
)
|
||||
else:
|
||||
y_true_bin = (y_true > 0).astype(int)
|
||||
y_pred_bin = (y_pred > 0).astype(int)
|
||||
if y_true_bin.sum() == 0:
|
||||
f1 = 0.0
|
||||
else:
|
||||
tp = ((y_true_bin == 1) & (y_pred_bin == 1)).sum()
|
||||
fp = ((y_true_bin == 0) & (y_pred_bin == 1)).sum()
|
||||
fn = ((y_true_bin == 1) & (y_pred_bin == 0)).sum()
|
||||
prec = tp / (tp + fp) if tp + fp > 0 else 0.0
|
||||
rec = tp / (tp + fn) if tp + fn > 0 else 0.0
|
||||
f1 = 2 * prec * rec / (prec + rec) if prec + rec > 0 else 0.0
|
||||
# Brier approximé en traitant y_pred comme proba clampée [0,1]
|
||||
proba = y_pred.clip(0, 1)
|
||||
brier = float(((proba - y_true_bin) ** 2).mean())
|
||||
rows.append(
|
||||
{"target": target, "kind": "cls", "horizon_h": h, "f1": float(f1), "brier": brier}
|
||||
)
|
||||
|
||||
# Sauvegarde forecast/target complets pour inspection
|
||||
out_prefix = f"{target}_{MODEL_ID.replace('/', '__')}"
|
||||
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||
pd.concat([forecast_series.rename("y_pred"), target_series.rename("y_true")], axis=1).to_csv(
|
||||
DATA_DIR / f"chronos_multi_forecast_{out_prefix}.csv",
|
||||
index=True,
|
||||
)
|
||||
|
||||
df_metrics = pd.DataFrame(rows)
|
||||
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||
metrics_path = DATA_DIR / "chronos_multi_metrics.csv"
|
||||
df_metrics.to_csv(metrics_path, index=False)
|
||||
|
||||
for target in df_metrics["target"].unique():
|
||||
sub = df_metrics[df_metrics["target"] == target].sort_values("horizon_h")
|
||||
_plot_metrics(sub, target, FIG_DIR / f"chronos_multi_{target}.png")
|
||||
|
||||
print(df_metrics.to_string(index=False, float_format=lambda x: f"{x:.3f}"))
|
||||
print(f"✔ Sauvegardé : {metrics_path}")
|
||||
print("✔ Figures par cible dans figures/chronos_multi_<target>.png")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
157
docs/11 - Modèle Chronos/scripts/run_chronos_tuned.py
Normal file
157
docs/11 - Modèle Chronos/scripts/run_chronos_tuned.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# scripts/run_chronos_tuned.py
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[3]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from meteo.dataset import load_raw_csv
|
||||
|
||||
try:
|
||||
from chronos import ChronosPipeline
|
||||
except ImportError as exc:
|
||||
raise SystemExit("chronos-forecasting manquant : pip install -r requirements.txt") from exc
|
||||
|
||||
CSV_PATH = PROJECT_ROOT / "data" / "weather_minutely.csv"
|
||||
DOC_DIR = Path(__file__).resolve().parent.parent
|
||||
DATA_DIR = DOC_DIR / "data"
|
||||
FIG_DIR = DOC_DIR / "figures"
|
||||
|
||||
# Réglages plus prudents : contexte raccourci, horizon 64 max, plus d'échantillons
|
||||
MODEL_ID = "amazon/chronos-t5-small"
|
||||
RESAMPLE_RULE = "1h"
|
||||
CONTEXT_H = 288 # 12 jours
|
||||
MAX_HORIZON_H = 64 # <=64 conseillé
|
||||
HORIZONS_H = (1, 6, 24, 48) # on reste cohérents avec nos jalons
|
||||
NUM_SAMPLES = 100
|
||||
|
||||
TARGETS = {
|
||||
"temperature": {"kind": "reg"},
|
||||
"wind_speed": {"kind": "reg"},
|
||||
"rain_rate": {"kind": "rain"},
|
||||
}
|
||||
|
||||
|
||||
def _load_series(target: str) -> pd.Series:
|
||||
df = load_raw_csv(CSV_PATH)
|
||||
if target not in df.columns:
|
||||
raise SystemExit(f"Colonne absente : {target}")
|
||||
s = df[target].resample(RESAMPLE_RULE).mean().interpolate(limit_direction="both")
|
||||
return s.dropna()
|
||||
|
||||
|
||||
def _prepare_window(series: pd.Series, context_h: int, horizon_h: int) -> tuple[np.ndarray, pd.Series]:
|
||||
needed = context_h + horizon_h
|
||||
if len(series) < needed:
|
||||
raise SystemExit(f"Pas assez de données pour {needed} heures.")
|
||||
window = series.iloc[-needed:]
|
||||
context = window.iloc[:context_h]
|
||||
target = window.iloc[context_h:]
|
||||
return context.to_numpy(dtype=float), target
|
||||
|
||||
|
||||
def _plot_metrics(df: pd.DataFrame, target: str, output_path: Path) -> None:
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fig, ax = plt.subplots(figsize=(6, 4))
|
||||
if df["kind"].iloc[0] == "reg":
|
||||
ax.plot(df["horizon_h"], df["mae"], marker="o", label="MAE")
|
||||
ax.plot(df["horizon_h"], df["rmse"], marker="o", label="RMSE")
|
||||
ax.set_ylabel("Erreur (°C)" if target == "temperature" else "Erreur (unité)")
|
||||
else:
|
||||
ax.plot(df["horizon_h"], df["f1"], marker="o", label="F1")
|
||||
ax.plot(df["horizon_h"], df["brier"], marker="o", label="Brier")
|
||||
ax.set_ylabel("Score")
|
||||
ax.set_xlabel("Horizon (heures)")
|
||||
ax.set_title(f"Chronos (small, réglages prudents) – {target}")
|
||||
ax.grid(True, linestyle=":", alpha=0.4)
|
||||
ax.legend()
|
||||
fig.tight_layout()
|
||||
fig.savefig(output_path, dpi=150)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
pipeline = ChronosPipeline.from_pretrained(
|
||||
MODEL_ID,
|
||||
device_map="auto",
|
||||
dtype="auto",
|
||||
)
|
||||
|
||||
rows: list[dict[str, object]] = []
|
||||
|
||||
for target, meta in TARGETS.items():
|
||||
series = _load_series(target)
|
||||
context_arr, target_series = _prepare_window(series, CONTEXT_H, MAX_HORIZON_H)
|
||||
|
||||
context_tensor = torch.tensor(context_arr, dtype=torch.float32)
|
||||
forecasts = pipeline.predict(
|
||||
[context_tensor],
|
||||
prediction_length=MAX_HORIZON_H,
|
||||
num_samples=NUM_SAMPLES,
|
||||
)
|
||||
forecast_mean = forecasts.mean(0)
|
||||
if forecast_mean.ndim == 2:
|
||||
forecast_mean = forecast_mean[0]
|
||||
forecast_series = pd.Series(
|
||||
np.asarray(forecast_mean).ravel(),
|
||||
index=pd.date_range(target_series.index[0], periods=MAX_HORIZON_H, freq=RESAMPLE_RULE),
|
||||
)
|
||||
|
||||
for h in HORIZONS_H:
|
||||
y_true = target_series.iloc[:h]
|
||||
y_pred = forecast_series.iloc[:h]
|
||||
if meta["kind"] == "reg":
|
||||
mae = float((y_pred - y_true).abs().mean())
|
||||
rmse = float(np.sqrt(((y_pred - y_true) ** 2).mean()))
|
||||
rows.append(
|
||||
{"target": target, "kind": "reg", "horizon_h": h, "mae": mae, "rmse": rmse}
|
||||
)
|
||||
else:
|
||||
y_true_bin = (y_true > 0).astype(int)
|
||||
y_pred_bin = (y_pred > 0).astype(int)
|
||||
if y_true_bin.sum() == 0:
|
||||
f1 = 0.0
|
||||
else:
|
||||
tp = ((y_true_bin == 1) & (y_pred_bin == 1)).sum()
|
||||
fp = ((y_true_bin == 0) & (y_pred_bin == 1)).sum()
|
||||
fn = ((y_true_bin == 1) & (y_pred_bin == 0)).sum()
|
||||
prec = tp / (tp + fp) if tp + fp > 0 else 0.0
|
||||
rec = tp / (tp + fn) if tp + fn > 0 else 0.0
|
||||
f1 = 2 * prec * rec / (prec + rec) if prec + rec > 0 else 0.0
|
||||
proba = y_pred.clip(0, 1)
|
||||
brier = float(((proba - y_true_bin) ** 2).mean())
|
||||
rows.append(
|
||||
{"target": target, "kind": "cls", "horizon_h": h, "f1": float(f1), "brier": brier}
|
||||
)
|
||||
|
||||
out_prefix = f"{target}_{MODEL_ID.replace('/', '__')}_tuned"
|
||||
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||
pd.concat([forecast_series.rename("y_pred"), target_series.rename("y_true")], axis=1).to_csv(
|
||||
DATA_DIR / f"chronos_tuned_forecast_{out_prefix}.csv",
|
||||
index=True,
|
||||
)
|
||||
|
||||
df_metrics = pd.DataFrame(rows)
|
||||
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||
metrics_path = DATA_DIR / "chronos_tuned_metrics.csv"
|
||||
df_metrics.to_csv(metrics_path, index=False)
|
||||
|
||||
for target in df_metrics["target"].unique():
|
||||
sub = df_metrics[df_metrics["target"] == target].sort_values("horizon_h")
|
||||
_plot_metrics(sub, target, FIG_DIR / f"chronos_tuned_{target}.png")
|
||||
|
||||
print(df_metrics.to_string(index=False, float_format=lambda x: f"{x:.3f}"))
|
||||
print(f"✔ Sauvegardé : {metrics_path}")
|
||||
print("✔ Figures par cible (réglages prudents) dans figures/chronos_tuned_<target>.png")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user