wg-backend-django/dell-env/lib/python3.11/site-packages/plotly/matplotlylib/mplexporter/utils.py
2023-10-30 14:40:43 +07:00

383 lines
12 KiB
Python

"""
Utility Routines for Working with Matplotlib Objects
====================================================
"""
import itertools
import io
import base64
import numpy as np
import warnings
import matplotlib
from matplotlib.colors import colorConverter
from matplotlib.path import Path
from matplotlib.markers import MarkerStyle
from matplotlib.transforms import Affine2D
from matplotlib import ticker
def export_color(color):
"""Convert matplotlib color code to hex color or RGBA color"""
if color is None or colorConverter.to_rgba(color)[3] == 0:
return "none"
elif colorConverter.to_rgba(color)[3] == 1:
rgb = colorConverter.to_rgb(color)
return "#{0:02X}{1:02X}{2:02X}".format(*(int(255 * c) for c in rgb))
else:
c = colorConverter.to_rgba(color)
return (
"rgba("
+ ", ".join(str(int(np.round(val * 255))) for val in c[:3])
+ ", "
+ str(c[3])
+ ")"
)
def _many_to_one(input_dict):
"""Convert a many-to-one mapping to a one-to-one mapping"""
return dict((key, val) for keys, val in input_dict.items() for key in keys)
LINESTYLES = _many_to_one(
{
("solid", "-", (None, None)): "none",
("dashed", "--"): "6,6",
("dotted", ":"): "2,2",
("dashdot", "-."): "4,4,2,4",
("", " ", "None", "none"): None,
}
)
def get_dasharray(obj):
"""Get an SVG dash array for the given matplotlib linestyle
Parameters
----------
obj : matplotlib object
The matplotlib line or path object, which must have a get_linestyle()
method which returns a valid matplotlib line code
Returns
-------
dasharray : string
The HTML/SVG dasharray code associated with the object.
"""
if obj.__dict__.get("_dashSeq", None) is not None:
return ",".join(map(str, obj._dashSeq))
else:
ls = obj.get_linestyle()
dasharray = LINESTYLES.get(ls, "not found")
if dasharray == "not found":
warnings.warn(
"line style '{0}' not understood: "
"defaulting to solid line.".format(ls)
)
dasharray = LINESTYLES["solid"]
return dasharray
PATH_DICT = {
Path.LINETO: "L",
Path.MOVETO: "M",
Path.CURVE3: "S",
Path.CURVE4: "C",
Path.CLOSEPOLY: "Z",
}
def SVG_path(path, transform=None, simplify=False):
"""Construct the vertices and SVG codes for the path
Parameters
----------
path : matplotlib.Path object
transform : matplotlib transform (optional)
if specified, the path will be transformed before computing the output.
Returns
-------
vertices : array
The shape (M, 2) array of vertices of the Path. Note that some Path
codes require multiple vertices, so the length of these vertices may
be longer than the list of path codes.
path_codes : list
A length N list of single-character path codes, N <= M. Each code is
a single character, in ['L','M','S','C','Z']. See the standard SVG
path specification for a description of these.
"""
if transform is not None:
path = path.transformed(transform)
vc_tuples = [
(vertices if path_code != Path.CLOSEPOLY else [], PATH_DICT[path_code])
for (vertices, path_code) in path.iter_segments(simplify=simplify)
]
if not vc_tuples:
# empty path is a special case
return np.zeros((0, 2)), []
else:
vertices, codes = zip(*vc_tuples)
vertices = np.array(list(itertools.chain(*vertices))).reshape(-1, 2)
return vertices, list(codes)
def get_path_style(path, fill=True):
"""Get the style dictionary for matplotlib path objects"""
style = {}
style["alpha"] = path.get_alpha()
if style["alpha"] is None:
style["alpha"] = 1
style["edgecolor"] = export_color(path.get_edgecolor())
if fill:
style["facecolor"] = export_color(path.get_facecolor())
else:
style["facecolor"] = "none"
style["edgewidth"] = path.get_linewidth()
style["dasharray"] = get_dasharray(path)
style["zorder"] = path.get_zorder()
return style
def get_line_style(line):
"""Get the style dictionary for matplotlib line objects"""
style = {}
style["alpha"] = line.get_alpha()
if style["alpha"] is None:
style["alpha"] = 1
style["color"] = export_color(line.get_color())
style["linewidth"] = line.get_linewidth()
style["dasharray"] = get_dasharray(line)
style["zorder"] = line.get_zorder()
style["drawstyle"] = line.get_drawstyle()
return style
def get_marker_style(line):
"""Get the style dictionary for matplotlib marker objects"""
style = {}
style["alpha"] = line.get_alpha()
if style["alpha"] is None:
style["alpha"] = 1
style["facecolor"] = export_color(line.get_markerfacecolor())
style["edgecolor"] = export_color(line.get_markeredgecolor())
style["edgewidth"] = line.get_markeredgewidth()
style["marker"] = line.get_marker()
markerstyle = MarkerStyle(line.get_marker())
markersize = line.get_markersize()
markertransform = markerstyle.get_transform() + Affine2D().scale(
markersize, -markersize
)
style["markerpath"] = SVG_path(markerstyle.get_path(), markertransform)
style["markersize"] = markersize
style["zorder"] = line.get_zorder()
return style
def get_text_style(text):
"""Return the text style dict for a text instance"""
style = {}
style["alpha"] = text.get_alpha()
if style["alpha"] is None:
style["alpha"] = 1
style["fontsize"] = text.get_size()
style["color"] = export_color(text.get_color())
style["halign"] = text.get_horizontalalignment() # left, center, right
style["valign"] = text.get_verticalalignment() # baseline, center, top
style["malign"] = text._multialignment # text alignment when '\n' in text
style["rotation"] = text.get_rotation()
style["zorder"] = text.get_zorder()
return style
def get_axis_properties(axis):
"""Return the property dictionary for a matplotlib.Axis instance"""
props = {}
label1On = axis._major_tick_kw.get("label1On", True)
if isinstance(axis, matplotlib.axis.XAxis):
if label1On:
props["position"] = "bottom"
else:
props["position"] = "top"
elif isinstance(axis, matplotlib.axis.YAxis):
if label1On:
props["position"] = "left"
else:
props["position"] = "right"
else:
raise ValueError("{0} should be an Axis instance".format(axis))
# Use tick values if appropriate
locator = axis.get_major_locator()
props["nticks"] = len(locator())
if isinstance(locator, ticker.FixedLocator):
props["tickvalues"] = list(locator())
else:
props["tickvalues"] = None
# Find tick formats
formatter = axis.get_major_formatter()
if isinstance(formatter, ticker.NullFormatter):
props["tickformat"] = ""
elif isinstance(formatter, ticker.FixedFormatter):
props["tickformat"] = list(formatter.seq)
elif isinstance(formatter, ticker.FuncFormatter):
props["tickformat"] = list(formatter.func.args[0].values())
elif not any(label.get_visible() for label in axis.get_ticklabels()):
props["tickformat"] = ""
else:
props["tickformat"] = None
# Get axis scale
props["scale"] = axis.get_scale()
# Get major tick label size (assumes that's all we really care about!)
labels = axis.get_ticklabels()
if labels:
props["fontsize"] = labels[0].get_fontsize()
else:
props["fontsize"] = None
# Get associated grid
props["grid"] = get_grid_style(axis)
# get axis visibility
props["visible"] = axis.get_visible()
return props
def get_grid_style(axis):
gridlines = axis.get_gridlines()
if axis._major_tick_kw["gridOn"] and len(gridlines) > 0:
color = export_color(gridlines[0].get_color())
alpha = gridlines[0].get_alpha()
dasharray = get_dasharray(gridlines[0])
return dict(gridOn=True, color=color, dasharray=dasharray, alpha=alpha)
else:
return {"gridOn": False}
def get_figure_properties(fig):
return {
"figwidth": fig.get_figwidth(),
"figheight": fig.get_figheight(),
"dpi": fig.dpi,
}
def get_axes_properties(ax):
props = {
"axesbg": export_color(ax.patch.get_facecolor()),
"axesbgalpha": ax.patch.get_alpha(),
"bounds": ax.get_position().bounds,
"dynamic": ax.get_navigate(),
"axison": ax.axison,
"frame_on": ax.get_frame_on(),
"patch_visible": ax.patch.get_visible(),
"axes": [get_axis_properties(ax.xaxis), get_axis_properties(ax.yaxis)],
}
for axname in ["x", "y"]:
axis = getattr(ax, axname + "axis")
domain = getattr(ax, "get_{0}lim".format(axname))()
lim = domain
if isinstance(axis.converter, matplotlib.dates.DateConverter):
scale = "date"
try:
import pandas as pd
from pandas.tseries.converter import PeriodConverter
except ImportError:
pd = None
if pd is not None and isinstance(axis.converter, PeriodConverter):
_dates = [pd.Period(ordinal=int(d), freq=axis.freq) for d in domain]
domain = [
(d.year, d.month - 1, d.day, d.hour, d.minute, d.second, 0)
for d in _dates
]
else:
domain = [
(
d.year,
d.month - 1,
d.day,
d.hour,
d.minute,
d.second,
d.microsecond * 1e-3,
)
for d in matplotlib.dates.num2date(domain)
]
else:
scale = axis.get_scale()
if scale not in ["date", "linear", "log"]:
raise ValueError("Unknown axis scale: " "{0}".format(axis.get_scale()))
props[axname + "scale"] = scale
props[axname + "lim"] = lim
props[axname + "domain"] = domain
return props
def iter_all_children(obj, skipContainers=False):
"""
Returns an iterator over all childen and nested children using
obj's get_children() method
if skipContainers is true, only childless objects are returned.
"""
if hasattr(obj, "get_children") and len(obj.get_children()) > 0:
for child in obj.get_children():
if not skipContainers:
yield child
# could use `yield from` in python 3...
for grandchild in iter_all_children(child, skipContainers):
yield grandchild
else:
yield obj
def get_legend_properties(ax, legend):
handles, labels = ax.get_legend_handles_labels()
visible = legend.get_visible()
return {"handles": handles, "labels": labels, "visible": visible}
def image_to_base64(image):
"""
Convert a matplotlib image to a base64 png representation
Parameters
----------
image : matplotlib image object
The image to be converted.
Returns
-------
image_base64 : string
The UTF8-encoded base64 string representation of the png image.
"""
ax = image.axes
binary_buffer = io.BytesIO()
# image is saved in axes coordinates: we need to temporarily
# set the correct limits to get the correct image
lim = ax.axis()
ax.axis(image.get_extent())
image.write_png(binary_buffer)
ax.axis(lim)
binary_buffer.seek(0)
return base64.b64encode(binary_buffer.read()).decode("utf-8")