Ajout des matrices de corrélation + Refactoring
|
Before Width: | Height: | Size: 46 KiB After Width: | Height: | Size: 46 KiB |
|
Before Width: | Height: | Size: 55 KiB After Width: | Height: | Size: 55 KiB |
|
Before Width: | Height: | Size: 49 KiB After Width: | Height: | Size: 49 KiB |
|
Before Width: | Height: | Size: 45 KiB After Width: | Height: | Size: 45 KiB |
|
Before Width: | Height: | Size: 46 KiB After Width: | Height: | Size: 46 KiB |
|
Before Width: | Height: | Size: 45 KiB After Width: | Height: | Size: 45 KiB |
|
Before Width: | Height: | Size: 127 KiB After Width: | Height: | Size: 117 KiB |
|
After Width: | Height: | Size: 116 KiB |
@ -63,10 +63,12 @@ python "docs/04 - Corrélations binaires/scripts/plot_all_pairwise_scatter.py"
|
|||||||
|
|
||||||

|

|
||||||
|
|
||||||
## Heatmap globale
|
## Matrices de corrélation
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
python "docs/04 - Corrélations binaires/scripts/plot_correlation_heatmap.py"
|
python "docs/04 - Corrélations binaires/scripts/plot_correlation_heatmap.py"
|
||||||
```
|
```
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
|

