1
donnees_meteo/meteo/plots/relationships.py
2025-11-19 17:01:45 +01:00

427 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Fonctions de tracé pour comparer directement deux ou trois variables."""
from __future__ import annotations
from pathlib import Path
from typing import Callable, Sequence
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
import numpy as np
import pandas as pd
from .base import export_plot_dataset
from meteo.variables import Variable
__all__ = ['plot_scatter_pair', 'plot_pairwise_relationship_grid', 'plot_hexbin_with_third_variable', 'plot_event_composite']
def plot_scatter_pair(
df: pd.DataFrame,
var_x: Variable,
var_y: Variable,
output_path: str | Path,
*,
sample_step: int = 10,
color_by_time: bool = True,
cmap: str = "viridis",
) -> Path:
"""
Trace un nuage de points (scatter) pour une paire de variables.
- On sous-échantillonne les données avec `sample_step` (par exemple,
1 point sur 10) pour éviter un graphique illisible.
- Si `color_by_time` vaut True et que l'index est temporel, les points
sont colorés du plus ancien (sombre) au plus récent (clair).
- Lorsque l'axe Y correspond à la direction du vent, on bascule sur
un graphique polaire plus adapté (0° = Nord, sens horaire) avec
un rayon normalisé : centre = valeur minimale, bord = maximale.
"""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
# On ne garde que les colonnes pertinentes et les lignes complètes
df_pair = df[[var_x.column, var_y.column]].dropna()
if sample_step > 1:
df_pair = df_pair.iloc[::sample_step, :]
export_plot_dataset(df_pair, output_path)
direction_var: Variable | None = None
radial_var: Variable | None = None
direction_series: pd.Series | None = None
radial_series: pd.Series | None = None
if var_y.key == "wind_direction" and var_x.key != "wind_direction":
direction_var = var_y
direction_series = df_pair[var_y.column]
radial_var = var_x
radial_series = df_pair[var_x.column]
elif var_x.key == "wind_direction" and var_y.key != "wind_direction":
direction_var = var_x
direction_series = df_pair[var_x.column]
radial_var = var_y
radial_series = df_pair[var_y.column]
use_polar = direction_var is not None and radial_var is not None
if use_polar:
fig, ax = plt.subplots(subplot_kw={"projection": "polar"})
else:
fig, ax = plt.subplots()
scatter_kwargs: dict = {"s": 5, "alpha": 0.5}
colorbar_meta: dict | None = None
if color_by_time and isinstance(df_pair.index, pd.DatetimeIndex):
idx = df_pair.index
timestamps = idx.view("int64")
time_span = np.ptp(timestamps)
norm = (
Normalize(vmin=timestamps.min(), vmax=timestamps.max())
if time_span > 0
else None
)
scatter_kwargs |= {"c": timestamps, "cmap": cmap}
if norm is not None:
scatter_kwargs["norm"] = norm
colorbar_meta = {
"index": idx,
"timestamps": timestamps,
"time_span": time_span,
}
if use_polar:
assert direction_series is not None and radial_series is not None
assert direction_var is not None and radial_var is not None
theta = np.deg2rad(direction_series.to_numpy(dtype=float) % 360.0)
radius_raw = radial_series.to_numpy(dtype=float)
if radius_raw.size == 0:
radius = radius_raw
value_min = value_max = float("nan")
else:
value_min = float(np.min(radius_raw))
value_max = float(np.max(radius_raw))
if np.isclose(value_min, value_max):
radius = np.zeros_like(radius_raw)
else:
radius = (radius_raw - value_min) / (value_max - value_min)
scatter = ax.scatter(theta, radius, **scatter_kwargs)
cardinal_angles = np.deg2rad(np.arange(0, 360, 45))
cardinal_labels = ["N", "NE", "E", "SE", "S", "SO", "O", "NO"]
ax.set_theta_zero_location("N")
ax.set_theta_direction(-1)
ax.set_xticks(cardinal_angles)
ax.set_xticklabels(cardinal_labels)
if radius_raw.size > 0:
if np.isclose(value_min, value_max):
radial_positions = [0.0]
else:
radial_positions = np.linspace(0.0, 1.0, num=5).tolist()
if np.isclose(value_min, value_max):
actual_values = [value_min]
else:
actual_values = [
value_min + pos * (value_max - value_min)
for pos in radial_positions
]
ax.set_yticks(radial_positions)
ax.set_yticklabels([f"{val:.1f}" for val in actual_values])
ax.set_rlabel_position(225)
ax.set_ylim(0.0, 1.0)
unit_suffix = f" {radial_var.unit}" if radial_var.unit else ""
ax.text(
0.5,
-0.1,
f"Centre = {value_min:.1f}{unit_suffix}, bord = {value_max:.1f}{unit_suffix}",
transform=ax.transAxes,
ha="center",
va="top",
fontsize=8,
)
radial_label = f"{radial_var.label} ({radial_var.unit})" if radial_var.unit else radial_var.label
ax.set_ylabel(radial_label, labelpad=20)
else:
scatter = ax.scatter(
df_pair[var_x.column],
df_pair[var_y.column],
**scatter_kwargs,
)
if colorbar_meta is not None:
cbar = fig.colorbar(scatter, ax=ax)
idx = colorbar_meta["index"]
timestamps = colorbar_meta["timestamps"]
time_span = colorbar_meta["time_span"]
def _format_tick_label(ts: pd.Timestamp) -> str:
base = f"{ts.strftime('%Y-%m-%d')}\n{ts.strftime('%H:%M')}"
tz_name = ts.tzname()
return f"{base} ({tz_name})" if tz_name else base
if time_span > 0:
tick_datetimes = pd.date_range(start=idx.min(), end=idx.max(), periods=5)
tick_positions = tick_datetimes.view("int64")
tick_labels = [_format_tick_label(ts) for ts in tick_datetimes]
cbar.set_ticks(tick_positions)
cbar.set_ticklabels(tick_labels)
else:
cbar.set_ticks([timestamps[0]])
ts = idx[0]
cbar.set_ticklabels([_format_tick_label(ts)])
cbar.set_label("Temps (ancien → récent)")
if use_polar:
assert direction_var is not None and radial_var is not None
ax.set_title(f"{radial_var.label} en fonction de {direction_var.label}")
else:
ax.set_xlabel(f"{var_x.label} ({var_x.unit})")
ax.set_ylabel(f"{var_y.label} ({var_y.unit})")
ax.set_title(f"{var_y.label} en fonction de {var_x.label}")
fig.tight_layout()
fig.savefig(output_path, dpi=150)
plt.close(fig)
return output_path.resolve()
def plot_pairwise_relationship_grid(
df: pd.DataFrame,
variables: Sequence[Variable],
output_path: str | Path,
*,
sample_step: int = 10,
hist_bins: int = 40,
scatter_kwargs: dict | None = None,
) -> Path:
"""Trace un tableau de nuages de points exhaustif (sans doublon)."""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
if not variables:
raise ValueError("La liste de variables ne peut pas être vide.")
columns = [v.column for v in variables]
for col in columns:
if col not in df.columns:
raise KeyError(f"Colonne absente dans le DataFrame : {col}")
df_pairs = df[columns].dropna()
if df_pairs.empty:
raise RuntimeError("Aucune ligne complète pour générer les nuages de points.")
if sample_step > 1:
df_pairs = df_pairs.iloc[::sample_step, :]
export_plot_dataset(df_pairs, output_path)
n = len(variables)
fig_size = max(3.0, 1.8 * n)
fig, axes = plt.subplots(n, n, figsize=(fig_size, fig_size), squeeze=False)
default_scatter_kwargs = {"s": 5, "alpha": 0.5}
scatter_kwargs = {**default_scatter_kwargs, **(scatter_kwargs or {})}
for row_idx, var_y in enumerate(variables):
for col_idx, var_x in enumerate(variables):
ax = axes[row_idx][col_idx]
if row_idx < col_idx:
# Triangle supérieur vide pour éviter les doublons
ax.set_visible(False)
continue
if row_idx == col_idx:
series = df_pairs[var_x.column].dropna()
if series.empty:
ax.text(0.5, 0.5, "(vide)", ha="center", va="center")
ax.set_axis_off()
else:
bins = min(hist_bins, max(5, series.nunique()))
ax.hist(series, bins=bins, color="tab:blue", alpha=0.7)
ax.set_ylabel("")
else:
ax.scatter(
df_pairs[var_x.column],
df_pairs[var_y.column],
**scatter_kwargs,
)
if row_idx == n - 1:
ax.set_xlabel(var_x.label)
else:
ax.set_xticklabels([])
if col_idx == 0:
ax.set_ylabel(var_y.label)
else:
ax.set_yticklabels([])
fig.suptitle("Matrice de corrélations simples (nuages de points)")
fig.tight_layout(rect=[0, 0, 1, 0.97])
fig.savefig(output_path, dpi=150)
plt.close(fig)
return output_path.resolve()
def plot_hexbin_with_third_variable(
df: pd.DataFrame,
var_x: Variable,
var_y: Variable,
var_color: Variable,
output_path: str | Path,
*,
gridsize: int = 60,
mincnt: int = 5,
reduce_func: Callable[[np.ndarray], float] | None = None,
reduce_func_label: str | None = None,
cmap: str = "viridis",
) -> Path:
"""
Trace une carte de densité hexbin où la couleur encode une 3e variable.
"""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
reduce_func = reduce_func or np.mean
df_xyz = df[[var_x.column, var_y.column, var_color.column]].dropna()
export_plot_dataset(df_xyz, output_path)
if df_xyz.empty:
fig, ax = plt.subplots()
ax.text(
0.5,
0.5,
"Pas de données valides pour cette combinaison.",
ha="center",
va="center",
)
ax.set_axis_off()
fig.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close(fig)
return output_path.resolve()
fig, ax = plt.subplots()
hb = ax.hexbin(
df_xyz[var_x.column],
df_xyz[var_y.column],
C=df_xyz[var_color.column],
reduce_C_function=reduce_func,
gridsize=gridsize,
cmap=cmap,
mincnt=mincnt,
)
func_label = reduce_func_label or getattr(reduce_func, "__name__", "statistique")
colorbar_label = f"{func_label.capitalize()} de {var_color.label}"
cbar = fig.colorbar(hb, ax=ax)
cbar.set_label(colorbar_label)
ax.set_xlabel(f"{var_x.label} ({var_x.unit})")
ax.set_ylabel(f"{var_y.label} ({var_y.unit})")
ax.set_title(
f"{var_y.label} vs {var_x.label}\nCouleur : {func_label} de {var_color.label}"
)
ax.grid(False)
fig.tight_layout()
fig.savefig(output_path, dpi=150)
plt.close(fig)
return output_path.resolve()
def plot_event_composite(
aligned_segments: pd.DataFrame,
variables: Sequence[Variable],
output_path: str | Path,
*,
quantiles: tuple[float, float] = (0.25, 0.75),
baseline_label: str = "Début de l'événement",
) -> Path:
"""
Trace les moyennes/médianes autour d'événements détectés avec éventail inter-quantiles.
"""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
if aligned_segments.empty:
fig, ax = plt.subplots()
ax.text(
0.5,
0.5,
"Aucun événement aligné à tracer.",
ha="center",
va="center",
)
ax.set_axis_off()
fig.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close(fig)
return output_path.resolve()
if "offset_minutes" not in aligned_segments.index.names:
raise ValueError("aligned_segments doit avoir un niveau 'offset_minutes'.")
group = aligned_segments.groupby(level="offset_minutes")
mean_df = group.mean()
median_df = group.median()
q_low, q_high = quantiles
quantile_low = group.quantile(q_low) if q_low is not None else None
quantile_high = group.quantile(q_high) if q_high is not None else None
export_plot_dataset(
{
"mean": mean_df,
"median": median_df,
"quantile_low": quantile_low,
"quantile_high": quantile_high,
},
output_path,
)
offsets = mean_df.index.to_numpy(dtype=float)
n_vars = len(variables)
fig, axes = plt.subplots(n_vars, 1, figsize=(10, 3 * n_vars), sharex=True)
if n_vars == 1:
axes = [axes]
for ax, var in zip(axes, variables):
col = var.column
ax.axvline(0, color="black", linestyle="--", linewidth=1, label=baseline_label)
ax.plot(offsets, mean_df[col], color="tab:blue", label="Moyenne")
ax.plot(offsets, median_df[col], color="tab:orange", linestyle="--", label="Médiane")
if quantile_low is not None and quantile_high is not None:
ax.fill_between(
offsets,
quantile_low[col],
quantile_high[col],
color="tab:blue",
alpha=0.2,
label=f"IQR {int(q_low*100)}{int(q_high*100)}%",
)
ylabel = f"{var.label} ({var.unit})" if var.unit else var.label
ax.set_ylabel(ylabel)
ax.grid(True, linestyle=":", alpha=0.5)
axes[-1].set_xlabel("Minutes autour de l'événement")
axes[0].legend(loc="upper right")
total_events = len(aligned_segments.index.get_level_values("event_id").unique())
fig.suptitle(f"Composites autour d'événements ({total_events} occurrences)")
fig.tight_layout(rect=[0, 0, 1, 0.97])
fig.savefig(output_path, dpi=150)
plt.close(fig)
return output_path.resolve()