1

Modèle Chronos-2

This commit is contained in:
2025-11-26 01:03:41 +01:00
parent ccd2195d27
commit 9a393972eb
22 changed files with 743 additions and 1 deletions

View File

@@ -0,0 +1,68 @@
# scripts/compare_chronos.py
from __future__ import annotations
from pathlib import Path
import re
import matplotlib.pyplot as plt
import pandas as pd
DOC_DIR = Path(__file__).resolve().parent.parent
DATA_DIR = DOC_DIR / "data"
FIG_DIR = DOC_DIR / "figures"
def _load_forecasts() -> pd.DataFrame:
records = []
pattern = re.compile(r"chronos_forecast_(.+)\.csv")
for csv in DATA_DIR.glob("chronos_forecast_*.csv"):
match = pattern.match(csv.name)
if not match:
continue
model = match.group(1).replace("__", "/")
df = pd.read_csv(csv)
if not {"y_true", "y_pred"}.issubset(df.columns):
continue
err = df["y_pred"] - df["y_true"]
records.append(
{
"model": model,
"mae": err.abs().mean(),
"rmse": (err.pow(2).mean()) ** 0.5,
}
)
return pd.DataFrame(records)
def _plot_comparison(df: pd.DataFrame, output_path: Path) -> None:
output_path.parent.mkdir(parents=True, exist_ok=True)
fig, ax = plt.subplots(figsize=(7, 4))
x = range(len(df))
ax.bar([i - 0.15 for i in x], df["mae"], width=0.3, label="MAE")
ax.bar([i + 0.15 for i in x], df["rmse"], width=0.3, label="RMSE")
ax.set_xticks(list(x))
ax.set_xticklabels(df["model"], rotation=15)
ax.set_ylabel("Erreur (°C)")
ax.set_title("Chronos T5 comparaison des tailles")
ax.grid(True, linestyle=":", alpha=0.4, axis="y")
ax.legend()
fig.tight_layout()
fig.savefig(output_path, dpi=150)
plt.close(fig)
def main() -> None:
df = _load_forecasts()
if df.empty:
raise SystemExit("Aucune sortie chronos_forecast_*.csv trouvée dans data/. Lancez run_chronos.py d'abord.")
df_sorted = df.sort_values("mae")
summary_path = DATA_DIR / "chronos_summary.csv"
df_sorted.to_csv(summary_path, index=False)
_plot_comparison(df_sorted, FIG_DIR / "chronos_models_comparison.png")
print(df_sorted.to_string(index=False, float_format=lambda x: f"{x:.3f}"))
print(f"✔ Sauvegardé : {summary_path}")
print(f"✔ Figure : {FIG_DIR / 'chronos_models_comparison.png'}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,152 @@
# scripts/run_chronos.py
from __future__ import annotations
import os
from pathlib import Path
import sys
from typing import Tuple
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
PROJECT_ROOT = Path(__file__).resolve().parents[3]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from meteo.dataset import load_raw_csv
try:
from chronos import ChronosPipeline
except ImportError as exc: # pragma: no cover - guidance if deps missing
raise SystemExit(
"chronos-forecasting est manquant. Installez-le (pip install chronos-forecasting) "
"puis relancez."
) from exc
CSV_PATH = PROJECT_ROOT / "data" / "weather_minutely.csv"
DOC_DIR = Path(__file__).resolve().parent.parent
DATA_DIR = DOC_DIR / "data"
FIG_DIR = DOC_DIR / "figures"
MODEL_ID = os.getenv("CHRONOS_MODEL", "amazon/chronos-t5-small")
CONTEXT_LEN = int(os.getenv("CHRONOS_CONTEXT", "336")) # 14 jours (1H)
HORIZON_LEN = int(os.getenv("CHRONOS_HORIZON", "96")) # 4 jours (1H)
RESAMPLE_RULE = os.getenv("CHRONOS_RESAMPLE", "1h")
NUM_SAMPLES = int(os.getenv("CHRONOS_SAMPLES", "20")) # échantillons stochastiques
def _load_series(csv_path: Path, target_col: str, rule: str) -> pd.Series:
df = load_raw_csv(csv_path)
if target_col not in df.columns:
raise SystemExit(f"Colonne absente dans le CSV : {target_col!r}")
series = (
df[target_col]
.resample(rule)
.mean()
.interpolate(limit_direction="both")
)
return series.dropna()
def _split_context_target(series: pd.Series, context_len: int, horizon_len: int) -> Tuple[pd.Series, pd.Series]:
needed = context_len + horizon_len
if len(series) < needed:
raise SystemExit(f"Pas assez de données après resampling : {len(series)} < {needed} (context+horizon).")
window = series.iloc[-needed:]
context = window.iloc[:context_len]
target = window.iloc[context_len:]
return context, target
def _plot_forecast(target: pd.Series, forecast: pd.Series, output_path: Path) -> None:
output_path.parent.mkdir(parents=True, exist_ok=True)
fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(target.index, target.values, label="Observé", linewidth=2)
ax.plot(forecast.index, forecast.values, label="Chronos-2 (préd.)", linewidth=2, linestyle="--")
ax.set_title("Prévision Chronos-2 vs observation")
ax.set_xlabel("Date")
ax.set_ylabel("Température (°C)")
ax.grid(True, linestyle=":", alpha=0.4)
ax.legend()
fig.autofmt_xdate()
fig.tight_layout()
fig.savefig(output_path, dpi=150)
plt.close(fig)
def _plot_errors(errors: pd.Series, output_path: Path) -> None:
output_path.parent.mkdir(parents=True, exist_ok=True)
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(errors.index, errors.values, marker="o")
ax.set_title("Erreur absolue par horizon (heures)")
ax.set_xlabel("Horizon (h)")
ax.set_ylabel("Erreur absolue (°C)")
ax.grid(True, linestyle=":", alpha=0.4)
fig.tight_layout()
fig.savefig(output_path, dpi=150)
plt.close(fig)
def main() -> None:
if not CSV_PATH.exists():
raise SystemExit(f"Fichier introuvable : {CSV_PATH}")
series = _load_series(CSV_PATH, target_col="temperature", rule=RESAMPLE_RULE)
context, target = _split_context_target(series, CONTEXT_LEN, HORIZON_LEN)
model_slug = MODEL_ID.replace("/", "__")
print(f"Contexte : {len(context)} points ({context.index[0]} -> {context.index[-1]})")
print(f"Cible : {len(target)} points ({target.index[0]} -> {target.index[-1]})")
print(f"Modèle HF : {MODEL_ID}")
pipeline = ChronosPipeline.from_pretrained(
MODEL_ID,
device_map="auto", # CUDA/MPS/CPU automatique
dtype="auto",
)
context_array = context.to_numpy(dtype=float)
context_tensor = torch.tensor(context_array, dtype=torch.float32)
forecasts = pipeline.predict(
[context_tensor], # batch de 1 série (Tensor attendu)
prediction_length=HORIZON_LEN,
num_samples=NUM_SAMPLES,
)
forecast_mean = forecasts.mean(0)
# Si batch dimension présente, on prend la première série
if forecast_mean.ndim == 2:
forecast_mean = forecast_mean[0]
forecast_mean = np.asarray(forecast_mean).ravel()
forecast_index = pd.date_range(target.index[0], periods=HORIZON_LEN, freq=RESAMPLE_RULE)
forecast_series = pd.Series(forecast_mean, index=forecast_index, name="y_pred")
target_series = target.rename("y_true")
# Métriques
abs_errors = (forecast_series - target_series).abs()
mae = abs_errors.mean()
rmse = np.sqrt(((forecast_series - target_series) ** 2).mean())
print(f"MAE : {mae:.3f} °C | RMSE : {rmse:.3f} °C")
# Sauvegardes
DATA_DIR.mkdir(parents=True, exist_ok=True)
out_csv = DATA_DIR / f"chronos_forecast_{model_slug}.csv"
per_horizon_csv = DATA_DIR / f"chronos_errors_{model_slug}.csv"
pd.concat([target_series, forecast_series], axis=1).to_csv(out_csv, index=True)
pd.DataFrame({"horizon_h": np.arange(1, HORIZON_LEN + 1), "abs_error": abs_errors.values}).to_csv(per_horizon_csv, index=False)
# Figures
_plot_forecast(target_series, forecast_series, FIG_DIR / f"chronos_forecast_{model_slug}.png")
_plot_errors(abs_errors.rename("abs_error").rename_axis("h"), FIG_DIR / f"chronos_errors_{model_slug}.png")
print(f"✔ Sauvegardé : {out_csv}")
print(f"✔ Sauvegardé : {per_horizon_csv}")
print(f"✔ Figures : {FIG_DIR / f'chronos_forecast_{model_slug}.png'} ; {FIG_DIR / f'chronos_errors_{model_slug}.png'}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,147 @@
# scripts/run_chronos_holdout6.py
from __future__ import annotations
from pathlib import Path
import sys
from typing import Dict
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
PROJECT_ROOT = Path(__file__).resolve().parents[3]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from meteo.dataset import load_raw_csv
try:
from chronos import ChronosPipeline
except ImportError as exc:
raise SystemExit("chronos-forecasting manquant : pip install -r requirements.txt") from exc
CSV_PATH = PROJECT_ROOT / "data" / "weather_minutely.csv"
DOC_DIR = Path(__file__).resolve().parent.parent
DATA_DIR = DOC_DIR / "data"
FIG_DIR = DOC_DIR / "figures"
MODEL_ID = "amazon/chronos-t5-small"
RESAMPLE_RULE = "1h"
CONTEXT_H = 288 # 12 jours
HOLDOUT_H = 6
NUM_SAMPLES = 50
TARGETS = {
"temperature": {"kind": "reg", "unit": "°C"},
"wind_speed": {"kind": "reg", "unit": "km/h"},
"wind_direction": {"kind": "angle", "unit": "deg"},
"humidity": {"kind": "reg", "unit": "%"},
"pressure": {"kind": "reg", "unit": "hPa"},
}
def _load_series(target: str) -> pd.Series:
df = load_raw_csv(CSV_PATH)
if target not in df.columns:
raise SystemExit(f"Colonne absente : {target}")
return (
df[target]
.resample(RESAMPLE_RULE)
.mean()
.interpolate(limit_direction="both")
.dropna()
)
def _angular_error(y_true: pd.Series, y_pred: pd.Series) -> pd.Series:
diff = (y_pred - y_true).abs() % 360
return diff.apply(lambda x: min(x, 360 - x))
def _plot_errors(df: pd.DataFrame, output_path: Path) -> None:
output_path.parent.mkdir(parents=True, exist_ok=True)
fig, ax = plt.subplots(figsize=(8, 4))
for target in df["target"].unique():
sub = df[df["target"] == target]
ax.plot(sub["h"], sub["abs_error"], marker="o", label=target)
ax.set_xlabel("Heure du horizon (1-6)")
ax.set_ylabel("Erreur absolue (unité cible)")
ax.set_title("Chronos small erreurs sur les 6 dernières heures (holdout)")
ax.grid(True, linestyle=":", alpha=0.4)
ax.legend()
fig.tight_layout()
fig.savefig(output_path, dpi=150)
plt.close(fig)
def main() -> None:
if not CSV_PATH.exists():
raise SystemExit(f"Fichier introuvable : {CSV_PATH}")
pipeline = ChronosPipeline.from_pretrained(
MODEL_ID,
device_map="auto",
dtype="auto",
)
rows: list[Dict[str, object]] = []
for target, meta in TARGETS.items():
series = _load_series(target)
if len(series) < CONTEXT_H + HOLDOUT_H:
raise SystemExit("Pas assez de données pour contexte+holdout.")
context = series.iloc[-(CONTEXT_H + HOLDOUT_H) : -HOLDOUT_H]
holdout = series.iloc[-HOLDOUT_H:]
context_tensor = torch.tensor(context.to_numpy(dtype=float), dtype=torch.float32)
forecasts = pipeline.predict(
[context_tensor],
prediction_length=HOLDOUT_H,
num_samples=NUM_SAMPLES,
)
forecast_mean = forecasts.mean(0)
if forecast_mean.ndim == 2:
forecast_mean = forecast_mean[0]
forecast_series = pd.Series(
np.asarray(forecast_mean).ravel(),
index=holdout.index,
)
if meta["kind"] == "angle":
err = _angular_error(holdout, forecast_series)
else:
err = (forecast_series - holdout).abs()
for i, (ts, e) in enumerate(err.items(), start=1):
rows.append(
{
"target": target,
"timestamp": ts,
"h": i,
"abs_error": float(e),
"unit": meta["unit"],
}
)
out_prefix = f"{target}_{MODEL_ID.replace('/', '__')}_holdout6"
DATA_DIR.mkdir(parents=True, exist_ok=True)
pd.concat([forecast_series.rename("y_pred"), holdout.rename("y_true")], axis=1).to_csv(
DATA_DIR / f"chronos_holdout6_forecast_{out_prefix}.csv",
index=True,
)
df_errors = pd.DataFrame(rows)
err_path = DATA_DIR / "chronos_holdout6_errors.csv"
df_errors.to_csv(err_path, index=False)
_plot_errors(df_errors, FIG_DIR / "chronos_holdout6_errors.png")
print(df_errors.groupby("target")["abs_error"].mean().to_string(float_format=lambda x: f"{x:.3f}"))
print(f"✔ Sauvegardé : {err_path}")
print("✔ Figure : figures/chronos_holdout6_errors.png")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,163 @@
# scripts/run_chronos_multi.py
from __future__ import annotations
from pathlib import Path
import sys
from typing import Iterable
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
PROJECT_ROOT = Path(__file__).resolve().parents[3]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from meteo.dataset import load_raw_csv
try:
from chronos import ChronosPipeline
except ImportError as exc:
raise SystemExit("chronos-forecasting manquant : pip install -r requirements.txt") from exc
CSV_PATH = PROJECT_ROOT / "data" / "weather_minutely.csv"
DOC_DIR = Path(__file__).resolve().parent.parent
DATA_DIR = DOC_DIR / "data"
FIG_DIR = DOC_DIR / "figures"
MODEL_ID = "amazon/chronos-t5-small"
RESAMPLE_RULE = "1h" # on reste sur l'heure pour rester aligné avec le pré-entraînement Chronos
CONTEXT_H = 336 # 14 jours
HORIZONS_H = (1, 6, 24) # horizons demandés (10 min exclu car modèle horaire)
NUM_SAMPLES = 50
TARGETS = {
"temperature": {"kind": "reg"},
"wind_speed": {"kind": "reg"},
"rain_rate": {"kind": "rain"},
}
def _load_series(target: str) -> pd.Series:
df = load_raw_csv(CSV_PATH)
if target not in df.columns:
raise SystemExit(f"Colonne absente : {target}")
s = df[target].resample(RESAMPLE_RULE).mean().interpolate(limit_direction="both")
return s.dropna()
def _prepare_window(series: pd.Series, context_h: int, horizon_h: int) -> tuple[np.ndarray, pd.Series]:
needed = context_h + horizon_h
if len(series) < needed:
raise SystemExit(f"Pas assez de données pour {needed} heures.")
window = series.iloc[-needed:]
context = window.iloc[:context_h]
target = window.iloc[context_h:]
return context.to_numpy(dtype=float), target
def _plot_metrics(df: pd.DataFrame, target: str, output_path: Path) -> None:
output_path.parent.mkdir(parents=True, exist_ok=True)
fig, ax = plt.subplots(figsize=(6, 4))
if df["kind"].iloc[0] == "reg":
ax.plot(df["horizon_h"], df["mae"], marker="o", label="MAE")
ax.plot(df["horizon_h"], df["rmse"], marker="o", label="RMSE")
ax.set_ylabel("Erreur (°C)" if target == "temperature" else "Erreur (unité)")
else:
ax.plot(df["horizon_h"], df["f1"], marker="o", label="F1")
ax.plot(df["horizon_h"], df["brier"], marker="o", label="Brier")
ax.set_ylabel("Score")
ax.set_xlabel("Horizon (heures)")
ax.set_title(f"Chronos (small) {target}")
ax.grid(True, linestyle=":", alpha=0.4)
ax.legend()
fig.tight_layout()
fig.savefig(output_path, dpi=150)
plt.close(fig)
def main() -> None:
if not CSV_PATH.exists():
raise SystemExit(f"Fichier introuvable : {CSV_PATH}")
pipeline = ChronosPipeline.from_pretrained(
MODEL_ID,
device_map="auto",
dtype="auto",
)
rows: list[dict[str, object]] = []
for target, meta in TARGETS.items():
series = _load_series(target)
# On prend la plus grande fenêtre (max horizon)
context_arr, target_series = _prepare_window(series, CONTEXT_H, max(HORIZONS_H))
context_tensor = torch.tensor(context_arr, dtype=torch.float32)
forecasts = pipeline.predict(
[context_tensor],
prediction_length=max(HORIZONS_H),
num_samples=NUM_SAMPLES,
)
forecast_mean = forecasts.mean(0)
if forecast_mean.ndim == 2:
forecast_mean = forecast_mean[0]
forecast_series = pd.Series(
np.asarray(forecast_mean).ravel(),
index=pd.date_range(target_series.index[0], periods=max(HORIZONS_H), freq=RESAMPLE_RULE),
)
for h in HORIZONS_H:
y_true = target_series.iloc[:h]
y_pred = forecast_series.iloc[:h]
if meta["kind"] == "reg":
mae = float((y_pred - y_true).abs().mean())
rmse = float(np.sqrt(((y_pred - y_true) ** 2).mean()))
rows.append(
{"target": target, "kind": "reg", "horizon_h": h, "mae": mae, "rmse": rmse}
)
else:
y_true_bin = (y_true > 0).astype(int)
y_pred_bin = (y_pred > 0).astype(int)
if y_true_bin.sum() == 0:
f1 = 0.0
else:
tp = ((y_true_bin == 1) & (y_pred_bin == 1)).sum()
fp = ((y_true_bin == 0) & (y_pred_bin == 1)).sum()
fn = ((y_true_bin == 1) & (y_pred_bin == 0)).sum()
prec = tp / (tp + fp) if tp + fp > 0 else 0.0
rec = tp / (tp + fn) if tp + fn > 0 else 0.0
f1 = 2 * prec * rec / (prec + rec) if prec + rec > 0 else 0.0
# Brier approximé en traitant y_pred comme proba clampée [0,1]
proba = y_pred.clip(0, 1)
brier = float(((proba - y_true_bin) ** 2).mean())
rows.append(
{"target": target, "kind": "cls", "horizon_h": h, "f1": float(f1), "brier": brier}
)
# Sauvegarde forecast/target complets pour inspection
out_prefix = f"{target}_{MODEL_ID.replace('/', '__')}"
DATA_DIR.mkdir(parents=True, exist_ok=True)
pd.concat([forecast_series.rename("y_pred"), target_series.rename("y_true")], axis=1).to_csv(
DATA_DIR / f"chronos_multi_forecast_{out_prefix}.csv",
index=True,
)
df_metrics = pd.DataFrame(rows)
DATA_DIR.mkdir(parents=True, exist_ok=True)
metrics_path = DATA_DIR / "chronos_multi_metrics.csv"
df_metrics.to_csv(metrics_path, index=False)
for target in df_metrics["target"].unique():
sub = df_metrics[df_metrics["target"] == target].sort_values("horizon_h")
_plot_metrics(sub, target, FIG_DIR / f"chronos_multi_{target}.png")
print(df_metrics.to_string(index=False, float_format=lambda x: f"{x:.3f}"))
print(f"✔ Sauvegardé : {metrics_path}")
print("✔ Figures par cible dans figures/chronos_multi_<target>.png")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,157 @@
# scripts/run_chronos_tuned.py
from __future__ import annotations
from pathlib import Path
import sys
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
PROJECT_ROOT = Path(__file__).resolve().parents[3]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from meteo.dataset import load_raw_csv
try:
from chronos import ChronosPipeline
except ImportError as exc:
raise SystemExit("chronos-forecasting manquant : pip install -r requirements.txt") from exc
CSV_PATH = PROJECT_ROOT / "data" / "weather_minutely.csv"
DOC_DIR = Path(__file__).resolve().parent.parent
DATA_DIR = DOC_DIR / "data"
FIG_DIR = DOC_DIR / "figures"
# Réglages plus prudents : contexte raccourci, horizon 64 max, plus d'échantillons
MODEL_ID = "amazon/chronos-t5-small"
RESAMPLE_RULE = "1h"
CONTEXT_H = 288 # 12 jours
MAX_HORIZON_H = 64 # <=64 conseillé
HORIZONS_H = (1, 6, 24, 48) # on reste cohérents avec nos jalons
NUM_SAMPLES = 100
TARGETS = {
"temperature": {"kind": "reg"},
"wind_speed": {"kind": "reg"},
"rain_rate": {"kind": "rain"},
}
def _load_series(target: str) -> pd.Series:
df = load_raw_csv(CSV_PATH)
if target not in df.columns:
raise SystemExit(f"Colonne absente : {target}")
s = df[target].resample(RESAMPLE_RULE).mean().interpolate(limit_direction="both")
return s.dropna()
def _prepare_window(series: pd.Series, context_h: int, horizon_h: int) -> tuple[np.ndarray, pd.Series]:
needed = context_h + horizon_h
if len(series) < needed:
raise SystemExit(f"Pas assez de données pour {needed} heures.")
window = series.iloc[-needed:]
context = window.iloc[:context_h]
target = window.iloc[context_h:]
return context.to_numpy(dtype=float), target
def _plot_metrics(df: pd.DataFrame, target: str, output_path: Path) -> None:
output_path.parent.mkdir(parents=True, exist_ok=True)
fig, ax = plt.subplots(figsize=(6, 4))
if df["kind"].iloc[0] == "reg":
ax.plot(df["horizon_h"], df["mae"], marker="o", label="MAE")
ax.plot(df["horizon_h"], df["rmse"], marker="o", label="RMSE")
ax.set_ylabel("Erreur (°C)" if target == "temperature" else "Erreur (unité)")
else:
ax.plot(df["horizon_h"], df["f1"], marker="o", label="F1")
ax.plot(df["horizon_h"], df["brier"], marker="o", label="Brier")
ax.set_ylabel("Score")
ax.set_xlabel("Horizon (heures)")
ax.set_title(f"Chronos (small, réglages prudents) {target}")
ax.grid(True, linestyle=":", alpha=0.4)
ax.legend()
fig.tight_layout()
fig.savefig(output_path, dpi=150)
plt.close(fig)
def main() -> None:
pipeline = ChronosPipeline.from_pretrained(
MODEL_ID,
device_map="auto",
dtype="auto",
)
rows: list[dict[str, object]] = []
for target, meta in TARGETS.items():
series = _load_series(target)
context_arr, target_series = _prepare_window(series, CONTEXT_H, MAX_HORIZON_H)
context_tensor = torch.tensor(context_arr, dtype=torch.float32)
forecasts = pipeline.predict(
[context_tensor],
prediction_length=MAX_HORIZON_H,
num_samples=NUM_SAMPLES,
)
forecast_mean = forecasts.mean(0)
if forecast_mean.ndim == 2:
forecast_mean = forecast_mean[0]
forecast_series = pd.Series(
np.asarray(forecast_mean).ravel(),
index=pd.date_range(target_series.index[0], periods=MAX_HORIZON_H, freq=RESAMPLE_RULE),
)
for h in HORIZONS_H:
y_true = target_series.iloc[:h]
y_pred = forecast_series.iloc[:h]
if meta["kind"] == "reg":
mae = float((y_pred - y_true).abs().mean())
rmse = float(np.sqrt(((y_pred - y_true) ** 2).mean()))
rows.append(
{"target": target, "kind": "reg", "horizon_h": h, "mae": mae, "rmse": rmse}
)
else:
y_true_bin = (y_true > 0).astype(int)
y_pred_bin = (y_pred > 0).astype(int)
if y_true_bin.sum() == 0:
f1 = 0.0
else:
tp = ((y_true_bin == 1) & (y_pred_bin == 1)).sum()
fp = ((y_true_bin == 0) & (y_pred_bin == 1)).sum()
fn = ((y_true_bin == 1) & (y_pred_bin == 0)).sum()
prec = tp / (tp + fp) if tp + fp > 0 else 0.0
rec = tp / (tp + fn) if tp + fn > 0 else 0.0
f1 = 2 * prec * rec / (prec + rec) if prec + rec > 0 else 0.0
proba = y_pred.clip(0, 1)
brier = float(((proba - y_true_bin) ** 2).mean())
rows.append(
{"target": target, "kind": "cls", "horizon_h": h, "f1": float(f1), "brier": brier}
)
out_prefix = f"{target}_{MODEL_ID.replace('/', '__')}_tuned"
DATA_DIR.mkdir(parents=True, exist_ok=True)
pd.concat([forecast_series.rename("y_pred"), target_series.rename("y_true")], axis=1).to_csv(
DATA_DIR / f"chronos_tuned_forecast_{out_prefix}.csv",
index=True,
)
df_metrics = pd.DataFrame(rows)
DATA_DIR.mkdir(parents=True, exist_ok=True)
metrics_path = DATA_DIR / "chronos_tuned_metrics.csv"
df_metrics.to_csv(metrics_path, index=False)
for target in df_metrics["target"].unique():
sub = df_metrics[df_metrics["target"] == target].sort_values("horizon_h")
_plot_metrics(sub, target, FIG_DIR / f"chronos_tuned_{target}.png")
print(df_metrics.to_string(index=False, float_format=lambda x: f"{x:.3f}"))
print(f"✔ Sauvegardé : {metrics_path}")
print("✔ Figures par cible (réglages prudents) dans figures/chronos_tuned_<target>.png")
if __name__ == "__main__":
main()