Source code for direct_data_driven_mpc.utilities.visualization.control_plot
"""
Functions for plotting control input-output data.
This module provides functions for plotting input-output trajectories with
setpoints using Matplotlib. It creates highly customizable figures with
separate subplots for inputs and outputs, with optional highlighting of the
initial measurement period for data-driven control systems.
"""
from typing import Any
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from matplotlib.figure import Figure, SubFigure
from .plot_utilities import (
    create_input_output_figure,
    filter_and_reorder_legend,
    get_text_width_in_data,
    init_dict_if_none,
    validate_data_dimensions,
)
[docs]
def plot_input_output(
    u_k: np.ndarray,
    y_k: np.ndarray,
    y_s: np.ndarray,
    u_s: np.ndarray | None = None,
    u_bounds_list: list[tuple[float, float]] | None = None,
    y_bounds_list: list[tuple[float, float]] | None = None,
    inputs_line_params: dict[str, Any] | None = None,
    outputs_line_params: dict[str, Any] | None = None,
    setpoints_line_params: dict[str, Any] | None = None,
    bounds_line_params: dict[str, Any] | None = None,
    u_setpoint_var_symbol: str = "u^s",
    y_setpoint_var_symbol: str = "y^s",
    initial_steps: int | None = None,
    initial_excitation_text: str = "Init. Excitation",
    initial_measurement_text: str = "Init. Measurement",
    control_text: str = "Data-Driven MPC",
    display_initial_text: bool = True,
    display_control_text: bool = True,
    figsize: tuple[float, float] = (12.0, 8.0),
    dpi: int = 300,
    u_ylimits_list: list[tuple[float, float]] | None = None,
    y_ylimits_list: list[tuple[float, float]] | None = None,
    fontsize: int = 12,
    legend_params: dict[str, Any] | None = None,
    var_suffix: str = "",
    axs_u: list[Axes] | None = None,
    axs_y: list[Axes] | None = None,
    title: str | None = None,
    input_label: str | None = None,
    output_label: str | None = None,
    u_setpoint_labels: list[str] | None = None,
    y_setpoint_labels: list[str] | None = None,
    x_axis_labels: list[str] | None = None,
    input_y_axis_labels: list[str] | None = None,
    output_y_axis_labels: list[str] | None = None,
    plot_setpoint_lines: bool = True,
) -> None:
    """
    Plot input-output data with setpoints in a Matplotlib figure.
    This function creates a figure with two rows of subplots, with the first
    row containing control inputs, and the second row, system outputs. Each
    subplot shows the data series for each data sequence alongside its
    setpoint. The appearance of plot lines and legends can be customized by
    passing dictionaries of Matplotlib line and legend properties.
    If provided, the first 'initial_steps' time steps are highlighted to
    emphasize the initial input-output data measurement period representing
    the data-driven system characterization phase in a Data-Driven MPC
    algorithm. Additionally, custom labels can be displayed to indicate the
    initial measurement and the subsequent MPC control periods, but only if
    there is enough space to prevent them from overlapping with other plot
    elements.
    Note:
        If `axs_u` and `axs_y` are provided, the data will be plotted on these
        external axes and no new figure will be created. This allows for
        multiple data sequences to be plotted on the same external figure.
        Each data sequence can be differentiated  using the `data_label`
        argument.
    Args:
        u_k (np.ndarray): An array containing control input data of shape (T,
            m), where `m` is the number of inputs and `T` is the number of
            time steps.
        y_k (np.ndarray): An array containing system output data of shape (T,
            p), where `p` is the number of outputs and `T` is the number of
            time steps.
        y_s (np.ndarray): An array containing output setpoint values of shape
            (T, p), where `p` is the number of outputs and `T` is the number of
            time steps.
        u_s (np.ndarray | None): An array containing input setpoint values of
            shape (T, m), where `m` is the number of inputs and `T` is the
            number of time steps. If `None`, input setpoint lines will not be
            plotted. Defaults to `None`.
        u_bounds_list (list[tuple[float, float]] | None): A list of tuples
            (lower_bound, upper_bound) specifying bounds for each input data
            sequence. If provided, horizontal lines representing these bounds
            will be plotted in each subplot. If `None`, no horizontal lines
            will be plotted. The number of tuples must match the number of
            input data sequences. Defaults to `None`.
        y_bounds_list (list[tuple[float, float]] | None): A list of tuples
            (lower_bound, upper_bound) specifying bounds for each output data
            sequence. If provided, horizontal lines representing these bounds
            will be plotted in each subplot. If `None`, no horizontal lines
            will be plotted. The number of tuples must match the number of
            output data sequences. Defaults to `None`.
        inputs_line_params (dict[str, Any] | None): A dictionary of Matplotlib
            properties for customizing the lines used to plot the input data
            series (e.g., color, linestyle, linewidth). If not provided,
            Matplotlib's default line properties will be used.
        outputs_line_params (dict[str, Any] | None): A dictionary of Matplotlib
            properties for customizing the lines used to plot the output data
            series (e.g., color, linestyle, linewidth). If not provided,
            Matplotlib's default line properties will be used.
        setpoints_line_params (dict[str, Any] | None): A dictionary of
            Matplotlib properties for customizing the lines used to plot the
            setpoint values (e.g., color, linestyle, linewidth). If not
            provided, Matplotlib's default line properties will be used.
        bounds_line_params (dict[str, Any] | None): A dictionary of Matplotlib
            properties for customizing the lines used to plot the bounds of
            input-output data series (e.g., color, linestyle, linewidth). If
            not provided, Matplotlib's default line properties will be used.
        u_setpoint_var_symbol (str): The variable symbol used to label the
            input setpoint data series (e.g., "u^s").
        y_setpoint_var_symbol (str): The variable symbol used to label the
            output setpoint data series (e.g., "y^s").
        initial_steps (int | None): The number of initial time steps during
            which input-output measurements were taken for the data-driven
            characterization of the system. This highlights the initial
            measurement period in the plot. If `None`, no special highlighting
            will be applied. Defaults to `None`.
        initial_excitation_text (str): Label text to display over the initial
            excitation period of the input plots. Default is
            "Init. Excitation".
        initial_measurement_text (str): Label text to display over the initial
            measurement period of the output plots. Default is
            "Init. Measurement".
        control_text (str): Label text to display over the post-initial
            control period. Default is "Data-Driven MPC".
        display_initial_text (bool): Whether to display the `initial_text`
            label on the plot. Default is True.
        display_control_text (bool): Whether to display the `control_text`
            label on the plot. Default is True.
        figsize (tuple[float, float]): The (width, height) dimensions of the
            created Matplotlib figure.
        dpi (int): The DPI resolution of the figure.
        u_ylimits_list (list[tuple[float, float]] | None): A list of tuples
            (lower_limit, upper_limit) specifying the Y-axis limits for each
            input subplot. If `None`, the Y-axis limits will be determined
            automatically. Defaults to `None`.
        y_ylimits_list (list[tuple[float, float]] | None): A list of tuples
            (lower_limit, upper_limit) specifying the Y-axis limits for each
            output subplot. If `None`, the Y-axis limits will be determined
            automatically. Defaults to `None`.
        fontsize (int): The fontsize for labels and axes ticks.
        legend_params (dict[str, Any] | None): A dictionary of Matplotlib
            properties for customizing the plot legends (e.g., fontsize,
            loc, handlelength). If not provided, Matplotlib's default legend
            properties will be used.
        var_suffix (str): A string appended to each data series label in the
            plot legend.
        axs_u (list[Axes] | None): A list of external axes for input plots.
            Defaults to `None`.
        axs_y (list[Axes] | None): A list of external axes for output plots.
            Defaults to `None`.
        title (str | None): The title for the created plot figure. Set only if
            the figure is created internally (i.e., `axs_u` and `axs_y` are not
            provided). If `None`, no title will be displayed. Defaults to
            `None`.
        input_label (str | None): A custom legend label for the input data
            series. If provided, this label will override the default label
            constructed using `var_suffix`.
        output_label (str | None): A custom legend label for the output data
            series. If provided, this label will override the default label
            constructed using `var_suffix`.
        u_setpoint_labels (list[str] | None): A list of strings specifying
            custom legend labels for input setpoint series. If provided, the
            label at each index will override the default label constructed
            using `u_setpoint_var_symbol`.
        y_setpoint_labels (list[str] | None): A list of strings specifying
            custom legend labels for output setpoint series. If provided, the
            label at each index will override the default label constructed
            using `y_setpoint_var_symbol`.
        x_axis_labels (list[str] | None): A list of strings specifying custom
            X-axis labels for each subplot. If provided, the label at each
            index will override the default "Time step $k$".
        input_y_axis_labels (list[str] | None): A list of strings specifying
            custom Y-axis labels for each input subplot. If provided, the label
            at each index will override the default constructed labels.
        output_y_axis_labels (list[str] | None): A list of strings specifying
            custom Y-axis labels for each output subplot. If provided, the
            label at each index will override the default constructed labels.
        plot_setpoint_lines (bool): Whether to plot setpoint lines. If `False`,
            no setpoint line will be plotted. Used for avoiding duplicate
            setpoint entries in multi-data plots. Defaults to `True`.
    Raises:
        ValueError: If any array dimensions mismatch expected shapes, or if
            the lengths of `u_bounds_list`, `y_bounds_list`, `u_ylimits_list`,
            or `y_ylimits_list` do not match the number of subplots.
    """
    # Validate data dimensions
    validate_data_dimensions(
        u_k=u_k,
        y_k=y_k,
        u_s=u_s,
        y_s=y_s,
        u_bounds_list=u_bounds_list,
        y_bounds_list=y_bounds_list,
        u_ylimits_list=u_ylimits_list,
        y_ylimits_list=y_ylimits_list,
        u_setpoint_labels=u_setpoint_labels,
        y_setpoint_labels=y_setpoint_labels,
        x_axis_labels=x_axis_labels,
        input_y_axis_labels=input_y_axis_labels,
        output_y_axis_labels=output_y_axis_labels,
    )
    # Initialize Matplotlib params if not provided
    inputs_line_params = init_dict_if_none(inputs_line_params)
    outputs_line_params = init_dict_if_none(outputs_line_params)
    setpoints_line_params = init_dict_if_none(setpoints_line_params)
    bounds_line_params = init_dict_if_none(bounds_line_params)
    legend_params = init_dict_if_none(legend_params)
    # Retrieve number of input and output data sequences
    m = u_k.shape[1]  # Number of inputs
    p = y_k.shape[1]  # Number of outputs
    # Create figure if lists of Axes are not provided
    is_ext_fig = axs_u is not None and axs_y is not None  # External figure
    fig: Figure | SubFigure
    if not is_ext_fig:
        # Create figure and subplots
        fig, axs_u, axs_y = create_input_output_figure(
            m=m, p=p, figsize=figsize, dpi=dpi, fontsize=fontsize, title=title
        )
    else:
        assert axs_u is not None  # Prevent mypy [index] error
        # Use figure from the provided axes
        fig = axs_u[0].figure
    # Plot input data
    m = u_k.shape[1]  # Number of inputs
    p = y_k.shape[1]  # Number of outputs
    for i in range(m):
        # Get input setpoint if provided
        u_setpoint = u_s[:, i] if u_s is not None else None
        # Define plot index based on the number of input plots
        plot_index = -1 if m == 1 else i
        # Get input bounds if provided
        u_bounds = u_bounds_list[i] if u_bounds_list else None
        # Get plot Y-axis limit if provided
        u_plot_ylimit = u_ylimits_list[i] if u_ylimits_list else None
        # Prevent mypy [index] error
        assert axs_u is not None
        # Plot data
        plot_data(
            axis=axs_u[i],
            data=u_k[:, i],
            setpoint=u_setpoint,
            index=plot_index,
            data_line_params=inputs_line_params,
            bounds_line_params=bounds_line_params,
            setpoint_line_params=setpoints_line_params,
            var_symbol="u",
            setpoint_var_symbol=u_setpoint_var_symbol,
            var_label="Input",
            var_suffix=var_suffix,
            initial_text=initial_excitation_text,
            control_text=control_text,
            display_initial_text=display_initial_text,
            display_control_text=display_control_text,
            fontsize=fontsize,
            legend_params=legend_params,
            fig=fig,
            bounds=u_bounds,
            initial_steps=initial_steps,
            plot_ylimits=u_plot_ylimit,
            data_label=input_label,
            setpoint_labels=u_setpoint_labels,
            x_axis_labels=x_axis_labels,
            y_axis_labels=input_y_axis_labels,
            plot_setpoint_lines=plot_setpoint_lines,
        )
    # Plot output data
    for j in range(p):
        # Define plot index based on the number of output plots
        plot_index = -1 if p == 1 else j
        # Get output bounds if provided
        y_bounds = y_bounds_list[j] if y_bounds_list else None
        # Get plot Y-axis limit if provided
        y_plot_ylimits = y_ylimits_list[j] if y_ylimits_list else None
        # Prevent mypy [index] error
        assert axs_y is not None
        # Plot data
        plot_data(
            axis=axs_y[j],
            data=y_k[:, j],
            setpoint=y_s[:, j],
            index=plot_index,
            data_line_params=outputs_line_params,
            bounds_line_params=bounds_line_params,
            setpoint_line_params=setpoints_line_params,
            var_symbol="y",
            setpoint_var_symbol=y_setpoint_var_symbol,
            var_label="Output",
            var_suffix=var_suffix,
            initial_text=initial_measurement_text,
            control_text=control_text,
            display_initial_text=display_initial_text,
            display_control_text=display_control_text,
            fontsize=fontsize,
            legend_params=legend_params,
            fig=fig,
            bounds=y_bounds,
            initial_steps=initial_steps,
            plot_ylimits=y_plot_ylimits,
            data_label=output_label,
            setpoint_labels=y_setpoint_labels,
            x_axis_labels=x_axis_labels,
            y_axis_labels=output_y_axis_labels,
            plot_setpoint_lines=plot_setpoint_lines,
        )
    # Show the plot if the figure was created internally
    if not is_ext_fig:
        plt.show()
