Source code for jmetal.lab.visualization.interactive

import logging
from typing import List, TypeVar

import pandas as pd
from plotly import graph_objs as go
from plotly import io as pio
from plotly import offline

from jmetal.lab.visualization.plotting import Plot

logger = logging.getLogger(__name__)

S = TypeVar("S")


[docs] class InteractivePlot(Plot): def __init__( self, title: str = "Pareto front approximation", reference_front: List[S] = None, reference_point: list = None, axis_labels: list = None, ): super(InteractivePlot, self).__init__(title, reference_front, reference_point, axis_labels) self.figure = None self.layout = None self.data = []
[docs] def plot(self, front, label=None, normalize: bool = False, filename: str = None, format: str = "HTML"): """Plot a front of solutions (2D, 3D or parallel coordinates). :param front: List of solutions. :param label: Front name. :param normalize: Normalize the input front between 0 and 1 (for problems with more than 3 objectives). :param filename: Output filename. """ if not isinstance(label, list): label = [label] self.layout = go.Layout( margin=dict(l=80, r=80, b=80, t=150), height=800, title="{}<br>{}".format(self.plot_title, label[0]), scene=dict( xaxis=dict(title=self.axis_labels[0:1][0] if self.axis_labels[0:1] else None), yaxis=dict(title=self.axis_labels[1:2][0] if self.axis_labels[1:2] else None), zaxis=dict(title=self.axis_labels[2:3][0] if self.axis_labels[2:3] else None), ), hovermode="closest", ) # If any reference front, plot if self.reference_front: points, _ = self.get_points(self.reference_front) trace = self.__generate_trace( points=points, legend="Reference front", normalize=normalize, color="black", size=2 ) self.data.append(trace) # If any reference point, plot if self.reference_point: points = pd.DataFrame(self.reference_point) trace = self.__generate_trace(points=points, legend="Reference point", color="red", size=8) self.data.append(trace) # Get points and metadata points, _ = self.get_points(front) metadata = list(solution.__str__() for solution in front) trace = self.__generate_trace( points=points, metadata=metadata, legend="Front approximation", normalize=normalize ) self.data.append(trace) self.figure = go.Figure(data=self.data, layout=self.layout) # Plot the figure if filename: if format == "HTML": self.export_to_html(filename) logger.info("Figure {_filename} exported to HTML file") else: _filename = filename + "." + format pio.write_image(self.figure, _filename) logger.info("Figure {_filename} saved to file")
[docs] def export_to_html(self, filename: str) -> str: """Export the graph to an interactive HTML (solutions can be selected to show some metadata). :param filename: Output file name. :return: Script as string.""" html_string = ( """ <!DOCTYPE html> <html> <head> <meta charset="utf-8"/> <script src="https://cdn.plot.ly/plotly-latest.min.js"></script> <script src="https://unpkg.com/sweetalert2@7.7.0/dist/sweetalert2.all.js"></script> <link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.1/css/bootstrap.min.css"> </head> <body> """ + self.export_to_div(filename=None, include_plotlyjs=False) + """ <script> var myPlot = document.querySelectorAll('div')[0]; myPlot.on('plotly_click', function(data){ var pts = ''; for(var i=0; i < data.points.length; i++){ pts = '(x, y) = ('+data.points[i].x +', '+ data.points[i].y.toPrecision(4)+')'; cs = data.points[i].customdata } if(typeof cs !== "undefined"){ swal({ title: 'Closest solution clicked:', text: cs, type: 'info', position: 'bottom-end' }) } }); window.onresize = function() { Plotly.Plots.resize(myPlot); }; </script> </body> </html>""" ) with open(filename + ".html", "w") as outf: outf.write(html_string) return html_string
[docs] def export_to_div(self, filename=None, include_plotlyjs: bool = False) -> str: """Export as a `div` for embedding the graph in an HTML file. :param filename: Output file name (if desired, default to None). :param include_plotlyjs: If True, include plot.ly JS script (default to False). :return: Script as string. """ script = offline.plot(self.figure, output_type="div", include_plotlyjs=include_plotlyjs, show_link=False) if filename: with open(filename + ".html", "w") as outf: outf.write(script) return script
def __generate_trace( self, points: pd.DataFrame, legend: str, metadata: list = None, normalize: bool = False, **kwargs ): dimension = points.shape[1] # tweak points size for 3D plots marker_size = 8 if dimension == 3: marker_size = 4 # if indicated, perform normalization if normalize: points = (points - points.min()) / (points.max() - points.min()) marker = dict( color="#236FA4", size=marker_size, symbol="circle", line=dict(color="#236FA4", width=1), opacity=0.8 ) marker.update(**kwargs) if dimension == 2: trace = go.Scattergl( x=points[0], y=points[1], mode="markers", marker=marker, name=legend, customdata=metadata ) elif dimension == 3: trace = go.Scatter3d( x=points[0], y=points[1], z=points[2], mode="markers", marker=marker, name=legend, customdata=metadata ) else: dimensions = list() for column in points: dimensions.append( dict( range=[0, 1], label=self.axis_labels[column : column + 1][0] if self.axis_labels[column : column + 1] else None, values=points[column], ) ) trace = go.Parcoords( line=dict(color="#236FA4"), dimensions=dimensions, name=legend, ) return trace