427 lines
14 KiB
Python
427 lines
14 KiB
Python
"""Fonctions de tracé pour comparer directement deux ou trois variables."""
|
||
|
||
from __future__ import annotations
|
||
|
||
from pathlib import Path
|
||
from typing import Callable, Sequence
|
||
|
||
import matplotlib.pyplot as plt
|
||
from matplotlib.colors import Normalize
|
||
import numpy as np
|
||
import pandas as pd
|
||
|
||
from .base import export_plot_dataset
|
||
from meteo.variables import Variable
|
||
|
||
__all__ = ['plot_scatter_pair', 'plot_pairwise_relationship_grid', 'plot_hexbin_with_third_variable', 'plot_event_composite']
|
||
|
||
|
||
def plot_scatter_pair(
|
||
df: pd.DataFrame,
|
||
var_x: Variable,
|
||
var_y: Variable,
|
||
output_path: str | Path,
|
||
*,
|
||
sample_step: int = 10,
|
||
color_by_time: bool = True,
|
||
cmap: str = "viridis",
|
||
) -> Path:
|
||
"""
|
||
Trace un nuage de points (scatter) pour une paire de variables.
|
||
|
||
- On sous-échantillonne les données avec `sample_step` (par exemple,
|
||
1 point sur 10) pour éviter un graphique illisible.
|
||
- Si `color_by_time` vaut True et que l'index est temporel, les points
|
||
sont colorés du plus ancien (sombre) au plus récent (clair).
|
||
- Lorsque l'axe Y correspond à la direction du vent, on bascule sur
|
||
un graphique polaire plus adapté (0° = Nord, sens horaire) avec
|
||
un rayon normalisé : centre = valeur minimale, bord = maximale.
|
||
"""
|
||
output_path = Path(output_path)
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
# On ne garde que les colonnes pertinentes et les lignes complètes
|
||
df_pair = df[[var_x.column, var_y.column]].dropna()
|
||
|
||
if sample_step > 1:
|
||
df_pair = df_pair.iloc[::sample_step, :]
|
||
|
||
export_plot_dataset(df_pair, output_path)
|
||
|
||
direction_var: Variable | None = None
|
||
radial_var: Variable | None = None
|
||
direction_series: pd.Series | None = None
|
||
radial_series: pd.Series | None = None
|
||
|
||
if var_y.key == "wind_direction" and var_x.key != "wind_direction":
|
||
direction_var = var_y
|
||
direction_series = df_pair[var_y.column]
|
||
radial_var = var_x
|
||
radial_series = df_pair[var_x.column]
|
||
elif var_x.key == "wind_direction" and var_y.key != "wind_direction":
|
||
direction_var = var_x
|
||
direction_series = df_pair[var_x.column]
|
||
radial_var = var_y
|
||
radial_series = df_pair[var_y.column]
|
||
|
||
use_polar = direction_var is not None and radial_var is not None
|
||
|
||
if use_polar:
|
||
fig, ax = plt.subplots(subplot_kw={"projection": "polar"})
|
||
else:
|
||
fig, ax = plt.subplots()
|
||
|
||
scatter_kwargs: dict = {"s": 5, "alpha": 0.5}
|
||
colorbar_meta: dict | None = None
|
||
|
||
if color_by_time and isinstance(df_pair.index, pd.DatetimeIndex):
|
||
idx = df_pair.index
|
||
timestamps = idx.view("int64")
|
||
time_span = np.ptp(timestamps)
|
||
norm = (
|
||
Normalize(vmin=timestamps.min(), vmax=timestamps.max())
|
||
if time_span > 0
|
||
else None
|
||
)
|
||
scatter_kwargs |= {"c": timestamps, "cmap": cmap}
|
||
if norm is not None:
|
||
scatter_kwargs["norm"] = norm
|
||
colorbar_meta = {
|
||
"index": idx,
|
||
"timestamps": timestamps,
|
||
"time_span": time_span,
|
||
}
|
||
|
||
if use_polar:
|
||
assert direction_series is not None and radial_series is not None
|
||
assert direction_var is not None and radial_var is not None
|
||
|
||
theta = np.deg2rad(direction_series.to_numpy(dtype=float) % 360.0)
|
||
radius_raw = radial_series.to_numpy(dtype=float)
|
||
|
||
if radius_raw.size == 0:
|
||
radius = radius_raw
|
||
value_min = value_max = float("nan")
|
||
else:
|
||
value_min = float(np.min(radius_raw))
|
||
value_max = float(np.max(radius_raw))
|
||
if np.isclose(value_min, value_max):
|
||
radius = np.zeros_like(radius_raw)
|
||
else:
|
||
radius = (radius_raw - value_min) / (value_max - value_min)
|
||
|
||
scatter = ax.scatter(theta, radius, **scatter_kwargs)
|
||
|
||
cardinal_angles = np.deg2rad(np.arange(0, 360, 45))
|
||
cardinal_labels = ["N", "NE", "E", "SE", "S", "SO", "O", "NO"]
|
||
ax.set_theta_zero_location("N")
|
||
ax.set_theta_direction(-1)
|
||
ax.set_xticks(cardinal_angles)
|
||
ax.set_xticklabels(cardinal_labels)
|
||
|
||
if radius_raw.size > 0:
|
||
if np.isclose(value_min, value_max):
|
||
radial_positions = [0.0]
|
||
else:
|
||
radial_positions = np.linspace(0.0, 1.0, num=5).tolist()
|
||
if np.isclose(value_min, value_max):
|
||
actual_values = [value_min]
|
||
else:
|
||
actual_values = [
|
||
value_min + pos * (value_max - value_min)
|
||
for pos in radial_positions
|
||
]
|
||
ax.set_yticks(radial_positions)
|
||
ax.set_yticklabels([f"{val:.1f}" for val in actual_values])
|
||
ax.set_rlabel_position(225)
|
||
ax.set_ylim(0.0, 1.0)
|
||
|
||
unit_suffix = f" {radial_var.unit}" if radial_var.unit else ""
|
||
ax.text(
|
||
0.5,
|
||
-0.1,
|
||
f"Centre = {value_min:.1f}{unit_suffix}, bord = {value_max:.1f}{unit_suffix}",
|
||
transform=ax.transAxes,
|
||
ha="center",
|
||
va="top",
|
||
fontsize=8,
|
||
)
|
||
|
||
radial_label = f"{radial_var.label} ({radial_var.unit})" if radial_var.unit else radial_var.label
|
||
ax.set_ylabel(radial_label, labelpad=20)
|
||
else:
|
||
scatter = ax.scatter(
|
||
df_pair[var_x.column],
|
||
df_pair[var_y.column],
|
||
**scatter_kwargs,
|
||
)
|
||
|
||
if colorbar_meta is not None:
|
||
cbar = fig.colorbar(scatter, ax=ax)
|
||
idx = colorbar_meta["index"]
|
||
timestamps = colorbar_meta["timestamps"]
|
||
time_span = colorbar_meta["time_span"]
|
||
|
||
def _format_tick_label(ts: pd.Timestamp) -> str:
|
||
base = f"{ts.strftime('%Y-%m-%d')}\n{ts.strftime('%H:%M')}"
|
||
tz_name = ts.tzname()
|
||
return f"{base} ({tz_name})" if tz_name else base
|
||
|
||
if time_span > 0:
|
||
tick_datetimes = pd.date_range(start=idx.min(), end=idx.max(), periods=5)
|
||
tick_positions = tick_datetimes.view("int64")
|
||
tick_labels = [_format_tick_label(ts) for ts in tick_datetimes]
|
||
cbar.set_ticks(tick_positions)
|
||
cbar.set_ticklabels(tick_labels)
|
||
else:
|
||
cbar.set_ticks([timestamps[0]])
|
||
ts = idx[0]
|
||
cbar.set_ticklabels([_format_tick_label(ts)])
|
||
|
||
cbar.set_label("Temps (ancien → récent)")
|
||
|
||
if use_polar:
|
||
assert direction_var is not None and radial_var is not None
|
||
ax.set_title(f"{radial_var.label} en fonction de {direction_var.label}")
|
||
else:
|
||
ax.set_xlabel(f"{var_x.label} ({var_x.unit})")
|
||
ax.set_ylabel(f"{var_y.label} ({var_y.unit})")
|
||
ax.set_title(f"{var_y.label} en fonction de {var_x.label}")
|
||
fig.tight_layout()
|
||
fig.savefig(output_path, dpi=150)
|
||
plt.close(fig)
|
||
|
||
return output_path.resolve()
|
||
|
||
|
||
def plot_pairwise_relationship_grid(
|
||
df: pd.DataFrame,
|
||
variables: Sequence[Variable],
|
||
output_path: str | Path,
|
||
*,
|
||
sample_step: int = 10,
|
||
hist_bins: int = 40,
|
||
scatter_kwargs: dict | None = None,
|
||
) -> Path:
|
||
"""Trace un tableau de nuages de points exhaustif (sans doublon)."""
|
||
|
||
output_path = Path(output_path)
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
if not variables:
|
||
raise ValueError("La liste de variables ne peut pas être vide.")
|
||
|
||
columns = [v.column for v in variables]
|
||
for col in columns:
|
||
if col not in df.columns:
|
||
raise KeyError(f"Colonne absente dans le DataFrame : {col}")
|
||
|
||
df_pairs = df[columns].dropna()
|
||
if df_pairs.empty:
|
||
raise RuntimeError("Aucune ligne complète pour générer les nuages de points.")
|
||
|
||
if sample_step > 1:
|
||
df_pairs = df_pairs.iloc[::sample_step, :]
|
||
|
||
export_plot_dataset(df_pairs, output_path)
|
||
|
||
n = len(variables)
|
||
fig_size = max(3.0, 1.8 * n)
|
||
fig, axes = plt.subplots(n, n, figsize=(fig_size, fig_size), squeeze=False)
|
||
|
||
default_scatter_kwargs = {"s": 5, "alpha": 0.5}
|
||
scatter_kwargs = {**default_scatter_kwargs, **(scatter_kwargs or {})}
|
||
|
||
for row_idx, var_y in enumerate(variables):
|
||
for col_idx, var_x in enumerate(variables):
|
||
ax = axes[row_idx][col_idx]
|
||
|
||
if row_idx < col_idx:
|
||
# Triangle supérieur vide pour éviter les doublons
|
||
ax.set_visible(False)
|
||
continue
|
||
|
||
if row_idx == col_idx:
|
||
series = df_pairs[var_x.column].dropna()
|
||
if series.empty:
|
||
ax.text(0.5, 0.5, "(vide)", ha="center", va="center")
|
||
ax.set_axis_off()
|
||
else:
|
||
bins = min(hist_bins, max(5, series.nunique()))
|
||
ax.hist(series, bins=bins, color="tab:blue", alpha=0.7)
|
||
ax.set_ylabel("")
|
||
else:
|
||
ax.scatter(
|
||
df_pairs[var_x.column],
|
||
df_pairs[var_y.column],
|
||
**scatter_kwargs,
|
||
)
|
||
|
||
if row_idx == n - 1:
|
||
ax.set_xlabel(var_x.label)
|
||
else:
|
||
ax.set_xticklabels([])
|
||
|
||
if col_idx == 0:
|
||
ax.set_ylabel(var_y.label)
|
||
else:
|
||
ax.set_yticklabels([])
|
||
|
||
fig.suptitle("Matrice de corrélations simples (nuages de points)")
|
||
fig.tight_layout(rect=[0, 0, 1, 0.97])
|
||
fig.savefig(output_path, dpi=150)
|
||
plt.close(fig)
|
||
|
||
return output_path.resolve()
|
||
|
||
def plot_hexbin_with_third_variable(
|
||
df: pd.DataFrame,
|
||
var_x: Variable,
|
||
var_y: Variable,
|
||
var_color: Variable,
|
||
output_path: str | Path,
|
||
*,
|
||
gridsize: int = 60,
|
||
mincnt: int = 5,
|
||
reduce_func: Callable[[np.ndarray], float] | None = None,
|
||
reduce_func_label: str | None = None,
|
||
cmap: str = "viridis",
|
||
) -> Path:
|
||
"""
|
||
Trace une carte de densité hexbin où la couleur encode une 3e variable.
|
||
"""
|
||
output_path = Path(output_path)
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
reduce_func = reduce_func or np.mean
|
||
|
||
df_xyz = df[[var_x.column, var_y.column, var_color.column]].dropna()
|
||
export_plot_dataset(df_xyz, output_path)
|
||
if df_xyz.empty:
|
||
fig, ax = plt.subplots()
|
||
ax.text(
|
||
0.5,
|
||
0.5,
|
||
"Pas de données valides pour cette combinaison.",
|
||
ha="center",
|
||
va="center",
|
||
)
|
||
ax.set_axis_off()
|
||
fig.savefig(output_path, dpi=150, bbox_inches="tight")
|
||
plt.close(fig)
|
||
return output_path.resolve()
|
||
|
||
fig, ax = plt.subplots()
|
||
hb = ax.hexbin(
|
||
df_xyz[var_x.column],
|
||
df_xyz[var_y.column],
|
||
C=df_xyz[var_color.column],
|
||
reduce_C_function=reduce_func,
|
||
gridsize=gridsize,
|
||
cmap=cmap,
|
||
mincnt=mincnt,
|
||
)
|
||
|
||
func_label = reduce_func_label or getattr(reduce_func, "__name__", "statistique")
|
||
colorbar_label = f"{func_label.capitalize()} de {var_color.label}"
|
||
cbar = fig.colorbar(hb, ax=ax)
|
||
cbar.set_label(colorbar_label)
|
||
|
||
ax.set_xlabel(f"{var_x.label} ({var_x.unit})")
|
||
ax.set_ylabel(f"{var_y.label} ({var_y.unit})")
|
||
ax.set_title(
|
||
f"{var_y.label} vs {var_x.label}\nCouleur : {func_label} de {var_color.label}"
|
||
)
|
||
ax.grid(False)
|
||
fig.tight_layout()
|
||
fig.savefig(output_path, dpi=150)
|
||
plt.close(fig)
|
||
|
||
return output_path.resolve()
|
||
|
||
def plot_event_composite(
|
||
aligned_segments: pd.DataFrame,
|
||
variables: Sequence[Variable],
|
||
output_path: str | Path,
|
||
*,
|
||
quantiles: tuple[float, float] = (0.25, 0.75),
|
||
baseline_label: str = "Début de l'événement",
|
||
) -> Path:
|
||
"""
|
||
Trace les moyennes/médianes autour d'événements détectés avec éventail inter-quantiles.
|
||
"""
|
||
output_path = Path(output_path)
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
if aligned_segments.empty:
|
||
fig, ax = plt.subplots()
|
||
ax.text(
|
||
0.5,
|
||
0.5,
|
||
"Aucun événement aligné à tracer.",
|
||
ha="center",
|
||
va="center",
|
||
)
|
||
ax.set_axis_off()
|
||
fig.savefig(output_path, dpi=150, bbox_inches="tight")
|
||
plt.close(fig)
|
||
return output_path.resolve()
|
||
|
||
if "offset_minutes" not in aligned_segments.index.names:
|
||
raise ValueError("aligned_segments doit avoir un niveau 'offset_minutes'.")
|
||
|
||
group = aligned_segments.groupby(level="offset_minutes")
|
||
mean_df = group.mean()
|
||
median_df = group.median()
|
||
|
||
q_low, q_high = quantiles
|
||
quantile_low = group.quantile(q_low) if q_low is not None else None
|
||
quantile_high = group.quantile(q_high) if q_high is not None else None
|
||
|
||
export_plot_dataset(
|
||
{
|
||
"mean": mean_df,
|
||
"median": median_df,
|
||
"quantile_low": quantile_low,
|
||
"quantile_high": quantile_high,
|
||
},
|
||
output_path,
|
||
)
|
||
|
||
offsets = mean_df.index.to_numpy(dtype=float)
|
||
n_vars = len(variables)
|
||
fig, axes = plt.subplots(n_vars, 1, figsize=(10, 3 * n_vars), sharex=True)
|
||
if n_vars == 1:
|
||
axes = [axes]
|
||
|
||
for ax, var in zip(axes, variables):
|
||
col = var.column
|
||
ax.axvline(0, color="black", linestyle="--", linewidth=1, label=baseline_label)
|
||
ax.plot(offsets, mean_df[col], color="tab:blue", label="Moyenne")
|
||
ax.plot(offsets, median_df[col], color="tab:orange", linestyle="--", label="Médiane")
|
||
|
||
if quantile_low is not None and quantile_high is not None:
|
||
ax.fill_between(
|
||
offsets,
|
||
quantile_low[col],
|
||
quantile_high[col],
|
||
color="tab:blue",
|
||
alpha=0.2,
|
||
label=f"IQR {int(q_low*100)}–{int(q_high*100)}%",
|
||
)
|
||
|
||
ylabel = f"{var.label} ({var.unit})" if var.unit else var.label
|
||
ax.set_ylabel(ylabel)
|
||
ax.grid(True, linestyle=":", alpha=0.5)
|
||
|
||
axes[-1].set_xlabel("Minutes autour de l'événement")
|
||
axes[0].legend(loc="upper right")
|
||
total_events = len(aligned_segments.index.get_level_values("event_id").unique())
|
||
fig.suptitle(f"Composites autour d'événements ({total_events} occurrences)")
|
||
|
||
fig.tight_layout(rect=[0, 0, 1, 0.97])
|
||
fig.savefig(output_path, dpi=150)
|
||
plt.close(fig)
|
||
|
||
return output_path.resolve()
|