183 lines
5.3 KiB
Python
183 lines
5.3 KiB
Python
"""Visualisations d'indicateurs de corrélation (heatmaps et séries décalées)."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
from typing import Sequence
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
from .base import export_plot_dataset
|
|
from meteo.variables import Variable
|
|
|
|
__all__ = ['plot_lagged_correlation', 'plot_correlation_heatmap', 'plot_rolling_correlation_heatmap']
|
|
|
|
|
|
def plot_lagged_correlation(
|
|
lag_df: pd.DataFrame,
|
|
var_x: Variable,
|
|
var_y: Variable,
|
|
output_path: str | Path,
|
|
) -> Path:
|
|
"""
|
|
Trace la corrélation en fonction du lag (en minutes) entre deux variables.
|
|
"""
|
|
output_path = Path(output_path)
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
export_plot_dataset(lag_df, output_path)
|
|
|
|
plt.figure()
|
|
plt.plot(lag_df.index, lag_df["correlation"])
|
|
plt.axvline(0, linestyle="--") # lag = 0
|
|
plt.xlabel("Décalage (minutes)\n(lag > 0 : X précède Y)")
|
|
plt.ylabel("Corrélation")
|
|
plt.title(f"Corrélation décalée : {var_x.label} → {var_y.label}")
|
|
plt.grid(True)
|
|
plt.tight_layout()
|
|
plt.savefig(output_path, dpi=150)
|
|
plt.close()
|
|
|
|
return output_path.resolve()
|
|
|
|
def plot_correlation_heatmap(
|
|
corr: pd.DataFrame,
|
|
variables: Sequence[Variable],
|
|
output_path: str | Path,
|
|
*,
|
|
annotate: bool = True,
|
|
) -> Path:
|
|
"""
|
|
Trace une heatmap de la matrice de corrélation.
|
|
|
|
Paramètres
|
|
----------
|
|
corr :
|
|
Matrice de corrélation (index et colonnes doivent correspondre
|
|
aux noms de colonnes des variables).
|
|
variables :
|
|
Liste de Variable, dans l'ordre où elles doivent apparaître.
|
|
output_path :
|
|
Chemin du fichier image à écrire.
|
|
annotate :
|
|
Si True, affiche la valeur numérique dans chaque case.
|
|
"""
|
|
output_path = Path(output_path)
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
columns = [v.column for v in variables]
|
|
labels = [v.label for v in variables]
|
|
|
|
# On aligne la matrice sur l'ordre désiré
|
|
corr = corr.loc[columns, columns]
|
|
export_plot_dataset(corr, output_path)
|
|
|
|
data = corr.to_numpy()
|
|
|
|
fig, ax = plt.subplots()
|
|
im = ax.imshow(data, vmin=-1.0, vmax=1.0)
|
|
|
|
# Ticks et labels
|
|
ax.set_xticks(np.arange(len(labels)))
|
|
ax.set_yticks(np.arange(len(labels)))
|
|
ax.set_xticklabels(labels, rotation=45, ha="right")
|
|
ax.set_yticklabels(labels)
|
|
|
|
# Axe en haut/bas selon préférence (ici on laisse en bas)
|
|
ax.set_title("Matrice de corrélation (coef. de Pearson)")
|
|
|
|
# Barre de couleur
|
|
cbar = plt.colorbar(im, ax=ax)
|
|
cbar.set_label("Corrélation")
|
|
|
|
# Annotation des cases
|
|
if annotate:
|
|
n = data.shape[0]
|
|
for i in range(n):
|
|
for j in range(n):
|
|
if i == j:
|
|
text = "—"
|
|
else:
|
|
val = data[i, j]
|
|
if np.isnan(val):
|
|
text = ""
|
|
else:
|
|
text = f"{val:.2f}"
|
|
ax.text(
|
|
j,
|
|
i,
|
|
text,
|
|
ha="center",
|
|
va="center",
|
|
)
|
|
|
|
plt.tight_layout()
|
|
plt.savefig(output_path, dpi=150)
|
|
plt.close(fig)
|
|
|
|
return output_path.resolve()
|
|
|
|
def plot_rolling_correlation_heatmap(
|
|
rolling_corr: pd.DataFrame,
|
|
output_path: str | Path,
|
|
*,
|
|
cmap: str = "coolwarm",
|
|
vmin: float = -1.0,
|
|
vmax: float = 1.0,
|
|
time_tick_count: int = 6,
|
|
) -> Path:
|
|
"""
|
|
Visualise l'évolution de corrélations glissantes pour plusieurs paires.
|
|
"""
|
|
output_path = Path(output_path)
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
export_plot_dataset(rolling_corr, output_path)
|
|
|
|
if rolling_corr.empty:
|
|
fig, ax = plt.subplots()
|
|
ax.text(0.5, 0.5, "Aucune donnée de corrélation glissante.", ha="center", va="center")
|
|
ax.set_axis_off()
|
|
fig.savefig(output_path, dpi=150, bbox_inches="tight")
|
|
plt.close(fig)
|
|
return output_path.resolve()
|
|
|
|
labels = list(rolling_corr.columns)
|
|
data = rolling_corr.to_numpy().T
|
|
|
|
height = max(3.0, 0.6 * len(labels))
|
|
fig, ax = plt.subplots(figsize=(10, height))
|
|
im = ax.imshow(data, aspect="auto", cmap=cmap, vmin=vmin, vmax=vmax)
|
|
|
|
ax.set_yticks(np.arange(len(labels)))
|
|
ax.set_yticklabels(labels)
|
|
|
|
if isinstance(rolling_corr.index, pd.DatetimeIndex):
|
|
times = rolling_corr.index
|
|
if len(times) > 1:
|
|
tick_idx = np.linspace(0, len(times) - 1, num=min(time_tick_count, len(times)), dtype=int)
|
|
else:
|
|
tick_idx = np.array([0])
|
|
tick_labels = [times[i].strftime("%Y-%m-%d\n%H:%M") for i in tick_idx]
|
|
else:
|
|
tick_idx = np.linspace(0, len(rolling_corr.index) - 1, num=min(time_tick_count, len(rolling_corr.index)), dtype=int)
|
|
tick_labels = [str(rolling_corr.index[i]) for i in tick_idx]
|
|
|
|
ax.set_xticks(tick_idx)
|
|
ax.set_xticklabels(tick_labels, rotation=30, ha="right")
|
|
|
|
ax.set_xlabel("Temps (fin de fenêtre)")
|
|
ax.set_ylabel("Paire de variables")
|
|
ax.set_title("Corrélations glissantes")
|
|
|
|
cbar = fig.colorbar(im, ax=ax)
|
|
cbar.set_label("Coefficient de corrélation")
|
|
|
|
fig.tight_layout()
|
|
fig.savefig(output_path, dpi=150)
|
|
plt.close(fig)
|
|
|
|
return output_path.resolve()
|