Source code for jmetal.lab.visualization.plotting

import logging
from typing import TypeVar, List, Tuple

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from pandas import plotting

LOGGER = logging.getLogger('jmetal')

S = TypeVar('S')


[docs]class Plot: def __init__(self, title: str = 'Pareto front approximation', reference_front: List[S] = None, reference_point: list = None, axis_labels: list = None): """ :param title: Title of the graph. :param axis_labels: List of axis labels. :param reference_point: Reference point (e.g., [0.4, 1.2]). :param reference_front: Reference Pareto front (if any) as solutions. """ self.plot_title = title self.axis_labels = axis_labels if reference_point and not isinstance(reference_point[0], list): reference_point = [reference_point] self.reference_point = reference_point self.reference_front = reference_front self.dimension = None
[docs] @staticmethod def get_points(solutions: List[S]) -> Tuple[pd.DataFrame, int]: """ Get points for each solution of the front. :param solutions: List of solutions. :return: Pandas dataframe with one column for each objective and one row for each solution. """ if solutions is None: raise Exception('Front is none!') points = pd.DataFrame(list(solution.objectives for solution in solutions)) return points, points.shape[1]
[docs] def plot(self, front, label='', normalize: bool = False, filename: str = None, format: str = 'eps'): """ Plot any arbitrary number of fronts in 2D, 3D or p-coords. :param front: Pareto front or a list of them. :param label: Pareto front title or a list of them. :param normalize: If True, normalize data (for p-coords). :param filename: Output filename. :param format: Output file format. """ if not isinstance(front[0], list): front = [front] if not isinstance(label, list): label = [label] if len(front) != len(label): raise Exception('Number of fronts and labels must be the same') dimension = front[0][0].number_of_objectives if dimension == 2: self.two_dim(front, label, filename, format) elif dimension == 3: self.three_dim(front, label, filename, format) else: self.pcoords(front, normalize, filename, format)
[docs] def two_dim(self, fronts: List[list], labels: List[str] = None, filename: str = None, format: str = 'eps'): """ Plot any arbitrary number of fronts in 2D. :param fronts: List of fronts (containing solutions). :param labels: List of fronts title (if any). :param filename: Output filename. """ n = int(np.ceil(np.sqrt(len(fronts)))) fig = plt.figure() fig.suptitle(self.plot_title, fontsize=16) reference = None if self.reference_front: reference, _ = self.get_points(self.reference_front) for i, _ in enumerate(fronts): points, _ = self.get_points(fronts[i]) ax = fig.add_subplot(n, n, i + 1) points.plot(kind='scatter', x=0, y=1, ax=ax, s=10, color='#236FA4', alpha=1.0) if labels: ax.set_title(labels[i]) if self.reference_front: reference.plot(x=0, y=1, ax=ax, color='k', legend=False) if self.reference_point: for point in self.reference_point: plt.plot([point[0]], [point[1]], marker='o', markersize=5, color='r') plt.axvline(x=point[0], color='r', linestyle=':') plt.axhline(y=point[1], color='r', linestyle=':') if self.axis_labels: plt.xlabel(self.axis_labels[0]) plt.ylabel(self.axis_labels[1]) if filename: plt.savefig(filename + '.' + format, format=format, dpi=200) else: plt.show() plt.close(fig=fig)
[docs] def three_dim(self, fronts: List[list], labels: List[str] = None, filename: str = None, format: str = 'eps'): """ Plot any arbitrary number of fronts in 3D. :param fronts: List of fronts (containing solutions). :param labels: List of fronts title (if any). :param filename: Output filename. """ n = int(np.ceil(np.sqrt(len(fronts)))) fig = plt.figure() fig.suptitle(self.plot_title, fontsize=16) for i, _ in enumerate(fronts): ax = fig.add_subplot(n, n, i + 1, projection='3d') ax.scatter([s.objectives[0] for s in fronts[i]], [s.objectives[1] for s in fronts[i]], [s.objectives[2] for s in fronts[i]]) if labels: ax.set_title(labels[i]) if self.reference_front: ax.scatter([s.objectives[0] for s in self.reference_front], [s.objectives[1] for s in self.reference_front], [s.objectives[2] for s in self.reference_front]) if self.reference_point: # todo pass ax.relim() ax.autoscale_view(True, True, True) ax.view_init(elev=30.0, azim=15.0) ax.locator_params(nbins=4) if filename: plt.savefig(filename + '.' + format, format=format, dpi=1000) else: plt.show() plt.close(fig=fig)
[docs] def pcoords(self, fronts: List[list], normalize: bool = False, filename: str = None, format: str = 'eps'): """ Plot any arbitrary number of fronts in parallel coordinates. :param fronts: List of fronts (containing solutions). :param filename: Output filename. """ n = int(np.ceil(np.sqrt(len(fronts)))) fig = plt.figure() fig.suptitle(self.plot_title, fontsize=16) for i, _ in enumerate(fronts): points, _ = self.get_points(fronts[i]) if normalize: points = (points - points.min()) / (points.max() - points.min()) ax = fig.add_subplot(n, n, i + 1) min_, max_ = points.values.min(), points.values.max() points['scale'] = np.linspace(0, 1, len(points)) * (max_ - min_) + min_ pd.plotting.parallel_coordinates(points, 'scale', ax=ax) ax.get_legend().remove() if self.axis_labels: ax.set_xticklabels(self.axis_labels) if filename: plt.savefig(filename + '.' + format, format=format, dpi=1000) else: plt.show() plt.close(fig=fig)