|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# scripts/plot_correlation_heatmap.py
|
# scripts/plot_correlation_heatmap.py
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
@ -11,13 +12,54 @@ if str(PROJECT_ROOT) not in sys.path:
|
|||||||
|
|
||||||
from meteo.dataset import load_raw_csv
|
from meteo.dataset import load_raw_csv
|
||||||
from meteo.variables import VARIABLES
|
from meteo.variables import VARIABLES
|
||||||
from meteo.analysis import compute_correlation_matrix_for_variables
|
from meteo.analysis import compute_correlation_matrices_for_methods
|
||||||
from meteo.plots import plot_correlation_heatmap
|
from meteo.plots import plot_correlation_heatmap
|
||||||
|
|
||||||
|
|
||||||
CSV_PATH = Path("data/weather_minutely.csv")
|
CSV_PATH = Path("data/weather_minutely.csv")
|
||||||
DOC_DIR = Path(__file__).resolve().parent.parent
|
DOC_DIR = Path(__file__).resolve().parent.parent
|
||||||
OUTPUT_PATH = DOC_DIR / "figures" / "correlation_heatmap.png"
|
|
||||||
|
CORRELATION_METHODS: tuple[str, ...] = ("pearson", "spearman")
|
||||||
|
CORRELATION_TRANSFORM = "square"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class HeatmapConfig:
|
||||||
|
filename: str
|
||||||
|
title: str
|
||||||
|
colorbar_label: str
|
||||||
|
cmap: str = "viridis"
|
||||||
|
vmin: float = 0.0
|
||||||
|
vmax: float = 1.0
|
||||||
|
|
||||||
|
|
||||||
|
HEATMAP_CONFIGS: dict[str, HeatmapConfig] = {
|
||||||
|
"pearson": HeatmapConfig(
|
||||||
|
filename="correlation_heatmap.png",
|
||||||
|
title="Corrélations R² (coef. de Pearson)",
|
||||||
|
colorbar_label="Coefficient de corrélation R²",
|
||||||
|
),
|
||||||
|
"spearman": HeatmapConfig(
|
||||||
|
filename="correlation_heatmap_spearman.png",
|
||||||
|
title="Corrélations R² (coef. de Spearman)",
|
||||||
|
colorbar_label="Coefficient de corrélation R²",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_heatmap_config(method: str) -> HeatmapConfig:
|
||||||
|
if method in HEATMAP_CONFIGS:
|
||||||
|
return HEATMAP_CONFIGS[method]
|
||||||
|
|
||||||
|
# Valeurs par défaut pour un scénario non prévu.
|
||||||
|
return HeatmapConfig(
|
||||||
|
filename=f"correlation_heatmap_{method}.png",
|
||||||
|
title=f"Matrice de corrélation ({method})",
|
||||||
|
colorbar_label="Coefficient de corrélation",
|
||||||
|
cmap="viridis" if CORRELATION_TRANSFORM == "square" else "coolwarm",
|
||||||
|
vmin=0.0 if CORRELATION_TRANSFORM == "square" else -1.0,
|
||||||
|
vmax=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
@ -32,20 +74,32 @@ def main() -> None:
|
|||||||
print(f" Colonnes : {list(df.columns)}")
|
print(f" Colonnes : {list(df.columns)}")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
corr = compute_correlation_matrix_for_variables(df, VARIABLES, method="pearson")
|
matrices = compute_correlation_matrices_for_methods(
|
||||||
|
df=df,
|
||||||
print("Matrice de corrélation (aperçu) :")
|
|
||||||
print(corr)
|
|
||||||
print()
|
|
||||||
|
|
||||||
output_path = plot_correlation_heatmap(
|
|
||||||
corr=corr,
|
|
||||||
variables=VARIABLES,
|
variables=VARIABLES,
|
||||||
output_path=OUTPUT_PATH,
|
methods=CORRELATION_METHODS,
|
||||||
annotate=True,
|
transform=CORRELATION_TRANSFORM,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"✔ Heatmap de corrélation sauvegardée dans : {output_path}")
|
for method, corr in matrices.items():
|
||||||
|
print(f"Matrice de corrélation (méthode={method}, transform={CORRELATION_TRANSFORM}) :")
|
||||||
|
print(corr)
|
||||||
|
print()
|
||||||
|
|
||||||
|
config = _get_heatmap_config(method)
|
||||||
|
output_path = plot_correlation_heatmap(
|
||||||
|
corr=corr,
|
||||||
|
variables=VARIABLES,
|
||||||
|
output_path=DOC_DIR / "figures" / config.filename,
|
||||||
|
annotate=True,
|
||||||
|
title=config.title,
|
||||||
|
cmap=config.cmap,
|
||||||
|
vmin=config.vmin,
|
||||||
|
vmax=config.vmax,
|
||||||
|
colorbar_label=config.colorbar_label,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"✔ Heatmap de corrélation sauvegardée dans : {output_path}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -6,9 +6,11 @@ from .core import BinnedStatistics, DiurnalCycleStats, MONTH_ORDER
|
|||||||
from .correlations import (
|
from .correlations import (
|
||||||
compute_correlation_matrix,
|
compute_correlation_matrix,
|
||||||
compute_correlation_matrix_for_variables,
|
compute_correlation_matrix_for_variables,
|
||||||
|
compute_correlation_matrices_for_methods,
|
||||||
compute_lagged_correlation,
|
compute_lagged_correlation,
|
||||||
compute_rolling_correlation_series,
|
compute_rolling_correlation_series,
|
||||||
compute_rolling_correlations_for_pairs,
|
compute_rolling_correlations_for_pairs,
|
||||||
|
transform_correlation_matrix,
|
||||||
)
|
)
|
||||||
from .events import build_event_aligned_segments, detect_threshold_events
|
from .events import build_event_aligned_segments, detect_threshold_events
|
||||||
from .filters import filter_by_condition
|
from .filters import filter_by_condition
|
||||||
@ -28,9 +30,11 @@ __all__ = [
|
|||||||
"MONTH_ORDER",
|
"MONTH_ORDER",
|
||||||
"compute_correlation_matrix",
|
"compute_correlation_matrix",
|
||||||
"compute_correlation_matrix_for_variables",
|
"compute_correlation_matrix_for_variables",
|
||||||
|
"compute_correlation_matrices_for_methods",
|
||||||
"compute_lagged_correlation",
|
"compute_lagged_correlation",
|
||||||
"compute_rolling_correlation_series",
|
"compute_rolling_correlation_series",
|
||||||
"compute_rolling_correlations_for_pairs",
|
"compute_rolling_correlations_for_pairs",
|
||||||
|
"transform_correlation_matrix",
|
||||||
"build_event_aligned_segments",
|
"build_event_aligned_segments",
|
||||||
"detect_threshold_events",
|
"detect_threshold_events",
|
||||||
"filter_by_condition",
|
"filter_by_condition",
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Literal, Sequence
|
from typing import Callable, Literal, Sequence
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@ -11,13 +11,16 @@ from meteo.variables import Variable
|
|||||||
|
|
||||||
from .core import _ensure_datetime_index
|
from .core import _ensure_datetime_index
|
||||||
|
|
||||||
__all__ = ['compute_correlation_matrix', 'compute_correlation_matrix_for_variables', 'compute_lagged_correlation', 'compute_rolling_correlation_series', 'compute_rolling_correlations_for_pairs']
|
__all__ = ['compute_correlation_matrix', 'compute_correlation_matrix_for_variables', 'compute_correlation_matrices_for_methods', 'compute_lagged_correlation', 'compute_rolling_correlation_series', 'compute_rolling_correlations_for_pairs', 'transform_correlation_matrix']
|
||||||
|
|
||||||
|
CorrelationMethod = Literal["pearson", "spearman", "kendall"]
|
||||||
|
CorrelationTransform = Literal["identity", "absolute", "square"]
|
||||||
|
|
||||||
|
|
||||||
def compute_correlation_matrix(
|
def compute_correlation_matrix(
|
||||||
df: pd.DataFrame,
|
df: pd.DataFrame,
|
||||||
*,
|
*,
|
||||||
method: Literal["pearson", "spearman"] = "pearson",
|
method: CorrelationMethod = "pearson",
|
||||||
) -> pd.DataFrame:
|
) -> pd.DataFrame:
|
||||||
"""
|
"""
|
||||||
Calcule la matrice de corrélation entre toutes les colonnes numériques
|
Calcule la matrice de corrélation entre toutes les colonnes numériques
|
||||||
@ -36,7 +39,7 @@ def compute_correlation_matrix_for_variables(
|
|||||||
df: pd.DataFrame,
|
df: pd.DataFrame,
|
||||||
variables: Sequence[Variable],
|
variables: Sequence[Variable],
|
||||||
*,
|
*,
|
||||||
method: Literal["pearson", "spearman"] = "pearson",
|
method: CorrelationMethod = "pearson",
|
||||||
) -> pd.DataFrame:
|
) -> pd.DataFrame:
|
||||||
"""
|
"""
|
||||||
Calcule la matrice de corrélation pour un sous-ensemble de variables,
|
Calcule la matrice de corrélation pour un sous-ensemble de variables,
|
||||||
@ -70,6 +73,46 @@ def compute_correlation_matrix_for_variables(
|
|||||||
corr = corr.loc[columns, columns]
|
corr = corr.loc[columns, columns]
|
||||||
return corr
|
return corr
|
||||||
|
|
||||||
|
|
||||||
|
def transform_correlation_matrix(
|
||||||
|
corr: pd.DataFrame,
|
||||||
|
*,
|
||||||
|
transform: CorrelationTransform | Callable[[pd.DataFrame], pd.DataFrame] = "identity",
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""Applique une transformation générique sur une matrice de corrélation."""
|
||||||
|
|
||||||
|
if callable(transform):
|
||||||
|
return transform(corr)
|
||||||
|
|
||||||
|
if transform == "identity":
|
||||||
|
return corr
|
||||||
|
if transform == "absolute":
|
||||||
|
return corr.abs()
|
||||||
|
if transform == "square":
|
||||||
|
return corr.pow(2)
|
||||||
|
|
||||||
|
raise ValueError(f"Transformation de corrélation inconnue : {transform!r}")
|
||||||
|
|
||||||
|
|
||||||
|
def compute_correlation_matrices_for_methods(
|
||||||
|
df: pd.DataFrame,
|
||||||
|
variables: Sequence[Variable],
|
||||||
|
*,
|
||||||
|
methods: Sequence[CorrelationMethod],
|
||||||
|
transform: CorrelationTransform | Callable[[pd.DataFrame], pd.DataFrame] = "identity",
|
||||||
|
) -> dict[str, pd.DataFrame]:
|
||||||
|
"""Calcule plusieurs matrices de corrélation en une seule passe."""
|
||||||
|
|
||||||
|
if not methods:
|
||||||
|
raise ValueError("La liste des méthodes de corrélation est vide.")
|
||||||
|
|
||||||
|
matrices: dict[str, pd.DataFrame] = {}
|
||||||
|
for method in methods:
|
||||||
|
corr = compute_correlation_matrix_for_variables(df, variables, method=method)
|
||||||
|
matrices[method] = transform_correlation_matrix(corr, transform=transform)
|
||||||
|
|
||||||
|
return matrices
|
||||||
|
|
||||||
def compute_lagged_correlation(
|
def compute_lagged_correlation(
|
||||||
df: pd.DataFrame,
|
df: pd.DataFrame,
|
||||||
var_x: Variable,
|
var_x: Variable,
|
||||||
|
|||||||
@ -48,6 +48,11 @@ def plot_correlation_heatmap(
|
|||||||
output_path: str | Path,
|
output_path: str | Path,
|
||||||
*,
|
*,
|
||||||
annotate: bool = True,
|
annotate: bool = True,
|
||||||
|
title: str | None = None,
|
||||||
|
cmap: str | None = None,
|
||||||
|
vmin: float | None = None,
|
||||||
|
vmax: float | None = None,
|
||||||
|
colorbar_label: str | None = None,
|
||||||
) -> Path:
|
) -> Path:
|
||||||
"""
|
"""
|
||||||
Trace une heatmap de la matrice de corrélation.
|
Trace une heatmap de la matrice de corrélation.
|
||||||
@ -63,6 +68,14 @@ def plot_correlation_heatmap(
|
|||||||
Chemin du fichier image à écrire.
|
Chemin du fichier image à écrire.
|
||||||
annotate :
|
annotate :
|
||||||
Si True, affiche la valeur numérique dans chaque case.
|
Si True, affiche la valeur numérique dans chaque case.
|
||||||
|
title :
|
||||||
|
Titre personalisé (par défaut, libellé générique).
|
||||||
|
cmap :
|
||||||
|
Nom de la palette matplotlib à utiliser (par défaut, palette standard).
|
||||||
|
vmin / vmax :
|
||||||
|
Borne d'échelle de couleurs. Si None, valeurs classiques [-1, 1].
|
||||||
|
colorbar_label :
|
||||||
|
Libellé pour la barre de couleur (par défaut "Corrélation").
|
||||||
"""
|
"""
|
||||||
output_path = Path(output_path)
|
output_path = Path(output_path)
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
@ -77,7 +90,16 @@ def plot_correlation_heatmap(
|
|||||||
data = corr.to_numpy()
|
data = corr.to_numpy()
|
||||||
|
|
||||||
fig, ax = plt.subplots()
|
fig, ax = plt.subplots()
|
||||||
im = ax.imshow(data, vmin=-1.0, vmax=1.0)
|
if vmin is None:
|
||||||
|
vmin = -1.0
|
||||||
|
if vmax is None:
|
||||||
|
vmax = 1.0
|
||||||
|
|
||||||
|
im_kwargs = {"vmin": vmin, "vmax": vmax}
|
||||||
|
if cmap is not None:
|
||||||
|
im_kwargs["cmap"] = cmap
|
||||||
|
|
||||||
|
im = ax.imshow(data, **im_kwargs)
|
||||||
|
|
||||||
# Ticks et labels
|
# Ticks et labels
|
||||||
ax.set_xticks(np.arange(len(labels)))
|
ax.set_xticks(np.arange(len(labels)))
|
||||||
@ -86,31 +108,45 @@ def plot_correlation_heatmap(
|
|||||||
ax.set_yticklabels(labels)
|
ax.set_yticklabels(labels)
|
||||||
|
|
||||||
# Axe en haut/bas selon préférence (ici on laisse en bas)
|
# Axe en haut/bas selon préférence (ici on laisse en bas)
|
||||||
ax.set_title("Matrice de corrélation (coef. de Pearson)")
|
ax.set_title(title or "Matrice de corrélation")
|
||||||
|
|
||||||
# Barre de couleur
|
# Barre de couleur
|
||||||
cbar = plt.colorbar(im, ax=ax)
|
cbar = plt.colorbar(im, ax=ax)
|
||||||
cbar.set_label("Corrélation")
|
cbar.set_label(colorbar_label or "Corrélation")
|
||||||
|
|
||||||
# Annotation des cases
|
# Annotation des cases
|
||||||
if annotate:
|
if annotate:
|
||||||
n = data.shape[0]
|
n = data.shape[0]
|
||||||
|
norm = im.norm
|
||||||
|
cmap_obj = im.cmap
|
||||||
|
|
||||||
|
def _text_color(value: float) -> str:
|
||||||
|
rgba = cmap_obj(norm(value))
|
||||||
|
r, g, b, _ = rgba
|
||||||
|
luminance = 0.2126 * r + 0.7152 * g + 0.0722 * b
|
||||||
|
return "white" if luminance < 0.5 else "black"
|
||||||
|
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
for j in range(n):
|
for j in range(n):
|
||||||
|
val = data[i, j]
|
||||||
if i == j:
|
if i == j:
|
||||||
text = "—"
|
text = "—"
|
||||||
|
elif np.isnan(val):
|
||||||
|
text = ""
|
||||||
else:
|
else:
|
||||||
val = data[i, j]
|
text = f"{val:.2f}"
|
||||||
if np.isnan(val):
|
|
||||||
text = ""
|
if not text:
|
||||||
else:
|
continue
|
||||||
text = f"{val:.2f}"
|
|
||||||
|
color = _text_color(0.0 if np.isnan(val) else val)
|
||||||
ax.text(
|
ax.text(
|
||||||
j,
|
j,
|
||||||
i,
|
i,
|
||||||
text,
|
text,
|
||||||
ha="center",
|
ha="center",
|
||||||
va="center",
|
va="center",
|
||||||
|
color=color,
|
||||||
)
|
)
|
||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
|
|||||||