You've already forked donnees_meteo
Ajout des matrices de corrélation + Refactoring
This commit is contained in:
@@ -48,6 +48,11 @@ def plot_correlation_heatmap(
|
||||
output_path: str | Path,
|
||||
*,
|
||||
annotate: bool = True,
|
||||
title: str | None = None,
|
||||
cmap: str | None = None,
|
||||
vmin: float | None = None,
|
||||
vmax: float | None = None,
|
||||
colorbar_label: str | None = None,
|
||||
) -> Path:
|
||||
"""
|
||||
Trace une heatmap de la matrice de corrélation.
|
||||
@@ -63,6 +68,14 @@ def plot_correlation_heatmap(
|
||||
Chemin du fichier image à écrire.
|
||||
annotate :
|
||||
Si True, affiche la valeur numérique dans chaque case.
|
||||
title :
|
||||
Titre personalisé (par défaut, libellé générique).
|
||||
cmap :
|
||||
Nom de la palette matplotlib à utiliser (par défaut, palette standard).
|
||||
vmin / vmax :
|
||||
Borne d'échelle de couleurs. Si None, valeurs classiques [-1, 1].
|
||||
colorbar_label :
|
||||
Libellé pour la barre de couleur (par défaut "Corrélation").
|
||||
"""
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
@@ -77,7 +90,16 @@ def plot_correlation_heatmap(
|
||||
data = corr.to_numpy()
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
im = ax.imshow(data, vmin=-1.0, vmax=1.0)
|
||||
if vmin is None:
|
||||
vmin = -1.0
|
||||
if vmax is None:
|
||||
vmax = 1.0
|
||||
|
||||
im_kwargs = {"vmin": vmin, "vmax": vmax}
|
||||
if cmap is not None:
|
||||
im_kwargs["cmap"] = cmap
|
||||
|
||||
im = ax.imshow(data, **im_kwargs)
|
||||
|
||||
# Ticks et labels
|
||||
ax.set_xticks(np.arange(len(labels)))
|
||||
@@ -86,31 +108,45 @@ def plot_correlation_heatmap(
|
||||
ax.set_yticklabels(labels)
|
||||
|
||||
# Axe en haut/bas selon préférence (ici on laisse en bas)
|
||||
ax.set_title("Matrice de corrélation (coef. de Pearson)")
|
||||
ax.set_title(title or "Matrice de corrélation")
|
||||
|
||||
# Barre de couleur
|
||||
cbar = plt.colorbar(im, ax=ax)
|
||||
cbar.set_label("Corrélation")
|
||||
cbar.set_label(colorbar_label or "Corrélation")
|
||||
|
||||
# Annotation des cases
|
||||
if annotate:
|
||||
n = data.shape[0]
|
||||
norm = im.norm
|
||||
cmap_obj = im.cmap
|
||||
|
||||
def _text_color(value: float) -> str:
|
||||
rgba = cmap_obj(norm(value))
|
||||
r, g, b, _ = rgba
|
||||
luminance = 0.2126 * r + 0.7152 * g + 0.0722 * b
|
||||
return "white" if luminance < 0.5 else "black"
|
||||
|
||||
for i in range(n):
|
||||
for j in range(n):
|
||||
val = data[i, j]
|
||||
if i == j:
|
||||
text = "—"
|
||||
elif np.isnan(val):
|
||||
text = ""
|
||||
else:
|
||||
val = data[i, j]
|
||||
if np.isnan(val):
|
||||
text = ""
|
||||
else:
|
||||
text = f"{val:.2f}"
|
||||
text = f"{val:.2f}"
|
||||
|
||||
if not text:
|
||||
continue
|
||||
|
||||
color = _text_color(0.0 if np.isnan(val) else val)
|
||||
ax.text(
|
||||
j,
|
||||
i,
|
||||
text,
|
||||
ha="center",
|
||||
va="center",
|
||||
color=color,
|
||||
)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
Reference in New Issue
Block a user