"""Fonctions de tracé pour comparer directement deux ou trois variables.""" from __future__ import annotations from pathlib import Path from typing import Callable, Sequence import matplotlib.pyplot as plt from matplotlib.colors import Normalize import numpy as np import pandas as pd from .base import export_plot_dataset from meteo.variables import Variable __all__ = ['plot_scatter_pair', 'plot_pairwise_relationship_grid', 'plot_hexbin_with_third_variable', 'plot_event_composite'] def plot_scatter_pair( df: pd.DataFrame, var_x: Variable, var_y: Variable, output_path: str | Path, *, sample_step: int = 10, color_by_time: bool = True, cmap: str = "viridis", ) -> Path: """ Trace un nuage de points (scatter) pour une paire de variables. - On sous-échantillonne les données avec `sample_step` (par exemple, 1 point sur 10) pour éviter un graphique illisible. - Si `color_by_time` vaut True et que l'index est temporel, les points sont colorés du plus ancien (sombre) au plus récent (clair). - Lorsque l'axe Y correspond à la direction du vent, on bascule sur un graphique polaire plus adapté (0° = Nord, sens horaire) avec un rayon normalisé : centre = valeur minimale, bord = maximale. """ output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) # On ne garde que les colonnes pertinentes et les lignes complètes df_pair = df[[var_x.column, var_y.column]].dropna() if sample_step > 1: df_pair = df_pair.iloc[::sample_step, :] export_plot_dataset(df_pair, output_path) direction_var: Variable | None = None radial_var: Variable | None = None direction_series: pd.Series | None = None radial_series: pd.Series | None = None if var_y.key == "wind_direction" and var_x.key != "wind_direction": direction_var = var_y direction_series = df_pair[var_y.column] radial_var = var_x radial_series = df_pair[var_x.column] elif var_x.key == "wind_direction" and var_y.key != "wind_direction": direction_var = var_x direction_series = df_pair[var_x.column] radial_var = var_y radial_series = df_pair[var_y.column] use_polar = direction_var is not None and radial_var is not None if use_polar: fig, ax = plt.subplots(subplot_kw={"projection": "polar"}) else: fig, ax = plt.subplots() scatter_kwargs: dict = {"s": 5, "alpha": 0.5} colorbar_meta: dict | None = None if color_by_time and isinstance(df_pair.index, pd.DatetimeIndex): idx = df_pair.index timestamps = idx.view("int64") time_span = np.ptp(timestamps) norm = ( Normalize(vmin=timestamps.min(), vmax=timestamps.max()) if time_span > 0 else None ) scatter_kwargs |= {"c": timestamps, "cmap": cmap} if norm is not None: scatter_kwargs["norm"] = norm colorbar_meta = { "index": idx, "timestamps": timestamps, "time_span": time_span, } if use_polar: assert direction_series is not None and radial_series is not None assert direction_var is not None and radial_var is not None theta = np.deg2rad(direction_series.to_numpy(dtype=float) % 360.0) radius_raw = radial_series.to_numpy(dtype=float) if radius_raw.size == 0: radius = radius_raw value_min = value_max = float("nan") else: value_min = float(np.min(radius_raw)) value_max = float(np.max(radius_raw)) if np.isclose(value_min, value_max): radius = np.zeros_like(radius_raw) else: radius = (radius_raw - value_min) / (value_max - value_min) scatter = ax.scatter(theta, radius, **scatter_kwargs) cardinal_angles = np.deg2rad(np.arange(0, 360, 45)) cardinal_labels = ["N", "NE", "E", "SE", "S", "SO", "O", "NO"] ax.set_theta_zero_location("N") ax.set_theta_direction(-1) ax.set_xticks(cardinal_angles) ax.set_xticklabels(cardinal_labels) if radius_raw.size > 0: if np.isclose(value_min, value_max): radial_positions = [0.0] else: radial_positions = np.linspace(0.0, 1.0, num=5).tolist() if np.isclose(value_min, value_max): actual_values = [value_min] else: actual_values = [ value_min + pos * (value_max - value_min) for pos in radial_positions ] ax.set_yticks(radial_positions) ax.set_yticklabels([f"{val:.1f}" for val in actual_values]) ax.set_rlabel_position(225) ax.set_ylim(0.0, 1.0) unit_suffix = f" {radial_var.unit}" if radial_var.unit else "" ax.text( 0.5, -0.1, f"Centre = {value_min:.1f}{unit_suffix}, bord = {value_max:.1f}{unit_suffix}", transform=ax.transAxes, ha="center", va="top", fontsize=8, ) radial_label = f"{radial_var.label} ({radial_var.unit})" if radial_var.unit else radial_var.label ax.set_ylabel(radial_label, labelpad=20) else: scatter = ax.scatter( df_pair[var_x.column], df_pair[var_y.column], **scatter_kwargs, ) if colorbar_meta is not None: cbar = fig.colorbar(scatter, ax=ax) idx = colorbar_meta["index"] timestamps = colorbar_meta["timestamps"] time_span = colorbar_meta["time_span"] def _format_tick_label(ts: pd.Timestamp) -> str: base = f"{ts.strftime('%Y-%m-%d')}\n{ts.strftime('%H:%M')}" tz_name = ts.tzname() return f"{base} ({tz_name})" if tz_name else base if time_span > 0: tick_datetimes = pd.date_range(start=idx.min(), end=idx.max(), periods=5) tick_positions = tick_datetimes.view("int64") tick_labels = [_format_tick_label(ts) for ts in tick_datetimes] cbar.set_ticks(tick_positions) cbar.set_ticklabels(tick_labels) else: cbar.set_ticks([timestamps[0]]) ts = idx[0] cbar.set_ticklabels([_format_tick_label(ts)]) cbar.set_label("Temps (ancien → récent)") if use_polar: assert direction_var is not None and radial_var is not None ax.set_title(f"{radial_var.label} en fonction de {direction_var.label}") else: ax.set_xlabel(f"{var_x.label} ({var_x.unit})") ax.set_ylabel(f"{var_y.label} ({var_y.unit})") ax.set_title(f"{var_y.label} en fonction de {var_x.label}") fig.tight_layout() fig.savefig(output_path, dpi=150) plt.close(fig) return output_path.resolve() def plot_pairwise_relationship_grid( df: pd.DataFrame, variables: Sequence[Variable], output_path: str | Path, *, sample_step: int = 10, hist_bins: int = 40, scatter_kwargs: dict | None = None, ) -> Path: """Trace un tableau de nuages de points exhaustif (sans doublon).""" output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) if not variables: raise ValueError("La liste de variables ne peut pas être vide.") columns = [v.column for v in variables] for col in columns: if col not in df.columns: raise KeyError(f"Colonne absente dans le DataFrame : {col}") df_pairs = df[columns].dropna() if df_pairs.empty: raise RuntimeError("Aucune ligne complète pour générer les nuages de points.") if sample_step > 1: df_pairs = df_pairs.iloc[::sample_step, :] export_plot_dataset(df_pairs, output_path) n = len(variables) fig_size = max(3.0, 1.8 * n) fig, axes = plt.subplots(n, n, figsize=(fig_size, fig_size), squeeze=False) default_scatter_kwargs = {"s": 5, "alpha": 0.5} scatter_kwargs = {**default_scatter_kwargs, **(scatter_kwargs or {})} for row_idx, var_y in enumerate(variables): for col_idx, var_x in enumerate(variables): ax = axes[row_idx][col_idx] if row_idx < col_idx: # Triangle supérieur vide pour éviter les doublons ax.set_visible(False) continue if row_idx == col_idx: series = df_pairs[var_x.column].dropna() if series.empty: ax.text(0.5, 0.5, "(vide)", ha="center", va="center") ax.set_axis_off() else: bins = min(hist_bins, max(5, series.nunique())) ax.hist(series, bins=bins, color="tab:blue", alpha=0.7) ax.set_ylabel("") else: ax.scatter( df_pairs[var_x.column], df_pairs[var_y.column], **scatter_kwargs, ) if row_idx == n - 1: ax.set_xlabel(var_x.label) else: ax.set_xticklabels([]) if col_idx == 0: ax.set_ylabel(var_y.label) else: ax.set_yticklabels([]) fig.suptitle("Matrice de corrélations simples (nuages de points)") fig.tight_layout(rect=[0, 0, 1, 0.97]) fig.savefig(output_path, dpi=150) plt.close(fig) return output_path.resolve() def plot_hexbin_with_third_variable( df: pd.DataFrame, var_x: Variable, var_y: Variable, var_color: Variable, output_path: str | Path, *, gridsize: int = 60, mincnt: int = 5, reduce_func: Callable[[np.ndarray], float] | None = None, reduce_func_label: str | None = None, cmap: str = "viridis", ) -> Path: """ Trace une carte de densité hexbin où la couleur encode une 3e variable. """ output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) reduce_func = reduce_func or np.mean df_xyz = df[[var_x.column, var_y.column, var_color.column]].dropna() export_plot_dataset(df_xyz, output_path) if df_xyz.empty: fig, ax = plt.subplots() ax.text( 0.5, 0.5, "Pas de données valides pour cette combinaison.", ha="center", va="center", ) ax.set_axis_off() fig.savefig(output_path, dpi=150, bbox_inches="tight") plt.close(fig) return output_path.resolve() fig, ax = plt.subplots() hb = ax.hexbin( df_xyz[var_x.column], df_xyz[var_y.column], C=df_xyz[var_color.column], reduce_C_function=reduce_func, gridsize=gridsize, cmap=cmap, mincnt=mincnt, ) func_label = reduce_func_label or getattr(reduce_func, "__name__", "statistique") colorbar_label = f"{func_label.capitalize()} de {var_color.label}" cbar = fig.colorbar(hb, ax=ax) cbar.set_label(colorbar_label) ax.set_xlabel(f"{var_x.label} ({var_x.unit})") ax.set_ylabel(f"{var_y.label} ({var_y.unit})") ax.set_title( f"{var_y.label} vs {var_x.label}\nCouleur : {func_label} de {var_color.label}" ) ax.grid(False) fig.tight_layout() fig.savefig(output_path, dpi=150) plt.close(fig) return output_path.resolve() def plot_event_composite( aligned_segments: pd.DataFrame, variables: Sequence[Variable], output_path: str | Path, *, quantiles: tuple[float, float] = (0.25, 0.75), baseline_label: str = "Début de l'événement", ) -> Path: """ Trace les moyennes/médianes autour d'événements détectés avec éventail inter-quantiles. """ output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) if aligned_segments.empty: fig, ax = plt.subplots() ax.text( 0.5, 0.5, "Aucun événement aligné à tracer.", ha="center", va="center", ) ax.set_axis_off() fig.savefig(output_path, dpi=150, bbox_inches="tight") plt.close(fig) return output_path.resolve() if "offset_minutes" not in aligned_segments.index.names: raise ValueError("aligned_segments doit avoir un niveau 'offset_minutes'.") group = aligned_segments.groupby(level="offset_minutes") mean_df = group.mean() median_df = group.median() q_low, q_high = quantiles quantile_low = group.quantile(q_low) if q_low is not None else None quantile_high = group.quantile(q_high) if q_high is not None else None export_plot_dataset( { "mean": mean_df, "median": median_df, "quantile_low": quantile_low, "quantile_high": quantile_high, }, output_path, ) offsets = mean_df.index.to_numpy(dtype=float) n_vars = len(variables) fig, axes = plt.subplots(n_vars, 1, figsize=(10, 3 * n_vars), sharex=True) if n_vars == 1: axes = [axes] for ax, var in zip(axes, variables): col = var.column ax.axvline(0, color="black", linestyle="--", linewidth=1, label=baseline_label) ax.plot(offsets, mean_df[col], color="tab:blue", label="Moyenne") ax.plot(offsets, median_df[col], color="tab:orange", linestyle="--", label="Médiane") if quantile_low is not None and quantile_high is not None: ax.fill_between( offsets, quantile_low[col], quantile_high[col], color="tab:blue", alpha=0.2, label=f"IQR {int(q_low*100)}–{int(q_high*100)}%", ) ylabel = f"{var.label} ({var.unit})" if var.unit else var.label ax.set_ylabel(ylabel) ax.grid(True, linestyle=":", alpha=0.5) axes[-1].set_xlabel("Minutes autour de l'événement") axes[0].legend(loc="upper right") total_events = len(aligned_segments.index.get_level_values("event_id").unique()) fig.suptitle(f"Composites autour d'événements ({total_events} occurrences)") fig.tight_layout(rect=[0, 0, 1, 0.97]) fig.savefig(output_path, dpi=150) plt.close(fig) return output_path.resolve()