[docs]
def plot_data(
    axis: Axes,
    data: np.ndarray,
    setpoint: np.ndarray | None,
    index: int,
    data_line_params: dict[str, Any],
    setpoint_line_params: dict[str, Any],
    bounds_line_params: dict[str, Any],
    var_symbol: str,
    setpoint_var_symbol: str,
    var_label: str,
    var_suffix: str,
    initial_text: str,
    control_text: str,
    display_initial_text: bool,
    display_control_text: bool,
    fontsize: int,
    legend_params: dict[str, Any],
    fig: Figure | SubFigure,
    bounds: tuple[float, float] | None = None,
    initial_steps: int | None = None,
    plot_ylimits: tuple[float, float] | None = None,
    data_label: str | None = None,
    setpoint_labels: list[str] | None = None,
    x_axis_labels: list[str] | None = None,
    y_axis_labels: list[str] | None = None,
    plot_setpoint_lines: bool = True,
) -> None:
    """
    Plot a data series with setpoints in a specified axis. Optionally,
    highlight the initial measurement and control phases using shaded regions
    and text labels. The labels will be displayed if there is enough space to
    prevent them from overlapping with other plot elements.
    Note:
        The appearance of plot lines and legend can be customized by passing
        dictionaries of Matplotlib line and legend properties.
    Args:
        axis (Axes): The Matplotlib axis object to plot on.
        data (np.ndarray): An array containing data to be plotted.
        setpoint (np.ndarray | None): An array containing setpoint values to be
            plotted. If `None`, the setpoint line will not be plotted.
        index (int): The index of the data used for labeling purposes (e.g.,
            "u_1", "u_2"). If set to -1, subscripts will not be added to
            labels.
        data_line_params (dict[str, Any]): A dictionary of Matplotlib
            properties for customizing the line used to plot the data series
            (e.g., color, linestyle, linewidth).
        setpoint_line_params (dict[str, Any]): A dictionary of Matplotlib
            properties for customizing the line used to plot the setpoint
            value (e.g., color, linestyle, linewidth).
        bounds_line_params (dict[str, Any]): A dictionary of Matplotlib
            properties for customizing the lines used to plot the bounds of
            the data series (e.g., color, linestyle, linewidth).
        var_symbol (str): The variable symbol used to label the data series
            (e.g., "u" for inputs, "y" for outputs).
        setpoint_var_symbol (str): The variable symbol used to label the
            setpoint data series (e.g., "u^s" for inputs, "y^s" for outputs).
        var_label (str): The variable label representing the control signal
            (e.g., "Input", "Output").
        var_suffix (str): A string appended to each data series label in the
            plot legend.
        initial_text (str): Label text to display over the initial measurement
            period of the plot.
        control_text (str): Label text to display over the post-initial
            control period.
        display_initial_text (bool): Whether to display the `initial_text`
            label on the plot.
        display_control_text (bool): Whether to display the `control_text`
            label on the plot.
        fontsize (int): The fontsize for labels and axes ticks.
        legend_params (dict[str, Any]): A dictionary of Matplotlib properties
            for customizing the plot legend (e.g., fontsize, loc,
            handlelength).
        fig (Figure | SubFigure): The Matplotlib figure or subfigure that
            contains the axis.
        bounds (tuple[float, float] | None): A tuple (lower_bound,
            upper_bound) specifying the bounds of the data to be plotted. If
            provided, horizontal lines representing these bounds will be
            plotted. Defaults to `None`.
        initial_steps (int | None): The number of initial time steps during
            which input-output measurements were taken for the data-driven
            characterization of the system. This highlights the initial
            measurement period in the plot. Defaults to `None`.
        plot_ylimits (tuple[float, float] | None): A tuple (lower_limit,
            upper_limit) specifying the Y-axis limits for the plot. If `None`,
            the Y-axis limits will be determined automatically. Defaults to
            `None`.
        data_label (str | None): A custom legend label for the data series. If
            provided, this label will override the default constructed label
            using `var_symbol` and `var_suffix`.
        setpoint_labels (list[str] | None): A list of strings specifying custom
            legend labels for the setpoint series. If provided, the label at
            `index` will be used instead of the default label constructed using
            `setpoint_var_symbol`.
        x_axis_labels (list[str] | None): A list of strings specifying custom
            X-axis labels for each subplot or data index. If provided, the
            label at `index` will override the default "Time step $k$".
        y_axis_labels (list[str] | None): A list of strings specifying custom
            Y-axis labels for each subplot or data index. If provided, the
            label at `index` will override the default label constructed from
            `var_label` and `var_symbol`.
        plot_setpoint_lines (bool): Whether to plot setpoint lines. If `False`,
            no setpoint line will be plotted.
    """
    T = data.shape[0]  # Data length
    # Construct index label string based on index value
    index_str = f"_{index + 1}" if index != -1 else ""
    # Plot data series
    data_label_str = (
        data_label if data_label else f"${var_symbol}{index_str}${var_suffix}"
    )
    axis.plot(
        range(0, T),
        data,
        **data_line_params,
        label=data_label_str,
    )
    # Plot setpoint if provided
    setpoint_label_str = (
        setpoint_labels[index]
        if setpoint_labels
        else f"${setpoint_var_symbol}{index_str}$"
    )
    if setpoint is not None and plot_setpoint_lines:
        axis.plot(
            range(0, T),
            setpoint,
            **setpoint_line_params,
            label=setpoint_label_str,
        )
    # Plot bounds if provided
    if bounds is not None:
        lower_bound, upper_bound = bounds
        bounds_label = "Constraints"
        # Plot lower bound line
        axis.axhline(y=lower_bound, **bounds_line_params, label=bounds_label)
        # Plot upper bound line
        axis.axhline(y=upper_bound, **bounds_line_params)
    # Highlight initial input-output data measurement period if provided
    if initial_steps:
        # Highlight period with a grayed rectangle
        axis.axvspan(0, initial_steps, color="gray", alpha=0.1)
        # Add a vertical line at the right side of the rectangle
        axis.axvline(
            x=initial_steps, color="black", linestyle=(0, (5, 5)), linewidth=1
        )
        # Display initial measurement text if enabled
        if display_initial_text:
            # Get y-axis limits
            y_min, y_max = axis.get_ylim()
            # Place label at the center of the highlighted area
            u_init_text = axis.text(
                initial_steps / 2,
                (y_min + y_max) / 2,
                initial_text,
                fontsize=fontsize - 1,
                ha="center",
                va="center",
                color="black",
                bbox={"facecolor": "white", "edgecolor": "black"},
            )
            # Get initial text bounding box width
            init_text_width = get_text_width_in_data(
                text_object=u_init_text, axis=axis, fig=fig
            )
            # Hide text box if it overflows the plot area
            if initial_steps < init_text_width:
                u_init_text.set_visible(False)
        # Display Data-Driven MPC control text if enabled
        if display_control_text:
            # Get y-axis limits
            y_min, y_max = axis.get_ylim()
            # Place label at the center of the remaining area
            u_control_text = axis.text(
                (T + initial_steps) / 2,
                (y_min + y_max) / 2,
                control_text,
                fontsize=fontsize - 1,
                ha="center",
                va="center",
                color="black",
                bbox={"facecolor": "white", "edgecolor": "black"},
            )
            # Get control text bounding box width
            control_text_width = get_text_width_in_data(
                text_object=u_control_text, axis=axis, fig=fig
            )
            # Hide text box if it overflows the plot area
            if (T - initial_steps) < control_text_width:
                u_control_text.set_visible(False)
    # Format labels, legend and ticks
    x_axis_label_str = (
        x_axis_labels[index] if x_axis_labels else "Time step $k$"
    )
    y_axis_label_str = (
        y_axis_labels[index]
        if y_axis_labels
        else f"{var_label} ${var_symbol}{index_str}$"
    )
    axis.set_xlabel(x_axis_label_str, fontsize=fontsize)
    axis.set_ylabel(y_axis_label_str, fontsize=fontsize)
    axis.legend(**legend_params)
    axis.tick_params(axis="both", labelsize=fontsize)
    # Remove duplicate labels from legend (required for external figures
    # that plot multiple data sequences on the same plot to avoid label
    # repetition) and reposition labels
    end_labels_list = [setpoint_label_str]
    if bounds is not None:
        end_labels_list.append(bounds_label)
    filter_and_reorder_legend(
        axis=axis, legend_params=legend_params, end_labels_list=end_labels_list
    )
    # Set x-limits
    axis.set_xlim((0, T - 1))
    # Set y-limits if provided
    if plot_ylimits:
        axis.set_ylim(plot_ylimits)