1
donnees_meteo/meteo/plots/basic_series.py

351 lines
11 KiB
Python

"""Tracés simples et réutilisables pour les séries temporelles de base."""
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Callable
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from meteo.dataset import _circular_mean_deg
from meteo.variables import Variable
from .base import export_plot_dataset
__all__ = [
"PlotStyle",
"PlotChoice",
"recommended_style",
"resample_series_for_plot",
"plot_basic_series",
"plot_dual_time_series",
]
class PlotStyle(str, Enum):
LINE = "line"
AREA = "area"
BAR = "bar"
SCATTER = "scatter"
@dataclass(frozen=True)
class PlotChoice:
"""Configuration par variable : style et fonction d'agrégation."""
style: PlotStyle
agg: Callable[[pd.Series], float] | str = "mean"
DEFAULT_CHOICES: dict[str, PlotChoice] = {
# Variations continues : lignes ou aires.
"temperature": PlotChoice(PlotStyle.LINE, "mean"),
"pressure": PlotChoice(PlotStyle.LINE, "mean"),
"humidity": PlotChoice(PlotStyle.AREA, "mean"),
"illuminance": PlotChoice(PlotStyle.AREA, "mean"),
"sun_elevation": PlotChoice(PlotStyle.AREA, "mean"),
# Variables dont la perception bénéficie d'autres représentations.
"rain_rate": PlotChoice(PlotStyle.BAR, "mean"),
"wind_speed": PlotChoice(PlotStyle.LINE, "mean"),
"wind_direction": PlotChoice(PlotStyle.SCATTER, _circular_mean_deg),
}
# Palette douce mais contrastée, associée aux variables.
PALETTE = {
"temperature": "#d1495b",
"pressure": "#5c677d",
"humidity": "#2c7bb6",
"rain_rate": "#1b9e77",
"illuminance": "#f4a259",
"wind_speed": "#118ab2",
"wind_direction": "#8e6c8a",
"sun_elevation": "#f08c42",
}
DEFAULT_COLOR = "#386cb0"
def recommended_style(variable: Variable, override: str | None = None) -> PlotChoice:
"""Retourne le style/agrégation par défaut, ou une surcharge utilisateur."""
if override:
style = PlotStyle(override)
agg = DEFAULT_CHOICES.get(variable.key, PlotChoice(style)).agg
return PlotChoice(style, agg)
return DEFAULT_CHOICES.get(variable.key, PlotChoice(PlotStyle.LINE))
def _nice_frequencies() -> list[tuple[str, pd.Timedelta]]:
return [
("5min", pd.Timedelta(minutes=5)),
("10min", pd.Timedelta(minutes=10)),
("15min", pd.Timedelta(minutes=15)),
("30min", pd.Timedelta(minutes=30)),
("1h", pd.Timedelta(hours=1)),
("3h", pd.Timedelta(hours=3)),
("6h", pd.Timedelta(hours=6)),
("12h", pd.Timedelta(hours=12)),
("1d", pd.Timedelta(days=1)),
("3d", pd.Timedelta(days=3)),
("7d", pd.Timedelta(days=7)),
]
def _auto_resample_frequency(index: pd.DatetimeIndex, *, target_points: int = 420) -> str:
"""Choisit une fréquence qui limite le nombre de points tout en conservant la forme générale."""
if index.empty or len(index) < 2:
return "1h"
span = index.max() - index.min()
if span <= pd.Timedelta(0):
return "1h"
for label, delta in _nice_frequencies():
if span / delta <= target_points:
return label
return _nice_frequencies()[-1][0]
def _format_time_axis(ax: plt.Axes) -> None:
locator = mdates.AutoDateLocator(minticks=4, maxticks=8)
formatter = mdates.ConciseDateFormatter(locator, formats=["%Y", "%b", "%d", "%d %H:%M", "%H:%M", "%S"])
ax.xaxis.set_major_locator(locator)
ax.xaxis.set_major_formatter(formatter)
def _infer_bar_width(index: pd.DatetimeIndex) -> float:
"""
Calcule une largeur de barre raisonnable (en jours) pour les histogrammes temporels.
"""
if len(index) < 2:
return 0.3 # ~7 heures, pour rendre le point visible même isolé
diffs = np.diff(index.asi8) # nanosecondes
median_ns = float(np.median(diffs))
if not np.isfinite(median_ns) or median_ns <= 0:
return 0.1
return pd.to_timedelta(median_ns, unit="ns") / pd.Timedelta(days=1) * 0.8
def _ensure_datetime_index(series: pd.Series) -> pd.Series:
if not isinstance(series.index, pd.DatetimeIndex):
raise TypeError("Une série temporelle (DatetimeIndex) est attendue pour le tracé.")
return series
def _series_color(variable: Variable) -> str:
if variable.key in PALETTE:
return PALETTE[variable.key]
return PALETTE.get(variable.column, DEFAULT_COLOR)
def _format_label(var: Variable) -> str:
unit_text = f" ({var.unit})" if var.unit else ""
return f"{var.label}{unit_text}"
def resample_series_for_plot(
series: pd.Series,
*,
variable: Variable,
freq: str | None = None,
target_points: int = 420,
) -> tuple[pd.Series, str]:
"""
Prépare une série pour l'affichage : resample et agrégation adaptés à la variable.
"""
_ensure_datetime_index(series)
if freq is None:
freq = _auto_resample_frequency(series.index, target_points=target_points)
agg_func = DEFAULT_CHOICES.get(variable.key, PlotChoice(PlotStyle.LINE)).agg
resampled = series.resample(freq).agg(agg_func).dropna()
return resampled, freq
def plot_basic_series(
series: pd.Series,
*,
variable: Variable,
output_path: str | Path,
style: PlotStyle,
title: str,
ylabel: str,
annotate_freq: str | None = None,
) -> Path:
"""
Trace une série temporelle avec un style simple (ligne, aire, barres, nuage de points).
"""
_ensure_datetime_index(series)
if series.empty:
raise ValueError(f"Aucune donnée disponible pour {variable.key} après filtrage.")
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
color = _series_color(variable)
x = mdates.date2num(series.index)
values = series.to_numpy(dtype=float)
fig, ax = plt.subplots(figsize=(11, 4.2))
if style is PlotStyle.LINE:
ax.plot_date(x, values, "-", linewidth=1.8, color=color, label=variable.label)
elif style is PlotStyle.AREA:
ax.fill_between(x, values, step="mid", color=color, alpha=0.2)
ax.plot_date(x, values, "-", linewidth=1.6, color=color)
elif style is PlotStyle.BAR:
width = _infer_bar_width(series.index)
ax.bar(x, values, width=width, color=color, edgecolor=color, linewidth=0.5, alpha=0.85)
elif style is PlotStyle.SCATTER:
ax.scatter(x, values, s=16, color=color, alpha=0.9)
else:
raise ValueError(f"Style inconnu : {style}")
ax.set_title(title)
ax.set_ylabel(ylabel)
_format_time_axis(ax)
ax.grid(True, color="#e0e0e0", linewidth=0.8, alpha=0.7)
ax.margins(x=0.02, y=0.05)
if annotate_freq:
ax.text(
0.99,
0.02,
f"Agrégation : {annotate_freq}",
transform=ax.transAxes,
ha="right",
va="bottom",
fontsize=9,
color="#555555",
)
fig.tight_layout()
fig.savefig(output_path, dpi=150)
plt.close(fig)
export_plot_dataset(series.to_frame(name=variable.column), output_path)
return output_path.resolve()
def _draw_series(ax: plt.Axes, series: pd.Series, *, choice: PlotChoice, color: str, label: str):
x = mdates.date2num(series.index)
values = series.to_numpy(dtype=float)
if choice.style is PlotStyle.LINE:
return ax.plot_date(x, values, "-", linewidth=1.8, color=color, label=label)
if choice.style is PlotStyle.AREA:
ax.fill_between(x, values, step="mid", color=color, alpha=0.15)
return ax.plot_date(x, values, "-", linewidth=1.6, color=color, label=label)
if choice.style is PlotStyle.BAR:
width = _infer_bar_width(series.index) * 0.9
return ax.bar(x, values, width=width, color=color, edgecolor=color, linewidth=0.5, alpha=0.75, label=label)
if choice.style is PlotStyle.SCATTER:
return ax.scatter(x, values, s=16, color=color, alpha=0.9, label=label)
raise ValueError(f"Style inconnu : {choice.style}")
def plot_dual_time_series(
series_left: pd.Series,
variable_left: Variable,
choice_left: PlotChoice,
series_right: pd.Series,
variable_right: Variable,
choice_right: PlotChoice,
*,
output_path: str | Path,
title: str,
annotate_freq: str | None = None,
) -> Path:
"""Superpose deux séries temporelles (axes Y séparés) avec styles adaptés."""
_ensure_datetime_index(series_left)
_ensure_datetime_index(series_right)
if series_left.empty or series_right.empty:
raise ValueError("Les séries à tracer ne peuvent pas être vides.")
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
color_left = _series_color(variable_left)
color_right = _series_color(variable_right)
fig, ax_left = plt.subplots(figsize=(11, 4.6))
ax_right = ax_left.twinx()
artists_left = _draw_series(
ax_left,
series_left,
choice=choice_left,
color=color_left,
label=_format_label(variable_left),
)
artists_right = _draw_series(
ax_right,
series_right,
choice=choice_right,
color=color_right,
label=_format_label(variable_right),
)
ax_left.set_ylabel(_format_label(variable_left), color=color_left)
ax_right.set_ylabel(_format_label(variable_right), color=color_right)
ax_left.tick_params(axis="y", labelcolor=color_left)
ax_right.tick_params(axis="y", labelcolor=color_right)
_format_time_axis(ax_left)
ax_left.grid(True, color="#e0e0e0", linewidth=0.8, alpha=0.7)
ax_left.margins(x=0.02, y=0.05)
ax_right.margins(x=0.02, y=0.05)
ax_left.set_title(title)
handles = []
labels = []
for artist in artists_left if isinstance(artists_left, list) else [artists_left]:
handles.append(artist)
labels.append(artist.get_label())
if isinstance(artists_right, list):
handles.extend(artists_right)
labels.extend([a.get_label() for a in artists_right])
else:
handles.append(artists_right)
labels.append(artists_right.get_label())
ax_left.legend(handles, labels, loc="upper left")
if annotate_freq:
ax_left.text(
0.99,
0.02,
f"Agrégation : {annotate_freq}",
transform=ax_left.transAxes,
ha="right",
va="bottom",
fontsize=9,
color="#555555",
)
fig.tight_layout()
fig.savefig(output_path, dpi=150)
plt.close(fig)
export_plot_dataset(
pd.concat(
{variable_left.column: series_left, variable_right.column: series_right},
axis=1,
),
output_path,
)
return output_path.resolve()