Source code for direct_data_driven_mpc.utilities.visualization.control_plot

"""
Functions for plotting control input-output data.

This module provides functions for plotting input-output trajectories with
setpoints using Matplotlib. It creates highly customizable figures with
separate subplots for inputs and outputs, with optional highlighting of the
initial measurement period for data-driven control systems.
"""

from typing import Any

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from matplotlib.figure import Figure, SubFigure

from .plot_utilities import (
    create_input_output_figure,
    filter_and_reorder_legend,
    get_text_width_in_data,
    init_dict_if_none,
    validate_data_dimensions,
)


[docs] def plot_input_output( u_k: np.ndarray, y_k: np.ndarray, y_s: np.ndarray, u_s: np.ndarray | None = None, u_bounds_list: list[tuple[float, float]] | None = None, y_bounds_list: list[tuple[float, float]] | None = None, inputs_line_params: dict[str, Any] | None = None, outputs_line_params: dict[str, Any] | None = None, setpoints_line_params: dict[str, Any] | None = None, bounds_line_params: dict[str, Any] | None = None, u_setpoint_var_symbol: str = "u^s", y_setpoint_var_symbol: str = "y^s", initial_steps: int | None = None, initial_excitation_text: str = "Init. Excitation", initial_measurement_text: str = "Init. Measurement", control_text: str = "Data-Driven MPC", display_initial_text: bool = True, display_control_text: bool = True, figsize: tuple[float, float] = (12.0, 8.0), dpi: int = 300, u_ylimits_list: list[tuple[float, float]] | None = None, y_ylimits_list: list[tuple[float, float]] | None = None, fontsize: int = 12, legend_params: dict[str, Any] | None = None, var_suffix: str = "", axs_u: list[Axes] | None = None, axs_y: list[Axes] | None = None, title: str | None = None, input_label: str | None = None, output_label: str | None = None, u_setpoint_labels: list[str] | None = None, y_setpoint_labels: list[str] | None = None, x_axis_labels: list[str] | None = None, input_y_axis_labels: list[str] | None = None, output_y_axis_labels: list[str] | None = None, plot_setpoint_lines: bool = True, ) -> None: """ Plot input-output data with setpoints in a Matplotlib figure. This function creates a figure with two rows of subplots, with the first row containing control inputs, and the second row, system outputs. Each subplot shows the data series for each data sequence alongside its setpoint. The appearance of plot lines and legends can be customized by passing dictionaries of Matplotlib line and legend properties. If provided, the first 'initial_steps' time steps are highlighted to emphasize the initial input-output data measurement period representing the data-driven system characterization phase in a Data-Driven MPC algorithm. Additionally, custom labels can be displayed to indicate the initial measurement and the subsequent MPC control periods, but only if there is enough space to prevent them from overlapping with other plot elements. Note: If `axs_u` and `axs_y` are provided, the data will be plotted on these external axes and no new figure will be created. This allows for multiple data sequences to be plotted on the same external figure. Each data sequence can be differentiated using the `data_label` argument. Args: u_k (np.ndarray): An array containing control input data of shape (T, m), where `m` is the number of inputs and `T` is the number of time steps. y_k (np.ndarray): An array containing system output data of shape (T, p), where `p` is the number of outputs and `T` is the number of time steps. y_s (np.ndarray): An array containing output setpoint values of shape (T, p), where `p` is the number of outputs and `T` is the number of time steps. u_s (np.ndarray | None): An array containing input setpoint values of shape (T, m), where `m` is the number of inputs and `T` is the number of time steps. If `None`, input setpoint lines will not be plotted. Defaults to `None`. u_bounds_list (list[tuple[float, float]] | None): A list of tuples (lower_bound, upper_bound) specifying bounds for each input data sequence. If provided, horizontal lines representing these bounds will be plotted in each subplot. If `None`, no horizontal lines will be plotted. The number of tuples must match the number of input data sequences. Defaults to `None`. y_bounds_list (list[tuple[float, float]] | None): A list of tuples (lower_bound, upper_bound) specifying bounds for each output data sequence. If provided, horizontal lines representing these bounds will be plotted in each subplot. If `None`, no horizontal lines will be plotted. The number of tuples must match the number of output data sequences. Defaults to `None`. inputs_line_params (dict[str, Any] | None): A dictionary of Matplotlib properties for customizing the lines used to plot the input data series (e.g., color, linestyle, linewidth). If not provided, Matplotlib's default line properties will be used. outputs_line_params (dict[str, Any] | None): A dictionary of Matplotlib properties for customizing the lines used to plot the output data series (e.g., color, linestyle, linewidth). If not provided, Matplotlib's default line properties will be used. setpoints_line_params (dict[str, Any] | None): A dictionary of Matplotlib properties for customizing the lines used to plot the setpoint values (e.g., color, linestyle, linewidth). If not provided, Matplotlib's default line properties will be used. bounds_line_params (dict[str, Any] | None): A dictionary of Matplotlib properties for customizing the lines used to plot the bounds of input-output data series (e.g., color, linestyle, linewidth). If not provided, Matplotlib's default line properties will be used. u_setpoint_var_symbol (str): The variable symbol used to label the input setpoint data series (e.g., "u^s"). y_setpoint_var_symbol (str): The variable symbol used to label the output setpoint data series (e.g., "y^s"). initial_steps (int | None): The number of initial time steps during which input-output measurements were taken for the data-driven characterization of the system. This highlights the initial measurement period in the plot. If `None`, no special highlighting will be applied. Defaults to `None`. initial_excitation_text (str): Label text to display over the initial excitation period of the input plots. Default is "Init. Excitation". initial_measurement_text (str): Label text to display over the initial measurement period of the output plots. Default is "Init. Measurement". control_text (str): Label text to display over the post-initial control period. Default is "Data-Driven MPC". display_initial_text (bool): Whether to display the `initial_text` label on the plot. Default is True. display_control_text (bool): Whether to display the `control_text` label on the plot. Default is True. figsize (tuple[float, float]): The (width, height) dimensions of the created Matplotlib figure. dpi (int): The DPI resolution of the figure. u_ylimits_list (list[tuple[float, float]] | None): A list of tuples (lower_limit, upper_limit) specifying the Y-axis limits for each input subplot. If `None`, the Y-axis limits will be determined automatically. Defaults to `None`. y_ylimits_list (list[tuple[float, float]] | None): A list of tuples (lower_limit, upper_limit) specifying the Y-axis limits for each output subplot. If `None`, the Y-axis limits will be determined automatically. Defaults to `None`. fontsize (int): The fontsize for labels and axes ticks. legend_params (dict[str, Any] | None): A dictionary of Matplotlib properties for customizing the plot legends (e.g., fontsize, loc, handlelength). If not provided, Matplotlib's default legend properties will be used. var_suffix (str): A string appended to each data series label in the plot legend. axs_u (list[Axes] | None): A list of external axes for input plots. Defaults to `None`. axs_y (list[Axes] | None): A list of external axes for output plots. Defaults to `None`. title (str | None): The title for the created plot figure. Set only if the figure is created internally (i.e., `axs_u` and `axs_y` are not provided). If `None`, no title will be displayed. Defaults to `None`. input_label (str | None): A custom legend label for the input data series. If provided, this label will override the default label constructed using `var_suffix`. output_label (str | None): A custom legend label for the output data series. If provided, this label will override the default label constructed using `var_suffix`. u_setpoint_labels (list[str] | None): A list of strings specifying custom legend labels for input setpoint series. If provided, the label at each index will override the default label constructed using `u_setpoint_var_symbol`. y_setpoint_labels (list[str] | None): A list of strings specifying custom legend labels for output setpoint series. If provided, the label at each index will override the default label constructed using `y_setpoint_var_symbol`. x_axis_labels (list[str] | None): A list of strings specifying custom X-axis labels for each subplot. If provided, the label at each index will override the default "Time step $k$". input_y_axis_labels (list[str] | None): A list of strings specifying custom Y-axis labels for each input subplot. If provided, the label at each index will override the default constructed labels. output_y_axis_labels (list[str] | None): A list of strings specifying custom Y-axis labels for each output subplot. If provided, the label at each index will override the default constructed labels. plot_setpoint_lines (bool): Whether to plot setpoint lines. If `False`, no setpoint line will be plotted. Used for avoiding duplicate setpoint entries in multi-data plots. Defaults to `True`. Raises: ValueError: If any array dimensions mismatch expected shapes, or if the lengths of `u_bounds_list`, `y_bounds_list`, `u_ylimits_list`, or `y_ylimits_list` do not match the number of subplots. """ # Validate data dimensions validate_data_dimensions( u_k=u_k, y_k=y_k, u_s=u_s, y_s=y_s, u_bounds_list=u_bounds_list, y_bounds_list=y_bounds_list, u_ylimits_list=u_ylimits_list, y_ylimits_list=y_ylimits_list, u_setpoint_labels=u_setpoint_labels, y_setpoint_labels=y_setpoint_labels, x_axis_labels=x_axis_labels, input_y_axis_labels=input_y_axis_labels, output_y_axis_labels=output_y_axis_labels, ) # Initialize Matplotlib params if not provided inputs_line_params = init_dict_if_none(inputs_line_params) outputs_line_params = init_dict_if_none(outputs_line_params) setpoints_line_params = init_dict_if_none(setpoints_line_params) bounds_line_params = init_dict_if_none(bounds_line_params) legend_params = init_dict_if_none(legend_params) # Retrieve number of input and output data sequences m = u_k.shape[1] # Number of inputs p = y_k.shape[1] # Number of outputs # Create figure if lists of Axes are not provided is_ext_fig = axs_u is not None and axs_y is not None # External figure fig: Figure | SubFigure if not is_ext_fig: # Create figure and subplots fig, axs_u, axs_y = create_input_output_figure( m=m, p=p, figsize=figsize, dpi=dpi, fontsize=fontsize, title=title ) else: assert axs_u is not None # Prevent mypy [index] error # Use figure from the provided axes fig = axs_u[0].figure # Plot input data m = u_k.shape[1] # Number of inputs p = y_k.shape[1] # Number of outputs for i in range(m): # Get input setpoint if provided u_setpoint = u_s[:, i] if u_s is not None else None # Define plot index based on the number of input plots plot_index = -1 if m == 1 else i # Get input bounds if provided u_bounds = u_bounds_list[i] if u_bounds_list else None # Get plot Y-axis limit if provided u_plot_ylimit = u_ylimits_list[i] if u_ylimits_list else None # Prevent mypy [index] error assert axs_u is not None # Plot data plot_data( axis=axs_u[i], data=u_k[:, i], setpoint=u_setpoint, index=plot_index, data_line_params=inputs_line_params, bounds_line_params=bounds_line_params, setpoint_line_params=setpoints_line_params, var_symbol="u", setpoint_var_symbol=u_setpoint_var_symbol, var_label="Input", var_suffix=var_suffix, initial_text=initial_excitation_text, control_text=control_text, display_initial_text=display_initial_text, display_control_text=display_control_text, fontsize=fontsize, legend_params=legend_params, fig=fig, bounds=u_bounds, initial_steps=initial_steps, plot_ylimits=u_plot_ylimit, data_label=input_label, setpoint_labels=u_setpoint_labels, x_axis_labels=x_axis_labels, y_axis_labels=input_y_axis_labels, plot_setpoint_lines=plot_setpoint_lines, ) # Plot output data for j in range(p): # Define plot index based on the number of output plots plot_index = -1 if p == 1 else j # Get output bounds if provided y_bounds = y_bounds_list[j] if y_bounds_list else None # Get plot Y-axis limit if provided y_plot_ylimits = y_ylimits_list[j] if y_ylimits_list else None # Prevent mypy [index] error assert axs_y is not None # Plot data plot_data( axis=axs_y[j], data=y_k[:, j], setpoint=y_s[:, j], index=plot_index, data_line_params=outputs_line_params, bounds_line_params=bounds_line_params, setpoint_line_params=setpoints_line_params, var_symbol="y", setpoint_var_symbol=y_setpoint_var_symbol, var_label="Output", var_suffix=var_suffix, initial_text=initial_measurement_text, control_text=control_text, display_initial_text=display_initial_text, display_control_text=display_control_text, fontsize=fontsize, legend_params=legend_params, fig=fig, bounds=y_bounds, initial_steps=initial_steps, plot_ylimits=y_plot_ylimits, data_label=output_label, setpoint_labels=y_setpoint_labels, x_axis_labels=x_axis_labels, y_axis_labels=output_y_axis_labels, plot_setpoint_lines=plot_setpoint_lines, ) # Show the plot if the figure was created internally if not is_ext_fig: plt.show()
[docs] def plot_data( axis: Axes, data: np.ndarray, setpoint: np.ndarray | None, index: int, data_line_params: dict[str, Any], setpoint_line_params: dict[str, Any], bounds_line_params: dict[str, Any], var_symbol: str, setpoint_var_symbol: str, var_label: str, var_suffix: str, initial_text: str, control_text: str, display_initial_text: bool, display_control_text: bool, fontsize: int, legend_params: dict[str, Any], fig: Figure | SubFigure, bounds: tuple[float, float] | None = None, initial_steps: int | None = None, plot_ylimits: tuple[float, float] | None = None, data_label: str | None = None, setpoint_labels: list[str] | None = None, x_axis_labels: list[str] | None = None, y_axis_labels: list[str] | None = None, plot_setpoint_lines: bool = True, ) -> None: """ Plot a data series with setpoints in a specified axis. Optionally, highlight the initial measurement and control phases using shaded regions and text labels. The labels will be displayed if there is enough space to prevent them from overlapping with other plot elements. Note: The appearance of plot lines and legend can be customized by passing dictionaries of Matplotlib line and legend properties. Args: axis (Axes): The Matplotlib axis object to plot on. data (np.ndarray): An array containing data to be plotted. setpoint (np.ndarray | None): An array containing setpoint values to be plotted. If `None`, the setpoint line will not be plotted. index (int): The index of the data used for labeling purposes (e.g., "u_1", "u_2"). If set to -1, subscripts will not be added to labels. data_line_params (dict[str, Any]): A dictionary of Matplotlib properties for customizing the line used to plot the data series (e.g., color, linestyle, linewidth). setpoint_line_params (dict[str, Any]): A dictionary of Matplotlib properties for customizing the line used to plot the setpoint value (e.g., color, linestyle, linewidth). bounds_line_params (dict[str, Any]): A dictionary of Matplotlib properties for customizing the lines used to plot the bounds of the data series (e.g., color, linestyle, linewidth). var_symbol (str): The variable symbol used to label the data series (e.g., "u" for inputs, "y" for outputs). setpoint_var_symbol (str): The variable symbol used to label the setpoint data series (e.g., "u^s" for inputs, "y^s" for outputs). var_label (str): The variable label representing the control signal (e.g., "Input", "Output"). var_suffix (str): A string appended to each data series label in the plot legend. initial_text (str): Label text to display over the initial measurement period of the plot. control_text (str): Label text to display over the post-initial control period. display_initial_text (bool): Whether to display the `initial_text` label on the plot. display_control_text (bool): Whether to display the `control_text` label on the plot. fontsize (int): The fontsize for labels and axes ticks. legend_params (dict[str, Any]): A dictionary of Matplotlib properties for customizing the plot legend (e.g., fontsize, loc, handlelength). fig (Figure | SubFigure): The Matplotlib figure or subfigure that contains the axis. bounds (tuple[float, float] | None): A tuple (lower_bound, upper_bound) specifying the bounds of the data to be plotted. If provided, horizontal lines representing these bounds will be plotted. Defaults to `None`. initial_steps (int | None): The number of initial time steps during which input-output measurements were taken for the data-driven characterization of the system. This highlights the initial measurement period in the plot. Defaults to `None`. plot_ylimits (tuple[float, float] | None): A tuple (lower_limit, upper_limit) specifying the Y-axis limits for the plot. If `None`, the Y-axis limits will be determined automatically. Defaults to `None`. data_label (str | None): A custom legend label for the data series. If provided, this label will override the default constructed label using `var_symbol` and `var_suffix`. setpoint_labels (list[str] | None): A list of strings specifying custom legend labels for the setpoint series. If provided, the label at `index` will be used instead of the default label constructed using `setpoint_var_symbol`. x_axis_labels (list[str] | None): A list of strings specifying custom X-axis labels for each subplot or data index. If provided, the label at `index` will override the default "Time step $k$". y_axis_labels (list[str] | None): A list of strings specifying custom Y-axis labels for each subplot or data index. If provided, the label at `index` will override the default label constructed from `var_label` and `var_symbol`. plot_setpoint_lines (bool): Whether to plot setpoint lines. If `False`, no setpoint line will be plotted. """ T = data.shape[0] # Data length # Construct index label string based on index value index_str = f"_{index + 1}" if index != -1 else "" # Plot data series data_label_str = ( data_label if data_label else f"${var_symbol}{index_str}${var_suffix}" ) axis.plot( range(0, T), data, **data_line_params, label=data_label_str, ) # Plot setpoint if provided setpoint_label_str = ( setpoint_labels[index] if setpoint_labels else f"${setpoint_var_symbol}{index_str}$" ) if setpoint is not None and plot_setpoint_lines: axis.plot( range(0, T), setpoint, **setpoint_line_params, label=setpoint_label_str, ) # Plot bounds if provided if bounds is not None: lower_bound, upper_bound = bounds bounds_label = "Constraints" # Plot lower bound line axis.axhline(y=lower_bound, **bounds_line_params, label=bounds_label) # Plot upper bound line axis.axhline(y=upper_bound, **bounds_line_params) # Highlight initial input-output data measurement period if provided if initial_steps: # Highlight period with a grayed rectangle axis.axvspan(0, initial_steps, color="gray", alpha=0.1) # Add a vertical line at the right side of the rectangle axis.axvline( x=initial_steps, color="black", linestyle=(0, (5, 5)), linewidth=1 ) # Display initial measurement text if enabled if display_initial_text: # Get y-axis limits y_min, y_max = axis.get_ylim() # Place label at the center of the highlighted area u_init_text = axis.text( initial_steps / 2, (y_min + y_max) / 2, initial_text, fontsize=fontsize - 1, ha="center", va="center", color="black", bbox={"facecolor": "white", "edgecolor": "black"}, ) # Get initial text bounding box width init_text_width = get_text_width_in_data( text_object=u_init_text, axis=axis, fig=fig ) # Hide text box if it overflows the plot area if initial_steps < init_text_width: u_init_text.set_visible(False) # Display Data-Driven MPC control text if enabled if display_control_text: # Get y-axis limits y_min, y_max = axis.get_ylim() # Place label at the center of the remaining area u_control_text = axis.text( (T + initial_steps) / 2, (y_min + y_max) / 2, control_text, fontsize=fontsize - 1, ha="center", va="center", color="black", bbox={"facecolor": "white", "edgecolor": "black"}, ) # Get control text bounding box width control_text_width = get_text_width_in_data( text_object=u_control_text, axis=axis, fig=fig ) # Hide text box if it overflows the plot area if (T - initial_steps) < control_text_width: u_control_text.set_visible(False) # Format labels, legend and ticks x_axis_label_str = ( x_axis_labels[index] if x_axis_labels else "Time step $k$" ) y_axis_label_str = ( y_axis_labels[index] if y_axis_labels else f"{var_label} ${var_symbol}{index_str}$" ) axis.set_xlabel(x_axis_label_str, fontsize=fontsize) axis.set_ylabel(y_axis_label_str, fontsize=fontsize) axis.legend(**legend_params) axis.tick_params(axis="both", labelsize=fontsize) # Remove duplicate labels from legend (required for external figures # that plot multiple data sequences on the same plot to avoid label # repetition) and reposition labels end_labels_list = [setpoint_label_str] if bounds is not None: end_labels_list.append(bounds_label) filter_and_reorder_legend( axis=axis, legend_params=legend_params, end_labels_list=end_labels_list ) # Set x-limits axis.set_xlim((0, T - 1)) # Set y-limits if provided if plot_ylimits: axis.set_ylim(plot_ylimits)