Source code for direct_data_driven_mpc.utilities.visualization.comparison_plot

"""
Functions for plotting multiple input-output data for control system
comparison.

This module provides functions for plotting multiple input-output trajectories
with setpoints using Matplotlib. It enables comparing different control systems
by plotting their control data in the same figure.
"""

from typing import Any

import matplotlib.pyplot as plt
import numpy as np

from .control_plot import (
    plot_input_output,
)
from .plot_utilities import (
    check_list_length,
    create_input_output_figure,
    init_dict_if_none,
)


[docs] def plot_input_output_comparison( u_data: list[np.ndarray], y_data: list[np.ndarray], y_s: np.ndarray, u_s: np.ndarray | None = None, u_bounds_list: list[tuple[float, float]] | None = None, y_bounds_list: list[tuple[float, float]] | None = None, inputs_line_param_list: list[dict[str, Any]] | None = None, outputs_line_param_list: list[dict[str, Any]] | None = None, setpoints_line_params: dict[str, Any] | None = None, bounds_line_params: dict[str, Any] | None = None, var_suffix_list: list[str] | None = None, legend_params: dict[str, Any] | None = None, figsize: tuple[int, int] = (14, 8), dpi: int = 300, u_ylimits_list: list[tuple[float, float]] | None = None, y_ylimits_list: list[tuple[float, float]] | None = None, fontsize: int = 12, title: str | None = None, input_labels: list[str] | None = None, output_labels: list[str] | None = None, u_setpoint_labels: list[str] | None = None, y_setpoint_labels: list[str] | None = None, x_axis_labels: list[str] | None = None, input_y_axis_labels: list[str] | None = None, output_y_axis_labels: list[str] | None = None, show: bool = True, ) -> None: """ Plot multiple input-output trajectories with setpoints in a Matplotlib figure for control system comparison. This function creates a figure with two rows of subplots: the first row for control inputs, and the second for system outputs. Each subplot shows the trajectories of each data series alongside its setpoint line. Useful for comparing the performance of different control systems. Args: u_data (list[np.ndarray]): A list of `M` arrays of shape (T, m) containing control input data from `M` simulations. `T` is the number of time steps, and `m` is the number of control inputs. y_data (list[np.ndarray]): A list of `M` arrays of shape (T, p) containing system output data from `M` simulations. `T` is the number of time steps, and `p` is the number of system outputs. y_s (np.ndarray): An array of shape (T, p) containing `p` output setpoint values. These setpoints correspond to the system outputs from `y_data`. u_s (np.ndarray | None): An array of shape (T, m) containing `m` input setpoint values. These setpoints correspond to the control inputs from `u_data`. If `None`, input setpoint lines will not be plotted. Defaults to `None`. u_bounds_list (list[tuple[float, float]] | None): A list of tuples (lower_bound, upper_bound) specifying bounds for each input data sequence. If provided, horizontal lines representing these bounds will be plotted in each subplot. If `None`, no horizontal lines will be plotted. The number of tuples must match the number of input data sequences. Defaults to `None`. y_bounds_list (list[tuple[float, float]] | None): A list of tuples (lower_bound, upper_bound) specifying bounds for each output data sequence. If provided, horizontal lines representing these bounds will be plotted in each subplot. If `None`, no horizontal lines will be plotted. The number of tuples must match the number of output data sequences. Defaults to `None`. inputs_line_param_list (list[dict[str, Any]] | None): A list of `M` dictionaries, where each dictionary specifies Matplotlib properties for customizing the plot lines corresponding to one of the `M` input data arrays in `u_data`. If not provided, Matplotlib's default line properties will be used. outputs_line_param_list (list[dict[str, Any]] | None): A list of `M` dictionaries, where each dictionary specifies Matplotlib properties for customizing the plot lines corresponding to one of the `M` output data arrays in `y_data`. If not provided, Matplotlib's default line properties will be used. setpoints_line_params (dict[str, Any] | None): A dictionary of Matplotlib properties for customizing the lines used to plot the setpoint values (e.g., color, linestyle, linewidth). If not provided, Matplotlib's default line properties will be used. bounds_line_params (dict[str, Any] | None): A dictionary of Matplotlib properties for customizing the lines used to plot the bounds of input-output data series (e.g., color, linestyle, linewidth). If not provided, Matplotlib's default line properties will be used. var_suffix_list (list[str] | None): A list of strings appended to each data series label in the plot legend. If not provided, no strings are appended. legend_params (dict[str, Any] | None): A dictionary of Matplotlib properties for customizing the plot legends (e.g., fontsize, loc, handlelength). If not provided, Matplotlib's default legend properties will be used. figsize (tuple[int, int]): The (width, height) dimensions of the created Matplotlib figure. dpi (int): The DPI resolution of the figure. u_ylimits_list (list[tuple[float, float]] | None): A list of tuples (lower_limit, upper_limit) specifying the Y-axis limits for each input subplot. If `None`, the Y-axis limits will be determined automatically. y_ylimits_list (list[tuple[float, float]] | None): A list of tuples (lower_limit, upper_limit) specifying the Y-axis limits for each output subplot. If `None`, the Y-axis limits will be determined automatically. fontsize (int): The fontsize for labels, legends and axes ticks. title (str | None): The title for the created plot figure. input_labels (list[str] | None): A list of strings specifying custom legend labels for input data series. If provided, the label at each index will override the default label constructed using `var_suffix_list`. output_labels (list[str] | None): A list of strings specifying custom legend labels for output data series. If provided, the label at each index will override the default label constructed using `var_suffix_list`. u_setpoint_labels (list[str] | None): A list of strings specifying custom legend labels for input setpoint series. If provided, the label at each index will override the corresponding default label. y_setpoint_labels (list[str] | None): A list of strings specifying custom legend labels for output setpoint series. If provided, the label at each index will override the corresponding default label. x_axis_labels (list[str] | None): A list of strings specifying custom X-axis labels for each subplot. If provided, the label at each index will override the corresponding default label. input_y_axis_labels (list[str] | None): A list of strings specifying custom Y-axis labels for each input subplot. If provided, the label at each index will override the corresponding default label. output_y_axis_labels (list[str] | None): A list of strings specifying custom Y-axis labels for each output subplot. If provided, the label at each index will override the corresponding default label. show (bool): Whether to call `plt.show()` for the figure or not. Useful when adding plot elements externally before rendering the figure. Defaults to `True`. Raises: ValueError: If input/output array shapes, or line parameter list lengths, are not as expected. """ validate_comparison_plot_parameters( u_data=u_data, y_data=y_data, inputs_line_param_list=inputs_line_param_list, outputs_line_param_list=outputs_line_param_list, var_suffix_list=var_suffix_list, input_labels=input_labels, output_labels=output_labels, ) # Initialize Matplotlib params if not provided setpoints_line_params = init_dict_if_none(setpoints_line_params) legend_params = init_dict_if_none(legend_params) # Create figure with subplots m = u_data[0].shape[1] # Number of inputs p = y_data[0].shape[1] # Number of outputs _, axs_u, axs_y = create_input_output_figure( m=m, p=p, figsize=figsize, dpi=dpi, fontsize=fontsize, title=title ) # Plot data iterating through each data array for i in range(len(u_data)): # Initialize Matplotlib params if not provided inputs_line_params = ( init_dict_if_none(inputs_line_param_list[i]) if inputs_line_param_list else None ) outputs_line_params = ( init_dict_if_none(outputs_line_param_list[i]) if outputs_line_param_list else None ) # Retrieve plot labels for each index var_suffix = var_suffix_list[i] if var_suffix_list else "" input_label = input_labels[i] if input_labels else None output_label = output_labels[i] if output_labels else None # Plot setpoint line only for the last data set to prevent cluttering plot_setpoint_lines = i == (len(u_data) - 1) # Plot input-output data plot_input_output( u_k=u_data[i], y_k=y_data[i], u_s=u_s, y_s=y_s, u_bounds_list=u_bounds_list, y_bounds_list=y_bounds_list, inputs_line_params=inputs_line_params, outputs_line_params=outputs_line_params, setpoints_line_params=setpoints_line_params, bounds_line_params=bounds_line_params, var_suffix=var_suffix, dpi=dpi, u_ylimits_list=u_ylimits_list, y_ylimits_list=y_ylimits_list, fontsize=fontsize, legend_params=legend_params, axs_u=axs_u, axs_y=axs_y, input_label=input_label, output_label=output_label, u_setpoint_labels=u_setpoint_labels, y_setpoint_labels=y_setpoint_labels, x_axis_labels=x_axis_labels, input_y_axis_labels=input_y_axis_labels, output_y_axis_labels=output_y_axis_labels, plot_setpoint_lines=plot_setpoint_lines, ) # Show plot if enabled if show: plt.show()
[docs] def validate_comparison_plot_parameters( u_data: list[np.ndarray], y_data: list[np.ndarray], inputs_line_param_list: list[dict[str, Any]] | None = None, outputs_line_param_list: list[dict[str, Any]] | None = None, var_suffix_list: list[str] | None = None, input_labels: list[str] | None = None, output_labels: list[str] | None = None, ) -> None: """ Validate that input/output data and plot parameter lists match the expected dimensions for generating comparison plots. Args: u_data (list[np.ndarray]): A list of `M` arrays of shape (T, m) containing control input data from `M` simulations. `T` is the number of time steps, and `m` is the number of control inputs. y_data (list[np.ndarray]): A list of `M` arrays of shape (T, p) containing system output data from `M` simulations. `T` is the number of time steps, and `p` is the number of system outputs. inputs_line_param_list (list[dict[str, Any]] | None): A list of `M` dictionaries, where each dictionary specifies Matplotlib properties for customizing the plot lines corresponding to one of the `M` input data arrays in `u_data`. outputs_line_param_list (list[dict[str, Any]] | None): A list of `M` dictionaries, where each dictionary specifies Matplotlib properties for customizing the plot lines corresponding to one of the `M` output data arrays in `y_data`. var_suffix_list (list[str] | None): A list of strings appended to each data series label in the plot legend. input_labels (list[str] | None): A list of strings specifying custom legend labels for input data series. output_labels (list[str] | None): A list of strings specifying custom legend labels for output data series. Raises: ValueError: If any parameter does not match the expected dimension. """ if not u_data or not y_data: raise ValueError( "`u_data` and `y_data` must contain at least one simulation." ) if len(u_data) != len(y_data): raise ValueError( "`u_data` and `y_data` must have the same number of trajectories." ) # Validate input-output data dimensions u_shape = u_data[0].shape y_shape = y_data[0].shape if not all(u.shape == u_shape for u in u_data): raise ValueError( f"All `u_data` arrays must have the same shape ({u_shape})." ) if not all(y.shape == y_shape for y in y_data): raise ValueError( f"All `y_data` arrays must have the same shape ({y_shape})." ) # Validate list lengths if provided n_sim = len(u_data) # Lists for input plots check_list_length("inputs_line_param_list", inputs_line_param_list, n_sim) check_list_length( "outputs_line_param_list", outputs_line_param_list, n_sim ) check_list_length("var_suffix_list", var_suffix_list, n_sim) check_list_length("input_labels", input_labels, n_sim) check_list_length("output_labels", output_labels, n_sim)