"""
Utility functions and classes for static and animated control input-output
plots.
This module provides helper functions used in control input-output data plot
generation, and a custom Matplotlib legend handler class for highlighting
initial measurement periods in animated plots.
"""
from typing import Any
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.artist import Artist
from matplotlib.axes import Axes
from matplotlib.figure import Figure, SubFigure
from matplotlib.layout_engine import ConstrainedLayoutEngine
from matplotlib.legend import Legend
from matplotlib.legend_handler import HandlerPatch
from matplotlib.lines import Line2D
from matplotlib.patches import Rectangle
from matplotlib.text import Text
from matplotlib.transforms import Transform
[docs]
class HandlerInitMeasurementRect(HandlerPatch):
"""
A custom legend handler for the rectangle representing the initial
input-output measurement period in control input-output plot animations.
"""
[docs]
def create_artists(
self,
legend: Legend,
orig_handle: Artist,
xdescent: float,
ydescent: float,
width: float,
height: float,
fontsize: float,
trans: Transform,
) -> list[Rectangle | Line2D]:
# Make sure orig_handle is a Rectangle
assert isinstance(orig_handle, Rectangle)
# Define the main rectangle
rect = Rectangle(
(xdescent, ydescent),
width,
height,
transform=trans,
color=orig_handle.get_facecolor(),
alpha=orig_handle.get_alpha(),
)
# Create dashed vertical lines at the sides of the rectangle
line1 = Line2D(
[xdescent, xdescent],
[ydescent, ydescent + height],
color="black",
linestyle=(0, (2, 2)),
linewidth=1,
)
line2 = Line2D(
[xdescent + width, xdescent + width],
[ydescent, ydescent + height],
color="black",
linestyle=(0, (2, 2)),
linewidth=1,
)
# Add transform to the vertical lines
line1.set_transform(trans)
line2.set_transform(trans)
return [rect, line1, line2]
[docs]
def validate_data_dimensions(
u_k: np.ndarray,
y_k: np.ndarray,
y_s: np.ndarray,
u_s: np.ndarray | None = None,
u_bounds_list: list[tuple[float, float]] | None = None,
y_bounds_list: list[tuple[float, float]] | None = None,
u_ylimits_list: list[tuple[float, float]] | None = None,
y_ylimits_list: list[tuple[float, float]] | None = None,
u_setpoint_labels: list[str] | None = None,
y_setpoint_labels: list[str] | None = None,
x_axis_labels: list[str] | None = None,
input_y_axis_labels: list[str] | None = None,
output_y_axis_labels: list[str] | None = None,
) -> None:
"""
Validate that input-output data arrays, and bound and ylimit lists have the
expected shapes and lengths.
Args:
u_k (np.ndarray): An array containing control input data of shape (T,
m), where `m` is the number of inputs and `T` is the number of
time steps.
y_k (np.ndarray): An array containing system output data of shape (T,
p), where `p` is the number of outputs and `T` is the number of
time steps.
y_s (np.ndarray): An array containing output setpoint values of shape
(T, p), where `p` is the number of outputs and `T` is the number of
time steps.
u_s (np.ndarray | None): An array containing input setpoint values of
shape (T, m), where `m` is the number of inputs and `T` is the
number of time steps.
u_bounds_list (list[tuple[float, float]] | None): A list of tuples
(lower_bound, upper_bound) specifying bounds for each input data
sequence.
y_bounds_list (list[tuple[float, float]] | None): A list of tuples
(lower_bound, upper_bound) specifying bounds for each output data
sequence.
u_ylimits_list (list[tuple[float, float]] | None): A list of tuples
(lower_limit, upper_limit) specifying the Y-axis limits for each
input subplot.
y_ylimits_list (list[tuple[float, float]] | None): A list of tuples
(lower_limit, upper_limit) specifying the Y-axis limits for each
output subplot.
u_setpoint_labels (list[str] | None): A list of strings specifying
custom legend labels for input setpoint series.
y_setpoint_labels (list[str] | None): A list of strings specifying
custom legend labels for output setpoint series.
x_axis_labels (list[str] | None): A list of strings specifying custom
X-axis labels for each subplot.
input_y_axis_labels (list[str] | None): A list of strings specifying
custom Y-axis labels for each input subplot.
output_y_axis_labels (list[str] | None): A list of strings specifying
custom Y-axis labels for each output subplot.
Raises:
ValueError: If any array dimensions mismatch expected shapes, or if
the lengths of the list arguments do not match the number of
subplots.
"""
# Check input-output data dimensions
if u_k.shape[0] != y_k.shape[0]:
raise ValueError(
"Dimension mismatch. The number of time steps for `u_k` "
f"({u_k.shape[0]}) and `y_k` ({y_k.shape[0]}) must match."
)
if y_k.shape != y_s.shape:
raise ValueError(
f"Shape mismatch. The shapes of `y_k` ({y_k.shape}) and "
f"`y_s` ({y_s.shape}) must match."
)
# If input setpoint is passed, verify input data dimension match
if u_s is not None:
if u_k.shape != u_s.shape:
raise ValueError(
f"Shape mismatch. The shape of `u_k` ({u_k.shape}) and "
f"`u_s` ({u_s.shape}) must match."
)
# Error handling for list lengths
m = u_k.shape[1] # Number of inputs
p = y_k.shape[1] # Number of outputs
check_list_length("u_bounds_list", u_bounds_list, m)
check_list_length("y_bounds_list", y_bounds_list, p)
check_list_length("u_ylimits_list", u_ylimits_list, m)
check_list_length("y_ylimits_list", y_ylimits_list, p)
check_list_length("u_setpoint_labels", u_setpoint_labels, m)
check_list_length("y_setpoint_labels", y_setpoint_labels, p)
# Lists for Y-axis labels
check_list_length("x_axis_labels", x_axis_labels, max(m, p))
check_list_length("input_y_axis_labels", input_y_axis_labels, m)
check_list_length("output_y_axis_labels", output_y_axis_labels, p)
[docs]
def get_padded_limits(
X: np.ndarray,
X_s: np.ndarray | None = None,
pad_percentage: float = 0.05,
) -> tuple[float, float]:
"""
Get the minimum and maximum limits from two data sequences extended by
a specified percentage of the combined data range.
Args:
X (np.ndarray): First data array.
X_s (np.ndarray | None): Second data array. If `None`, only `X` is
considered. Defaults to `None`.
pad_percentage (float): The percentage of the data range to be used
as padding. Defaults to 0.05.
Returns:
tuple[float, float]: A tuple containing padded minimum and maximum
limits for the combined data from `X` and `X_s`.
"""
# Get minimum and maximum limits from data sequences
X_min, X_max = np.min(X), np.max(X)
if X_s is not None:
X_s_min, X_s_max = np.min(X_s), np.max(X_s)
X_lim_min = min(X_min, X_s_min)
X_lim_max = max(X_max, X_s_max)
else:
X_lim_min, X_lim_max = X_min, X_max
# Extend limits by a percentage of the overall data range
X_range = X_lim_max - X_lim_min
X_lim_min -= X_range * pad_percentage
X_lim_max += X_range * pad_percentage
return (X_lim_min, X_lim_max)
[docs]
def get_text_width_in_data(
text_object: Text, axis: Axes, fig: Figure | SubFigure
) -> float:
"""
Calculate the bounding box width of a text object in data coordinates.
Args:
text_object (Text): A Matplotlib text object.
axis (Axes): The axis on which the text object is displayed.
fig (Figure | SubFigure): The Matplotlib figure or subfigure that
contains the axis.
Returns:
float: The width of the text object's bounding box in data coordinates.
"""
# Get the bounding box of the text object in pixel coordinates
render = fig.canvas.get_renderer() # type: ignore[attr-defined]
text_box = text_object.get_window_extent(renderer=render)
# Convert the bounding box from pixel coordinates to data coordinates
text_box_data = axis.transData.inverted().transform(text_box)
# Calculate the width of the bounding box in data coordinates
text_box_width = text_box_data[1][0] - text_box_data[0][0]
return text_box_width
[docs]
def filter_and_reorder_legend(
axis: Axes,
legend_params: dict[str, Any],
end_labels_list: list[str] | None = None,
) -> None:
"""
Remove duplicate entries from the legend of a Matplotlib axis. Optionally,
move specified labels to the end of the legend.
Note:
The appearance of the plot legend can be customized by passing a
dictionary of Matplotlib legend properties.
Args:
axis (Axes): The Matplotlib axis containing the legend to modify.
legend_params (dict[str, Any]): A dictionary of Matplotlib properties
for customizing the plot legend (e.g., fontsize, loc,
handlelength).
end_labels_list (list[str] | None): A list of labels to move to the end
of the legend. Labels are moved in the order provided, with the
last label in the list becoming the final legend entry. If not
provided, the legend labels will not be reordered. Defaults to
`None`.
"""
# Initialize `last_labels_list` if not provided
if end_labels_list is None:
end_labels_list = []
# Get labels and handles from axis without duplicates
handles, labels = axis.get_legend_handles_labels()
unique_labels = dict(zip(labels, handles, strict=False))
# Reorder labels if `last_label_list` is provided
for last_label in end_labels_list:
if last_label in unique_labels:
last_handle = unique_labels.pop(last_label)
unique_labels[last_label] = last_handle
# Update the legend with the unique handles and labels
axis.legend(unique_labels.values(), unique_labels.keys(), **legend_params)
[docs]
def init_dict_if_none(input_dict: dict | None) -> dict:
"""
Return an empty dictionary if the input is `None`, otherwise return the
input.
Args:
input_dict (dict | None): A dictionary or `None`.
Returns:
dict: The original dictionary or an empty one.
"""
return {} if input_dict is None else input_dict
[docs]
def get_label_from_list(
label_list: list[str], index: int, fallback: str
) -> str:
"""
Get a label from a list by index, or return a fallback value if the list is
empty.
Args:
label_list (list[str]): A list of label strings.
index (int): The index of the desired label.
fallback (str): A fallback string to return if the list is empty.
Returns:
str: The label at the specified index or the fallback string.
"""
return label_list[index] if label_list else fallback
[docs]
def check_list_length(
name: str, data_list: list | None, expected: int
) -> None:
"""
Verify whether a list contains the expected number of elements.
Args:
name (str): The name of the list (for error message context).
data_list (list | None): The list to check.
expected (int): The expected number of elements in the list.
Raises:
ValueError: If the list length does not match the expected value.
"""
if data_list and len(data_list) != expected:
raise ValueError(
f"The length of `{name}` ({len(data_list)}) does not match "
f"the expected value ({expected})."
)