383 lines
12 KiB
Python
383 lines
12 KiB
Python
"""
|
|
Utility Routines for Working with Matplotlib Objects
|
|
====================================================
|
|
"""
|
|
import itertools
|
|
import io
|
|
import base64
|
|
|
|
import numpy as np
|
|
|
|
import warnings
|
|
|
|
import matplotlib
|
|
from matplotlib.colors import colorConverter
|
|
from matplotlib.path import Path
|
|
from matplotlib.markers import MarkerStyle
|
|
from matplotlib.transforms import Affine2D
|
|
from matplotlib import ticker
|
|
|
|
|
|
def export_color(color):
|
|
"""Convert matplotlib color code to hex color or RGBA color"""
|
|
if color is None or colorConverter.to_rgba(color)[3] == 0:
|
|
return "none"
|
|
elif colorConverter.to_rgba(color)[3] == 1:
|
|
rgb = colorConverter.to_rgb(color)
|
|
return "#{0:02X}{1:02X}{2:02X}".format(*(int(255 * c) for c in rgb))
|
|
else:
|
|
c = colorConverter.to_rgba(color)
|
|
return (
|
|
"rgba("
|
|
+ ", ".join(str(int(np.round(val * 255))) for val in c[:3])
|
|
+ ", "
|
|
+ str(c[3])
|
|
+ ")"
|
|
)
|
|
|
|
|
|
def _many_to_one(input_dict):
|
|
"""Convert a many-to-one mapping to a one-to-one mapping"""
|
|
return dict((key, val) for keys, val in input_dict.items() for key in keys)
|
|
|
|
|
|
LINESTYLES = _many_to_one(
|
|
{
|
|
("solid", "-", (None, None)): "none",
|
|
("dashed", "--"): "6,6",
|
|
("dotted", ":"): "2,2",
|
|
("dashdot", "-."): "4,4,2,4",
|
|
("", " ", "None", "none"): None,
|
|
}
|
|
)
|
|
|
|
|
|
def get_dasharray(obj):
|
|
"""Get an SVG dash array for the given matplotlib linestyle
|
|
|
|
Parameters
|
|
----------
|
|
obj : matplotlib object
|
|
The matplotlib line or path object, which must have a get_linestyle()
|
|
method which returns a valid matplotlib line code
|
|
|
|
Returns
|
|
-------
|
|
dasharray : string
|
|
The HTML/SVG dasharray code associated with the object.
|
|
"""
|
|
if obj.__dict__.get("_dashSeq", None) is not None:
|
|
return ",".join(map(str, obj._dashSeq))
|
|
else:
|
|
ls = obj.get_linestyle()
|
|
dasharray = LINESTYLES.get(ls, "not found")
|
|
if dasharray == "not found":
|
|
warnings.warn(
|
|
"line style '{0}' not understood: "
|
|
"defaulting to solid line.".format(ls)
|
|
)
|
|
dasharray = LINESTYLES["solid"]
|
|
return dasharray
|
|
|
|
|
|
PATH_DICT = {
|
|
Path.LINETO: "L",
|
|
Path.MOVETO: "M",
|
|
Path.CURVE3: "S",
|
|
Path.CURVE4: "C",
|
|
Path.CLOSEPOLY: "Z",
|
|
}
|
|
|
|
|
|
def SVG_path(path, transform=None, simplify=False):
|
|
"""Construct the vertices and SVG codes for the path
|
|
|
|
Parameters
|
|
----------
|
|
path : matplotlib.Path object
|
|
|
|
transform : matplotlib transform (optional)
|
|
if specified, the path will be transformed before computing the output.
|
|
|
|
Returns
|
|
-------
|
|
vertices : array
|
|
The shape (M, 2) array of vertices of the Path. Note that some Path
|
|
codes require multiple vertices, so the length of these vertices may
|
|
be longer than the list of path codes.
|
|
path_codes : list
|
|
A length N list of single-character path codes, N <= M. Each code is
|
|
a single character, in ['L','M','S','C','Z']. See the standard SVG
|
|
path specification for a description of these.
|
|
"""
|
|
if transform is not None:
|
|
path = path.transformed(transform)
|
|
|
|
vc_tuples = [
|
|
(vertices if path_code != Path.CLOSEPOLY else [], PATH_DICT[path_code])
|
|
for (vertices, path_code) in path.iter_segments(simplify=simplify)
|
|
]
|
|
|
|
if not vc_tuples:
|
|
# empty path is a special case
|
|
return np.zeros((0, 2)), []
|
|
else:
|
|
vertices, codes = zip(*vc_tuples)
|
|
vertices = np.array(list(itertools.chain(*vertices))).reshape(-1, 2)
|
|
return vertices, list(codes)
|
|
|
|
|
|
def get_path_style(path, fill=True):
|
|
"""Get the style dictionary for matplotlib path objects"""
|
|
style = {}
|
|
style["alpha"] = path.get_alpha()
|
|
if style["alpha"] is None:
|
|
style["alpha"] = 1
|
|
style["edgecolor"] = export_color(path.get_edgecolor())
|
|
if fill:
|
|
style["facecolor"] = export_color(path.get_facecolor())
|
|
else:
|
|
style["facecolor"] = "none"
|
|
style["edgewidth"] = path.get_linewidth()
|
|
style["dasharray"] = get_dasharray(path)
|
|
style["zorder"] = path.get_zorder()
|
|
return style
|
|
|
|
|
|
def get_line_style(line):
|
|
"""Get the style dictionary for matplotlib line objects"""
|
|
style = {}
|
|
style["alpha"] = line.get_alpha()
|
|
if style["alpha"] is None:
|
|
style["alpha"] = 1
|
|
style["color"] = export_color(line.get_color())
|
|
style["linewidth"] = line.get_linewidth()
|
|
style["dasharray"] = get_dasharray(line)
|
|
style["zorder"] = line.get_zorder()
|
|
style["drawstyle"] = line.get_drawstyle()
|
|
return style
|
|
|
|
|
|
def get_marker_style(line):
|
|
"""Get the style dictionary for matplotlib marker objects"""
|
|
style = {}
|
|
style["alpha"] = line.get_alpha()
|
|
if style["alpha"] is None:
|
|
style["alpha"] = 1
|
|
|
|
style["facecolor"] = export_color(line.get_markerfacecolor())
|
|
style["edgecolor"] = export_color(line.get_markeredgecolor())
|
|
style["edgewidth"] = line.get_markeredgewidth()
|
|
|
|
style["marker"] = line.get_marker()
|
|
markerstyle = MarkerStyle(line.get_marker())
|
|
markersize = line.get_markersize()
|
|
markertransform = markerstyle.get_transform() + Affine2D().scale(
|
|
markersize, -markersize
|
|
)
|
|
style["markerpath"] = SVG_path(markerstyle.get_path(), markertransform)
|
|
style["markersize"] = markersize
|
|
style["zorder"] = line.get_zorder()
|
|
return style
|
|
|
|
|
|
def get_text_style(text):
|
|
"""Return the text style dict for a text instance"""
|
|
style = {}
|
|
style["alpha"] = text.get_alpha()
|
|
if style["alpha"] is None:
|
|
style["alpha"] = 1
|
|
style["fontsize"] = text.get_size()
|
|
style["color"] = export_color(text.get_color())
|
|
style["halign"] = text.get_horizontalalignment() # left, center, right
|
|
style["valign"] = text.get_verticalalignment() # baseline, center, top
|
|
style["malign"] = text._multialignment # text alignment when '\n' in text
|
|
style["rotation"] = text.get_rotation()
|
|
style["zorder"] = text.get_zorder()
|
|
return style
|
|
|
|
|
|
def get_axis_properties(axis):
|
|
"""Return the property dictionary for a matplotlib.Axis instance"""
|
|
props = {}
|
|
label1On = axis._major_tick_kw.get("label1On", True)
|
|
|
|
if isinstance(axis, matplotlib.axis.XAxis):
|
|
if label1On:
|
|
props["position"] = "bottom"
|
|
else:
|
|
props["position"] = "top"
|
|
elif isinstance(axis, matplotlib.axis.YAxis):
|
|
if label1On:
|
|
props["position"] = "left"
|
|
else:
|
|
props["position"] = "right"
|
|
else:
|
|
raise ValueError("{0} should be an Axis instance".format(axis))
|
|
|
|
# Use tick values if appropriate
|
|
locator = axis.get_major_locator()
|
|
props["nticks"] = len(locator())
|
|
if isinstance(locator, ticker.FixedLocator):
|
|
props["tickvalues"] = list(locator())
|
|
else:
|
|
props["tickvalues"] = None
|
|
|
|
# Find tick formats
|
|
formatter = axis.get_major_formatter()
|
|
if isinstance(formatter, ticker.NullFormatter):
|
|
props["tickformat"] = ""
|
|
elif isinstance(formatter, ticker.FixedFormatter):
|
|
props["tickformat"] = list(formatter.seq)
|
|
elif isinstance(formatter, ticker.FuncFormatter):
|
|
props["tickformat"] = list(formatter.func.args[0].values())
|
|
elif not any(label.get_visible() for label in axis.get_ticklabels()):
|
|
props["tickformat"] = ""
|
|
else:
|
|
props["tickformat"] = None
|
|
|
|
# Get axis scale
|
|
props["scale"] = axis.get_scale()
|
|
|
|
# Get major tick label size (assumes that's all we really care about!)
|
|
labels = axis.get_ticklabels()
|
|
if labels:
|
|
props["fontsize"] = labels[0].get_fontsize()
|
|
else:
|
|
props["fontsize"] = None
|
|
|
|
# Get associated grid
|
|
props["grid"] = get_grid_style(axis)
|
|
|
|
# get axis visibility
|
|
props["visible"] = axis.get_visible()
|
|
|
|
return props
|
|
|
|
|
|
def get_grid_style(axis):
|
|
gridlines = axis.get_gridlines()
|
|
if axis._major_tick_kw["gridOn"] and len(gridlines) > 0:
|
|
color = export_color(gridlines[0].get_color())
|
|
alpha = gridlines[0].get_alpha()
|
|
dasharray = get_dasharray(gridlines[0])
|
|
return dict(gridOn=True, color=color, dasharray=dasharray, alpha=alpha)
|
|
else:
|
|
return {"gridOn": False}
|
|
|
|
|
|
def get_figure_properties(fig):
|
|
return {
|
|
"figwidth": fig.get_figwidth(),
|
|
"figheight": fig.get_figheight(),
|
|
"dpi": fig.dpi,
|
|
}
|
|
|
|
|
|
def get_axes_properties(ax):
|
|
props = {
|
|
"axesbg": export_color(ax.patch.get_facecolor()),
|
|
"axesbgalpha": ax.patch.get_alpha(),
|
|
"bounds": ax.get_position().bounds,
|
|
"dynamic": ax.get_navigate(),
|
|
"axison": ax.axison,
|
|
"frame_on": ax.get_frame_on(),
|
|
"patch_visible": ax.patch.get_visible(),
|
|
"axes": [get_axis_properties(ax.xaxis), get_axis_properties(ax.yaxis)],
|
|
}
|
|
|
|
for axname in ["x", "y"]:
|
|
axis = getattr(ax, axname + "axis")
|
|
domain = getattr(ax, "get_{0}lim".format(axname))()
|
|
lim = domain
|
|
if isinstance(axis.converter, matplotlib.dates.DateConverter):
|
|
scale = "date"
|
|
try:
|
|
import pandas as pd
|
|
from pandas.tseries.converter import PeriodConverter
|
|
except ImportError:
|
|
pd = None
|
|
|
|
if pd is not None and isinstance(axis.converter, PeriodConverter):
|
|
_dates = [pd.Period(ordinal=int(d), freq=axis.freq) for d in domain]
|
|
domain = [
|
|
(d.year, d.month - 1, d.day, d.hour, d.minute, d.second, 0)
|
|
for d in _dates
|
|
]
|
|
else:
|
|
domain = [
|
|
(
|
|
d.year,
|
|
d.month - 1,
|
|
d.day,
|
|
d.hour,
|
|
d.minute,
|
|
d.second,
|
|
d.microsecond * 1e-3,
|
|
)
|
|
for d in matplotlib.dates.num2date(domain)
|
|
]
|
|
else:
|
|
scale = axis.get_scale()
|
|
|
|
if scale not in ["date", "linear", "log"]:
|
|
raise ValueError("Unknown axis scale: " "{0}".format(axis.get_scale()))
|
|
|
|
props[axname + "scale"] = scale
|
|
props[axname + "lim"] = lim
|
|
props[axname + "domain"] = domain
|
|
|
|
return props
|
|
|
|
|
|
def iter_all_children(obj, skipContainers=False):
|
|
"""
|
|
Returns an iterator over all childen and nested children using
|
|
obj's get_children() method
|
|
|
|
if skipContainers is true, only childless objects are returned.
|
|
"""
|
|
if hasattr(obj, "get_children") and len(obj.get_children()) > 0:
|
|
for child in obj.get_children():
|
|
if not skipContainers:
|
|
yield child
|
|
# could use `yield from` in python 3...
|
|
for grandchild in iter_all_children(child, skipContainers):
|
|
yield grandchild
|
|
else:
|
|
yield obj
|
|
|
|
|
|
def get_legend_properties(ax, legend):
|
|
handles, labels = ax.get_legend_handles_labels()
|
|
visible = legend.get_visible()
|
|
return {"handles": handles, "labels": labels, "visible": visible}
|
|
|
|
|
|
def image_to_base64(image):
|
|
"""
|
|
Convert a matplotlib image to a base64 png representation
|
|
|
|
Parameters
|
|
----------
|
|
image : matplotlib image object
|
|
The image to be converted.
|
|
|
|
Returns
|
|
-------
|
|
image_base64 : string
|
|
The UTF8-encoded base64 string representation of the png image.
|
|
"""
|
|
ax = image.axes
|
|
binary_buffer = io.BytesIO()
|
|
|
|
# image is saved in axes coordinates: we need to temporarily
|
|
# set the correct limits to get the correct image
|
|
lim = ax.axis()
|
|
ax.axis(image.get_extent())
|
|
image.write_png(binary_buffer)
|
|
ax.axis(lim)
|
|
|
|
binary_buffer.seek(0)
|
|
return base64.b64encode(binary_buffer.read()).decode("utf-8")
|