import logging
from typing import List, TypeVar
import matplotlib
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from jmetal.lab.visualization.plotting import Plot
logger = logging.getLogger(__name__)
S = TypeVar("S") 
"""
.. module:: streaming
   :platform: Unix, Windows
   :synopsis: Classes for plotting solutions in real-time.
.. moduleauthor:: Antonio Benítez-Hidalgo <antonio.b@uma.es>
"""
[docs]
class StreamingPlot:
    def __init__(
        self,
        plot_title: str = "Pareto front approximation",
        reference_front: List[S] = None,
        reference_point: list = None,
        axis_labels: list = None,
    ):
        """
        :param plot_title: Title of the graph.
        :param axis_labels: List of axis labels.
        :param reference_point: Reference point (e.g., [0.4, 1.2]).
        :param reference_front: Reference Pareto front (if any) as solutions.
        """
        self.plot_title = plot_title
        self.axis_labels = axis_labels
        if reference_point and not isinstance(reference_point[0], list):
            reference_point = [reference_point]
        self.reference_point = reference_point
        self.reference_front = reference_front
        self.dimension = None
        import warnings
        warnings.filterwarnings("ignore", ".*GUI is implemented.*")
        self.fig, self.ax = plt.subplots()
        self.sc = None
        self.axis = None
[docs]
    def plot(self, front):
        # Get data
        points, dimension = Plot.get_points(front)
        # Create an empty figure
        self.create_layout(dimension)
        # If any reference point, plot
        if self.reference_point:
            for point in self.reference_point:
                (self.scp,) = self.ax.plot(*[[p] for p in point], c="r", ls="None", marker="*", markersize=3)
        # If any reference front, plot
        if self.reference_front:
            rpoints, _ = Plot.get_points(self.reference_front)
            (self.scf,) = self.ax.plot(
                *[rpoints[column].tolist() for column in rpoints.columns.values],
                c="k",
                ls="None",
                marker="*",
                markersize=1
            )
        # Plot data
        (self.sc,) = self.ax.plot(
            *[points[column].tolist() for column in points.columns.values], ls="None", marker="o", markersize=4
        )
        # Show plot
        plt.show(block=False) 
[docs]
    def update(self, front: List[S], reference_point: list = None) -> None:
        if self.sc is None:
            raise Exception("Figure is none")
        points, dimension = Plot.get_points(front)
        # Replace with new points
        self.sc.set_data(points[0], points[1])
        if dimension == 3:
            self.sc.set_3d_properties(points[2])
        # If any new reference point, plot
        if reference_point:
            self.scp.set_data([p[0] for p in reference_point], [p[1] for p in reference_point])
        # Re-align the axis
        self.ax.relim()
        self.ax.autoscale_view(True, True, True)
        try:
            # self.fig.canvas.draw()
            self.fig.canvas.flush_events()
        except KeyboardInterrupt:
            pass
        pause(0.01) 
[docs]
    def create_layout(self, dimension: int) -> None:
        logger.info("Creating figure layout")
        self.fig.canvas.manager.set_window_title(self.plot_title)
        self.fig.suptitle(self.plot_title, fontsize=16)
        if dimension == 2:
            # Stylize axis
            self.ax.spines["top"].set_visible(False)
            self.ax.spines["right"].set_visible(False)
            self.ax.get_xaxis().tick_bottom()
            self.ax.get_yaxis().tick_left()
        elif dimension == 3:
            self.ax = Axes3D(self.fig)
            self.ax.autoscale(enable=True, axis="both")
        else:
            raise Exception("Dimension must be either 2 or 3")
        self.ax.set_autoscale_on(True)
        self.ax.autoscale_view(True, True, True)
        # Style options
        self.ax.grid(color="#f0f0f5", linestyle="-", linewidth=0.5, alpha=0.5) 
 
def pause(interval: float):
    backend = plt.rcParams["backend"]
    if backend in matplotlib.rcsetup.interactive_bk:
        figManager = matplotlib._pylab_helpers.Gcf.get_active()
        if figManager is not None:
            canvas = figManager.canvas
            if canvas.figure.stale:
                canvas.draw()
            canvas.start_event_loop(interval)
            return