1
donnees_meteo/model/splits.py

66 lines
1.9 KiB
Python

"""Fonctions de découpe temporelle pour l'entraînement et la validation."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterator
import pandas as pd
@dataclass(frozen=True)
class Split:
"""Indices pour une paire (train, validation)."""
train: pd.Index
validation: pd.Index
def chronological_split(
df: pd.DataFrame,
*,
train_frac: float = 0.7,
val_frac: float = 0.15,
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""
Coupe un DataFrame chronologiquement en (train, validation, test) sans fuite temporelle.
"""
if not 0 < train_frac < 1 or not 0 < val_frac < 1:
raise ValueError("train_frac et val_frac doivent être dans ]0, 1[.")
if train_frac + val_frac >= 1:
raise ValueError("train_frac + val_frac doit être < 1.")
n = len(df)
n_train = int(n * train_frac)
n_val = int(n * val_frac)
train_df = df.iloc[:n_train]
val_df = df.iloc[n_train : n_train + n_val]
test_df = df.iloc[n_train + n_val :]
return train_df, val_df, test_df
def rolling_time_series_splits(
df: pd.DataFrame,
*,
n_splits: int = 3,
train_frac: float = 0.7,
val_frac: float = 0.15,
) -> Iterator[Split]:
"""
Génère plusieurs paires (train, validation) chronologiques en “roulant” la fenêtre.
Chaque _fold_ commence en début de série et pousse progressivement la frontière
train/validation vers le futur. Le test final reste en dehors de ces folds.
"""
if n_splits < 1:
raise ValueError("n_splits doit être >= 1.")
for split_idx in range(n_splits):
# On avance la fenêtre de validation à chaque itération
offset = int(len(df) * 0.05 * split_idx)
sub_df = df.iloc[offset:]
train_df, val_df, _ = chronological_split(sub_df, train_frac=train_frac, val_frac=val_frac)
yield Split(train=train_df.index, validation=val_df.index)