"""Visualisations d'indicateurs de corrélation (heatmaps et séries décalées).""" from __future__ import annotations from pathlib import Path from typing import Sequence import matplotlib.pyplot as plt import numpy as np import pandas as pd from .base import export_plot_dataset from meteo.variables import Variable __all__ = ['plot_lagged_correlation', 'plot_correlation_heatmap', 'plot_rolling_correlation_heatmap'] def plot_lagged_correlation( lag_df: pd.DataFrame, var_x: Variable, var_y: Variable, output_path: str | Path, ) -> Path: """ Trace la corrélation en fonction du lag (en minutes) entre deux variables. """ output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) export_plot_dataset(lag_df, output_path) plt.figure() plt.plot(lag_df.index, lag_df["correlation"]) plt.axvline(0, linestyle="--") # lag = 0 plt.xlabel("Décalage (minutes)\n(lag > 0 : X précède Y)") plt.ylabel("Corrélation") plt.title(f"Corrélation décalée : {var_x.label} → {var_y.label}") plt.grid(True) plt.tight_layout() plt.savefig(output_path, dpi=150) plt.close() return output_path.resolve() def plot_correlation_heatmap( corr: pd.DataFrame, variables: Sequence[Variable], output_path: str | Path, *, annotate: bool = True, ) -> Path: """ Trace une heatmap de la matrice de corrélation. Paramètres ---------- corr : Matrice de corrélation (index et colonnes doivent correspondre aux noms de colonnes des variables). variables : Liste de Variable, dans l'ordre où elles doivent apparaître. output_path : Chemin du fichier image à écrire. annotate : Si True, affiche la valeur numérique dans chaque case. """ output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) columns = [v.column for v in variables] labels = [v.label for v in variables] # On aligne la matrice sur l'ordre désiré corr = corr.loc[columns, columns] export_plot_dataset(corr, output_path) data = corr.to_numpy() fig, ax = plt.subplots() im = ax.imshow(data, vmin=-1.0, vmax=1.0) # Ticks et labels ax.set_xticks(np.arange(len(labels))) ax.set_yticks(np.arange(len(labels))) ax.set_xticklabels(labels, rotation=45, ha="right") ax.set_yticklabels(labels) # Axe en haut/bas selon préférence (ici on laisse en bas) ax.set_title("Matrice de corrélation (coef. de Pearson)") # Barre de couleur cbar = plt.colorbar(im, ax=ax) cbar.set_label("Corrélation") # Annotation des cases if annotate: n = data.shape[0] for i in range(n): for j in range(n): if i == j: text = "—" else: val = data[i, j] if np.isnan(val): text = "" else: text = f"{val:.2f}" ax.text( j, i, text, ha="center", va="center", ) plt.tight_layout() plt.savefig(output_path, dpi=150) plt.close(fig) return output_path.resolve() def plot_rolling_correlation_heatmap( rolling_corr: pd.DataFrame, output_path: str | Path, *, cmap: str = "coolwarm", vmin: float = -1.0, vmax: float = 1.0, time_tick_count: int = 6, ) -> Path: """ Visualise l'évolution de corrélations glissantes pour plusieurs paires. """ output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) export_plot_dataset(rolling_corr, output_path) if rolling_corr.empty: fig, ax = plt.subplots() ax.text(0.5, 0.5, "Aucune donnée de corrélation glissante.", ha="center", va="center") ax.set_axis_off() fig.savefig(output_path, dpi=150, bbox_inches="tight") plt.close(fig) return output_path.resolve() labels = list(rolling_corr.columns) data = rolling_corr.to_numpy().T height = max(3.0, 0.6 * len(labels)) fig, ax = plt.subplots(figsize=(10, height)) im = ax.imshow(data, aspect="auto", cmap=cmap, vmin=vmin, vmax=vmax) ax.set_yticks(np.arange(len(labels))) ax.set_yticklabels(labels) if isinstance(rolling_corr.index, pd.DatetimeIndex): times = rolling_corr.index if len(times) > 1: tick_idx = np.linspace(0, len(times) - 1, num=min(time_tick_count, len(times)), dtype=int) else: tick_idx = np.array([0]) tick_labels = [times[i].strftime("%Y-%m-%d\n%H:%M") for i in tick_idx] else: tick_idx = np.linspace(0, len(rolling_corr.index) - 1, num=min(time_tick_count, len(rolling_corr.index)), dtype=int) tick_labels = [str(rolling_corr.index[i]) for i in tick_idx] ax.set_xticks(tick_idx) ax.set_xticklabels(tick_labels, rotation=30, ha="right") ax.set_xlabel("Temps (fin de fenêtre)") ax.set_ylabel("Paire de variables") ax.set_title("Corrélations glissantes") cbar = fig.colorbar(im, ax=ax) cbar.set_label("Coefficient de corrélation") fig.tight_layout() fig.savefig(output_path, dpi=150) plt.close(fig) return output_path.resolve()