1
donnees_meteo/meteo/plots/correlations.py

318 lines
9.7 KiB
Python

"""Visualisations d'indicateurs de corrélation (heatmaps et séries décalées)."""
from __future__ import annotations
from pathlib import Path
from typing import Iterable, Sequence
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from meteo.correlation_presets import CorrelationBand
from .base import export_plot_dataset
from meteo.variables import Variable
__all__ = [
'plot_lagged_correlation',
'plot_lagged_correlation_multi',
'plot_correlation_heatmap',
'plot_rolling_correlation_heatmap',
'CorrelationBand',
]
def plot_lagged_correlation(
lag_df: pd.DataFrame,
var_x: Variable,
var_y: Variable,
output_path: str | Path,
) -> Path:
"""
Trace la corrélation en fonction du lag (en minutes) entre deux variables.
"""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
export_plot_dataset(lag_df, output_path)
plt.figure()
plt.plot(lag_df.index, lag_df["correlation"])
plt.axvline(0, linestyle="--") # lag = 0
plt.xlabel("Décalage (minutes)\n(lag > 0 : X précède Y)")
plt.ylabel("Corrélation")
plt.title(f"Corrélation décalée : {var_x.label}{var_y.label}")
plt.grid(True)
plt.tight_layout()
plt.savefig(output_path, dpi=150)
plt.close()
return output_path.resolve()
def plot_lagged_correlation_multi(
lag_series: dict[str, pd.Series],
var_x: Variable,
var_y: Variable,
output_path: str | Path,
*,
title_suffix: str | None = None,
ylabel: str = "Corrélation",
y_limits: tuple[float, float] | None = None,
thresholds: Sequence[float] | None = None,
bands: Iterable["CorrelationBand"] | None = None,
) -> Path:
"""
Trace plusieurs courbes de corrélation en fonction du lag (ex. Pearson/Spearman).
"""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
df = pd.concat(lag_series, axis=1)
export_plot_dataset(df, output_path)
plt.figure()
colors = ["#1f77b4", "#d1495b", "#2ca02c", "#9467bd"]
for idx, (label, series) in enumerate(df.items()):
plt.plot(series.index, series, label=label, color=colors[idx % len(colors)], linewidth=1.6)
ax = plt.gca()
if bands:
xmin, xmax = df.index.min(), df.index.max()
for band in bands:
ax.axhspan(band.min_value, band.max_value, color=band.color, alpha=0.25, zorder=0)
ax.text(
xmax,
(band.min_value + band.max_value) / 2.0,
band.label,
ha="right",
va="center",
fontsize=8,
color="#444444",
bbox=dict(facecolor="white", edgecolor="none", alpha=0.6, pad=1.5),
)
plt.axvline(0, linestyle="--", color="#666666", linewidth=1.0)
plt.xlabel("Décalage (minutes)\n(lag > 0 : X précède Y)")
plt.ylabel(ylabel)
title = f"Corrélation décalée : {var_x.label}{var_y.label}"
if title_suffix:
title = f"{title} ({title_suffix})"
plt.title(title)
if thresholds:
xmin, xmax = plt.xlim()
for thr in thresholds:
plt.axhline(thr, color="#999999", linestyle="--", linewidth=1.0, alpha=0.85)
plt.text(
xmax,
thr,
f"{thr:.2f}",
ha="right",
va="center",
fontsize=8,
color="#555555",
bbox=dict(facecolor="white", edgecolor="none", alpha=0.7, pad=1.5),
)
if y_limits is not None:
plt.ylim(*y_limits)
plt.grid(True, alpha=0.7)
plt.legend()
plt.tight_layout()
plt.savefig(output_path, dpi=150)
plt.close()
return output_path.resolve()
def plot_correlation_heatmap(
corr: pd.DataFrame,
variables: Sequence[Variable],
output_path: str | Path,
*,
annotate: bool = True,
annotate_values: "pd.DataFrame | None" = None,
title: str | None = None,
figsize: tuple[float, float] | None = None,
cmap: str | None = None,
vmin: float | None = None,
vmax: float | None = None,
colorbar_label: str | None = None,
) -> Path:
"""
Trace une heatmap de la matrice de corrélation.
Paramètres
----------
corr :
Matrice de corrélation (index et colonnes doivent correspondre
aux noms de colonnes des variables).
variables :
Liste de Variable, dans l'ordre où elles doivent apparaître.
output_path :
Chemin du fichier image à écrire.
annotate :
Si True, affiche la valeur numérique dans chaque case.
title :
Titre personalisé (par défaut, libellé générique).
cmap :
Nom de la palette matplotlib à utiliser (par défaut, palette standard).
vmin / vmax :
Borne d'échelle de couleurs. Si None, valeurs classiques [-1, 1].
colorbar_label :
Libellé pour la barre de couleur (par défaut "Corrélation").
"""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
columns = [v.column for v in variables]
labels = [v.label for v in variables]
# On aligne la matrice sur l'ordre désiré
corr = corr.loc[columns, columns]
export_plot_dataset(corr, output_path)
data = corr.to_numpy()
if figsize is None:
n = len(variables)
# Augmente la taille pour laisser respirer les annotations
side = max(6.0, n * 0.9)
figsize = (side, side)
fig, ax = plt.subplots(figsize=figsize)
if vmin is None:
vmin = -1.0
if vmax is None:
vmax = 1.0
im_kwargs = {"vmin": vmin, "vmax": vmax}
if cmap is not None:
im_kwargs["cmap"] = cmap
im = ax.imshow(data, **im_kwargs)
# Ticks et labels
ax.set_xticks(np.arange(len(labels)))
ax.set_yticks(np.arange(len(labels)))
ax.set_xticklabels(labels, rotation=45, ha="right")
ax.set_yticklabels(labels)
# Axe en haut/bas selon préférence (ici on laisse en bas)
ax.set_title(title or "Matrice de corrélation")
# Barre de couleur
cbar = plt.colorbar(im, ax=ax)
cbar.set_label(colorbar_label or "Corrélation")
# Annotation des cases
if annotate:
n = data.shape[0]
annot_data = (
annotate_values.loc[columns, columns].to_numpy()
if annotate_values is not None
else data
)
norm = im.norm
cmap_obj = im.cmap
def _text_color(value: float) -> str:
rgba = cmap_obj(norm(value))
r, g, b, _ = rgba
luminance = 0.2126 * r + 0.7152 * g + 0.0722 * b
return "white" if luminance < 0.5 else "black"
for i in range(n):
for j in range(n):
val_corr = data[i, j]
val_annot = annot_data[i, j]
if i == j:
text = ""
elif isinstance(val_annot, (float, int, np.floating)) and np.isnan(val_annot):
text = ""
else:
# si annotate_values est fourni, on affiche la valeur annotée brute
if annotate_values is not None:
text = str(val_annot)
else:
text = f"{val_corr:.2f}"
if not text:
continue
color = _text_color(0.0 if np.isnan(val_corr) else val_corr)
ax.text(
j,
i,
text,
ha="center",
va="center",
color=color,
)
plt.tight_layout()
plt.savefig(output_path, dpi=150)
plt.close(fig)
return output_path.resolve()
def plot_rolling_correlation_heatmap(
rolling_corr: pd.DataFrame,
output_path: str | Path,
*,
cmap: str = "coolwarm",
vmin: float = -1.0,
vmax: float = 1.0,
time_tick_count: int = 6,
) -> Path:
"""
Visualise l'évolution de corrélations glissantes pour plusieurs paires.
"""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
export_plot_dataset(rolling_corr, output_path)
if rolling_corr.empty:
fig, ax = plt.subplots()
ax.text(0.5, 0.5, "Aucune donnée de corrélation glissante.", ha="center", va="center")
ax.set_axis_off()
fig.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close(fig)
return output_path.resolve()
labels = list(rolling_corr.columns)
data = rolling_corr.to_numpy().T
height = max(3.0, 0.6 * len(labels))
fig, ax = plt.subplots(figsize=(10, height))
im = ax.imshow(data, aspect="auto", cmap=cmap, vmin=vmin, vmax=vmax)
ax.set_yticks(np.arange(len(labels)))
ax.set_yticklabels(labels)
if isinstance(rolling_corr.index, pd.DatetimeIndex):
times = rolling_corr.index
if len(times) > 1:
tick_idx = np.linspace(0, len(times) - 1, num=min(time_tick_count, len(times)), dtype=int)
else:
tick_idx = np.array([0])
tick_labels = [times[i].strftime("%Y-%m-%d\n%H:%M") for i in tick_idx]
else:
tick_idx = np.linspace(0, len(rolling_corr.index) - 1, num=min(time_tick_count, len(rolling_corr.index)), dtype=int)
tick_labels = [str(rolling_corr.index[i]) for i in tick_idx]
ax.set_xticks(tick_idx)
ax.set_xticklabels(tick_labels, rotation=30, ha="right")
ax.set_xlabel("Temps (fin de fenêtre)")
ax.set_ylabel("Paire de variables")
ax.set_title("Corrélations glissantes")
cbar = fig.colorbar(im, ax=ax)
cbar.set_label("Coefficient de corrélation")
fig.tight_layout()
fig.savefig(output_path, dpi=150)
plt.close(fig)
return output_path.resolve()