1

Ajout des matrices de corrélation + Refactoring

This commit is contained in:
Richard Dern 2025-11-19 23:31:38 +01:00
parent 3a1f7e2a7e
commit a4d3ce7b49
13 changed files with 165 additions and 26 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 46 KiB

After

Width:  |  Height:  |  Size: 46 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 55 KiB

After

Width:  |  Height:  |  Size: 55 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 49 KiB

After

Width:  |  Height:  |  Size: 49 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 45 KiB

After

Width:  |  Height:  |  Size: 45 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 46 KiB

After

Width:  |  Height:  |  Size: 46 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 45 KiB

After

Width:  |  Height:  |  Size: 45 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 127 KiB

After

Width:  |  Height:  |  Size: 117 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 116 KiB

View File

@ -63,10 +63,12 @@ python "docs/04 - Corrélations binaires/scripts/plot_all_pairwise_scatter.py"
![](figures/pairwise_scatter/scatter_wind_speed_vs_wind_direction.png) ![](figures/pairwise_scatter/scatter_wind_speed_vs_wind_direction.png)
## Heatmap globale ## Matrices de corrélation
```shell ```shell
python "docs/04 - Corrélations binaires/scripts/plot_correlation_heatmap.py" python "docs/04 - Corrélations binaires/scripts/plot_correlation_heatmap.py"
``` ```
![](figures/correlation_heatmap.png) ![](figures/correlation_heatmap.png)
![](figures/correlation_heatmap_spearman.png)

View File

