import logging
from typing import List, Tuple, TypeVar
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
logger = logging.getLogger(__name__)
S = TypeVar("S")
[docs]
class Plot:
def __init__(
self,
title: str = "Pareto front approximation",
reference_front: List[S] = None,
reference_point: list = None,
axis_labels: list = None,
):
"""
:param title: Title of the graph.
:param axis_labels: List of axis labels.
:param reference_point: Reference point (e.g., [0.4, 1.2]).
:param reference_front: Reference Pareto front (if any) as solutions.
"""
self.plot_title = title
self.axis_labels = axis_labels
if reference_point and not isinstance(reference_point[0], list):
reference_point = [reference_point]
self.reference_point = reference_point
self.reference_front = reference_front
self.dimension = None
[docs]
@staticmethod
def get_points(solutions: List[S]) -> Tuple[pd.DataFrame, int]:
"""Get points for each solution of the front.
:param solutions: List of solutions.
:return: Pandas dataframe with one column for each objective and one row for each solution.
"""
if solutions is None:
raise Exception("Front is none!")
points = pd.DataFrame(list(solution.objectives for solution in solutions))
return points, points.shape[1]
[docs]
def plot(self, front, label="", normalize: bool = False, filename: str = None, format: str = "eps"):
"""Plot any arbitrary number of fronts in 2D, 3D or p-coords.
:param front: Pareto front or a list of them.
:param label: Pareto front title or a list of them.
:param normalize: If True, normalize data (for p-coords).
:param filename: Output filename.
:param format: Output file format.
"""
if not isinstance(front[0], list):
front = [front]
if not isinstance(label, list):
label = [label]
if len(front) != len(label):
raise Exception("Number of fronts and labels must be the same")
dimension = len(front[0][0].objectives)
if dimension == 2:
self.two_dim(front, label, filename, format)
elif dimension == 3:
self.three_dim(front, label, filename, format)
else:
self.pcoords(front, normalize, filename, format)
[docs]
def two_dim(self, fronts: List[list], labels: List[str] = None, filename: str = None, format: str = "eps"):
"""Plot any arbitrary number of fronts in 2D.
:param fronts: List of fronts (containing solutions).
:param labels: List of fronts title (if any).
:param filename: Output filename.
"""
n = int(np.ceil(np.sqrt(len(fronts))))
fig = plt.figure()
fig.suptitle(self.plot_title, fontsize=16)
reference = None
if self.reference_front:
reference, _ = self.get_points(self.reference_front)
for i, _ in enumerate(fronts):
points, _ = self.get_points(fronts[i])
ax = fig.add_subplot(n, n, i + 1)
points.plot(kind="scatter", x=0, y=1, ax=ax, s=10, color="#236FA4", alpha=1.0)
if labels:
ax.set_title(labels[i])
if self.reference_front:
reference.plot(x=0, y=1, ax=ax, color="k", legend=False)
if self.reference_point:
for point in self.reference_point:
plt.plot([point[0]], [point[1]], marker="o", markersize=5, color="r")
plt.axvline(x=point[0], color="r", linestyle=":")
plt.axhline(y=point[1], color="r", linestyle=":")
if self.axis_labels:
plt.xlabel(self.axis_labels[0])
plt.ylabel(self.axis_labels[1])
if filename:
_filename = filename + "." + format
plt.savefig(_filename, format=format, dpi=1000)
logger.info("Figure {_filename} saved to file")
else:
plt.show()
plt.close(fig=fig)
[docs]
def three_dim(self, fronts: List[list], labels: List[str] = None, filename: str = None, format: str = "eps"):
"""Plot any arbitrary number of fronts in 3D.
:param fronts: List of fronts (containing solutions).
:param labels: List of fronts title (if any).
:param filename: Output filename.
"""
n = int(np.ceil(np.sqrt(len(fronts))))
fig = plt.figure()
fig.suptitle(self.plot_title, fontsize=16)
for i, _ in enumerate(fronts):
ax = fig.add_subplot(n, n, i + 1, projection="3d")
ax.scatter(
[s.objectives[0] for s in fronts[i]],
[s.objectives[1] for s in fronts[i]],
[s.objectives[2] for s in fronts[i]],
)
if labels:
ax.set_title(labels[i])
if self.reference_front:
ax.scatter(
[s.objectives[0] for s in self.reference_front],
[s.objectives[1] for s in self.reference_front],
[s.objectives[2] for s in self.reference_front],
)
if self.reference_point:
# todo
pass
ax.relim()
ax.autoscale_view(True, True, True)
ax.view_init(elev=30.0, azim=15.0)
ax.locator_params(nbins=4)
if filename:
_filename = filename + "." + format
plt.savefig(_filename, format=format, dpi=1000)
logger.info("Figure {_filename} saved to file")
else:
plt.show()
plt.close(fig=fig)
[docs]
def pcoords(self, fronts: List[list], normalize: bool = False, filename: str = None, format: str = "eps"):
"""Plot any arbitrary number of fronts in parallel coordinates.
:param fronts: List of fronts (containing solutions).
:param filename: Output filename.
"""
n = int(np.ceil(np.sqrt(len(fronts))))
fig = plt.figure()
fig.suptitle(self.plot_title, fontsize=16)
for i, _ in enumerate(fronts):
points, _ = self.get_points(fronts[i])
if normalize:
points = (points - points.min()) / (points.max() - points.min())
ax = fig.add_subplot(n, n, i + 1)
min_, max_ = points.values.min(), points.values.max()
points["scale"] = np.linspace(0, 1, len(points)) * (max_ - min_) + min_
pd.plotting.parallel_coordinates(points, "scale", ax=ax)
ax.get_legend().remove()
if self.axis_labels:
ax.set_xticklabels(self.axis_labels)
if filename:
plt.savefig(filename + "." + format, format=format, dpi=1000)
else:
plt.show()
plt.close(fig=fig)