@ -1,6 +1,7 @@
# scripts/plot_correlation_heatmap.py # scripts/plot_correlation_heatmap.py
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
import sys import sys
@ -11,13 +12,54 @@ if str(PROJECT_ROOT) not in sys.path:
from meteo.dataset import load_raw_csv from meteo.dataset import load_raw_csv
from meteo.variables import VARIABLES from meteo.variables import VARIABLES
from meteo.analysis import compute_correlation_matrix_for_variables from meteo.analysis import compute_correlation_matrices_for_methods
from meteo.plots import plot_correlation_heatmap from meteo.plots import plot_correlation_heatmap
CSV_PATH = Path("data/weather_minutely.csv") CSV_PATH = Path("data/weather_minutely.csv")
DOC_DIR = Path(__file__).resolve().parent.parent DOC_DIR = Path(__file__).resolve().parent.parent
OUTPUT_PATH = DOC_DIR / "figures" / "correlation_heatmap.png"
CORRELATION_METHODS: tuple[str, ...] = ("pearson", "spearman")
CORRELATION_TRANSFORM = "square"
@dataclass(frozen=True)
class HeatmapConfig:
filename: str
title: str
colorbar_label: str
cmap: str = "viridis"
vmin: float = 0.0
vmax: float = 1.0
HEATMAP_CONFIGS: dict[str, HeatmapConfig] = {
"pearson": HeatmapConfig(
filename="correlation_heatmap.png",
title="Corrélations R² (coef. de Pearson)",
colorbar_label="Coefficient de corrélation R²",
),
"spearman": HeatmapConfig(
filename="correlation_heatmap_spearman.png",
title="Corrélations R² (coef. de Spearman)",
colorbar_label="Coefficient de corrélation R²",
),
}
def _get_heatmap_config(method: str) -> HeatmapConfig:
if method in HEATMAP_CONFIGS:
return HEATMAP_CONFIGS[method]
# Valeurs par défaut pour un scénario non prévu.
return HeatmapConfig(
filename=f"correlation_heatmap_{method}.png",
title=f"Matrice de corrélation ({method})",
colorbar_label="Coefficient de corrélation",
cmap="viridis" if CORRELATION_TRANSFORM == "square" else "coolwarm",
vmin=0.0 if CORRELATION_TRANSFORM == "square" else -1.0,
vmax=1.0,
)
def main() -> None: def main() -> None:
@ -32,20 +74,32 @@ def main() -> None:
print(f" Colonnes : {list(df.columns)}") print(f" Colonnes : {list(df.columns)}")
print() print()
corr = compute_correlation_matrix_for_variables(df, VARIABLES, method="pearson") matrices = compute_correlation_matrices_for_methods(
df=df,
print("Matrice de corrélation (aperçu) :")
print(corr)
print()
output_path = plot_correlation_heatmap(
corr=corr,
variables=VARIABLES, variables=VARIABLES,
output_path=OUTPUT_PATH, methods=CORRELATION_METHODS,
annotate=True, transform=CORRELATION_TRANSFORM,
) )
print(f"✔ Heatmap de corrélation sauvegardée dans : {output_path}") for method, corr in matrices.items():
print(f"Matrice de corrélation (méthode={method}, transform={CORRELATION_TRANSFORM}) :")
print(corr)
print()
config = _get_heatmap_config(method)
output_path = plot_correlation_heatmap(
corr=corr,
variables=VARIABLES,
output_path=DOC_DIR / "figures" / config.filename,
annotate=True,
title=config.title,
cmap=config.cmap,
vmin=config.vmin,
vmax=config.vmax,
colorbar_label=config.colorbar_label,
)
print(f"✔ Heatmap de corrélation sauvegardée dans : {output_path}")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -6,9 +6,11 @@ from .core import BinnedStatistics, DiurnalCycleStats, MONTH_ORDER
from .correlations import ( from .correlations import (
compute_correlation_matrix, compute_correlation_matrix,
compute_correlation_matrix_for_variables, compute_correlation_matrix_for_variables,
compute_correlation_matrices_for_methods,
compute_lagged_correlation, compute_lagged_correlation,
compute_rolling_correlation_series, compute_rolling_correlation_series,
compute_rolling_correlations_for_pairs, compute_rolling_correlations_for_pairs,
transform_correlation_matrix,
) )
from .events import build_event_aligned_segments, detect_threshold_events from .events import build_event_aligned_segments, detect_threshold_events
from .filters import filter_by_condition from .filters import filter_by_condition
@ -28,9 +30,11 @@ __all__ = [
"MONTH_ORDER", "MONTH_ORDER",
"compute_correlation_matrix", "compute_correlation_matrix",
"compute_correlation_matrix_for_variables", "compute_correlation_matrix_for_variables",
"compute_correlation_matrices_for_methods",
"compute_lagged_correlation", "compute_lagged_correlation",
"compute_rolling_correlation_series", "compute_rolling_correlation_series",
"compute_rolling_correlations_for_pairs", "compute_rolling_correlations_for_pairs",
"transform_correlation_matrix",
"build_event_aligned_segments", "build_event_aligned_segments",
"detect_threshold_events", "detect_threshold_events",
"filter_by_condition", "filter_by_condition",

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Literal, Sequence from typing import Callable, Literal, Sequence
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -11,13 +11,16 @@ from meteo.variables import Variable
from .core import _ensure_datetime_index from .core import _ensure_datetime_index
__all__ = ['compute_correlation_matrix', 'compute_correlation_matrix_for_variables', 'compute_lagged_correlation', 'compute_rolling_correlation_series', 'compute_rolling_correlations_for_pairs'] __all__ = ['compute_correlation_matrix', 'compute_correlation_matrix_for_variables', 'compute_correlation_matrices_for_methods', 'compute_lagged_correlation', 'compute_rolling_correlation_series', 'compute_rolling_correlations_for_pairs', 'transform_correlation_matrix']
CorrelationMethod = Literal["pearson", "spearman", "kendall"]
CorrelationTransform = Literal["identity", "absolute", "square"]
def compute_correlation_matrix( def compute_correlation_matrix(
df: pd.DataFrame, df: pd.DataFrame,
*, *,
method: Literal["pearson", "spearman"] = "pearson", method: CorrelationMethod = "pearson",
) -> pd.DataFrame: ) -> pd.DataFrame:
""" """
Calcule la matrice de corrélation entre toutes les colonnes numériques Calcule la matrice de corrélation entre toutes les colonnes numériques
@ -36,7 +39,7 @@ def compute_correlation_matrix_for_variables(
df: pd.DataFrame, df: pd.DataFrame,
variables: Sequence[Variable], variables: Sequence[Variable],
*, *,
method: Literal["pearson", "spearman"] = "pearson", method: CorrelationMethod = "pearson",
) -> pd.DataFrame: ) -> pd.DataFrame:
""" """
Calcule la matrice de corrélation pour un sous-ensemble de variables, Calcule la matrice de corrélation pour un sous-ensemble de variables,
@ -70,6 +73,46 @@ def compute_correlation_matrix_for_variables(
corr = corr.loc[columns, columns] corr = corr.loc[columns, columns]
return corr return corr
def transform_correlation_matrix(
corr: pd.DataFrame,
*,
transform: CorrelationTransform | Callable[[pd.DataFrame], pd.DataFrame] = "identity",
) -> pd.DataFrame:
"""Applique une transformation générique sur une matrice de corrélation."""
if callable(transform):
return transform(corr)
if transform == "identity":
return corr
if transform == "absolute":
return corr.abs()
if transform == "square":
return corr.pow(2)
raise ValueError(f"Transformation de corrélation inconnue : {transform!r}")
def compute_correlation_matrices_for_methods(
df: pd.DataFrame,
variables: Sequence[Variable],
*,
methods: Sequence[CorrelationMethod],
transform: CorrelationTransform | Callable[[pd.DataFrame], pd.DataFrame] = "identity",
) -> dict[str, pd.DataFrame]:
"""Calcule plusieurs matrices de corrélation en une seule passe."""
if not methods:
raise ValueError("La liste des méthodes de corrélation est vide.")
matrices: dict[str, pd.DataFrame] = {}
for method in methods:
corr = compute_correlation_matrix_for_variables(df, variables, method=method)
matrices[method] = transform_correlation_matrix(corr, transform=transform)
return matrices
def compute_lagged_correlation( def compute_lagged_correlation(
df: pd.DataFrame, df: pd.DataFrame,
var_x: Variable, var_x: Variable,

View File

@ -48,6 +48,11 @@ def plot_correlation_heatmap(
output_path: str | Path, output_path: str | Path,
*, *,
annotate: bool = True, annotate: bool = True,
title: str | None = None,
cmap: str | None = None,
vmin: float | None = None,
vmax: float | None = None,
colorbar_label: str | None = None,
) -> Path: ) -> Path:
""" """
Trace une heatmap de la matrice de corrélation. Trace une heatmap de la matrice de corrélation.
@ -63,6 +68,14 @@ def plot_correlation_heatmap(
Chemin du fichier image à écrire. Chemin du fichier image à écrire.
annotate : annotate :
Si True, affiche la valeur numérique dans chaque case. Si True, affiche la valeur numérique dans chaque case.
title :
Titre personalisé (par défaut, libellé générique).
cmap :
Nom de la palette matplotlib à utiliser (par défaut, palette standard).
vmin / vmax :
Borne d'échelle de couleurs. Si None, valeurs classiques [-1, 1].
colorbar_label :
Libellé pour la barre de couleur (par défaut "Corrélation").
""" """
output_path = Path(output_path) output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True) output_path.parent.mkdir(parents=True, exist_ok=True)
@ -77,7 +90,16 @@ def plot_correlation_heatmap(
data = corr.to_numpy() data = corr.to_numpy()
fig, ax = plt.subplots() fig, ax = plt.subplots()
im = ax.imshow(data, vmin=-1.0, vmax=1.0) if vmin is None:
vmin = -1.0
if vmax is None:
vmax = 1.0
im_kwargs = {"vmin": vmin, "vmax": vmax}
if cmap is not None:
im_kwargs["cmap"] = cmap
im = ax.imshow(data, **im_kwargs)
# Ticks et labels # Ticks et labels
ax.set_xticks(np.arange(len(labels))) ax.set_xticks(np.arange(len(labels)))
@ -86,31 +108,45 @@ def plot_correlation_heatmap(
ax.set_yticklabels(labels) ax.set_yticklabels(labels)
# Axe en haut/bas selon préférence (ici on laisse en bas) # Axe en haut/bas selon préférence (ici on laisse en bas)
ax.set_title("Matrice de corrélation (coef. de Pearson)") ax.set_title(title or "Matrice de corrélation")
# Barre de couleur # Barre de couleur
cbar = plt.colorbar(im, ax=ax) cbar = plt.colorbar(im, ax=ax)
cbar.set_label("Corrélation") cbar.set_label(colorbar_label or "Corrélation")
# Annotation des cases # Annotation des cases
if annotate: if annotate:
n = data.shape[0] n = data.shape[0]
norm = im.norm
cmap_obj = im.cmap
def _text_color(value: float) -> str:
rgba = cmap_obj(norm(value))
r, g, b, _ = rgba
luminance = 0.2126 * r + 0.7152 * g + 0.0722 * b
return "white" if luminance < 0.5 else "black"
for i in range(n): for i in range(n):
for j in range(n): for j in range(n):
val = data[i, j]
if i == j: if i == j:
text = "" text = ""
elif np.isnan(val):
text = ""
else: else:
val = data[i, j] text = f"{val:.2f}"
if np.isnan(val):
text = "" if not text:
else: continue
text = f"{val:.2f}"
color = _text_color(0.0 if np.isnan(val) else val)
ax.text( ax.text(
j, j,
i, i,
text, text,
ha="center", ha="center",
va="center", va="center",
color=color,
) )
plt.tight_layout() plt.tight_layout()