{}
\n" " \n" " \n" "| {k}: | {v!r} |
# pre-release
[-_\.]?
(?P(a|b|c|rc|alpha|beta|pre|preview))
[-_\.]?
(?P[0-9]+)?
)?
(?P # post release
(?:-(?P[0-9]+))
|
(?:
[-_\.]?
(?Ppost|rev|r)
[-_\.]?
(?P[0-9]+)?
)
)?
(?P # dev release
[-_\.]?
(?Pdev)
[-_\.]?
(?P[0-9]+)?
)?
)
(?:\+(?P[a-z0-9]+(?:[-_\.][a-z0-9]+)*))? # local version
"""
class Version(_BaseVersion):
_regex = re.compile(r"^\s*" + VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
def __init__(self, version: str) -> None:
# Validate the version and parse it into pieces
match = self._regex.search(version)
if not match:
raise InvalidVersion(f"Invalid version: '{version}'")
# Store the parsed out pieces of the version
self._version = _Version(
epoch=int(match.group("epoch")) if match.group("epoch") else 0,
release=tuple(int(i) for i in match.group("release").split(".")),
pre=_parse_letter_version(match.group("pre_l"), match.group("pre_n")),
post=_parse_letter_version(
match.group("post_l"), match.group("post_n1") or match.group("post_n2")
),
dev=_parse_letter_version(match.group("dev_l"), match.group("dev_n")),
local=_parse_local_version(match.group("local")),
)
# Generate a key which will be used for sorting
self._key = _cmpkey(
self._version.epoch,
self._version.release,
self._version.pre,
self._version.post,
self._version.dev,
self._version.local,
)
def __repr__(self) -> str:
return f""
def __str__(self) -> str:
parts = []
# Epoch
if self.epoch != 0:
parts.append(f"{self.epoch}!")
# Release segment
parts.append(".".join(str(x) for x in self.release))
# Pre-release
if self.pre is not None:
parts.append("".join(str(x) for x in self.pre))
# Post-release
if self.post is not None:
parts.append(f".post{self.post}")
# Development release
if self.dev is not None:
parts.append(f".dev{self.dev}")
# Local version segment
if self.local is not None:
parts.append(f"+{self.local}")
return "".join(parts)
@property
def epoch(self) -> int:
_epoch: int = self._version.epoch
return _epoch
@property
def release(self) -> Tuple[int, ...]:
_release: Tuple[int, ...] = self._version.release
return _release
@property
def pre(self) -> Optional[Tuple[str, int]]:
_pre: Optional[Tuple[str, int]] = self._version.pre
return _pre
@property
def post(self) -> Optional[int]:
return self._version.post[1] if self._version.post else None
@property
def dev(self) -> Optional[int]:
return self._version.dev[1] if self._version.dev else None
@property
def local(self) -> Optional[str]:
if self._version.local:
return ".".join(str(x) for x in self._version.local)
else:
return None
@property
def public(self) -> str:
return str(self).split("+", 1)[0]
@property
def base_version(self) -> str:
parts = []
# Epoch
if self.epoch != 0:
parts.append(f"{self.epoch}!")
# Release segment
parts.append(".".join(str(x) for x in self.release))
return "".join(parts)
@property
def is_prerelease(self) -> bool:
return self.dev is not None or self.pre is not None
@property
def is_postrelease(self) -> bool:
return self.post is not None
@property
def is_devrelease(self) -> bool:
return self.dev is not None
@property
def major(self) -> int:
return self.release[0] if len(self.release) >= 1 else 0
@property
def minor(self) -> int:
return self.release[1] if len(self.release) >= 2 else 0
@property
def micro(self) -> int:
return self.release[2] if len(self.release) >= 3 else 0
def _parse_letter_version(
letter: str, number: Union[str, bytes, SupportsInt]
) -> Optional[Tuple[str, int]]:
if letter:
# We consider there to be an implicit 0 in a pre-release if there is
# not a numeral associated with it.
if number is None:
number = 0
# We normalize any letters to their lower case form
letter = letter.lower()
# We consider some words to be alternate spellings of other words and
# in those cases we want to normalize the spellings to our preferred
# spelling.
if letter == "alpha":
letter = "a"
elif letter == "beta":
letter = "b"
elif letter in ["c", "pre", "preview"]:
letter = "rc"
elif letter in ["rev", "r"]:
letter = "post"
return letter, int(number)
if not letter and number:
# We assume if we are given a number, but we are not given a letter
# then this is using the implicit post release syntax (e.g. 1.0-1)
letter = "post"
return letter, int(number)
return None
_local_version_separators = re.compile(r"[\._-]")
def _parse_local_version(local: str) -> Optional[LocalType]:
"""
Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve").
"""
if local is not None:
return tuple(
part.lower() if not part.isdigit() else int(part)
for part in _local_version_separators.split(local)
)
return None
def _cmpkey(
epoch: int,
release: Tuple[int, ...],
pre: Optional[Tuple[str, int]],
post: Optional[Tuple[str, int]],
dev: Optional[Tuple[str, int]],
local: Optional[Tuple[SubLocalType]],
) -> CmpKey:
# When we compare a release version, we want to compare it with all of the
# trailing zeros removed. So we'll use a reverse the list, drop all the now
# leading zeros until we come to something non zero, then take the rest
# re-reverse it back into the correct order and make it a tuple and use
# that for our sorting key.
_release = tuple(
reversed(list(itertools.dropwhile(lambda x: x == 0, reversed(release))))
)
# We need to "trick" the sorting algorithm to put 1.0.dev0 before 1.0a0.
# We'll do this by abusing the pre segment, but we _only_ want to do this
# if there is not a pre or a post segment. If we have one of those then
# the normal sorting rules will handle this case correctly.
if pre is None and post is None and dev is not None:
_pre: PrePostDevType = NegativeInfinity
# Versions without a pre-release (except as noted above) should sort after
# those with one.
elif pre is None:
_pre = Infinity
else:
_pre = pre
# Versions without a post segment should sort before those with one.
if post is None:
_post: PrePostDevType = NegativeInfinity
else:
_post = post
# Versions without a development segment should sort after those with one.
if dev is None:
_dev: PrePostDevType = Infinity
else:
_dev = dev
if local is None:
# Versions without a local segment should sort before those with one.
_local: LocalType = NegativeInfinity
else:
# Versions with a local segment need that segment parsed to implement
# the sorting rules in PEP440.
# - Alpha numeric segments sort before numeric segments
# - Alpha numeric segments sort lexicographically
# - Numeric segments sort numerically
# - Shorter versions sort before longer versions when the prefixes
# match exactly
_local = tuple(
(i, "") if isinstance(i, int) else (NegativeInfinity, i) for i in local
)
return epoch, _release, _pre, _post, _dev, _local
================================================
FILE: seaborn/matrix.py
================================================
"""Functions to visualize matrices of data."""
import warnings
import matplotlib as mpl
from matplotlib.collections import LineCollection
import matplotlib.pyplot as plt
from matplotlib import gridspec
import numpy as np
import pandas as pd
try:
from scipy.cluster import hierarchy
_no_scipy = False
except ImportError:
_no_scipy = True
from . import cm
from .axisgrid import Grid
from ._compat import get_colormap
from .utils import (
despine,
axis_ticklabels_overlap,
relative_luminance,
to_utf8,
_draw_figure,
)
__all__ = ["heatmap", "clustermap"]
def _index_to_label(index):
"""Convert a pandas index or multiindex to an axis label."""
if isinstance(index, pd.MultiIndex):
return "-".join(map(to_utf8, index.names))
else:
return index.name
def _index_to_ticklabels(index):
"""Convert a pandas index or multiindex into ticklabels."""
if isinstance(index, pd.MultiIndex):
return ["-".join(map(to_utf8, i)) for i in index.values]
else:
return index.values
def _convert_colors(colors):
"""Convert either a list of colors or nested lists of colors to RGB."""
to_rgb = mpl.colors.to_rgb
try:
to_rgb(colors[0])
# If this works, there is only one level of colors
return list(map(to_rgb, colors))
except ValueError:
# If we get here, we have nested lists
return [list(map(to_rgb, color_list)) for color_list in colors]
def _matrix_mask(data, mask):
"""Ensure that data and mask are compatible and add missing values.
Values will be plotted for cells where ``mask`` is ``False``.
``data`` is expected to be a DataFrame; ``mask`` can be an array or
a DataFrame.
"""
if mask is None:
mask = np.zeros(data.shape, bool)
if isinstance(mask, pd.DataFrame):
# For DataFrame masks, ensure that semantic labels match data
if not mask.index.equals(data.index) \
and mask.columns.equals(data.columns):
err = "Mask must have the same index and columns as data."
raise ValueError(err)
elif hasattr(mask, "__array__"):
mask = np.asarray(mask)
# For array masks, ensure that shape matches data then convert
if mask.shape != data.shape:
raise ValueError("Mask must have the same shape as data.")
mask = pd.DataFrame(mask,
index=data.index,
columns=data.columns,
dtype=bool)
# Add any cells with missing data to the mask
# This works around an issue where `plt.pcolormesh` doesn't represent
# missing data properly
mask = mask | pd.isnull(data)
return mask
class _HeatMapper:
"""Draw a heatmap plot of a matrix with nice labels and colormaps."""
def __init__(self, data, vmin, vmax, cmap, center, robust, annot, fmt,
annot_kws, cbar, cbar_kws,
xticklabels=True, yticklabels=True, mask=None):
"""Initialize the plotting object."""
# We always want to have a DataFrame with semantic information
# and an ndarray to pass to matplotlib
if isinstance(data, pd.DataFrame):
plot_data = data.values
else:
plot_data = np.asarray(data)
data = pd.DataFrame(plot_data)
# Validate the mask and convert to DataFrame
mask = _matrix_mask(data, mask)
plot_data = np.ma.masked_where(np.asarray(mask), plot_data)
# Get good names for the rows and columns
xtickevery = 1
if isinstance(xticklabels, int):
xtickevery = xticklabels
xticklabels = _index_to_ticklabels(data.columns)
elif xticklabels is True:
xticklabels = _index_to_ticklabels(data.columns)
elif xticklabels is False:
xticklabels = []
ytickevery = 1
if isinstance(yticklabels, int):
ytickevery = yticklabels
yticklabels = _index_to_ticklabels(data.index)
elif yticklabels is True:
yticklabels = _index_to_ticklabels(data.index)
elif yticklabels is False:
yticklabels = []
if not len(xticklabels):
self.xticks = []
self.xticklabels = []
elif isinstance(xticklabels, str) and xticklabels == "auto":
self.xticks = "auto"
self.xticklabels = _index_to_ticklabels(data.columns)
else:
self.xticks, self.xticklabels = self._skip_ticks(xticklabels,
xtickevery)
if not len(yticklabels):
self.yticks = []
self.yticklabels = []
elif isinstance(yticklabels, str) and yticklabels == "auto":
self.yticks = "auto"
self.yticklabels = _index_to_ticklabels(data.index)
else:
self.yticks, self.yticklabels = self._skip_ticks(yticklabels,
ytickevery)
# Get good names for the axis labels
xlabel = _index_to_label(data.columns)
ylabel = _index_to_label(data.index)
self.xlabel = xlabel if xlabel is not None else ""
self.ylabel = ylabel if ylabel is not None else ""
# Determine good default values for the colormapping
self._determine_cmap_params(plot_data, vmin, vmax,
cmap, center, robust)
# Sort out the annotations
if annot is None or annot is False:
annot = False
annot_data = None
else:
if isinstance(annot, bool):
annot_data = plot_data
else:
annot_data = np.asarray(annot)
if annot_data.shape != plot_data.shape:
err = "`data` and `annot` must have same shape."
raise ValueError(err)
annot = True
# Save other attributes to the object
self.data = data
self.plot_data = plot_data
self.annot = annot
self.annot_data = annot_data
self.fmt = fmt
self.annot_kws = {} if annot_kws is None else annot_kws.copy()
self.cbar = cbar
self.cbar_kws = {} if cbar_kws is None else cbar_kws.copy()
def _determine_cmap_params(self, plot_data, vmin, vmax,
cmap, center, robust):
"""Use some heuristics to set good defaults for colorbar and range."""
# plot_data is a np.ma.array instance
calc_data = plot_data.astype(float).filled(np.nan)
if vmin is None:
if robust:
vmin = np.nanpercentile(calc_data, 2)
else:
vmin = np.nanmin(calc_data)
if vmax is None:
if robust:
vmax = np.nanpercentile(calc_data, 98)
else:
vmax = np.nanmax(calc_data)
self.vmin, self.vmax = vmin, vmax
# Choose default colormaps if not provided
if cmap is None:
if center is None:
self.cmap = cm.rocket
else:
self.cmap = cm.icefire
elif isinstance(cmap, str):
self.cmap = get_colormap(cmap)
elif isinstance(cmap, list):
self.cmap = mpl.colors.ListedColormap(cmap)
else:
self.cmap = cmap
# Recenter a divergent colormap
if center is not None:
# Copy bad values
# in mpl<3.2 only masked values are honored with "bad" color spec
# (see https://github.com/matplotlib/matplotlib/pull/14257)
bad = self.cmap(np.ma.masked_invalid([np.nan]))[0]
# under/over values are set for sure when cmap extremes
# do not map to the same color as +-inf
under = self.cmap(-np.inf)
over = self.cmap(np.inf)
under_set = under != self.cmap(0)
over_set = over != self.cmap(self.cmap.N - 1)
vrange = max(vmax - center, center - vmin)
normlize = mpl.colors.Normalize(center - vrange, center + vrange)
cmin, cmax = normlize([vmin, vmax])
cc = np.linspace(cmin, cmax, 256)
self.cmap = mpl.colors.ListedColormap(self.cmap(cc))
self.cmap.set_bad(bad)
if under_set:
self.cmap.set_under(under)
if over_set:
self.cmap.set_over(over)
def _annotate_heatmap(self, ax, mesh):
"""Add textual labels with the value in each cell."""
mesh.update_scalarmappable()
height, width = self.annot_data.shape
xpos, ypos = np.meshgrid(np.arange(width) + .5, np.arange(height) + .5)
for x, y, m, color, val in zip(xpos.flat, ypos.flat,
mesh.get_array().flat, mesh.get_facecolors(),
self.annot_data.flat):
if m is not np.ma.masked:
lum = relative_luminance(color)
text_color = ".15" if lum > .408 else "w"
annotation = ("{:" + self.fmt + "}").format(val)
text_kwargs = dict(color=text_color, ha="center", va="center")
text_kwargs.update(self.annot_kws)
ax.text(x, y, annotation, **text_kwargs)
def _skip_ticks(self, labels, tickevery):
"""Return ticks and labels at evenly spaced intervals."""
n = len(labels)
if tickevery == 0:
ticks, labels = [], []
elif tickevery == 1:
ticks, labels = np.arange(n) + .5, labels
else:
start, end, step = 0, n, tickevery
ticks = np.arange(start, end, step) + .5
labels = labels[start:end:step]
return ticks, labels
def _auto_ticks(self, ax, labels, axis):
"""Determine ticks and ticklabels that minimize overlap."""
transform = ax.figure.dpi_scale_trans.inverted()
bbox = ax.get_window_extent().transformed(transform)
size = [bbox.width, bbox.height][axis]
axis = [ax.xaxis, ax.yaxis][axis]
tick, = axis.set_ticks([0])
fontsize = tick.label1.get_size()
max_ticks = int(size // (fontsize / 72))
if max_ticks < 1:
return [], []
tick_every = len(labels) // max_ticks + 1
tick_every = 1 if tick_every == 0 else tick_every
ticks, labels = self._skip_ticks(labels, tick_every)
return ticks, labels
def plot(self, ax, cax, kws):
"""Draw the heatmap on the provided Axes."""
# Remove all the Axes spines
despine(ax=ax, left=True, bottom=True)
# setting vmin/vmax in addition to norm is deprecated
# so avoid setting if norm is set
if kws.get("norm") is None:
kws.setdefault("vmin", self.vmin)
kws.setdefault("vmax", self.vmax)
# Draw the heatmap
mesh = ax.pcolormesh(self.plot_data, cmap=self.cmap, **kws)
# Set the axis limits
ax.set(xlim=(0, self.data.shape[1]), ylim=(0, self.data.shape[0]))
# Invert the y axis to show the plot in matrix form
ax.invert_yaxis()
# Possibly add a colorbar
if self.cbar:
cb = ax.figure.colorbar(mesh, cax, ax, **self.cbar_kws)
cb.outline.set_linewidth(0)
# If rasterized is passed to pcolormesh, also rasterize the
# colorbar to avoid white lines on the PDF rendering
if kws.get('rasterized', False):
cb.solids.set_rasterized(True)
# Add row and column labels
if isinstance(self.xticks, str) and self.xticks == "auto":
xticks, xticklabels = self._auto_ticks(ax, self.xticklabels, 0)
else:
xticks, xticklabels = self.xticks, self.xticklabels
if isinstance(self.yticks, str) and self.yticks == "auto":
yticks, yticklabels = self._auto_ticks(ax, self.yticklabels, 1)
else:
yticks, yticklabels = self.yticks, self.yticklabels
ax.set(xticks=xticks, yticks=yticks)
xtl = ax.set_xticklabels(xticklabels)
ytl = ax.set_yticklabels(yticklabels, rotation="vertical")
plt.setp(ytl, va="center") # GH2484
# Possibly rotate them if they overlap
_draw_figure(ax.figure)
if axis_ticklabels_overlap(xtl):
plt.setp(xtl, rotation="vertical")
if axis_ticklabels_overlap(ytl):
plt.setp(ytl, rotation="horizontal")
# Add the axis labels
ax.set(xlabel=self.xlabel, ylabel=self.ylabel)
# Annotate the cells with the formatted values
if self.annot:
self._annotate_heatmap(ax, mesh)
def heatmap(
data, *,
vmin=None, vmax=None, cmap=None, center=None, robust=False,
annot=None, fmt=".2g", annot_kws=None,
linewidths=0, linecolor="white",
cbar=True, cbar_kws=None, cbar_ax=None,
square=False, xticklabels="auto", yticklabels="auto",
mask=None, ax=None,
**kwargs
):
"""Plot rectangular data as a color-encoded matrix.
This is an Axes-level function and will draw the heatmap into the
currently-active Axes if none is provided to the ``ax`` argument. Part of
this Axes space will be taken and used to plot a colormap, unless ``cbar``
is False or a separate Axes is provided to ``cbar_ax``.
Parameters
----------
data : rectangular dataset
2D dataset that can be coerced into an ndarray. If a Pandas DataFrame
is provided, the index/column information will be used to label the
columns and rows.
vmin, vmax : floats, optional
Values to anchor the colormap, otherwise they are inferred from the
data and other keyword arguments.
cmap : matplotlib colormap name or object, or list of colors, optional
The mapping from data values to color space. If not provided, the
default will depend on whether ``center`` is set.
center : float, optional
The value at which to center the colormap when plotting divergent data.
Using this parameter will change the default ``cmap`` if none is
specified.
robust : bool, optional
If True and ``vmin`` or ``vmax`` are absent, the colormap range is
computed with robust quantiles instead of the extreme values.
annot : bool or rectangular dataset, optional
If True, write the data value in each cell. If an array-like with the
same shape as ``data``, then use this to annotate the heatmap instead
of the data. Note that DataFrames will match on position, not index.
fmt : str, optional
String formatting code to use when adding annotations.
annot_kws : dict of key, value mappings, optional
Keyword arguments for :meth:`matplotlib.axes.Axes.text` when ``annot``
is True.
linewidths : float, optional
Width of the lines that will divide each cell.
linecolor : color, optional
Color of the lines that will divide each cell.
cbar : bool, optional
Whether to draw a colorbar.
cbar_kws : dict of key, value mappings, optional
Keyword arguments for :meth:`matplotlib.figure.Figure.colorbar`.
cbar_ax : matplotlib Axes, optional
Axes in which to draw the colorbar, otherwise take space from the
main Axes.
square : bool, optional
If True, set the Axes aspect to "equal" so each cell will be
square-shaped.
xticklabels, yticklabels : "auto", bool, list-like, or int, optional
If True, plot the column names of the dataframe. If False, don't plot
the column names. If list-like, plot these alternate labels as the
xticklabels. If an integer, use the column names but plot only every
n label. If "auto", try to densely plot non-overlapping labels.
mask : bool array or DataFrame, optional
If passed, data will not be shown in cells where ``mask`` is True.
Cells with missing values are automatically masked.
ax : matplotlib Axes, optional
Axes in which to draw the plot, otherwise use the currently-active
Axes.
kwargs : other keyword arguments
All other keyword arguments are passed to
:meth:`matplotlib.axes.Axes.pcolormesh`.
Returns
-------
ax : matplotlib Axes
Axes object with the heatmap.
See Also
--------
clustermap : Plot a matrix using hierarchical clustering to arrange the
rows and columns.
Examples
--------
.. include:: ../docstrings/heatmap.rst
"""
# Initialize the plotter object
plotter = _HeatMapper(data, vmin, vmax, cmap, center, robust, annot, fmt,
annot_kws, cbar, cbar_kws, xticklabels,
yticklabels, mask)
# Add the pcolormesh kwargs here
kwargs["linewidths"] = linewidths
kwargs["edgecolor"] = linecolor
# Draw the plot and return the Axes
if ax is None:
ax = plt.gca()
if square:
ax.set_aspect("equal")
plotter.plot(ax, cbar_ax, kwargs)
return ax
class _DendrogramPlotter:
"""Object for drawing tree of similarities between data rows/columns"""
def __init__(self, data, linkage, metric, method, axis, label, rotate):
"""Plot a dendrogram of the relationships between the columns of data
Parameters
----------
data : pandas.DataFrame
Rectangular data
"""
self.axis = axis
if self.axis == 1:
data = data.T
if isinstance(data, pd.DataFrame):
array = data.values
else:
array = np.asarray(data)
data = pd.DataFrame(array)
self.array = array
self.data = data
self.shape = self.data.shape
self.metric = metric
self.method = method
self.axis = axis
self.label = label
self.rotate = rotate
if linkage is None:
self.linkage = self.calculated_linkage
else:
self.linkage = linkage
self.dendrogram = self.calculate_dendrogram()
# Dendrogram ends are always at multiples of 5, who knows why
ticks = 10 * np.arange(self.data.shape[0]) + 5
if self.label:
ticklabels = _index_to_ticklabels(self.data.index)
ticklabels = [ticklabels[i] for i in self.reordered_ind]
if self.rotate:
self.xticks = []
self.yticks = ticks
self.xticklabels = []
self.yticklabels = ticklabels
self.ylabel = _index_to_label(self.data.index)
self.xlabel = ''
else:
self.xticks = ticks
self.yticks = []
self.xticklabels = ticklabels
self.yticklabels = []
self.ylabel = ''
self.xlabel = _index_to_label(self.data.index)
else:
self.xticks, self.yticks = [], []
self.yticklabels, self.xticklabels = [], []
self.xlabel, self.ylabel = '', ''
self.dependent_coord = self.dendrogram['dcoord']
self.independent_coord = self.dendrogram['icoord']
def _calculate_linkage_scipy(self):
linkage = hierarchy.linkage(self.array, method=self.method,
metric=self.metric)
return linkage
def _calculate_linkage_fastcluster(self):
import fastcluster
# Fastcluster has a memory-saving vectorized version, but only
# with certain linkage methods, and mostly with euclidean metric
# vector_methods = ('single', 'centroid', 'median', 'ward')
euclidean_methods = ('centroid', 'median', 'ward')
euclidean = self.metric == 'euclidean' and self.method in \
euclidean_methods
if euclidean or self.method == 'single':
return fastcluster.linkage_vector(self.array,
method=self.method,
metric=self.metric)
else:
linkage = fastcluster.linkage(self.array, method=self.method,
metric=self.metric)
return linkage
@property
def calculated_linkage(self):
try:
return self._calculate_linkage_fastcluster()
except ImportError:
if np.prod(self.shape) >= 10000:
msg = ("Clustering large matrix with scipy. Installing "
"`fastcluster` may give better performance.")
warnings.warn(msg)
return self._calculate_linkage_scipy()
def calculate_dendrogram(self):
"""Calculates a dendrogram based on the linkage matrix
Made a separate function, not a property because don't want to
recalculate the dendrogram every time it is accessed.
Returns
-------
dendrogram : dict
Dendrogram dictionary as returned by scipy.cluster.hierarchy
.dendrogram. The important key-value pairing is
"reordered_ind" which indicates the re-ordering of the matrix
"""
return hierarchy.dendrogram(self.linkage, no_plot=True,
color_threshold=-np.inf)
@property
def reordered_ind(self):
"""Indices of the matrix, reordered by the dendrogram"""
return self.dendrogram['leaves']
def plot(self, ax, tree_kws):
"""Plots a dendrogram of the similarities between data on the axes
Parameters
----------
ax : matplotlib.axes.Axes
Axes object upon which the dendrogram is plotted
"""
tree_kws = {} if tree_kws is None else tree_kws.copy()
tree_kws.setdefault("linewidths", .5)
tree_kws.setdefault("colors", tree_kws.pop("color", (.2, .2, .2)))
if self.rotate and self.axis == 0:
coords = zip(self.dependent_coord, self.independent_coord)
else:
coords = zip(self.independent_coord, self.dependent_coord)
lines = LineCollection([list(zip(x, y)) for x, y in coords],
**tree_kws)
ax.add_collection(lines)
number_of_leaves = len(self.reordered_ind)
max_dependent_coord = max(map(max, self.dependent_coord))
if self.rotate:
ax.yaxis.set_ticks_position('right')
# Constants 10 and 1.05 come from
# `scipy.cluster.hierarchy._plot_dendrogram`
ax.set_ylim(0, number_of_leaves * 10)
ax.set_xlim(0, max_dependent_coord * 1.05)
ax.invert_xaxis()
ax.invert_yaxis()
else:
# Constants 10 and 1.05 come from
# `scipy.cluster.hierarchy._plot_dendrogram`
ax.set_xlim(0, number_of_leaves * 10)
ax.set_ylim(0, max_dependent_coord * 1.05)
despine(ax=ax, bottom=True, left=True)
ax.set(xticks=self.xticks, yticks=self.yticks,
xlabel=self.xlabel, ylabel=self.ylabel)
xtl = ax.set_xticklabels(self.xticklabels)
ytl = ax.set_yticklabels(self.yticklabels, rotation='vertical')
# Force a draw of the plot to avoid matplotlib window error
_draw_figure(ax.figure)
if len(ytl) > 0 and axis_ticklabels_overlap(ytl):
plt.setp(ytl, rotation="horizontal")
if len(xtl) > 0 and axis_ticklabels_overlap(xtl):
plt.setp(xtl, rotation="vertical")
return self
def dendrogram(
data, *,
linkage=None, axis=1, label=True, metric='euclidean',
method='average', rotate=False, tree_kws=None, ax=None
):
"""Draw a tree diagram of relationships within a matrix
Parameters
----------
data : pandas.DataFrame
Rectangular data
linkage : numpy.array, optional
Linkage matrix
axis : int, optional
Which axis to use to calculate linkage. 0 is rows, 1 is columns.
label : bool, optional
If True, label the dendrogram at leaves with column or row names
metric : str, optional
Distance metric. Anything valid for scipy.spatial.distance.pdist
method : str, optional
Linkage method to use. Anything valid for
scipy.cluster.hierarchy.linkage
rotate : bool, optional
When plotting the matrix, whether to rotate it 90 degrees
counter-clockwise, so the leaves face right
tree_kws : dict, optional
Keyword arguments for the ``matplotlib.collections.LineCollection``
that is used for plotting the lines of the dendrogram tree.
ax : matplotlib axis, optional
Axis to plot on, otherwise uses current axis
Returns
-------
dendrogramplotter : _DendrogramPlotter
A Dendrogram plotter object.
Notes
-----
Access the reordered dendrogram indices with
dendrogramplotter.reordered_ind
"""
if _no_scipy:
raise RuntimeError("dendrogram requires scipy to be installed")
plotter = _DendrogramPlotter(data, linkage=linkage, axis=axis,
metric=metric, method=method,
label=label, rotate=rotate)
if ax is None:
ax = plt.gca()
return plotter.plot(ax=ax, tree_kws=tree_kws)
class ClusterGrid(Grid):
def __init__(self, data, pivot_kws=None, z_score=None, standard_scale=None,
figsize=None, row_colors=None, col_colors=None, mask=None,
dendrogram_ratio=None, colors_ratio=None, cbar_pos=None):
"""Grid object for organizing clustered heatmap input on to axes"""
if _no_scipy:
raise RuntimeError("ClusterGrid requires scipy to be available")
if isinstance(data, pd.DataFrame):
self.data = data
else:
self.data = pd.DataFrame(data)
self.data2d = self.format_data(self.data, pivot_kws, z_score,
standard_scale)
self.mask = _matrix_mask(self.data2d, mask)
self._figure = plt.figure(figsize=figsize)
self.row_colors, self.row_color_labels = \
self._preprocess_colors(data, row_colors, axis=0)
self.col_colors, self.col_color_labels = \
self._preprocess_colors(data, col_colors, axis=1)
try:
row_dendrogram_ratio, col_dendrogram_ratio = dendrogram_ratio
except TypeError:
row_dendrogram_ratio = col_dendrogram_ratio = dendrogram_ratio
try:
row_colors_ratio, col_colors_ratio = colors_ratio
except TypeError:
row_colors_ratio = col_colors_ratio = colors_ratio
width_ratios = self.dim_ratios(self.row_colors,
row_dendrogram_ratio,
row_colors_ratio)
height_ratios = self.dim_ratios(self.col_colors,
col_dendrogram_ratio,
col_colors_ratio)
nrows = 2 if self.col_colors is None else 3
ncols = 2 if self.row_colors is None else 3
self.gs = gridspec.GridSpec(nrows, ncols,
width_ratios=width_ratios,
height_ratios=height_ratios)
self.ax_row_dendrogram = self._figure.add_subplot(self.gs[-1, 0])
self.ax_col_dendrogram = self._figure.add_subplot(self.gs[0, -1])
self.ax_row_dendrogram.set_axis_off()
self.ax_col_dendrogram.set_axis_off()
self.ax_row_colors = None
self.ax_col_colors = None
if self.row_colors is not None:
self.ax_row_colors = self._figure.add_subplot(
self.gs[-1, 1])
if self.col_colors is not None:
self.ax_col_colors = self._figure.add_subplot(
self.gs[1, -1])
self.ax_heatmap = self._figure.add_subplot(self.gs[-1, -1])
if cbar_pos is None:
self.ax_cbar = self.cax = None
else:
# Initialize the colorbar axes in the gridspec so that tight_layout
# works. We will move it where it belongs later. This is a hack.
self.ax_cbar = self._figure.add_subplot(self.gs[0, 0])
self.cax = self.ax_cbar # Backwards compatibility
self.cbar_pos = cbar_pos
self.dendrogram_row = None
self.dendrogram_col = None
def _preprocess_colors(self, data, colors, axis):
"""Preprocess {row/col}_colors to extract labels and convert colors."""
labels = None
if colors is not None:
if isinstance(colors, (pd.DataFrame, pd.Series)):
# If data is unindexed, raise
if (not hasattr(data, "index") and axis == 0) or (
not hasattr(data, "columns") and axis == 1
):
axis_name = "col" if axis else "row"
msg = (f"{axis_name}_colors indices can't be matched with data "
f"indices. Provide {axis_name}_colors as a non-indexed "
"datatype, e.g. by using `.to_numpy()``")
raise TypeError(msg)
# Ensure colors match data indices
if axis == 0:
colors = colors.reindex(data.index)
else:
colors = colors.reindex(data.columns)
# Replace na's with white color
# TODO We should set these to transparent instead
colors = colors.astype(object).fillna('white')
# Extract color values and labels from frame/series
if isinstance(colors, pd.DataFrame):
labels = list(colors.columns)
colors = colors.T.values
else:
if colors.name is None:
labels = [""]
else:
labels = [colors.name]
colors = colors.values
colors = _convert_colors(colors)
return colors, labels
def format_data(self, data, pivot_kws, z_score=None,
standard_scale=None):
"""Extract variables from data or use directly."""
# Either the data is already in 2d matrix format, or need to do a pivot
if pivot_kws is not None:
data2d = data.pivot(**pivot_kws)
else:
data2d = data
if z_score is not None and standard_scale is not None:
raise ValueError(
'Cannot perform both z-scoring and standard-scaling on data')
if z_score is not None:
data2d = self.z_score(data2d, z_score)
if standard_scale is not None:
data2d = self.standard_scale(data2d, standard_scale)
return data2d
@staticmethod
def z_score(data2d, axis=1):
"""Standarize the mean and variance of the data axis
Parameters
----------
data2d : pandas.DataFrame
Data to normalize
axis : int
Which axis to normalize across. If 0, normalize across rows, if 1,
normalize across columns.
Returns
-------
normalized : pandas.DataFrame
Noramlized data with a mean of 0 and variance of 1 across the
specified axis.
"""
if axis == 1:
z_scored = data2d
else:
z_scored = data2d.T
z_scored = (z_scored - z_scored.mean()) / z_scored.std()
if axis == 1:
return z_scored
else:
return z_scored.T
@staticmethod
def standard_scale(data2d, axis=1):
"""Divide the data by the difference between the max and min
Parameters
----------
data2d : pandas.DataFrame
Data to normalize
axis : int
Which axis to normalize across. If 0, normalize across rows, if 1,
normalize across columns.
Returns
-------
standardized : pandas.DataFrame
Noramlized data with a mean of 0 and variance of 1 across the
specified axis.
"""
# Normalize these values to range from 0 to 1
if axis == 1:
standardized = data2d
else:
standardized = data2d.T
subtract = standardized.min()
standardized = (standardized - subtract) / (
standardized.max() - standardized.min())
if axis == 1:
return standardized
else:
return standardized.T
def dim_ratios(self, colors, dendrogram_ratio, colors_ratio):
"""Get the proportions of the figure taken up by each axes."""
ratios = [dendrogram_ratio]
if colors is not None:
# Colors are encoded as rgb, so there is an extra dimension
if np.ndim(colors) > 2:
n_colors = len(colors)
else:
n_colors = 1
ratios += [n_colors * colors_ratio]
# Add the ratio for the heatmap itself
ratios.append(1 - sum(ratios))
return ratios
@staticmethod
def color_list_to_matrix_and_cmap(colors, ind, axis=0):
"""Turns a list of colors into a numpy matrix and matplotlib colormap
These arguments can now be plotted using heatmap(matrix, cmap)
and the provided colors will be plotted.
Parameters
----------
colors : list of matplotlib colors
Colors to label the rows or columns of a dataframe.
ind : list of ints
Ordering of the rows or columns, to reorder the original colors
by the clustered dendrogram order
axis : int
Which axis this is labeling
Returns
-------
matrix : numpy.array
A numpy array of integer values, where each indexes into the cmap
cmap : matplotlib.colors.ListedColormap
"""
try:
mpl.colors.to_rgb(colors[0])
except ValueError:
# We have a 2D color structure
m, n = len(colors), len(colors[0])
if not all(len(c) == n for c in colors[1:]):
raise ValueError("Multiple side color vectors must have same size")
else:
# We have one vector of colors
m, n = 1, len(colors)
colors = [colors]
# Map from unique colors to colormap index value
unique_colors = {}
matrix = np.zeros((m, n), int)
for i, inner in enumerate(colors):
for j, color in enumerate(inner):
idx = unique_colors.setdefault(color, len(unique_colors))
matrix[i, j] = idx
# Reorder for clustering and transpose for axis
matrix = matrix[:, ind]
if axis == 0:
matrix = matrix.T
cmap = mpl.colors.ListedColormap(list(unique_colors))
return matrix, cmap
def plot_dendrograms(self, row_cluster, col_cluster, metric, method,
row_linkage, col_linkage, tree_kws):
# Plot the row dendrogram
if row_cluster:
self.dendrogram_row = dendrogram(
self.data2d, metric=metric, method=method, label=False, axis=0,
ax=self.ax_row_dendrogram, rotate=True, linkage=row_linkage,
tree_kws=tree_kws
)
else:
self.ax_row_dendrogram.set_xticks([])
self.ax_row_dendrogram.set_yticks([])
# PLot the column dendrogram
if col_cluster:
self.dendrogram_col = dendrogram(
self.data2d, metric=metric, method=method, label=False,
axis=1, ax=self.ax_col_dendrogram, linkage=col_linkage,
tree_kws=tree_kws
)
else:
self.ax_col_dendrogram.set_xticks([])
self.ax_col_dendrogram.set_yticks([])
despine(ax=self.ax_row_dendrogram, bottom=True, left=True)
despine(ax=self.ax_col_dendrogram, bottom=True, left=True)
def plot_colors(self, xind, yind, **kws):
"""Plots color labels between the dendrogram and the heatmap
Parameters
----------
heatmap_kws : dict
Keyword arguments heatmap
"""
# Remove any custom colormap and centering
# TODO this code has consistently caused problems when we
# have missed kwargs that need to be excluded that it might
# be better to rewrite *in*clusively.
kws = kws.copy()
kws.pop('cmap', None)
kws.pop('norm', None)
kws.pop('center', None)
kws.pop('annot', None)
kws.pop('vmin', None)
kws.pop('vmax', None)
kws.pop('robust', None)
kws.pop('xticklabels', None)
kws.pop('yticklabels', None)
# Plot the row colors
if self.row_colors is not None:
matrix, cmap = self.color_list_to_matrix_and_cmap(
self.row_colors, yind, axis=0)
# Get row_color labels
if self.row_color_labels is not None:
row_color_labels = self.row_color_labels
else:
row_color_labels = False
heatmap(matrix, cmap=cmap, cbar=False, ax=self.ax_row_colors,
xticklabels=row_color_labels, yticklabels=False, **kws)
# Adjust rotation of labels
if row_color_labels is not False:
plt.setp(self.ax_row_colors.get_xticklabels(), rotation=90)
else:
despine(self.ax_row_colors, left=True, bottom=True)
# Plot the column colors
if self.col_colors is not None:
matrix, cmap = self.color_list_to_matrix_and_cmap(
self.col_colors, xind, axis=1)
# Get col_color labels
if self.col_color_labels is not None:
col_color_labels = self.col_color_labels
else:
col_color_labels = False
heatmap(matrix, cmap=cmap, cbar=False, ax=self.ax_col_colors,
xticklabels=False, yticklabels=col_color_labels, **kws)
# Adjust rotation of labels, place on right side
if col_color_labels is not False:
self.ax_col_colors.yaxis.tick_right()
plt.setp(self.ax_col_colors.get_yticklabels(), rotation=0)
else:
despine(self.ax_col_colors, left=True, bottom=True)
def plot_matrix(self, colorbar_kws, xind, yind, **kws):
self.data2d = self.data2d.iloc[yind, xind]
self.mask = self.mask.iloc[yind, xind]
# Try to reorganize specified tick labels, if provided
xtl = kws.pop("xticklabels", "auto")
try:
xtl = np.asarray(xtl)[xind]
except (TypeError, IndexError):
pass
ytl = kws.pop("yticklabels", "auto")
try:
ytl = np.asarray(ytl)[yind]
except (TypeError, IndexError):
pass
# Reorganize the annotations to match the heatmap
annot = kws.pop("annot", None)
if annot is None or annot is False:
pass
else:
if isinstance(annot, bool):
annot_data = self.data2d
else:
annot_data = np.asarray(annot)
if annot_data.shape != self.data2d.shape:
err = "`data` and `annot` must have same shape."
raise ValueError(err)
annot_data = annot_data[yind][:, xind]
annot = annot_data
# Setting ax_cbar=None in clustermap call implies no colorbar
kws.setdefault("cbar", self.ax_cbar is not None)
heatmap(self.data2d, ax=self.ax_heatmap, cbar_ax=self.ax_cbar,
cbar_kws=colorbar_kws, mask=self.mask,
xticklabels=xtl, yticklabels=ytl, annot=annot, **kws)
ytl = self.ax_heatmap.get_yticklabels()
ytl_rot = None if not ytl else ytl[0].get_rotation()
self.ax_heatmap.yaxis.set_ticks_position('right')
self.ax_heatmap.yaxis.set_label_position('right')
if ytl_rot is not None:
ytl = self.ax_heatmap.get_yticklabels()
plt.setp(ytl, rotation=ytl_rot)
tight_params = dict(h_pad=.02, w_pad=.02)
if self.ax_cbar is None:
self._figure.tight_layout(**tight_params)
else:
# Turn the colorbar axes off for tight layout so that its
# ticks don't interfere with the rest of the plot layout.
# Then move it.
self.ax_cbar.set_axis_off()
self._figure.tight_layout(**tight_params)
self.ax_cbar.set_axis_on()
self.ax_cbar.set_position(self.cbar_pos)
def plot(self, metric, method, colorbar_kws, row_cluster, col_cluster,
row_linkage, col_linkage, tree_kws, **kws):
# heatmap square=True sets the aspect ratio on the axes, but that is
# not compatible with the multi-axes layout of clustergrid
if kws.get("square", False):
msg = "``square=True`` ignored in clustermap"
warnings.warn(msg)
kws.pop("square")
colorbar_kws = {} if colorbar_kws is None else colorbar_kws
self.plot_dendrograms(row_cluster, col_cluster, metric, method,
row_linkage=row_linkage, col_linkage=col_linkage,
tree_kws=tree_kws)
try:
xind = self.dendrogram_col.reordered_ind
except AttributeError:
xind = np.arange(self.data2d.shape[1])
try:
yind = self.dendrogram_row.reordered_ind
except AttributeError:
yind = np.arange(self.data2d.shape[0])
self.plot_colors(xind, yind, **kws)
self.plot_matrix(colorbar_kws, xind, yind, **kws)
return self
def clustermap(
data, *,
pivot_kws=None, method='average', metric='euclidean',
z_score=None, standard_scale=None, figsize=(10, 10),
cbar_kws=None, row_cluster=True, col_cluster=True,
row_linkage=None, col_linkage=None,
row_colors=None, col_colors=None, mask=None,
dendrogram_ratio=.2, colors_ratio=0.03,
cbar_pos=(.02, .8, .05, .18), tree_kws=None,
**kwargs
):
"""
Plot a matrix dataset as a hierarchically-clustered heatmap.
This function requires scipy to be available.
Parameters
----------
data : 2D array-like
Rectangular data for clustering. Cannot contain NAs.
pivot_kws : dict, optional
If `data` is a tidy dataframe, can provide keyword arguments for
pivot to create a rectangular dataframe.
method : str, optional
Linkage method to use for calculating clusters. See
:func:`scipy.cluster.hierarchy.linkage` documentation for more
information.
metric : str, optional
Distance metric to use for the data. See
:func:`scipy.spatial.distance.pdist` documentation for more options.
To use different metrics (or methods) for rows and columns, you may
construct each linkage matrix yourself and provide them as
`{row,col}_linkage`.
z_score : int or None, optional
Either 0 (rows) or 1 (columns). Whether or not to calculate z-scores
for the rows or the columns. Z scores are: z = (x - mean)/std, so
values in each row (column) will get the mean of the row (column)
subtracted, then divided by the standard deviation of the row (column).
This ensures that each row (column) has mean of 0 and variance of 1.
standard_scale : int or None, optional
Either 0 (rows) or 1 (columns). Whether or not to standardize that
dimension, meaning for each row or column, subtract the minimum and
divide each by its maximum.
figsize : tuple of (width, height), optional
Overall size of the figure.
cbar_kws : dict, optional
Keyword arguments to pass to `cbar_kws` in :func:`heatmap`, e.g. to
add a label to the colorbar.
{row,col}_cluster : bool, optional
If ``True``, cluster the {rows, columns}.
{row,col}_linkage : :class:`numpy.ndarray`, optional
Precomputed linkage matrix for the rows or columns. See
:func:`scipy.cluster.hierarchy.linkage` for specific formats.
{row,col}_colors : list-like or pandas DataFrame/Series, optional
List of colors to label for either the rows or columns. Useful to evaluate
whether samples within a group are clustered together. Can use nested lists or
DataFrame for multiple color levels of labeling. If given as a
:class:`pandas.DataFrame` or :class:`pandas.Series`, labels for the colors are
extracted from the DataFrames column names or from the name of the Series.
DataFrame/Series colors are also matched to the data by their index, ensuring
colors are drawn in the correct order.
mask : bool array or DataFrame, optional
If passed, data will not be shown in cells where `mask` is True.
Cells with missing values are automatically masked. Only used for
visualizing, not for calculating.
{dendrogram,colors}_ratio : float, or pair of floats, optional
Proportion of the figure size devoted to the two marginal elements. If
a pair is given, they correspond to (row, col) ratios.
cbar_pos : tuple of (left, bottom, width, height), optional
Position of the colorbar axes in the figure. Setting to ``None`` will
disable the colorbar.
tree_kws : dict, optional
Parameters for the :class:`matplotlib.collections.LineCollection`
that is used to plot the lines of the dendrogram tree.
kwargs : other keyword arguments
All other keyword arguments are passed to :func:`heatmap`.
Returns
-------
:class:`ClusterGrid`
A :class:`ClusterGrid` instance.
See Also
--------
heatmap : Plot rectangular data as a color-encoded matrix.
Notes
-----
The returned object has a ``savefig`` method that should be used if you
want to save the figure object without clipping the dendrograms.
To access the reordered row indices, use:
``clustergrid.dendrogram_row.reordered_ind``
Column indices, use:
``clustergrid.dendrogram_col.reordered_ind``
Examples
--------
.. include:: ../docstrings/clustermap.rst
"""
if _no_scipy:
raise RuntimeError("clustermap requires scipy to be available")
plotter = ClusterGrid(data, pivot_kws=pivot_kws, figsize=figsize,
row_colors=row_colors, col_colors=col_colors,
z_score=z_score, standard_scale=standard_scale,
mask=mask, dendrogram_ratio=dendrogram_ratio,
colors_ratio=colors_ratio, cbar_pos=cbar_pos)
return plotter.plot(metric=metric, method=method,
colorbar_kws=cbar_kws,
row_cluster=row_cluster, col_cluster=col_cluster,
row_linkage=row_linkage, col_linkage=col_linkage,
tree_kws=tree_kws, **kwargs)
================================================
FILE: seaborn/miscplot.py
================================================
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
__all__ = ["palplot", "dogplot"]
def palplot(pal, size=1):
"""Plot the values in a color palette as a horizontal array.
Parameters
----------
pal : sequence of matplotlib colors
colors, i.e. as returned by seaborn.color_palette()
size :
scaling factor for size of plot
"""
n = len(pal)
_, ax = plt.subplots(1, 1, figsize=(n * size, size))
ax.imshow(np.arange(n).reshape(1, n),
cmap=mpl.colors.ListedColormap(list(pal)),
interpolation="nearest", aspect="auto")
ax.set_xticks(np.arange(n) - .5)
ax.set_yticks([-.5, .5])
# Ensure nice border between colors
ax.set_xticklabels(["" for _ in range(n)])
# The proper way to set no ticks
ax.yaxis.set_major_locator(ticker.NullLocator())
def dogplot(*_, **__):
"""Who's a good boy?"""
from urllib.request import urlopen
from io import BytesIO
url = "https://github.com/mwaskom/seaborn-data/raw/master/png/img{}.png"
pic = np.random.randint(2, 7)
data = BytesIO(urlopen(url.format(pic)).read())
img = plt.imread(data)
f, ax = plt.subplots(figsize=(5, 5), dpi=100)
f.subplots_adjust(0, 0, 1, 1)
ax.imshow(img)
ax.set_axis_off()
================================================
FILE: seaborn/objects.py
================================================
"""
A declarative, object-oriented interface for creating statistical graphics.
The seaborn.objects namespace contains a number of classes that can be composed
together to build a customized visualization.
The main object is :class:`Plot`, which is the starting point for all figures.
Pass :class:`Plot` a dataset and specify assignments from its variables to
roles in the plot. Build up the visualization by calling its methods.
There are four other general types of objects in this interface:
- :class:`Mark` subclasses, which create matplotlib artists for visualization
- :class:`Stat` subclasses, which apply statistical transforms before plotting
- :class:`Move` subclasses, which make further adjustments to reduce overplotting
These classes are passed to :meth:`Plot.add` to define a layer in the plot.
Each layer has a :class:`Mark` and optional :class:`Stat` and/or :class:`Move`.
Plots can have multiple layers.
The other general type of object is a :class:`Scale` subclass, which provide an
interface for controlling the mappings between data values and visual properties.
Pass :class:`Scale` objects to :meth:`Plot.scale`.
See the documentation for other :class:`Plot` methods to learn about the many
ways that a plot can be enhanced and customized.
"""
from seaborn._core.plot import Plot # noqa: F401
from seaborn._marks.base import Mark # noqa: F401
from seaborn._marks.area import Area, Band # noqa: F401
from seaborn._marks.bar import Bar, Bars # noqa: F401
from seaborn._marks.dot import Dot, Dots # noqa: F401
from seaborn._marks.line import Dash, Line, Lines, Path, Paths, Range # noqa: F401
from seaborn._marks.text import Text # noqa: F401
from seaborn._stats.base import Stat # noqa: F401
from seaborn._stats.aggregation import Agg, Est # noqa: F401
from seaborn._stats.counting import Count, Hist # noqa: F401
from seaborn._stats.density import KDE # noqa: F401
from seaborn._stats.order import Perc # noqa: F401
from seaborn._stats.regression import PolyFit # noqa: F401
from seaborn._core.moves import Dodge, Jitter, Norm, Shift, Stack, Move # noqa: F401
from seaborn._core.scales import ( # noqa: F401
Boolean, Continuous, Nominal, Temporal, Scale
)
================================================
FILE: seaborn/palettes.py
================================================
import colorsys
from itertools import cycle
import numpy as np
import matplotlib as mpl
from .external import husl
from .utils import desaturate, get_color_cycle
from .colors import xkcd_rgb, crayons
from ._compat import get_colormap
__all__ = ["color_palette", "hls_palette", "husl_palette", "mpl_palette",
"dark_palette", "light_palette", "diverging_palette",
"blend_palette", "xkcd_palette", "crayon_palette",
"cubehelix_palette", "set_color_codes"]
SEABORN_PALETTES = dict(
deep=["#4C72B0", "#DD8452", "#55A868", "#C44E52", "#8172B3",
"#937860", "#DA8BC3", "#8C8C8C", "#CCB974", "#64B5CD"],
deep6=["#4C72B0", "#55A868", "#C44E52",
"#8172B3", "#CCB974", "#64B5CD"],
muted=["#4878D0", "#EE854A", "#6ACC64", "#D65F5F", "#956CB4",
"#8C613C", "#DC7EC0", "#797979", "#D5BB67", "#82C6E2"],
muted6=["#4878D0", "#6ACC64", "#D65F5F",
"#956CB4", "#D5BB67", "#82C6E2"],
pastel=["#A1C9F4", "#FFB482", "#8DE5A1", "#FF9F9B", "#D0BBFF",
"#DEBB9B", "#FAB0E4", "#CFCFCF", "#FFFEA3", "#B9F2F0"],
pastel6=["#A1C9F4", "#8DE5A1", "#FF9F9B",
"#D0BBFF", "#FFFEA3", "#B9F2F0"],
bright=["#023EFF", "#FF7C00", "#1AC938", "#E8000B", "#8B2BE2",
"#9F4800", "#F14CC1", "#A3A3A3", "#FFC400", "#00D7FF"],
bright6=["#023EFF", "#1AC938", "#E8000B",
"#8B2BE2", "#FFC400", "#00D7FF"],
dark=["#001C7F", "#B1400D", "#12711C", "#8C0800", "#591E71",
"#592F0D", "#A23582", "#3C3C3C", "#B8850A", "#006374"],
dark6=["#001C7F", "#12711C", "#8C0800",
"#591E71", "#B8850A", "#006374"],
colorblind=["#0173B2", "#DE8F05", "#029E73", "#D55E00", "#CC78BC",
"#CA9161", "#FBAFE4", "#949494", "#ECE133", "#56B4E9"],
colorblind6=["#0173B2", "#029E73", "#D55E00",
"#CC78BC", "#ECE133", "#56B4E9"]
)
MPL_QUAL_PALS = {
"tab10": 10, "tab20": 20, "tab20b": 20, "tab20c": 20,
"Set1": 9, "Set2": 8, "Set3": 12,
"Accent": 8, "Paired": 12,
"Pastel1": 9, "Pastel2": 8, "Dark2": 8,
}
QUAL_PALETTE_SIZES = MPL_QUAL_PALS.copy()
QUAL_PALETTE_SIZES.update({k: len(v) for k, v in SEABORN_PALETTES.items()})
QUAL_PALETTES = list(QUAL_PALETTE_SIZES.keys())
class _ColorPalette(list):
"""Set the color palette in a with statement, otherwise be a list."""
def __enter__(self):
"""Open the context."""
from .rcmod import set_palette
self._orig_palette = color_palette()
set_palette(self)
return self
def __exit__(self, *args):
"""Close the context."""
from .rcmod import set_palette
set_palette(self._orig_palette)
def as_hex(self):
"""Return a color palette with hex codes instead of RGB values."""
hex = [mpl.colors.rgb2hex(rgb) for rgb in self]
return _ColorPalette(hex)
def _repr_html_(self):
"""Rich display of the color palette in an HTML frontend."""
s = 55
n = len(self)
html = f''
return html
def _patch_colormap_display():
"""Simplify the rich display of matplotlib color maps in a notebook."""
def _repr_png_(self):
"""Generate a PNG representation of the Colormap."""
import io
from PIL import Image
import numpy as np
IMAGE_SIZE = (400, 50)
X = np.tile(np.linspace(0, 1, IMAGE_SIZE[0]), (IMAGE_SIZE[1], 1))
pixels = self(X, bytes=True)
png_bytes = io.BytesIO()
Image.fromarray(pixels).save(png_bytes, format='png')
return png_bytes.getvalue()
def _repr_html_(self):
"""Generate an HTML representation of the Colormap."""
import base64
png_bytes = self._repr_png_()
png_base64 = base64.b64encode(png_bytes).decode('ascii')
return ('
')
mpl.colors.Colormap._repr_png_ = _repr_png_
mpl.colors.Colormap._repr_html_ = _repr_html_
def color_palette(palette=None, n_colors=None, desat=None, as_cmap=False):
"""Return a list of colors or continuous colormap defining a palette.
Possible ``palette`` values include:
- Name of a seaborn palette (deep, muted, bright, pastel, dark, colorblind)
- Name of matplotlib colormap
- 'husl' or 'hls'
- 'ch:'
- 'light:', 'dark:', 'blend:,',
- A sequence of colors in any format matplotlib accepts
Calling this function with ``palette=None`` will return the current
matplotlib color cycle.
This function can also be used in a ``with`` statement to temporarily
set the color cycle for a plot or set of plots.
See the :ref:`tutorial ` for more information.
Parameters
----------
palette : None, string, or sequence, optional
Name of palette or None to return current palette. If a sequence, input
colors are used but possibly cycled and desaturated.
n_colors : int, optional
Number of colors in the palette. If ``None``, the default will depend
on how ``palette`` is specified. Named palettes default to 6 colors,
but grabbing the current palette or passing in a list of colors will
not change the number of colors unless this is specified. Asking for
more colors than exist in the palette will cause it to cycle. Ignored
when ``as_cmap`` is True.
desat : float, optional
Proportion to desaturate each color by.
as_cmap : bool
If True, return a :class:`matplotlib.colors.ListedColormap`.
Returns
-------
list of RGB tuples or :class:`matplotlib.colors.ListedColormap`
See Also
--------
set_palette : Set the default color cycle for all plots.
set_color_codes : Reassign color codes like ``"b"``, ``"g"``, etc. to
colors from one of the seaborn palettes.
Examples
--------
.. include:: ../docstrings/color_palette.rst
"""
if palette is None:
palette = get_color_cycle()
if n_colors is None:
n_colors = len(palette)
elif not isinstance(palette, str):
palette = palette
if n_colors is None:
n_colors = len(palette)
else:
if n_colors is None:
# Use all colors in a qualitative palette or 6 of another kind
n_colors = QUAL_PALETTE_SIZES.get(palette, 6)
if palette in SEABORN_PALETTES:
# Named "seaborn variant" of matplotlib default color cycle
palette = SEABORN_PALETTES[palette]
elif palette == "hls":
# Evenly spaced colors in cylindrical RGB space
palette = hls_palette(n_colors, as_cmap=as_cmap)
elif palette == "husl":
# Evenly spaced colors in cylindrical Lab space
palette = husl_palette(n_colors, as_cmap=as_cmap)
elif palette.lower() == "jet":
# Paternalism
raise ValueError("No.")
elif palette.startswith("ch:"):
# Cubehelix palette with params specified in string
args, kwargs = _parse_cubehelix_args(palette)
palette = cubehelix_palette(n_colors, *args, **kwargs, as_cmap=as_cmap)
elif palette.startswith("light:"):
# light palette to color specified in string
_, color = palette.split(":")
reverse = color.endswith("_r")
if reverse:
color = color[:-2]
palette = light_palette(color, n_colors, reverse=reverse, as_cmap=as_cmap)
elif palette.startswith("dark:"):
# light palette to color specified in string
_, color = palette.split(":")
reverse = color.endswith("_r")
if reverse:
color = color[:-2]
palette = dark_palette(color, n_colors, reverse=reverse, as_cmap=as_cmap)
elif palette.startswith("blend:"):
# blend palette between colors specified in string
_, colors = palette.split(":")
colors = colors.split(",")
palette = blend_palette(colors, n_colors, as_cmap=as_cmap)
else:
try:
# Perhaps a named matplotlib colormap?
palette = mpl_palette(palette, n_colors, as_cmap=as_cmap)
except (ValueError, KeyError): # Error class changed in mpl36
raise ValueError(f"{palette!r} is not a valid palette name")
if desat is not None:
palette = [desaturate(c, desat) for c in palette]
if not as_cmap:
# Always return as many colors as we asked for
pal_cycle = cycle(palette)
palette = [next(pal_cycle) for _ in range(n_colors)]
# Always return in r, g, b tuple format
try:
palette = map(mpl.colors.colorConverter.to_rgb, palette)
palette = _ColorPalette(palette)
except ValueError:
raise ValueError(f"Could not generate a palette for {palette}")
return palette
def hls_palette(n_colors=6, h=.01, l=.6, s=.65, as_cmap=False): # noqa
"""
Return hues with constant lightness and saturation in the HLS system.
The hues are evenly sampled along a circular path. The resulting palette will be
appropriate for categorical or cyclical data.
The `h`, `l`, and `s` values should be between 0 and 1.
.. note::
While the separation of the resulting colors will be mathematically
constant, the HLS system does not construct a perceptually-uniform space,
so their apparent intensity will vary.
Parameters
----------
n_colors : int
Number of colors in the palette.
h : float
The value of the first hue.
l : float
The lightness value.
s : float
The saturation intensity.
as_cmap : bool
If True, return a matplotlib colormap object.
Returns
-------
palette
list of RGB tuples or :class:`matplotlib.colors.ListedColormap`
See Also
--------
husl_palette : Make a palette using evenly spaced hues in the HUSL system.
Examples
--------
.. include:: ../docstrings/hls_palette.rst
"""
if as_cmap:
n_colors = 256
hues = np.linspace(0, 1, int(n_colors) + 1)[:-1]
hues += h
hues %= 1
hues -= hues.astype(int)
palette = [colorsys.hls_to_rgb(h_i, l, s) for h_i in hues]
if as_cmap:
return mpl.colors.ListedColormap(palette, "hls")
else:
return _ColorPalette(palette)
def husl_palette(n_colors=6, h=.01, s=.9, l=.65, as_cmap=False): # noqa
"""
Return hues with constant lightness and saturation in the HUSL system.
The hues are evenly sampled along a circular path. The resulting palette will be
appropriate for categorical or cyclical data.
The `h`, `l`, and `s` values should be between 0 and 1.
This function is similar to :func:`hls_palette`, but it uses a nonlinear color
space that is more perceptually uniform.
Parameters
----------
n_colors : int
Number of colors in the palette.
h : float
The value of the first hue.
l : float
The lightness value.
s : float
The saturation intensity.
as_cmap : bool
If True, return a matplotlib colormap object.
Returns
-------
palette
list of RGB tuples or :class:`matplotlib.colors.ListedColormap`
See Also
--------
hls_palette : Make a palette using evenly spaced hues in the HSL system.
Examples
--------
.. include:: ../docstrings/husl_palette.rst
"""
if as_cmap:
n_colors = 256
hues = np.linspace(0, 1, int(n_colors) + 1)[:-1]
hues += h
hues %= 1
hues *= 359
s *= 99
l *= 99 # noqa
palette = [_color_to_rgb((h_i, s, l), input="husl") for h_i in hues]
if as_cmap:
return mpl.colors.ListedColormap(palette, "hsl")
else:
return _ColorPalette(palette)
def mpl_palette(name, n_colors=6, as_cmap=False):
"""
Return a palette or colormap from the matplotlib registry.
For continuous palettes, evenly-spaced discrete samples are chosen while
excluding the minimum and maximum value in the colormap to provide better
contrast at the extremes.
For qualitative palettes (e.g. those from colorbrewer), exact values are
indexed (rather than interpolated), but fewer than `n_colors` can be returned
if the palette does not define that many.
Parameters
----------
name : string
Name of the palette. This should be a named matplotlib colormap.
n_colors : int
Number of discrete colors in the palette.
Returns
-------
list of RGB tuples or :class:`matplotlib.colors.ListedColormap`
Examples
--------
.. include:: ../docstrings/mpl_palette.rst
"""
if name.endswith("_d"):
sub_name = name[:-2]
if sub_name.endswith("_r"):
reverse = True
sub_name = sub_name[:-2]
else:
reverse = False
pal = color_palette(sub_name, 2) + ["#333333"]
if reverse:
pal = pal[::-1]
cmap = blend_palette(pal, n_colors, as_cmap=True)
else:
cmap = get_colormap(name)
if name in MPL_QUAL_PALS:
bins = np.linspace(0, 1, MPL_QUAL_PALS[name])[:n_colors]
else:
bins = np.linspace(0, 1, int(n_colors) + 2)[1:-1]
palette = list(map(tuple, cmap(bins)[:, :3]))
if as_cmap:
return cmap
else:
return _ColorPalette(palette)
def _color_to_rgb(color, input):
"""Add some more flexibility to color choices."""
if input == "hls":
color = colorsys.hls_to_rgb(*color)
elif input == "husl":
color = husl.husl_to_rgb(*color)
color = tuple(np.clip(color, 0, 1))
elif input == "xkcd":
color = xkcd_rgb[color]
return mpl.colors.to_rgb(color)
def dark_palette(color, n_colors=6, reverse=False, as_cmap=False, input="rgb"):
"""Make a sequential palette that blends from dark to ``color``.
This kind of palette is good for data that range between relatively
uninteresting low values and interesting high values.
The ``color`` parameter can be specified in a number of ways, including
all options for defining a color in matplotlib and several additional
color spaces that are handled by seaborn. You can also use the database
of named colors from the XKCD color survey.
If you are using the IPython notebook, you can also choose this palette
interactively with the :func:`choose_dark_palette` function.
Parameters
----------
color : base color for high values
hex, rgb-tuple, or html color name
n_colors : int, optional
number of colors in the palette
reverse : bool, optional
if True, reverse the direction of the blend
as_cmap : bool, optional
If True, return a :class:`matplotlib.colors.ListedColormap`.
input : {'rgb', 'hls', 'husl', xkcd'}
Color space to interpret the input color. The first three options
apply to tuple inputs and the latter applies to string inputs.
Returns
-------
palette
list of RGB tuples or :class:`matplotlib.colors.ListedColormap`
See Also
--------
light_palette : Create a sequential palette with bright low values.
diverging_palette : Create a diverging palette with two colors.
Examples
--------
.. include:: ../docstrings/dark_palette.rst
"""
rgb = _color_to_rgb(color, input)
hue, sat, _ = husl.rgb_to_husl(*rgb)
gray_s, gray_l = .15 * sat, 15
gray = _color_to_rgb((hue, gray_s, gray_l), input="husl")
colors = [rgb, gray] if reverse else [gray, rgb]
return blend_palette(colors, n_colors, as_cmap)
def light_palette(color, n_colors=6, reverse=False, as_cmap=False, input="rgb"):
"""Make a sequential palette that blends from light to ``color``.
The ``color`` parameter can be specified in a number of ways, including
all options for defining a color in matplotlib and several additional
color spaces that are handled by seaborn. You can also use the database
of named colors from the XKCD color survey.
If you are using a Jupyter notebook, you can also choose this palette
interactively with the :func:`choose_light_palette` function.
Parameters
----------
color : base color for high values
hex code, html color name, or tuple in `input` space.
n_colors : int, optional
number of colors in the palette
reverse : bool, optional
if True, reverse the direction of the blend
as_cmap : bool, optional
If True, return a :class:`matplotlib.colors.ListedColormap`.
input : {'rgb', 'hls', 'husl', xkcd'}
Color space to interpret the input color. The first three options
apply to tuple inputs and the latter applies to string inputs.
Returns
-------
palette
list of RGB tuples or :class:`matplotlib.colors.ListedColormap`
See Also
--------
dark_palette : Create a sequential palette with dark low values.
diverging_palette : Create a diverging palette with two colors.
Examples
--------
.. include:: ../docstrings/light_palette.rst
"""
rgb = _color_to_rgb(color, input)
hue, sat, _ = husl.rgb_to_husl(*rgb)
gray_s, gray_l = .15 * sat, 95
gray = _color_to_rgb((hue, gray_s, gray_l), input="husl")
colors = [rgb, gray] if reverse else [gray, rgb]
return blend_palette(colors, n_colors, as_cmap)
def diverging_palette(h_neg, h_pos, s=75, l=50, sep=1, n=6, # noqa
center="light", as_cmap=False):
"""Make a diverging palette between two HUSL colors.
If you are using the IPython notebook, you can also choose this palette
interactively with the :func:`choose_diverging_palette` function.
Parameters
----------
h_neg, h_pos : float in [0, 359]
Anchor hues for negative and positive extents of the map.
s : float in [0, 100], optional
Anchor saturation for both extents of the map.
l : float in [0, 100], optional
Anchor lightness for both extents of the map.
sep : int, optional
Size of the intermediate region.
n : int, optional
Number of colors in the palette (if not returning a cmap)
center : {"light", "dark"}, optional
Whether the center of the palette is light or dark
as_cmap : bool, optional
If True, return a :class:`matplotlib.colors.ListedColormap`.
Returns
-------
palette
list of RGB tuples or :class:`matplotlib.colors.ListedColormap`
See Also
--------
dark_palette : Create a sequential palette with dark values.
light_palette : Create a sequential palette with light values.
Examples
--------
.. include: ../docstrings/diverging_palette.rst
"""
palfunc = dict(dark=dark_palette, light=light_palette)[center]
n_half = int(128 - (sep // 2))
neg = palfunc((h_neg, s, l), n_half, reverse=True, input="husl")
pos = palfunc((h_pos, s, l), n_half, input="husl")
midpoint = dict(light=[(.95, .95, .95)], dark=[(.133, .133, .133)])[center]
mid = midpoint * sep
pal = blend_palette(np.concatenate([neg, mid, pos]), n, as_cmap=as_cmap)
return pal
def blend_palette(colors, n_colors=6, as_cmap=False, input="rgb"):
"""Make a palette that blends between a list of colors.
Parameters
----------
colors : sequence of colors in various formats interpreted by `input`
hex code, html color name, or tuple in `input` space.
n_colors : int, optional
Number of colors in the palette.
as_cmap : bool, optional
If True, return a :class:`matplotlib.colors.ListedColormap`.
Returns
-------
palette
list of RGB tuples or :class:`matplotlib.colors.ListedColormap`
Examples
--------
.. include: ../docstrings/blend_palette.rst
"""
colors = [_color_to_rgb(color, input) for color in colors]
name = "blend"
pal = mpl.colors.LinearSegmentedColormap.from_list(name, colors)
if not as_cmap:
rgb_array = pal(np.linspace(0, 1, int(n_colors)))[:, :3] # no alpha
pal = _ColorPalette(map(tuple, rgb_array))
return pal
def xkcd_palette(colors):
"""Make a palette with color names from the xkcd color survey.
See xkcd for the full list of colors: https://xkcd.com/color/rgb/
This is just a simple wrapper around the `seaborn.xkcd_rgb` dictionary.
Parameters
----------
colors : list of strings
List of keys in the `seaborn.xkcd_rgb` dictionary.
Returns
-------
palette
A list of colors as RGB tuples.
See Also
--------
crayon_palette : Make a palette with Crayola crayon colors.
"""
palette = [xkcd_rgb[name] for name in colors]
return color_palette(palette, len(palette))
def crayon_palette(colors):
"""Make a palette with color names from Crayola crayons.
Colors are taken from here:
https://en.wikipedia.org/wiki/List_of_Crayola_crayon_colors
This is just a simple wrapper around the `seaborn.crayons` dictionary.
Parameters
----------
colors : list of strings
List of keys in the `seaborn.crayons` dictionary.
Returns
-------
palette
A list of colors as RGB tuples.
See Also
--------
xkcd_palette : Make a palette with named colors from the XKCD color survey.
"""
palette = [crayons[name] for name in colors]
return color_palette(palette, len(palette))
def cubehelix_palette(n_colors=6, start=0, rot=.4, gamma=1.0, hue=0.8,
light=.85, dark=.15, reverse=False, as_cmap=False):
"""Make a sequential palette from the cubehelix system.
This produces a colormap with linearly-decreasing (or increasing)
brightness. That means that information will be preserved if printed to
black and white or viewed by someone who is colorblind. "cubehelix" is
also available as a matplotlib-based palette, but this function gives the
user more control over the look of the palette and has a different set of
defaults.
In addition to using this function, it is also possible to generate a
cubehelix palette generally in seaborn using a string starting with
`ch:` and containing other parameters (e.g. `"ch:s=.25,r=-.5"`).
Parameters
----------
n_colors : int
Number of colors in the palette.
start : float, 0 <= start <= 3
The hue value at the start of the helix.
rot : float
Rotations around the hue wheel over the range of the palette.
gamma : float 0 <= gamma
Nonlinearity to emphasize dark (gamma < 1) or light (gamma > 1) colors.
hue : float, 0 <= hue <= 1
Saturation of the colors.
dark : float 0 <= dark <= 1
Intensity of the darkest color in the palette.
light : float 0 <= light <= 1
Intensity of the lightest color in the palette.
reverse : bool
If True, the palette will go from dark to light.
as_cmap : bool
If True, return a :class:`matplotlib.colors.ListedColormap`.
Returns
-------
palette
list of RGB tuples or :class:`matplotlib.colors.ListedColormap`
See Also
--------
choose_cubehelix_palette : Launch an interactive widget to select cubehelix
palette parameters.
dark_palette : Create a sequential palette with dark low values.
light_palette : Create a sequential palette with bright low values.
References
----------
Green, D. A. (2011). "A colour scheme for the display of astronomical
intensity images". Bulletin of the Astromical Society of India, Vol. 39,
p. 289-295.
Examples
--------
.. include:: ../docstrings/cubehelix_palette.rst
"""
def get_color_function(p0, p1):
# Copied from matplotlib because it lives in private module
def color(x):
# Apply gamma factor to emphasise low or high intensity values
xg = x ** gamma
# Calculate amplitude and angle of deviation from the black
# to white diagonal in the plane of constant
# perceived intensity.
a = hue * xg * (1 - xg) / 2
phi = 2 * np.pi * (start / 3 + rot * x)
return xg + a * (p0 * np.cos(phi) + p1 * np.sin(phi))
return color
cdict = {
"red": get_color_function(-0.14861, 1.78277),
"green": get_color_function(-0.29227, -0.90649),
"blue": get_color_function(1.97294, 0.0),
}
cmap = mpl.colors.LinearSegmentedColormap("cubehelix", cdict)
x = np.linspace(light, dark, int(n_colors))
pal = cmap(x)[:, :3].tolist()
if reverse:
pal = pal[::-1]
if as_cmap:
x_256 = np.linspace(light, dark, 256)
if reverse:
x_256 = x_256[::-1]
pal_256 = cmap(x_256)
cmap = mpl.colors.ListedColormap(pal_256, "seaborn_cubehelix")
return cmap
else:
return _ColorPalette(pal)
def _parse_cubehelix_args(argstr):
"""Turn stringified cubehelix params into args/kwargs."""
if argstr.startswith("ch:"):
argstr = argstr[3:]
if argstr.endswith("_r"):
reverse = True
argstr = argstr[:-2]
else:
reverse = False
if not argstr:
return [], {"reverse": reverse}
all_args = argstr.split(",")
args = [float(a.strip(" ")) for a in all_args if "=" not in a]
kwargs = [a.split("=") for a in all_args if "=" in a]
kwargs = {k.strip(" "): float(v.strip(" ")) for k, v in kwargs}
kwarg_map = dict(
s="start", r="rot", g="gamma",
h="hue", l="light", d="dark", # noqa: E741
)
kwargs = {kwarg_map.get(k, k): v for k, v in kwargs.items()}
if reverse:
kwargs["reverse"] = True
return args, kwargs
def set_color_codes(palette="deep"):
"""Change how matplotlib color shorthands are interpreted.
Calling this will change how shorthand codes like "b" or "g"
are interpreted by matplotlib in subsequent plots.
Parameters
----------
palette : {deep, muted, pastel, dark, bright, colorblind}
Named seaborn palette to use as the source of colors.
See Also
--------
set : Color codes can be set through the high-level seaborn style
manager.
set_palette : Color codes can also be set through the function that
sets the matplotlib color cycle.
"""
if palette == "reset":
colors = [
(0., 0., 1.),
(0., .5, 0.),
(1., 0., 0.),
(.75, 0., .75),
(.75, .75, 0.),
(0., .75, .75),
(0., 0., 0.)
]
elif not isinstance(palette, str):
err = "set_color_codes requires a named seaborn palette"
raise TypeError(err)
elif palette in SEABORN_PALETTES:
if not palette.endswith("6"):
palette = palette + "6"
colors = SEABORN_PALETTES[palette] + [(.1, .1, .1)]
else:
err = f"Cannot set colors with palette '{palette}'"
raise ValueError(err)
for code, color in zip("bgrmyck", colors):
rgb = mpl.colors.colorConverter.to_rgb(color)
mpl.colors.colorConverter.colors[code] = rgb
================================================
FILE: seaborn/rcmod.py
================================================
"""Control plot style and scaling using the matplotlib rcParams interface."""
import functools
import matplotlib as mpl
from cycler import cycler
from . import palettes
__all__ = ["set_theme", "set", "reset_defaults", "reset_orig",
"axes_style", "set_style", "plotting_context", "set_context",
"set_palette"]
_style_keys = [
"axes.facecolor",
"axes.edgecolor",
"axes.grid",
"axes.axisbelow",
"axes.labelcolor",
"figure.facecolor",
"grid.color",
"grid.linestyle",
"text.color",
"xtick.color",
"ytick.color",
"xtick.direction",
"ytick.direction",
"lines.solid_capstyle",
"patch.edgecolor",
"patch.force_edgecolor",
"image.cmap",
"font.family",
"font.sans-serif",
"xtick.bottom",
"xtick.top",
"ytick.left",
"ytick.right",
"axes.spines.left",
"axes.spines.bottom",
"axes.spines.right",
"axes.spines.top",
]
_context_keys = [
"font.size",
"axes.labelsize",
"axes.titlesize",
"xtick.labelsize",
"ytick.labelsize",
"legend.fontsize",
"legend.title_fontsize",
"axes.linewidth",
"grid.linewidth",
"lines.linewidth",
"lines.markersize",
"patch.linewidth",
"xtick.major.width",
"ytick.major.width",
"xtick.minor.width",
"ytick.minor.width",
"xtick.major.size",
"ytick.major.size",
"xtick.minor.size",
"ytick.minor.size",
]
def set_theme(context="notebook", style="darkgrid", palette="deep",
font="sans-serif", font_scale=1, color_codes=True, rc=None):
"""
Set aspects of the visual theme for all matplotlib and seaborn plots.
This function changes the global defaults for all plots using the
matplotlib rcParams system. The themeing is decomposed into several distinct
sets of parameter values.
The options are illustrated in the :doc:`aesthetics <../tutorial/aesthetics>`
and :doc:`color palette <../tutorial/color_palettes>` tutorials.
Parameters
----------
context : string or dict
Scaling parameters, see :func:`plotting_context`.
style : string or dict
Axes style parameters, see :func:`axes_style`.
palette : string or sequence
Color palette, see :func:`color_palette`.
font : string
Font family, see matplotlib font manager.
font_scale : float, optional
Separate scaling factor to independently scale the size of the
font elements.
color_codes : bool
If ``True`` and ``palette`` is a seaborn palette, remap the shorthand
color codes (e.g. "b", "g", "r", etc.) to the colors from this palette.
rc : dict or None
Dictionary of rc parameter mappings to override the above.
Examples
--------
.. include:: ../docstrings/set_theme.rst
"""
set_context(context, font_scale)
set_style(style, rc={"font.family": font})
set_palette(palette, color_codes=color_codes)
if rc is not None:
mpl.rcParams.update(rc)
def set(*args, **kwargs):
"""
Alias for :func:`set_theme`, which is the preferred interface.
This function may be removed in the future.
"""
set_theme(*args, **kwargs)
def reset_defaults():
"""Restore all RC params to default settings."""
mpl.rcParams.update(mpl.rcParamsDefault)
def reset_orig():
"""Restore all RC params to original settings (respects custom rc)."""
from . import _orig_rc_params
mpl.rcParams.update(_orig_rc_params)
def axes_style(style=None, rc=None):
"""
Get the parameters that control the general style of the plots.
The style parameters control properties like the color of the background and
whether a grid is enabled by default. This is accomplished using the
matplotlib rcParams system.
The options are illustrated in the
:doc:`aesthetics tutorial <../tutorial/aesthetics>`.
This function can also be used as a context manager to temporarily
alter the global defaults. See :func:`set_theme` or :func:`set_style`
to modify the global defaults for all plots.
Parameters
----------
style : None, dict, or one of {darkgrid, whitegrid, dark, white, ticks}
A dictionary of parameters or the name of a preconfigured style.
rc : dict, optional
Parameter mappings to override the values in the preset seaborn
style dictionaries. This only updates parameters that are
considered part of the style definition.
Examples
--------
.. include:: ../docstrings/axes_style.rst
"""
if style is None:
style_dict = {k: mpl.rcParams[k] for k in _style_keys}
elif isinstance(style, dict):
style_dict = style
else:
styles = ["white", "dark", "whitegrid", "darkgrid", "ticks"]
if style not in styles:
raise ValueError(f"style must be one of {', '.join(styles)}")
# Define colors here
dark_gray = ".15"
light_gray = ".8"
# Common parameters
style_dict = {
"figure.facecolor": "white",
"axes.labelcolor": dark_gray,
"xtick.direction": "out",
"ytick.direction": "out",
"xtick.color": dark_gray,
"ytick.color": dark_gray,
"axes.axisbelow": True,
"grid.linestyle": "-",
"text.color": dark_gray,
"font.family": ["sans-serif"],
"font.sans-serif": ["Arial", "DejaVu Sans", "Liberation Sans",
"Bitstream Vera Sans", "sans-serif"],
"lines.solid_capstyle": "round",
"patch.edgecolor": "w",
"patch.force_edgecolor": True,
"image.cmap": "rocket",
"xtick.top": False,
"ytick.right": False,
}
# Set grid on or off
if "grid" in style:
style_dict.update({
"axes.grid": True,
})
else:
style_dict.update({
"axes.grid": False,
})
# Set the color of the background, spines, and grids
if style.startswith("dark"):
style_dict.update({
"axes.facecolor": "#EAEAF2",
"axes.edgecolor": "white",
"grid.color": "white",
"axes.spines.left": True,
"axes.spines.bottom": True,
"axes.spines.right": True,
"axes.spines.top": True,
})
elif style == "whitegrid":
style_dict.update({
"axes.facecolor": "white",
"axes.edgecolor": light_gray,
"grid.color": light_gray,
"axes.spines.left": True,
"axes.spines.bottom": True,
"axes.spines.right": True,
"axes.spines.top": True,
})
elif style in ["white", "ticks"]:
style_dict.update({
"axes.facecolor": "white",
"axes.edgecolor": dark_gray,
"grid.color": light_gray,
"axes.spines.left": True,
"axes.spines.bottom": True,
"axes.spines.right": True,
"axes.spines.top": True,
})
# Show or hide the axes ticks
if style == "ticks":
style_dict.update({
"xtick.bottom": True,
"ytick.left": True,
})
else:
style_dict.update({
"xtick.bottom": False,
"ytick.left": False,
})
# Remove entries that are not defined in the base list of valid keys
# This lets us handle matplotlib <=/> 2.0
style_dict = {k: v for k, v in style_dict.items() if k in _style_keys}
# Override these settings with the provided rc dictionary
if rc is not None:
rc = {k: v for k, v in rc.items() if k in _style_keys}
style_dict.update(rc)
# Wrap in an _AxesStyle object so this can be used in a with statement
style_object = _AxesStyle(style_dict)
return style_object
def set_style(style=None, rc=None):
"""
Set the parameters that control the general style of the plots.
The style parameters control properties like the color of the background and
whether a grid is enabled by default. This is accomplished using the
matplotlib rcParams system.
The options are illustrated in the
:doc:`aesthetics tutorial <../tutorial/aesthetics>`.
See :func:`axes_style` to get the parameter values.
Parameters
----------
style : dict, or one of {darkgrid, whitegrid, dark, white, ticks}
A dictionary of parameters or the name of a preconfigured style.
rc : dict, optional
Parameter mappings to override the values in the preset seaborn
style dictionaries. This only updates parameters that are
considered part of the style definition.
Examples
--------
.. include:: ../docstrings/set_style.rst
"""
style_object = axes_style(style, rc)
mpl.rcParams.update(style_object)
def plotting_context(context=None, font_scale=1, rc=None):
"""
Get the parameters that control the scaling of plot elements.
These parameters correspond to label size, line thickness, etc. For more
information, see the :doc:`aesthetics tutorial <../tutorial/aesthetics>`.
The base context is "notebook", and the other contexts are "paper", "talk",
and "poster", which are version of the notebook parameters scaled by different
values. Font elements can also be scaled independently of (but relative to)
the other values.
This function can also be used as a context manager to temporarily
alter the global defaults. See :func:`set_theme` or :func:`set_context`
to modify the global defaults for all plots.
Parameters
----------
context : None, dict, or one of {paper, notebook, talk, poster}
A dictionary of parameters or the name of a preconfigured set.
font_scale : float, optional
Separate scaling factor to independently scale the size of the
font elements.
rc : dict, optional
Parameter mappings to override the values in the preset seaborn
context dictionaries. This only updates parameters that are
considered part of the context definition.
Examples
--------
.. include:: ../docstrings/plotting_context.rst
"""
if context is None:
context_dict = {k: mpl.rcParams[k] for k in _context_keys}
elif isinstance(context, dict):
context_dict = context
else:
contexts = ["paper", "notebook", "talk", "poster"]
if context not in contexts:
raise ValueError(f"context must be in {', '.join(contexts)}")
# Set up dictionary of default parameters
texts_base_context = {
"font.size": 12,
"axes.labelsize": 12,
"axes.titlesize": 12,
"xtick.labelsize": 11,
"ytick.labelsize": 11,
"legend.fontsize": 11,
"legend.title_fontsize": 12,
}
base_context = {
"axes.linewidth": 1.25,
"grid.linewidth": 1,
"lines.linewidth": 1.5,
"lines.markersize": 6,
"patch.linewidth": 1,
"xtick.major.width": 1.25,
"ytick.major.width": 1.25,
"xtick.minor.width": 1,
"ytick.minor.width": 1,
"xtick.major.size": 6,
"ytick.major.size": 6,
"xtick.minor.size": 4,
"ytick.minor.size": 4,
}
base_context.update(texts_base_context)
# Scale all the parameters by the same factor depending on the context
scaling = dict(paper=.8, notebook=1, talk=1.5, poster=2)[context]
context_dict = {k: v * scaling for k, v in base_context.items()}
# Now independently scale the fonts
font_keys = texts_base_context.keys()
font_dict = {k: context_dict[k] * font_scale for k in font_keys}
context_dict.update(font_dict)
# Override these settings with the provided rc dictionary
if rc is not None:
rc = {k: v for k, v in rc.items() if k in _context_keys}
context_dict.update(rc)
# Wrap in a _PlottingContext object so this can be used in a with statement
context_object = _PlottingContext(context_dict)
return context_object
def set_context(context=None, font_scale=1, rc=None):
"""
Set the parameters that control the scaling of plot elements.
These parameters correspond to label size, line thickness, etc.
Calling this function modifies the global matplotlib `rcParams`. For more
information, see the :doc:`aesthetics tutorial <../tutorial/aesthetics>`.
The base context is "notebook", and the other contexts are "paper", "talk",
and "poster", which are version of the notebook parameters scaled by different
values. Font elements can also be scaled independently of (but relative to)
the other values.
See :func:`plotting_context` to get the parameter values.
Parameters
----------
context : dict, or one of {paper, notebook, talk, poster}
A dictionary of parameters or the name of a preconfigured set.
font_scale : float, optional
Separate scaling factor to independently scale the size of the
font elements.
rc : dict, optional
Parameter mappings to override the values in the preset seaborn
context dictionaries. This only updates parameters that are
considered part of the context definition.
Examples
--------
.. include:: ../docstrings/set_context.rst
"""
context_object = plotting_context(context, font_scale, rc)
mpl.rcParams.update(context_object)
class _RCAesthetics(dict):
def __enter__(self):
rc = mpl.rcParams
self._orig = {k: rc[k] for k in self._keys}
self._set(self)
def __exit__(self, exc_type, exc_value, exc_tb):
self._set(self._orig)
def __call__(self, func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
with self:
return func(*args, **kwargs)
return wrapper
class _AxesStyle(_RCAesthetics):
"""Light wrapper on a dict to set style temporarily."""
_keys = _style_keys
_set = staticmethod(set_style)
class _PlottingContext(_RCAesthetics):
"""Light wrapper on a dict to set context temporarily."""
_keys = _context_keys
_set = staticmethod(set_context)
def set_palette(palette, n_colors=None, desat=None, color_codes=False):
"""Set the matplotlib color cycle using a seaborn palette.
Parameters
----------
palette : seaborn color palette | matplotlib colormap | hls | husl
Palette definition. Should be something :func:`color_palette` can process.
n_colors : int
Number of colors in the cycle. The default number of colors will depend
on the format of ``palette``, see the :func:`color_palette`
documentation for more information.
desat : float
Proportion to desaturate each color by.
color_codes : bool
If ``True`` and ``palette`` is a seaborn palette, remap the shorthand
color codes (e.g. "b", "g", "r", etc.) to the colors from this palette.
See Also
--------
color_palette : build a color palette or set the color cycle temporarily
in a ``with`` statement.
set_context : set parameters to scale plot elements
set_style : set the default parameters for figure style
"""
colors = palettes.color_palette(palette, n_colors, desat)
cyl = cycler('color', colors)
mpl.rcParams['axes.prop_cycle'] = cyl
if color_codes:
try:
palettes.set_color_codes(palette)
except (ValueError, TypeError):
pass
================================================
FILE: seaborn/regression.py
================================================
"""Plotting functions for linear models (broadly construed)."""
import copy
from textwrap import dedent
import warnings
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
try:
import statsmodels
assert statsmodels
_has_statsmodels = True
except ImportError:
_has_statsmodels = False
from . import utils
from . import algorithms as algo
from .axisgrid import FacetGrid, _facet_docs
__all__ = ["lmplot", "regplot", "residplot"]
class _LinearPlotter:
"""Base class for plotting relational data in tidy format.
To get anything useful done you'll have to inherit from this, but setup
code that can be abstracted out should be put here.
"""
def establish_variables(self, data, **kws):
"""Extract variables from data or use directly."""
self.data = data
# Validate the inputs
any_strings = any([isinstance(v, str) for v in kws.values()])
if any_strings and data is None:
raise ValueError("Must pass `data` if using named variables.")
# Set the variables
for var, val in kws.items():
if isinstance(val, str):
vector = data[val]
elif isinstance(val, list):
vector = np.asarray(val)
else:
vector = val
if vector is not None and vector.shape != (1,):
vector = np.squeeze(vector)
if np.ndim(vector) > 1:
err = "regplot inputs must be 1d"
raise ValueError(err)
setattr(self, var, vector)
def dropna(self, *vars):
"""Remove observations with missing data."""
vals = [getattr(self, var) for var in vars]
vals = [v for v in vals if v is not None]
not_na = np.all(np.column_stack([pd.notnull(v) for v in vals]), axis=1)
for var in vars:
val = getattr(self, var)
if val is not None:
setattr(self, var, val[not_na])
def plot(self, ax):
raise NotImplementedError
class _RegressionPlotter(_LinearPlotter):
"""Plotter for numeric independent variables with regression model.
This does the computations and drawing for the `regplot` function, and
is thus also used indirectly by `lmplot`.
"""
def __init__(self, x, y, data=None, x_estimator=None, x_bins=None,
x_ci="ci", scatter=True, fit_reg=True, ci=95, n_boot=1000,
units=None, seed=None, order=1, logistic=False, lowess=False,
robust=False, logx=False, x_partial=None, y_partial=None,
truncate=False, dropna=True, x_jitter=None, y_jitter=None,
color=None, label=None):
# Set member attributes
self.x_estimator = x_estimator
self.ci = ci
self.x_ci = ci if x_ci == "ci" else x_ci
self.n_boot = n_boot
self.seed = seed
self.scatter = scatter
self.fit_reg = fit_reg
self.order = order
self.logistic = logistic
self.lowess = lowess
self.robust = robust
self.logx = logx
self.truncate = truncate
self.x_jitter = x_jitter
self.y_jitter = y_jitter
self.color = color
self.label = label
# Validate the regression options:
if sum((order > 1, logistic, robust, lowess, logx)) > 1:
raise ValueError("Mutually exclusive regression options.")
# Extract the data vals from the arguments or passed dataframe
self.establish_variables(data, x=x, y=y, units=units,
x_partial=x_partial, y_partial=y_partial)
# Drop null observations
if dropna:
self.dropna("x", "y", "units", "x_partial", "y_partial")
# Regress nuisance variables out of the data
if self.x_partial is not None:
self.x = self.regress_out(self.x, self.x_partial)
if self.y_partial is not None:
self.y = self.regress_out(self.y, self.y_partial)
# Possibly bin the predictor variable, which implies a point estimate
if x_bins is not None:
self.x_estimator = np.mean if x_estimator is None else x_estimator
x_discrete, x_bins = self.bin_predictor(x_bins)
self.x_discrete = x_discrete
else:
self.x_discrete = self.x
# Disable regression in case of singleton inputs
if len(self.x) <= 1:
self.fit_reg = False
# Save the range of the x variable for the grid later
if self.fit_reg:
self.x_range = self.x.min(), self.x.max()
@property
def scatter_data(self):
"""Data where each observation is a point."""
x_j = self.x_jitter
if x_j is None:
x = self.x
else:
x = self.x + np.random.uniform(-x_j, x_j, len(self.x))
y_j = self.y_jitter
if y_j is None:
y = self.y
else:
y = self.y + np.random.uniform(-y_j, y_j, len(self.y))
return x, y
@property
def estimate_data(self):
"""Data with a point estimate and CI for each discrete x value."""
x, y = self.x_discrete, self.y
vals = sorted(np.unique(x))
points, cis = [], []
for val in vals:
# Get the point estimate of the y variable
_y = y[x == val]
est = self.x_estimator(_y)
points.append(est)
# Compute the confidence interval for this estimate
if self.x_ci is None:
cis.append(None)
else:
units = None
if self.x_ci == "sd":
sd = np.std(_y)
_ci = est - sd, est + sd
else:
if self.units is not None:
units = self.units[x == val]
boots = algo.bootstrap(_y,
func=self.x_estimator,
n_boot=self.n_boot,
units=units,
seed=self.seed)
_ci = utils.ci(boots, self.x_ci)
cis.append(_ci)
return vals, points, cis
def _check_statsmodels(self):
"""Check whether statsmodels is installed if any boolean options require it."""
options = "logistic", "robust", "lowess"
err = "`{}=True` requires statsmodels, an optional dependency, to be installed."
for option in options:
if getattr(self, option) and not _has_statsmodels:
raise RuntimeError(err.format(option))
def fit_regression(self, ax=None, x_range=None, grid=None):
"""Fit the regression model."""
self._check_statsmodels()
# Create the grid for the regression
if grid is None:
if self.truncate:
x_min, x_max = self.x_range
else:
if ax is None:
x_min, x_max = x_range
else:
x_min, x_max = ax.get_xlim()
grid = np.linspace(x_min, x_max, 100)
ci = self.ci
# Fit the regression
if self.order > 1:
yhat, yhat_boots = self.fit_poly(grid, self.order)
elif self.logistic:
from statsmodels.genmod.generalized_linear_model import GLM
from statsmodels.genmod.families import Binomial
yhat, yhat_boots = self.fit_statsmodels(grid, GLM,
family=Binomial())
elif self.lowess:
ci = None
grid, yhat = self.fit_lowess()
elif self.robust:
from statsmodels.robust.robust_linear_model import RLM
yhat, yhat_boots = self.fit_statsmodels(grid, RLM)
elif self.logx:
yhat, yhat_boots = self.fit_logx(grid)
else:
yhat, yhat_boots = self.fit_fast(grid)
# Compute the confidence interval at each grid point
if ci is None:
err_bands = None
else:
err_bands = utils.ci(yhat_boots, ci, axis=0)
return grid, yhat, err_bands
def fit_fast(self, grid):
"""Low-level regression and prediction using linear algebra."""
def reg_func(_x, _y):
return np.linalg.pinv(_x).dot(_y)
X, y = np.c_[np.ones(len(self.x)), self.x], self.y
grid = np.c_[np.ones(len(grid)), grid]
yhat = grid.dot(reg_func(X, y))
if self.ci is None:
return yhat, None
beta_boots = algo.bootstrap(X, y,
func=reg_func,
n_boot=self.n_boot,
units=self.units,
seed=self.seed).T
yhat_boots = grid.dot(beta_boots).T
return yhat, yhat_boots
def fit_poly(self, grid, order):
"""Regression using numpy polyfit for higher-order trends."""
def reg_func(_x, _y):
return np.polyval(np.polyfit(_x, _y, order), grid)
x, y = self.x, self.y
yhat = reg_func(x, y)
if self.ci is None:
return yhat, None
yhat_boots = algo.bootstrap(x, y,
func=reg_func,
n_boot=self.n_boot,
units=self.units,
seed=self.seed)
return yhat, yhat_boots
def fit_statsmodels(self, grid, model, **kwargs):
"""More general regression function using statsmodels objects."""
import statsmodels.tools.sm_exceptions as sme
X, y = np.c_[np.ones(len(self.x)), self.x], self.y
grid = np.c_[np.ones(len(grid)), grid]
def reg_func(_x, _y):
err_classes = (sme.PerfectSeparationError,)
try:
with warnings.catch_warnings():
if hasattr(sme, "PerfectSeparationWarning"):
# statsmodels>=0.14.0
warnings.simplefilter("error", sme.PerfectSeparationWarning)
err_classes = (*err_classes, sme.PerfectSeparationWarning)
yhat = model(_y, _x, **kwargs).fit().predict(grid)
except err_classes:
yhat = np.empty(len(grid))
yhat.fill(np.nan)
return yhat
yhat = reg_func(X, y)
if self.ci is None:
return yhat, None
yhat_boots = algo.bootstrap(X, y,
func=reg_func,
n_boot=self.n_boot,
units=self.units,
seed=self.seed)
return yhat, yhat_boots
def fit_lowess(self):
"""Fit a locally-weighted regression, which returns its own grid."""
from statsmodels.nonparametric.smoothers_lowess import lowess
grid, yhat = lowess(self.y, self.x).T
return grid, yhat
def fit_logx(self, grid):
"""Fit the model in log-space."""
X, y = np.c_[np.ones(len(self.x)), self.x], self.y
grid = np.c_[np.ones(len(grid)), np.log(grid)]
def reg_func(_x, _y):
_x = np.c_[_x[:, 0], np.log(_x[:, 1])]
return np.linalg.pinv(_x).dot(_y)
yhat = grid.dot(reg_func(X, y))
if self.ci is None:
return yhat, None
beta_boots = algo.bootstrap(X, y,
func=reg_func,
n_boot=self.n_boot,
units=self.units,
seed=self.seed).T
yhat_boots = grid.dot(beta_boots).T
return yhat, yhat_boots
def bin_predictor(self, bins):
"""Discretize a predictor by assigning value to closest bin."""
x = np.asarray(self.x)
if np.isscalar(bins):
percentiles = np.linspace(0, 100, bins + 2)[1:-1]
bins = np.percentile(x, percentiles)
else:
bins = np.ravel(bins)
dist = np.abs(np.subtract.outer(x, bins))
x_binned = bins[np.argmin(dist, axis=1)].ravel()
return x_binned, bins
def regress_out(self, a, b):
"""Regress b from a keeping a's original mean."""
a_mean = a.mean()
a = a - a_mean
b = b - b.mean()
b = np.c_[b]
a_prime = a - b.dot(np.linalg.pinv(b).dot(a))
return np.asarray(a_prime + a_mean).reshape(a.shape)
def plot(self, ax, scatter_kws, line_kws):
"""Draw the full plot."""
# Insert the plot label into the correct set of keyword arguments
if self.scatter:
scatter_kws["label"] = self.label
else:
line_kws["label"] = self.label
# Use the current color cycle state as a default
if self.color is None:
lines, = ax.plot([], [])
color = lines.get_color()
lines.remove()
else:
color = self.color
# Ensure that color is hex to avoid matplotlib weirdness
color = mpl.colors.rgb2hex(mpl.colors.colorConverter.to_rgb(color))
# Let color in keyword arguments override overall plot color
scatter_kws.setdefault("color", color)
line_kws.setdefault("color", color)
# Draw the constituent plots
if self.scatter:
self.scatterplot(ax, scatter_kws)
if self.fit_reg:
self.lineplot(ax, line_kws)
# Label the axes
if hasattr(self.x, "name"):
ax.set_xlabel(self.x.name)
if hasattr(self.y, "name"):
ax.set_ylabel(self.y.name)
def scatterplot(self, ax, kws):
"""Draw the data."""
# Treat the line-based markers specially, explicitly setting larger
# linewidth than is provided by the seaborn style defaults.
# This would ideally be handled better in matplotlib (i.e., distinguish
# between edgewidth for solid glyphs and linewidth for line glyphs
# but this should do for now.
line_markers = ["1", "2", "3", "4", "+", "x", "|", "_"]
if self.x_estimator is None:
if "marker" in kws and kws["marker"] in line_markers:
lw = mpl.rcParams["lines.linewidth"]
else:
lw = mpl.rcParams["lines.markeredgewidth"]
kws.setdefault("linewidths", lw)
if not hasattr(kws['color'], 'shape') or kws['color'].shape[1] < 4:
kws.setdefault("alpha", .8)
x, y = self.scatter_data
ax.scatter(x, y, **kws)
else:
# TODO abstraction
ci_kws = {"color": kws["color"]}
if "alpha" in kws:
ci_kws["alpha"] = kws["alpha"]
ci_kws["linewidth"] = mpl.rcParams["lines.linewidth"] * 1.75
kws.setdefault("s", 50)
xs, ys, cis = self.estimate_data
if [ci for ci in cis if ci is not None]:
for x, ci in zip(xs, cis):
ax.plot([x, x], ci, **ci_kws)
ax.scatter(xs, ys, **kws)
def lineplot(self, ax, kws):
"""Draw the model."""
# Fit the regression model
grid, yhat, err_bands = self.fit_regression(ax)
edges = grid[0], grid[-1]
# Get set default aesthetics
fill_color = kws["color"]
lw = kws.pop("lw", mpl.rcParams["lines.linewidth"] * 1.5)
kws.setdefault("linewidth", lw)
# Draw the regression line and confidence interval
line, = ax.plot(grid, yhat, **kws)
if not self.truncate:
line.sticky_edges.x[:] = edges # Prevent mpl from adding margin
if err_bands is not None:
ax.fill_between(grid, *err_bands, facecolor=fill_color, alpha=.15)
_regression_docs = dict(
model_api=dedent("""\
There are a number of mutually exclusive options for estimating the
regression model. See the :ref:`tutorial ` for more
information.\
"""),
regplot_vs_lmplot=dedent("""\
The :func:`regplot` and :func:`lmplot` functions are closely related, but
the former is an axes-level function while the latter is a figure-level
function that combines :func:`regplot` and :class:`FacetGrid`.\
"""),
x_estimator=dedent("""\
x_estimator : callable that maps vector -> scalar, optional
Apply this function to each unique value of ``x`` and plot the
resulting estimate. This is useful when ``x`` is a discrete variable.
If ``x_ci`` is given, this estimate will be bootstrapped and a
confidence interval will be drawn.\
"""),
x_bins=dedent("""\
x_bins : int or vector, optional
Bin the ``x`` variable into discrete bins and then estimate the central
tendency and a confidence interval. This binning only influences how
the scatterplot is drawn; the regression is still fit to the original
data. This parameter is interpreted either as the number of
evenly-sized (not necessary spaced) bins or the positions of the bin
centers. When this parameter is used, it implies that the default of
``x_estimator`` is ``numpy.mean``.\
"""),
x_ci=dedent("""\
x_ci : "ci", "sd", int in [0, 100] or None, optional
Size of the confidence interval used when plotting a central tendency
for discrete values of ``x``. If ``"ci"``, defer to the value of the
``ci`` parameter. If ``"sd"``, skip bootstrapping and show the
standard deviation of the observations in each bin.\
"""),
scatter=dedent("""\
scatter : bool, optional
If ``True``, draw a scatterplot with the underlying observations (or
the ``x_estimator`` values).\
"""),
fit_reg=dedent("""\
fit_reg : bool, optional
If ``True``, estimate and plot a regression model relating the ``x``
and ``y`` variables.\
"""),
ci=dedent("""\
ci : int in [0, 100] or None, optional
Size of the confidence interval for the regression estimate. This will
be drawn using translucent bands around the regression line. The
confidence interval is estimated using a bootstrap; for large
datasets, it may be advisable to avoid that computation by setting
this parameter to None.\
"""),
n_boot=dedent("""\
n_boot : int, optional
Number of bootstrap resamples used to estimate the ``ci``. The default
value attempts to balance time and stability; you may want to increase
this value for "final" versions of plots.\
"""),
units=dedent("""\
units : variable name in ``data``, optional
If the ``x`` and ``y`` observations are nested within sampling units,
those can be specified here. This will be taken into account when
computing the confidence intervals by performing a multilevel bootstrap
that resamples both units and observations (within unit). This does not
otherwise influence how the regression is estimated or drawn.\
"""),
seed=dedent("""\
seed : int, numpy.random.Generator, or numpy.random.RandomState, optional
Seed or random number generator for reproducible bootstrapping.\
"""),
order=dedent("""\
order : int, optional
If ``order`` is greater than 1, use ``numpy.polyfit`` to estimate a
polynomial regression.\
"""),
logistic=dedent("""\
logistic : bool, optional
If ``True``, assume that ``y`` is a binary variable and use
``statsmodels`` to estimate a logistic regression model. Note that this
is substantially more computationally intensive than linear regression,
so you may wish to decrease the number of bootstrap resamples
(``n_boot``) or set ``ci`` to None.\
"""),
lowess=dedent("""\
lowess : bool, optional
If ``True``, use ``statsmodels`` to estimate a nonparametric lowess
model (locally weighted linear regression). Note that confidence
intervals cannot currently be drawn for this kind of model.\
"""),
robust=dedent("""\
robust : bool, optional
If ``True``, use ``statsmodels`` to estimate a robust regression. This
will de-weight outliers. Note that this is substantially more
computationally intensive than standard linear regression, so you may
wish to decrease the number of bootstrap resamples (``n_boot``) or set
``ci`` to None.\
"""),
logx=dedent("""\
logx : bool, optional
If ``True``, estimate a linear regression of the form y ~ log(x), but
plot the scatterplot and regression model in the input space. Note that
``x`` must be positive for this to work.\
"""),
xy_partial=dedent("""\
{x,y}_partial : strings in ``data`` or matrices
Confounding variables to regress out of the ``x`` or ``y`` variables
before plotting.\
"""),
truncate=dedent("""\
truncate : bool, optional
If ``True``, the regression line is bounded by the data limits. If
``False``, it extends to the ``x`` axis limits.
"""),
dropna=dedent("""\
dropna : bool, optional
If ``True``, remove observations with missing data from the plot.
"""),
xy_jitter=dedent("""\
{x,y}_jitter : floats, optional
Add uniform random noise of this size to either the ``x`` or ``y``
variables. The noise is added to a copy of the data after fitting the
regression, and only influences the look of the scatterplot. This can
be helpful when plotting variables that take discrete values.\
"""),
scatter_line_kws=dedent("""\
{scatter,line}_kws : dictionaries
Additional keyword arguments to pass to ``plt.scatter`` and
``plt.plot``.\
"""),
)
_regression_docs.update(_facet_docs)
def lmplot(
data, *,
x=None, y=None, hue=None, col=None, row=None,
palette=None, col_wrap=None, height=5, aspect=1, markers="o",
sharex=None, sharey=None, hue_order=None, col_order=None, row_order=None,
legend=True, legend_out=None, x_estimator=None, x_bins=None,
x_ci="ci", scatter=True, fit_reg=True, ci=95, n_boot=1000,
units=None, seed=None, order=1, logistic=False, lowess=False,
robust=False, logx=False, x_partial=None, y_partial=None,
truncate=True, x_jitter=None, y_jitter=None, scatter_kws=None,
line_kws=None, facet_kws=None,
):
if facet_kws is None:
facet_kws = {}
def facet_kw_deprecation(key, val):
msg = (
f"{key} is deprecated from the `lmplot` function signature. "
"Please update your code to pass it using `facet_kws`."
)
if val is not None:
warnings.warn(msg, UserWarning)
facet_kws[key] = val
facet_kw_deprecation("sharex", sharex)
facet_kw_deprecation("sharey", sharey)
facet_kw_deprecation("legend_out", legend_out)
if data is None:
raise TypeError("Missing required keyword argument `data`.")
# Reduce the dataframe to only needed columns
need_cols = [x, y, hue, col, row, units, x_partial, y_partial]
cols = np.unique([a for a in need_cols if a is not None]).tolist()
data = data[cols]
# Initialize the grid
facets = FacetGrid(
data, row=row, col=col, hue=hue,
palette=palette,
row_order=row_order, col_order=col_order, hue_order=hue_order,
height=height, aspect=aspect, col_wrap=col_wrap,
**facet_kws,
)
# Add the markers here as FacetGrid has figured out how many levels of the
# hue variable are needed and we don't want to duplicate that process
if facets.hue_names is None:
n_markers = 1
else:
n_markers = len(facets.hue_names)
if not isinstance(markers, list):
markers = [markers] * n_markers
if len(markers) != n_markers:
raise ValueError("markers must be a singleton or a list of markers "
"for each level of the hue variable")
facets.hue_kws = {"marker": markers}
def update_datalim(data, x, y, ax, **kws):
xys = data[[x, y]].to_numpy().astype(float)
ax.update_datalim(xys, updatey=False)
ax.autoscale_view(scaley=False)
facets.map_dataframe(update_datalim, x=x, y=y)
# Draw the regression plot on each facet
regplot_kws = dict(
x_estimator=x_estimator, x_bins=x_bins, x_ci=x_ci,
scatter=scatter, fit_reg=fit_reg, ci=ci, n_boot=n_boot, units=units,
seed=seed, order=order, logistic=logistic, lowess=lowess,
robust=robust, logx=logx, x_partial=x_partial, y_partial=y_partial,
truncate=truncate, x_jitter=x_jitter, y_jitter=y_jitter,
scatter_kws=scatter_kws, line_kws=line_kws,
)
facets.map_dataframe(regplot, x=x, y=y, **regplot_kws)
facets.set_axis_labels(x, y)
# Add a legend
if legend and (hue is not None) and (hue not in [col, row]):
facets.add_legend()
return facets
lmplot.__doc__ = dedent("""\
Plot data and regression model fits across a FacetGrid.
This function combines :func:`regplot` and :class:`FacetGrid`. It is
intended as a convenient interface to fit regression models across
conditional subsets of a dataset.
When thinking about how to assign variables to different facets, a general
rule is that it makes sense to use ``hue`` for the most important
comparison, followed by ``col`` and ``row``. However, always think about
your particular dataset and the goals of the visualization you are
creating.
{model_api}
The parameters to this function span most of the options in
:class:`FacetGrid`, although there may be occasional cases where you will
want to use that class and :func:`regplot` directly.
Parameters
----------
{data}
x, y : strings, optional
Input variables; these should be column names in ``data``.
hue, col, row : strings
Variables that define subsets of the data, which will be drawn on
separate facets in the grid. See the ``*_order`` parameters to control
the order of levels of this variable.
{palette}
{col_wrap}
{height}
{aspect}
markers : matplotlib marker code or list of marker codes, optional
Markers for the scatterplot. If a list, each marker in the list will be
used for each level of the ``hue`` variable.
{share_xy}
.. deprecated:: 0.12.0
Pass using the `facet_kws` dictionary.
{{hue,col,row}}_order : lists, optional
Order for the levels of the faceting variables. By default, this will
be the order that the levels appear in ``data`` or, if the variables
are pandas categoricals, the category order.
legend : bool, optional
If ``True`` and there is a ``hue`` variable, add a legend.
{legend_out}
.. deprecated:: 0.12.0
Pass using the `facet_kws` dictionary.
{x_estimator}
{x_bins}
{x_ci}
{scatter}
{fit_reg}
{ci}
{n_boot}
{units}
{seed}
{order}
{logistic}
{lowess}
{robust}
{logx}
{xy_partial}
{truncate}
{xy_jitter}
{scatter_line_kws}
facet_kws : dict
Dictionary of keyword arguments for :class:`FacetGrid`.
Returns
-------
:class:`FacetGrid`
The :class:`FacetGrid` object with the plot on it for further tweaking.
See Also
--------
regplot : Plot data and a conditional model fit.
FacetGrid : Subplot grid for plotting conditional relationships.
pairplot : Combine :func:`regplot` and :class:`PairGrid` (when used with
``kind="reg"``).
Notes
-----
{regplot_vs_lmplot}
Examples
--------
.. include:: ../docstrings/lmplot.rst
""").format(**_regression_docs)
def regplot(
data=None, *, x=None, y=None,
x_estimator=None, x_bins=None, x_ci="ci",
scatter=True, fit_reg=True, ci=95, n_boot=1000, units=None,
seed=None, order=1, logistic=False, lowess=False, robust=False,
logx=False, x_partial=None, y_partial=None,
truncate=True, dropna=True, x_jitter=None, y_jitter=None,
label=None, color=None, marker="o",
scatter_kws=None, line_kws=None, ax=None
):
plotter = _RegressionPlotter(x, y, data, x_estimator, x_bins, x_ci,
scatter, fit_reg, ci, n_boot, units, seed,
order, logistic, lowess, robust, logx,
x_partial, y_partial, truncate, dropna,
x_jitter, y_jitter, color, label)
if ax is None:
ax = plt.gca()
scatter_kws = {} if scatter_kws is None else copy.copy(scatter_kws)
scatter_kws["marker"] = marker
line_kws = {} if line_kws is None else copy.copy(line_kws)
plotter.plot(ax, scatter_kws, line_kws)
return ax
regplot.__doc__ = dedent("""\
Plot data and a linear regression model fit.
{model_api}
Parameters
----------
x, y : string, series, or vector array
Input variables. If strings, these should correspond with column names
in ``data``. When pandas objects are used, axes will be labeled with
the series name.
{data}
{x_estimator}
{x_bins}
{x_ci}
{scatter}
{fit_reg}
{ci}
{n_boot}
{units}
{seed}
{order}
{logistic}
{lowess}
{robust}
{logx}
{xy_partial}
{truncate}
{dropna}
{xy_jitter}
label : string
Label to apply to either the scatterplot or regression line (if
``scatter`` is ``False``) for use in a legend.
color : matplotlib color
Color to apply to all plot elements; will be superseded by colors
passed in ``scatter_kws`` or ``line_kws``.
marker : matplotlib marker code
Marker to use for the scatterplot glyphs.
{scatter_line_kws}
ax : matplotlib Axes, optional
Axes object to draw the plot onto, otherwise uses the current Axes.
Returns
-------
ax : matplotlib Axes
The Axes object containing the plot.
See Also
--------
lmplot : Combine :func:`regplot` and :class:`FacetGrid` to plot multiple
linear relationships in a dataset.
jointplot : Combine :func:`regplot` and :class:`JointGrid` (when used with
``kind="reg"``).
pairplot : Combine :func:`regplot` and :class:`PairGrid` (when used with
``kind="reg"``).
residplot : Plot the residuals of a linear regression model.
Notes
-----
{regplot_vs_lmplot}
It's also easy to combine :func:`regplot` and :class:`JointGrid` or
:class:`PairGrid` through the :func:`jointplot` and :func:`pairplot`
functions, although these do not directly accept all of :func:`regplot`'s
parameters.
Examples
--------
.. include:: ../docstrings/regplot.rst
""").format(**_regression_docs)
def residplot(
data=None, *, x=None, y=None,
x_partial=None, y_partial=None, lowess=False,
order=1, robust=False, dropna=True, label=None, color=None,
scatter_kws=None, line_kws=None, ax=None
):
"""Plot the residuals of a linear regression.
This function will regress y on x (possibly as a robust or polynomial
regression) and then draw a scatterplot of the residuals. You can
optionally fit a lowess smoother to the residual plot, which can
help in determining if there is structure to the residuals.
Parameters
----------
data : DataFrame, optional
DataFrame to use if `x` and `y` are column names.
x : vector or string
Data or column name in `data` for the predictor variable.
y : vector or string
Data or column name in `data` for the response variable.
{x, y}_partial : vectors or string(s) , optional
These variables are treated as confounding and are removed from
the `x` or `y` variables before plotting.
lowess : boolean, optional
Fit a lowess smoother to the residual scatterplot.
order : int, optional
Order of the polynomial to fit when calculating the residuals.
robust : boolean, optional
Fit a robust linear regression when calculating the residuals.
dropna : boolean, optional
If True, ignore observations with missing data when fitting and
plotting.
label : string, optional
Label that will be used in any plot legends.
color : matplotlib color, optional
Color to use for all elements of the plot.
{scatter, line}_kws : dictionaries, optional
Additional keyword arguments passed to scatter() and plot() for drawing
the components of the plot.
ax : matplotlib axis, optional
Plot into this axis, otherwise grab the current axis or make a new
one if not existing.
Returns
-------
ax: matplotlib axes
Axes with the regression plot.
See Also
--------
regplot : Plot a simple linear regression model.
jointplot : Draw a :func:`residplot` with univariate marginal distributions
(when used with ``kind="resid"``).
Examples
--------
.. include:: ../docstrings/residplot.rst
"""
plotter = _RegressionPlotter(x, y, data, ci=None,
order=order, robust=robust,
x_partial=x_partial, y_partial=y_partial,
dropna=dropna, color=color, label=label)
if ax is None:
ax = plt.gca()
# Calculate the residual from a linear regression
_, yhat, _ = plotter.fit_regression(grid=plotter.x)
plotter.y = plotter.y - yhat
# Set the regression option on the plotter
if lowess:
plotter.lowess = True
else:
plotter.fit_reg = False
# Plot a horizontal line at 0
ax.axhline(0, ls=":", c=".2")
# Draw the scatterplot
scatter_kws = {} if scatter_kws is None else scatter_kws.copy()
line_kws = {} if line_kws is None else line_kws.copy()
plotter.plot(ax, scatter_kws, line_kws)
return ax
================================================
FILE: seaborn/relational.py
================================================
from functools import partial
import warnings
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.cbook import normalize_kwargs
from ._base import (
VectorPlotter,
)
from .utils import (
adjust_legend_subtitles,
_default_color,
_deprecate_ci,
_get_transform_functions,
_scatter_legend_artist,
)
from ._compat import groupby_apply_include_groups
from ._statistics import EstimateAggregator, WeightedAggregator
from .axisgrid import FacetGrid, _facet_docs
from ._docstrings import DocstringComponents, _core_docs
__all__ = ["relplot", "scatterplot", "lineplot"]
_relational_narrative = DocstringComponents(dict(
# --- Introductory prose
main_api="""
The relationship between `x` and `y` can be shown for different subsets
of the data using the `hue`, `size`, and `style` parameters. These
parameters control what visual semantics are used to identify the different
subsets. It is possible to show up to three dimensions independently by
using all three semantic types, but this style of plot can be hard to
interpret and is often ineffective. Using redundant semantics (i.e. both
`hue` and `style` for the same variable) can be helpful for making
graphics more accessible.
See the :ref:`tutorial ` for more information.
""",
relational_semantic="""
The default treatment of the `hue` (and to a lesser extent, `size`)
semantic, if present, depends on whether the variable is inferred to
represent "numeric" or "categorical" data. In particular, numeric variables
are represented with a sequential colormap by default, and the legend
entries show regular "ticks" with values that may or may not exist in the
data. This behavior can be controlled through various parameters, as
described and illustrated below.
""",
))
_relational_docs = dict(
# --- Shared function parameters
data_vars="""
x, y : names of variables in `data` or vector data
Input data variables; must be numeric. Can pass data directly or
reference columns in `data`.
""",
data="""
data : DataFrame, array, or list of arrays
Input data structure. If `x` and `y` are specified as names, this
should be a "long-form" DataFrame containing those columns. Otherwise
it is treated as "wide-form" data and grouping variables are ignored.
See the examples for the various ways this parameter can be specified
and the different effects of each.
""",
palette="""
palette : string, list, dict, or matplotlib colormap
An object that determines how colors are chosen when `hue` is used.
It can be the name of a seaborn palette or matplotlib colormap, a list
of colors (anything matplotlib understands), a dict mapping levels
of the `hue` variable to colors, or a matplotlib colormap object.
""",
hue_order="""
hue_order : list
Specified order for the appearance of the `hue` variable levels,
otherwise they are determined from the data. Not relevant when the
`hue` variable is numeric.
""",
hue_norm="""
hue_norm : tuple or :class:`matplotlib.colors.Normalize` object
Normalization in data units for colormap applied to the `hue`
variable when it is numeric. Not relevant if `hue` is categorical.
""",
sizes="""
sizes : list, dict, or tuple
An object that determines how sizes are chosen when `size` is used.
List or dict arguments should provide a size for each unique data value,
which forces a categorical interpretation. The argument may also be a
min, max tuple.
""",
size_order="""
size_order : list
Specified order for appearance of the `size` variable levels,
otherwise they are determined from the data. Not relevant when the
`size` variable is numeric.
""",
size_norm="""
size_norm : tuple or Normalize object
Normalization in data units for scaling plot objects when the
`size` variable is numeric.
""",
dashes="""
dashes : boolean, list, or dictionary
Object determining how to draw the lines for different levels of the
`style` variable. Setting to `True` will use default dash codes, or
you can pass a list of dash codes or a dictionary mapping levels of the
`style` variable to dash codes. Setting to `False` will use solid
lines for all subsets. Dashes are specified as in matplotlib: a tuple
of `(segment, gap)` lengths, or an empty string to draw a solid line.
""",
markers="""
markers : boolean, list, or dictionary
Object determining how to draw the markers for different levels of the
`style` variable. Setting to `True` will use default markers, or
you can pass a list of markers or a dictionary mapping levels of the
`style` variable to markers. Setting to `False` will draw
marker-less lines. Markers are specified as in matplotlib.
""",
style_order="""
style_order : list
Specified order for appearance of the `style` variable levels
otherwise they are determined from the data. Not relevant when the
`style` variable is numeric.
""",
units="""
units : vector or key in `data`
Grouping variable identifying sampling units. When used, a separate
line will be drawn for each unit with appropriate semantics, but no
legend entry will be added. Useful for showing distribution of
experimental replicates when exact identities are not needed.
""",
estimator="""
estimator : name of pandas method or callable or None
Method for aggregating across multiple observations of the `y`
variable at the same `x` level. If `None`, all observations will
be drawn.
""",
ci="""
ci : int or "sd" or None
Size of the confidence interval to draw when aggregating.
.. deprecated:: 0.12.0
Use the new `errorbar` parameter for more flexibility.
""",
n_boot="""
n_boot : int
Number of bootstraps to use for computing the confidence interval.
""",
seed="""
seed : int, numpy.random.Generator, or numpy.random.RandomState
Seed or random number generator for reproducible bootstrapping.
""",
legend="""
legend : "auto", "brief", "full", or False
How to draw the legend. If "brief", numeric `hue` and `size`
variables will be represented with a sample of evenly spaced values.
If "full", every group will get an entry in the legend. If "auto",
choose between brief or full representation based on number of levels.
If `False`, no legend data is added and no legend is drawn.
""",
ax_in="""
ax : matplotlib Axes
Axes object to draw the plot onto, otherwise uses the current Axes.
""",
ax_out="""
ax : matplotlib Axes
Returns the Axes object with the plot drawn onto it.
""",
)
_param_docs = DocstringComponents.from_nested_components(
core=_core_docs["params"],
facets=DocstringComponents(_facet_docs),
rel=DocstringComponents(_relational_docs),
stat=DocstringComponents.from_function_params(EstimateAggregator.__init__),
)
class _RelationalPlotter(VectorPlotter):
wide_structure = {
"x": "@index", "y": "@values", "hue": "@columns", "style": "@columns",
}
# TODO where best to define default parameters?
sort = True
class _LinePlotter(_RelationalPlotter):
_legend_attributes = ["color", "linewidth", "marker", "dashes"]
def __init__(
self, *,
data=None, variables={},
estimator=None, n_boot=None, seed=None, errorbar=None,
sort=True, orient="x", err_style=None, err_kws=None, legend=None
):
# TODO this is messy, we want the mapping to be agnostic about
# the kind of plot to draw, but for the time being we need to set
# this information so the SizeMapping can use it
self._default_size_range = (
np.r_[.5, 2] * mpl.rcParams["lines.linewidth"]
)
super().__init__(data=data, variables=variables)
self.estimator = estimator
self.errorbar = errorbar
self.n_boot = n_boot
self.seed = seed
self.sort = sort
self.orient = orient
self.err_style = err_style
self.err_kws = {} if err_kws is None else err_kws
self.legend = legend
def plot(self, ax, kws):
"""Draw the plot onto an axes, passing matplotlib kwargs."""
# Draw a test plot, using the passed in kwargs. The goal here is to
# honor both (a) the current state of the plot cycler and (b) the
# specified kwargs on all the lines we will draw, overriding when
# relevant with the data semantics. Note that we won't cycle
# internally; in other words, if `hue` is not used, all elements will
# have the same color, but they will have the color that you would have
# gotten from the corresponding matplotlib function, and calling the
# function will advance the axes property cycle.
kws = normalize_kwargs(kws, mpl.lines.Line2D)
kws.setdefault("markeredgewidth", 0.75)
kws.setdefault("markeredgecolor", "w")
# Set default error kwargs
err_kws = self.err_kws.copy()
if self.err_style == "band":
err_kws.setdefault("alpha", .2)
elif self.err_style == "bars":
pass
elif self.err_style is not None:
err = "`err_style` must be 'band' or 'bars', not {}"
raise ValueError(err.format(self.err_style))
# Initialize the aggregation object
weighted = "weight" in self.plot_data
agg = (WeightedAggregator if weighted else EstimateAggregator)(
self.estimator, self.errorbar, n_boot=self.n_boot, seed=self.seed,
)
# TODO abstract variable to aggregate over here-ish. Better name?
orient = self.orient
if orient not in {"x", "y"}:
err = f"`orient` must be either 'x' or 'y', not {orient!r}."
raise ValueError(err)
other = {"x": "y", "y": "x"}[orient]
# TODO How to handle NA? We don't want NA to propagate through to the
# estimate/CI when some values are present, but we would also like
# matplotlib to show "gaps" in the line when all values are missing.
# This is straightforward absent aggregation, but complicated with it.
# If we want to use nas, we need to conditionalize dropna in iter_data.
# Loop over the semantic subsets and add to the plot
grouping_vars = "hue", "size", "style"
for sub_vars, sub_data in self.iter_data(grouping_vars, from_comp_data=True):
if self.sort:
sort_vars = ["units", orient, other]
sort_cols = [var for var in sort_vars if var in self.variables]
sub_data = sub_data.sort_values(sort_cols)
if (
self.estimator is not None
and sub_data[orient].value_counts().max() > 1
):
if "units" in self.variables:
# TODO eventually relax this constraint
err = "estimator must be None when specifying units"
raise ValueError(err)
grouped = sub_data.groupby(orient, sort=self.sort)
# Could pass as_index=False instead of reset_index,
# but that fails on a corner case with older pandas.
sub_data = (
grouped
.apply(agg, other, **groupby_apply_include_groups(False))
.reset_index()
)
else:
sub_data[f"{other}min"] = np.nan
sub_data[f"{other}max"] = np.nan
# Apply inverse axis scaling
for var in "xy":
_, inv = _get_transform_functions(ax, var)
for col in sub_data.filter(regex=f"^{var}"):
sub_data[col] = inv(sub_data[col])
# --- Draw the main line(s)
if "units" in self.variables: # XXX why not add to grouping variables?
lines = []
for _, unit_data in sub_data.groupby("units"):
lines.extend(ax.plot(unit_data["x"], unit_data["y"], **kws))
else:
lines = ax.plot(sub_data["x"], sub_data["y"], **kws)
for line in lines:
if "hue" in sub_vars:
line.set_color(self._hue_map(sub_vars["hue"]))
if "size" in sub_vars:
line.set_linewidth(self._size_map(sub_vars["size"]))
if "style" in sub_vars:
attributes = self._style_map(sub_vars["style"])
if "dashes" in attributes:
line.set_dashes(attributes["dashes"])
if "marker" in attributes:
line.set_marker(attributes["marker"])
line_color = line.get_color()
line_alpha = line.get_alpha()
line_capstyle = line.get_solid_capstyle()
# --- Draw the confidence intervals
if self.estimator is not None and self.errorbar is not None:
# TODO handling of orientation will need to happen here
if self.err_style == "band":
func = {"x": ax.fill_between, "y": ax.fill_betweenx}[orient]
func(
sub_data[orient],
sub_data[f"{other}min"], sub_data[f"{other}max"],
color=line_color, **err_kws
)
elif self.err_style == "bars":
error_param = {
f"{other}err": (
sub_data[other] - sub_data[f"{other}min"],
sub_data[f"{other}max"] - sub_data[other],
)
}
ebars = ax.errorbar(
sub_data["x"], sub_data["y"], **error_param,
linestyle="", color=line_color, alpha=line_alpha,
**err_kws
)
# Set the capstyle properly on the error bars
for obj in ebars.get_children():
if isinstance(obj, mpl.collections.LineCollection):
obj.set_capstyle(line_capstyle)
# Finalize the axes details
self._add_axis_labels(ax)
if self.legend:
legend_artist = partial(mpl.lines.Line2D, xdata=[], ydata=[])
attrs = {"hue": "color", "size": "linewidth", "style": None}
self.add_legend_data(ax, legend_artist, kws, attrs)
handles, _ = ax.get_legend_handles_labels()
if handles:
legend = ax.legend(title=self.legend_title)
adjust_legend_subtitles(legend)
class _ScatterPlotter(_RelationalPlotter):
_legend_attributes = ["color", "s", "marker"]
def __init__(self, *, data=None, variables={}, legend=None):
# TODO this is messy, we want the mapping to be agnostic about
# the kind of plot to draw, but for the time being we need to set
# this information so the SizeMapping can use it
self._default_size_range = (
np.r_[.5, 2] * np.square(mpl.rcParams["lines.markersize"])
)
super().__init__(data=data, variables=variables)
self.legend = legend
def plot(self, ax, kws):
# --- Determine the visual attributes of the plot
data = self.comp_data.dropna()
if data.empty:
return
kws = normalize_kwargs(kws, mpl.collections.PathCollection)
# Define the vectors of x and y positions
empty = np.full(len(data), np.nan)
x = data.get("x", empty)
y = data.get("y", empty)
# Apply inverse scaling to the coordinate variables
_, inv_x = _get_transform_functions(ax, "x")
_, inv_y = _get_transform_functions(ax, "y")
x, y = inv_x(x), inv_y(y)
if "style" in self.variables:
# Use a representative marker so scatter sets the edgecolor
# properly for line art markers. We currently enforce either
# all or none line art so this works.
example_level = self._style_map.levels[0]
example_marker = self._style_map(example_level, "marker")
kws.setdefault("marker", example_marker)
# Conditionally set the marker edgecolor based on whether the marker is "filled"
# See https://github.com/matplotlib/matplotlib/issues/17849 for context
m = kws.get("marker", mpl.rcParams.get("marker", "o"))
if not isinstance(m, mpl.markers.MarkerStyle):
# TODO in more recent matplotlib (which?) can pass a MarkerStyle here
m = mpl.markers.MarkerStyle(m)
if m.is_filled():
kws.setdefault("edgecolor", "w")
# Draw the scatter plot
points = ax.scatter(x=x, y=y, **kws)
# Apply the mapping from semantic variables to artist attributes
if "hue" in self.variables:
points.set_facecolors(self._hue_map(data["hue"]))
if "size" in self.variables:
points.set_sizes(self._size_map(data["size"]))
if "style" in self.variables:
p = [self._style_map(val, "path") for val in data["style"]]
points.set_paths(p)
# Apply dependent default attributes
if "linewidth" not in kws:
sizes = points.get_sizes()
linewidth = .08 * np.sqrt(np.percentile(sizes, 10))
points.set_linewidths(linewidth)
kws["linewidth"] = linewidth
# Finalize the axes details
self._add_axis_labels(ax)
if self.legend:
attrs = {"hue": "color", "size": "s", "style": None}
self.add_legend_data(ax, _scatter_legend_artist, kws, attrs)
handles, _ = ax.get_legend_handles_labels()
if handles:
legend = ax.legend(title=self.legend_title)
adjust_legend_subtitles(legend)
def lineplot(
data=None, *,
x=None, y=None, hue=None, size=None, style=None, units=None, weights=None,
palette=None, hue_order=None, hue_norm=None,
sizes=None, size_order=None, size_norm=None,
dashes=True, markers=None, style_order=None,
estimator="mean", errorbar=("ci", 95), n_boot=1000, seed=None,
orient="x", sort=True, err_style="band", err_kws=None,
legend="auto", ci="deprecated", ax=None, **kwargs
):
# Handle deprecation of ci parameter
errorbar = _deprecate_ci(errorbar, ci)
p = _LinePlotter(
data=data,
variables=dict(
x=x, y=y, hue=hue, size=size, style=style, units=units, weight=weights
),
estimator=estimator, n_boot=n_boot, seed=seed, errorbar=errorbar,
sort=sort, orient=orient, err_style=err_style, err_kws=err_kws,
legend=legend,
)
p.map_hue(palette=palette, order=hue_order, norm=hue_norm)
p.map_size(sizes=sizes, order=size_order, norm=size_norm)
p.map_style(markers=markers, dashes=dashes, order=style_order)
if ax is None:
ax = plt.gca()
if "style" not in p.variables and not {"ls", "linestyle"} & set(kwargs): # XXX
kwargs["dashes"] = "" if dashes is None or isinstance(dashes, bool) else dashes
if not p.has_xy_data:
return ax
p._attach(ax)
# Other functions have color as an explicit param,
# and we should probably do that here too
color = kwargs.pop("color", kwargs.pop("c", None))
kwargs["color"] = _default_color(ax.plot, hue, color, kwargs)
p.plot(ax, kwargs)
return ax
lineplot.__doc__ = """\
Draw a line plot with possibility of several semantic groupings.
{narrative.main_api}
{narrative.relational_semantic}
By default, the plot aggregates over multiple `y` values at each value of
`x` and shows an estimate of the central tendency and a confidence
interval for that estimate.
Parameters
----------
{params.core.data}
{params.core.xy}
hue : vector or key in `data`
Grouping variable that will produce lines with different colors.
Can be either categorical or numeric, although color mapping will
behave differently in latter case.
size : vector or key in `data`
Grouping variable that will produce lines with different widths.
Can be either categorical or numeric, although size mapping will
behave differently in latter case.
style : vector or key in `data`
Grouping variable that will produce lines with different dashes
and/or markers. Can have a numeric dtype but will always be treated
as categorical.
{params.rel.units}
weights : vector or key in `data`
Data values or column used to compute weighted estimation.
Note that use of weights currently limits the choice of statistics
to a 'mean' estimator and 'ci' errorbar.
{params.core.palette}
{params.core.hue_order}
{params.core.hue_norm}
{params.rel.sizes}
{params.rel.size_order}
{params.rel.size_norm}
{params.rel.dashes}
{params.rel.markers}
{params.rel.style_order}
{params.rel.estimator}
{params.stat.errorbar}
{params.rel.n_boot}
{params.rel.seed}
orient : "x" or "y"
Dimension along which the data are sorted / aggregated. Equivalently,
the "independent variable" of the resulting function.
sort : boolean
If True, the data will be sorted by the x and y variables, otherwise
lines will connect points in the order they appear in the dataset.
err_style : "band" or "bars"
Whether to draw the confidence intervals with translucent error bands
or discrete error bars.
err_kws : dict of keyword arguments
Additional parameters to control the aesthetics of the error bars. The
kwargs are passed either to :meth:`matplotlib.axes.Axes.fill_between`
or :meth:`matplotlib.axes.Axes.errorbar`, depending on `err_style`.
{params.rel.legend}
{params.rel.ci}
{params.core.ax}
kwargs : key, value mappings
Other keyword arguments are passed down to
:meth:`matplotlib.axes.Axes.plot`.
Returns
-------
{returns.ax}
See Also
--------
{seealso.scatterplot}
{seealso.pointplot}
Examples
--------
.. include:: ../docstrings/lineplot.rst
""".format(
narrative=_relational_narrative,
params=_param_docs,
returns=_core_docs["returns"],
seealso=_core_docs["seealso"],
)
def scatterplot(
data=None, *,
x=None, y=None, hue=None, size=None, style=None,
palette=None, hue_order=None, hue_norm=None,
sizes=None, size_order=None, size_norm=None,
markers=True, style_order=None, legend="auto", ax=None,
**kwargs
):
p = _ScatterPlotter(
data=data,
variables=dict(x=x, y=y, hue=hue, size=size, style=style),
legend=legend
)
p.map_hue(palette=palette, order=hue_order, norm=hue_norm)
p.map_size(sizes=sizes, order=size_order, norm=size_norm)
p.map_style(markers=markers, order=style_order)
if ax is None:
ax = plt.gca()
if not p.has_xy_data:
return ax
p._attach(ax)
color = kwargs.pop("color", None)
kwargs["color"] = _default_color(ax.scatter, hue, color, kwargs)
p.plot(ax, kwargs)
return ax
scatterplot.__doc__ = """\
Draw a scatter plot with possibility of several semantic groupings.
{narrative.main_api}
{narrative.relational_semantic}
Parameters
----------
{params.core.data}
{params.core.xy}
hue : vector or key in `data`
Grouping variable that will produce points with different colors.
Can be either categorical or numeric, although color mapping will
behave differently in latter case.
size : vector or key in `data`
Grouping variable that will produce points with different sizes.
Can be either categorical or numeric, although size mapping will
behave differently in latter case.
style : vector or key in `data`
Grouping variable that will produce points with different markers.
Can have a numeric dtype but will always be treated as categorical.
{params.core.palette}
{params.core.hue_order}
{params.core.hue_norm}
{params.rel.sizes}
{params.rel.size_order}
{params.rel.size_norm}
{params.rel.markers}
{params.rel.style_order}
{params.rel.legend}
{params.core.ax}
kwargs : key, value mappings
Other keyword arguments are passed down to
:meth:`matplotlib.axes.Axes.scatter`.
Returns
-------
{returns.ax}
See Also
--------
{seealso.lineplot}
{seealso.stripplot}
{seealso.swarmplot}
Examples
--------
.. include:: ../docstrings/scatterplot.rst
""".format(
narrative=_relational_narrative,
params=_param_docs,
returns=_core_docs["returns"],
seealso=_core_docs["seealso"],
)
def relplot(
data=None, *,
x=None, y=None, hue=None, size=None, style=None, units=None, weights=None,
row=None, col=None, col_wrap=None, row_order=None, col_order=None,
palette=None, hue_order=None, hue_norm=None,
sizes=None, size_order=None, size_norm=None,
markers=None, dashes=None, style_order=None,
legend="auto", kind="scatter", height=5, aspect=1, facet_kws=None,
**kwargs
):
if kind == "scatter":
Plotter = _ScatterPlotter
func = scatterplot
markers = True if markers is None else markers
elif kind == "line":
Plotter = _LinePlotter
func = lineplot
dashes = True if dashes is None else dashes
else:
err = f"Plot kind {kind} not recognized"
raise ValueError(err)
# Check for attempt to plot onto specific axes and warn
if "ax" in kwargs:
msg = (
"relplot is a figure-level function and does not accept "
"the `ax` parameter. You may wish to try {}".format(kind + "plot")
)
warnings.warn(msg, UserWarning)
kwargs.pop("ax")
# Use the full dataset to map the semantics
variables = dict(x=x, y=y, hue=hue, size=size, style=style)
if kind == "line":
variables["units"] = units
variables["weight"] = weights
else:
if units is not None:
msg = "The `units` parameter has no effect with kind='scatter'."
warnings.warn(msg, stacklevel=2)
if weights is not None:
msg = "The `weights` parameter has no effect with kind='scatter'."
warnings.warn(msg, stacklevel=2)
p = Plotter(
data=data,
variables=variables,
legend=legend,
)
p.map_hue(palette=palette, order=hue_order, norm=hue_norm)
p.map_size(sizes=sizes, order=size_order, norm=size_norm)
p.map_style(markers=markers, dashes=dashes, order=style_order)
# Extract the semantic mappings
if "hue" in p.variables:
palette = p._hue_map.lookup_table
hue_order = p._hue_map.levels
hue_norm = p._hue_map.norm
else:
palette = hue_order = hue_norm = None
if "size" in p.variables:
sizes = p._size_map.lookup_table
size_order = p._size_map.levels
size_norm = p._size_map.norm
if "style" in p.variables:
style_order = p._style_map.levels
if markers:
markers = {k: p._style_map(k, "marker") for k in style_order}
else:
markers = None
if dashes:
dashes = {k: p._style_map(k, "dashes") for k in style_order}
else:
dashes = None
else:
markers = dashes = style_order = None
# Now extract the data that would be used to draw a single plot
variables = p.variables
plot_data = p.plot_data
# Define the common plotting parameters
plot_kws = dict(
palette=palette, hue_order=hue_order, hue_norm=hue_norm,
sizes=sizes, size_order=size_order, size_norm=size_norm,
markers=markers, dashes=dashes, style_order=style_order,
legend=False,
)
plot_kws.update(kwargs)
if kind == "scatter":
plot_kws.pop("dashes")
# Add the grid semantics onto the plotter
grid_variables = dict(
x=x, y=y, row=row, col=col, hue=hue, size=size, style=style,
)
if kind == "line":
grid_variables.update(units=units, weights=weights)
p.assign_variables(data, grid_variables)
# Define the named variables for plotting on each facet
# Rename the variables with a leading underscore to avoid
# collisions with faceting variable names
plot_variables = {v: f"_{v}" for v in variables}
if "weight" in plot_variables:
plot_variables["weights"] = plot_variables.pop("weight")
plot_kws.update(plot_variables)
# Pass the row/col variables to FacetGrid with their original
# names so that the axes titles render correctly
for var in ["row", "col"]:
# Handle faceting variables that lack name information
if var in p.variables and p.variables[var] is None:
p.variables[var] = f"_{var}_"
grid_kws = {v: p.variables.get(v) for v in ["row", "col"]}
# Rename the columns of the plot_data structure appropriately
new_cols = plot_variables.copy()
new_cols.update(grid_kws)
full_data = p.plot_data.rename(columns=new_cols)
# Set up the FacetGrid object
facet_kws = {} if facet_kws is None else facet_kws.copy()
g = FacetGrid(
data=full_data.dropna(axis=1, how="all"),
**grid_kws,
col_wrap=col_wrap, row_order=row_order, col_order=col_order,
height=height, aspect=aspect, dropna=False,
**facet_kws
)
# Draw the plot
g.map_dataframe(func, **plot_kws)
# Label the axes, using the original variables
# Pass "" when the variable name is None to overwrite internal variables
g.set_axis_labels(variables.get("x") or "", variables.get("y") or "")
if legend:
# Replace the original plot data so the legend uses numeric data with
# the correct type, since we force a categorical mapping above.
p.plot_data = plot_data
# Handle the additional non-semantic keyword arguments out here.
# We're selective because some kwargs may be seaborn function specific
# and not relevant to the matplotlib artists going into the legend.
# Ideally, we will have a better solution where we don't need to re-make
# the legend out here and will have parity with the axes-level functions.
keys = ["c", "color", "alpha", "m", "marker"]
if kind == "scatter":
legend_artist = _scatter_legend_artist
keys += ["s", "facecolor", "fc", "edgecolor", "ec", "linewidth", "lw"]
else:
legend_artist = partial(mpl.lines.Line2D, xdata=[], ydata=[])
keys += [
"markersize", "ms",
"markeredgewidth", "mew",
"markeredgecolor", "mec",
"linestyle", "ls",
"linewidth", "lw",
]
common_kws = {k: v for k, v in kwargs.items() if k in keys}
attrs = {"hue": "color", "style": None}
if kind == "scatter":
attrs["size"] = "s"
elif kind == "line":
attrs["size"] = "linewidth"
p.add_legend_data(g.axes.flat[0], legend_artist, common_kws, attrs)
if p.legend_data:
g.add_legend(legend_data=p.legend_data,
label_order=p.legend_order,
title=p.legend_title,
adjust_subtitles=True)
# Rename the columns of the FacetGrid's `data` attribute
# to match the original column names
orig_cols = {
f"_{k}": f"_{k}_" if v is None else v for k, v in variables.items()
}
grid_data = g.data.rename(columns=orig_cols)
if data is not None and (x is not None or y is not None):
if not isinstance(data, pd.DataFrame):
data = pd.DataFrame(data)
g.data = pd.merge(
data,
grid_data[grid_data.columns.difference(data.columns)],
left_index=True,
right_index=True,
)
else:
g.data = grid_data
return g
relplot.__doc__ = """\
Figure-level interface for drawing relational plots onto a FacetGrid.
This function provides access to several different axes-level functions
that show the relationship between two variables with semantic mappings
of subsets. The `kind` parameter selects the underlying axes-level
function to use:
- :func:`scatterplot` (with `kind="scatter"`; the default)
- :func:`lineplot` (with `kind="line"`)
Extra keyword arguments are passed to the underlying function, so you
should refer to the documentation for each to see kind-specific options.
{narrative.main_api}
{narrative.relational_semantic}
After plotting, the :class:`FacetGrid` with the plot is returned and can
be used directly to tweak supporting plot details or add other layers.
Parameters
----------
{params.core.data}
{params.core.xy}
hue : vector or key in `data`
Grouping variable that will produce elements with different colors.
Can be either categorical or numeric, although color mapping will
behave differently in latter case.
size : vector or key in `data`
Grouping variable that will produce elements with different sizes.
Can be either categorical or numeric, although size mapping will
behave differently in latter case.
style : vector or key in `data`
Grouping variable that will produce elements with different styles.
Can have a numeric dtype but will always be treated as categorical.
{params.rel.units}
weights : vector or key in `data`
Data values or column used to compute weighted estimation.
Note that use of weights currently limits the choice of statistics
to a 'mean' estimator and 'ci' errorbar.
{params.facets.rowcol}
{params.facets.col_wrap}
row_order, col_order : lists of strings
Order to organize the rows and/or columns of the grid in, otherwise the
orders are inferred from the data objects.
{params.core.palette}
{params.core.hue_order}
{params.core.hue_norm}
{params.rel.sizes}
{params.rel.size_order}
{params.rel.size_norm}
{params.rel.style_order}
{params.rel.dashes}
{params.rel.markers}
{params.rel.legend}
kind : string
Kind of plot to draw, corresponding to a seaborn relational plot.
Options are `"scatter"` or `"line"`.
{params.facets.height}
{params.facets.aspect}
facet_kws : dict
Dictionary of other keyword arguments to pass to :class:`FacetGrid`.
kwargs : key, value pairings
Other keyword arguments are passed through to the underlying plotting
function.
Returns
-------
{returns.facetgrid}
Examples
--------
.. include:: ../docstrings/relplot.rst
""".format(
narrative=_relational_narrative,
params=_param_docs,
returns=_core_docs["returns"],
)
================================================
FILE: seaborn/utils.py
================================================
"""Utility functions, mostly for internal use."""
import os
import inspect
import warnings
import colorsys
from contextlib import contextmanager
from urllib.request import urlopen, urlretrieve
from types import ModuleType
import numpy as np
import pandas as pd
import matplotlib as mpl
from matplotlib.colors import to_rgb
import matplotlib.pyplot as plt
from matplotlib.cbook import normalize_kwargs
from seaborn._core.typing import deprecated
from seaborn.external.version import Version
from seaborn.external.appdirs import user_cache_dir
__all__ = ["desaturate", "saturate", "set_hls_values", "move_legend",
"despine", "get_dataset_names", "get_data_home", "load_dataset"]
DATASET_SOURCE = "https://raw.githubusercontent.com/mwaskom/seaborn-data/master"
DATASET_NAMES_URL = f"{DATASET_SOURCE}/dataset_names.txt"
def ci_to_errsize(cis, heights):
"""Convert intervals to error arguments relative to plot heights.
Parameters
----------
cis : 2 x n sequence
sequence of confidence interval limits
heights : n sequence
sequence of plot heights
Returns
-------
errsize : 2 x n array
sequence of error size relative to height values in correct
format as argument for plt.bar
"""
cis = np.atleast_2d(cis).reshape(2, -1)
heights = np.atleast_1d(heights)
errsize = []
for i, (low, high) in enumerate(np.transpose(cis)):
h = heights[i]
elow = h - low
ehigh = high - h
errsize.append([elow, ehigh])
errsize = np.asarray(errsize).T
return errsize
def _draw_figure(fig):
"""Force draw of a matplotlib figure, accounting for back-compat."""
# See https://github.com/matplotlib/matplotlib/issues/19197 for context
fig.canvas.draw()
if fig.stale:
try:
fig.draw(fig.canvas.get_renderer())
except AttributeError:
pass
def _default_color(method, hue, color, kws, saturation=1):
"""If needed, get a default color by using the matplotlib property cycle."""
if hue is not None:
# This warning is probably user-friendly, but it's currently triggered
# in a FacetGrid context and I don't want to mess with that logic right now
# if color is not None:
# msg = "`color` is ignored when `hue` is assigned."
# warnings.warn(msg)
return None
kws = kws.copy()
kws.pop("label", None)
if color is not None:
if saturation < 1:
color = desaturate(color, saturation)
return color
elif method.__name__ == "plot":
color = normalize_kwargs(kws, mpl.lines.Line2D).get("color")
scout, = method([], [], scalex=False, scaley=False, color=color)
color = scout.get_color()
scout.remove()
elif method.__name__ == "scatter":
# Matplotlib will raise if the size of x/y don't match s/c,
# and the latter might be in the kws dict
scout_size = max(
np.atleast_1d(kws.get(key, [])).shape[0]
for key in ["s", "c", "fc", "facecolor", "facecolors"]
)
scout_x = scout_y = np.full(scout_size, np.nan)
scout = method(scout_x, scout_y, **kws)
facecolors = scout.get_facecolors()
if not len(facecolors):
# Handle bug in matplotlib <= 3.2 (I think)
# This will limit the ability to use non color= kwargs to specify
# a color in versions of matplotlib with the bug, but trying to
# work out what the user wanted by re-implementing the broken logic
# of inspecting the kwargs is probably too brittle.
single_color = False
else:
single_color = np.unique(facecolors, axis=0).shape[0] == 1
# Allow the user to specify an array of colors through various kwargs
if "c" not in kws and single_color:
color = to_rgb(facecolors[0])
scout.remove()
elif method.__name__ == "bar":
# bar() needs masked, not empty data, to generate a patch
scout, = method([np.nan], [np.nan], **kws)
color = to_rgb(scout.get_facecolor())
scout.remove()
# Axes.bar adds both a patch and a container
method.__self__.containers.pop(-1)
elif method.__name__ == "fill_between":
kws = normalize_kwargs(kws, mpl.collections.PolyCollection)
scout = method([], [], **kws)
facecolor = scout.get_facecolor()
color = to_rgb(facecolor[0])
scout.remove()
if saturation < 1:
color = desaturate(color, saturation)
return color
def desaturate(color, prop):
"""Decrease the saturation channel of a color by some percent.
Parameters
----------
color : matplotlib color
hex, rgb-tuple, or html color name
prop : float
saturation channel of color will be multiplied by this value
Returns
-------
new_color : rgb tuple
desaturated color code in RGB tuple representation
"""
# Check inputs
if not 0 <= prop <= 1:
raise ValueError("prop must be between 0 and 1")
# Get rgb tuple rep
rgb = to_rgb(color)
# Short circuit to avoid floating point issues
if prop == 1:
return rgb
# Convert to hls
h, l, s = colorsys.rgb_to_hls(*rgb)
# Desaturate the saturation channel
s *= prop
# Convert back to rgb
new_color = colorsys.hls_to_rgb(h, l, s)
return new_color
def saturate(color):
"""Return a fully saturated color with the same hue.
Parameters
----------
color : matplotlib color
hex, rgb-tuple, or html color name
Returns
-------
new_color : rgb tuple
saturated color code in RGB tuple representation
"""
return set_hls_values(color, s=1)
def set_hls_values(color, h=None, l=None, s=None): # noqa
"""Independently manipulate the h, l, or s channels of a color.
Parameters
----------
color : matplotlib color
hex, rgb-tuple, or html color name
h, l, s : floats between 0 and 1, or None
new values for each channel in hls space
Returns
-------
new_color : rgb tuple
new color code in RGB tuple representation
"""
# Get an RGB tuple representation
rgb = to_rgb(color)
vals = list(colorsys.rgb_to_hls(*rgb))
for i, val in enumerate([h, l, s]):
if val is not None:
vals[i] = val
rgb = colorsys.hls_to_rgb(*vals)
return rgb
def axlabel(xlabel, ylabel, **kwargs):
"""Grab current axis and label it.
DEPRECATED: will be removed in a future version.
"""
msg = "This function is deprecated and will be removed in a future version"
warnings.warn(msg, FutureWarning)
ax = plt.gca()
ax.set_xlabel(xlabel, **kwargs)
ax.set_ylabel(ylabel, **kwargs)
def remove_na(vector):
"""Helper method for removing null values from data vectors.
Parameters
----------
vector : vector object
Must implement boolean masking with [] subscript syntax.
Returns
-------
clean_clean : same type as ``vector``
Vector of data with null values removed. May be a copy or a view.
"""
return vector[pd.notnull(vector)]
def get_color_cycle():
"""Return the list of colors in the current matplotlib color cycle
Parameters
----------
None
Returns
-------
colors : list
List of matplotlib colors in the current cycle, or dark gray if
the current color cycle is empty.
"""
cycler = mpl.rcParams['axes.prop_cycle']
return cycler.by_key()['color'] if 'color' in cycler.keys else [".15"]
def despine(fig=None, ax=None, top=True, right=True, left=False,
bottom=False, offset=None, trim=False):
"""Remove the top and right spines from plot(s).
fig : matplotlib figure, optional
Figure to despine all axes of, defaults to the current figure.
ax : matplotlib axes, optional
Specific axes object to despine. Ignored if fig is provided.
top, right, left, bottom : boolean, optional
If True, remove that spine.
offset : int or dict, optional
Absolute distance, in points, spines should be moved away
from the axes (negative values move spines inward). A single value
applies to all spines; a dict can be used to set offset values per
side.
trim : bool, optional
If True, limit spines to the smallest and largest major tick
on each non-despined axis.
Returns
-------
None
"""
# Get references to the axes we want
if fig is None and ax is None:
axes = plt.gcf().axes
elif fig is not None:
axes = fig.axes
elif ax is not None:
axes = [ax]
for ax_i in axes:
for side in ["top", "right", "left", "bottom"]:
# Toggle the spine objects
is_visible = not locals()[side]
ax_i.spines[side].set_visible(is_visible)
if offset is not None and is_visible:
try:
val = offset.get(side, 0)
except AttributeError:
val = offset
ax_i.spines[side].set_position(('outward', val))
# Potentially move the ticks
if left and not right:
maj_on = any(
t.tick1line.get_visible()
for t in ax_i.yaxis.majorTicks
)
min_on = any(
t.tick1line.get_visible()
for t in ax_i.yaxis.minorTicks
)
ax_i.yaxis.set_ticks_position("right")
for t in ax_i.yaxis.majorTicks:
t.tick2line.set_visible(maj_on)
for t in ax_i.yaxis.minorTicks:
t.tick2line.set_visible(min_on)
if bottom and not top:
maj_on = any(
t.tick1line.get_visible()
for t in ax_i.xaxis.majorTicks
)
min_on = any(
t.tick1line.get_visible()
for t in ax_i.xaxis.minorTicks
)
ax_i.xaxis.set_ticks_position("top")
for t in ax_i.xaxis.majorTicks:
t.tick2line.set_visible(maj_on)
for t in ax_i.xaxis.minorTicks:
t.tick2line.set_visible(min_on)
if trim:
# clip off the parts of the spines that extend past major ticks
xticks = np.asarray(ax_i.get_xticks())
if xticks.size:
firsttick = np.compress(xticks >= min(ax_i.get_xlim()),
xticks)[0]
lasttick = np.compress(xticks <= max(ax_i.get_xlim()),
xticks)[-1]
ax_i.spines['bottom'].set_bounds(firsttick, lasttick)
ax_i.spines['top'].set_bounds(firsttick, lasttick)
newticks = xticks.compress(xticks <= lasttick)
newticks = newticks.compress(newticks >= firsttick)
ax_i.set_xticks(newticks)
yticks = np.asarray(ax_i.get_yticks())
if yticks.size:
firsttick = np.compress(yticks >= min(ax_i.get_ylim()),
yticks)[0]
lasttick = np.compress(yticks <= max(ax_i.get_ylim()),
yticks)[-1]
ax_i.spines['left'].set_bounds(firsttick, lasttick)
ax_i.spines['right'].set_bounds(firsttick, lasttick)
newticks = yticks.compress(yticks <= lasttick)
newticks = newticks.compress(newticks >= firsttick)
ax_i.set_yticks(newticks)
def move_legend(obj, loc, **kwargs):
"""
Recreate a plot's legend at a new location.
The name is a slight misnomer. Matplotlib legends do not expose public
control over their position parameters. So this function creates a new legend,
copying over the data from the original object, which is then removed.
Parameters
----------
obj : the object with the plot
This argument can be either a seaborn or matplotlib object:
- :class:`seaborn.FacetGrid` or :class:`seaborn.PairGrid`
- :class:`matplotlib.axes.Axes` or :class:`matplotlib.figure.Figure`
loc : str or int
Location argument, as in :meth:`matplotlib.axes.Axes.legend`.
kwargs
Other keyword arguments are passed to :meth:`matplotlib.axes.Axes.legend`.
Examples
--------
.. include:: ../docstrings/move_legend.rst
"""
# This is a somewhat hackish solution that will hopefully be obviated by
# upstream improvements to matplotlib legends that make them easier to
# modify after creation.
from seaborn.axisgrid import Grid # Avoid circular import
# Locate the legend object and a method to recreate the legend
if isinstance(obj, Grid):
old_legend = obj.legend
legend_func = obj.figure.legend
elif isinstance(obj, mpl.axes.Axes):
old_legend = obj.legend_
legend_func = obj.legend
elif isinstance(obj, mpl.figure.Figure):
if obj.legends:
old_legend = obj.legends[-1]
else:
old_legend = None
legend_func = obj.legend
else:
err = "`obj` must be a seaborn Grid or matplotlib Axes or Figure instance."
raise TypeError(err)
if old_legend is None:
err = f"{obj} has no legend attached."
raise ValueError(err)
# Extract the components of the legend we need to reuse
# Import here to avoid a circular import
from seaborn._compat import get_legend_handles
handles = get_legend_handles(old_legend)
labels = [t.get_text() for t in old_legend.get_texts()]
# Handle the case where the user is trying to override the labels
if (new_labels := kwargs.pop("labels", None)) is not None:
if len(new_labels) != len(labels):
err = "Length of new labels does not match existing legend."
raise ValueError(err)
labels = new_labels
# Extract legend properties that can be passed to the recreation method
# (Vexingly, these don't all round-trip)
legend_kws = inspect.signature(mpl.legend.Legend).parameters
props = {k: v for k, v in old_legend.properties().items() if k in legend_kws}
# Delegate default bbox_to_anchor rules to matplotlib
props.pop("bbox_to_anchor")
# Try to propagate the existing title and font properties; respect new ones too
title = props.pop("title")
if "title" in kwargs:
title.set_text(kwargs.pop("title"))
title_kwargs = {k: v for k, v in kwargs.items() if k.startswith("title_")}
for key, val in title_kwargs.items():
title.set(**{key[6:]: val})
kwargs.pop(key)
# Try to respect the frame visibility
kwargs.setdefault("frameon", old_legend.legendPatch.get_visible())
# Remove the old legend and create the new one
props.update(kwargs)
old_legend.remove()
new_legend = legend_func(handles, labels, loc=loc, **props)
new_legend.set_title(title.get_text(), title.get_fontproperties())
# Let the Grid object continue to track the correct legend object
if isinstance(obj, Grid):
obj._legend = new_legend
def _kde_support(data, bw, gridsize, cut, clip):
"""Establish support for a kernel density estimate."""
support_min = max(data.min() - bw * cut, clip[0])
support_max = min(data.max() + bw * cut, clip[1])
support = np.linspace(support_min, support_max, gridsize)
return support
def ci(a, which=95, axis=None):
"""Return a percentile range from an array of values."""
p = 50 - which / 2, 50 + which / 2
return np.nanpercentile(a, p, axis)
def get_dataset_names():
"""Report available example datasets, useful for reporting issues.
Requires an internet connection.
"""
with urlopen(DATASET_NAMES_URL) as resp:
txt = resp.read()
dataset_names = [name.strip() for name in txt.decode().split("\n")]
return list(filter(None, dataset_names))
def get_data_home(data_home=None):
"""Return a path to the cache directory for example datasets.
This directory is used by :func:`load_dataset`.
If the ``data_home`` argument is not provided, it will use a directory
specified by the `SEABORN_DATA` environment variable (if it exists)
or otherwise default to an OS-appropriate user cache location.
"""
if data_home is None:
data_home = os.environ.get("SEABORN_DATA", user_cache_dir("seaborn"))
data_home = os.path.expanduser(data_home)
if not os.path.exists(data_home):
os.makedirs(data_home)
return data_home
def load_dataset(name, cache=True, data_home=None, **kws):
"""Load an example dataset from the online repository (requires internet).
This function provides quick access to a small number of example datasets
that are useful for documenting seaborn or generating reproducible examples
for bug reports. It is not necessary for normal usage.
Note that some of the datasets have a small amount of preprocessing applied
to define a proper ordering for categorical variables.
Use :func:`get_dataset_names` to see a list of available datasets.
Parameters
----------
name : str
Name of the dataset (``{name}.csv`` on
https://github.com/mwaskom/seaborn-data).
cache : boolean, optional
If True, try to load from the local cache first, and save to the cache
if a download is required.
data_home : string, optional
The directory in which to cache data; see :func:`get_data_home`.
kws : keys and values, optional
Additional keyword arguments are passed to passed through to
:func:`pandas.read_csv`.
Returns
-------
df : :class:`pandas.DataFrame`
Tabular data, possibly with some preprocessing applied.
"""
# A common beginner mistake is to assume that one's personal data needs
# to be passed through this function to be usable with seaborn.
# Let's provide a more helpful error than you would otherwise get.
if isinstance(name, pd.DataFrame):
err = (
"This function accepts only strings (the name of an example dataset). "
"You passed a pandas DataFrame. If you have your own dataset, "
"it is not necessary to use this function before plotting."
)
raise TypeError(err)
url = f"{DATASET_SOURCE}/{name}.csv"
if cache:
cache_path = os.path.join(get_data_home(data_home), os.path.basename(url))
if not os.path.exists(cache_path):
if name not in get_dataset_names():
raise ValueError(f"'{name}' is not one of the example datasets.")
urlretrieve(url, cache_path)
full_path = cache_path
else:
full_path = url
df = pd.read_csv(full_path, **kws)
if df.iloc[-1].isnull().all():
df = df.iloc[:-1]
# Set some columns as a categorical type with ordered levels
if name == "tips":
df["day"] = pd.Categorical(df["day"], ["Thur", "Fri", "Sat", "Sun"])
df["sex"] = pd.Categorical(df["sex"], ["Male", "Female"])
df["time"] = pd.Categorical(df["time"], ["Lunch", "Dinner"])
df["smoker"] = pd.Categorical(df["smoker"], ["Yes", "No"])
elif name == "flights":
months = df["month"].str[:3]
df["month"] = pd.Categorical(months, months.unique())
elif name == "exercise":
df["time"] = pd.Categorical(df["time"], ["1 min", "15 min", "30 min"])
df["kind"] = pd.Categorical(df["kind"], ["rest", "walking", "running"])
df["diet"] = pd.Categorical(df["diet"], ["no fat", "low fat"])
elif name == "titanic":
df["class"] = pd.Categorical(df["class"], ["First", "Second", "Third"])
df["deck"] = pd.Categorical(df["deck"], list("ABCDEFG"))
elif name == "penguins":
df["sex"] = df["sex"].str.title()
elif name == "diamonds":
df["color"] = pd.Categorical(
df["color"], ["D", "E", "F", "G", "H", "I", "J"],
)
df["clarity"] = pd.Categorical(
df["clarity"], ["IF", "VVS1", "VVS2", "VS1", "VS2", "SI1", "SI2", "I1"],
)
df["cut"] = pd.Categorical(
df["cut"], ["Ideal", "Premium", "Very Good", "Good", "Fair"],
)
elif name == "taxis":
df["pickup"] = pd.to_datetime(df["pickup"])
df["dropoff"] = pd.to_datetime(df["dropoff"])
elif name == "seaice":
df["Date"] = pd.to_datetime(df["Date"])
elif name == "dowjones":
df["Date"] = pd.to_datetime(df["Date"])
return df
def axis_ticklabels_overlap(labels):
"""Return a boolean for whether the list of ticklabels have overlaps.
Parameters
----------
labels : list of matplotlib ticklabels
Returns
-------
overlap : boolean
True if any of the labels overlap.
"""
if not labels:
return False
try:
bboxes = [l.get_window_extent() for l in labels]
overlaps = [b.count_overlaps(bboxes) for b in bboxes]
return max(overlaps) > 1
except RuntimeError:
# Issue on macos backend raises an error in the above code
return False
def axes_ticklabels_overlap(ax):
"""Return booleans for whether the x and y ticklabels on an Axes overlap.
Parameters
----------
ax : matplotlib Axes
Returns
-------
x_overlap, y_overlap : booleans
True when the labels on that axis overlap.
"""
return (axis_ticklabels_overlap(ax.get_xticklabels()),
axis_ticklabels_overlap(ax.get_yticklabels()))
def locator_to_legend_entries(locator, limits, dtype):
"""Return levels and formatted levels for brief numeric legends."""
raw_levels = locator.tick_values(*limits).astype(dtype)
# The locator can return ticks outside the limits, clip them here
raw_levels = [l for l in raw_levels if l >= limits[0] and l <= limits[1]]
class dummy_axis:
def get_view_interval(self):
return limits
if isinstance(locator, mpl.ticker.LogLocator):
formatter = mpl.ticker.LogFormatter()
else:
formatter = mpl.ticker.ScalarFormatter()
# Avoid having an offset/scientific notation which we don't currently
# have any way of representing in the legend
formatter.set_useOffset(False)
formatter.set_scientific(False)
formatter.axis = dummy_axis()
formatted_levels = formatter.format_ticks(raw_levels)
return raw_levels, formatted_levels
def relative_luminance(color):
"""Calculate the relative luminance of a color according to W3C standards
Parameters
----------
color : matplotlib color or sequence of matplotlib colors
Hex code, rgb-tuple, or html color name.
Returns
-------
luminance : float(s) between 0 and 1
"""
rgb = mpl.colors.colorConverter.to_rgba_array(color)[:, :3]
rgb = np.where(rgb <= .03928, rgb / 12.92, ((rgb + .055) / 1.055) ** 2.4)
lum = rgb.dot([.2126, .7152, .0722])
try:
return lum.item()
except ValueError:
return lum
def to_utf8(obj):
"""Return a string representing a Python object.
Strings (i.e. type ``str``) are returned unchanged.
Byte strings (i.e. type ``bytes``) are returned as UTF-8-decoded strings.
For other objects, the method ``__str__()`` is called, and the result is
returned as a string.
Parameters
----------
obj : object
Any Python object
Returns
-------
s : str
UTF-8-decoded string representation of ``obj``
"""
if isinstance(obj, str):
return obj
try:
return obj.decode(encoding="utf-8")
except AttributeError: # obj is not bytes-like
return str(obj)
def _check_argument(param, options, value, prefix=False):
"""Raise if value for param is not in options."""
if prefix and value is not None:
failure = not any(value.startswith(p) for p in options if isinstance(p, str))
else:
failure = value not in options
if failure:
raise ValueError(
f"The value for `{param}` must be one of {options}, "
f"but {repr(value)} was passed."
)
return value
def _assign_default_kwargs(kws, call_func, source_func):
"""Assign default kwargs for call_func using values from source_func."""
# This exists so that axes-level functions and figure-level functions can
# both call a Plotter method while having the default kwargs be defined in
# the signature of the axes-level function.
# An alternative would be to have a decorator on the method that sets its
# defaults based on those defined in the axes-level function.
# Then the figure-level function would not need to worry about defaults.
# I am not sure which is better.
needed = inspect.signature(call_func).parameters
defaults = inspect.signature(source_func).parameters
for param in needed:
if param in defaults and param not in kws:
kws[param] = defaults[param].default
return kws
def adjust_legend_subtitles(legend):
"""
Make invisible-handle "subtitles" entries look more like titles.
Note: This function is not part of the public API and may be changed or removed.
"""
# Legend title not in rcParams until 3.0
font_size = plt.rcParams.get("legend.title_fontsize", None)
hpackers = legend.findobj(mpl.offsetbox.VPacker)[0].get_children()
for hpack in hpackers:
draw_area, text_area = hpack.get_children()
handles = draw_area.get_children()
if not all(artist.get_visible() for artist in handles):
draw_area.set_width(0)
for text in text_area.get_children():
if font_size is not None:
text.set_size(font_size)
def _deprecate_ci(errorbar, ci):
"""
Warn on usage of ci= and convert to appropriate errorbar= arg.
ci was deprecated when errorbar was added in 0.12. It should not be removed
completely for some time, but it can be moved out of function definitions
(and extracted from kwargs) after one cycle.
"""
if ci is not deprecated and ci != "deprecated":
if ci is None:
errorbar = None
elif ci == "sd":
errorbar = "sd"
else:
errorbar = ("ci", ci)
msg = (
"\n\nThe `ci` parameter is deprecated. "
f"Use `errorbar={repr(errorbar)}` for the same effect.\n"
)
warnings.warn(msg, FutureWarning, stacklevel=3)
return errorbar
def _get_transform_functions(ax, axis):
"""Return the forward and inverse transforms for a given axis."""
axis_obj = getattr(ax, f"{axis}axis")
transform = axis_obj.get_transform()
return transform.transform, transform.inverted().transform
@contextmanager
def _disable_autolayout():
"""Context manager for preventing rc-controlled auto-layout behavior."""
# This is a workaround for an issue in matplotlib, for details see
# https://github.com/mwaskom/seaborn/issues/2914
# The only affect of this rcParam is to set the default value for
# layout= in plt.figure, so we could just do that instead.
# But then we would need to own the complexity of the transition
# from tight_layout=True -> layout="tight". This seems easier,
# but can be removed when (if) that is simpler on the matplotlib side,
# or if the layout algorithms are improved to handle figure legends.
orig_val = mpl.rcParams["figure.autolayout"]
try:
mpl.rcParams["figure.autolayout"] = False
yield
finally:
mpl.rcParams["figure.autolayout"] = orig_val
def _version_predates(lib: ModuleType, version: str) -> bool:
"""Helper function for checking version compatibility."""
return Version(lib.__version__) < Version(version)
def _scatter_legend_artist(**kws):
kws = normalize_kwargs(kws, mpl.collections.PathCollection)
edgecolor = kws.pop("edgecolor", None)
rc = mpl.rcParams
line_kws = {
"linestyle": "",
"marker": kws.pop("marker", "o"),
"markersize": np.sqrt(kws.pop("s", rc["lines.markersize"] ** 2)),
"markerfacecolor": kws.pop("facecolor", kws.get("color")),
"markeredgewidth": kws.pop("linewidth", 0),
**kws,
}
if edgecolor is not None:
if edgecolor == "face":
line_kws["markeredgecolor"] = line_kws["markerfacecolor"]
else:
line_kws["markeredgecolor"] = edgecolor
return mpl.lines.Line2D([], [], **line_kws)
def _get_patch_legend_artist(fill):
def legend_artist(**kws):
color = kws.pop("color", None)
if color is not None:
if fill:
kws["facecolor"] = color
else:
kws["edgecolor"] = color
kws["facecolor"] = "none"
return mpl.patches.Rectangle((0, 0), 0, 0, **kws)
return legend_artist
================================================
FILE: seaborn/widgets.py
================================================
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
try:
from ipywidgets import interact, FloatSlider, IntSlider
except ImportError:
def interact(f):
msg = "Interactive palettes require `ipywidgets`, which is not installed."
raise ImportError(msg)
from .miscplot import palplot
from .palettes import (color_palette, dark_palette, light_palette,
diverging_palette, cubehelix_palette)
__all__ = ["choose_colorbrewer_palette", "choose_cubehelix_palette",
"choose_dark_palette", "choose_light_palette",
"choose_diverging_palette"]
def _init_mutable_colormap():
"""Create a matplotlib colormap that will be updated by the widgets."""
greys = color_palette("Greys", 256)
cmap = LinearSegmentedColormap.from_list("interactive", greys)
cmap._init()
cmap._set_extremes()
return cmap
def _update_lut(cmap, colors):
"""Change the LUT values in a matplotlib colormap in-place."""
cmap._lut[:256] = colors
cmap._set_extremes()
def _show_cmap(cmap):
"""Show a continuous matplotlib colormap."""
from .rcmod import axes_style # Avoid circular import
with axes_style("white"):
f, ax = plt.subplots(figsize=(8.25, .75))
ax.set(xticks=[], yticks=[])
x = np.linspace(0, 1, 256)[np.newaxis, :]
ax.pcolormesh(x, cmap=cmap)
def choose_colorbrewer_palette(data_type, as_cmap=False):
"""Select a palette from the ColorBrewer set.
These palettes are built into matplotlib and can be used by name in
many seaborn functions, or by passing the object returned by this function.
Parameters
----------
data_type : {'sequential', 'diverging', 'qualitative'}
This describes the kind of data you want to visualize. See the seaborn
color palette docs for more information about how to choose this value.
Note that you can pass substrings (e.g. 'q' for 'qualitative.
as_cmap : bool
If True, the return value is a matplotlib colormap rather than a
list of discrete colors.
Returns
-------
pal or cmap : list of colors or matplotlib colormap
Object that can be passed to plotting functions.
See Also
--------
dark_palette : Create a sequential palette with dark low values.
light_palette : Create a sequential palette with bright low values.
diverging_palette : Create a diverging palette from selected colors.
cubehelix_palette : Create a sequential palette or colormap using the
cubehelix system.
"""
if data_type.startswith("q") and as_cmap:
raise ValueError("Qualitative palettes cannot be colormaps.")
pal = []
if as_cmap:
cmap = _init_mutable_colormap()
if data_type.startswith("s"):
opts = ["Greys", "Reds", "Greens", "Blues", "Oranges", "Purples",
"BuGn", "BuPu", "GnBu", "OrRd", "PuBu", "PuRd", "RdPu", "YlGn",
"PuBuGn", "YlGnBu", "YlOrBr", "YlOrRd"]
variants = ["regular", "reverse", "dark"]
@interact
def choose_sequential(name=opts, n=(2, 18),
desat=FloatSlider(min=0, max=1, value=1),
variant=variants):
if variant == "reverse":
name += "_r"
elif variant == "dark":
name += "_d"
if as_cmap:
colors = color_palette(name, 256, desat)
_update_lut(cmap, np.c_[colors, np.ones(256)])
_show_cmap(cmap)
else:
pal[:] = color_palette(name, n, desat)
palplot(pal)
elif data_type.startswith("d"):
opts = ["RdBu", "RdGy", "PRGn", "PiYG", "BrBG",
"RdYlBu", "RdYlGn", "Spectral"]
variants = ["regular", "reverse"]
@interact
def choose_diverging(name=opts, n=(2, 16),
desat=FloatSlider(min=0, max=1, value=1),
variant=variants):
if variant == "reverse":
name += "_r"
if as_cmap:
colors = color_palette(name, 256, desat)
_update_lut(cmap, np.c_[colors, np.ones(256)])
_show_cmap(cmap)
else:
pal[:] = color_palette(name, n, desat)
palplot(pal)
elif data_type.startswith("q"):
opts = ["Set1", "Set2", "Set3", "Paired", "Accent",
"Pastel1", "Pastel2", "Dark2"]
@interact
def choose_qualitative(name=opts, n=(2, 16),
desat=FloatSlider(min=0, max=1, value=1)):
pal[:] = color_palette(name, n, desat)
palplot(pal)
if as_cmap:
return cmap
return pal
def choose_dark_palette(input="husl", as_cmap=False):
"""Launch an interactive widget to create a dark sequential palette.
This corresponds with the :func:`dark_palette` function. This kind
of palette is good for data that range between relatively uninteresting
low values and interesting high values.
Requires IPython 2+ and must be used in the notebook.
Parameters
----------
input : {'husl', 'hls', 'rgb'}
Color space for defining the seed value. Note that the default is
different than the default input for :func:`dark_palette`.
as_cmap : bool
If True, the return value is a matplotlib colormap rather than a
list of discrete colors.
Returns
-------
pal or cmap : list of colors or matplotlib colormap
Object that can be passed to plotting functions.
See Also
--------
dark_palette : Create a sequential palette with dark low values.
light_palette : Create a sequential palette with bright low values.
cubehelix_palette : Create a sequential palette or colormap using the
cubehelix system.
"""
pal = []
if as_cmap:
cmap = _init_mutable_colormap()
if input == "rgb":
@interact
def choose_dark_palette_rgb(r=(0., 1.),
g=(0., 1.),
b=(0., 1.),
n=(3, 17)):
color = r, g, b
if as_cmap:
colors = dark_palette(color, 256, input="rgb")
_update_lut(cmap, colors)
_show_cmap(cmap)
else:
pal[:] = dark_palette(color, n, input="rgb")
palplot(pal)
elif input == "hls":
@interact
def choose_dark_palette_hls(h=(0., 1.),
l=(0., 1.), # noqa: E741
s=(0., 1.),
n=(3, 17)):
color = h, l, s
if as_cmap:
colors = dark_palette(color, 256, input="hls")
_update_lut(cmap, colors)
_show_cmap(cmap)
else:
pal[:] = dark_palette(color, n, input="hls")
palplot(pal)
elif input == "husl":
@interact
def choose_dark_palette_husl(h=(0, 359),
s=(0, 99),
l=(0, 99), # noqa: E741
n=(3, 17)):
color = h, s, l
if as_cmap:
colors = dark_palette(color, 256, input="husl")
_update_lut(cmap, colors)
_show_cmap(cmap)
else:
pal[:] = dark_palette(color, n, input="husl")
palplot(pal)
if as_cmap:
return cmap
return pal
def choose_light_palette(input="husl", as_cmap=False):
"""Launch an interactive widget to create a light sequential palette.
This corresponds with the :func:`light_palette` function. This kind
of palette is good for data that range between relatively uninteresting
low values and interesting high values.
Requires IPython 2+ and must be used in the notebook.
Parameters
----------
input : {'husl', 'hls', 'rgb'}
Color space for defining the seed value. Note that the default is
different than the default input for :func:`light_palette`.
as_cmap : bool
If True, the return value is a matplotlib colormap rather than a
list of discrete colors.
Returns
-------
pal or cmap : list of colors or matplotlib colormap
Object that can be passed to plotting functions.
See Also
--------
light_palette : Create a sequential palette with bright low values.
dark_palette : Create a sequential palette with dark low values.
cubehelix_palette : Create a sequential palette or colormap using the
cubehelix system.
"""
pal = []
if as_cmap:
cmap = _init_mutable_colormap()
if input == "rgb":
@interact
def choose_light_palette_rgb(r=(0., 1.),
g=(0., 1.),
b=(0., 1.),
n=(3, 17)):
color = r, g, b
if as_cmap:
colors = light_palette(color, 256, input="rgb")
_update_lut(cmap, colors)
_show_cmap(cmap)
else:
pal[:] = light_palette(color, n, input="rgb")
palplot(pal)
elif input == "hls":
@interact
def choose_light_palette_hls(h=(0., 1.),
l=(0., 1.), # noqa: E741
s=(0., 1.),
n=(3, 17)):
color = h, l, s
if as_cmap:
colors = light_palette(color, 256, input="hls")
_update_lut(cmap, colors)
_show_cmap(cmap)
else:
pal[:] = light_palette(color, n, input="hls")
palplot(pal)
elif input == "husl":
@interact
def choose_light_palette_husl(h=(0, 359),
s=(0, 99),
l=(0, 99), # noqa: E741
n=(3, 17)):
color = h, s, l
if as_cmap:
colors = light_palette(color, 256, input="husl")
_update_lut(cmap, colors)
_show_cmap(cmap)
else:
pal[:] = light_palette(color, n, input="husl")
palplot(pal)
if as_cmap:
return cmap
return pal
def choose_diverging_palette(as_cmap=False):
"""Launch an interactive widget to choose a diverging color palette.
This corresponds with the :func:`diverging_palette` function. This kind
of palette is good for data that range between interesting low values
and interesting high values with a meaningful midpoint. (For example,
change scores relative to some baseline value).
Requires IPython 2+ and must be used in the notebook.
Parameters
----------
as_cmap : bool
If True, the return value is a matplotlib colormap rather than a
list of discrete colors.
Returns
-------
pal or cmap : list of colors or matplotlib colormap
Object that can be passed to plotting functions.
See Also
--------
diverging_palette : Create a diverging color palette or colormap.
choose_colorbrewer_palette : Interactively choose palettes from the
colorbrewer set, including diverging palettes.
"""
pal = []
if as_cmap:
cmap = _init_mutable_colormap()
@interact
def choose_diverging_palette(
h_neg=IntSlider(min=0,
max=359,
value=220),
h_pos=IntSlider(min=0,
max=359,
value=10),
s=IntSlider(min=0, max=99, value=74),
l=IntSlider(min=0, max=99, value=50), # noqa: E741
sep=IntSlider(min=1, max=50, value=10),
n=(2, 16),
center=["light", "dark"]
):
if as_cmap:
colors = diverging_palette(h_neg, h_pos, s, l, sep, 256, center)
_update_lut(cmap, colors)
_show_cmap(cmap)
else:
pal[:] = diverging_palette(h_neg, h_pos, s, l, sep, n, center)
palplot(pal)
if as_cmap:
return cmap
return pal
def choose_cubehelix_palette(as_cmap=False):
"""Launch an interactive widget to create a sequential cubehelix palette.
This corresponds with the :func:`cubehelix_palette` function. This kind
of palette is good for data that range between relatively uninteresting
low values and interesting high values. The cubehelix system allows the
palette to have more hue variance across the range, which can be helpful
for distinguishing a wider range of values.
Requires IPython 2+ and must be used in the notebook.
Parameters
----------
as_cmap : bool
If True, the return value is a matplotlib colormap rather than a
list of discrete colors.
Returns
-------
pal or cmap : list of colors or matplotlib colormap
Object that can be passed to plotting functions.
See Also
--------
cubehelix_palette : Create a sequential palette or colormap using the
cubehelix system.
"""
pal = []
if as_cmap:
cmap = _init_mutable_colormap()
@interact
def choose_cubehelix(n_colors=IntSlider(min=2, max=16, value=9),
start=FloatSlider(min=0, max=3, value=0),
rot=FloatSlider(min=-1, max=1, value=.4),
gamma=FloatSlider(min=0, max=5, value=1),
hue=FloatSlider(min=0, max=1, value=.8),
light=FloatSlider(min=0, max=1, value=.85),
dark=FloatSlider(min=0, max=1, value=.15),
reverse=False):
if as_cmap:
colors = cubehelix_palette(256, start, rot, gamma,
hue, light, dark, reverse)
_update_lut(cmap, np.c_[colors, np.ones(256)])
_show_cmap(cmap)
else:
pal[:] = cubehelix_palette(n_colors, start, rot, gamma,
hue, light, dark, reverse)
palplot(pal)
if as_cmap:
return cmap
return pal
================================================
FILE: setup.cfg
================================================
[flake8]
max-line-length = 88
exclude = seaborn/cm.py,seaborn/external
ignore = E741,F522,W503
[mypy]
# Currently this ignores pandas and matplotlib
# We may want to make custom stub files for the parts we use
# I have found the available third party stubs to be less
# complete than they would need to be useful
ignore_missing_imports = True
[coverage:run]
omit =
seaborn/widgets.py
seaborn/external/*
seaborn/colors/*
seaborn/cm.py
seaborn/conftest.py
[coverage:report]
exclude_lines =
pragma: no cover
if TYPE_CHECKING:
raise NotImplementedError
================================================
FILE: tests/__init__.py
================================================
================================================
FILE: tests/_core/__init__.py
================================================
================================================
FILE: tests/_core/test_data.py
================================================
import functools
import numpy as np
import pandas as pd
import pytest
from numpy.testing import assert_array_equal
from pandas.testing import assert_series_equal
from seaborn._core.data import PlotData
assert_vector_equal = functools.partial(assert_series_equal, check_names=False)
class TestPlotData:
@pytest.fixture
def long_variables(self):
variables = dict(x="x", y="y", color="a", size="z", style="s_cat")
return variables
def test_named_vectors(self, long_df, long_variables):
p = PlotData(long_df, long_variables)
assert p.source_data is long_df
assert p.source_vars is long_variables
for key, val in long_variables.items():
assert p.names[key] == val
assert_vector_equal(p.frame[key], long_df[val])
def test_named_and_given_vectors(self, long_df, long_variables):
long_variables["y"] = long_df["b"]
long_variables["size"] = long_df["z"].to_numpy()
p = PlotData(long_df, long_variables)
assert_vector_equal(p.frame["color"], long_df[long_variables["color"]])
assert_vector_equal(p.frame["y"], long_df["b"])
assert_vector_equal(p.frame["size"], long_df["z"])
assert p.names["color"] == long_variables["color"]
assert p.names["y"] == "b"
assert p.names["size"] is None
assert p.ids["color"] == long_variables["color"]
assert p.ids["y"] == "b"
assert p.ids["size"] == id(long_variables["size"])
def test_index_as_variable(self, long_df, long_variables):
index = pd.Index(np.arange(len(long_df)) * 2 + 10, name="i", dtype=int)
long_variables["x"] = "i"
p = PlotData(long_df.set_index(index), long_variables)
assert p.names["x"] == p.ids["x"] == "i"
assert_vector_equal(p.frame["x"], pd.Series(index, index))
def test_multiindex_as_variables(self, long_df, long_variables):
index_i = pd.Index(np.arange(len(long_df)) * 2 + 10, name="i", dtype=int)
index_j = pd.Index(np.arange(len(long_df)) * 3 + 5, name="j", dtype=int)
index = pd.MultiIndex.from_arrays([index_i, index_j])
long_variables.update({"x": "i", "y": "j"})
p = PlotData(long_df.set_index(index), long_variables)
assert_vector_equal(p.frame["x"], pd.Series(index_i, index))
assert_vector_equal(p.frame["y"], pd.Series(index_j, index))
def test_int_as_variable_key(self, rng):
df = pd.DataFrame(rng.uniform(size=(10, 3)))
var = "x"
key = 2
p = PlotData(df, {var: key})
assert_vector_equal(p.frame[var], df[key])
assert p.names[var] == p.ids[var] == str(key)
def test_int_as_variable_value(self, long_df):
p = PlotData(long_df, {"x": 0, "y": "y"})
assert (p.frame["x"] == 0).all()
assert p.names["x"] is None
assert p.ids["x"] == id(0)
def test_tuple_as_variable_key(self, rng):
cols = pd.MultiIndex.from_product([("a", "b", "c"), ("x", "y")])
df = pd.DataFrame(rng.uniform(size=(10, 6)), columns=cols)
var = "color"
key = ("b", "y")
p = PlotData(df, {var: key})
assert_vector_equal(p.frame[var], df[key])
assert p.names[var] == p.ids[var] == str(key)
def test_dict_as_data(self, long_dict, long_variables):
p = PlotData(long_dict, long_variables)
assert p.source_data is long_dict
for key, val in long_variables.items():
assert_vector_equal(p.frame[key], pd.Series(long_dict[val]))
@pytest.mark.parametrize(
"vector_type",
["series", "numpy", "list"],
)
def test_vectors_various_types(self, long_df, long_variables, vector_type):
variables = {key: long_df[val] for key, val in long_variables.items()}
if vector_type == "numpy":
variables = {key: val.to_numpy() for key, val in variables.items()}
elif vector_type == "list":
variables = {key: val.to_list() for key, val in variables.items()}
p = PlotData(None, variables)
assert list(p.names) == list(long_variables)
if vector_type == "series":
assert p.source_vars is variables
assert p.names == p.ids == {key: val.name for key, val in variables.items()}
else:
assert p.names == {key: None for key in variables}
assert p.ids == {key: id(val) for key, val in variables.items()}
for key, val in long_variables.items():
if vector_type == "series":
assert_vector_equal(p.frame[key], long_df[val])
else:
assert_array_equal(p.frame[key], long_df[val])
def test_none_as_variable_value(self, long_df):
p = PlotData(long_df, {"x": "z", "y": None})
assert list(p.frame.columns) == ["x"]
assert p.names == p.ids == {"x": "z"}
def test_frame_and_vector_mismatched_lengths(self, long_df):
vector = np.arange(len(long_df) * 2)
with pytest.raises(ValueError):
PlotData(long_df, {"x": "x", "y": vector})
@pytest.mark.parametrize(
"arg", [{}, pd.DataFrame()],
)
def test_empty_data_input(self, arg):
p = PlotData(arg, {})
assert p.frame.empty
assert not p.names
if not isinstance(arg, pd.DataFrame):
p = PlotData(None, dict(x=arg, y=arg))
assert p.frame.empty
assert not p.names
def test_index_alignment_series_to_dataframe(self):
x = [1, 2, 3]
x_index = pd.Index(x, dtype=int)
y_values = [3, 4, 5]
y_index = pd.Index(y_values, dtype=int)
y = pd.Series(y_values, y_index, name="y")
data = pd.DataFrame(dict(x=x), index=x_index)
p = PlotData(data, {"x": "x", "y": y})
x_col_expected = pd.Series([1, 2, 3, np.nan, np.nan], np.arange(1, 6))
y_col_expected = pd.Series([np.nan, np.nan, 3, 4, 5], np.arange(1, 6))
assert_vector_equal(p.frame["x"], x_col_expected)
assert_vector_equal(p.frame["y"], y_col_expected)
def test_index_alignment_between_series(self):
x_index = [1, 2, 3]
x_values = [10, 20, 30]
x = pd.Series(x_values, x_index, name="x")
y_index = [3, 4, 5]
y_values = [300, 400, 500]
y = pd.Series(y_values, y_index, name="y")
p = PlotData(None, {"x": x, "y": y})
idx_expected = [1, 2, 3, 4, 5]
x_col_expected = pd.Series([10, 20, 30, np.nan, np.nan], idx_expected)
y_col_expected = pd.Series([np.nan, np.nan, 300, 400, 500], idx_expected)
assert_vector_equal(p.frame["x"], x_col_expected)
assert_vector_equal(p.frame["y"], y_col_expected)
def test_key_not_in_data_raises(self, long_df):
var = "x"
key = "what"
msg = f"Could not interpret value `{key}` for `{var}`. An entry with this name"
with pytest.raises(ValueError, match=msg):
PlotData(long_df, {var: key})
def test_key_with_no_data_raises(self):
var = "x"
key = "what"
msg = f"Could not interpret value `{key}` for `{var}`. Value is a string,"
with pytest.raises(ValueError, match=msg):
PlotData(None, {var: key})
def test_data_vector_different_lengths_raises(self, long_df):
vector = np.arange(len(long_df) - 5)
msg = "Length of ndarray vectors must match length of `data`"
with pytest.raises(ValueError, match=msg):
PlotData(long_df, {"y": vector})
def test_undefined_variables_raise(self, long_df):
with pytest.raises(ValueError):
PlotData(long_df, dict(x="not_in_df"))
with pytest.raises(ValueError):
PlotData(long_df, dict(x="x", y="not_in_df"))
with pytest.raises(ValueError):
PlotData(long_df, dict(x="x", y="y", color="not_in_df"))
def test_contains_operation(self, long_df):
p = PlotData(long_df, {"x": "y", "color": long_df["a"]})
assert "x" in p
assert "y" not in p
assert "color" in p
def test_join_add_variable(self, long_df):
v1 = {"x": "x", "y": "f"}
v2 = {"color": "a"}
p1 = PlotData(long_df, v1)
p2 = p1.join(None, v2)
for var, key in dict(**v1, **v2).items():
assert var in p2
assert p2.names[var] == key
assert_vector_equal(p2.frame[var], long_df[key])
def test_join_replace_variable(self, long_df):
v1 = {"x": "x", "y": "y"}
v2 = {"y": "s"}
p1 = PlotData(long_df, v1)
p2 = p1.join(None, v2)
variables = v1.copy()
variables.update(v2)
for var, key in variables.items():
assert var in p2
assert p2.names[var] == key
assert_vector_equal(p2.frame[var], long_df[key])
def test_join_remove_variable(self, long_df):
variables = {"x": "x", "y": "f"}
drop_var = "y"
p1 = PlotData(long_df, variables)
p2 = p1.join(None, {drop_var: None})
assert drop_var in p1
assert drop_var not in p2
assert drop_var not in p2.frame
assert drop_var not in p2.names
def test_join_all_operations(self, long_df):
v1 = {"x": "x", "y": "y", "color": "a"}
v2 = {"y": "s", "size": "s", "color": None}
p1 = PlotData(long_df, v1)
p2 = p1.join(None, v2)
for var, key in v2.items():
if key is None:
assert var not in p2
else:
assert p2.names[var] == key
assert_vector_equal(p2.frame[var], long_df[key])
def test_join_all_operations_same_data(self, long_df):
v1 = {"x": "x", "y": "y", "color": "a"}
v2 = {"y": "s", "size": "s", "color": None}
p1 = PlotData(long_df, v1)
p2 = p1.join(long_df, v2)
for var, key in v2.items():
if key is None:
assert var not in p2
else:
assert p2.names[var] == key
assert_vector_equal(p2.frame[var], long_df[key])
def test_join_add_variable_new_data(self, long_df):
d1 = long_df[["x", "y"]]
d2 = long_df[["a", "s"]]
v1 = {"x": "x", "y": "y"}
v2 = {"color": "a"}
p1 = PlotData(d1, v1)
p2 = p1.join(d2, v2)
for var, key in dict(**v1, **v2).items():
assert p2.names[var] == key
assert_vector_equal(p2.frame[var], long_df[key])
def test_join_replace_variable_new_data(self, long_df):
d1 = long_df[["x", "y"]]
d2 = long_df[["a", "s"]]
v1 = {"x": "x", "y": "y"}
v2 = {"x": "a"}
p1 = PlotData(d1, v1)
p2 = p1.join(d2, v2)
variables = v1.copy()
variables.update(v2)
for var, key in variables.items():
assert p2.names[var] == key
assert_vector_equal(p2.frame[var], long_df[key])
def test_join_add_variable_different_index(self, long_df):
d1 = long_df.iloc[:70]
d2 = long_df.iloc[30:]
v1 = {"x": "a"}
v2 = {"y": "z"}
p1 = PlotData(d1, v1)
p2 = p1.join(d2, v2)
(var1, key1), = v1.items()
(var2, key2), = v2.items()
assert_vector_equal(p2.frame.loc[d1.index, var1], d1[key1])
assert_vector_equal(p2.frame.loc[d2.index, var2], d2[key2])
assert p2.frame.loc[d2.index.difference(d1.index), var1].isna().all()
assert p2.frame.loc[d1.index.difference(d2.index), var2].isna().all()
def test_join_replace_variable_different_index(self, long_df):
d1 = long_df.iloc[:70]
d2 = long_df.iloc[30:]
var = "x"
k1, k2 = "a", "z"
v1 = {var: k1}
v2 = {var: k2}
p1 = PlotData(d1, v1)
p2 = p1.join(d2, v2)
(var1, key1), = v1.items()
(var2, key2), = v2.items()
assert_vector_equal(p2.frame.loc[d2.index, var], d2[k2])
assert p2.frame.loc[d1.index.difference(d2.index), var].isna().all()
def test_join_subset_data_inherit_variables(self, long_df):
sub_df = long_df[long_df["a"] == "b"]
var = "y"
p1 = PlotData(long_df, {var: var})
p2 = p1.join(sub_df, None)
assert_vector_equal(p2.frame.loc[sub_df.index, var], sub_df[var])
assert p2.frame.loc[long_df.index.difference(sub_df.index), var].isna().all()
def test_join_multiple_inherits_from_orig(self, rng):
d1 = pd.DataFrame(dict(a=rng.normal(0, 1, 100), b=rng.normal(0, 1, 100)))
d2 = pd.DataFrame(dict(a=rng.normal(0, 1, 100)))
p = PlotData(d1, {"x": "a"}).join(d2, {"y": "a"}).join(None, {"y": "a"})
assert_vector_equal(p.frame["x"], d1["a"])
assert_vector_equal(p.frame["y"], d1["a"])
def test_bad_type(self, flat_list):
err = "Data source must be a DataFrame or Mapping"
with pytest.raises(TypeError, match=err):
PlotData(flat_list, {})
@pytest.mark.skipif(
condition=not hasattr(pd.api, "interchange"),
reason="Tests behavior assuming support for dataframe interchange"
)
def test_data_interchange(self, mock_long_df, long_df):
variables = {"x": "x", "y": "z", "color": "a"}
p = PlotData(mock_long_df, variables)
for var, col in variables.items():
assert_vector_equal(p.frame[var], long_df[col])
p = PlotData(mock_long_df, {**variables, "color": long_df["a"]})
for var, col in variables.items():
assert_vector_equal(p.frame[var], long_df[col])
def test_data_interchange_failure(self, mock_long_df):
mock_long_df._data = None # Break to_pandas()
with pytest.raises(RuntimeError, match="Encountered an exception"):
PlotData(mock_long_df, {"x": "x"})
@pytest.mark.skipif(
condition=hasattr(pd.api, "interchange"),
reason="Tests graceful failure without support for dataframe interchange"
)
def test_data_interchange_support_test(self, mock_long_df):
with pytest.raises(TypeError, match="Support for non-pandas DataFrame"):
PlotData(mock_long_df, {"x": "x"})
================================================
FILE: tests/_core/test_groupby.py
================================================
import numpy as np
import pandas as pd
import pytest
from numpy.testing import assert_array_equal
from seaborn._core.groupby import GroupBy
@pytest.fixture
def df():
return pd.DataFrame(
columns=["a", "b", "x", "y"],
data=[
["a", "g", 1, .2],
["b", "h", 3, .5],
["a", "f", 2, .8],
["a", "h", 1, .3],
["b", "f", 2, .4],
]
)
def test_init_from_list():
g = GroupBy(["a", "c", "b"])
assert g.order == {"a": None, "c": None, "b": None}
def test_init_from_dict():
order = {"a": [3, 2, 1], "c": None, "b": ["x", "y", "z"]}
g = GroupBy(order)
assert g.order == order
def test_init_requires_order():
with pytest.raises(ValueError, match="GroupBy requires at least one"):
GroupBy([])
def test_at_least_one_grouping_variable_required(df):
with pytest.raises(ValueError, match="No grouping variables are present"):
GroupBy(["z"]).agg(df, x="mean")
def test_agg_one_grouper(df):
res = GroupBy(["a"]).agg(df, {"y": "max"})
assert_array_equal(res.index, [0, 1])
assert_array_equal(res.columns, ["a", "y"])
assert_array_equal(res["a"], ["a", "b"])
assert_array_equal(res["y"], [.8, .5])
def test_agg_two_groupers(df):
res = GroupBy(["a", "x"]).agg(df, {"y": "min"})
assert_array_equal(res.index, [0, 1, 2, 3, 4, 5])
assert_array_equal(res.columns, ["a", "x", "y"])
assert_array_equal(res["a"], ["a", "a", "a", "b", "b", "b"])
assert_array_equal(res["x"], [1, 2, 3, 1, 2, 3])
assert_array_equal(res["y"], [.2, .8, np.nan, np.nan, .4, .5])
def test_agg_two_groupers_ordered(df):
order = {"b": ["h", "g", "f"], "x": [3, 2, 1]}
res = GroupBy(order).agg(df, {"a": "min", "y": lambda x: x.iloc[0]})
assert_array_equal(res.index, [0, 1, 2, 3, 4, 5, 6, 7, 8])
assert_array_equal(res.columns, ["a", "b", "x", "y"])
assert_array_equal(res["b"], ["h", "h", "h", "g", "g", "g", "f", "f", "f"])
assert_array_equal(res["x"], [3, 2, 1, 3, 2, 1, 3, 2, 1])
T, F = True, False
assert_array_equal(res["a"].isna(), [F, T, F, T, T, F, T, F, T])
assert_array_equal(res["a"].dropna(), ["b", "a", "a", "a"])
assert_array_equal(res["y"].dropna(), [.5, .3, .2, .8])
def test_apply_no_grouper(df):
df = df[["x", "y"]]
res = GroupBy(["a"]).apply(df, lambda x: x.sort_values("x"))
assert_array_equal(res.columns, ["x", "y"])
assert_array_equal(res["x"], df["x"].sort_values())
assert_array_equal(res["y"], df.loc[np.argsort(df["x"]), "y"])
def test_apply_one_grouper(df):
res = GroupBy(["a"]).apply(df, lambda x: x.sort_values("x"))
assert_array_equal(res.index, [0, 1, 2, 3, 4])
assert_array_equal(res.columns, ["a", "b", "x", "y"])
assert_array_equal(res["a"], ["a", "a", "a", "b", "b"])
assert_array_equal(res["b"], ["g", "h", "f", "f", "h"])
assert_array_equal(res["x"], [1, 1, 2, 2, 3])
def test_apply_mutate_columns(df):
xx = np.arange(0, 5)
hats = []
def polyfit(df):
fit = np.polyfit(df["x"], df["y"], 1)
hat = np.polyval(fit, xx)
hats.append(hat)
return pd.DataFrame(dict(x=xx, y=hat))
res = GroupBy(["a"]).apply(df, polyfit)
assert_array_equal(res.index, np.arange(xx.size * 2))
assert_array_equal(res.columns, ["a", "x", "y"])
assert_array_equal(res["a"], ["a"] * xx.size + ["b"] * xx.size)
assert_array_equal(res["x"], xx.tolist() + xx.tolist())
assert_array_equal(res["y"], np.concatenate(hats))
def test_apply_replace_columns(df):
def add_sorted_cumsum(df):
x = df["x"].sort_values()
z = df.loc[x.index, "y"].cumsum()
return pd.DataFrame(dict(x=x.values, z=z.values))
res = GroupBy(["a"]).apply(df, add_sorted_cumsum)
assert_array_equal(res.index, df.index)
assert_array_equal(res.columns, ["a", "x", "z"])
assert_array_equal(res["a"], ["a", "a", "a", "b", "b"])
assert_array_equal(res["x"], [1, 1, 2, 2, 3])
assert_array_equal(res["z"], [.2, .5, 1.3, .4, .9])
================================================
FILE: tests/_core/test_moves.py
================================================
from itertools import product
import numpy as np
import pandas as pd
from pandas.testing import assert_series_equal
from numpy.testing import assert_array_equal, assert_array_almost_equal
from seaborn._core.moves import Dodge, Jitter, Shift, Stack, Norm
from seaborn._core.rules import categorical_order
from seaborn._core.groupby import GroupBy
import pytest
class MoveFixtures:
@pytest.fixture
def df(self, rng):
n = 50
data = {
"x": rng.choice([0., 1., 2., 3.], n),
"y": rng.normal(0, 1, n),
"grp2": rng.choice(["a", "b"], n),
"grp3": rng.choice(["x", "y", "z"], n),
"width": 0.8,
"baseline": 0,
}
return pd.DataFrame(data)
@pytest.fixture
def toy_df(self):
data = {
"x": [0, 0, 1],
"y": [1, 2, 3],
"grp": ["a", "b", "b"],
"width": .8,
"baseline": 0,
}
return pd.DataFrame(data)
@pytest.fixture
def toy_df_widths(self, toy_df):
toy_df["width"] = [.8, .2, .4]
return toy_df
@pytest.fixture
def toy_df_facets(self):
data = {
"x": [0, 0, 1, 0, 1, 2],
"y": [1, 2, 3, 1, 2, 3],
"grp": ["a", "b", "a", "b", "a", "b"],
"col": ["x", "x", "x", "y", "y", "y"],
"width": .8,
"baseline": 0,
}
return pd.DataFrame(data)
class TestJitter(MoveFixtures):
def get_groupby(self, data, orient):
other = {"x": "y", "y": "x"}[orient]
variables = [v for v in data if v not in [other, "width"]]
return GroupBy(variables)
def check_same(self, res, df, *cols):
for col in cols:
assert_series_equal(res[col], df[col])
def check_pos(self, res, df, var, limit):
assert (res[var] != df[var]).all()
assert (res[var] < df[var] + limit / 2).all()
assert (res[var] > df[var] - limit / 2).all()
def test_default(self, df):
orient = "x"
groupby = self.get_groupby(df, orient)
res = Jitter()(df, groupby, orient, {})
self.check_same(res, df, "y", "grp2", "width")
self.check_pos(res, df, "x", 0.2 * df["width"])
assert (res["x"] - df["x"]).abs().min() > 0
def test_width(self, df):
width = .4
orient = "x"
groupby = self.get_groupby(df, orient)
res = Jitter(width=width)(df, groupby, orient, {})
self.check_same(res, df, "y", "grp2", "width")
self.check_pos(res, df, "x", width * df["width"])
def test_x(self, df):
val = .2
orient = "x"
groupby = self.get_groupby(df, orient)
res = Jitter(x=val)(df, groupby, orient, {})
self.check_same(res, df, "y", "grp2", "width")
self.check_pos(res, df, "x", val)
def test_y(self, df):
val = .2
orient = "x"
groupby = self.get_groupby(df, orient)
res = Jitter(y=val)(df, groupby, orient, {})
self.check_same(res, df, "x", "grp2", "width")
self.check_pos(res, df, "y", val)
def test_seed(self, df):
kws = dict(width=.2, y=.1, seed=0)
orient = "x"
groupby = self.get_groupby(df, orient)
res1 = Jitter(**kws)(df, groupby, orient, {})
res2 = Jitter(**kws)(df, groupby, orient, {})
for var in "xy":
assert_series_equal(res1[var], res2[var])
class TestDodge(MoveFixtures):
# First some very simple toy examples
def test_default(self, toy_df):
groupby = GroupBy(["x", "grp"])
res = Dodge()(toy_df, groupby, "x", {})
assert_array_equal(res["y"], [1, 2, 3]),
assert_array_almost_equal(res["x"], [-.2, .2, 1.2])
assert_array_almost_equal(res["width"], [.4, .4, .4])
def test_fill(self, toy_df):
groupby = GroupBy(["x", "grp"])
res = Dodge(empty="fill")(toy_df, groupby, "x", {})
assert_array_equal(res["y"], [1, 2, 3]),
assert_array_almost_equal(res["x"], [-.2, .2, 1])
assert_array_almost_equal(res["width"], [.4, .4, .8])
def test_drop(self, toy_df):
groupby = GroupBy(["x", "grp"])
res = Dodge("drop")(toy_df, groupby, "x", {})
assert_array_equal(res["y"], [1, 2, 3])
assert_array_almost_equal(res["x"], [-.2, .2, 1])
assert_array_almost_equal(res["width"], [.4, .4, .4])
def test_gap(self, toy_df):
groupby = GroupBy(["x", "grp"])
res = Dodge(gap=.25)(toy_df, groupby, "x", {})
assert_array_equal(res["y"], [1, 2, 3])
assert_array_almost_equal(res["x"], [-.2, .2, 1.2])
assert_array_almost_equal(res["width"], [.3, .3, .3])
def test_widths_default(self, toy_df_widths):
groupby = GroupBy(["x", "grp"])
res = Dodge()(toy_df_widths, groupby, "x", {})
assert_array_equal(res["y"], [1, 2, 3])
assert_array_almost_equal(res["x"], [-.08, .32, 1.1])
assert_array_almost_equal(res["width"], [.64, .16, .2])
def test_widths_fill(self, toy_df_widths):
groupby = GroupBy(["x", "grp"])
res = Dodge(empty="fill")(toy_df_widths, groupby, "x", {})
assert_array_equal(res["y"], [1, 2, 3])
assert_array_almost_equal(res["x"], [-.08, .32, 1])
assert_array_almost_equal(res["width"], [.64, .16, .4])
def test_widths_drop(self, toy_df_widths):
groupby = GroupBy(["x", "grp"])
res = Dodge(empty="drop")(toy_df_widths, groupby, "x", {})
assert_array_equal(res["y"], [1, 2, 3])
assert_array_almost_equal(res["x"], [-.08, .32, 1])
assert_array_almost_equal(res["width"], [.64, .16, .2])
def test_faceted_default(self, toy_df_facets):
groupby = GroupBy(["x", "grp", "col"])
res = Dodge()(toy_df_facets, groupby, "x", {})
assert_array_equal(res["y"], [1, 2, 3, 1, 2, 3])
assert_array_almost_equal(res["x"], [-.2, .2, .8, .2, .8, 2.2])
assert_array_almost_equal(res["width"], [.4] * 6)
def test_faceted_fill(self, toy_df_facets):
groupby = GroupBy(["x", "grp", "col"])
res = Dodge(empty="fill")(toy_df_facets, groupby, "x", {})
assert_array_equal(res["y"], [1, 2, 3, 1, 2, 3])
assert_array_almost_equal(res["x"], [-.2, .2, 1, 0, 1, 2])
assert_array_almost_equal(res["width"], [.4, .4, .8, .8, .8, .8])
def test_faceted_drop(self, toy_df_facets):
groupby = GroupBy(["x", "grp", "col"])
res = Dodge(empty="drop")(toy_df_facets, groupby, "x", {})
assert_array_equal(res["y"], [1, 2, 3, 1, 2, 3])
assert_array_almost_equal(res["x"], [-.2, .2, 1, 0, 1, 2])
assert_array_almost_equal(res["width"], [.4] * 6)
def test_orient(self, toy_df):
df = toy_df.assign(x=toy_df["y"], y=toy_df["x"])
groupby = GroupBy(["y", "grp"])
res = Dodge("drop")(df, groupby, "y", {})
assert_array_equal(res["x"], [1, 2, 3])
assert_array_almost_equal(res["y"], [-.2, .2, 1])
assert_array_almost_equal(res["width"], [.4, .4, .4])
# Now tests with slightly more complicated data
@pytest.mark.parametrize("grp", ["grp2", "grp3"])
def test_single_semantic(self, df, grp):
groupby = GroupBy(["x", grp])
res = Dodge()(df, groupby, "x", {})
levels = categorical_order(df[grp])
w, n = 0.8, len(levels)
shifts = np.linspace(0, w - w / n, n)
shifts -= shifts.mean()
assert_series_equal(res["y"], df["y"])
assert_series_equal(res["width"], df["width"] / n)
for val, shift in zip(levels, shifts):
rows = df[grp] == val
assert_series_equal(res.loc[rows, "x"], df.loc[rows, "x"] + shift)
def test_two_semantics(self, df):
groupby = GroupBy(["x", "grp2", "grp3"])
res = Dodge()(df, groupby, "x", {})
levels = categorical_order(df["grp2"]), categorical_order(df["grp3"])
w, n = 0.8, len(levels[0]) * len(levels[1])
shifts = np.linspace(0, w - w / n, n)
shifts -= shifts.mean()
assert_series_equal(res["y"], df["y"])
assert_series_equal(res["width"], df["width"] / n)
for (v2, v3), shift in zip(product(*levels), shifts):
rows = (df["grp2"] == v2) & (df["grp3"] == v3)
assert_series_equal(res.loc[rows, "x"], df.loc[rows, "x"] + shift)
class TestStack(MoveFixtures):
def test_basic(self, toy_df):
groupby = GroupBy(["color", "group"])
res = Stack()(toy_df, groupby, "x", {})
assert_array_equal(res["x"], [0, 0, 1])
assert_array_equal(res["y"], [1, 3, 3])
assert_array_equal(res["baseline"], [0, 1, 0])
def test_faceted(self, toy_df_facets):
groupby = GroupBy(["color", "group"])
res = Stack()(toy_df_facets, groupby, "x", {})
assert_array_equal(res["x"], [0, 0, 1, 0, 1, 2])
assert_array_equal(res["y"], [1, 3, 3, 1, 2, 3])
assert_array_equal(res["baseline"], [0, 1, 0, 0, 0, 0])
def test_misssing_data(self, toy_df):
df = pd.DataFrame({
"x": [0, 0, 0],
"y": [2, np.nan, 1],
"baseline": [0, 0, 0],
})
res = Stack()(df, None, "x", {})
assert_array_equal(res["y"], [2, np.nan, 3])
assert_array_equal(res["baseline"], [0, np.nan, 2])
def test_baseline_homogeneity_check(self, toy_df):
toy_df["baseline"] = [0, 1, 2]
groupby = GroupBy(["color", "group"])
move = Stack()
err = "Stack move cannot be used when baselines"
with pytest.raises(RuntimeError, match=err):
move(toy_df, groupby, "x", {})
class TestShift(MoveFixtures):
def test_default(self, toy_df):
gb = GroupBy(["color", "group"])
res = Shift()(toy_df, gb, "x", {})
for col in toy_df:
assert_series_equal(toy_df[col], res[col])
@pytest.mark.parametrize("x,y", [(.3, 0), (0, .2), (.1, .3)])
def test_moves(self, toy_df, x, y):
gb = GroupBy(["color", "group"])
res = Shift(x=x, y=y)(toy_df, gb, "x", {})
assert_array_equal(res["x"], toy_df["x"] + x)
assert_array_equal(res["y"], toy_df["y"] + y)
class TestNorm(MoveFixtures):
@pytest.mark.parametrize("orient", ["x", "y"])
def test_default_no_groups(self, df, orient):
other = {"x": "y", "y": "x"}[orient]
gb = GroupBy(["null"])
res = Norm()(df, gb, orient, {})
assert res[other].max() == pytest.approx(1)
@pytest.mark.parametrize("orient", ["x", "y"])
def test_default_groups(self, df, orient):
other = {"x": "y", "y": "x"}[orient]
gb = GroupBy(["grp2"])
res = Norm()(df, gb, orient, {})
for _, grp in res.groupby("grp2"):
assert grp[other].max() == pytest.approx(1)
def test_sum(self, df):
gb = GroupBy(["null"])
res = Norm("sum")(df, gb, "x", {})
assert res["y"].sum() == pytest.approx(1)
def test_where(self, df):
gb = GroupBy(["null"])
res = Norm(where="x == 2")(df, gb, "x", {})
assert res.loc[res["x"] == 2, "y"].max() == pytest.approx(1)
def test_percent(self, df):
gb = GroupBy(["null"])
res = Norm(percent=True)(df, gb, "x", {})
assert res["y"].max() == pytest.approx(100)
================================================
FILE: tests/_core/test_plot.py
================================================
import io
import xml
import functools
import itertools
import warnings
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
from PIL import Image
import pytest
from pandas.testing import assert_frame_equal, assert_series_equal
from numpy.testing import assert_array_equal, assert_array_almost_equal
from seaborn._core.plot import Plot, PlotConfig, Default
from seaborn._core.scales import Continuous, Nominal, Temporal
from seaborn._core.moves import Move, Shift, Dodge
from seaborn._core.rules import categorical_order
from seaborn._core.exceptions import PlotSpecError
from seaborn._marks.base import Mark
from seaborn._stats.base import Stat
from seaborn._marks.dot import Dot
from seaborn._stats.aggregation import Agg
from seaborn.utils import _version_predates
assert_vector_equal = functools.partial(
# TODO do we care about int/float dtype consistency?
# Eventually most variables become floats ... but does it matter when?
# (Or rather, does it matter if it happens too early?)
assert_series_equal, check_names=False, check_dtype=False,
)
def assert_gridspec_shape(ax, nrows=1, ncols=1):
gs = ax.get_gridspec()
assert gs.nrows == nrows
assert gs.ncols == ncols
class MockMark(Mark):
_grouping_props = ["color"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.passed_keys = []
self.passed_data = []
self.passed_axes = []
self.passed_scales = None
self.passed_orient = None
self.n_splits = 0
def _plot(self, split_gen, scales, orient):
for keys, data, ax in split_gen():
self.n_splits += 1
self.passed_keys.append(keys)
self.passed_data.append(data)
self.passed_axes.append(ax)
self.passed_scales = scales
self.passed_orient = orient
def _legend_artist(self, variables, value, scales):
a = mpl.lines.Line2D([], [])
a.variables = variables
a.value = value
return a
class TestInit:
def test_empty(self):
p = Plot()
assert p._data.source_data is None
assert p._data.source_vars == {}
def test_data_only(self, long_df):
p = Plot(long_df)
assert p._data.source_data is long_df
assert p._data.source_vars == {}
def test_df_and_named_variables(self, long_df):
variables = {"x": "a", "y": "z"}
p = Plot(long_df, **variables)
for var, col in variables.items():
assert_vector_equal(p._data.frame[var], long_df[col])
assert p._data.source_data is long_df
assert p._data.source_vars.keys() == variables.keys()
def test_df_and_mixed_variables(self, long_df):
variables = {"x": "a", "y": long_df["z"]}
p = Plot(long_df, **variables)
for var, col in variables.items():
if isinstance(col, str):
assert_vector_equal(p._data.frame[var], long_df[col])
else:
assert_vector_equal(p._data.frame[var], col)
assert p._data.source_data is long_df
assert p._data.source_vars.keys() == variables.keys()
def test_vector_variables_only(self, long_df):
variables = {"x": long_df["a"], "y": long_df["z"]}
p = Plot(**variables)
for var, col in variables.items():
assert_vector_equal(p._data.frame[var], col)
assert p._data.source_data is None
assert p._data.source_vars.keys() == variables.keys()
def test_vector_variables_no_index(self, long_df):
variables = {"x": long_df["a"].to_numpy(), "y": long_df["z"].to_list()}
p = Plot(**variables)
for var, col in variables.items():
assert_vector_equal(p._data.frame[var], pd.Series(col))
assert p._data.names[var] is None
assert p._data.source_data is None
assert p._data.source_vars.keys() == variables.keys()
def test_data_only_named(self, long_df):
p = Plot(data=long_df)
assert p._data.source_data is long_df
assert p._data.source_vars == {}
def test_positional_and_named_data(self, long_df):
err = "`data` given by both name and position"
with pytest.raises(TypeError, match=err):
Plot(long_df, data=long_df)
@pytest.mark.parametrize("var", ["x", "y"])
def test_positional_and_named_xy(self, long_df, var):
err = f"`{var}` given by both name and position"
with pytest.raises(TypeError, match=err):
Plot(long_df, "a", "b", **{var: "c"})
def test_positional_data_x_y(self, long_df):
p = Plot(long_df, "a", "b")
assert p._data.source_data is long_df
assert list(p._data.source_vars) == ["x", "y"]
def test_positional_x_y(self, long_df):
p = Plot(long_df["a"], long_df["b"])
assert p._data.source_data is None
assert list(p._data.source_vars) == ["x", "y"]
def test_positional_data_x(self, long_df):
p = Plot(long_df, "a")
assert p._data.source_data is long_df
assert list(p._data.source_vars) == ["x"]
def test_positional_x(self, long_df):
p = Plot(long_df["a"])
assert p._data.source_data is None
assert list(p._data.source_vars) == ["x"]
@pytest.mark.skipif(
condition=not hasattr(pd.api, "interchange"),
reason="Tests behavior assuming support for dataframe interchange"
)
def test_positional_interchangeable_dataframe(self, mock_long_df, long_df):
p = Plot(mock_long_df, x="x")
assert_frame_equal(p._data.source_data, long_df)
def test_positional_too_many(self, long_df):
err = r"Plot\(\) accepts no more than 3 positional arguments \(data, x, y\)"
with pytest.raises(TypeError, match=err):
Plot(long_df, "x", "y", "z")
def test_unknown_keywords(self, long_df):
err = r"Plot\(\) got unexpected keyword argument\(s\): bad"
with pytest.raises(TypeError, match=err):
Plot(long_df, bad="x")
class TestLayerAddition:
def test_without_data(self, long_df):
p = Plot(long_df, x="x", y="y").add(MockMark()).plot()
layer, = p._layers
assert_frame_equal(p._data.frame, layer["data"].frame, check_dtype=False)
def test_with_new_variable_by_name(self, long_df):
p = Plot(long_df, x="x").add(MockMark(), y="y").plot()
layer, = p._layers
assert layer["data"].frame.columns.to_list() == ["x", "y"]
for var in "xy":
assert_vector_equal(layer["data"].frame[var], long_df[var])
def test_with_new_variable_by_vector(self, long_df):
p = Plot(long_df, x="x").add(MockMark(), y=long_df["y"]).plot()
layer, = p._layers
assert layer["data"].frame.columns.to_list() == ["x", "y"]
for var in "xy":
assert_vector_equal(layer["data"].frame[var], long_df[var])
def test_with_late_data_definition(self, long_df):
p = Plot().add(MockMark(), data=long_df, x="x", y="y").plot()
layer, = p._layers
assert layer["data"].frame.columns.to_list() == ["x", "y"]
for var in "xy":
assert_vector_equal(layer["data"].frame[var], long_df[var])
def test_with_new_data_definition(self, long_df):
long_df_sub = long_df.sample(frac=.5)
p = Plot(long_df, x="x", y="y").add(MockMark(), data=long_df_sub).plot()
layer, = p._layers
assert layer["data"].frame.columns.to_list() == ["x", "y"]
for var in "xy":
assert_vector_equal(
layer["data"].frame[var], long_df_sub[var].reindex(long_df.index)
)
def test_drop_variable(self, long_df):
p = Plot(long_df, x="x", y="y").add(MockMark(), y=None).plot()
layer, = p._layers
assert layer["data"].frame.columns.to_list() == ["x"]
assert_vector_equal(layer["data"].frame["x"], long_df["x"], check_dtype=False)
@pytest.mark.xfail(reason="Need decision on default stat")
def test_stat_default(self):
class MarkWithDefaultStat(Mark):
default_stat = Stat
p = Plot().add(MarkWithDefaultStat())
layer, = p._layers
assert layer["stat"].__class__ is Stat
def test_stat_nondefault(self):
class MarkWithDefaultStat(Mark):
default_stat = Stat
class OtherMockStat(Stat):
pass
p = Plot().add(MarkWithDefaultStat(), OtherMockStat())
layer, = p._layers
assert layer["stat"].__class__ is OtherMockStat
@pytest.mark.parametrize(
"arg,expected",
[("x", "x"), ("y", "y"), ("v", "x"), ("h", "y")],
)
def test_orient(self, arg, expected):
class MockStatTrackOrient(Stat):
def __call__(self, data, groupby, orient, scales):
self.orient_at_call = orient
return data
class MockMoveTrackOrient(Move):
def __call__(self, data, groupby, orient, scales):
self.orient_at_call = orient
return data
s = MockStatTrackOrient()
m = MockMoveTrackOrient()
Plot(x=[1, 2, 3], y=[1, 2, 3]).add(MockMark(), s, m, orient=arg).plot()
assert s.orient_at_call == expected
assert m.orient_at_call == expected
def test_variable_list(self, long_df):
p = Plot(long_df, x="x", y="y")
assert p._variables == ["x", "y"]
p = Plot(long_df).add(MockMark(), x="x", y="y")
assert p._variables == ["x", "y"]
p = Plot(long_df, y="x", color="a").add(MockMark(), x="y")
assert p._variables == ["y", "color", "x"]
p = Plot(long_df, x="x", y="y", color="a").add(MockMark(), color=None)
assert p._variables == ["x", "y", "color"]
p = (
Plot(long_df, x="x", y="y")
.add(MockMark(), color="a")
.add(MockMark(), alpha="s")
)
assert p._variables == ["x", "y", "color", "alpha"]
p = Plot(long_df, y="x").pair(x=["a", "b"])
assert p._variables == ["y", "x0", "x1"]
def test_type_checks(self):
p = Plot()
with pytest.raises(TypeError, match="mark must be a Mark instance"):
p.add(MockMark)
class MockStat(Stat):
pass
class MockMove(Move):
pass
err = "Transforms must have at most one Stat type"
with pytest.raises(TypeError, match=err):
p.add(MockMark(), MockStat)
with pytest.raises(TypeError, match=err):
p.add(MockMark(), MockMove(), MockStat())
with pytest.raises(TypeError, match=err):
p.add(MockMark(), MockMark(), MockStat())
class TestScaling:
def test_inference(self, long_df):
for col, scale_type in zip("zat", ["Continuous", "Nominal", "Temporal"]):
p = Plot(long_df, x=col, y=col).add(MockMark()).plot()
for var in "xy":
assert p._scales[var].__class__.__name__ == scale_type
def test_inference_from_layer_data(self):
p = Plot().add(MockMark(), x=["a", "b", "c"]).plot()
assert p._scales["x"]("b") == 1
def test_inference_joins(self):
p = (
Plot(y=pd.Series([1, 2, 3, 4]))
.add(MockMark(), x=pd.Series([1, 2]))
.add(MockMark(), x=pd.Series(["a", "b"], index=[2, 3]))
.plot()
)
assert p._scales["x"]("a") == 2
def test_inferred_categorical_converter(self):
p = Plot(x=["b", "c", "a"]).add(MockMark()).plot()
ax = p._figure.axes[0]
assert ax.xaxis.convert_units("c") == 1
def test_explicit_categorical_converter(self):
p = Plot(y=[2, 1, 3]).scale(y=Nominal()).add(MockMark()).plot()
ax = p._figure.axes[0]
assert ax.yaxis.convert_units("3") == 2
@pytest.mark.xfail(reason="Temporal auto-conversion not implemented")
def test_categorical_as_datetime(self):
dates = ["1970-01-03", "1970-01-02", "1970-01-04"]
p = Plot(x=dates).scale(...).add(MockMark()).plot()
p # TODO
...
def test_faceted_log_scale(self):
p = Plot(y=[1, 10]).facet(col=["a", "b"]).scale(y="log").plot()
for ax in p._figure.axes:
xfm = ax.yaxis.get_transform().transform
assert_array_equal(xfm([1, 10, 100]), [0, 1, 2])
def test_paired_single_log_scale(self):
x0, x1 = [1, 2, 3], [1, 10, 100]
p = Plot().pair(x=[x0, x1]).scale(x1="log").plot()
ax_lin, ax_log = p._figure.axes
xfm_lin = ax_lin.xaxis.get_transform().transform
assert_array_equal(xfm_lin([1, 10, 100]), [1, 10, 100])
xfm_log = ax_log.xaxis.get_transform().transform
assert_array_equal(xfm_log([1, 10, 100]), [0, 1, 2])
def test_paired_with_common_fallback(self):
x0, x1 = [1, 2, 3], [1, 10, 100]
p = Plot().pair(x=[x0, x1]).scale(x="pow", x1="log").plot()
ax_pow, ax_log = p._figure.axes
xfm_pow = ax_pow.xaxis.get_transform().transform
assert_array_equal(xfm_pow([1, 2, 3]), [1, 4, 9])
xfm_log = ax_log.xaxis.get_transform().transform
assert_array_equal(xfm_log([1, 10, 100]), [0, 1, 2])
@pytest.mark.xfail(reason="Custom log scale needs log name for consistency")
def test_log_scale_name(self):
p = Plot().scale(x="log").plot()
ax = p._figure.axes[0]
assert ax.get_xscale() == "log"
assert ax.get_yscale() == "linear"
def test_mark_data_log_transform_is_inverted(self, long_df):
col = "z"
m = MockMark()
Plot(long_df, x=col).scale(x="log").add(m).plot()
assert_vector_equal(m.passed_data[0]["x"], long_df[col])
def test_mark_data_log_transfrom_with_stat(self, long_df):
class Mean(Stat):
group_by_orient = True
def __call__(self, data, groupby, orient, scales):
other = {"x": "y", "y": "x"}[orient]
return groupby.agg(data, {other: "mean"})
col = "z"
grouper = "a"
m = MockMark()
s = Mean()
Plot(long_df, x=grouper, y=col).scale(y="log").add(m, s).plot()
expected = (
long_df[col]
.pipe(np.log)
.groupby(long_df[grouper], sort=False)
.mean()
.pipe(np.exp)
.reset_index(drop=True)
)
assert_vector_equal(m.passed_data[0]["y"], expected)
def test_mark_data_from_categorical(self, long_df):
col = "a"
m = MockMark()
Plot(long_df, x=col).add(m).plot()
levels = categorical_order(long_df[col])
level_map = {x: float(i) for i, x in enumerate(levels)}
assert_vector_equal(m.passed_data[0]["x"], long_df[col].map(level_map))
def test_mark_data_from_datetime(self, long_df):
col = "t"
m = MockMark()
Plot(long_df, x=col).add(m).plot()
expected = long_df[col].map(mpl.dates.date2num)
assert_vector_equal(m.passed_data[0]["x"], expected)
def test_computed_var_ticks(self, long_df):
class Identity(Stat):
def __call__(self, df, groupby, orient, scales):
other = {"x": "y", "y": "x"}[orient]
return df.assign(**{other: df[orient]})
tick_locs = [1, 2, 5]
scale = Continuous().tick(at=tick_locs)
p = Plot(long_df, "x").add(MockMark(), Identity()).scale(y=scale).plot()
ax = p._figure.axes[0]
assert_array_equal(ax.get_yticks(), tick_locs)
def test_computed_var_transform(self, long_df):
class Identity(Stat):
def __call__(self, df, groupby, orient, scales):
other = {"x": "y", "y": "x"}[orient]
return df.assign(**{other: df[orient]})
p = Plot(long_df, "x").add(MockMark(), Identity()).scale(y="log").plot()
ax = p._figure.axes[0]
xfm = ax.yaxis.get_transform().transform
assert_array_equal(xfm([1, 10, 100]), [0, 1, 2])
def test_explicit_range_with_axis_scaling(self):
x = [1, 2, 3]
ymin = [10, 100, 1000]
ymax = [20, 200, 2000]
m = MockMark()
Plot(x=x, ymin=ymin, ymax=ymax).add(m).scale(y="log").plot()
assert_vector_equal(m.passed_data[0]["ymax"], pd.Series(ymax, dtype=float))
def test_derived_range_with_axis_scaling(self):
class AddOne(Stat):
def __call__(self, df, *args):
return df.assign(ymax=df["y"] + 1)
x = y = [1, 10, 100]
m = MockMark()
Plot(x, y).add(m, AddOne()).scale(y="log").plot()
assert_vector_equal(m.passed_data[0]["ymax"], pd.Series([10., 100., 1000.]))
def test_facet_categories(self):
m = MockMark()
p = Plot(x=["a", "b", "a", "c"]).facet(col=["x", "x", "y", "y"]).add(m).plot()
ax1, ax2 = p._figure.axes
assert len(ax1.get_xticks()) == 3
assert len(ax2.get_xticks()) == 3
assert_vector_equal(m.passed_data[0]["x"], pd.Series([0., 1.], [0, 1]))
assert_vector_equal(m.passed_data[1]["x"], pd.Series([0., 2.], [2, 3]))
def test_facet_categories_unshared(self):
m = MockMark()
p = (
Plot(x=["a", "b", "a", "c"])
.facet(col=["x", "x", "y", "y"])
.share(x=False)
.add(m)
.plot()
)
ax1, ax2 = p._figure.axes
assert len(ax1.get_xticks()) == 2
assert len(ax2.get_xticks()) == 2
assert_vector_equal(m.passed_data[0]["x"], pd.Series([0., 1.], [0, 1]))
assert_vector_equal(m.passed_data[1]["x"], pd.Series([0., 1.], [2, 3]))
def test_facet_categories_single_dim_shared(self):
data = [
("a", 1, 1), ("b", 1, 1),
("a", 1, 2), ("c", 1, 2),
("b", 2, 1), ("d", 2, 1),
("e", 2, 2), ("e", 2, 1),
]
df = pd.DataFrame(data, columns=["x", "row", "col"]).assign(y=1)
m = MockMark()
p = (
Plot(df, x="x")
.facet(row="row", col="col")
.add(m)
.share(x="row")
.plot()
)
axs = p._figure.axes
for ax in axs:
assert ax.get_xticks() == [0, 1, 2]
assert_vector_equal(m.passed_data[0]["x"], pd.Series([0., 1.], [0, 1]))
assert_vector_equal(m.passed_data[1]["x"], pd.Series([0., 2.], [2, 3]))
assert_vector_equal(m.passed_data[2]["x"], pd.Series([0., 1., 2.], [4, 5, 7]))
assert_vector_equal(m.passed_data[3]["x"], pd.Series([2.], [6]))
def test_pair_categories(self):
data = [("a", "a"), ("b", "c")]
df = pd.DataFrame(data, columns=["x1", "x2"]).assign(y=1)
m = MockMark()
p = Plot(df, y="y").pair(x=["x1", "x2"]).add(m).plot()
ax1, ax2 = p._figure.axes
assert ax1.get_xticks() == [0, 1]
assert ax2.get_xticks() == [0, 1]
assert_vector_equal(m.passed_data[0]["x"], pd.Series([0., 1.], [0, 1]))
assert_vector_equal(m.passed_data[1]["x"], pd.Series([0., 1.], [0, 1]))
def test_pair_categories_shared(self):
data = [("a", "a"), ("b", "c")]
df = pd.DataFrame(data, columns=["x1", "x2"]).assign(y=1)
m = MockMark()
p = Plot(df, y="y").pair(x=["x1", "x2"]).add(m).share(x=True).plot()
for ax in p._figure.axes:
assert ax.get_xticks() == [0, 1, 2]
print(m.passed_data)
assert_vector_equal(m.passed_data[0]["x"], pd.Series([0., 1.], [0, 1]))
assert_vector_equal(m.passed_data[1]["x"], pd.Series([0., 2.], [0, 1]))
def test_identity_mapping_linewidth(self):
m = MockMark()
x = y = [1, 2, 3, 4, 5]
lw = pd.Series([.5, .1, .1, .9, 3])
Plot(x=x, y=y, linewidth=lw).scale(linewidth=None).add(m).plot()
assert_vector_equal(m.passed_scales["linewidth"](lw), lw)
def test_pair_single_coordinate_stat_orient(self, long_df):
class MockStat(Stat):
def __call__(self, data, groupby, orient, scales):
self.orient = orient
return data
s = MockStat()
Plot(long_df).pair(x=["x", "y"]).add(MockMark(), s).plot()
assert s.orient == "x"
def test_inferred_nominal_passed_to_stat(self):
class MockStat(Stat):
def __call__(self, data, groupby, orient, scales):
self.scales = scales
return data
s = MockStat()
y = ["a", "a", "b", "c"]
Plot(y=y).add(MockMark(), s).plot()
assert s.scales["y"].__class__.__name__ == "Nominal"
# TODO where should RGB consistency be enforced?
@pytest.mark.xfail(
reason="Correct output representation for color with identity scale undefined"
)
def test_identity_mapping_color_strings(self):
m = MockMark()
x = y = [1, 2, 3]
c = ["C0", "C2", "C1"]
Plot(x=x, y=y, color=c).scale(color=None).add(m).plot()
expected = mpl.colors.to_rgba_array(c)[:, :3]
assert_array_equal(m.passed_scales["color"](c), expected)
def test_identity_mapping_color_tuples(self):
m = MockMark()
x = y = [1, 2, 3]
c = [(1, 0, 0), (0, 1, 0), (1, 0, 0)]
Plot(x=x, y=y, color=c).scale(color=None).add(m).plot()
expected = mpl.colors.to_rgba_array(c)[:, :3]
assert_array_equal(m.passed_scales["color"](c), expected)
@pytest.mark.xfail(
reason="Need decision on what to do with scale defined for unused variable"
)
def test_undefined_variable_raises(self):
p = Plot(x=[1, 2, 3], color=["a", "b", "c"]).scale(y=Continuous())
err = r"No data found for variable\(s\) with explicit scale: {'y'}"
with pytest.raises(RuntimeError, match=err):
p.plot()
def test_nominal_x_axis_tweaks(self):
p = Plot(x=["a", "b", "c"], y=[1, 2, 3])
ax1 = p.plot()._figure.axes[0]
assert ax1.get_xlim() == (-.5, 2.5)
assert not any(x.get_visible() for x in ax1.xaxis.get_gridlines())
lim = (-1, 2.1)
ax2 = p.limit(x=lim).plot()._figure.axes[0]
assert ax2.get_xlim() == lim
def test_nominal_y_axis_tweaks(self):
p = Plot(x=[1, 2, 3], y=["a", "b", "c"])
ax1 = p.plot()._figure.axes[0]
assert ax1.get_ylim() == (2.5, -.5)
assert not any(y.get_visible() for y in ax1.yaxis.get_gridlines())
lim = (-1, 2.1)
ax2 = p.limit(y=lim).plot()._figure.axes[0]
assert ax2.get_ylim() == lim
class TestPlotting:
def test_matplotlib_object_creation(self):
p = Plot().plot()
assert isinstance(p._figure, mpl.figure.Figure)
for sub in p._subplots:
assert isinstance(sub["ax"], mpl.axes.Axes)
def test_empty(self):
m = MockMark()
Plot().add(m).plot()
assert m.n_splits == 0
assert not m.passed_data
def test_no_orient_variance(self):
x, y = [0, 0], [1, 2]
m = MockMark()
Plot(x, y).add(m).plot()
assert_array_equal(m.passed_data[0]["x"], x)
assert_array_equal(m.passed_data[0]["y"], y)
def test_single_split_single_layer(self, long_df):
m = MockMark()
p = Plot(long_df, x="f", y="z").add(m).plot()
assert m.n_splits == 1
assert m.passed_keys[0] == {}
assert m.passed_axes == [sub["ax"] for sub in p._subplots]
for col in p._data.frame:
assert_series_equal(m.passed_data[0][col], p._data.frame[col])
def test_single_split_multi_layer(self, long_df):
vs = [{"color": "a", "linewidth": "z"}, {"color": "b", "pattern": "c"}]
class NoGroupingMark(MockMark):
_grouping_props = []
ms = [NoGroupingMark(), NoGroupingMark()]
Plot(long_df).add(ms[0], **vs[0]).add(ms[1], **vs[1]).plot()
for m, v in zip(ms, vs):
for var, col in v.items():
assert_vector_equal(m.passed_data[0][var], long_df[col])
def check_splits_single_var(
self, data, mark, data_vars, split_var, split_col, split_keys
):
assert mark.n_splits == len(split_keys)
assert mark.passed_keys == [{split_var: key} for key in split_keys]
for i, key in enumerate(split_keys):
split_data = data[data[split_col] == key]
for var, col in data_vars.items():
assert_array_equal(mark.passed_data[i][var], split_data[col])
def check_splits_multi_vars(
self, data, mark, data_vars, split_vars, split_cols, split_keys
):
assert mark.n_splits == np.prod([len(ks) for ks in split_keys])
expected_keys = [
dict(zip(split_vars, level_keys))
for level_keys in itertools.product(*split_keys)
]
assert mark.passed_keys == expected_keys
for i, keys in enumerate(itertools.product(*split_keys)):
use_rows = pd.Series(True, data.index)
for var, col, key in zip(split_vars, split_cols, keys):
use_rows &= data[col] == key
split_data = data[use_rows]
for var, col in data_vars.items():
assert_array_equal(mark.passed_data[i][var], split_data[col])
@pytest.mark.parametrize(
"split_var", [
"color", # explicitly declared on the Mark
"group", # implicitly used for all Mark classes
])
def test_one_grouping_variable(self, long_df, split_var):
split_col = "a"
data_vars = {"x": "f", "y": "z", split_var: split_col}
m = MockMark()
p = Plot(long_df, **data_vars).add(m).plot()
split_keys = categorical_order(long_df[split_col])
sub, *_ = p._subplots
assert m.passed_axes == [sub["ax"] for _ in split_keys]
self.check_splits_single_var(
long_df, m, data_vars, split_var, split_col, split_keys
)
def test_two_grouping_variables(self, long_df):
split_vars = ["color", "group"]
split_cols = ["a", "b"]
data_vars = {"y": "z", **{var: col for var, col in zip(split_vars, split_cols)}}
m = MockMark()
p = Plot(long_df, **data_vars).add(m).plot()
split_keys = [categorical_order(long_df[col]) for col in split_cols]
sub, *_ = p._subplots
assert m.passed_axes == [
sub["ax"] for _ in itertools.product(*split_keys)
]
self.check_splits_multi_vars(
long_df, m, data_vars, split_vars, split_cols, split_keys
)
def test_specified_width(self, long_df):
m = MockMark()
Plot(long_df, x="x", y="y").add(m, width="z").plot()
assert_array_almost_equal(m.passed_data[0]["width"], long_df["z"])
def test_facets_no_subgroups(self, long_df):
split_var = "col"
split_col = "b"
data_vars = {"x": "f", "y": "z"}
m = MockMark()
p = Plot(long_df, **data_vars).facet(**{split_var: split_col}).add(m).plot()
split_keys = categorical_order(long_df[split_col])
assert m.passed_axes == list(p._figure.axes)
self.check_splits_single_var(
long_df, m, data_vars, split_var, split_col, split_keys
)
def test_facets_one_subgroup(self, long_df):
facet_var, facet_col = fx = "col", "a"
group_var, group_col = gx = "group", "b"
split_vars, split_cols = zip(*[fx, gx])
data_vars = {"x": "f", "y": "z", group_var: group_col}
m = MockMark()
p = (
Plot(long_df, **data_vars)
.facet(**{facet_var: facet_col})
.add(m)
.plot()
)
split_keys = [categorical_order(long_df[col]) for col in [facet_col, group_col]]
assert m.passed_axes == [
ax
for ax in list(p._figure.axes)
for _ in categorical_order(long_df[group_col])
]
self.check_splits_multi_vars(
long_df, m, data_vars, split_vars, split_cols, split_keys
)
def test_layer_specific_facet_disabling(self, long_df):
axis_vars = {"x": "y", "y": "z"}
row_var = "a"
m = MockMark()
p = Plot(long_df, **axis_vars).facet(row=row_var).add(m, row=None).plot()
col_levels = categorical_order(long_df[row_var])
assert len(p._figure.axes) == len(col_levels)
for data in m.passed_data:
for var, col in axis_vars.items():
assert_vector_equal(data[var], long_df[col])
def test_paired_variables(self, long_df):
x = ["x", "y"]
y = ["f", "z"]
m = MockMark()
Plot(long_df).pair(x, y).add(m).plot()
var_product = itertools.product(x, y)
for data, (x_i, y_i) in zip(m.passed_data, var_product):
assert_vector_equal(data["x"], long_df[x_i].astype(float))
assert_vector_equal(data["y"], long_df[y_i].astype(float))
def test_paired_one_dimension(self, long_df):
x = ["y", "z"]
m = MockMark()
Plot(long_df).pair(x).add(m).plot()
for data, x_i in zip(m.passed_data, x):
assert_vector_equal(data["x"], long_df[x_i].astype(float))
def test_paired_variables_one_subset(self, long_df):
x = ["x", "y"]
y = ["f", "z"]
group = "a"
long_df["x"] = long_df["x"].astype(float) # simplify vector comparison
m = MockMark()
Plot(long_df, group=group).pair(x, y).add(m).plot()
groups = categorical_order(long_df[group])
var_product = itertools.product(x, y, groups)
for data, (x_i, y_i, g_i) in zip(m.passed_data, var_product):
rows = long_df[group] == g_i
assert_vector_equal(data["x"], long_df.loc[rows, x_i])
assert_vector_equal(data["y"], long_df.loc[rows, y_i])
def test_paired_and_faceted(self, long_df):
x = ["y", "z"]
y = "f"
row = "c"
m = MockMark()
Plot(long_df, y=y).facet(row=row).pair(x).add(m).plot()
facets = categorical_order(long_df[row])
var_product = itertools.product(x, facets)
for data, (x_i, f_i) in zip(m.passed_data, var_product):
rows = long_df[row] == f_i
assert_vector_equal(data["x"], long_df.loc[rows, x_i])
assert_vector_equal(data["y"], long_df.loc[rows, y])
def test_theme_default(self):
p = Plot().plot()
assert mpl.colors.same_color(p._figure.axes[0].get_facecolor(), "#EAEAF2")
def test_theme_params(self):
color = ".888"
p = Plot().theme({"axes.facecolor": color}).plot()
assert mpl.colors.same_color(p._figure.axes[0].get_facecolor(), color)
def test_theme_error(self):
p = Plot()
with pytest.raises(TypeError, match=r"theme\(\) takes 2 positional"):
p.theme("arg1", "arg2")
def test_theme_validation(self):
p = Plot()
# You'd think matplotlib would raise a TypeError here, but it doesn't
with pytest.raises(ValueError, match="Key axes.linewidth:"):
p.theme({"axes.linewidth": "thick"})
with pytest.raises(KeyError, match="not.a.key is not a valid rc"):
p.theme({"not.a.key": True})
def test_stat(self, long_df):
orig_df = long_df.copy(deep=True)
m = MockMark()
Plot(long_df, x="a", y="z").add(m, Agg()).plot()
expected = long_df.groupby("a", sort=False)["z"].mean().reset_index(drop=True)
assert_vector_equal(m.passed_data[0]["y"], expected)
assert_frame_equal(long_df, orig_df) # Test data was not mutated
def test_move(self, long_df):
orig_df = long_df.copy(deep=True)
m = MockMark()
Plot(long_df, x="z", y="z").add(m, Shift(x=1)).plot()
assert_vector_equal(m.passed_data[0]["x"], long_df["z"] + 1)
assert_vector_equal(m.passed_data[0]["y"], long_df["z"])
assert_frame_equal(long_df, orig_df) # Test data was not mutated
def test_stat_and_move(self, long_df):
m = MockMark()
Plot(long_df, x="a", y="z").add(m, Agg(), Shift(y=1)).plot()
expected = long_df.groupby("a", sort=False)["z"].mean().reset_index(drop=True)
assert_vector_equal(m.passed_data[0]["y"], expected + 1)
def test_stat_log_scale(self, long_df):
orig_df = long_df.copy(deep=True)
m = MockMark()
Plot(long_df, x="a", y="z").add(m, Agg()).scale(y="log").plot()
x = long_df["a"]
y = np.log10(long_df["z"])
expected = y.groupby(x, sort=False).mean().reset_index(drop=True)
assert_vector_equal(m.passed_data[0]["y"], 10 ** expected)
assert_frame_equal(long_df, orig_df) # Test data was not mutated
def test_move_log_scale(self, long_df):
m = MockMark()
Plot(
long_df, x="z", y="z"
).scale(x="log").add(m, Shift(x=-1)).plot()
assert_vector_equal(m.passed_data[0]["x"], long_df["z"] / 10)
def test_multi_move(self, long_df):
m = MockMark()
move_stack = [Shift(1), Shift(2)]
Plot(long_df, x="x", y="y").add(m, *move_stack).plot()
assert_vector_equal(m.passed_data[0]["x"], long_df["x"] + 3)
def test_multi_move_with_pairing(self, long_df):
m = MockMark()
move_stack = [Shift(1), Shift(2)]
Plot(long_df, x="x").pair(y=["y", "z"]).add(m, *move_stack).plot()
for frame in m.passed_data:
assert_vector_equal(frame["x"], long_df["x"] + 3)
def test_move_with_range(self, long_df):
x = [0, 0, 1, 1, 2, 2]
group = [0, 1, 0, 1, 0, 1]
ymin = np.arange(6)
ymax = np.arange(6) * 2
m = MockMark()
Plot(x=x, group=group, ymin=ymin, ymax=ymax).add(m, Dodge()).plot()
signs = [-1, +1]
for i, df in m.passed_data[0].groupby("group"):
assert_array_equal(df["x"], np.arange(3) + signs[i] * 0.2)
def test_methods_clone(self, long_df):
p1 = Plot(long_df, "x", "y")
p2 = p1.add(MockMark()).facet("a")
assert p1 is not p2
assert not p1._layers
assert not p1._facet_spec
def test_default_is_no_pyplot(self):
p = Plot().plot()
assert not plt.get_fignums()
assert isinstance(p._figure, mpl.figure.Figure)
def test_with_pyplot(self):
p = Plot().plot(pyplot=True)
assert len(plt.get_fignums()) == 1
fig = plt.gcf()
assert p._figure is fig
def test_show(self):
p = Plot()
with warnings.catch_warnings(record=True) as msg:
out = p.show(block=False)
assert out is None
assert not hasattr(p, "_figure")
assert len(plt.get_fignums()) == 1
fig = plt.gcf()
gui_backend = (
# From https://github.com/matplotlib/matplotlib/issues/20281
fig.canvas.manager.show != mpl.backend_bases.FigureManagerBase.show
)
if not gui_backend:
assert msg
def test_save(self):
buf = io.BytesIO()
p = Plot().save(buf)
assert isinstance(p, Plot)
img = Image.open(buf)
assert img.format == "PNG"
buf = io.StringIO()
Plot().save(buf, format="svg")
tag = xml.etree.ElementTree.fromstring(buf.getvalue()).tag
assert tag == "{http://www.w3.org/2000/svg}svg"
def test_layout_size(self):
size = (4, 2)
p = Plot().layout(size=size).plot()
assert tuple(p._figure.get_size_inches()) == size
@pytest.mark.skipif(
_version_predates(mpl, "3.6"),
reason="mpl<3.6 does not have get_layout_engine",
)
def test_layout_extent(self):
p = Plot().layout(extent=(.1, .2, .6, 1)).plot()
assert p._figure.get_layout_engine().get()["rect"] == [.1, .2, .5, .8]
@pytest.mark.skipif(
_version_predates(mpl, "3.6"),
reason="mpl<3.6 does not have get_layout_engine",
)
def test_constrained_layout_extent(self):
p = Plot().layout(engine="constrained", extent=(.1, .2, .6, 1)).plot()
assert p._figure.get_layout_engine().get()["rect"] == [.1, .2, .5, .8]
def test_base_layout_extent(self):
p = Plot().layout(engine=None, extent=(.1, .2, .6, 1)).plot()
assert p._figure.subplotpars.left == 0.1
assert p._figure.subplotpars.right == 0.6
assert p._figure.subplotpars.bottom == 0.2
assert p._figure.subplotpars.top == 1
def test_on_axes(self):
ax = mpl.figure.Figure().subplots()
m = MockMark()
p = Plot([1], [2]).on(ax).add(m).plot()
assert m.passed_axes == [ax]
assert p._figure is ax.figure
@pytest.mark.parametrize("facet", [True, False])
def test_on_figure(self, facet):
f = mpl.figure.Figure()
m = MockMark()
p = Plot([1, 2], [3, 4]).on(f).add(m)
if facet:
p = p.facet(["a", "b"])
p = p.plot()
assert m.passed_axes == f.axes
assert p._figure is f
@pytest.mark.parametrize("facet", [True, False])
def test_on_subfigure(self, facet):
sf1, sf2 = mpl.figure.Figure().subfigures(2)
sf1.subplots()
m = MockMark()
p = Plot([1, 2], [3, 4]).on(sf2).add(m)
if facet:
p = p.facet(["a", "b"])
p = p.plot()
assert m.passed_axes == sf2.figure.axes[1:]
assert p._figure is sf2.figure
def test_on_type_check(self):
p = Plot()
with pytest.raises(TypeError, match="The `Plot.on`.+"):
p.on([])
def test_on_axes_with_subplots_error(self):
ax = mpl.figure.Figure().subplots()
p1 = Plot().facet(["a", "b"]).on(ax)
with pytest.raises(RuntimeError, match="Cannot create multiple subplots"):
p1.plot()
p2 = Plot().pair([["a", "b"], ["x", "y"]]).on(ax)
with pytest.raises(RuntimeError, match="Cannot create multiple subplots"):
p2.plot()
@pytest.mark.skipif(
_version_predates(mpl, "3.6"),
reason="Requires newer matplotlib layout engine API"
)
def test_on_layout_algo_default(self):
class MockEngine(mpl.layout_engine.ConstrainedLayoutEngine):
...
f = mpl.figure.Figure(layout=MockEngine())
p = Plot().on(f).plot()
layout_engine = p._figure.get_layout_engine()
assert layout_engine.__class__.__name__ == "MockEngine"
@pytest.mark.skipif(
_version_predates(mpl, "3.6"),
reason="Requires newer matplotlib layout engine API"
)
def test_on_layout_algo_spec(self):
f = mpl.figure.Figure(layout="constrained")
p = Plot().on(f).layout(engine="tight").plot()
layout_engine = p._figure.get_layout_engine()
assert layout_engine.__class__.__name__ == "TightLayoutEngine"
def test_axis_labels_from_constructor(self, long_df):
ax, = Plot(long_df, x="a", y="b").plot()._figure.axes
assert ax.get_xlabel() == "a"
assert ax.get_ylabel() == "b"
ax, = Plot(x=long_df["a"], y=long_df["b"].to_numpy()).plot()._figure.axes
assert ax.get_xlabel() == "a"
assert ax.get_ylabel() == ""
def test_axis_labels_from_layer(self, long_df):
m = MockMark()
ax, = Plot(long_df).add(m, x="a", y="b").plot()._figure.axes
assert ax.get_xlabel() == "a"
assert ax.get_ylabel() == "b"
p = Plot().add(m, x=long_df["a"], y=long_df["b"].to_list())
ax, = p.plot()._figure.axes
assert ax.get_xlabel() == "a"
assert ax.get_ylabel() == ""
def test_axis_labels_are_first_name(self, long_df):
m = MockMark()
p = (
Plot(long_df, x=long_df["z"].to_list(), y="b")
.add(m, x="a")
.add(m, x="x", y="y")
)
ax, = p.plot()._figure.axes
assert ax.get_xlabel() == "a"
assert ax.get_ylabel() == "b"
def test_limits(self, long_df):
limit = (-2, 24)
p = Plot(long_df, x="x", y="y").limit(x=limit).plot()
ax = p._figure.axes[0]
assert ax.get_xlim() == limit
limit = (np.datetime64("2005-01-01"), np.datetime64("2008-01-01"))
p = Plot(long_df, x="d", y="y").limit(x=limit).plot()
ax = p._figure.axes[0]
assert ax.get_xlim() == tuple(mpl.dates.date2num(limit))
limit = ("b", "c")
p = Plot(x=["a", "b", "c", "d"], y=[1, 2, 3, 4]).limit(x=limit).plot()
ax = p._figure.axes[0]
assert ax.get_xlim() == (0.5, 2.5)
def test_labels_axis(self, long_df):
label = "Y axis"
p = Plot(long_df, x="x", y="y").label(y=label).plot()
ax = p._figure.axes[0]
assert ax.get_ylabel() == label
label = str.capitalize
p = Plot(long_df, x="x", y="y").label(y=label).plot()
ax = p._figure.axes[0]
assert ax.get_ylabel() == "Y"
def test_labels_legend(self, long_df):
m = MockMark()
label = "A"
p = Plot(long_df, x="x", y="y", color="a").add(m).label(color=label).plot()
assert p._figure.legends[0].get_title().get_text() == label
func = str.capitalize
p = Plot(long_df, x="x", y="y", color="a").add(m).label(color=func).plot()
assert p._figure.legends[0].get_title().get_text() == label
def test_labels_facets(self):
data = {"a": ["b", "c"], "x": ["y", "z"]}
p = Plot(data).facet("a", "x").label(col=str.capitalize, row="$x$").plot()
axs = np.reshape(p._figure.axes, (2, 2))
for (i, j), ax in np.ndenumerate(axs):
expected = f"A {data['a'][j]} | $x$ {data['x'][i]}"
assert ax.get_title() == expected
def test_title_single(self):
label = "A"
p = Plot().label(title=label).plot()
assert p._figure.axes[0].get_title() == label
def test_title_facet_function(self):
titles = ["a", "b"]
p = Plot().facet(titles).label(title=str.capitalize).plot()
for i, ax in enumerate(p._figure.axes):
assert ax.get_title() == titles[i].upper()
cols, rows = ["a", "b"], ["x", "y"]
p = Plot().facet(cols, rows).label(title=str.capitalize).plot()
for i, ax in enumerate(p._figure.axes):
expected = " | ".join([cols[i % 2].upper(), rows[i // 2].upper()])
assert ax.get_title() == expected
class TestExceptions:
def test_scale_setup(self):
x = y = color = ["a", "b"]
bad_palette = "not_a_palette"
p = Plot(x, y, color=color).add(MockMark()).scale(color=bad_palette)
msg = "Scale setup failed for the `color` variable."
with pytest.raises(PlotSpecError, match=msg) as err:
p.plot()
assert isinstance(err.value.__cause__, ValueError)
assert bad_palette in str(err.value.__cause__)
def test_coordinate_scaling(self):
x = ["a", "b"]
y = [1, 2]
p = Plot(x, y).add(MockMark()).scale(x=Temporal())
msg = "Scaling operation failed for the `x` variable."
with pytest.raises(PlotSpecError, match=msg) as err:
p.plot()
# Don't test the cause contents b/c matplotlib owns them here.
assert hasattr(err.value, "__cause__")
def test_semantic_scaling(self):
class ErrorRaising(Continuous):
def _setup(self, data, prop, axis=None):
def f(x):
raise ValueError("This is a test")
new = super()._setup(data, prop, axis)
new._pipeline = [f]
return new
x = y = color = [1, 2]
p = Plot(x, y, color=color).add(Dot()).scale(color=ErrorRaising())
msg = "Scaling operation failed for the `color` variable."
with pytest.raises(PlotSpecError, match=msg) as err:
p.plot()
assert isinstance(err.value.__cause__, ValueError)
assert str(err.value.__cause__) == "This is a test"
class TestFacetInterface:
@pytest.fixture(scope="class", params=["row", "col"])
def dim(self, request):
return request.param
@pytest.fixture(scope="class", params=["reverse", "subset", "expand"])
def reorder(self, request):
return {
"reverse": lambda x: x[::-1],
"subset": lambda x: x[:-1],
"expand": lambda x: x + ["z"],
}[request.param]
def check_facet_results_1d(self, p, df, dim, key, order=None):
p = p.plot()
order = categorical_order(df[key], order)
assert len(p._figure.axes) == len(order)
other_dim = {"row": "col", "col": "row"}[dim]
for subplot, level in zip(p._subplots, order):
assert subplot[dim] == level
assert subplot[other_dim] is None
assert subplot["ax"].get_title() == f"{level}"
assert_gridspec_shape(subplot["ax"], **{f"n{dim}s": len(order)})
def test_1d(self, long_df, dim):
key = "a"
p = Plot(long_df).facet(**{dim: key})
self.check_facet_results_1d(p, long_df, dim, key)
def test_1d_as_vector(self, long_df, dim):
key = "a"
p = Plot(long_df).facet(**{dim: long_df[key]})
self.check_facet_results_1d(p, long_df, dim, key)
def test_1d_with_order(self, long_df, dim, reorder):
key = "a"
order = reorder(categorical_order(long_df[key]))
p = Plot(long_df).facet(**{dim: key, "order": order})
self.check_facet_results_1d(p, long_df, dim, key, order)
def check_facet_results_2d(self, p, df, variables, order=None):
p = p.plot()
if order is None:
order = {dim: categorical_order(df[key]) for dim, key in variables.items()}
levels = itertools.product(*[order[dim] for dim in ["row", "col"]])
assert len(p._subplots) == len(list(levels))
for subplot, (row_level, col_level) in zip(p._subplots, levels):
assert subplot["row"] == row_level
assert subplot["col"] == col_level
assert subplot["axes"].get_title() == (
f"{col_level} | {row_level}"
)
assert_gridspec_shape(
subplot["axes"], len(levels["row"]), len(levels["col"])
)
def test_2d(self, long_df):
variables = {"row": "a", "col": "c"}
p = Plot(long_df).facet(**variables)
self.check_facet_results_2d(p, long_df, variables)
def test_2d_with_order(self, long_df, reorder):
variables = {"row": "a", "col": "c"}
order = {
dim: reorder(categorical_order(long_df[key]))
for dim, key in variables.items()
}
p = Plot(long_df).facet(**variables, order=order)
self.check_facet_results_2d(p, long_df, variables, order)
@pytest.mark.parametrize("algo", ["tight", "constrained"])
def test_layout_algo(self, algo):
p = Plot().facet(["a", "b"]).limit(x=(.1, .9))
p1 = p.layout(engine=algo).plot()
p2 = p.layout(engine="none").plot()
# Force a draw (we probably need a method for this)
p1.save(io.BytesIO())
p2.save(io.BytesIO())
bb11, bb12 = [ax.get_position() for ax in p1._figure.axes]
bb21, bb22 = [ax.get_position() for ax in p2._figure.axes]
sep1 = bb12.corners()[0, 0] - bb11.corners()[2, 0]
sep2 = bb22.corners()[0, 0] - bb21.corners()[2, 0]
assert sep1 <= sep2
def test_axis_sharing(self, long_df):
variables = {"row": "a", "col": "c"}
p = Plot(long_df).facet(**variables)
p1 = p.plot()
root, *other = p1._figure.axes
for axis in "xy":
shareset = getattr(root, f"get_shared_{axis}_axes")()
assert all(shareset.joined(root, ax) for ax in other)
p2 = p.share(x=False, y=False).plot()
root, *other = p2._figure.axes
for axis in "xy":
shareset = getattr(root, f"get_shared_{axis}_axes")()
assert not any(shareset.joined(root, ax) for ax in other)
p3 = p.share(x="col", y="row").plot()
shape = (
len(categorical_order(long_df[variables["row"]])),
len(categorical_order(long_df[variables["col"]])),
)
axes_matrix = np.reshape(p3._figure.axes, shape)
for (shared, unshared), vectors in zip(
["yx", "xy"], [axes_matrix, axes_matrix.T]
):
for root, *other in vectors:
shareset = {
axis: getattr(root, f"get_shared_{axis}_axes")() for axis in "xy"
}
assert all(shareset[shared].joined(root, ax) for ax in other)
assert not any(shareset[unshared].joined(root, ax) for ax in other)
def test_unshared_spacing(self):
x = [1, 2, 10, 20]
y = [1, 2, 3, 4]
col = [1, 1, 2, 2]
m = MockMark()
Plot(x, y).facet(col).add(m).share(x=False).plot()
assert_array_almost_equal(m.passed_data[0]["width"], [0.8, 0.8])
assert_array_equal(m.passed_data[1]["width"], [8, 8])
def test_col_wrapping(self):
cols = list("abcd")
wrap = 3
p = Plot().facet(col=cols, wrap=wrap).plot()
assert len(p._figure.axes) == 4
assert_gridspec_shape(p._figure.axes[0], len(cols) // wrap + 1, wrap)
# TODO test axis labels and titles
def test_row_wrapping(self):
rows = list("abcd")
wrap = 3
p = Plot().facet(row=rows, wrap=wrap).plot()
assert_gridspec_shape(p._figure.axes[0], wrap, len(rows) // wrap + 1)
assert len(p._figure.axes) == 4
# TODO test axis labels and titles
class TestPairInterface:
def check_pair_grid(self, p, x, y):
xys = itertools.product(y, x)
for (y_i, x_j), subplot in zip(xys, p._subplots):
ax = subplot["ax"]
assert ax.get_xlabel() == "" if x_j is None else x_j
assert ax.get_ylabel() == "" if y_i is None else y_i
assert_gridspec_shape(subplot["ax"], len(y), len(x))
@pytest.mark.parametrize("vector_type", [list, pd.Index])
def test_all_numeric(self, long_df, vector_type):
x, y = ["x", "y", "z"], ["s", "f"]
p = Plot(long_df).pair(vector_type(x), vector_type(y)).plot()
self.check_pair_grid(p, x, y)
def test_single_variable_key_raises(self, long_df):
p = Plot(long_df)
err = "You must pass a sequence of variable keys to `y`"
with pytest.raises(TypeError, match=err):
p.pair(x=["x", "y"], y="z")
@pytest.mark.parametrize("dim", ["x", "y"])
def test_single_dimension(self, long_df, dim):
variables = {"x": None, "y": None}
variables[dim] = ["x", "y", "z"]
p = Plot(long_df).pair(**variables).plot()
variables = {k: [v] if v is None else v for k, v in variables.items()}
self.check_pair_grid(p, **variables)
def test_non_cross(self, long_df):
x = ["x", "y"]
y = ["f", "z"]
p = Plot(long_df).pair(x, y, cross=False).plot()
for i, subplot in enumerate(p._subplots):
ax = subplot["ax"]
assert ax.get_xlabel() == x[i]
assert ax.get_ylabel() == y[i]
assert_gridspec_shape(ax, 1, len(x))
root, *other = p._figure.axes
for axis in "xy":
shareset = getattr(root, f"get_shared_{axis}_axes")()
assert not any(shareset.joined(root, ax) for ax in other)
def test_list_of_vectors(self, long_df):
x_vars = ["x", "z"]
p = Plot(long_df, y="y").pair(x=[long_df[x] for x in x_vars]).plot()
assert len(p._figure.axes) == len(x_vars)
for ax, x_i in zip(p._figure.axes, x_vars):
assert ax.get_xlabel() == x_i
def test_with_no_variables(self, long_df):
p = Plot(long_df).pair().plot()
assert len(p._figure.axes) == 1
def test_with_facets(self, long_df):
x = "x"
y = ["y", "z"]
col = "a"
p = Plot(long_df, x=x).facet(col).pair(y=y).plot()
facet_levels = categorical_order(long_df[col])
dims = itertools.product(y, facet_levels)
for (y_i, col_i), subplot in zip(dims, p._subplots):
ax = subplot["ax"]
assert ax.get_xlabel() == x
assert ax.get_ylabel() == y_i
assert ax.get_title() == f"{col_i}"
assert_gridspec_shape(ax, len(y), len(facet_levels))
@pytest.mark.parametrize("variables", [("rows", "y"), ("columns", "x")])
def test_error_on_facet_overlap(self, long_df, variables):
facet_dim, pair_axis = variables
p = Plot(long_df).facet(**{facet_dim[:3]: "a"}).pair(**{pair_axis: ["x", "y"]})
expected = f"Cannot facet the {facet_dim} while pairing on `{pair_axis}`."
with pytest.raises(RuntimeError, match=expected):
p.plot()
@pytest.mark.parametrize("variables", [("columns", "y"), ("rows", "x")])
def test_error_on_wrap_overlap(self, long_df, variables):
facet_dim, pair_axis = variables
p = (
Plot(long_df)
.facet(wrap=2, **{facet_dim[:3]: "a"})
.pair(**{pair_axis: ["x", "y"]})
)
expected = f"Cannot wrap the {facet_dim} while pairing on `{pair_axis}``."
with pytest.raises(RuntimeError, match=expected):
p.plot()
def test_axis_sharing(self, long_df):
p = Plot(long_df).pair(x=["a", "b"], y=["y", "z"])
shape = 2, 2
p1 = p.plot()
axes_matrix = np.reshape(p1._figure.axes, shape)
for root, *other in axes_matrix: # Test row-wise sharing
x_shareset = getattr(root, "get_shared_x_axes")()
assert not any(x_shareset.joined(root, ax) for ax in other)
y_shareset = getattr(root, "get_shared_y_axes")()
assert all(y_shareset.joined(root, ax) for ax in other)
for root, *other in axes_matrix.T: # Test col-wise sharing
x_shareset = getattr(root, "get_shared_x_axes")()
assert all(x_shareset.joined(root, ax) for ax in other)
y_shareset = getattr(root, "get_shared_y_axes")()
assert not any(y_shareset.joined(root, ax) for ax in other)
p2 = p.share(x=False, y=False).plot()
root, *other = p2._figure.axes
for axis in "xy":
shareset = getattr(root, f"get_shared_{axis}_axes")()
assert not any(shareset.joined(root, ax) for ax in other)
def test_axis_sharing_with_facets(self, long_df):
p = Plot(long_df, y="y").pair(x=["a", "b"]).facet(row="c").plot()
shape = 2, 2
axes_matrix = np.reshape(p._figure.axes, shape)
for root, *other in axes_matrix: # Test row-wise sharing
x_shareset = getattr(root, "get_shared_x_axes")()
assert not any(x_shareset.joined(root, ax) for ax in other)
y_shareset = getattr(root, "get_shared_y_axes")()
assert all(y_shareset.joined(root, ax) for ax in other)
for root, *other in axes_matrix.T: # Test col-wise sharing
x_shareset = getattr(root, "get_shared_x_axes")()
assert all(x_shareset.joined(root, ax) for ax in other)
y_shareset = getattr(root, "get_shared_y_axes")()
assert all(y_shareset.joined(root, ax) for ax in other)
def test_x_wrapping(self, long_df):
x_vars = ["f", "x", "y", "z"]
wrap = 3
p = Plot(long_df, y="y").pair(x=x_vars, wrap=wrap).plot()
assert_gridspec_shape(p._figure.axes[0], len(x_vars) // wrap + 1, wrap)
assert len(p._figure.axes) == len(x_vars)
for ax, var in zip(p._figure.axes, x_vars):
label = ax.xaxis.get_label()
assert label.get_visible()
assert label.get_text() == var
def test_y_wrapping(self, long_df):
y_vars = ["f", "x", "y", "z"]
wrap = 3
p = Plot(long_df, x="x").pair(y=y_vars, wrap=wrap).plot()
n_row, n_col = wrap, len(y_vars) // wrap + 1
assert_gridspec_shape(p._figure.axes[0], n_row, n_col)
assert len(p._figure.axes) == len(y_vars)
label_array = np.empty(n_row * n_col, object)
label_array[:len(y_vars)] = y_vars
label_array = label_array.reshape((n_row, n_col), order="F")
label_array = [y for y in label_array.flat if y is not None]
for i, ax in enumerate(p._figure.axes):
label = ax.yaxis.get_label()
assert label.get_visible()
assert label.get_text() == label_array[i]
def test_non_cross_wrapping(self, long_df):
x_vars = ["a", "b", "c", "t"]
y_vars = ["f", "x", "y", "z"]
wrap = 3
p = (
Plot(long_df, x="x")
.pair(x=x_vars, y=y_vars, wrap=wrap, cross=False)
.plot()
)
assert_gridspec_shape(p._figure.axes[0], len(x_vars) // wrap + 1, wrap)
assert len(p._figure.axes) == len(x_vars)
def test_cross_mismatched_lengths(self, long_df):
p = Plot(long_df)
with pytest.raises(ValueError, match="Lengths of the `x` and `y`"):
p.pair(x=["a", "b"], y=["x", "y", "z"], cross=False)
def test_orient_inference(self, long_df):
orient_list = []
class CaptureOrientMove(Move):
def __call__(self, data, groupby, orient, scales):
orient_list.append(orient)
return data
(
Plot(long_df, x="x")
.pair(y=["b", "z"])
.add(MockMark(), CaptureOrientMove())
.plot()
)
assert orient_list == ["y", "x"]
def test_computed_coordinate_orient_inference(self, long_df):
class MockComputeStat(Stat):
def __call__(self, df, groupby, orient, scales):
other = {"x": "y", "y": "x"}[orient]
return df.assign(**{other: df[orient] * 2})
m = MockMark()
Plot(long_df, y="y").add(m, MockComputeStat()).plot()
assert m.passed_orient == "y"
def test_two_variables_single_order_error(self, long_df):
p = Plot(long_df)
err = "When faceting on both col= and row=, passing `order`"
with pytest.raises(RuntimeError, match=err):
p.facet(col="a", row="b", order=["a", "b", "c"])
def test_limits(self, long_df):
lims = (-3, 10), (-2, 24)
p = Plot(long_df, y="y").pair(x=["x", "z"]).limit(x=lims[0], x1=lims[1]).plot()
for ax, lim in zip(p._figure.axes, lims):
assert ax.get_xlim() == lim
def test_labels(self, long_df):
label = "zed"
p = (
Plot(long_df, y="y")
.pair(x=["x", "z"])
.label(x=str.capitalize, x1=label)
)
ax0, ax1 = p.plot()._figure.axes
assert ax0.get_xlabel() == "X"
assert ax1.get_xlabel() == label
class TestLabelVisibility:
def has_xaxis_labels(self, ax):
if _version_predates(mpl, "3.7"):
# mpl3.7 added a getter for tick params, but both yaxis and xaxis return
# the same entry of "labelleft" instead of "labelbottom" for xaxis
return len(ax.get_xticklabels()) > 0
elif _version_predates(mpl, "3.10"):
# Then I guess they made it labelbottom in 3.10?
return ax.xaxis.get_tick_params()["labelleft"]
else:
return ax.xaxis.get_tick_params()["labelbottom"]
def test_single_subplot(self, long_df):
x, y = "a", "z"
p = Plot(long_df, x=x, y=y).plot()
subplot, *_ = p._subplots
ax = subplot["ax"]
assert ax.xaxis.get_label().get_visible()
assert ax.yaxis.get_label().get_visible()
assert all(t.get_visible() for t in ax.get_xticklabels())
assert all(t.get_visible() for t in ax.get_yticklabels())
@pytest.mark.parametrize(
"facet_kws,pair_kws", [({"col": "b"}, {}), ({}, {"x": ["x", "y", "f"]})]
)
def test_1d_column(self, long_df, facet_kws, pair_kws):
x = None if "x" in pair_kws else "a"
y = "z"
p = Plot(long_df, x=x, y=y).plot()
first, *other = p._subplots
ax = first["ax"]
assert ax.xaxis.get_label().get_visible()
assert ax.yaxis.get_label().get_visible()
assert all(t.get_visible() for t in ax.get_xticklabels())
assert all(t.get_visible() for t in ax.get_yticklabels())
for s in other:
ax = s["ax"]
assert ax.xaxis.get_label().get_visible()
assert not ax.yaxis.get_label().get_visible()
assert all(t.get_visible() for t in ax.get_xticklabels())
assert not any(t.get_visible() for t in ax.get_yticklabels())
@pytest.mark.parametrize(
"facet_kws,pair_kws", [({"row": "b"}, {}), ({}, {"y": ["x", "y", "f"]})]
)
def test_1d_row(self, long_df, facet_kws, pair_kws):
x = "z"
y = None if "y" in pair_kws else "z"
p = Plot(long_df, x=x, y=y).plot()
first, *other = p._subplots
ax = first["ax"]
assert ax.xaxis.get_label().get_visible()
assert all(t.get_visible() for t in ax.get_xticklabels())
assert ax.yaxis.get_label().get_visible()
assert all(t.get_visible() for t in ax.get_yticklabels())
for s in other:
ax = s["ax"]
assert not ax.xaxis.get_label().get_visible()
assert ax.yaxis.get_label().get_visible()
assert not any(t.get_visible() for t in ax.get_xticklabels())
assert all(t.get_visible() for t in ax.get_yticklabels())
def test_1d_column_wrapped(self):
p = Plot().facet(col=["a", "b", "c", "d"], wrap=3).plot()
subplots = list(p._subplots)
for s in [subplots[0], subplots[-1]]:
ax = s["ax"]
assert ax.yaxis.get_label().get_visible()
assert all(t.get_visible() for t in ax.get_yticklabels())
for s in subplots[1:]:
ax = s["ax"]
assert ax.xaxis.get_label().get_visible()
assert self.has_xaxis_labels(ax)
assert all(t.get_visible() for t in ax.get_xticklabels())
for s in subplots[1:-1]:
ax = s["ax"]
assert not ax.yaxis.get_label().get_visible()
assert not any(t.get_visible() for t in ax.get_yticklabels())
ax = subplots[0]["ax"]
assert not ax.xaxis.get_label().get_visible()
assert not any(t.get_visible() for t in ax.get_xticklabels())
def test_1d_row_wrapped(self):
p = Plot().facet(row=["a", "b", "c", "d"], wrap=3).plot()
subplots = list(p._subplots)
for s in subplots[:-1]:
ax = s["ax"]
assert ax.yaxis.get_label().get_visible()
assert all(t.get_visible() for t in ax.get_yticklabels())
for s in subplots[-2:]:
ax = s["ax"]
assert ax.xaxis.get_label().get_visible()
assert self.has_xaxis_labels(ax)
assert all(t.get_visible() for t in ax.get_xticklabels())
for s in subplots[:-2]:
ax = s["ax"]
assert not ax.xaxis.get_label().get_visible()
assert not any(t.get_visible() for t in ax.get_xticklabels())
ax = subplots[-1]["ax"]
assert not ax.yaxis.get_label().get_visible()
assert not any(t.get_visible() for t in ax.get_yticklabels())
def test_1d_column_wrapped_non_cross(self, long_df):
p = (
Plot(long_df)
.pair(x=["a", "b", "c"], y=["x", "y", "z"], wrap=2, cross=False)
.plot()
)
for s in p._subplots:
ax = s["ax"]
assert ax.xaxis.get_label().get_visible()
assert all(t.get_visible() for t in ax.get_xticklabels())
assert ax.yaxis.get_label().get_visible()
assert all(t.get_visible() for t in ax.get_yticklabels())
def test_2d(self):
p = Plot().facet(col=["a", "b"], row=["x", "y"]).plot()
subplots = list(p._subplots)
for s in subplots[:2]:
ax = s["ax"]
assert not ax.xaxis.get_label().get_visible()
assert not any(t.get_visible() for t in ax.get_xticklabels())
for s in subplots[2:]:
ax = s["ax"]
assert ax.xaxis.get_label().get_visible()
assert all(t.get_visible() for t in ax.get_xticklabels())
for s in [subplots[0], subplots[2]]:
ax = s["ax"]
assert ax.yaxis.get_label().get_visible()
assert all(t.get_visible() for t in ax.get_yticklabels())
for s in [subplots[1], subplots[3]]:
ax = s["ax"]
assert not ax.yaxis.get_label().get_visible()
assert not any(t.get_visible() for t in ax.get_yticklabels())
def test_2d_unshared(self):
p = (
Plot()
.facet(col=["a", "b"], row=["x", "y"])
.share(x=False, y=False)
.plot()
)
subplots = list(p._subplots)
for s in subplots[:2]:
ax = s["ax"]
assert not ax.xaxis.get_label().get_visible()
assert all(t.get_visible() for t in ax.get_xticklabels())
for s in subplots[2:]:
ax = s["ax"]
assert ax.xaxis.get_label().get_visible()
assert all(t.get_visible() for t in ax.get_xticklabels())
for s in [subplots[0], subplots[2]]:
ax = s["ax"]
assert ax.yaxis.get_label().get_visible()
assert all(t.get_visible() for t in ax.get_yticklabels())
for s in [subplots[1], subplots[3]]:
ax = s["ax"]
assert not ax.yaxis.get_label().get_visible()
assert all(t.get_visible() for t in ax.get_yticklabels())
class TestLegend:
@pytest.fixture
def xy(self):
return dict(x=[1, 2, 3, 4], y=[1, 2, 3, 4])
def test_single_layer_single_variable(self, xy):
s = pd.Series(["a", "b", "a", "c"], name="s")
p = Plot(**xy).add(MockMark(), color=s).plot()
e, = p._legend_contents
labels = categorical_order(s)
assert e[0] == (s.name, s.name)
assert e[-1] == labels
artists = e[1]
assert len(artists) == len(labels)
for a, label in zip(artists, labels):
assert isinstance(a, mpl.artist.Artist)
assert a.value == label
assert a.variables == ["color"]
def test_single_layer_common_variable(self, xy):
s = pd.Series(["a", "b", "a", "c"], name="s")
sem = dict(color=s, marker=s)
p = Plot(**xy).add(MockMark(), **sem).plot()
e, = p._legend_contents
labels = categorical_order(s)
assert e[0] == (s.name, s.name)
assert e[-1] == labels
artists = e[1]
assert len(artists) == len(labels)
for a, label in zip(artists, labels):
assert isinstance(a, mpl.artist.Artist)
assert a.value == label
assert a.variables == list(sem)
def test_single_layer_common_unnamed_variable(self, xy):
s = np.array(["a", "b", "a", "c"])
sem = dict(color=s, marker=s)
p = Plot(**xy).add(MockMark(), **sem).plot()
e, = p._legend_contents
labels = list(np.unique(s)) # assumes sorted order
assert e[0] == ("", id(s))
assert e[-1] == labels
artists = e[1]
assert len(artists) == len(labels)
for a, label in zip(artists, labels):
assert isinstance(a, mpl.artist.Artist)
assert a.value == label
assert a.variables == list(sem)
def test_single_layer_multi_variable(self, xy):
s1 = pd.Series(["a", "b", "a", "c"], name="s1")
s2 = pd.Series(["m", "m", "p", "m"], name="s2")
sem = dict(color=s1, marker=s2)
p = Plot(**xy).add(MockMark(), **sem).plot()
e1, e2 = p._legend_contents
variables = {v.name: k for k, v in sem.items()}
for e, s in zip([e1, e2], [s1, s2]):
assert e[0] == (s.name, s.name)
labels = categorical_order(s)
assert e[-1] == labels
artists = e[1]
assert len(artists) == len(labels)
for a, label in zip(artists, labels):
assert isinstance(a, mpl.artist.Artist)
assert a.value == label
assert a.variables == [variables[s.name]]
def test_multi_layer_single_variable(self, xy):
s = pd.Series(["a", "b", "a", "c"], name="s")
p = Plot(**xy, color=s).add(MockMark()).add(MockMark()).plot()
e1, e2 = p._legend_contents
labels = categorical_order(s)
for e in [e1, e2]:
assert e[0] == (s.name, s.name)
labels = categorical_order(s)
assert e[-1] == labels
artists = e[1]
assert len(artists) == len(labels)
for a, label in zip(artists, labels):
assert isinstance(a, mpl.artist.Artist)
assert a.value == label
assert a.variables == ["color"]
def test_multi_layer_multi_variable(self, xy):
s1 = pd.Series(["a", "b", "a", "c"], name="s1")
s2 = pd.Series(["m", "m", "p", "m"], name="s2")
sem = dict(color=s1), dict(marker=s2)
variables = {"s1": "color", "s2": "marker"}
p = Plot(**xy).add(MockMark(), **sem[0]).add(MockMark(), **sem[1]).plot()
e1, e2 = p._legend_contents
for e, s in zip([e1, e2], [s1, s2]):
assert e[0] == (s.name, s.name)
labels = categorical_order(s)
assert e[-1] == labels
artists = e[1]
assert len(artists) == len(labels)
for a, label in zip(artists, labels):
assert isinstance(a, mpl.artist.Artist)
assert a.value == label
assert a.variables == [variables[s.name]]
def test_multi_layer_different_artists(self, xy):
class MockMark1(MockMark):
def _legend_artist(self, variables, value, scales):
return mpl.lines.Line2D([], [])
class MockMark2(MockMark):
def _legend_artist(self, variables, value, scales):
return mpl.patches.Patch()
s = pd.Series(["a", "b", "a", "c"], name="s")
p = Plot(**xy, color=s).add(MockMark1()).add(MockMark2()).plot()
legend, = p._figure.legends
names = categorical_order(s)
labels = [t.get_text() for t in legend.get_texts()]
assert labels == names
if not _version_predates(mpl, "3.5"):
contents = legend.get_children()[0]
assert len(contents.findobj(mpl.lines.Line2D)) == len(names)
assert len(contents.findobj(mpl.patches.Patch)) == len(names)
def test_three_layers(self, xy):
class MockMarkLine(MockMark):
def _legend_artist(self, variables, value, scales):
return mpl.lines.Line2D([], [])
s = pd.Series(["a", "b", "a", "c"], name="s")
p = Plot(**xy, color=s)
for _ in range(3):
p = p.add(MockMarkLine())
p = p.plot()
texts = p._figure.legends[0].get_texts()
assert len(texts) == len(s.unique())
def test_identity_scale_ignored(self, xy):
s = pd.Series(["r", "g", "b", "g"])
p = Plot(**xy).add(MockMark(), color=s).scale(color=None).plot()
assert not p._legend_contents
def test_suppression_in_add_method(self, xy):
s = pd.Series(["a", "b", "a", "c"], name="s")
p = Plot(**xy).add(MockMark(), color=s, legend=False).plot()
assert not p._legend_contents
def test_anonymous_title(self, xy):
p = Plot(**xy, color=["a", "b", "c", "d"]).add(MockMark()).plot()
legend, = p._figure.legends
assert legend.get_title().get_text() == ""
def test_legendless_mark(self, xy):
class NoLegendMark(MockMark):
def _legend_artist(self, variables, value, scales):
return None
p = Plot(**xy, color=["a", "b", "c", "d"]).add(NoLegendMark()).plot()
assert not p._figure.legends
def test_legend_has_no_offset(self, xy):
color = np.add(xy["x"], 1e8)
p = Plot(**xy, color=color).add(MockMark()).plot()
legend = p._figure.legends[0]
assert legend.texts
for text in legend.texts:
assert float(text.get_text()) > 1e7
def test_layer_legend(self, xy):
p = Plot(**xy).add(MockMark(), label="a").add(MockMark(), label="b").plot()
legend = p._figure.legends[0]
assert legend.texts
for text, expected in zip(legend.texts, "ab"):
assert text.get_text() == expected
def test_layer_legend_with_scale_legend(self, xy):
s = pd.Series(["a", "b", "a", "c"], name="s")
p = Plot(**xy, color=s).add(MockMark(), label="x").plot()
legend = p._figure.legends[0]
texts = [t.get_text() for t in legend.findobj(mpl.text.Text)]
assert "x" in texts
for val in s.unique():
assert val in texts
def test_layer_legend_title(self, xy):
p = Plot(**xy).add(MockMark(), label="x").label(legend="layer").plot()
assert p._figure.legends[0].get_title().get_text() == "layer"
class TestDefaultObject:
def test_default_repr(self):
assert repr(Default()) == ""
class TestThemeConfig:
@pytest.fixture(autouse=True)
def reset_config(self):
yield
Plot.config.theme.reset()
def test_default(self):
p = Plot().plot()
ax = p._figure.axes[0]
expected = Plot.config.theme["axes.facecolor"]
assert mpl.colors.same_color(ax.get_facecolor(), expected)
def test_setitem(self):
color = "#CCC"
Plot.config.theme["axes.facecolor"] = color
p = Plot().plot()
ax = p._figure.axes[0]
assert mpl.colors.same_color(ax.get_facecolor(), color)
def test_update(self):
color = "#DDD"
Plot.config.theme.update({"axes.facecolor": color})
p = Plot().plot()
ax = p._figure.axes[0]
assert mpl.colors.same_color(ax.get_facecolor(), color)
def test_reset(self):
orig = Plot.config.theme["axes.facecolor"]
Plot.config.theme.update({"axes.facecolor": "#EEE"})
Plot.config.theme.reset()
p = Plot().plot()
ax = p._figure.axes[0]
assert mpl.colors.same_color(ax.get_facecolor(), orig)
def test_copy(self):
key, val = "axes.facecolor", ".95"
orig = Plot.config.theme[key]
theme = Plot.config.theme.copy()
theme.update({key: val})
assert Plot.config.theme[key] == orig
def test_html_repr(self):
res = Plot.config.theme._repr_html_()
for tag in ["div", "table", "tr", "td"]:
assert res.count(f"<{tag}") == res.count(f"{tag}")
for key in Plot.config.theme:
assert f"{key}: " in res
class TestDisplayConfig:
@pytest.fixture(autouse=True)
def reset_config(self):
yield
Plot.config.display.update(PlotConfig().display)
def test_png_format(self):
Plot.config.display["format"] = "png"
assert Plot()._repr_svg_() is None
assert Plot().plot()._repr_svg_() is None
def assert_valid_png(p):
data, metadata = p._repr_png_()
img = Image.open(io.BytesIO(data))
assert img.format == "PNG"
assert sorted(metadata) == ["height", "width"]
assert_valid_png(Plot())
assert_valid_png(Plot().plot())
def test_svg_format(self):
Plot.config.display["format"] = "svg"
assert Plot()._repr_png_() is None
assert Plot().plot()._repr_png_() is None
def assert_valid_svg(p):
res = p._repr_svg_()
root = xml.etree.ElementTree.fromstring(res)
assert root.tag == "{http://www.w3.org/2000/svg}svg"
assert_valid_svg(Plot())
assert_valid_svg(Plot().plot())
def test_png_scaling(self):
Plot.config.display["scaling"] = 1.
res1, meta1 = Plot()._repr_png_()
Plot.config.display["scaling"] = .5
res2, meta2 = Plot()._repr_png_()
assert meta1["width"] / 2 == meta2["width"]
assert meta1["height"] / 2 == meta2["height"]
img1 = Image.open(io.BytesIO(res1))
img2 = Image.open(io.BytesIO(res2))
assert img1.size == img2.size
def test_svg_scaling(self):
Plot.config.display["format"] = "svg"
Plot.config.display["scaling"] = 1.
res1 = Plot()._repr_svg_()
Plot.config.display["scaling"] = .5
res2 = Plot()._repr_svg_()
root1 = xml.etree.ElementTree.fromstring(res1)
root2 = xml.etree.ElementTree.fromstring(res2)
def getdim(root, dim):
return float(root.attrib[dim][:-2])
assert getdim(root1, "width") / 2 == getdim(root2, "width")
assert getdim(root1, "height") / 2 == getdim(root2, "height")
def test_png_hidpi(self):
res1, meta1 = Plot()._repr_png_()
Plot.config.display["hidpi"] = False
res2, meta2 = Plot()._repr_png_()
assert meta1["width"] == meta2["width"]
assert meta1["height"] == meta2["height"]
img1 = Image.open(io.BytesIO(res1))
img2 = Image.open(io.BytesIO(res2))
assert img1.size[0] // 2 == img2.size[0]
assert img1.size[1] // 2 == img2.size[1]
================================================
FILE: tests/_core/test_properties.py
================================================
import numpy as np
import pandas as pd
import matplotlib as mpl
from matplotlib.colors import same_color, to_rgb, to_rgba
from matplotlib.markers import MarkerStyle
import pytest
from numpy.testing import assert_array_equal
from seaborn._core.rules import categorical_order
from seaborn._core.scales import Nominal, Continuous, Boolean
from seaborn._core.properties import (
Alpha,
Color,
Coordinate,
EdgeWidth,
Fill,
LineStyle,
LineWidth,
Marker,
PointSize,
)
from seaborn._compat import get_colormap
from seaborn.palettes import color_palette
class DataFixtures:
@pytest.fixture
def num_vector(self, long_df):
return long_df["s"]
@pytest.fixture
def num_order(self, num_vector):
return categorical_order(num_vector)
@pytest.fixture
def cat_vector(self, long_df):
return long_df["a"]
@pytest.fixture
def cat_order(self, cat_vector):
return categorical_order(cat_vector)
@pytest.fixture
def dt_num_vector(self, long_df):
return long_df["t"]
@pytest.fixture
def dt_cat_vector(self, long_df):
return long_df["d"]
@pytest.fixture
def bool_vector(self, long_df):
return long_df["x"] > 10
@pytest.fixture
def vectors(self, num_vector, cat_vector, bool_vector):
return {"num": num_vector, "cat": cat_vector, "bool": bool_vector}
class TestCoordinate(DataFixtures):
def test_bad_scale_arg_str(self, num_vector):
err = "Unknown magic arg for x scale: 'xxx'."
with pytest.raises(ValueError, match=err):
Coordinate("x").infer_scale("xxx", num_vector)
def test_bad_scale_arg_type(self, cat_vector):
err = "Magic arg for x scale must be str, not list."
with pytest.raises(TypeError, match=err):
Coordinate("x").infer_scale([1, 2, 3], cat_vector)
class TestColor(DataFixtures):
def assert_same_rgb(self, a, b):
assert_array_equal(a[:, :3], b[:, :3])
def test_nominal_default_palette(self, cat_vector, cat_order):
m = Color().get_mapping(Nominal(), cat_vector)
n = len(cat_order)
actual = m(np.arange(n))
expected = color_palette(None, n)
for have, want in zip(actual, expected):
assert same_color(have, want)
def test_nominal_default_palette_large(self):
vector = pd.Series(list("abcdefghijklmnopqrstuvwxyz"))
m = Color().get_mapping(Nominal(), vector)
actual = m(np.arange(26))
expected = color_palette("husl", 26)
for have, want in zip(actual, expected):
assert same_color(have, want)
def test_nominal_named_palette(self, cat_vector, cat_order):
palette = "Blues"
m = Color().get_mapping(Nominal(palette), cat_vector)
n = len(cat_order)
actual = m(np.arange(n))
expected = color_palette(palette, n)
for have, want in zip(actual, expected):
assert same_color(have, want)
def test_nominal_list_palette(self, cat_vector, cat_order):
palette = color_palette("Reds", len(cat_order))
m = Color().get_mapping(Nominal(palette), cat_vector)
actual = m(np.arange(len(palette)))
expected = palette
for have, want in zip(actual, expected):
assert same_color(have, want)
def test_nominal_dict_palette(self, cat_vector, cat_order):
colors = color_palette("Greens")
palette = dict(zip(cat_order, colors))
m = Color().get_mapping(Nominal(palette), cat_vector)
n = len(cat_order)
actual = m(np.arange(n))
expected = colors
for have, want in zip(actual, expected):
assert same_color(have, want)
def test_nominal_dict_with_missing_keys(self, cat_vector, cat_order):
palette = dict(zip(cat_order[1:], color_palette("Purples")))
with pytest.raises(ValueError, match="No entry in color dict"):
Color("color").get_mapping(Nominal(palette), cat_vector)
def test_nominal_list_too_short(self, cat_vector, cat_order):
n = len(cat_order) - 1
palette = color_palette("Oranges", n)
msg = rf"The edgecolor list has fewer values \({n}\) than needed \({n + 1}\)"
with pytest.warns(UserWarning, match=msg):
Color("edgecolor").get_mapping(Nominal(palette), cat_vector)
def test_nominal_list_too_long(self, cat_vector, cat_order):
n = len(cat_order) + 1
palette = color_palette("Oranges", n)
msg = rf"The edgecolor list has more values \({n}\) than needed \({n - 1}\)"
with pytest.warns(UserWarning, match=msg):
Color("edgecolor").get_mapping(Nominal(palette), cat_vector)
def test_continuous_default_palette(self, num_vector):
cmap = color_palette("ch:", as_cmap=True)
m = Color().get_mapping(Continuous(), num_vector)
self.assert_same_rgb(m(num_vector), cmap(num_vector))
def test_continuous_named_palette(self, num_vector):
pal = "flare"
cmap = color_palette(pal, as_cmap=True)
m = Color().get_mapping(Continuous(pal), num_vector)
self.assert_same_rgb(m(num_vector), cmap(num_vector))
def test_continuous_tuple_palette(self, num_vector):
vals = ("blue", "red")
cmap = color_palette("blend:" + ",".join(vals), as_cmap=True)
m = Color().get_mapping(Continuous(vals), num_vector)
self.assert_same_rgb(m(num_vector), cmap(num_vector))
def test_continuous_callable_palette(self, num_vector):
cmap = get_colormap("viridis")
m = Color().get_mapping(Continuous(cmap), num_vector)
self.assert_same_rgb(m(num_vector), cmap(num_vector))
def test_continuous_missing(self):
x = pd.Series([1, 2, np.nan, 4])
m = Color().get_mapping(Continuous(), x)
assert np.isnan(m(x)[2]).all()
def test_bad_scale_values_continuous(self, num_vector):
with pytest.raises(TypeError, match="Scale values for color with a Continuous"):
Color().get_mapping(Continuous(["r", "g", "b"]), num_vector)
def test_bad_scale_values_nominal(self, cat_vector):
with pytest.raises(TypeError, match="Scale values for color with a Nominal"):
Color().get_mapping(Nominal(get_colormap("viridis")), cat_vector)
def test_bad_inference_arg(self, cat_vector):
with pytest.raises(TypeError, match="A single scale argument for color"):
Color().infer_scale(123, cat_vector)
@pytest.mark.parametrize(
"data_type,scale_class",
[("cat", Nominal), ("num", Continuous), ("bool", Boolean)]
)
def test_default(self, data_type, scale_class, vectors):
scale = Color().default_scale(vectors[data_type])
assert isinstance(scale, scale_class)
def test_default_numeric_data_category_dtype(self, num_vector):
scale = Color().default_scale(num_vector.astype("category"))
assert isinstance(scale, Nominal)
def test_default_binary_data(self):
x = pd.Series([0, 0, 1, 0, 1], dtype=int)
scale = Color().default_scale(x)
assert isinstance(scale, Continuous)
@pytest.mark.parametrize(
"values,data_type,scale_class",
[
("viridis", "cat", Nominal), # Based on variable type
("viridis", "num", Continuous), # Based on variable type
("viridis", "bool", Boolean), # Based on variable type
("muted", "num", Nominal), # Based on qualitative palette
(["r", "g", "b"], "num", Nominal), # Based on list palette
({2: "r", 4: "g", 8: "b"}, "num", Nominal), # Based on dict palette
(("r", "b"), "num", Continuous), # Based on tuple / variable type
(("g", "m"), "cat", Nominal), # Based on tuple / variable type
(("c", "y"), "bool", Boolean), # Based on tuple / variable type
(get_colormap("inferno"), "num", Continuous), # Based on callable
]
)
def test_inference(self, values, data_type, scale_class, vectors):
scale = Color().infer_scale(values, vectors[data_type])
assert isinstance(scale, scale_class)
assert scale.values == values
def test_standardization(self):
f = Color().standardize
assert f("C3") == to_rgb("C3")
assert f("dodgerblue") == to_rgb("dodgerblue")
assert f((.1, .2, .3)) == (.1, .2, .3)
assert f((.1, .2, .3, .4)) == (.1, .2, .3, .4)
assert f("#123456") == to_rgb("#123456")
assert f("#12345678") == to_rgba("#12345678")
assert f("#123") == to_rgb("#123")
assert f("#1234") == to_rgba("#1234")
class ObjectPropertyBase(DataFixtures):
def assert_equal(self, a, b):
assert self.unpack(a) == self.unpack(b)
def unpack(self, x):
return x
@pytest.mark.parametrize("data_type", ["cat", "num", "bool"])
def test_default(self, data_type, vectors):
scale = self.prop().default_scale(vectors[data_type])
assert isinstance(scale, Boolean if data_type == "bool" else Nominal)
@pytest.mark.parametrize("data_type", ["cat", "num", "bool"])
def test_inference_list(self, data_type, vectors):
scale = self.prop().infer_scale(self.values, vectors[data_type])
assert isinstance(scale, Boolean if data_type == "bool" else Nominal)
assert scale.values == self.values
@pytest.mark.parametrize("data_type", ["cat", "num", "bool"])
def test_inference_dict(self, data_type, vectors):
x = vectors[data_type]
values = dict(zip(categorical_order(x), self.values))
scale = self.prop().infer_scale(values, x)
assert isinstance(scale, Boolean if data_type == "bool" else Nominal)
assert scale.values == values
def test_dict_missing(self, cat_vector):
levels = categorical_order(cat_vector)
values = dict(zip(levels, self.values[:-1]))
scale = Nominal(values)
name = self.prop.__name__.lower()
msg = f"No entry in {name} dictionary for {repr(levels[-1])}"
with pytest.raises(ValueError, match=msg):
self.prop().get_mapping(scale, cat_vector)
@pytest.mark.parametrize("data_type", ["cat", "num"])
def test_mapping_default(self, data_type, vectors):
x = vectors[data_type]
mapping = self.prop().get_mapping(Nominal(), x)
n = x.nunique()
for i, expected in enumerate(self.prop()._default_values(n)):
actual, = mapping([i])
self.assert_equal(actual, expected)
@pytest.mark.parametrize("data_type", ["cat", "num"])
def test_mapping_from_list(self, data_type, vectors):
x = vectors[data_type]
scale = Nominal(self.values)
mapping = self.prop().get_mapping(scale, x)
for i, expected in enumerate(self.standardized_values):
actual, = mapping([i])
self.assert_equal(actual, expected)
@pytest.mark.parametrize("data_type", ["cat", "num"])
def test_mapping_from_dict(self, data_type, vectors):
x = vectors[data_type]
levels = categorical_order(x)
values = dict(zip(levels, self.values[::-1]))
standardized_values = dict(zip(levels, self.standardized_values[::-1]))
scale = Nominal(values)
mapping = self.prop().get_mapping(scale, x)
for i, level in enumerate(levels):
actual, = mapping([i])
expected = standardized_values[level]
self.assert_equal(actual, expected)
def test_mapping_with_null_value(self, cat_vector):
mapping = self.prop().get_mapping(Nominal(self.values), cat_vector)
actual = mapping(np.array([0, np.nan, 2]))
v0, _, v2 = self.standardized_values
expected = [v0, self.prop.null_value, v2]
for a, b in zip(actual, expected):
self.assert_equal(a, b)
def test_unique_default_large_n(self):
n = 24
x = pd.Series(np.arange(n))
mapping = self.prop().get_mapping(Nominal(), x)
assert len({self.unpack(x_i) for x_i in mapping(x)}) == n
def test_bad_scale_values(self, cat_vector):
var_name = self.prop.__name__.lower()
with pytest.raises(TypeError, match=f"Scale values for a {var_name} variable"):
self.prop().get_mapping(Nominal(("o", "s")), cat_vector)
class TestMarker(ObjectPropertyBase):
prop = Marker
values = ["o", (5, 2, 0), MarkerStyle("^")]
standardized_values = [MarkerStyle(x) for x in values]
def assert_equal(self, a, b):
a_path, b_path = a.get_path(), b.get_path()
assert_array_equal(a_path.vertices, b_path.vertices)
assert_array_equal(a_path.codes, b_path.codes)
assert a_path.simplify_threshold == b_path.simplify_threshold
assert a_path.should_simplify == b_path.should_simplify
assert a.get_joinstyle() == b.get_joinstyle()
assert a.get_transform().to_values() == b.get_transform().to_values()
assert a.get_fillstyle() == b.get_fillstyle()
def unpack(self, x):
return (
x.get_path(),
x.get_joinstyle(),
x.get_transform().to_values(),
x.get_fillstyle(),
)
class TestLineStyle(ObjectPropertyBase):
prop = LineStyle
values = ["solid", "--", (1, .5)]
standardized_values = [LineStyle._get_dash_pattern(x) for x in values]
def test_bad_type(self):
p = LineStyle()
with pytest.raises(TypeError, match="^Linestyle must be .+, not list.$"):
p.standardize([1, 2])
def test_bad_style(self):
p = LineStyle()
with pytest.raises(ValueError, match="^Linestyle string must be .+, not 'o'.$"):
p.standardize("o")
def test_bad_dashes(self):
p = LineStyle()
with pytest.raises(TypeError, match="^Invalid dash pattern"):
p.standardize((1, 2, "x"))
class TestFill(DataFixtures):
@pytest.fixture
def vectors(self):
return {
"cat": pd.Series(["a", "a", "b"]),
"num": pd.Series([1, 1, 2]),
"bool": pd.Series([True, True, False])
}
@pytest.fixture
def cat_vector(self, vectors):
return vectors["cat"]
@pytest.fixture
def num_vector(self, vectors):
return vectors["num"]
@pytest.mark.parametrize("data_type", ["cat", "num", "bool"])
def test_default(self, data_type, vectors):
x = vectors[data_type]
scale = Fill().default_scale(x)
assert isinstance(scale, Boolean if data_type == "bool" else Nominal)
@pytest.mark.parametrize("data_type", ["cat", "num", "bool"])
def test_inference_list(self, data_type, vectors):
x = vectors[data_type]
scale = Fill().infer_scale([True, False], x)
assert isinstance(scale, Boolean if data_type == "bool" else Nominal)
assert scale.values == [True, False]
@pytest.mark.parametrize("data_type", ["cat", "num", "bool"])
def test_inference_dict(self, data_type, vectors):
x = vectors[data_type]
values = dict(zip(x.unique(), [True, False]))
scale = Fill().infer_scale(values, x)
assert isinstance(scale, Boolean if data_type == "bool" else Nominal)
assert scale.values == values
def test_mapping_categorical_data(self, cat_vector):
mapping = Fill().get_mapping(Nominal(), cat_vector)
assert_array_equal(mapping([0, 1, 0]), [True, False, True])
def test_mapping_numeric_data(self, num_vector):
mapping = Fill().get_mapping(Nominal(), num_vector)
assert_array_equal(mapping([0, 1, 0]), [True, False, True])
def test_mapping_list(self, cat_vector):
mapping = Fill().get_mapping(Nominal([False, True]), cat_vector)
assert_array_equal(mapping([0, 1, 0]), [False, True, False])
def test_mapping_truthy_list(self, cat_vector):
mapping = Fill().get_mapping(Nominal([0, 1]), cat_vector)
assert_array_equal(mapping([0, 1, 0]), [False, True, False])
def test_mapping_dict(self, cat_vector):
values = dict(zip(cat_vector.unique(), [False, True]))
mapping = Fill().get_mapping(Nominal(values), cat_vector)
assert_array_equal(mapping([0, 1, 0]), [False, True, False])
def test_cycle_warning(self):
x = pd.Series(["a", "b", "c"])
with pytest.warns(UserWarning, match="The variable assigned to fill"):
Fill().get_mapping(Nominal(), x)
def test_values_error(self):
x = pd.Series(["a", "b"])
with pytest.raises(TypeError, match="Scale values for fill must be"):
Fill().get_mapping(Nominal("bad_values"), x)
class IntervalBase(DataFixtures):
def norm(self, x):
return (x - x.min()) / (x.max() - x.min())
@pytest.mark.parametrize("data_type,scale_class", [
("cat", Nominal),
("num", Continuous),
("bool", Boolean),
])
def test_default(self, data_type, scale_class, vectors):
x = vectors[data_type]
scale = self.prop().default_scale(x)
assert isinstance(scale, scale_class)
@pytest.mark.parametrize("arg,data_type,scale_class", [
((1, 3), "cat", Nominal),
((1, 3), "num", Continuous),
((1, 3), "bool", Boolean),
([1, 2, 3], "cat", Nominal),
([1, 2, 3], "num", Nominal),
([1, 3], "bool", Boolean),
({"a": 1, "b": 3, "c": 2}, "cat", Nominal),
({2: 1, 4: 3, 8: 2}, "num", Nominal),
({True: 4, False: 2}, "bool", Boolean),
])
def test_inference(self, arg, data_type, scale_class, vectors):
x = vectors[data_type]
scale = self.prop().infer_scale(arg, x)
assert isinstance(scale, scale_class)
assert scale.values == arg
def test_mapped_interval_numeric(self, num_vector):
mapping = self.prop().get_mapping(Continuous(), num_vector)
assert_array_equal(mapping([0, 1]), self.prop().default_range)
def test_mapped_interval_categorical(self, cat_vector):
mapping = self.prop().get_mapping(Nominal(), cat_vector)
n = cat_vector.nunique()
assert_array_equal(mapping([n - 1, 0]), self.prop().default_range)
def test_bad_scale_values_numeric_data(self, num_vector):
prop_name = self.prop.__name__.lower()
err_stem = (
f"Values for {prop_name} variables with Continuous scale must be 2-tuple"
)
with pytest.raises(TypeError, match=f"{err_stem}; not ."):
self.prop().get_mapping(Continuous("abc"), num_vector)
with pytest.raises(TypeError, match=f"{err_stem}; not 3-tuple."):
self.prop().get_mapping(Continuous((1, 2, 3)), num_vector)
def test_bad_scale_values_categorical_data(self, cat_vector):
prop_name = self.prop.__name__.lower()
err_text = f"Values for {prop_name} variables with Nominal scale"
with pytest.raises(TypeError, match=err_text):
self.prop().get_mapping(Nominal("abc"), cat_vector)
class TestAlpha(IntervalBase):
prop = Alpha
class TestLineWidth(IntervalBase):
prop = LineWidth
def test_rcparam_default(self):
with mpl.rc_context({"lines.linewidth": 2}):
assert self.prop().default_range == (1, 4)
class TestEdgeWidth(IntervalBase):
prop = EdgeWidth
def test_rcparam_default(self):
with mpl.rc_context({"patch.linewidth": 2}):
assert self.prop().default_range == (1, 4)
class TestPointSize(IntervalBase):
prop = PointSize
def test_areal_scaling_numeric(self, num_vector):
limits = 5, 10
scale = Continuous(limits)
mapping = self.prop().get_mapping(scale, num_vector)
x = np.linspace(0, 1, 6)
expected = np.sqrt(np.linspace(*np.square(limits), num=len(x)))
assert_array_equal(mapping(x), expected)
def test_areal_scaling_categorical(self, cat_vector):
limits = (2, 4)
scale = Nominal(limits)
mapping = self.prop().get_mapping(scale, cat_vector)
assert_array_equal(mapping(np.arange(3)), [4, np.sqrt(10), 2])
================================================
FILE: tests/_core/test_rules.py
================================================
import numpy as np
import pandas as pd
import pytest
from seaborn._core.rules import (
VarType,
variable_type,
categorical_order,
)
def test_vartype_object():
v = VarType("numeric")
assert v == "numeric"
assert v != "categorical"
with pytest.raises(AssertionError):
v == "number"
with pytest.raises(AssertionError):
VarType("date")
def test_variable_type():
s = pd.Series([1., 2., 3.])
assert variable_type(s) == "numeric"
assert variable_type(s.astype(int)) == "numeric"
assert variable_type(s.astype(object)) == "numeric"
s = pd.Series([1, 2, 3, np.nan], dtype=object)
assert variable_type(s) == "numeric"
s = pd.Series([np.nan, np.nan])
assert variable_type(s) == "numeric"
s = pd.Series([pd.NA, pd.NA])
assert variable_type(s) == "numeric"
s = pd.Series([1, 2, pd.NA], dtype="Int64")
assert variable_type(s) == "numeric"
s = pd.Series([1, 2, pd.NA], dtype=object)
assert variable_type(s) == "numeric"
s = pd.Series(["1", "2", "3"])
assert variable_type(s) == "categorical"
s = pd.Series([True, False, False])
assert variable_type(s) == "numeric"
assert variable_type(s, boolean_type="categorical") == "categorical"
assert variable_type(s, boolean_type="boolean") == "boolean"
# This should arguably be datmetime, but we don't currently handle it correctly
# Test is mainly asserting that this doesn't fail on the boolean check.
s = pd.timedelta_range(1, periods=3, freq="D").to_series()
assert variable_type(s) == "categorical"
s_cat = s.astype("category")
assert variable_type(s_cat, boolean_type="categorical") == "categorical"
assert variable_type(s_cat, boolean_type="numeric") == "categorical"
assert variable_type(s_cat, boolean_type="boolean") == "categorical"
s = pd.Series([1, 0, 0])
assert variable_type(s, boolean_type="boolean") == "boolean"
assert variable_type(s, boolean_type="boolean", strict_boolean=True) == "numeric"
s = pd.Series([1, 0, 0])
assert variable_type(s, boolean_type="boolean") == "boolean"
s = pd.Series([pd.Timestamp(1), pd.Timestamp(2)])
assert variable_type(s) == "datetime"
assert variable_type(s.astype(object)) == "datetime"
def test_categorical_order():
x = pd.Series(["a", "c", "c", "b", "a", "d"])
y = pd.Series([3, 2, 5, 1, 4])
order = ["a", "b", "c", "d"]
out = categorical_order(x)
assert out == ["a", "c", "b", "d"]
out = categorical_order(x, order)
assert out == order
out = categorical_order(x, ["b", "a"])
assert out == ["b", "a"]
out = categorical_order(y)
assert out == [1, 2, 3, 4, 5]
out = categorical_order(pd.Series(y))
assert out == [1, 2, 3, 4, 5]
y_cat = pd.Series(pd.Categorical(y, y))
out = categorical_order(y_cat)
assert out == list(y)
x = pd.Series(x).astype("category")
out = categorical_order(x)
assert out == list(x.cat.categories)
out = categorical_order(x, ["b", "a"])
assert out == ["b", "a"]
x = pd.Series(["a", np.nan, "c", "c", "b", "a", "d"])
out = categorical_order(x)
assert out == ["a", "c", "b", "d"]
================================================
FILE: tests/_core/test_scales.py
================================================
import re
import numpy as np
import pandas as pd
import matplotlib as mpl
import pytest
from numpy.testing import assert_array_equal
from pandas.testing import assert_series_equal
from seaborn._core.plot import Plot
from seaborn._core.scales import (
Nominal,
Continuous,
Boolean,
Temporal,
PseudoAxis,
)
from seaborn._core.properties import (
IntervalProperty,
ObjectProperty,
Coordinate,
Alpha,
Color,
Fill,
)
from seaborn.palettes import color_palette
from seaborn.utils import _version_predates
class TestContinuous:
@pytest.fixture
def x(self):
return pd.Series([1, 3, 9], name="x", dtype=float)
def setup_ticks(self, x, *args, **kwargs):
s = Continuous().tick(*args, **kwargs)._setup(x, Coordinate())
a = PseudoAxis(s._matplotlib_scale)
a.set_view_interval(0, 1)
return a
def setup_labels(self, x, *args, **kwargs):
s = Continuous().label(*args, **kwargs)._setup(x, Coordinate())
a = PseudoAxis(s._matplotlib_scale)
a.set_view_interval(0, 1)
locs = a.major.locator()
return a, locs
def test_coordinate_defaults(self, x):
s = Continuous()._setup(x, Coordinate())
assert_series_equal(s(x), x)
def test_coordinate_transform(self, x):
s = Continuous(trans="log")._setup(x, Coordinate())
assert_series_equal(s(x), np.log10(x))
def test_coordinate_transform_with_parameter(self, x):
s = Continuous(trans="pow3")._setup(x, Coordinate())
assert_series_equal(s(x), np.power(x, 3))
def test_coordinate_transform_error(self, x):
s = Continuous(trans="bad")
with pytest.raises(ValueError, match="Unknown value provided"):
s._setup(x, Coordinate())
def test_interval_defaults(self, x):
s = Continuous()._setup(x, IntervalProperty())
assert_array_equal(s(x), [0, .25, 1])
def test_interval_with_range(self, x):
s = Continuous((1, 3))._setup(x, IntervalProperty())
assert_array_equal(s(x), [1, 1.5, 3])
def test_interval_with_norm(self, x):
s = Continuous(norm=(3, 7))._setup(x, IntervalProperty())
assert_array_equal(s(x), [-.5, 0, 1.5])
def test_interval_with_range_norm_and_transform(self, x):
x = pd.Series([1, 10, 100])
# TODO param order?
s = Continuous((2, 3), (10, 100), "log")._setup(x, IntervalProperty())
assert_array_equal(s(x), [1, 2, 3])
def test_interval_with_bools(self):
x = pd.Series([True, False, False])
s = Continuous()._setup(x, IntervalProperty())
assert_array_equal(s(x), [1, 0, 0])
def test_color_defaults(self, x):
cmap = color_palette("ch:", as_cmap=True)
s = Continuous()._setup(x, Color())
assert_array_equal(s(x), cmap([0, .25, 1])[:, :3]) # FIXME RGBA
def test_color_named_values(self, x):
cmap = color_palette("viridis", as_cmap=True)
s = Continuous("viridis")._setup(x, Color())
assert_array_equal(s(x), cmap([0, .25, 1])[:, :3]) # FIXME RGBA
def test_color_tuple_values(self, x):
cmap = color_palette("blend:b,g", as_cmap=True)
s = Continuous(("b", "g"))._setup(x, Color())
assert_array_equal(s(x), cmap([0, .25, 1])[:, :3]) # FIXME RGBA
def test_color_callable_values(self, x):
cmap = color_palette("light:r", as_cmap=True)
s = Continuous(cmap)._setup(x, Color())
assert_array_equal(s(x), cmap([0, .25, 1])[:, :3]) # FIXME RGBA
def test_color_with_norm(self, x):
cmap = color_palette("ch:", as_cmap=True)
s = Continuous(norm=(3, 7))._setup(x, Color())
assert_array_equal(s(x), cmap([-.5, 0, 1.5])[:, :3]) # FIXME RGBA
def test_color_with_transform(self, x):
x = pd.Series([1, 10, 100], name="x", dtype=float)
cmap = color_palette("ch:", as_cmap=True)
s = Continuous(trans="log")._setup(x, Color())
assert_array_equal(s(x), cmap([0, .5, 1])[:, :3]) # FIXME RGBA
def test_tick_locator(self, x):
locs = [.2, .5, .8]
locator = mpl.ticker.FixedLocator(locs)
a = self.setup_ticks(x, locator)
assert_array_equal(a.major.locator(), locs)
def test_tick_locator_input_check(self, x):
err = "Tick locator must be an instance of .*?, not ."
with pytest.raises(TypeError, match=err):
Continuous().tick((1, 2))
def test_tick_upto(self, x):
for n in [2, 5, 10]:
a = self.setup_ticks(x, upto=n)
assert len(a.major.locator()) <= (n + 1)
def test_tick_every(self, x):
for d in [.05, .2, .5]:
a = self.setup_ticks(x, every=d)
assert np.allclose(np.diff(a.major.locator()), d)
def test_tick_every_between(self, x):
lo, hi = .2, .8
for d in [.05, .2, .5]:
a = self.setup_ticks(x, every=d, between=(lo, hi))
expected = np.arange(lo, hi + d, d)
assert_array_equal(a.major.locator(), expected)
def test_tick_at(self, x):
locs = [.2, .5, .9]
a = self.setup_ticks(x, at=locs)
assert_array_equal(a.major.locator(), locs)
def test_tick_count(self, x):
n = 8
a = self.setup_ticks(x, count=n)
assert_array_equal(a.major.locator(), np.linspace(0, 1, n))
def test_tick_count_between(self, x):
n = 5
lo, hi = .2, .7
a = self.setup_ticks(x, count=n, between=(lo, hi))
assert_array_equal(a.major.locator(), np.linspace(lo, hi, n))
def test_tick_minor(self, x):
n = 3
a = self.setup_ticks(x, count=2, minor=n)
expected = np.linspace(0, 1, n + 2)
if _version_predates(mpl, "3.8.0rc1"):
# I am not sure why matplotlib <3.8 minor ticks include the
# largest major location but exclude the smalllest one ...
expected = expected[1:]
assert_array_equal(a.minor.locator(), expected)
def test_log_tick_default(self, x):
s = Continuous(trans="log")._setup(x, Coordinate())
a = PseudoAxis(s._matplotlib_scale)
a.set_view_interval(.5, 1050)
ticks = a.major.locator()
assert np.allclose(np.diff(np.log10(ticks)), 1)
def test_log_tick_upto(self, x):
n = 3
s = Continuous(trans="log").tick(upto=n)._setup(x, Coordinate())
a = PseudoAxis(s._matplotlib_scale)
assert a.major.locator.numticks == n
def test_log_tick_count(self, x):
with pytest.raises(RuntimeError, match="`count` requires"):
Continuous(trans="log").tick(count=4)
s = Continuous(trans="log").tick(count=4, between=(1, 1000))
a = PseudoAxis(s._setup(x, Coordinate())._matplotlib_scale)
a.set_view_interval(.5, 1050)
assert_array_equal(a.major.locator(), [1, 10, 100, 1000])
def test_log_tick_format_disabled(self, x):
s = Continuous(trans="log").label(base=None)._setup(x, Coordinate())
a = PseudoAxis(s._matplotlib_scale)
a.set_view_interval(20, 20000)
labels = a.major.formatter.format_ticks(a.major.locator())
for text in labels:
assert re.match(r"^\d+$", text)
def test_log_tick_every(self, x):
with pytest.raises(RuntimeError, match="`every` not supported"):
Continuous(trans="log").tick(every=2)
def test_symlog_tick_default(self, x):
s = Continuous(trans="symlog")._setup(x, Coordinate())
a = PseudoAxis(s._matplotlib_scale)
a.set_view_interval(-1050, 1050)
ticks = a.major.locator()
assert ticks[0] == -ticks[-1]
pos_ticks = np.sort(np.unique(np.abs(ticks)))
assert np.allclose(np.diff(np.log10(pos_ticks[1:])), 1)
assert pos_ticks[0] == 0
def test_label_formatter(self, x):
fmt = mpl.ticker.FormatStrFormatter("%.3f")
a, locs = self.setup_labels(x, fmt)
labels = a.major.formatter.format_ticks(locs)
for text in labels:
assert re.match(r"^\d\.\d{3}$", text)
def test_label_like_pattern(self, x):
a, locs = self.setup_labels(x, like=".4f")
labels = a.major.formatter.format_ticks(locs)
for text in labels:
assert re.match(r"^\d\.\d{4}$", text)
def test_label_like_string(self, x):
a, locs = self.setup_labels(x, like="x = {x:.1f}")
labels = a.major.formatter.format_ticks(locs)
for text in labels:
assert re.match(r"^x = \d\.\d$", text)
def test_label_like_function(self, x):
a, locs = self.setup_labels(x, like="{:^5.1f}".format)
labels = a.major.formatter.format_ticks(locs)
for text in labels:
assert re.match(r"^ \d\.\d $", text)
def test_label_base(self, x):
a, locs = self.setup_labels(100 * x, base=2)
labels = a.major.formatter.format_ticks(locs)
for text in labels[1:]:
assert not text or "2^" in text
def test_label_unit(self, x):
a, locs = self.setup_labels(1000 * x, unit="g")
labels = a.major.formatter.format_ticks(locs)
for text in labels[1:-1]:
assert re.match(r"^\d+ mg$", text)
def test_label_unit_with_sep(self, x):
a, locs = self.setup_labels(1000 * x, unit=("", "g"))
labels = a.major.formatter.format_ticks(locs)
for text in labels[1:-1]:
assert re.match(r"^\d+mg$", text)
def test_label_empty_unit(self, x):
a, locs = self.setup_labels(1000 * x, unit="")
labels = a.major.formatter.format_ticks(locs)
for text in labels[1:-1]:
assert re.match(r"^\d+m$", text)
def test_label_base_from_transform(self, x):
s = Continuous(trans="log")
a = PseudoAxis(s._setup(x, Coordinate())._matplotlib_scale)
a.set_view_interval(10, 1000)
label, = a.major.formatter.format_ticks([100])
assert r"10^{2}" in label
def test_label_type_checks(self):
s = Continuous()
with pytest.raises(TypeError, match="Label formatter must be"):
s.label("{x}")
with pytest.raises(TypeError, match="`like` must be"):
s.label(like=2)
class TestNominal:
@pytest.fixture
def x(self):
return pd.Series(["a", "c", "b", "c"], name="x")
@pytest.fixture
def y(self):
return pd.Series([1, -1.5, 3, -1.5], name="y")
def test_coordinate_defaults(self, x):
s = Nominal()._setup(x, Coordinate())
assert_array_equal(s(x), np.array([0, 1, 2, 1], float))
def test_coordinate_with_order(self, x):
s = Nominal(order=["a", "b", "c"])._setup(x, Coordinate())
assert_array_equal(s(x), np.array([0, 2, 1, 2], float))
def test_coordinate_with_subset_order(self, x):
s = Nominal(order=["c", "a"])._setup(x, Coordinate())
assert_array_equal(s(x), np.array([1, 0, np.nan, 0], float))
def test_coordinate_axis(self, x):
ax = mpl.figure.Figure().subplots()
s = Nominal()._setup(x, Coordinate(), ax.xaxis)
assert_array_equal(s(x), np.array([0, 1, 2, 1], float))
f = ax.xaxis.get_major_formatter()
assert f.format_ticks([0, 1, 2]) == ["a", "c", "b"]
def test_coordinate_axis_with_order(self, x):
order = ["a", "b", "c"]
ax = mpl.figure.Figure().subplots()
s = Nominal(order=order)._setup(x, Coordinate(), ax.xaxis)
assert_array_equal(s(x), np.array([0, 2, 1, 2], float))
f = ax.xaxis.get_major_formatter()
assert f.format_ticks([0, 1, 2]) == order
def test_coordinate_axis_with_subset_order(self, x):
order = ["c", "a"]
ax = mpl.figure.Figure().subplots()
s = Nominal(order=order)._setup(x, Coordinate(), ax.xaxis)
assert_array_equal(s(x), np.array([1, 0, np.nan, 0], float))
f = ax.xaxis.get_major_formatter()
assert f.format_ticks([0, 1, 2]) == [*order, ""]
def test_coordinate_axis_with_category_dtype(self, x):
order = ["b", "a", "d", "c"]
x = x.astype(pd.CategoricalDtype(order))
ax = mpl.figure.Figure().subplots()
s = Nominal()._setup(x, Coordinate(), ax.xaxis)
assert_array_equal(s(x), np.array([1, 3, 0, 3], float))
f = ax.xaxis.get_major_formatter()
assert f.format_ticks([0, 1, 2, 3]) == order
def test_coordinate_numeric_data(self, y):
ax = mpl.figure.Figure().subplots()
s = Nominal()._setup(y, Coordinate(), ax.yaxis)
assert_array_equal(s(y), np.array([1, 0, 2, 0], float))
f = ax.yaxis.get_major_formatter()
assert f.format_ticks([0, 1, 2]) == ["-1.5", "1.0", "3.0"]
def test_coordinate_numeric_data_with_order(self, y):
order = [1, 4, -1.5]
ax = mpl.figure.Figure().subplots()
s = Nominal(order=order)._setup(y, Coordinate(), ax.yaxis)
assert_array_equal(s(y), np.array([0, 2, np.nan, 2], float))
f = ax.yaxis.get_major_formatter()
assert f.format_ticks([0, 1, 2]) == ["1.0", "4.0", "-1.5"]
def test_color_defaults(self, x):
s = Nominal()._setup(x, Color())
cs = color_palette()
assert_array_equal(s(x), [cs[0], cs[1], cs[2], cs[1]])
def test_color_named_palette(self, x):
pal = "flare"
s = Nominal(pal)._setup(x, Color())
cs = color_palette(pal, 3)
assert_array_equal(s(x), [cs[0], cs[1], cs[2], cs[1]])
def test_color_list_palette(self, x):
cs = color_palette("crest", 3)
s = Nominal(cs)._setup(x, Color())
assert_array_equal(s(x), [cs[0], cs[1], cs[2], cs[1]])
def test_color_dict_palette(self, x):
cs = color_palette("crest", 3)
pal = dict(zip("bac", cs))
s = Nominal(pal)._setup(x, Color())
assert_array_equal(s(x), [cs[1], cs[2], cs[0], cs[2]])
def test_color_numeric_data(self, y):
s = Nominal()._setup(y, Color())
cs = color_palette()
assert_array_equal(s(y), [cs[1], cs[0], cs[2], cs[0]])
def test_color_numeric_with_order_subset(self, y):
s = Nominal(order=[-1.5, 1])._setup(y, Color())
c1, c2 = color_palette(n_colors=2)
null = (np.nan, np.nan, np.nan)
assert_array_equal(s(y), [c2, c1, null, c1])
@pytest.mark.xfail(reason="Need to sort out float/int order")
def test_color_numeric_int_float_mix(self):
z = pd.Series([1, 2], name="z")
s = Nominal(order=[1.0, 2])._setup(z, Color())
c1, c2 = color_palette(n_colors=2)
null = (np.nan, np.nan, np.nan)
assert_array_equal(s(z), [c1, null, c2])
def test_color_alpha_in_palette(self, x):
cs = [(.2, .2, .3, .5), (.1, .2, .3, 1), (.5, .6, .2, 0)]
s = Nominal(cs)._setup(x, Color())
assert_array_equal(s(x), [cs[0], cs[1], cs[2], cs[1]])
def test_color_unknown_palette(self, x):
pal = "not_a_palette"
err = f"'{pal}' is not a valid palette name"
with pytest.raises(ValueError, match=err):
Nominal(pal)._setup(x, Color())
def test_object_defaults(self, x):
class MockProperty(ObjectProperty):
def _default_values(self, n):
return list("xyz"[:n])
s = Nominal()._setup(x, MockProperty())
assert s(x) == ["x", "y", "z", "y"]
def test_object_list(self, x):
vs = ["x", "y", "z"]
s = Nominal(vs)._setup(x, ObjectProperty())
assert s(x) == ["x", "y", "z", "y"]
def test_object_dict(self, x):
vs = {"a": "x", "b": "y", "c": "z"}
s = Nominal(vs)._setup(x, ObjectProperty())
assert s(x) == ["x", "z", "y", "z"]
def test_object_order(self, x):
vs = ["x", "y", "z"]
s = Nominal(vs, order=["c", "a", "b"])._setup(x, ObjectProperty())
assert s(x) == ["y", "x", "z", "x"]
def test_object_order_subset(self, x):
vs = ["x", "y"]
s = Nominal(vs, order=["a", "c"])._setup(x, ObjectProperty())
assert s(x) == ["x", "y", None, "y"]
def test_objects_that_are_weird(self, x):
vs = [("x", 1), (None, None, 0), {}]
s = Nominal(vs)._setup(x, ObjectProperty())
assert s(x) == [vs[0], vs[1], vs[2], vs[1]]
def test_alpha_default(self, x):
s = Nominal()._setup(x, Alpha())
assert_array_equal(s(x), [.95, .625, .3, .625])
def test_fill(self):
x = pd.Series(["a", "a", "b", "a"], name="x")
s = Nominal()._setup(x, Fill())
assert_array_equal(s(x), [True, True, False, True])
def test_fill_dict(self):
x = pd.Series(["a", "a", "b", "a"], name="x")
vs = {"a": False, "b": True}
s = Nominal(vs)._setup(x, Fill())
assert_array_equal(s(x), [False, False, True, False])
def test_fill_nunique_warning(self):
x = pd.Series(["a", "b", "c", "a", "b"], name="x")
with pytest.warns(UserWarning, match="The variable assigned to fill"):
s = Nominal()._setup(x, Fill())
assert_array_equal(s(x), [True, False, True, True, False])
def test_interval_defaults(self, x):
class MockProperty(IntervalProperty):
_default_range = (1, 2)
s = Nominal()._setup(x, MockProperty())
assert_array_equal(s(x), [2, 1.5, 1, 1.5])
def test_interval_tuple(self, x):
s = Nominal((1, 2))._setup(x, IntervalProperty())
assert_array_equal(s(x), [2, 1.5, 1, 1.5])
def test_interval_tuple_numeric(self, y):
s = Nominal((1, 2))._setup(y, IntervalProperty())
assert_array_equal(s(y), [1.5, 2, 1, 2])
def test_interval_list(self, x):
vs = [2, 5, 4]
s = Nominal(vs)._setup(x, IntervalProperty())
assert_array_equal(s(x), [2, 5, 4, 5])
def test_interval_dict(self, x):
vs = {"a": 3, "b": 4, "c": 6}
s = Nominal(vs)._setup(x, IntervalProperty())
assert_array_equal(s(x), [3, 6, 4, 6])
def test_interval_with_transform(self, x):
class MockProperty(IntervalProperty):
_forward = np.square
_inverse = np.sqrt
s = Nominal((2, 4))._setup(x, MockProperty())
assert_array_equal(s(x), [4, np.sqrt(10), 2, np.sqrt(10)])
def test_empty_data(self):
x = pd.Series([], dtype=object, name="x")
s = Nominal()._setup(x, Coordinate())
assert_array_equal(s(x), [])
def test_finalize(self, x):
ax = mpl.figure.Figure().subplots()
s = Nominal()._setup(x, Coordinate(), ax.yaxis)
s._finalize(Plot(), ax.yaxis)
levels = x.unique()
assert ax.get_ylim() == (len(levels) - .5, -.5)
assert_array_equal(ax.get_yticks(), list(range(len(levels))))
for i, expected in enumerate(levels):
assert ax.yaxis.major.formatter(i) == expected
class TestTemporal:
@pytest.fixture
def t(self):
dates = pd.to_datetime(["1972-09-27", "1975-06-24", "1980-12-14"])
return pd.Series(dates, name="x")
@pytest.fixture
def x(self, t):
return pd.Series(mpl.dates.date2num(t), name=t.name)
def test_coordinate_defaults(self, t, x):
s = Temporal()._setup(t, Coordinate())
assert_array_equal(s(t), x)
def test_interval_defaults(self, t, x):
s = Temporal()._setup(t, IntervalProperty())
normed = (x - x.min()) / (x.max() - x.min())
assert_array_equal(s(t), normed)
def test_interval_with_range(self, t, x):
values = (1, 3)
s = Temporal((1, 3))._setup(t, IntervalProperty())
normed = (x - x.min()) / (x.max() - x.min())
expected = normed * (values[1] - values[0]) + values[0]
assert_array_equal(s(t), expected)
def test_interval_with_norm(self, t, x):
norm = t[1], t[2]
s = Temporal(norm=norm)._setup(t, IntervalProperty())
n = mpl.dates.date2num(norm)
normed = (x - n[0]) / (n[1] - n[0])
assert_array_equal(s(t), normed)
def test_color_defaults(self, t, x):
cmap = color_palette("ch:", as_cmap=True)
s = Temporal()._setup(t, Color())
normed = (x - x.min()) / (x.max() - x.min())
assert_array_equal(s(t), cmap(normed)[:, :3]) # FIXME RGBA
def test_color_named_values(self, t, x):
name = "viridis"
cmap = color_palette(name, as_cmap=True)
s = Temporal(name)._setup(t, Color())
normed = (x - x.min()) / (x.max() - x.min())
assert_array_equal(s(t), cmap(normed)[:, :3]) # FIXME RGBA
def test_coordinate_axis(self, t, x):
ax = mpl.figure.Figure().subplots()
s = Temporal()._setup(t, Coordinate(), ax.xaxis)
assert_array_equal(s(t), x)
locator = ax.xaxis.get_major_locator()
formatter = ax.xaxis.get_major_formatter()
assert isinstance(locator, mpl.dates.AutoDateLocator)
assert isinstance(formatter, mpl.dates.AutoDateFormatter)
def test_tick_locator(self, t):
locator = mpl.dates.YearLocator(month=3, day=15)
s = Temporal().tick(locator)
a = PseudoAxis(s._setup(t, Coordinate())._matplotlib_scale)
a.set_view_interval(0, 365)
assert 73 in a.major.locator()
def test_tick_upto(self, t, x):
n = 8
ax = mpl.figure.Figure().subplots()
Temporal().tick(upto=n)._setup(t, Coordinate(), ax.xaxis)
locator = ax.xaxis.get_major_locator()
assert set(locator.maxticks.values()) == {n}
def test_label_formatter(self, t):
formatter = mpl.dates.DateFormatter("%Y")
s = Temporal().label(formatter)
a = PseudoAxis(s._setup(t, Coordinate())._matplotlib_scale)
a.set_view_interval(10, 1000)
label, = a.major.formatter.format_ticks([100])
assert label == "1970"
def test_label_concise(self, t, x):
ax = mpl.figure.Figure().subplots()
Temporal().label(concise=True)._setup(t, Coordinate(), ax.xaxis)
formatter = ax.xaxis.get_major_formatter()
assert isinstance(formatter, mpl.dates.ConciseDateFormatter)
class TestBoolean:
@pytest.fixture
def x(self):
return pd.Series([True, False, False, True], name="x", dtype=bool)
def test_coordinate(self, x):
s = Boolean()._setup(x, Coordinate())
assert_array_equal(s(x), x.astype(float))
def test_coordinate_axis(self, x):
ax = mpl.figure.Figure().subplots()
s = Boolean()._setup(x, Coordinate(), ax.xaxis)
assert_array_equal(s(x), x.astype(float))
f = ax.xaxis.get_major_formatter()
assert f.format_ticks([0, 1]) == ["False", "True"]
@pytest.mark.parametrize(
"dtype,value",
[
(object, np.nan),
(object, None),
("boolean", pd.NA),
]
)
def test_coordinate_missing(self, x, dtype, value):
x = x.astype(dtype)
x[2] = value
s = Boolean()._setup(x, Coordinate())
assert_array_equal(s(x), x.astype(float))
def test_color_defaults(self, x):
s = Boolean()._setup(x, Color())
cs = color_palette()
expected = [cs[int(x_i)] for x_i in ~x]
assert_array_equal(s(x), expected)
def test_color_list_palette(self, x):
cs = color_palette("crest", 2)
s = Boolean(cs)._setup(x, Color())
expected = [cs[int(x_i)] for x_i in ~x]
assert_array_equal(s(x), expected)
def test_color_tuple_palette(self, x):
cs = tuple(color_palette("crest", 2))
s = Boolean(cs)._setup(x, Color())
expected = [cs[int(x_i)] for x_i in ~x]
assert_array_equal(s(x), expected)
def test_color_dict_palette(self, x):
cs = color_palette("crest", 2)
pal = {True: cs[0], False: cs[1]}
s = Boolean(pal)._setup(x, Color())
expected = [pal[x_i] for x_i in x]
assert_array_equal(s(x), expected)
def test_object_defaults(self, x):
vs = ["x", "y", "z"]
class MockProperty(ObjectProperty):
def _default_values(self, n):
return vs[:n]
s = Boolean()._setup(x, MockProperty())
expected = [vs[int(x_i)] for x_i in ~x]
assert s(x) == expected
def test_object_list(self, x):
vs = ["x", "y"]
s = Boolean(vs)._setup(x, ObjectProperty())
expected = [vs[int(x_i)] for x_i in ~x]
assert s(x) == expected
def test_object_dict(self, x):
vs = {True: "x", False: "y"}
s = Boolean(vs)._setup(x, ObjectProperty())
expected = [vs[x_i] for x_i in x]
assert s(x) == expected
def test_fill(self, x):
s = Boolean()._setup(x, Fill())
assert_array_equal(s(x), x)
def test_interval_defaults(self, x):
vs = (1, 2)
class MockProperty(IntervalProperty):
_default_range = vs
s = Boolean()._setup(x, MockProperty())
expected = [vs[int(x_i)] for x_i in x]
assert_array_equal(s(x), expected)
def test_interval_tuple(self, x):
vs = (3, 5)
s = Boolean(vs)._setup(x, IntervalProperty())
expected = [vs[int(x_i)] for x_i in x]
assert_array_equal(s(x), expected)
def test_finalize(self, x):
ax = mpl.figure.Figure().subplots()
s = Boolean()._setup(x, Coordinate(), ax.xaxis)
s._finalize(Plot(), ax.xaxis)
assert ax.get_xlim() == (1.5, -.5)
assert_array_equal(ax.get_xticks(), [0, 1])
assert ax.xaxis.major.formatter(0) == "False"
assert ax.xaxis.major.formatter(1) == "True"
================================================
FILE: tests/_core/test_subplots.py
================================================
import itertools
import numpy as np
import pytest
from seaborn._core.subplots import Subplots
class TestSpecificationChecks:
def test_both_facets_and_wrap(self):
err = "Cannot wrap facets when specifying both `col` and `row`."
facet_spec = {"wrap": 3, "variables": {"col": "a", "row": "b"}}
with pytest.raises(RuntimeError, match=err):
Subplots({}, facet_spec, {})
def test_cross_xy_pairing_and_wrap(self):
err = "Cannot wrap subplots when pairing on both `x` and `y`."
pair_spec = {"wrap": 3, "structure": {"x": ["a", "b"], "y": ["y", "z"]}}
with pytest.raises(RuntimeError, match=err):
Subplots({}, {}, pair_spec)
def test_col_facets_and_x_pairing(self):
err = "Cannot facet the columns while pairing on `x`."
facet_spec = {"variables": {"col": "a"}}
pair_spec = {"structure": {"x": ["x", "y"]}}
with pytest.raises(RuntimeError, match=err):
Subplots({}, facet_spec, pair_spec)
def test_wrapped_columns_and_y_pairing(self):
err = "Cannot wrap the columns while pairing on `y`."
facet_spec = {"variables": {"col": "a"}, "wrap": 2}
pair_spec = {"structure": {"y": ["x", "y"]}}
with pytest.raises(RuntimeError, match=err):
Subplots({}, facet_spec, pair_spec)
def test_wrapped_x_pairing_and_facetd_rows(self):
err = "Cannot wrap the columns while faceting the rows."
facet_spec = {"variables": {"row": "a"}}
pair_spec = {"structure": {"x": ["x", "y"]}, "wrap": 2}
with pytest.raises(RuntimeError, match=err):
Subplots({}, facet_spec, pair_spec)
class TestSubplotSpec:
def test_single_subplot(self):
s = Subplots({}, {}, {})
assert s.n_subplots == 1
assert s.subplot_spec["ncols"] == 1
assert s.subplot_spec["nrows"] == 1
assert s.subplot_spec["sharex"] is True
assert s.subplot_spec["sharey"] is True
def test_single_facet(self):
key = "a"
order = list("abc")
spec = {"variables": {"col": key}, "structure": {"col": order}}
s = Subplots({}, spec, {})
assert s.n_subplots == len(order)
assert s.subplot_spec["ncols"] == len(order)
assert s.subplot_spec["nrows"] == 1
assert s.subplot_spec["sharex"] is True
assert s.subplot_spec["sharey"] is True
def test_two_facets(self):
col_key = "a"
row_key = "b"
col_order = list("xy")
row_order = list("xyz")
spec = {
"variables": {"col": col_key, "row": row_key},
"structure": {"col": col_order, "row": row_order},
}
s = Subplots({}, spec, {})
assert s.n_subplots == len(col_order) * len(row_order)
assert s.subplot_spec["ncols"] == len(col_order)
assert s.subplot_spec["nrows"] == len(row_order)
assert s.subplot_spec["sharex"] is True
assert s.subplot_spec["sharey"] is True
def test_col_facet_wrapped(self):
key = "b"
wrap = 3
order = list("abcde")
spec = {"variables": {"col": key}, "structure": {"col": order}, "wrap": wrap}
s = Subplots({}, spec, {})
assert s.n_subplots == len(order)
assert s.subplot_spec["ncols"] == wrap
assert s.subplot_spec["nrows"] == len(order) // wrap + 1
assert s.subplot_spec["sharex"] is True
assert s.subplot_spec["sharey"] is True
def test_row_facet_wrapped(self):
key = "b"
wrap = 3
order = list("abcde")
spec = {"variables": {"row": key}, "structure": {"row": order}, "wrap": wrap}
s = Subplots({}, spec, {})
assert s.n_subplots == len(order)
assert s.subplot_spec["ncols"] == len(order) // wrap + 1
assert s.subplot_spec["nrows"] == wrap
assert s.subplot_spec["sharex"] is True
assert s.subplot_spec["sharey"] is True
def test_col_facet_wrapped_single_row(self):
key = "b"
order = list("abc")
wrap = len(order) + 2
spec = {"variables": {"col": key}, "structure": {"col": order}, "wrap": wrap}
s = Subplots({}, spec, {})
assert s.n_subplots == len(order)
assert s.subplot_spec["ncols"] == len(order)
assert s.subplot_spec["nrows"] == 1
assert s.subplot_spec["sharex"] is True
assert s.subplot_spec["sharey"] is True
def test_x_and_y_paired(self):
x = ["x", "y", "z"]
y = ["a", "b"]
s = Subplots({}, {}, {"structure": {"x": x, "y": y}})
assert s.n_subplots == len(x) * len(y)
assert s.subplot_spec["ncols"] == len(x)
assert s.subplot_spec["nrows"] == len(y)
assert s.subplot_spec["sharex"] == "col"
assert s.subplot_spec["sharey"] == "row"
def test_x_paired(self):
x = ["x", "y", "z"]
s = Subplots({}, {}, {"structure": {"x": x}})
assert s.n_subplots == len(x)
assert s.subplot_spec["ncols"] == len(x)
assert s.subplot_spec["nrows"] == 1
assert s.subplot_spec["sharex"] == "col"
assert s.subplot_spec["sharey"] is True
def test_y_paired(self):
y = ["x", "y", "z"]
s = Subplots({}, {}, {"structure": {"y": y}})
assert s.n_subplots == len(y)
assert s.subplot_spec["ncols"] == 1
assert s.subplot_spec["nrows"] == len(y)
assert s.subplot_spec["sharex"] is True
assert s.subplot_spec["sharey"] == "row"
def test_x_paired_and_wrapped(self):
x = ["a", "b", "x", "y", "z"]
wrap = 3
s = Subplots({}, {}, {"structure": {"x": x}, "wrap": wrap})
assert s.n_subplots == len(x)
assert s.subplot_spec["ncols"] == wrap
assert s.subplot_spec["nrows"] == len(x) // wrap + 1
assert s.subplot_spec["sharex"] is False
assert s.subplot_spec["sharey"] is True
def test_y_paired_and_wrapped(self):
y = ["a", "b", "x", "y", "z"]
wrap = 2
s = Subplots({}, {}, {"structure": {"y": y}, "wrap": wrap})
assert s.n_subplots == len(y)
assert s.subplot_spec["ncols"] == len(y) // wrap + 1
assert s.subplot_spec["nrows"] == wrap
assert s.subplot_spec["sharex"] is True
assert s.subplot_spec["sharey"] is False
def test_y_paired_and_wrapped_single_row(self):
y = ["x", "y", "z"]
wrap = 1
s = Subplots({}, {}, {"structure": {"y": y}, "wrap": wrap})
assert s.n_subplots == len(y)
assert s.subplot_spec["ncols"] == len(y)
assert s.subplot_spec["nrows"] == 1
assert s.subplot_spec["sharex"] is True
assert s.subplot_spec["sharey"] is False
def test_col_faceted_y_paired(self):
y = ["x", "y", "z"]
key = "a"
order = list("abc")
facet_spec = {"variables": {"col": key}, "structure": {"col": order}}
pair_spec = {"structure": {"y": y}}
s = Subplots({}, facet_spec, pair_spec)
assert s.n_subplots == len(order) * len(y)
assert s.subplot_spec["ncols"] == len(order)
assert s.subplot_spec["nrows"] == len(y)
assert s.subplot_spec["sharex"] is True
assert s.subplot_spec["sharey"] == "row"
def test_row_faceted_x_paired(self):
x = ["f", "s"]
key = "a"
order = list("abc")
facet_spec = {"variables": {"row": key}, "structure": {"row": order}}
pair_spec = {"structure": {"x": x}}
s = Subplots({}, facet_spec, pair_spec)
assert s.n_subplots == len(order) * len(x)
assert s.subplot_spec["ncols"] == len(x)
assert s.subplot_spec["nrows"] == len(order)
assert s.subplot_spec["sharex"] == "col"
assert s.subplot_spec["sharey"] is True
def test_x_any_y_paired_non_cross(self):
x = ["a", "b", "c"]
y = ["x", "y", "z"]
spec = {"structure": {"x": x, "y": y}, "cross": False}
s = Subplots({}, {}, spec)
assert s.n_subplots == len(x)
assert s.subplot_spec["ncols"] == len(y)
assert s.subplot_spec["nrows"] == 1
assert s.subplot_spec["sharex"] is False
assert s.subplot_spec["sharey"] is False
def test_x_any_y_paired_non_cross_wrapped(self):
x = ["a", "b", "c"]
y = ["x", "y", "z"]
wrap = 2
spec = {"structure": {"x": x, "y": y}, "cross": False, "wrap": wrap}
s = Subplots({}, {}, spec)
assert s.n_subplots == len(x)
assert s.subplot_spec["ncols"] == wrap
assert s.subplot_spec["nrows"] == len(x) // wrap + 1
assert s.subplot_spec["sharex"] is False
assert s.subplot_spec["sharey"] is False
def test_forced_unshared_facets(self):
s = Subplots({"sharex": False, "sharey": "row"}, {}, {})
assert s.subplot_spec["sharex"] is False
assert s.subplot_spec["sharey"] == "row"
class TestSubplotElements:
def test_single_subplot(self):
s = Subplots({}, {}, {})
f = s.init_figure({}, {})
assert len(s) == 1
for i, e in enumerate(s):
for side in ["left", "right", "bottom", "top"]:
assert e[side]
for dim in ["col", "row"]:
assert e[dim] is None
for axis in "xy":
assert e[axis] == axis
assert e["ax"] == f.axes[i]
@pytest.mark.parametrize("dim", ["col", "row"])
def test_single_facet_dim(self, dim):
key = "a"
order = list("abc")
spec = {"variables": {dim: key}, "structure": {dim: order}}
s = Subplots({}, spec, {})
s.init_figure(spec, {})
assert len(s) == len(order)
for i, e in enumerate(s):
assert e[dim] == order[i]
for axis in "xy":
assert e[axis] == axis
assert e["top"] == (dim == "col" or i == 0)
assert e["bottom"] == (dim == "col" or i == len(order) - 1)
assert e["left"] == (dim == "row" or i == 0)
assert e["right"] == (dim == "row" or i == len(order) - 1)
@pytest.mark.parametrize("dim", ["col", "row"])
def test_single_facet_dim_wrapped(self, dim):
key = "b"
order = list("abc")
wrap = len(order) - 1
spec = {"variables": {dim: key}, "structure": {dim: order}, "wrap": wrap}
s = Subplots({}, spec, {})
s.init_figure(spec, {})
assert len(s) == len(order)
for i, e in enumerate(s):
assert e[dim] == order[i]
for axis in "xy":
assert e[axis] == axis
sides = {
"col": ["top", "bottom", "left", "right"],
"row": ["left", "right", "top", "bottom"],
}
tests = (
i < wrap,
i >= wrap or i >= len(s) % wrap,
i % wrap == 0,
i % wrap == wrap - 1 or i + 1 == len(s),
)
for side, expected in zip(sides[dim], tests):
assert e[side] == expected
def test_both_facet_dims(self):
col = "a"
row = "b"
col_order = list("ab")
row_order = list("xyz")
facet_spec = {
"variables": {"col": col, "row": row},
"structure": {"col": col_order, "row": row_order},
}
s = Subplots({}, facet_spec, {})
s.init_figure(facet_spec, {})
n_cols = len(col_order)
n_rows = len(row_order)
assert len(s) == n_cols * n_rows
es = list(s)
for e in es[:n_cols]:
assert e["top"]
for e in es[::n_cols]:
assert e["left"]
for e in es[n_cols - 1::n_cols]:
assert e["right"]
for e in es[-n_cols:]:
assert e["bottom"]
for e, (row_, col_) in zip(es, itertools.product(row_order, col_order)):
assert e["col"] == col_
assert e["row"] == row_
for e in es:
assert e["x"] == "x"
assert e["y"] == "y"
@pytest.mark.parametrize("var", ["x", "y"])
def test_single_paired_var(self, var):
other_var = {"x": "y", "y": "x"}[var]
pairings = ["x", "y", "z"]
pair_spec = {
"variables": {f"{var}{i}": v for i, v in enumerate(pairings)},
"structure": {var: [f"{var}{i}" for i, _ in enumerate(pairings)]},
}
s = Subplots({}, {}, pair_spec)
s.init_figure(pair_spec)
assert len(s) == len(pair_spec["structure"][var])
for i, e in enumerate(s):
assert e[var] == f"{var}{i}"
assert e[other_var] == other_var
assert e["col"] is e["row"] is None
tests = i == 0, True, True, i == len(s) - 1
sides = {
"x": ["left", "right", "top", "bottom"],
"y": ["top", "bottom", "left", "right"],
}
for side, expected in zip(sides[var], tests):
assert e[side] == expected
@pytest.mark.parametrize("var", ["x", "y"])
def test_single_paired_var_wrapped(self, var):
other_var = {"x": "y", "y": "x"}[var]
pairings = ["x", "y", "z", "a", "b"]
wrap = len(pairings) - 2
pair_spec = {
"variables": {f"{var}{i}": val for i, val in enumerate(pairings)},
"structure": {var: [f"{var}{i}" for i, _ in enumerate(pairings)]},
"wrap": wrap
}
s = Subplots({}, {}, pair_spec)
s.init_figure(pair_spec)
assert len(s) == len(pairings)
for i, e in enumerate(s):
assert e[var] == f"{var}{i}"
assert e[other_var] == other_var
assert e["col"] is e["row"] is None
tests = (
i < wrap,
i >= wrap or i >= len(s) % wrap,
i % wrap == 0,
i % wrap == wrap - 1 or i + 1 == len(s),
)
sides = {
"x": ["top", "bottom", "left", "right"],
"y": ["left", "right", "top", "bottom"],
}
for side, expected in zip(sides[var], tests):
assert e[side] == expected
def test_both_paired_variables(self):
x = ["x0", "x1"]
y = ["y0", "y1", "y2"]
pair_spec = {"structure": {"x": x, "y": y}}
s = Subplots({}, {}, pair_spec)
s.init_figure(pair_spec)
n_cols = len(x)
n_rows = len(y)
assert len(s) == n_cols * n_rows
es = list(s)
for e in es[:n_cols]:
assert e["top"]
for e in es[::n_cols]:
assert e["left"]
for e in es[n_cols - 1::n_cols]:
assert e["right"]
for e in es[-n_cols:]:
assert e["bottom"]
for e in es:
assert e["col"] is e["row"] is None
for i in range(len(y)):
for j in range(len(x)):
e = es[i * len(x) + j]
assert e["x"] == f"x{j}"
assert e["y"] == f"y{i}"
def test_both_paired_non_cross(self):
pair_spec = {
"structure": {"x": ["x0", "x1", "x2"], "y": ["y0", "y1", "y2"]},
"cross": False
}
s = Subplots({}, {}, pair_spec)
s.init_figure(pair_spec)
for i, e in enumerate(s):
assert e["x"] == f"x{i}"
assert e["y"] == f"y{i}"
assert e["col"] is e["row"] is None
assert e["left"] == (i == 0)
assert e["right"] == (i == (len(s) - 1))
assert e["top"]
assert e["bottom"]
@pytest.mark.parametrize("dim,var", [("col", "y"), ("row", "x")])
def test_one_facet_one_paired(self, dim, var):
other_var = {"x": "y", "y": "x"}[var]
other_dim = {"col": "row", "row": "col"}[dim]
order = list("abc")
facet_spec = {"variables": {dim: "s"}, "structure": {dim: order}}
pairings = ["x", "y", "t"]
pair_spec = {
"variables": {f"{var}{i}": val for i, val in enumerate(pairings)},
"structure": {var: [f"{var}{i}" for i, _ in enumerate(pairings)]},
}
s = Subplots({}, facet_spec, pair_spec)
s.init_figure(pair_spec)
n_cols = len(order) if dim == "col" else len(pairings)
n_rows = len(order) if dim == "row" else len(pairings)
assert len(s) == len(order) * len(pairings)
es = list(s)
for e in es[:n_cols]:
assert e["top"]
for e in es[::n_cols]:
assert e["left"]
for e in es[n_cols - 1::n_cols]:
assert e["right"]
for e in es[-n_cols:]:
assert e["bottom"]
if dim == "row":
es = np.reshape(es, (n_rows, n_cols)).T.ravel()
for i, e in enumerate(es):
assert e[dim] == order[i % len(pairings)]
assert e[other_dim] is None
assert e[var] == f"{var}{i // len(order)}"
assert e[other_var] == other_var
================================================
FILE: tests/_marks/__init__.py
================================================
================================================
FILE: tests/_marks/test_area.py
================================================
import matplotlib as mpl
from matplotlib.colors import to_rgba, to_rgba_array
from numpy.testing import assert_array_equal
from seaborn._core.plot import Plot
from seaborn._marks.area import Area, Band
class TestArea:
def test_single_defaults(self):
x, y = [1, 2, 3], [1, 2, 1]
p = Plot(x=x, y=y).add(Area()).plot()
ax = p._figure.axes[0]
poly = ax.patches[0]
verts = poly.get_path().vertices.T
colors = p._theme["axes.prop_cycle"].by_key()["color"]
expected_x = [1, 2, 3, 3, 2, 1, 1]
assert_array_equal(verts[0], expected_x)
expected_y = [0, 0, 0, 1, 2, 1, 0]
assert_array_equal(verts[1], expected_y)
fc = poly.get_facecolor()
assert_array_equal(fc, to_rgba(colors[0], .2))
ec = poly.get_edgecolor()
assert_array_equal(ec, to_rgba(colors[0], 1))
lw = poly.get_linewidth()
assert_array_equal(lw, mpl.rcParams["patch.linewidth"] * 2)
def test_set_properties(self):
x, y = [1, 2, 3], [1, 2, 1]
mark = Area(
color=".33",
alpha=.3,
edgecolor=".88",
edgealpha=.8,
edgewidth=2,
edgestyle=(0, (2, 1)),
)
p = Plot(x=x, y=y).add(mark).plot()
ax = p._figure.axes[0]
poly = ax.patches[0]
fc = poly.get_facecolor()
assert_array_equal(fc, to_rgba(mark.color, mark.alpha))
ec = poly.get_edgecolor()
assert_array_equal(ec, to_rgba(mark.edgecolor, mark.edgealpha))
lw = poly.get_linewidth()
assert_array_equal(lw, mark.edgewidth * 2)
ls = poly.get_linestyle()
dash_on, dash_off = mark.edgestyle[1]
expected = (0, (mark.edgewidth * dash_on / 4, mark.edgewidth * dash_off / 4))
assert ls == expected
def test_mapped_properties(self):
x, y = [1, 2, 3, 2, 3, 4], [1, 2, 1, 1, 3, 2]
g = ["a", "a", "a", "b", "b", "b"]
cs = [".2", ".8"]
p = Plot(x=x, y=y, color=g, edgewidth=g).scale(color=cs).add(Area()).plot()
ax = p._figure.axes[0]
expected_x = [1, 2, 3, 3, 2, 1, 1], [2, 3, 4, 4, 3, 2, 2]
expected_y = [0, 0, 0, 1, 2, 1, 0], [0, 0, 0, 2, 3, 1, 0]
for i, poly in enumerate(ax.patches):
verts = poly.get_path().vertices.T
assert_array_equal(verts[0], expected_x[i])
assert_array_equal(verts[1], expected_y[i])
fcs = [p.get_facecolor() for p in ax.patches]
assert_array_equal(fcs, to_rgba_array(cs, .2))
ecs = [p.get_edgecolor() for p in ax.patches]
assert_array_equal(ecs, to_rgba_array(cs, 1))
lws = [p.get_linewidth() for p in ax.patches]
assert lws[0] > lws[1]
def test_unfilled(self):
x, y = [1, 2, 3], [1, 2, 1]
c = ".5"
p = Plot(x=x, y=y).add(Area(fill=False, color=c)).plot()
ax = p._figure.axes[0]
poly = ax.patches[0]
assert poly.get_facecolor() == to_rgba(c, 0)
class TestBand:
def test_range(self):
x, ymin, ymax = [1, 2, 4], [2, 1, 4], [3, 3, 5]
p = Plot(x=x, ymin=ymin, ymax=ymax).add(Band()).plot()
ax = p._figure.axes[0]
verts = ax.patches[0].get_path().vertices.T
expected_x = [1, 2, 4, 4, 2, 1, 1]
assert_array_equal(verts[0], expected_x)
expected_y = [2, 1, 4, 5, 3, 3, 2]
assert_array_equal(verts[1], expected_y)
def test_auto_range(self):
x = [1, 1, 2, 2, 2]
y = [1, 2, 3, 4, 5]
p = Plot(x=x, y=y).add(Band()).plot()
ax = p._figure.axes[0]
verts = ax.patches[0].get_path().vertices.T
expected_x = [1, 2, 2, 1, 1]
assert_array_equal(verts[0], expected_x)
expected_y = [1, 3, 5, 2, 1]
assert_array_equal(verts[1], expected_y)
================================================
FILE: tests/_marks/test_bar.py
================================================
import numpy as np
import pandas as pd
from matplotlib.colors import to_rgba, to_rgba_array
import pytest
from numpy.testing import assert_array_equal
from seaborn._core.plot import Plot
from seaborn._marks.bar import Bar, Bars
class TestBar:
def plot_bars(self, variables, mark_kws, layer_kws):
p = Plot(**variables).add(Bar(**mark_kws), **layer_kws).plot()
ax = p._figure.axes[0]
return [bar for barlist in ax.containers for bar in barlist]
def check_bar(self, bar, x, y, width, height):
assert bar.get_x() == pytest.approx(x)
assert bar.get_y() == pytest.approx(y)
assert bar.get_width() == pytest.approx(width)
assert bar.get_height() == pytest.approx(height)
def test_categorical_positions_vertical(self):
x = ["a", "b"]
y = [1, 2]
w = .8
bars = self.plot_bars({"x": x, "y": y}, {}, {})
for i, bar in enumerate(bars):
self.check_bar(bar, i - w / 2, 0, w, y[i])
def test_categorical_positions_horizontal(self):
x = [1, 2]
y = ["a", "b"]
w = .8
bars = self.plot_bars({"x": x, "y": y}, {}, {})
for i, bar in enumerate(bars):
self.check_bar(bar, 0, i - w / 2, x[i], w)
def test_numeric_positions_vertical(self):
x = [1, 2]
y = [3, 4]
w = .8
bars = self.plot_bars({"x": x, "y": y}, {}, {})
for i, bar in enumerate(bars):
self.check_bar(bar, x[i] - w / 2, 0, w, y[i])
def test_numeric_positions_horizontal(self):
x = [1, 2]
y = [3, 4]
w = .8
bars = self.plot_bars({"x": x, "y": y}, {}, {"orient": "h"})
for i, bar in enumerate(bars):
self.check_bar(bar, 0, y[i] - w / 2, x[i], w)
def test_set_properties(self):
x = ["a", "b", "c"]
y = [1, 3, 2]
mark = Bar(
color=".8",
alpha=.5,
edgecolor=".3",
edgealpha=.9,
edgestyle=(2, 1),
edgewidth=1.5,
)
p = Plot(x, y).add(mark).plot()
ax = p._figure.axes[0]
for bar in ax.patches:
assert bar.get_facecolor() == to_rgba(mark.color, mark.alpha)
assert bar.get_edgecolor() == to_rgba(mark.edgecolor, mark.edgealpha)
# See comments in plotting method for why we need these adjustments
assert bar.get_linewidth() == mark.edgewidth * 2
expected_dashes = (mark.edgestyle[0] / 2, mark.edgestyle[1] / 2)
assert bar.get_linestyle() == (0, expected_dashes)
def test_mapped_properties(self):
x = ["a", "b"]
y = [1, 2]
mark = Bar(alpha=.2)
p = Plot(x, y, color=x, edgewidth=y).add(mark).plot()
ax = p._figure.axes[0]
colors = p._theme["axes.prop_cycle"].by_key()["color"]
for i, bar in enumerate(ax.patches):
assert bar.get_facecolor() == to_rgba(colors[i], mark.alpha)
assert bar.get_edgecolor() == to_rgba(colors[i], 1)
assert ax.patches[0].get_linewidth() < ax.patches[1].get_linewidth()
def test_zero_height_skipped(self):
p = Plot(["a", "b", "c"], [1, 0, 2]).add(Bar()).plot()
ax = p._figure.axes[0]
assert len(ax.patches) == 2
def test_artist_kws_clip(self):
p = Plot(["a", "b"], [1, 2]).add(Bar({"clip_on": False})).plot()
patch = p._figure.axes[0].patches[0]
assert patch.clipbox is None
class TestBars:
@pytest.fixture
def x(self):
return pd.Series([4, 5, 6, 7, 8], name="x")
@pytest.fixture
def y(self):
return pd.Series([2, 8, 3, 5, 9], name="y")
@pytest.fixture
def color(self):
return pd.Series(["a", "b", "c", "a", "c"], name="color")
def test_positions(self, x, y):
p = Plot(x, y).add(Bars()).plot()
ax = p._figure.axes[0]
paths = ax.collections[0].get_paths()
assert len(paths) == len(x)
for i, path in enumerate(paths):
verts = path.vertices
assert verts[0, 0] == pytest.approx(x[i] - .5)
assert verts[1, 0] == pytest.approx(x[i] + .5)
assert verts[0, 1] == 0
assert verts[3, 1] == y[i]
def test_positions_horizontal(self, x, y):
p = Plot(x=y, y=x).add(Bars(), orient="h").plot()
ax = p._figure.axes[0]
paths = ax.collections[0].get_paths()
assert len(paths) == len(x)
for i, path in enumerate(paths):
verts = path.vertices
assert verts[0, 1] == pytest.approx(x[i] - .5)
assert verts[3, 1] == pytest.approx(x[i] + .5)
assert verts[0, 0] == 0
assert verts[1, 0] == y[i]
def test_width(self, x, y):
p = Plot(x, y).add(Bars(width=.4)).plot()
ax = p._figure.axes[0]
paths = ax.collections[0].get_paths()
for i, path in enumerate(paths):
verts = path.vertices
assert verts[0, 0] == pytest.approx(x[i] - .2)
assert verts[1, 0] == pytest.approx(x[i] + .2)
def test_mapped_color_direct_alpha(self, x, y, color):
alpha = .5
p = Plot(x, y, color=color).add(Bars(alpha=alpha)).plot()
ax = p._figure.axes[0]
fcs = ax.collections[0].get_facecolors()
C0, C1, C2, *_ = p._theme["axes.prop_cycle"].by_key()["color"]
expected = to_rgba_array([C0, C1, C2, C0, C2], alpha)
assert_array_equal(fcs, expected)
def test_mapped_edgewidth(self, x, y):
p = Plot(x, y, edgewidth=y).add(Bars()).plot()
ax = p._figure.axes[0]
lws = ax.collections[0].get_linewidths()
assert_array_equal(np.argsort(lws), np.argsort(y))
def test_auto_edgewidth(self):
x0 = np.arange(10)
x1 = np.arange(1000)
p0 = Plot(x0, x0).add(Bars()).plot()
p1 = Plot(x1, x1).add(Bars()).plot()
lw0 = p0._figure.axes[0].collections[0].get_linewidths()
lw1 = p1._figure.axes[0].collections[0].get_linewidths()
assert (lw0 > lw1).all()
def test_unfilled(self, x, y):
p = Plot(x, y).add(Bars(fill=False, edgecolor="C4")).plot()
ax = p._figure.axes[0]
fcs = ax.collections[0].get_facecolors()
ecs = ax.collections[0].get_edgecolors()
colors = p._theme["axes.prop_cycle"].by_key()["color"]
assert_array_equal(fcs, to_rgba_array([colors[0]] * len(x), 0))
assert_array_equal(ecs, to_rgba_array([colors[4]] * len(x), 1))
def test_log_scale(self):
x = y = [1, 10, 100, 1000]
p = Plot(x, y).add(Bars()).scale(x="log").plot()
ax = p._figure.axes[0]
paths = ax.collections[0].get_paths()
for a, b in zip(paths, paths[1:]):
assert a.vertices[1, 0] == pytest.approx(b.vertices[0, 0])
================================================
FILE: tests/_marks/test_base.py
================================================
from dataclasses import dataclass
import numpy as np
import pandas as pd
import matplotlib as mpl
import pytest
from numpy.testing import assert_array_equal
from seaborn._marks.base import Mark, Mappable, resolve_color
class TestMappable:
def mark(self, **features):
@dataclass
class MockMark(Mark):
linewidth: float = Mappable(rc="lines.linewidth")
pointsize: float = Mappable(4)
color: str = Mappable("C0")
fillcolor: str = Mappable(depend="color")
alpha: float = Mappable(1)
fillalpha: float = Mappable(depend="alpha")
m = MockMark(**features)
return m
def test_repr(self):
assert str(Mappable(.5)) == "<0.5>"
assert str(Mappable("CO")) == "<'CO'>"
assert str(Mappable(rc="lines.linewidth")) == ""
assert str(Mappable(depend="color")) == ""
assert str(Mappable(auto=True)) == ""
def test_input_checks(self):
with pytest.raises(AssertionError):
Mappable(rc="bogus.parameter")
with pytest.raises(AssertionError):
Mappable(depend="nonexistent_feature")
def test_value(self):
val = 3
m = self.mark(linewidth=val)
assert m._resolve({}, "linewidth") == val
df = pd.DataFrame(index=pd.RangeIndex(10))
assert_array_equal(m._resolve(df, "linewidth"), np.full(len(df), val))
def test_default(self):
val = 3
m = self.mark(linewidth=Mappable(val))
assert m._resolve({}, "linewidth") == val
df = pd.DataFrame(index=pd.RangeIndex(10))
assert_array_equal(m._resolve(df, "linewidth"), np.full(len(df), val))
def test_rcparam(self):
param = "lines.linewidth"
val = mpl.rcParams[param]
m = self.mark(linewidth=Mappable(rc=param))
assert m._resolve({}, "linewidth") == val
df = pd.DataFrame(index=pd.RangeIndex(10))
assert_array_equal(m._resolve(df, "linewidth"), np.full(len(df), val))
def test_depends(self):
val = 2
df = pd.DataFrame(index=pd.RangeIndex(10))
m = self.mark(pointsize=Mappable(val), linewidth=Mappable(depend="pointsize"))
assert m._resolve({}, "linewidth") == val
assert_array_equal(m._resolve(df, "linewidth"), np.full(len(df), val))
m = self.mark(pointsize=val * 2, linewidth=Mappable(depend="pointsize"))
assert m._resolve({}, "linewidth") == val * 2
assert_array_equal(m._resolve(df, "linewidth"), np.full(len(df), val * 2))
def test_mapped(self):
values = {"a": 1, "b": 2, "c": 3}
def f(x):
return np.array([values[x_i] for x_i in x])
m = self.mark(linewidth=Mappable(2))
scales = {"linewidth": f}
assert m._resolve({"linewidth": "c"}, "linewidth", scales) == 3
df = pd.DataFrame({"linewidth": ["a", "b", "c"]})
expected = np.array([1, 2, 3], float)
assert_array_equal(m._resolve(df, "linewidth", scales), expected)
def test_color(self):
c, a = "C1", .5
m = self.mark(color=c, alpha=a)
assert resolve_color(m, {}) == mpl.colors.to_rgba(c, a)
df = pd.DataFrame(index=pd.RangeIndex(10))
cs = [c] * len(df)
assert_array_equal(resolve_color(m, df), mpl.colors.to_rgba_array(cs, a))
def test_color_mapped_alpha(self):
c = "r"
values = {"a": .2, "b": .5, "c": .8}
m = self.mark(color=c, alpha=Mappable(1))
scales = {"alpha": lambda s: np.array([values[s_i] for s_i in s])}
assert resolve_color(m, {"alpha": "b"}, "", scales) == mpl.colors.to_rgba(c, .5)
df = pd.DataFrame({"alpha": list(values.keys())})
# Do this in two steps for mpl 3.2 compat
expected = mpl.colors.to_rgba_array([c] * len(df))
expected[:, 3] = list(values.values())
assert_array_equal(resolve_color(m, df, "", scales), expected)
def test_color_scaled_as_strings(self):
colors = ["C1", "dodgerblue", "#445566"]
m = self.mark()
scales = {"color": lambda s: colors}
actual = resolve_color(m, {"color": pd.Series(["a", "b", "c"])}, "", scales)
expected = mpl.colors.to_rgba_array(colors)
assert_array_equal(actual, expected)
def test_fillcolor(self):
c, a = "green", .8
fa = .2
m = self.mark(
color=c, alpha=a,
fillcolor=Mappable(depend="color"), fillalpha=Mappable(fa),
)
assert resolve_color(m, {}) == mpl.colors.to_rgba(c, a)
assert resolve_color(m, {}, "fill") == mpl.colors.to_rgba(c, fa)
df = pd.DataFrame(index=pd.RangeIndex(10))
cs = [c] * len(df)
assert_array_equal(resolve_color(m, df), mpl.colors.to_rgba_array(cs, a))
assert_array_equal(
resolve_color(m, df, "fill"), mpl.colors.to_rgba_array(cs, fa)
)
================================================
FILE: tests/_marks/test_dot.py
================================================
from matplotlib.colors import to_rgba, to_rgba_array
import pytest
from numpy.testing import assert_array_equal
from seaborn.palettes import color_palette
from seaborn._core.plot import Plot
from seaborn._marks.dot import Dot, Dots
@pytest.fixture(autouse=True)
def default_palette():
with color_palette("deep"):
yield
class DotBase:
def check_offsets(self, points, x, y):
offsets = points.get_offsets().T
assert_array_equal(offsets[0], x)
assert_array_equal(offsets[1], y)
def check_colors(self, part, points, colors, alpha=None):
rgba = to_rgba_array(colors, alpha)
getter = getattr(points, f"get_{part}colors")
assert_array_equal(getter(), rgba)
class TestDot(DotBase):
def test_simple(self):
x = [1, 2, 3]
y = [4, 5, 2]
p = Plot(x=x, y=y).add(Dot()).plot()
ax = p._figure.axes[0]
points, = ax.collections
C0, *_ = p._theme["axes.prop_cycle"].by_key()["color"]
self.check_offsets(points, x, y)
self.check_colors("face", points, [C0] * 3, 1)
self.check_colors("edge", points, [C0] * 3, 1)
def test_filled_unfilled_mix(self):
x = [1, 2]
y = [4, 5]
marker = ["a", "b"]
shapes = ["o", "x"]
mark = Dot(edgecolor="w", stroke=2, edgewidth=1)
p = Plot(x=x, y=y).add(mark, marker=marker).scale(marker=shapes).plot()
ax = p._figure.axes[0]
points, = ax.collections
C0, *_ = p._theme["axes.prop_cycle"].by_key()["color"]
self.check_offsets(points, x, y)
self.check_colors("face", points, [C0, to_rgba(C0, 0)], None)
self.check_colors("edge", points, ["w", C0], 1)
expected = [mark.edgewidth, mark.stroke]
assert_array_equal(points.get_linewidths(), expected)
def test_missing_coordinate_data(self):
x = [1, float("nan"), 3]
y = [5, 3, 4]
p = Plot(x=x, y=y).add(Dot()).plot()
ax = p._figure.axes[0]
points, = ax.collections
self.check_offsets(points, [1, 3], [5, 4])
@pytest.mark.parametrize("prop", ["color", "fill", "marker", "pointsize"])
def test_missing_semantic_data(self, prop):
x = [1, 2, 3]
y = [5, 3, 4]
z = ["a", float("nan"), "b"]
p = Plot(x=x, y=y, **{prop: z}).add(Dot()).plot()
ax = p._figure.axes[0]
points, = ax.collections
self.check_offsets(points, [1, 3], [5, 4])
class TestDots(DotBase):
def test_simple(self):
x = [1, 2, 3]
y = [4, 5, 2]
p = Plot(x=x, y=y).add(Dots()).plot()
ax = p._figure.axes[0]
points, = ax.collections
C0, *_ = p._theme["axes.prop_cycle"].by_key()["color"]
self.check_offsets(points, x, y)
self.check_colors("face", points, [C0] * 3, .2)
self.check_colors("edge", points, [C0] * 3, 1)
def test_set_color(self):
x = [1, 2, 3]
y = [4, 5, 2]
m = Dots(color=".25")
p = Plot(x=x, y=y).add(m).plot()
ax = p._figure.axes[0]
points, = ax.collections
self.check_offsets(points, x, y)
self.check_colors("face", points, [m.color] * 3, .2)
self.check_colors("edge", points, [m.color] * 3, 1)
def test_map_color(self):
x = [1, 2, 3]
y = [4, 5, 2]
c = ["a", "b", "a"]
p = Plot(x=x, y=y, color=c).add(Dots()).plot()
ax = p._figure.axes[0]
points, = ax.collections
C0, C1, *_ = p._theme["axes.prop_cycle"].by_key()["color"]
self.check_offsets(points, x, y)
self.check_colors("face", points, [C0, C1, C0], .2)
self.check_colors("edge", points, [C0, C1, C0], 1)
def test_fill(self):
x = [1, 2, 3]
y = [4, 5, 2]
c = ["a", "b", "a"]
p = Plot(x=x, y=y, color=c).add(Dots(fill=False)).plot()
ax = p._figure.axes[0]
points, = ax.collections
C0, C1, *_ = p._theme["axes.prop_cycle"].by_key()["color"]
self.check_offsets(points, x, y)
self.check_colors("face", points, [C0, C1, C0], 0)
self.check_colors("edge", points, [C0, C1, C0], 1)
def test_pointsize(self):
x = [1, 2, 3]
y = [4, 5, 2]
s = 3
p = Plot(x=x, y=y).add(Dots(pointsize=s)).plot()
ax = p._figure.axes[0]
points, = ax.collections
self.check_offsets(points, x, y)
assert_array_equal(points.get_sizes(), [s ** 2] * 3)
def test_stroke(self):
x = [1, 2, 3]
y = [4, 5, 2]
s = 3
p = Plot(x=x, y=y).add(Dots(stroke=s)).plot()
ax = p._figure.axes[0]
points, = ax.collections
self.check_offsets(points, x, y)
assert_array_equal(points.get_linewidths(), [s] * 3)
def test_filled_unfilled_mix(self):
x = [1, 2]
y = [4, 5]
marker = ["a", "b"]
shapes = ["o", "x"]
mark = Dots(stroke=2)
p = Plot(x=x, y=y).add(mark, marker=marker).scale(marker=shapes).plot()
ax = p._figure.axes[0]
points, = ax.collections
C0, C1, *_ = p._theme["axes.prop_cycle"].by_key()["color"]
self.check_offsets(points, x, y)
self.check_colors("face", points, [to_rgba(C0, .2), to_rgba(C0, 0)], None)
self.check_colors("edge", points, [C0, C0], 1)
assert_array_equal(points.get_linewidths(), [mark.stroke] * 2)
================================================
FILE: tests/_marks/test_line.py
================================================
import numpy as np
import matplotlib as mpl
from matplotlib.colors import same_color, to_rgba
from numpy.testing import assert_array_equal, assert_array_almost_equal
from seaborn._core.plot import Plot
from seaborn._core.moves import Dodge
from seaborn._marks.line import Dash, Line, Path, Lines, Paths, Range
class TestPath:
def test_xy_data(self):
x = [1, 5, 3, np.nan, 2]
y = [1, 4, 2, 5, 3]
g = [1, 2, 1, 1, 2]
p = Plot(x=x, y=y, group=g).add(Path()).plot()
line1, line2 = p._figure.axes[0].get_lines()
assert_array_equal(line1.get_xdata(), [1, 3, np.nan])
assert_array_equal(line1.get_ydata(), [1, 2, np.nan])
assert_array_equal(line2.get_xdata(), [5, 2])
assert_array_equal(line2.get_ydata(), [4, 3])
def test_shared_colors_direct(self):
x = y = [1, 2, 3]
color = ".44"
m = Path(color=color)
p = Plot(x=x, y=y).add(m).plot()
line, = p._figure.axes[0].get_lines()
assert same_color(line.get_color(), color)
assert same_color(line.get_markeredgecolor(), color)
assert same_color(line.get_markerfacecolor(), color)
def test_separate_colors_direct(self):
x = y = [1, 2, 3]
y = [1, 2, 3]
m = Path(color=".22", edgecolor=".55", fillcolor=".77")
p = Plot(x=x, y=y).add(m).plot()
line, = p._figure.axes[0].get_lines()
assert same_color(line.get_color(), m.color)
assert same_color(line.get_markeredgecolor(), m.edgecolor)
assert same_color(line.get_markerfacecolor(), m.fillcolor)
def test_shared_colors_mapped(self):
x = y = [1, 2, 3, 4]
c = ["a", "a", "b", "b"]
m = Path()
p = Plot(x=x, y=y, color=c).add(m).plot()
ax = p._figure.axes[0]
colors = p._theme["axes.prop_cycle"].by_key()["color"]
for i, line in enumerate(ax.get_lines()):
assert same_color(line.get_color(), colors[i])
assert same_color(line.get_markeredgecolor(), colors[i])
assert same_color(line.get_markerfacecolor(), colors[i])
def test_separate_colors_mapped(self):
x = y = [1, 2, 3, 4]
c = ["a", "a", "b", "b"]
d = ["x", "y", "x", "y"]
m = Path()
p = Plot(x=x, y=y, color=c, fillcolor=d).add(m).plot()
ax = p._figure.axes[0]
colors = p._theme["axes.prop_cycle"].by_key()["color"]
for i, line in enumerate(ax.get_lines()):
assert same_color(line.get_color(), colors[i // 2])
assert same_color(line.get_markeredgecolor(), colors[i // 2])
assert same_color(line.get_markerfacecolor(), colors[i % 2])
def test_color_with_alpha(self):
x = y = [1, 2, 3]
m = Path(color=(.4, .9, .2, .5), fillcolor=(.2, .2, .3, .9))
p = Plot(x=x, y=y).add(m).plot()
line, = p._figure.axes[0].get_lines()
assert same_color(line.get_color(), m.color)
assert same_color(line.get_markeredgecolor(), m.color)
assert same_color(line.get_markerfacecolor(), m.fillcolor)
def test_color_and_alpha(self):
x = y = [1, 2, 3]
m = Path(color=(.4, .9, .2), fillcolor=(.2, .2, .3), alpha=.5)
p = Plot(x=x, y=y).add(m).plot()
line, = p._figure.axes[0].get_lines()
assert same_color(line.get_color(), to_rgba(m.color, m.alpha))
assert same_color(line.get_markeredgecolor(), to_rgba(m.color, m.alpha))
assert same_color(line.get_markerfacecolor(), to_rgba(m.fillcolor, m.alpha))
def test_other_props_direct(self):
x = y = [1, 2, 3]
m = Path(marker="s", linestyle="--", linewidth=3, pointsize=10, edgewidth=1)
p = Plot(x=x, y=y).add(m).plot()
line, = p._figure.axes[0].get_lines()
assert line.get_marker() == m.marker
assert line.get_linestyle() == m.linestyle
assert line.get_linewidth() == m.linewidth
assert line.get_markersize() == m.pointsize
assert line.get_markeredgewidth() == m.edgewidth
def test_other_props_mapped(self):
x = y = [1, 2, 3, 4]
g = ["a", "a", "b", "b"]
m = Path()
p = Plot(x=x, y=y, marker=g, linestyle=g, pointsize=g).add(m).plot()
line1, line2 = p._figure.axes[0].get_lines()
assert line1.get_marker() != line2.get_marker()
# Matplotlib bug in storing linestyle from dash pattern
# assert line1.get_linestyle() != line2.get_linestyle()
assert line1.get_markersize() != line2.get_markersize()
def test_capstyle(self):
x = y = [1, 2]
rc = {"lines.solid_capstyle": "projecting", "lines.dash_capstyle": "round"}
p = Plot(x, y).add(Path()).theme(rc).plot()
line, = p._figure.axes[0].get_lines()
assert line.get_dash_capstyle() == "projecting"
p = Plot(x, y).add(Path(linestyle="--")).theme(rc).plot()
line, = p._figure.axes[0].get_lines()
assert line.get_dash_capstyle() == "round"
p = Plot(x, y).add(Path({"solid_capstyle": "butt"})).theme(rc).plot()
line, = p._figure.axes[0].get_lines()
assert line.get_solid_capstyle() == "butt"
class TestLine:
# Most behaviors shared with Path and covered by above tests
def test_xy_data(self):
x = [1, 5, 3, np.nan, 2]
y = [1, 4, 2, 5, 3]
g = [1, 2, 1, 1, 2]
p = Plot(x=x, y=y, group=g).add(Line()).plot()
line1, line2 = p._figure.axes[0].get_lines()
assert_array_equal(line1.get_xdata(), [1, 3])
assert_array_equal(line1.get_ydata(), [1, 2])
assert_array_equal(line2.get_xdata(), [2, 5])
assert_array_equal(line2.get_ydata(), [3, 4])
class TestPaths:
def test_xy_data(self):
x = [1, 5, 3, np.nan, 2]
y = [1, 4, 2, 5, 3]
g = [1, 2, 1, 1, 2]
p = Plot(x=x, y=y, group=g).add(Paths()).plot()
lines, = p._figure.axes[0].collections
verts = lines.get_paths()[0].vertices.T
assert_array_equal(verts[0], [1, 3, np.nan])
assert_array_equal(verts[1], [1, 2, np.nan])
verts = lines.get_paths()[1].vertices.T
assert_array_equal(verts[0], [5, 2])
assert_array_equal(verts[1], [4, 3])
def test_set_properties(self):
x = y = [1, 2, 3]
m = Paths(color=".737", linewidth=1, linestyle=(3, 1))
p = Plot(x=x, y=y).add(m).plot()
lines, = p._figure.axes[0].collections
assert same_color(lines.get_color().squeeze(), m.color)
assert lines.get_linewidth().item() == m.linewidth
assert lines.get_dashes()[0] == (0, list(m.linestyle))
def test_mapped_properties(self):
x = y = [1, 2, 3, 4]
g = ["a", "a", "b", "b"]
p = Plot(x=x, y=y, color=g, linewidth=g, linestyle=g).add(Paths()).plot()
lines, = p._figure.axes[0].collections
assert not np.array_equal(lines.get_colors()[0], lines.get_colors()[1])
assert lines.get_linewidths()[0] != lines.get_linewidth()[1]
assert lines.get_linestyle()[0] != lines.get_linestyle()[1]
def test_color_with_alpha(self):
x = y = [1, 2, 3]
m = Paths(color=(.2, .6, .9, .5))
p = Plot(x=x, y=y).add(m).plot()
lines, = p._figure.axes[0].collections
assert same_color(lines.get_colors().squeeze(), m.color)
def test_color_and_alpha(self):
x = y = [1, 2, 3]
m = Paths(color=(.2, .6, .9), alpha=.5)
p = Plot(x=x, y=y).add(m).plot()
lines, = p._figure.axes[0].collections
assert same_color(lines.get_colors().squeeze(), to_rgba(m.color, m.alpha))
def test_capstyle(self):
x = y = [1, 2]
rc = {"lines.solid_capstyle": "projecting"}
with mpl.rc_context(rc):
p = Plot(x, y).add(Paths()).plot()
lines = p._figure.axes[0].collections[0]
assert lines.get_capstyle() == "projecting"
p = Plot(x, y).add(Paths(linestyle="--")).plot()
lines = p._figure.axes[0].collections[0]
assert lines.get_capstyle() == "projecting"
p = Plot(x, y).add(Paths({"capstyle": "butt"})).plot()
lines = p._figure.axes[0].collections[0]
assert lines.get_capstyle() == "butt"
class TestLines:
def test_xy_data(self):
x = [1, 5, 3, np.nan, 2]
y = [1, 4, 2, 5, 3]
g = [1, 2, 1, 1, 2]
p = Plot(x=x, y=y, group=g).add(Lines()).plot()
lines, = p._figure.axes[0].collections
verts = lines.get_paths()[0].vertices.T
assert_array_equal(verts[0], [1, 3])
assert_array_equal(verts[1], [1, 2])
verts = lines.get_paths()[1].vertices.T
assert_array_equal(verts[0], [2, 5])
assert_array_equal(verts[1], [3, 4])
def test_single_orient_value(self):
x = [1, 1, 1]
y = [1, 2, 3]
p = Plot(x, y).add(Lines()).plot()
lines, = p._figure.axes[0].collections
verts = lines.get_paths()[0].vertices.T
assert_array_equal(verts[0], x)
assert_array_equal(verts[1], y)
class TestRange:
def test_xy_data(self):
x = [1, 2]
ymin = [1, 4]
ymax = [2, 3]
p = Plot(x=x, ymin=ymin, ymax=ymax).add(Range()).plot()
lines, = p._figure.axes[0].collections
for i, path in enumerate(lines.get_paths()):
verts = path.vertices.T
assert_array_equal(verts[0], [x[i], x[i]])
assert_array_equal(verts[1], [ymin[i], ymax[i]])
def test_auto_range(self):
x = [1, 1, 2, 2, 2]
y = [1, 2, 3, 4, 5]
p = Plot(x=x, y=y).add(Range()).plot()
lines, = p._figure.axes[0].collections
paths = lines.get_paths()
assert_array_equal(paths[0].vertices, [(1, 1), (1, 2)])
assert_array_equal(paths[1].vertices, [(2, 3), (2, 5)])
def test_mapped_color(self):
x = [1, 2, 1, 2]
ymin = [1, 4, 3, 2]
ymax = [2, 3, 1, 4]
group = ["a", "a", "b", "b"]
p = Plot(x=x, ymin=ymin, ymax=ymax, color=group).add(Range()).plot()
lines, = p._figure.axes[0].collections
colors = p._theme["axes.prop_cycle"].by_key()["color"]
for i, path in enumerate(lines.get_paths()):
verts = path.vertices.T
assert_array_equal(verts[0], [x[i], x[i]])
assert_array_equal(verts[1], [ymin[i], ymax[i]])
assert same_color(lines.get_colors()[i], colors[i // 2])
def test_direct_properties(self):
x = [1, 2]
ymin = [1, 4]
ymax = [2, 3]
m = Range(color=".654", linewidth=4)
p = Plot(x=x, ymin=ymin, ymax=ymax).add(m).plot()
lines, = p._figure.axes[0].collections
for i, path in enumerate(lines.get_paths()):
assert same_color(lines.get_colors()[i], m.color)
assert lines.get_linewidths()[i] == m.linewidth
class TestDash:
def test_xy_data(self):
x = [0, 0, 1, 2]
y = [1, 2, 3, 4]
p = Plot(x=x, y=y).add(Dash()).plot()
lines, = p._figure.axes[0].collections
for i, path in enumerate(lines.get_paths()):
verts = path.vertices.T
assert_array_almost_equal(verts[0], [x[i] - .4, x[i] + .4])
assert_array_equal(verts[1], [y[i], y[i]])
def test_xy_data_grouped(self):
x = [0, 0, 1, 2]
y = [1, 2, 3, 4]
color = ["a", "b", "a", "b"]
p = Plot(x=x, y=y, color=color).add(Dash()).plot()
lines, = p._figure.axes[0].collections
idx = [0, 2, 1, 3]
for i, path in zip(idx, lines.get_paths()):
verts = path.vertices.T
assert_array_almost_equal(verts[0], [x[i] - .4, x[i] + .4])
assert_array_equal(verts[1], [y[i], y[i]])
def test_set_properties(self):
x = [0, 0, 1, 2]
y = [1, 2, 3, 4]
m = Dash(color=".8", linewidth=4)
p = Plot(x=x, y=y).add(m).plot()
lines, = p._figure.axes[0].collections
for color in lines.get_color():
assert same_color(color, m.color)
for linewidth in lines.get_linewidth():
assert linewidth == m.linewidth
def test_mapped_properties(self):
x = [0, 1]
y = [1, 2]
color = ["a", "b"]
linewidth = [1, 2]
p = Plot(x=x, y=y, color=color, linewidth=linewidth).add(Dash()).plot()
lines, = p._figure.axes[0].collections
palette = p._theme["axes.prop_cycle"].by_key()["color"]
for color, line_color in zip(palette, lines.get_color()):
assert same_color(color, line_color)
linewidths = lines.get_linewidths()
assert linewidths[1] > linewidths[0]
def test_width(self):
x = [0, 0, 1, 2]
y = [1, 2, 3, 4]
p = Plot(x=x, y=y).add(Dash(width=.4)).plot()
lines, = p._figure.axes[0].collections
for i, path in enumerate(lines.get_paths()):
verts = path.vertices.T
assert_array_almost_equal(verts[0], [x[i] - .2, x[i] + .2])
assert_array_equal(verts[1], [y[i], y[i]])
def test_dodge(self):
x = [0, 1]
y = [1, 2]
group = ["a", "b"]
p = Plot(x=x, y=y, group=group).add(Dash(), Dodge()).plot()
lines, = p._figure.axes[0].collections
paths = lines.get_paths()
v0 = paths[0].vertices.T
assert_array_almost_equal(v0[0], [-.4, 0])
assert_array_equal(v0[1], [y[0], y[0]])
v1 = paths[1].vertices.T
assert_array_almost_equal(v1[0], [1, 1.4])
assert_array_equal(v1[1], [y[1], y[1]])
================================================
FILE: tests/_marks/test_text.py
================================================
import numpy as np
from matplotlib.colors import to_rgba
from matplotlib.text import Text as MPLText
from numpy.testing import assert_array_almost_equal
from seaborn._core.plot import Plot
from seaborn._marks.text import Text
class TestText:
def get_texts(self, ax):
if ax.texts:
return list(ax.texts)
else:
# Compatibility with matplotlib < 3.5 (I think)
return [a for a in ax.artists if isinstance(a, MPLText)]
def test_simple(self):
x = y = [1, 2, 3]
s = list("abc")
p = Plot(x, y, text=s).add(Text()).plot()
ax = p._figure.axes[0]
for i, text in enumerate(self.get_texts(ax)):
x_, y_ = text.get_position()
assert x_ == x[i]
assert y_ == y[i]
assert text.get_text() == s[i]
assert text.get_horizontalalignment() == "center"
assert text.get_verticalalignment() == "center_baseline"
def test_set_properties(self):
x = y = [1, 2, 3]
s = list("abc")
color = "red"
alpha = .6
fontsize = 6
valign = "bottom"
m = Text(color=color, alpha=alpha, fontsize=fontsize, valign=valign)
p = Plot(x, y, text=s).add(m).plot()
ax = p._figure.axes[0]
for i, text in enumerate(self.get_texts(ax)):
assert text.get_text() == s[i]
assert text.get_color() == to_rgba(m.color, m.alpha)
assert text.get_fontsize() == m.fontsize
assert text.get_verticalalignment() == m.valign
def test_mapped_properties(self):
x = y = [1, 2, 3]
s = list("abc")
color = list("aab")
fontsize = [1, 2, 4]
p = Plot(x, y, color=color, fontsize=fontsize, text=s).add(Text()).plot()
ax = p._figure.axes[0]
texts = self.get_texts(ax)
assert texts[0].get_color() == texts[1].get_color()
assert texts[0].get_color() != texts[2].get_color()
assert (
texts[0].get_fontsize()
< texts[1].get_fontsize()
< texts[2].get_fontsize()
)
def test_mapped_alignment(self):
x = [1, 2]
p = Plot(x=x, y=x, halign=x, valign=x, text=x).add(Text()).plot()
ax = p._figure.axes[0]
t1, t2 = self.get_texts(ax)
assert t1.get_horizontalalignment() == "left"
assert t2.get_horizontalalignment() == "right"
assert t1.get_verticalalignment() == "top"
assert t2.get_verticalalignment() == "bottom"
def test_identity_fontsize(self):
x = y = [1, 2, 3]
s = list("abc")
fs = [5, 8, 12]
p = Plot(x, y, text=s, fontsize=fs).add(Text()).scale(fontsize=None).plot()
ax = p._figure.axes[0]
for i, text in enumerate(self.get_texts(ax)):
assert text.get_fontsize() == fs[i]
def test_offset_centered(self):
x = y = [1, 2, 3]
s = list("abc")
p = Plot(x, y, text=s).add(Text()).plot()
ax = p._figure.axes[0]
ax_trans = ax.transData.get_matrix()
for text in self.get_texts(ax):
assert_array_almost_equal(text.get_transform().get_matrix(), ax_trans)
def test_offset_valign(self):
x = y = [1, 2, 3]
s = list("abc")
m = Text(valign="bottom", fontsize=5, offset=.1)
p = Plot(x, y, text=s).add(m).plot()
ax = p._figure.axes[0]
expected_shift_matrix = np.zeros((3, 3))
expected_shift_matrix[1, -1] = m.offset * ax.figure.dpi / 72
ax_trans = ax.transData.get_matrix()
for text in self.get_texts(ax):
shift_matrix = text.get_transform().get_matrix() - ax_trans
assert_array_almost_equal(shift_matrix, expected_shift_matrix)
def test_offset_halign(self):
x = y = [1, 2, 3]
s = list("abc")
m = Text(halign="right", fontsize=10, offset=.5)
p = Plot(x, y, text=s).add(m).plot()
ax = p._figure.axes[0]
expected_shift_matrix = np.zeros((3, 3))
expected_shift_matrix[0, -1] = -m.offset * ax.figure.dpi / 72
ax_trans = ax.transData.get_matrix()
for text in self.get_texts(ax):
shift_matrix = text.get_transform().get_matrix() - ax_trans
assert_array_almost_equal(shift_matrix, expected_shift_matrix)
================================================
FILE: tests/_stats/__init__.py
================================================
================================================
FILE: tests/_stats/test_aggregation.py
================================================
import numpy as np
import pandas as pd
import pytest
from pandas.testing import assert_frame_equal
from seaborn._core.groupby import GroupBy
from seaborn._stats.aggregation import Agg, Est
class AggregationFixtures:
@pytest.fixture
def df(self, rng):
n = 30
return pd.DataFrame(dict(
x=rng.uniform(0, 7, n).round(),
y=rng.normal(size=n),
color=rng.choice(["a", "b", "c"], n),
group=rng.choice(["x", "y"], n),
))
def get_groupby(self, df, orient):
other = {"x": "y", "y": "x"}[orient]
cols = [c for c in df if c != other]
return GroupBy(cols)
class TestAgg(AggregationFixtures):
def test_default(self, df):
ori = "x"
df = df[["x", "y"]]
gb = self.get_groupby(df, ori)
res = Agg()(df, gb, ori, {})
expected = df.groupby("x", as_index=False)["y"].mean()
assert_frame_equal(res, expected)
def test_default_multi(self, df):
ori = "x"
gb = self.get_groupby(df, ori)
res = Agg()(df, gb, ori, {})
grp = ["x", "color", "group"]
index = pd.MultiIndex.from_product(
[sorted(df["x"].unique()), df["color"].unique(), df["group"].unique()],
names=["x", "color", "group"]
)
expected = (
df
.groupby(grp)
.agg("mean")
.reindex(index=index)
.dropna()
.reset_index()
.reindex(columns=df.columns)
)
assert_frame_equal(res, expected)
@pytest.mark.parametrize("func", ["max", lambda x: float(len(x) % 2)])
def test_func(self, df, func):
ori = "x"
df = df[["x", "y"]]
gb = self.get_groupby(df, ori)
res = Agg(func)(df, gb, ori, {})
expected = df.groupby("x", as_index=False)["y"].agg(func)
assert_frame_equal(res, expected)
class TestEst(AggregationFixtures):
# Note: Most of the underlying code is exercised in tests/test_statistics
@pytest.mark.parametrize("func", [np.mean, "mean"])
def test_mean_sd(self, df, func):
ori = "x"
df = df[["x", "y"]]
gb = self.get_groupby(df, ori)
res = Est(func, "sd")(df, gb, ori, {})
grouped = df.groupby("x", as_index=False)["y"]
est = grouped.mean()
err = grouped.std().fillna(0) # fillna needed only on pinned tests
expected = est.assign(ymin=est["y"] - err["y"], ymax=est["y"] + err["y"])
assert_frame_equal(res, expected)
def test_sd_single_obs(self):
y = 1.5
ori = "x"
df = pd.DataFrame([{"x": "a", "y": y}])
gb = self.get_groupby(df, ori)
res = Est("mean", "sd")(df, gb, ori, {})
expected = df.assign(ymin=y, ymax=y)
assert_frame_equal(res, expected)
def test_median_pi(self, df):
ori = "x"
df = df[["x", "y"]]
gb = self.get_groupby(df, ori)
res = Est("median", ("pi", 100))(df, gb, ori, {})
grouped = df.groupby("x", as_index=False)["y"]
est = grouped.median()
expected = est.assign(ymin=grouped.min()["y"], ymax=grouped.max()["y"])
assert_frame_equal(res, expected)
def test_weighted_mean(self, df, rng):
weights = rng.uniform(0, 5, len(df))
gb = self.get_groupby(df[["x", "y"]], "x")
df = df.assign(weight=weights)
res = Est("mean")(df, gb, "x", {})
for _, res_row in res.iterrows():
rows = df[df["x"] == res_row["x"]]
expected = np.average(rows["y"], weights=rows["weight"])
assert res_row["y"] == expected
def test_seed(self, df):
ori = "x"
gb = self.get_groupby(df, ori)
args = df, gb, ori, {}
res1 = Est("mean", "ci", seed=99)(*args)
res2 = Est("mean", "ci", seed=99)(*args)
assert_frame_equal(res1, res2)
================================================
FILE: tests/_stats/test_counting.py
================================================
import numpy as np
import pandas as pd
import pytest
from numpy.testing import assert_array_equal
from seaborn._core.groupby import GroupBy
from seaborn._stats.counting import Hist, Count
class TestCount:
@pytest.fixture
def df(self, rng):
n = 30
return pd.DataFrame(dict(
x=rng.uniform(0, 7, n).round(),
y=rng.normal(size=n),
color=rng.choice(["a", "b", "c"], n),
group=rng.choice(["x", "y"], n),
))
def get_groupby(self, df, orient):
other = {"x": "y", "y": "x"}[orient]
cols = [c for c in df if c != other]
return GroupBy(cols)
def test_single_grouper(self, df):
ori = "x"
df = df[["x"]]
gb = self.get_groupby(df, ori)
res = Count()(df, gb, ori, {})
expected = df.groupby("x").size()
assert_array_equal(res.sort_values("x")["y"], expected)
def test_multiple_groupers(self, df):
ori = "x"
df = df[["x", "group"]].sort_values("group")
gb = self.get_groupby(df, ori)
res = Count()(df, gb, ori, {})
expected = df.groupby(["x", "group"]).size()
assert_array_equal(res.sort_values(["x", "group"])["y"], expected)
class TestHist:
@pytest.fixture
def single_args(self):
groupby = GroupBy(["group"])
class Scale:
scale_type = "continuous"
return groupby, "x", {"x": Scale()}
@pytest.fixture
def triple_args(self):
groupby = GroupBy(["group", "a", "s"])
class Scale:
scale_type = "continuous"
return groupby, "x", {"x": Scale()}
def test_string_bins(self, long_df):
h = Hist(bins="sqrt")
bin_kws = h._define_bin_params(long_df, "x", "continuous")
assert bin_kws["range"] == (long_df["x"].min(), long_df["x"].max())
assert bin_kws["bins"] == int(np.sqrt(len(long_df)))
def test_int_bins(self, long_df):
n = 24
h = Hist(bins=n)
bin_kws = h._define_bin_params(long_df, "x", "continuous")
assert bin_kws["range"] == (long_df["x"].min(), long_df["x"].max())
assert bin_kws["bins"] == n
def test_array_bins(self, long_df):
bins = [-3, -2, 1, 2, 3]
h = Hist(bins=bins)
bin_kws = h._define_bin_params(long_df, "x", "continuous")
assert_array_equal(bin_kws["bins"], bins)
def test_binwidth(self, long_df):
binwidth = .5
h = Hist(binwidth=binwidth)
bin_kws = h._define_bin_params(long_df, "x", "continuous")
n_bins = bin_kws["bins"]
left, right = bin_kws["range"]
assert (right - left) / n_bins == pytest.approx(binwidth)
def test_binrange(self, long_df):
binrange = (-4, 4)
h = Hist(binrange=binrange)
bin_kws = h._define_bin_params(long_df, "x", "continuous")
assert bin_kws["range"] == binrange
def test_discrete_bins(self, long_df):
h = Hist(discrete=True)
x = long_df["x"].astype(int)
bin_kws = h._define_bin_params(long_df.assign(x=x), "x", "continuous")
assert bin_kws["range"] == (x.min() - .5, x.max() + .5)
assert bin_kws["bins"] == (x.max() - x.min() + 1)
def test_discrete_bins_from_nominal_scale(self, rng):
h = Hist()
x = rng.randint(0, 5, 10)
df = pd.DataFrame({"x": x})
bin_kws = h._define_bin_params(df, "x", "nominal")
assert bin_kws["range"] == (x.min() - .5, x.max() + .5)
assert bin_kws["bins"] == (x.max() - x.min() + 1)
def test_count_stat(self, long_df, single_args):
h = Hist(stat="count")
out = h(long_df, *single_args)
assert out["y"].sum() == len(long_df)
def test_probability_stat(self, long_df, single_args):
h = Hist(stat="probability")
out = h(long_df, *single_args)
assert out["y"].sum() == 1
def test_proportion_stat(self, long_df, single_args):
h = Hist(stat="proportion")
out = h(long_df, *single_args)
assert out["y"].sum() == 1
def test_percent_stat(self, long_df, single_args):
h = Hist(stat="percent")
out = h(long_df, *single_args)
assert out["y"].sum() == 100
def test_density_stat(self, long_df, single_args):
h = Hist(stat="density")
out = h(long_df, *single_args)
assert (out["y"] * out["space"]).sum() == 1
def test_frequency_stat(self, long_df, single_args):
h = Hist(stat="frequency")
out = h(long_df, *single_args)
assert (out["y"] * out["space"]).sum() == len(long_df)
def test_invalid_stat(self):
with pytest.raises(ValueError, match="The `stat` parameter for `Hist`"):
Hist(stat="invalid")
def test_cumulative_count(self, long_df, single_args):
h = Hist(stat="count", cumulative=True)
out = h(long_df, *single_args)
assert out["y"].max() == len(long_df)
def test_cumulative_proportion(self, long_df, single_args):
h = Hist(stat="proportion", cumulative=True)
out = h(long_df, *single_args)
assert out["y"].max() == 1
def test_cumulative_density(self, long_df, single_args):
h = Hist(stat="density", cumulative=True)
out = h(long_df, *single_args)
assert out["y"].max() == 1
def test_common_norm_default(self, long_df, triple_args):
h = Hist(stat="percent")
out = h(long_df, *triple_args)
assert out["y"].sum() == pytest.approx(100)
def test_common_norm_false(self, long_df, triple_args):
h = Hist(stat="percent", common_norm=False)
out = h(long_df, *triple_args)
for _, out_part in out.groupby(["a", "s"]):
assert out_part["y"].sum() == pytest.approx(100)
def test_common_norm_subset(self, long_df, triple_args):
h = Hist(stat="percent", common_norm=["a"])
out = h(long_df, *triple_args)
for _, out_part in out.groupby("a"):
assert out_part["y"].sum() == pytest.approx(100)
def test_common_norm_warning(self, long_df, triple_args):
h = Hist(common_norm=["b"])
with pytest.warns(UserWarning, match=r"Undefined variable\(s\)"):
h(long_df, *triple_args)
def test_common_bins_default(self, long_df, triple_args):
h = Hist()
out = h(long_df, *triple_args)
bins = []
for _, out_part in out.groupby(["a", "s"]):
bins.append(tuple(out_part["x"]))
assert len(set(bins)) == 1
def test_common_bins_false(self, long_df, triple_args):
h = Hist(common_bins=False)
out = h(long_df, *triple_args)
bins = []
for _, out_part in out.groupby(["a", "s"]):
bins.append(tuple(out_part["x"]))
assert len(set(bins)) == len(out.groupby(["a", "s"]))
def test_common_bins_subset(self, long_df, triple_args):
h = Hist(common_bins=False)
out = h(long_df, *triple_args)
bins = []
for _, out_part in out.groupby("a"):
bins.append(tuple(out_part["x"]))
assert len(set(bins)) == out["a"].nunique()
def test_common_bins_warning(self, long_df, triple_args):
h = Hist(common_bins=["b"])
with pytest.warns(UserWarning, match=r"Undefined variable\(s\)"):
h(long_df, *triple_args)
def test_histogram_single(self, long_df, single_args):
h = Hist()
out = h(long_df, *single_args)
hist, edges = np.histogram(long_df["x"], bins="auto")
assert_array_equal(out["y"], hist)
assert_array_equal(out["space"], np.diff(edges))
def test_histogram_multiple(self, long_df, triple_args):
h = Hist()
out = h(long_df, *triple_args)
bins = np.histogram_bin_edges(long_df["x"], "auto")
for (a, s), out_part in out.groupby(["a", "s"]):
x = long_df.loc[(long_df["a"] == a) & (long_df["s"] == s), "x"]
hist, edges = np.histogram(x, bins=bins)
assert_array_equal(out_part["y"], hist)
assert_array_equal(out_part["space"], np.diff(edges))
================================================
FILE: tests/_stats/test_density.py
================================================
import numpy as np
import pandas as pd
import pytest
from numpy.testing import assert_array_equal, assert_array_almost_equal
from seaborn._core.groupby import GroupBy
from seaborn._stats.density import KDE, _no_scipy
from seaborn._compat import groupby_apply_include_groups
class TestKDE:
@pytest.fixture
def df(self, rng):
n = 100
return pd.DataFrame(dict(
x=rng.uniform(0, 7, n).round(),
y=rng.normal(size=n),
color=rng.choice(["a", "b", "c"], n),
alpha=rng.choice(["x", "y"], n),
))
def get_groupby(self, df, orient):
cols = [c for c in df if c != orient]
return GroupBy([*cols, "group"])
def integrate(self, y, x):
y = np.asarray(y)
x = np.asarray(x)
dx = np.diff(x)
return (dx * y[:-1] + dx * y[1:]).sum() / 2
@pytest.mark.parametrize("ori", ["x", "y"])
def test_columns(self, df, ori):
df = df[[ori, "alpha"]]
gb = self.get_groupby(df, ori)
res = KDE()(df, gb, ori, {})
other = {"x": "y", "y": "x"}[ori]
expected = [ori, "alpha", "density", other]
assert list(res.columns) == expected
@pytest.mark.parametrize("gridsize", [20, 30, None])
def test_gridsize(self, df, gridsize):
ori = "y"
df = df[[ori]]
gb = self.get_groupby(df, ori)
res = KDE(gridsize=gridsize)(df, gb, ori, {})
if gridsize is None:
assert_array_equal(res[ori], df[ori])
else:
assert len(res) == gridsize
@pytest.mark.parametrize("cut", [1, 2])
def test_cut(self, df, cut):
ori = "y"
df = df[[ori]]
gb = self.get_groupby(df, ori)
res = KDE(cut=cut, bw_method=1)(df, gb, ori, {})
vals = df[ori]
bw = vals.std()
assert res[ori].min() == pytest.approx(vals.min() - bw * cut, abs=1e-2)
assert res[ori].max() == pytest.approx(vals.max() + bw * cut, abs=1e-2)
@pytest.mark.parametrize("common_grid", [True, False])
def test_common_grid(self, df, common_grid):
ori = "y"
df = df[[ori, "alpha"]]
gb = self.get_groupby(df, ori)
res = KDE(common_grid=common_grid)(df, gb, ori, {})
vals = df["alpha"].unique()
a = res.loc[res["alpha"] == vals[0], ori].to_numpy()
b = res.loc[res["alpha"] == vals[1], ori].to_numpy()
if common_grid:
assert_array_equal(a, b)
else:
assert np.not_equal(a, b).all()
@pytest.mark.parametrize("common_norm", [True, False])
def test_common_norm(self, df, common_norm):
ori = "y"
df = df[[ori, "alpha"]]
gb = self.get_groupby(df, ori)
res = KDE(common_norm=common_norm)(df, gb, ori, {})
areas = (
res.groupby("alpha")
.apply(
lambda x: self.integrate(x["density"], x[ori]),
**groupby_apply_include_groups(False),
)
)
if common_norm:
assert areas.sum() == pytest.approx(1, abs=1e-3)
else:
assert_array_almost_equal(areas, [1, 1], decimal=3)
def test_common_norm_variables(self, df):
ori = "y"
df = df[[ori, "alpha", "color"]]
gb = self.get_groupby(df, ori)
res = KDE(common_norm=["alpha"])(df, gb, ori, {})
def integrate_by_color_and_sum(x):
return (
x.groupby("color")
.apply(
lambda y: self.integrate(y["density"], y[ori]),
**groupby_apply_include_groups(False)
)
.sum()
)
areas = (
res
.groupby("alpha")
.apply(integrate_by_color_and_sum, **groupby_apply_include_groups(False))
)
assert_array_almost_equal(areas, [1, 1], decimal=3)
@pytest.mark.parametrize("param", ["norm", "grid"])
def test_common_input_checks(self, df, param):
ori = "y"
df = df[[ori, "alpha"]]
gb = self.get_groupby(df, ori)
msg = rf"Undefined variable\(s\) passed for KDE.common_{param}"
with pytest.warns(UserWarning, match=msg):
KDE(**{f"common_{param}": ["color", "alpha"]})(df, gb, ori, {})
msg = f"KDE.common_{param} must be a boolean or list of strings"
with pytest.raises(TypeError, match=msg):
KDE(**{f"common_{param}": "alpha"})(df, gb, ori, {})
def test_bw_adjust(self, df):
ori = "y"
df = df[[ori]]
gb = self.get_groupby(df, ori)
res1 = KDE(bw_adjust=0.5)(df, gb, ori, {})
res2 = KDE(bw_adjust=2.0)(df, gb, ori, {})
mad1 = res1["density"].diff().abs().mean()
mad2 = res2["density"].diff().abs().mean()
assert mad1 > mad2
def test_bw_method_scalar(self, df):
ori = "y"
df = df[[ori]]
gb = self.get_groupby(df, ori)
res1 = KDE(bw_method=0.5)(df, gb, ori, {})
res2 = KDE(bw_method=2.0)(df, gb, ori, {})
mad1 = res1["density"].diff().abs().mean()
mad2 = res2["density"].diff().abs().mean()
assert mad1 > mad2
@pytest.mark.skipif(_no_scipy, reason="KDE.cumulative requires scipy")
@pytest.mark.parametrize("common_norm", [True, False])
def test_cumulative(self, df, common_norm):
ori = "y"
df = df[[ori, "alpha"]]
gb = self.get_groupby(df, ori)
res = KDE(cumulative=True, common_norm=common_norm)(df, gb, ori, {})
for _, group_res in res.groupby("alpha"):
assert (group_res["density"].diff().dropna() >= 0).all()
if not common_norm:
assert group_res["density"].max() == pytest.approx(1, abs=1e-3)
def test_cumulative_requires_scipy(self):
if _no_scipy:
err = "Cumulative KDE evaluation requires scipy"
with pytest.raises(RuntimeError, match=err):
KDE(cumulative=True)
@pytest.mark.parametrize("vals", [[], [1], [1] * 5, [1929245168.06679] * 18])
def test_singular(self, df, vals):
df1 = pd.DataFrame({"y": vals, "alpha": ["z"] * len(vals)})
gb = self.get_groupby(df1, "y")
res = KDE()(df1, gb, "y", {})
assert res.empty
df2 = pd.concat([df[["y", "alpha"]], df1], ignore_index=True)
gb = self.get_groupby(df2, "y")
res = KDE()(df2, gb, "y", {})
assert set(res["alpha"]) == set(df["alpha"])
@pytest.mark.parametrize("col", ["y", "weight"])
def test_missing(self, df, col):
val, ori = "xy"
df["weight"] = 1
df = df[[ori, "weight"]].astype(float)
df.loc[:4, col] = np.nan
gb = self.get_groupby(df, ori)
res = KDE()(df, gb, ori, {})
assert self.integrate(res[val], res[ori]) == pytest.approx(1, abs=1e-3)
================================================
FILE: tests/_stats/test_order.py
================================================
import numpy as np
import pandas as pd
import pytest
from numpy.testing import assert_array_equal
from seaborn._core.groupby import GroupBy
from seaborn._stats.order import Perc
from seaborn.utils import _version_predates
class Fixtures:
@pytest.fixture
def df(self, rng):
return pd.DataFrame(dict(x="", y=rng.normal(size=30)))
def get_groupby(self, df, orient):
# TODO note, copied from aggregation
other = {"x": "y", "y": "x"}[orient]
cols = [c for c in df if c != other]
return GroupBy(cols)
class TestPerc(Fixtures):
def test_int_k(self, df):
ori = "x"
gb = self.get_groupby(df, ori)
res = Perc(3)(df, gb, ori, {})
percentiles = [0, 50, 100]
assert_array_equal(res["percentile"], percentiles)
assert_array_equal(res["y"], np.percentile(df["y"], percentiles))
def test_list_k(self, df):
ori = "x"
gb = self.get_groupby(df, ori)
percentiles = [0, 20, 100]
res = Perc(k=percentiles)(df, gb, ori, {})
assert_array_equal(res["percentile"], percentiles)
assert_array_equal(res["y"], np.percentile(df["y"], percentiles))
def test_orientation(self, df):
df = df.rename(columns={"x": "y", "y": "x"})
ori = "y"
gb = self.get_groupby(df, ori)
res = Perc(k=3)(df, gb, ori, {})
assert_array_equal(res["x"], np.percentile(df["x"], [0, 50, 100]))
def test_method(self, df):
ori = "x"
gb = self.get_groupby(df, ori)
method = "nearest"
res = Perc(k=5, method=method)(df, gb, ori, {})
percentiles = [0, 25, 50, 75, 100]
if _version_predates(np, "1.22.0"):
expected = np.percentile(df["y"], percentiles, interpolation=method)
else:
expected = np.percentile(df["y"], percentiles, method=method)
assert_array_equal(res["y"], expected)
def test_grouped(self, df, rng):
ori = "x"
df = df.assign(x=rng.choice(["a", "b", "c"], len(df)))
gb = self.get_groupby(df, ori)
k = [10, 90]
res = Perc(k)(df, gb, ori, {})
for x, res_x in res.groupby("x"):
assert_array_equal(res_x["percentile"], k)
expected = np.percentile(df.loc[df["x"] == x, "y"], k)
assert_array_equal(res_x["y"], expected)
def test_with_na(self, df):
ori = "x"
df.loc[:5, "y"] = np.nan
gb = self.get_groupby(df, ori)
k = [10, 90]
res = Perc(k)(df, gb, ori, {})
expected = np.percentile(df["y"].dropna(), k)
assert_array_equal(res["y"], expected)
================================================
FILE: tests/_stats/test_regression.py
================================================
import numpy as np
import pandas as pd
import pytest
from numpy.testing import assert_array_equal, assert_array_almost_equal
from pandas.testing import assert_frame_equal
from seaborn._core.groupby import GroupBy
from seaborn._stats.regression import PolyFit
class TestPolyFit:
@pytest.fixture
def df(self, rng):
n = 100
return pd.DataFrame(dict(
x=rng.normal(0, 1, n),
y=rng.normal(0, 1, n),
color=rng.choice(["a", "b", "c"], n),
group=rng.choice(["x", "y"], n),
))
def test_no_grouper(self, df):
groupby = GroupBy(["group"])
res = PolyFit(order=1, gridsize=100)(df[["x", "y"]], groupby, "x", {})
assert_array_equal(res.columns, ["x", "y"])
grid = np.linspace(df["x"].min(), df["x"].max(), 100)
assert_array_equal(res["x"], grid)
assert_array_almost_equal(
res["y"].diff().diff().dropna(), np.zeros(grid.size - 2)
)
def test_one_grouper(self, df):
groupby = GroupBy(["group"])
gridsize = 50
res = PolyFit(gridsize=gridsize)(df, groupby, "x", {})
assert res.columns.to_list() == ["x", "y", "group"]
ngroups = df["group"].nunique()
assert_array_equal(res.index, np.arange(ngroups * gridsize))
for _, part in res.groupby("group"):
grid = np.linspace(part["x"].min(), part["x"].max(), gridsize)
assert_array_equal(part["x"], grid)
assert part["y"].diff().diff().dropna().abs().gt(0).all()
def test_missing_data(self, df):
groupby = GroupBy(["group"])
df.iloc[5:10] = np.nan
res1 = PolyFit()(df[["x", "y"]], groupby, "x", {})
res2 = PolyFit()(df[["x", "y"]].dropna(), groupby, "x", {})
assert_frame_equal(res1, res2)
================================================
FILE: tests/conftest.py
================================================
import numpy as np
import pandas as pd
import pytest
@pytest.fixture(autouse=True)
def close_figs():
yield
import matplotlib.pyplot as plt
plt.close("all")
@pytest.fixture(autouse=True)
def random_seed():
seed = sum(map(ord, "seaborn random global"))
np.random.seed(seed)
@pytest.fixture()
def rng():
seed = sum(map(ord, "seaborn random object"))
return np.random.RandomState(seed)
@pytest.fixture
def wide_df(rng):
columns = list("abc")
index = pd.RangeIndex(10, 50, 2, name="wide_index")
values = rng.normal(size=(len(index), len(columns)))
return pd.DataFrame(values, index=index, columns=columns)
@pytest.fixture
def wide_array(wide_df):
return wide_df.to_numpy()
# TODO s/flat/thin?
@pytest.fixture
def flat_series(rng):
index = pd.RangeIndex(10, 30, name="t")
return pd.Series(rng.normal(size=20), index, name="s")
@pytest.fixture
def flat_array(flat_series):
return flat_series.to_numpy()
@pytest.fixture
def flat_list(flat_series):
return flat_series.to_list()
@pytest.fixture(params=["series", "array", "list"])
def flat_data(rng, request):
index = pd.RangeIndex(10, 30, name="t")
series = pd.Series(rng.normal(size=20), index, name="s")
if request.param == "series":
data = series
elif request.param == "array":
data = series.to_numpy()
elif request.param == "list":
data = series.to_list()
return data
@pytest.fixture
def wide_list_of_series(rng):
return [pd.Series(rng.normal(size=20), np.arange(20), name="a"),
pd.Series(rng.normal(size=10), np.arange(5, 15), name="b")]
@pytest.fixture
def wide_list_of_arrays(wide_list_of_series):
return [s.to_numpy() for s in wide_list_of_series]
@pytest.fixture
def wide_list_of_lists(wide_list_of_series):
return [s.to_list() for s in wide_list_of_series]
@pytest.fixture
def wide_dict_of_series(wide_list_of_series):
return {s.name: s for s in wide_list_of_series}
@pytest.fixture
def wide_dict_of_arrays(wide_list_of_series):
return {s.name: s.to_numpy() for s in wide_list_of_series}
@pytest.fixture
def wide_dict_of_lists(wide_list_of_series):
return {s.name: s.to_list() for s in wide_list_of_series}
@pytest.fixture
def long_df(rng):
n = 100
df = pd.DataFrame(dict(
x=rng.uniform(0, 20, n).round().astype("int"),
y=rng.normal(size=n),
z=rng.lognormal(size=n),
a=rng.choice(list("abc"), n),
b=rng.choice(list("mnop"), n),
c=rng.choice([0, 1], n, [.3, .7]),
d=rng.choice(np.arange("2004-07-30", "2007-07-30", dtype="datetime64[Y]"), n),
t=rng.choice(np.arange("2004-07-30", "2004-07-31", dtype="datetime64[m]"), n),
s=rng.choice([2, 4, 8], n),
f=rng.choice([0.2, 0.3], n),
))
a_cat = df["a"].astype("category")
new_categories = np.roll(a_cat.cat.categories, 1)
df["a_cat"] = a_cat.cat.reorder_categories(new_categories)
df["s_cat"] = df["s"].astype("category")
df["s_str"] = df["s"].astype(str)
return df
@pytest.fixture
def long_dict(long_df):
return long_df.to_dict()
@pytest.fixture
def repeated_df(rng):
n = 100
return pd.DataFrame(dict(
x=np.tile(np.arange(n // 2), 2),
y=rng.normal(size=n),
a=rng.choice(list("abc"), n),
u=np.repeat(np.arange(2), n // 2),
))
@pytest.fixture
def null_df(rng, long_df):
df = long_df.copy()
for col in df:
if pd.api.types.is_integer_dtype(df[col]):
df[col] = df[col].astype(float)
idx = rng.permutation(df.index)[:10]
df.loc[idx, col] = np.nan
return df
@pytest.fixture
def object_df(rng, long_df):
df = long_df.copy()
# objectify numeric columns
for col in ["c", "s", "f"]:
df[col] = df[col].astype(object)
return df
@pytest.fixture
def null_series(flat_series):
return pd.Series(index=flat_series.index, dtype='float64')
class MockConvertibleDataFrame:
# Mock object that is not a pandas.DataFrame but that can
# be converted to one via the DataFrame exchange protocol
def __init__(self, data):
self._data = data
def to_pandas(self, *args, **kwargs):
if self._data is None:
raise ValueError("Cannot convert to pandas")
return self._data
@pytest.fixture
def mock_long_df(long_df):
return MockConvertibleDataFrame(long_df)
================================================
FILE: tests/test_algorithms.py
================================================
import numpy as np
import pytest
from numpy.testing import assert_array_equal
from seaborn import algorithms as algo
@pytest.fixture
def random():
np.random.seed(sum(map(ord, "test_algorithms")))
def test_bootstrap(random):
"""Test that bootstrapping gives the right answer in dumb cases."""
a_ones = np.ones(10)
n_boot = 5
out1 = algo.bootstrap(a_ones, n_boot=n_boot)
assert_array_equal(out1, np.ones(n_boot))
out2 = algo.bootstrap(a_ones, n_boot=n_boot, func=np.median)
assert_array_equal(out2, np.ones(n_boot))
def test_bootstrap_length(random):
"""Test that we get a bootstrap array of the right shape."""
a_norm = np.random.randn(1000)
out = algo.bootstrap(a_norm)
assert len(out) == 10000
n_boot = 100
out = algo.bootstrap(a_norm, n_boot=n_boot)
assert len(out) == n_boot
def test_bootstrap_range(random):
"""Test that bootstrapping a random array stays within the right range."""
a_norm = np.random.randn(1000)
amin, amax = a_norm.min(), a_norm.max()
out = algo.bootstrap(a_norm)
assert amin <= out.min()
assert amax >= out.max()
def test_bootstrap_multiarg(random):
"""Test that bootstrap works with multiple input arrays."""
x = np.vstack([[1, 10] for i in range(10)])
y = np.vstack([[5, 5] for i in range(10)])
def f(x, y):
return np.vstack((x, y)).max(axis=0)
out_actual = algo.bootstrap(x, y, n_boot=2, func=f)
out_wanted = np.array([[5, 10], [5, 10]])
assert_array_equal(out_actual, out_wanted)
def test_bootstrap_axis(random):
"""Test axis kwarg to bootstrap function."""
x = np.random.randn(10, 20)
n_boot = 100
out_default = algo.bootstrap(x, n_boot=n_boot)
assert out_default.shape == (n_boot,)
out_axis = algo.bootstrap(x, n_boot=n_boot, axis=0)
assert out_axis.shape, (n_boot, x.shape[1])
def test_bootstrap_seed(random):
"""Test that we can get reproducible resamples by seeding the RNG."""
data = np.random.randn(50)
seed = 42
boots1 = algo.bootstrap(data, seed=seed)
boots2 = algo.bootstrap(data, seed=seed)
assert_array_equal(boots1, boots2)
def test_bootstrap_ols(random):
"""Test bootstrap of OLS model fit."""
def ols_fit(X, y):
XtXinv = np.linalg.inv(np.dot(X.T, X))
return XtXinv.dot(X.T).dot(y)
X = np.column_stack((np.random.randn(50, 4), np.ones(50)))
w = [2, 4, 0, 3, 5]
y_noisy = np.dot(X, w) + np.random.randn(50) * 20
y_lownoise = np.dot(X, w) + np.random.randn(50)
n_boot = 500
w_boot_noisy = algo.bootstrap(X, y_noisy,
n_boot=n_boot,
func=ols_fit)
w_boot_lownoise = algo.bootstrap(X, y_lownoise,
n_boot=n_boot,
func=ols_fit)
assert w_boot_noisy.shape == (n_boot, 5)
assert w_boot_lownoise.shape == (n_boot, 5)
assert w_boot_noisy.std() > w_boot_lownoise.std()
def test_bootstrap_units(random):
"""Test that results make sense when passing unit IDs to bootstrap."""
data = np.random.randn(50)
ids = np.repeat(range(10), 5)
bwerr = np.random.normal(0, 2, 10)
bwerr = bwerr[ids]
data_rm = data + bwerr
seed = 77
boots_orig = algo.bootstrap(data_rm, seed=seed)
boots_rm = algo.bootstrap(data_rm, units=ids, seed=seed)
assert boots_rm.std() > boots_orig.std()
def test_bootstrap_arglength():
"""Test that different length args raise ValueError."""
with pytest.raises(ValueError):
algo.bootstrap(np.arange(5), np.arange(10))
def test_bootstrap_string_func():
"""Test that named numpy methods are the same as the numpy function."""
x = np.random.randn(100)
res_a = algo.bootstrap(x, func="mean", seed=0)
res_b = algo.bootstrap(x, func=np.mean, seed=0)
assert np.array_equal(res_a, res_b)
res_a = algo.bootstrap(x, func="std", seed=0)
res_b = algo.bootstrap(x, func=np.std, seed=0)
assert np.array_equal(res_a, res_b)
with pytest.raises(AttributeError):
algo.bootstrap(x, func="not_a_method_name")
def test_bootstrap_reproducibility(random):
"""Test that bootstrapping uses the internal random state."""
data = np.random.randn(50)
boots1 = algo.bootstrap(data, seed=100)
boots2 = algo.bootstrap(data, seed=100)
assert_array_equal(boots1, boots2)
random_state1 = np.random.RandomState(200)
boots1 = algo.bootstrap(data, seed=random_state1)
random_state2 = np.random.RandomState(200)
boots2 = algo.bootstrap(data, seed=random_state2)
assert_array_equal(boots1, boots2)
with pytest.warns(UserWarning):
# Deprecated, remove when removing random_seed
boots1 = algo.bootstrap(data, random_seed=100)
boots2 = algo.bootstrap(data, random_seed=100)
assert_array_equal(boots1, boots2)
def test_nanaware_func_auto(random):
x = np.random.normal(size=10)
x[0] = np.nan
boots = algo.bootstrap(x, func="mean")
assert not np.isnan(boots).any()
def test_nanaware_func_warning(random):
x = np.random.normal(size=10)
x[0] = np.nan
with pytest.warns(UserWarning, match="Data contain nans but"):
boots = algo.bootstrap(x, func="ptp")
assert np.isnan(boots).any()
================================================
FILE: tests/test_axisgrid.py
================================================
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import pytest
import numpy.testing as npt
from numpy.testing import assert_array_equal, assert_array_almost_equal
import pandas.testing as tm
from seaborn._base import categorical_order
from seaborn import rcmod
from seaborn.palettes import color_palette
from seaborn.relational import scatterplot
from seaborn.distributions import histplot, kdeplot, distplot
from seaborn.categorical import pointplot
from seaborn.utils import _version_predates
from seaborn import axisgrid as ag
from seaborn._testing import (
assert_plots_equal,
assert_colors_equal,
)
from seaborn._compat import get_legend_handles
rs = np.random.RandomState(0)
class TestFacetGrid:
df = pd.DataFrame(dict(x=rs.normal(size=60),
y=rs.gamma(4, size=60),
a=np.repeat(list("abc"), 20),
b=np.tile(list("mn"), 30),
c=np.tile(list("tuv"), 20),
d=np.tile(list("abcdefghijkl"), 5)))
def test_self_data(self):
g = ag.FacetGrid(self.df)
assert g.data is self.df
def test_self_figure(self):
g = ag.FacetGrid(self.df)
assert isinstance(g.figure, plt.Figure)
assert g.figure is g._figure
def test_self_axes(self):
g = ag.FacetGrid(self.df, row="a", col="b", hue="c")
for ax in g.axes.flat:
assert isinstance(ax, plt.Axes)
def test_axes_array_size(self):
g = ag.FacetGrid(self.df)
assert g.axes.shape == (1, 1)
g = ag.FacetGrid(self.df, row="a")
assert g.axes.shape == (3, 1)
g = ag.FacetGrid(self.df, col="b")
assert g.axes.shape == (1, 2)
g = ag.FacetGrid(self.df, hue="c")
assert g.axes.shape == (1, 1)
g = ag.FacetGrid(self.df, row="a", col="b", hue="c")
assert g.axes.shape == (3, 2)
for ax in g.axes.flat:
assert isinstance(ax, plt.Axes)
def test_single_axes(self):
g = ag.FacetGrid(self.df)
assert isinstance(g.ax, plt.Axes)
g = ag.FacetGrid(self.df, row="a")
with pytest.raises(AttributeError):
g.ax
g = ag.FacetGrid(self.df, col="a")
with pytest.raises(AttributeError):
g.ax
g = ag.FacetGrid(self.df, col="a", row="b")
with pytest.raises(AttributeError):
g.ax
def test_col_wrap(self):
n = len(self.df.d.unique())
g = ag.FacetGrid(self.df, col="d")
assert g.axes.shape == (1, n)
assert g.facet_axis(0, 8) is g.axes[0, 8]
g_wrap = ag.FacetGrid(self.df, col="d", col_wrap=4)
assert g_wrap.axes.shape == (n,)
assert g_wrap.facet_axis(0, 8) is g_wrap.axes[8]
assert g_wrap._ncol == 4
assert g_wrap._nrow == (n / 4)
with pytest.raises(ValueError):
g = ag.FacetGrid(self.df, row="b", col="d", col_wrap=4)
df = self.df.copy()
df.loc[df.d == "j"] = np.nan
g_missing = ag.FacetGrid(df, col="d")
assert g_missing.axes.shape == (1, n - 1)
g_missing_wrap = ag.FacetGrid(df, col="d", col_wrap=4)
assert g_missing_wrap.axes.shape == (n - 1,)
g = ag.FacetGrid(self.df, col="d", col_wrap=1)
assert len(list(g.facet_data())) == n
def test_normal_axes(self):
null = np.empty(0, object).flat
g = ag.FacetGrid(self.df)
npt.assert_array_equal(g._bottom_axes, g.axes.flat)
npt.assert_array_equal(g._not_bottom_axes, null)
npt.assert_array_equal(g._left_axes, g.axes.flat)
npt.assert_array_equal(g._not_left_axes, null)
npt.assert_array_equal(g._inner_axes, null)
g = ag.FacetGrid(self.df, col="c")
npt.assert_array_equal(g._bottom_axes, g.axes.flat)
npt.assert_array_equal(g._not_bottom_axes, null)
npt.assert_array_equal(g._left_axes, g.axes[:, 0].flat)
npt.assert_array_equal(g._not_left_axes, g.axes[:, 1:].flat)
npt.assert_array_equal(g._inner_axes, null)
g = ag.FacetGrid(self.df, row="c")
npt.assert_array_equal(g._bottom_axes, g.axes[-1, :].flat)
npt.assert_array_equal(g._not_bottom_axes, g.axes[:-1, :].flat)
npt.assert_array_equal(g._left_axes, g.axes.flat)
npt.assert_array_equal(g._not_left_axes, null)
npt.assert_array_equal(g._inner_axes, null)
g = ag.FacetGrid(self.df, col="a", row="c")
npt.assert_array_equal(g._bottom_axes, g.axes[-1, :].flat)
npt.assert_array_equal(g._not_bottom_axes, g.axes[:-1, :].flat)
npt.assert_array_equal(g._left_axes, g.axes[:, 0].flat)
npt.assert_array_equal(g._not_left_axes, g.axes[:, 1:].flat)
npt.assert_array_equal(g._inner_axes, g.axes[:-1, 1:].flat)
def test_wrapped_axes(self):
null = np.empty(0, object).flat
g = ag.FacetGrid(self.df, col="a", col_wrap=2)
npt.assert_array_equal(g._bottom_axes,
g.axes[np.array([1, 2])].flat)
npt.assert_array_equal(g._not_bottom_axes, g.axes[:1].flat)
npt.assert_array_equal(g._left_axes, g.axes[np.array([0, 2])].flat)
npt.assert_array_equal(g._not_left_axes, g.axes[np.array([1])].flat)
npt.assert_array_equal(g._inner_axes, null)
def test_axes_dict(self):
g = ag.FacetGrid(self.df)
assert isinstance(g.axes_dict, dict)
assert not g.axes_dict
g = ag.FacetGrid(self.df, row="c")
assert list(g.axes_dict.keys()) == g.row_names
for (name, ax) in zip(g.row_names, g.axes.flat):
assert g.axes_dict[name] is ax
g = ag.FacetGrid(self.df, col="c")
assert list(g.axes_dict.keys()) == g.col_names
for (name, ax) in zip(g.col_names, g.axes.flat):
assert g.axes_dict[name] is ax
g = ag.FacetGrid(self.df, col="a", col_wrap=2)
assert list(g.axes_dict.keys()) == g.col_names
for (name, ax) in zip(g.col_names, g.axes.flat):
assert g.axes_dict[name] is ax
g = ag.FacetGrid(self.df, row="a", col="c")
for (row_var, col_var), ax in g.axes_dict.items():
i = g.row_names.index(row_var)
j = g.col_names.index(col_var)
assert g.axes[i, j] is ax
def test_figure_size(self):
g = ag.FacetGrid(self.df, row="a", col="b")
npt.assert_array_equal(g.figure.get_size_inches(), (6, 9))
g = ag.FacetGrid(self.df, row="a", col="b", height=6)
npt.assert_array_equal(g.figure.get_size_inches(), (12, 18))
g = ag.FacetGrid(self.df, col="c", height=4, aspect=.5)
npt.assert_array_equal(g.figure.get_size_inches(), (6, 4))
def test_figure_size_with_legend(self):
g = ag.FacetGrid(self.df, col="a", hue="c", height=4, aspect=.5)
npt.assert_array_equal(g.figure.get_size_inches(), (6, 4))
g.add_legend()
assert g.figure.get_size_inches()[0] > 6
g = ag.FacetGrid(self.df, col="a", hue="c", height=4, aspect=.5,
legend_out=False)
npt.assert_array_equal(g.figure.get_size_inches(), (6, 4))
g.add_legend()
npt.assert_array_equal(g.figure.get_size_inches(), (6, 4))
def test_legend_data(self):
g = ag.FacetGrid(self.df, hue="a")
g.map(plt.plot, "x", "y")
g.add_legend()
palette = color_palette(n_colors=3)
assert g._legend.get_title().get_text() == "a"
a_levels = sorted(self.df.a.unique())
lines = g._legend.get_lines()
assert len(lines) == len(a_levels)
for line, hue in zip(lines, palette):
assert_colors_equal(line.get_color(), hue)
labels = g._legend.get_texts()
assert len(labels) == len(a_levels)
for label, level in zip(labels, a_levels):
assert label.get_text() == level
def test_legend_data_missing_level(self):
g = ag.FacetGrid(self.df, hue="a", hue_order=list("azbc"))
g.map(plt.plot, "x", "y")
g.add_legend()
c1, c2, c3, c4 = color_palette(n_colors=4)
palette = [c1, c3, c4]
assert g._legend.get_title().get_text() == "a"
a_levels = sorted(self.df.a.unique())
lines = g._legend.get_lines()
assert len(lines) == len(a_levels)
for line, hue in zip(lines, palette):
assert_colors_equal(line.get_color(), hue)
labels = g._legend.get_texts()
assert len(labels) == 4
for label, level in zip(labels, list("azbc")):
assert label.get_text() == level
def test_get_boolean_legend_data(self):
self.df["b_bool"] = self.df.b == "m"
g = ag.FacetGrid(self.df, hue="b_bool")
g.map(plt.plot, "x", "y")
g.add_legend()
palette = color_palette(n_colors=2)
assert g._legend.get_title().get_text() == "b_bool"
b_levels = list(map(str, categorical_order(self.df.b_bool)))
lines = g._legend.get_lines()
assert len(lines) == len(b_levels)
for line, hue in zip(lines, palette):
assert_colors_equal(line.get_color(), hue)
labels = g._legend.get_texts()
assert len(labels) == len(b_levels)
for label, level in zip(labels, b_levels):
assert label.get_text() == level
def test_legend_tuples(self):
g = ag.FacetGrid(self.df, hue="a")
g.map(plt.plot, "x", "y")
handles, labels = g.ax.get_legend_handles_labels()
label_tuples = [("", l) for l in labels]
legend_data = dict(zip(label_tuples, handles))
g.add_legend(legend_data, label_tuples)
for entry, label in zip(g._legend.get_texts(), labels):
assert entry.get_text() == label
def test_legend_options(self):
g = ag.FacetGrid(self.df, hue="b")
g.map(plt.plot, "x", "y")
g.add_legend()
g1 = ag.FacetGrid(self.df, hue="b", legend_out=False)
g1.add_legend(adjust_subtitles=True)
g1 = ag.FacetGrid(self.df, hue="b", legend_out=False)
g1.add_legend(adjust_subtitles=False)
def test_legendout_with_colwrap(self):
g = ag.FacetGrid(self.df, col="d", hue='b',
col_wrap=4, legend_out=False)
g.map(plt.plot, "x", "y", linewidth=3)
g.add_legend()
def test_legend_tight_layout(self):
g = ag.FacetGrid(self.df, hue='b')
g.map(plt.plot, "x", "y", linewidth=3)
g.add_legend()
g.tight_layout()
axes_right_edge = g.ax.get_window_extent().xmax
legend_left_edge = g._legend.get_window_extent().xmin
assert axes_right_edge < legend_left_edge
def test_subplot_kws(self):
g = ag.FacetGrid(self.df, despine=False,
subplot_kws=dict(projection="polar"))
for ax in g.axes.flat:
assert "PolarAxes" in ax.__class__.__name__
def test_gridspec_kws(self):
ratios = [3, 1, 2]
gskws = dict(width_ratios=ratios)
g = ag.FacetGrid(self.df, col='c', row='a', gridspec_kws=gskws)
for ax in g.axes.flat:
ax.set_xticks([])
ax.set_yticks([])
g.figure.tight_layout()
for (l, m, r) in g.axes:
assert l.get_position().width > m.get_position().width
assert r.get_position().width > m.get_position().width
def test_gridspec_kws_col_wrap(self):
ratios = [3, 1, 2, 1, 1]
gskws = dict(width_ratios=ratios)
with pytest.warns(UserWarning):
ag.FacetGrid(self.df, col='d', col_wrap=5, gridspec_kws=gskws)
def test_data_generator(self):
g = ag.FacetGrid(self.df, row="a")
d = list(g.facet_data())
assert len(d) == 3
tup, data = d[0]
assert tup == (0, 0, 0)
assert (data["a"] == "a").all()
tup, data = d[1]
assert tup == (1, 0, 0)
assert (data["a"] == "b").all()
g = ag.FacetGrid(self.df, row="a", col="b")
d = list(g.facet_data())
assert len(d) == 6
tup, data = d[0]
assert tup == (0, 0, 0)
assert (data["a"] == "a").all()
assert (data["b"] == "m").all()
tup, data = d[1]
assert tup == (0, 1, 0)
assert (data["a"] == "a").all()
assert (data["b"] == "n").all()
tup, data = d[2]
assert tup == (1, 0, 0)
assert (data["a"] == "b").all()
assert (data["b"] == "m").all()
g = ag.FacetGrid(self.df, hue="c")
d = list(g.facet_data())
assert len(d) == 3
tup, data = d[1]
assert tup == (0, 0, 1)
assert (data["c"] == "u").all()
def test_map(self):
g = ag.FacetGrid(self.df, row="a", col="b", hue="c")
g.map(plt.plot, "x", "y", linewidth=3)
lines = g.axes[0, 0].lines
assert len(lines) == 3
line1, _, _ = lines
assert line1.get_linewidth() == 3
x, y = line1.get_data()
mask = (self.df.a == "a") & (self.df.b == "m") & (self.df.c == "t")
npt.assert_array_equal(x, self.df.x[mask])
npt.assert_array_equal(y, self.df.y[mask])
def test_map_dataframe(self):
g = ag.FacetGrid(self.df, row="a", col="b", hue="c")
def plot(x, y, data=None, **kws):
plt.plot(data[x], data[y], **kws)
# Modify __module__ so this doesn't look like a seaborn function
plot.__module__ = "test"
g.map_dataframe(plot, "x", "y", linestyle="--")
lines = g.axes[0, 0].lines
assert len(g.axes[0, 0].lines) == 3
line1, _, _ = lines
assert line1.get_linestyle() == "--"
x, y = line1.get_data()
mask = (self.df.a == "a") & (self.df.b == "m") & (self.df.c == "t")
npt.assert_array_equal(x, self.df.x[mask])
npt.assert_array_equal(y, self.df.y[mask])
def test_set(self):
g = ag.FacetGrid(self.df, row="a", col="b")
xlim = (-2, 5)
ylim = (3, 6)
xticks = [-2, 0, 3, 5]
yticks = [3, 4.5, 6]
g.set(xlim=xlim, ylim=ylim, xticks=xticks, yticks=yticks)
for ax in g.axes.flat:
npt.assert_array_equal(ax.get_xlim(), xlim)
npt.assert_array_equal(ax.get_ylim(), ylim)
npt.assert_array_equal(ax.get_xticks(), xticks)
npt.assert_array_equal(ax.get_yticks(), yticks)
def test_set_titles(self):
g = ag.FacetGrid(self.df, row="a", col="b")
g.map(plt.plot, "x", "y")
# Test the default titles
assert g.axes[0, 0].get_title() == "a = a | b = m"
assert g.axes[0, 1].get_title() == "a = a | b = n"
assert g.axes[1, 0].get_title() == "a = b | b = m"
# Test a provided title
g.set_titles("{row_var} == {row_name} \\/ {col_var} == {col_name}")
assert g.axes[0, 0].get_title() == "a == a \\/ b == m"
assert g.axes[0, 1].get_title() == "a == a \\/ b == n"
assert g.axes[1, 0].get_title() == "a == b \\/ b == m"
# Test a single row
g = ag.FacetGrid(self.df, col="b")
g.map(plt.plot, "x", "y")
# Test the default titles
assert g.axes[0, 0].get_title() == "b = m"
assert g.axes[0, 1].get_title() == "b = n"
# test with dropna=False
g = ag.FacetGrid(self.df, col="b", hue="b", dropna=False)
g.map(plt.plot, 'x', 'y')
def test_set_titles_margin_titles(self):
g = ag.FacetGrid(self.df, row="a", col="b", margin_titles=True)
g.map(plt.plot, "x", "y")
# Test the default titles
assert g.axes[0, 0].get_title() == "b = m"
assert g.axes[0, 1].get_title() == "b = n"
assert g.axes[1, 0].get_title() == ""
# Test the row "titles"
assert g.axes[0, 1].texts[0].get_text() == "a = a"
assert g.axes[1, 1].texts[0].get_text() == "a = b"
assert g.axes[0, 1].texts[0] is g._margin_titles_texts[0]
# Test provided titles
g.set_titles(col_template="{col_name}", row_template="{row_name}")
assert g.axes[0, 0].get_title() == "m"
assert g.axes[0, 1].get_title() == "n"
assert g.axes[1, 0].get_title() == ""
assert len(g.axes[1, 1].texts) == 1
assert g.axes[1, 1].texts[0].get_text() == "b"
def test_set_ticklabels(self):
g = ag.FacetGrid(self.df, row="a", col="b")
g.map(plt.plot, "x", "y")
ax = g.axes[-1, 0]
xlab = [l.get_text() + "h" for l in ax.get_xticklabels()]
ylab = [l.get_text() + "i" for l in ax.get_yticklabels()]
g.set_xticklabels(xlab)
g.set_yticklabels(ylab)
got_x = [l.get_text() for l in g.axes[-1, 1].get_xticklabels()]
got_y = [l.get_text() for l in g.axes[0, 0].get_yticklabels()]
npt.assert_array_equal(got_x, xlab)
npt.assert_array_equal(got_y, ylab)
x, y = np.arange(10), np.arange(10)
df = pd.DataFrame(np.c_[x, y], columns=["x", "y"])
g = ag.FacetGrid(df).map_dataframe(pointplot, x="x", y="y", order=x)
g.set_xticklabels(step=2)
got_x = [int(l.get_text()) for l in g.axes[0, 0].get_xticklabels()]
npt.assert_array_equal(x[::2], got_x)
g = ag.FacetGrid(self.df, col="d", col_wrap=5)
g.map(plt.plot, "x", "y")
g.set_xticklabels(rotation=45)
g.set_yticklabels(rotation=75)
for ax in g._bottom_axes:
for l in ax.get_xticklabels():
assert l.get_rotation() == 45
for ax in g._left_axes:
for l in ax.get_yticklabels():
assert l.get_rotation() == 75
def test_set_axis_labels(self):
g = ag.FacetGrid(self.df, row="a", col="b")
g.map(plt.plot, "x", "y")
xlab = 'xx'
ylab = 'yy'
g.set_axis_labels(xlab, ylab)
got_x = [ax.get_xlabel() for ax in g.axes[-1, :]]
got_y = [ax.get_ylabel() for ax in g.axes[:, 0]]
npt.assert_array_equal(got_x, xlab)
npt.assert_array_equal(got_y, ylab)
for ax in g.axes.flat:
ax.set(xlabel="x", ylabel="y")
g.set_axis_labels(xlab, ylab)
for ax in g._not_bottom_axes:
assert not ax.get_xlabel()
for ax in g._not_left_axes:
assert not ax.get_ylabel()
def test_axis_lims(self):
g = ag.FacetGrid(self.df, row="a", col="b", xlim=(0, 4), ylim=(-2, 3))
assert g.axes[0, 0].get_xlim() == (0, 4)
assert g.axes[0, 0].get_ylim() == (-2, 3)
def test_data_orders(self):
g = ag.FacetGrid(self.df, row="a", col="b", hue="c")
assert g.row_names == list("abc")
assert g.col_names == list("mn")
assert g.hue_names == list("tuv")
assert g.axes.shape == (3, 2)
g = ag.FacetGrid(self.df, row="a", col="b", hue="c",
row_order=list("bca"),
col_order=list("nm"),
hue_order=list("vtu"))
assert g.row_names == list("bca")
assert g.col_names == list("nm")
assert g.hue_names == list("vtu")
assert g.axes.shape == (3, 2)
g = ag.FacetGrid(self.df, row="a", col="b", hue="c",
row_order=list("bcda"),
col_order=list("nom"),
hue_order=list("qvtu"))
assert g.row_names == list("bcda")
assert g.col_names == list("nom")
assert g.hue_names == list("qvtu")
assert g.axes.shape == (4, 3)
def test_palette(self):
rcmod.set()
g = ag.FacetGrid(self.df, hue="c")
assert g._colors == color_palette(n_colors=len(self.df.c.unique()))
g = ag.FacetGrid(self.df, hue="d")
assert g._colors == color_palette("husl", len(self.df.d.unique()))
g = ag.FacetGrid(self.df, hue="c", palette="Set2")
assert g._colors == color_palette("Set2", len(self.df.c.unique()))
dict_pal = dict(t="red", u="green", v="blue")
list_pal = color_palette(["red", "green", "blue"], 3)
g = ag.FacetGrid(self.df, hue="c", palette=dict_pal)
assert g._colors == list_pal
list_pal = color_palette(["green", "blue", "red"], 3)
g = ag.FacetGrid(self.df, hue="c", hue_order=list("uvt"),
palette=dict_pal)
assert g._colors == list_pal
def test_hue_kws(self):
kws = dict(marker=["o", "s", "D"])
g = ag.FacetGrid(self.df, hue="c", hue_kws=kws)
g.map(plt.plot, "x", "y")
for line, marker in zip(g.axes[0, 0].lines, kws["marker"]):
assert line.get_marker() == marker
def test_dropna(self):
df = self.df.copy()
hasna = pd.Series(np.tile(np.arange(6), 10), dtype=float)
hasna[hasna == 5] = np.nan
df["hasna"] = hasna
g = ag.FacetGrid(df, dropna=False, row="hasna")
assert g._not_na.sum() == 60
g = ag.FacetGrid(df, dropna=True, row="hasna")
assert g._not_na.sum() == 50
def test_categorical_column_missing_categories(self):
df = self.df.copy()
df['a'] = df['a'].astype('category')
g = ag.FacetGrid(df[df['a'] == 'a'], col="a", col_wrap=1)
assert g.axes.shape == (len(df['a'].cat.categories),)
def test_categorical_warning(self):
g = ag.FacetGrid(self.df, col="b")
with pytest.warns(UserWarning):
g.map(pointplot, "b", "x")
def test_refline(self):
g = ag.FacetGrid(self.df, row="a", col="b")
g.refline()
for ax in g.axes.flat:
assert not ax.lines
refx = refy = 0.5
hline = np.array([[0, refy], [1, refy]])
vline = np.array([[refx, 0], [refx, 1]])
g.refline(x=refx, y=refy)
for ax in g.axes.flat:
assert ax.lines[0].get_color() == '.5'
assert ax.lines[0].get_linestyle() == '--'
assert len(ax.lines) == 2
npt.assert_array_equal(ax.lines[0].get_xydata(), vline)
npt.assert_array_equal(ax.lines[1].get_xydata(), hline)
color, linestyle = 'red', '-'
g.refline(x=refx, color=color, linestyle=linestyle)
npt.assert_array_equal(g.axes[0, 0].lines[-1].get_xydata(), vline)
assert g.axes[0, 0].lines[-1].get_color() == color
assert g.axes[0, 0].lines[-1].get_linestyle() == linestyle
def test_apply(self, long_df):
def f(grid, color):
grid.figure.set_facecolor(color)
color = (.1, .6, .3, .9)
g = ag.FacetGrid(long_df)
res = g.apply(f, color)
assert res is g
assert g.figure.get_facecolor() == color
def test_pipe(self, long_df):
def f(grid, color):
grid.figure.set_facecolor(color)
return color
color = (.1, .6, .3, .9)
g = ag.FacetGrid(long_df)
res = g.pipe(f, color)
assert res == color
assert g.figure.get_facecolor() == color
def test_tick_params(self):
g = ag.FacetGrid(self.df, row="a", col="b")
color = "blue"
pad = 3
g.tick_params(pad=pad, color=color)
for ax in g.axes.flat:
for axis in ["xaxis", "yaxis"]:
for tick in getattr(ax, axis).get_major_ticks():
assert mpl.colors.same_color(tick.tick1line.get_color(), color)
assert mpl.colors.same_color(tick.tick2line.get_color(), color)
assert tick.get_pad() == pad
@pytest.mark.skipif(
condition=not hasattr(pd.api, "interchange"),
reason="Tests behavior assuming support for dataframe interchange"
)
def test_data_interchange(self, mock_long_df, long_df):
g = ag.FacetGrid(mock_long_df, col="a", row="b")
g.map(scatterplot, "x", "y")
assert g.axes.shape == (long_df["b"].nunique(), long_df["a"].nunique())
for ax in g.axes.flat:
assert len(ax.collections) == 1
class TestPairGrid:
rs = np.random.RandomState(sum(map(ord, "PairGrid")))
df = pd.DataFrame(dict(x=rs.normal(size=60),
y=rs.randint(0, 4, size=(60)),
z=rs.gamma(3, size=60),
a=np.repeat(list("abc"), 20),
b=np.repeat(list("abcdefghijkl"), 5)))
def test_self_data(self):
g = ag.PairGrid(self.df)
assert g.data is self.df
def test_ignore_datelike_data(self):
df = self.df.copy()
df['date'] = pd.date_range('2010-01-01', periods=len(df), freq='D')
result = ag.PairGrid(self.df).data
expected = df.drop('date', axis=1)
tm.assert_frame_equal(result, expected)
def test_self_figure(self):
g = ag.PairGrid(self.df)
assert isinstance(g.figure, plt.Figure)
assert g.figure is g._figure
def test_self_axes(self):
g = ag.PairGrid(self.df)
for ax in g.axes.flat:
assert isinstance(ax, plt.Axes)
def test_default_axes(self):
g = ag.PairGrid(self.df)
assert g.axes.shape == (3, 3)
assert g.x_vars == ["x", "y", "z"]
assert g.y_vars == ["x", "y", "z"]
assert g.square_grid
@pytest.mark.parametrize("vars", [["z", "x"], np.array(["z", "x"])])
def test_specific_square_axes(self, vars):
g = ag.PairGrid(self.df, vars=vars)
assert g.axes.shape == (len(vars), len(vars))
assert g.x_vars == list(vars)
assert g.y_vars == list(vars)
assert g.square_grid
def test_remove_hue_from_default(self):
hue = "z"
g = ag.PairGrid(self.df, hue=hue)
assert hue not in g.x_vars
assert hue not in g.y_vars
vars = ["x", "y", "z"]
g = ag.PairGrid(self.df, hue=hue, vars=vars)
assert hue in g.x_vars
assert hue in g.y_vars
@pytest.mark.parametrize(
"x_vars, y_vars",
[
(["x", "y"], ["z", "y", "x"]),
(["x", "y"], "z"),
(np.array(["x", "y"]), np.array(["z", "y", "x"])),
],
)
def test_specific_nonsquare_axes(self, x_vars, y_vars):
g = ag.PairGrid(self.df, x_vars=x_vars, y_vars=y_vars)
assert g.axes.shape == (len(y_vars), len(x_vars))
assert g.x_vars == list(x_vars)
assert g.y_vars == list(y_vars)
assert not g.square_grid
def test_corner(self):
plot_vars = ["x", "y", "z"]
g = ag.PairGrid(self.df, vars=plot_vars, corner=True)
corner_size = sum(i + 1 for i in range(len(plot_vars)))
assert len(g.figure.axes) == corner_size
g.map_diag(plt.hist)
assert len(g.figure.axes) == (corner_size + len(plot_vars))
for ax in np.diag(g.axes):
assert not ax.yaxis.get_visible()
plot_vars = ["x", "y", "z"]
g = ag.PairGrid(self.df, vars=plot_vars, corner=True)
g.map(scatterplot)
assert len(g.figure.axes) == corner_size
assert g.axes[0, 0].get_ylabel() == "x"
def test_size(self):
g1 = ag.PairGrid(self.df, height=3)
npt.assert_array_equal(g1.fig.get_size_inches(), (9, 9))
g2 = ag.PairGrid(self.df, height=4, aspect=.5)
npt.assert_array_equal(g2.fig.get_size_inches(), (6, 12))
g3 = ag.PairGrid(self.df, y_vars=["z"], x_vars=["x", "y"],
height=2, aspect=2)
npt.assert_array_equal(g3.fig.get_size_inches(), (8, 2))
def test_empty_grid(self):
with pytest.raises(ValueError, match="No variables found"):
ag.PairGrid(self.df[["a", "b"]])
def test_map(self):
vars = ["x", "y", "z"]
g1 = ag.PairGrid(self.df)
g1.map(plt.scatter)
for i, axes_i in enumerate(g1.axes):
for j, ax in enumerate(axes_i):
x_in = self.df[vars[j]]
y_in = self.df[vars[i]]
x_out, y_out = ax.collections[0].get_offsets().T
npt.assert_array_equal(x_in, x_out)
npt.assert_array_equal(y_in, y_out)
g2 = ag.PairGrid(self.df, hue="a")
g2.map(plt.scatter)
for i, axes_i in enumerate(g2.axes):
for j, ax in enumerate(axes_i):
x_in = self.df[vars[j]]
y_in = self.df[vars[i]]
for k, k_level in enumerate(self.df.a.unique()):
x_in_k = x_in[self.df.a == k_level]
y_in_k = y_in[self.df.a == k_level]
x_out, y_out = ax.collections[k].get_offsets().T
npt.assert_array_equal(x_in_k, x_out)
npt.assert_array_equal(y_in_k, y_out)
def test_map_nonsquare(self):
x_vars = ["x"]
y_vars = ["y", "z"]
g = ag.PairGrid(self.df, x_vars=x_vars, y_vars=y_vars)
g.map(plt.scatter)
x_in = self.df.x
for i, i_var in enumerate(y_vars):
ax = g.axes[i, 0]
y_in = self.df[i_var]
x_out, y_out = ax.collections[0].get_offsets().T
npt.assert_array_equal(x_in, x_out)
npt.assert_array_equal(y_in, y_out)
def test_map_lower(self):
vars = ["x", "y", "z"]
g = ag.PairGrid(self.df)
g.map_lower(plt.scatter)
for i, j in zip(*np.tril_indices_from(g.axes, -1)):
ax = g.axes[i, j]
x_in = self.df[vars[j]]
y_in = self.df[vars[i]]
x_out, y_out = ax.collections[0].get_offsets().T
npt.assert_array_equal(x_in, x_out)
npt.assert_array_equal(y_in, y_out)
for i, j in zip(*np.triu_indices_from(g.axes)):
ax = g.axes[i, j]
assert len(ax.collections) == 0
def test_map_upper(self):
vars = ["x", "y", "z"]
g = ag.PairGrid(self.df)
g.map_upper(plt.scatter)
for i, j in zip(*np.triu_indices_from(g.axes, 1)):
ax = g.axes[i, j]
x_in = self.df[vars[j]]
y_in = self.df[vars[i]]
x_out, y_out = ax.collections[0].get_offsets().T
npt.assert_array_equal(x_in, x_out)
npt.assert_array_equal(y_in, y_out)
for i, j in zip(*np.tril_indices_from(g.axes)):
ax = g.axes[i, j]
assert len(ax.collections) == 0
def test_map_mixed_funcsig(self):
vars = ["x", "y", "z"]
g = ag.PairGrid(self.df, vars=vars)
g.map_lower(scatterplot)
g.map_upper(plt.scatter)
for i, j in zip(*np.triu_indices_from(g.axes, 1)):
ax = g.axes[i, j]
x_in = self.df[vars[j]]
y_in = self.df[vars[i]]
x_out, y_out = ax.collections[0].get_offsets().T
npt.assert_array_equal(x_in, x_out)
npt.assert_array_equal(y_in, y_out)
def test_map_diag(self):
g = ag.PairGrid(self.df)
g.map_diag(plt.hist)
for var, ax in zip(g.diag_vars, g.diag_axes):
assert len(ax.patches) == 10
assert pytest.approx(ax.patches[0].get_x()) == self.df[var].min()
g = ag.PairGrid(self.df, hue="a")
g.map_diag(plt.hist)
for ax in g.diag_axes:
assert len(ax.patches) == 30
g = ag.PairGrid(self.df, hue="a")
g.map_diag(plt.hist, histtype='step')
for ax in g.diag_axes:
for ptch in ax.patches:
assert not ptch.fill
def test_map_diag_rectangular(self):
x_vars = ["x", "y"]
y_vars = ["x", "z", "y"]
g1 = ag.PairGrid(self.df, x_vars=x_vars, y_vars=y_vars)
g1.map_diag(plt.hist)
g1.map_offdiag(plt.scatter)
assert set(g1.diag_vars) == (set(x_vars) & set(y_vars))
for var, ax in zip(g1.diag_vars, g1.diag_axes):
assert len(ax.patches) == 10
assert pytest.approx(ax.patches[0].get_x()) == self.df[var].min()
for j, x_var in enumerate(x_vars):
for i, y_var in enumerate(y_vars):
ax = g1.axes[i, j]
if x_var == y_var:
diag_ax = g1.diag_axes[j] # because fewer x than y vars
assert ax.bbox.bounds == diag_ax.bbox.bounds
else:
x, y = ax.collections[0].get_offsets().T
assert_array_equal(x, self.df[x_var])
assert_array_equal(y, self.df[y_var])
g2 = ag.PairGrid(self.df, x_vars=x_vars, y_vars=y_vars, hue="a")
g2.map_diag(plt.hist)
g2.map_offdiag(plt.scatter)
assert set(g2.diag_vars) == (set(x_vars) & set(y_vars))
for ax in g2.diag_axes:
assert len(ax.patches) == 30
x_vars = ["x", "y", "z"]
y_vars = ["x", "z"]
g3 = ag.PairGrid(self.df, x_vars=x_vars, y_vars=y_vars)
g3.map_diag(plt.hist)
g3.map_offdiag(plt.scatter)
assert set(g3.diag_vars) == (set(x_vars) & set(y_vars))
for var, ax in zip(g3.diag_vars, g3.diag_axes):
assert len(ax.patches) == 10
assert pytest.approx(ax.patches[0].get_x()) == self.df[var].min()
for j, x_var in enumerate(x_vars):
for i, y_var in enumerate(y_vars):
ax = g3.axes[i, j]
if x_var == y_var:
diag_ax = g3.diag_axes[i] # because fewer y than x vars
assert ax.bbox.bounds == diag_ax.bbox.bounds
else:
x, y = ax.collections[0].get_offsets().T
assert_array_equal(x, self.df[x_var])
assert_array_equal(y, self.df[y_var])
def test_map_diag_color(self):
color = "red"
g1 = ag.PairGrid(self.df)
g1.map_diag(plt.hist, color=color)
for ax in g1.diag_axes:
for patch in ax.patches:
assert_colors_equal(patch.get_facecolor(), color)
g2 = ag.PairGrid(self.df)
g2.map_diag(kdeplot, color='red')
for ax in g2.diag_axes:
for line in ax.lines:
assert_colors_equal(line.get_color(), color)
def test_map_diag_palette(self):
palette = "muted"
pal = color_palette(palette, n_colors=len(self.df.a.unique()))
g = ag.PairGrid(self.df, hue="a", palette=palette)
g.map_diag(kdeplot)
for ax in g.diag_axes:
for line, color in zip(ax.lines[::-1], pal):
assert_colors_equal(line.get_color(), color)
def test_map_diag_and_offdiag(self):
vars = ["x", "y", "z"]
g = ag.PairGrid(self.df)
g.map_offdiag(plt.scatter)
g.map_diag(plt.hist)
for ax in g.diag_axes:
assert len(ax.patches) == 10
for i, j in zip(*np.triu_indices_from(g.axes, 1)):
ax = g.axes[i, j]
x_in = self.df[vars[j]]
y_in = self.df[vars[i]]
x_out, y_out = ax.collections[0].get_offsets().T
npt.assert_array_equal(x_in, x_out)
npt.assert_array_equal(y_in, y_out)
for i, j in zip(*np.tril_indices_from(g.axes, -1)):
ax = g.axes[i, j]
x_in = self.df[vars[j]]
y_in = self.df[vars[i]]
x_out, y_out = ax.collections[0].get_offsets().T
npt.assert_array_equal(x_in, x_out)
npt.assert_array_equal(y_in, y_out)
for i, j in zip(*np.diag_indices_from(g.axes)):
ax = g.axes[i, j]
assert len(ax.collections) == 0
def test_diag_sharey(self):
g = ag.PairGrid(self.df, diag_sharey=True)
g.map_diag(kdeplot)
for ax in g.diag_axes[1:]:
assert ax.get_ylim() == g.diag_axes[0].get_ylim()
def test_map_diag_matplotlib(self):
bins = 10
g = ag.PairGrid(self.df)
g.map_diag(plt.hist, bins=bins)
for ax in g.diag_axes:
assert len(ax.patches) == bins
levels = len(self.df["a"].unique())
g = ag.PairGrid(self.df, hue="a")
g.map_diag(plt.hist, bins=bins)
for ax in g.diag_axes:
assert len(ax.patches) == (bins * levels)
def test_palette(self):
rcmod.set()
g = ag.PairGrid(self.df, hue="a")
assert g.palette == color_palette(n_colors=len(self.df.a.unique()))
g = ag.PairGrid(self.df, hue="b")
assert g.palette == color_palette("husl", len(self.df.b.unique()))
g = ag.PairGrid(self.df, hue="a", palette="Set2")
assert g.palette == color_palette("Set2", len(self.df.a.unique()))
dict_pal = dict(a="red", b="green", c="blue")
list_pal = color_palette(["red", "green", "blue"])
g = ag.PairGrid(self.df, hue="a", palette=dict_pal)
assert g.palette == list_pal
list_pal = color_palette(["blue", "red", "green"])
g = ag.PairGrid(self.df, hue="a", hue_order=list("cab"),
palette=dict_pal)
assert g.palette == list_pal
def test_hue_kws(self):
kws = dict(marker=["o", "s", "d", "+"])
g = ag.PairGrid(self.df, hue="a", hue_kws=kws)
g.map(plt.plot)
for line, marker in zip(g.axes[0, 0].lines, kws["marker"]):
assert line.get_marker() == marker
g = ag.PairGrid(self.df, hue="a", hue_kws=kws,
hue_order=list("dcab"))
g.map(plt.plot)
for line, marker in zip(g.axes[0, 0].lines, kws["marker"]):
assert line.get_marker() == marker
def test_hue_order(self):
order = list("dcab")
g = ag.PairGrid(self.df, hue="a", hue_order=order)
g.map(plt.plot)
for line, level in zip(g.axes[1, 0].lines, order):
x, y = line.get_xydata().T
npt.assert_array_equal(x, self.df.loc[self.df.a == level, "x"])
npt.assert_array_equal(y, self.df.loc[self.df.a == level, "y"])
plt.close("all")
g = ag.PairGrid(self.df, hue="a", hue_order=order)
g.map_diag(plt.plot)
for line, level in zip(g.axes[0, 0].lines, order):
x, y = line.get_xydata().T
npt.assert_array_equal(x, self.df.loc[self.df.a == level, "x"])
npt.assert_array_equal(y, self.df.loc[self.df.a == level, "x"])
plt.close("all")
g = ag.PairGrid(self.df, hue="a", hue_order=order)
g.map_lower(plt.plot)
for line, level in zip(g.axes[1, 0].lines, order):
x, y = line.get_xydata().T
npt.assert_array_equal(x, self.df.loc[self.df.a == level, "x"])
npt.assert_array_equal(y, self.df.loc[self.df.a == level, "y"])
plt.close("all")
g = ag.PairGrid(self.df, hue="a", hue_order=order)
g.map_upper(plt.plot)
for line, level in zip(g.axes[0, 1].lines, order):
x, y = line.get_xydata().T
npt.assert_array_equal(x, self.df.loc[self.df.a == level, "y"])
npt.assert_array_equal(y, self.df.loc[self.df.a == level, "x"])
plt.close("all")
def test_hue_order_missing_level(self):
order = list("dcaeb")
g = ag.PairGrid(self.df, hue="a", hue_order=order)
g.map(plt.plot)
for line, level in zip(g.axes[1, 0].lines, order):
x, y = line.get_xydata().T
npt.assert_array_equal(x, self.df.loc[self.df.a == level, "x"])
npt.assert_array_equal(y, self.df.loc[self.df.a == level, "y"])
plt.close("all")
g = ag.PairGrid(self.df, hue="a", hue_order=order)
g.map_diag(plt.plot)
for line, level in zip(g.axes[0, 0].lines, order):
x, y = line.get_xydata().T
npt.assert_array_equal(x, self.df.loc[self.df.a == level, "x"])
npt.assert_array_equal(y, self.df.loc[self.df.a == level, "x"])
plt.close("all")
g = ag.PairGrid(self.df, hue="a", hue_order=order)
g.map_lower(plt.plot)
for line, level in zip(g.axes[1, 0].lines, order):
x, y = line.get_xydata().T
npt.assert_array_equal(x, self.df.loc[self.df.a == level, "x"])
npt.assert_array_equal(y, self.df.loc[self.df.a == level, "y"])
plt.close("all")
g = ag.PairGrid(self.df, hue="a", hue_order=order)
g.map_upper(plt.plot)
for line, level in zip(g.axes[0, 1].lines, order):
x, y = line.get_xydata().T
npt.assert_array_equal(x, self.df.loc[self.df.a == level, "y"])
npt.assert_array_equal(y, self.df.loc[self.df.a == level, "x"])
plt.close("all")
def test_hue_in_map(self, long_df):
g = ag.PairGrid(long_df, vars=["x", "y"])
g.map(scatterplot, hue=long_df["a"])
ax = g.axes.flat[0]
points = ax.collections[0]
assert len(set(map(tuple, points.get_facecolors()))) == 3
def test_nondefault_index(self):
df = self.df.copy().set_index("b")
plot_vars = ["x", "y", "z"]
g1 = ag.PairGrid(df)
g1.map(plt.scatter)
for i, axes_i in enumerate(g1.axes):
for j, ax in enumerate(axes_i):
x_in = self.df[plot_vars[j]]
y_in = self.df[plot_vars[i]]
x_out, y_out = ax.collections[0].get_offsets().T
npt.assert_array_equal(x_in, x_out)
npt.assert_array_equal(y_in, y_out)
g2 = ag.PairGrid(df, hue="a")
g2.map(plt.scatter)
for i, axes_i in enumerate(g2.axes):
for j, ax in enumerate(axes_i):
x_in = self.df[plot_vars[j]]
y_in = self.df[plot_vars[i]]
for k, k_level in enumerate(self.df.a.unique()):
x_in_k = x_in[self.df.a == k_level]
y_in_k = y_in[self.df.a == k_level]
x_out, y_out = ax.collections[k].get_offsets().T
npt.assert_array_equal(x_in_k, x_out)
npt.assert_array_equal(y_in_k, y_out)
@pytest.mark.parametrize("func", [scatterplot, plt.scatter])
def test_dropna(self, func):
df = self.df.copy()
n_null = 20
df.loc[np.arange(n_null), "x"] = np.nan
plot_vars = ["x", "y", "z"]
g1 = ag.PairGrid(df, vars=plot_vars, dropna=True)
g1.map(func)
for i, axes_i in enumerate(g1.axes):
for j, ax in enumerate(axes_i):
x_in = df[plot_vars[j]]
y_in = df[plot_vars[i]]
x_out, y_out = ax.collections[0].get_offsets().T
n_valid = (x_in * y_in).notnull().sum()
assert n_valid == len(x_out)
assert n_valid == len(y_out)
g1.map_diag(histplot)
for i, ax in enumerate(g1.diag_axes):
var = plot_vars[i]
count = sum(p.get_height() for p in ax.patches)
assert count == df[var].notna().sum()
def test_histplot_legend(self):
# Tests _extract_legend_handles
g = ag.PairGrid(self.df, vars=["x", "y"], hue="a")
g.map_offdiag(histplot)
g.add_legend()
assert len(get_legend_handles(g._legend)) == len(self.df["a"].unique())
def test_pairplot(self):
vars = ["x", "y", "z"]
g = ag.pairplot(self.df)
for ax in g.diag_axes:
assert len(ax.patches) > 1
for i, j in zip(*np.triu_indices_from(g.axes, 1)):
ax = g.axes[i, j]
x_in = self.df[vars[j]]
y_in = self.df[vars[i]]
x_out, y_out = ax.collections[0].get_offsets().T
npt.assert_array_equal(x_in, x_out)
npt.assert_array_equal(y_in, y_out)
for i, j in zip(*np.tril_indices_from(g.axes, -1)):
ax = g.axes[i, j]
x_in = self.df[vars[j]]
y_in = self.df[vars[i]]
x_out, y_out = ax.collections[0].get_offsets().T
npt.assert_array_equal(x_in, x_out)
npt.assert_array_equal(y_in, y_out)
for i, j in zip(*np.diag_indices_from(g.axes)):
ax = g.axes[i, j]
assert len(ax.collections) == 0
g = ag.pairplot(self.df, hue="a")
n = len(self.df.a.unique())
for ax in g.diag_axes:
assert len(ax.collections) == n
def test_pairplot_reg(self):
vars = ["x", "y", "z"]
g = ag.pairplot(self.df, diag_kind="hist", kind="reg")
for ax in g.diag_axes:
assert len(ax.patches)
for i, j in zip(*np.triu_indices_from(g.axes, 1)):
ax = g.axes[i, j]
x_in = self.df[vars[j]]
y_in = self.df[vars[i]]
x_out, y_out = ax.collections[0].get_offsets().T
npt.assert_array_equal(x_in, x_out)
npt.assert_array_equal(y_in, y_out)
assert len(ax.lines) == 1
assert len(ax.collections) == 2
for i, j in zip(*np.tril_indices_from(g.axes, -1)):
ax = g.axes[i, j]
x_in = self.df[vars[j]]
y_in = self.df[vars[i]]
x_out, y_out = ax.collections[0].get_offsets().T
npt.assert_array_equal(x_in, x_out)
npt.assert_array_equal(y_in, y_out)
assert len(ax.lines) == 1
assert len(ax.collections) == 2
for i, j in zip(*np.diag_indices_from(g.axes)):
ax = g.axes[i, j]
assert len(ax.collections) == 0
def test_pairplot_reg_hue(self):
markers = ["o", "s", "d"]
g = ag.pairplot(self.df, kind="reg", hue="a", markers=markers)
ax = g.axes[-1, 0]
c1 = ax.collections[0]
c2 = ax.collections[2]
assert not np.array_equal(c1.get_facecolor(), c2.get_facecolor())
assert not np.array_equal(
c1.get_paths()[0].vertices, c2.get_paths()[0].vertices,
)
def test_pairplot_diag_kde(self):
vars = ["x", "y", "z"]
g = ag.pairplot(self.df, diag_kind="kde")
for ax in g.diag_axes:
assert len(ax.collections) == 1
for i, j in zip(*np.triu_indices_from(g.axes, 1)):
ax = g.axes[i, j]
x_in = self.df[vars[j]]
y_in = self.df[vars[i]]
x_out, y_out = ax.collections[0].get_offsets().T
npt.assert_array_equal(x_in, x_out)
npt.assert_array_equal(y_in, y_out)
for i, j in zip(*np.tril_indices_from(g.axes, -1)):
ax = g.axes[i, j]
x_in = self.df[vars[j]]
y_in = self.df[vars[i]]
x_out, y_out = ax.collections[0].get_offsets().T
npt.assert_array_equal(x_in, x_out)
npt.assert_array_equal(y_in, y_out)
for i, j in zip(*np.diag_indices_from(g.axes)):
ax = g.axes[i, j]
assert len(ax.collections) == 0
def test_pairplot_kde(self):
f, ax1 = plt.subplots()
kdeplot(data=self.df, x="x", y="y", ax=ax1)
g = ag.pairplot(self.df, kind="kde")
ax2 = g.axes[1, 0]
assert_plots_equal(ax1, ax2, labels=False)
def test_pairplot_hist(self):
f, ax1 = plt.subplots()
histplot(data=self.df, x="x", y="y", ax=ax1)
g = ag.pairplot(self.df, kind="hist")
ax2 = g.axes[1, 0]
assert_plots_equal(ax1, ax2, labels=False)
@pytest.mark.skipif(_version_predates(mpl, "3.7.0"), reason="Matplotlib bug")
def test_pairplot_markers(self):
vars = ["x", "y", "z"]
markers = ["o", "X", "s"]
g = ag.pairplot(self.df, hue="a", vars=vars, markers=markers)
m1 = get_legend_handles(g._legend)[0].get_marker()
m2 = get_legend_handles(g._legend)[1].get_marker()
assert m1 != m2
with pytest.warns(UserWarning):
g = ag.pairplot(self.df, hue="a", vars=vars, markers=markers[:-2])
def test_pairplot_column_multiindex(self):
cols = pd.MultiIndex.from_arrays([["x", "y"], [1, 2]])
df = self.df[["x", "y"]].set_axis(cols, axis=1)
g = ag.pairplot(df)
assert g.diag_vars == list(cols)
def test_corner_despine(self):
g = ag.PairGrid(self.df, corner=True, despine=False)
g.map_diag(histplot)
assert g.axes[0, 0].spines["top"].get_visible()
def test_corner_set(self):
g = ag.PairGrid(self.df, corner=True, despine=False)
g.set(xlim=(0, 10))
assert g.axes[-1, 0].get_xlim() == (0, 10)
def test_legend(self):
g1 = ag.pairplot(self.df, hue="a")
assert isinstance(g1.legend, mpl.legend.Legend)
g2 = ag.pairplot(self.df)
assert g2.legend is None
def test_tick_params(self):
g = ag.PairGrid(self.df)
color = "red"
pad = 3
g.tick_params(pad=pad, color=color)
for ax in g.axes.flat:
for axis in ["xaxis", "yaxis"]:
for tick in getattr(ax, axis).get_major_ticks():
assert mpl.colors.same_color(tick.tick1line.get_color(), color)
assert mpl.colors.same_color(tick.tick2line.get_color(), color)
assert tick.get_pad() == pad
@pytest.mark.skipif(
condition=not hasattr(pd.api, "interchange"),
reason="Tests behavior assuming support for dataframe interchange"
)
def test_data_interchange(self, mock_long_df, long_df):
g = ag.PairGrid(mock_long_df, vars=["x", "y", "z"], hue="a")
g.map(scatterplot)
assert g.axes.shape == (3, 3)
for ax in g.axes.flat:
pts = ax.collections[0].get_offsets()
assert len(pts) == len(long_df)
class TestJointGrid:
rs = np.random.RandomState(sum(map(ord, "JointGrid")))
x = rs.randn(100)
y = rs.randn(100)
x_na = x.copy()
x_na[10] = np.nan
x_na[20] = np.nan
data = pd.DataFrame(dict(x=x, y=y, x_na=x_na))
def test_margin_grid_from_lists(self):
g = ag.JointGrid(x=self.x.tolist(), y=self.y.tolist())
npt.assert_array_equal(g.x, self.x)
npt.assert_array_equal(g.y, self.y)
def test_margin_grid_from_arrays(self):
g = ag.JointGrid(x=self.x, y=self.y)
npt.assert_array_equal(g.x, self.x)
npt.assert_array_equal(g.y, self.y)
def test_margin_grid_from_series(self):
g = ag.JointGrid(x=self.data.x, y=self.data.y)
npt.assert_array_equal(g.x, self.x)
npt.assert_array_equal(g.y, self.y)
def test_margin_grid_from_dataframe(self):
g = ag.JointGrid(x="x", y="y", data=self.data)
npt.assert_array_equal(g.x, self.x)
npt.assert_array_equal(g.y, self.y)
def test_margin_grid_from_dataframe_bad_variable(self):
with pytest.raises(ValueError):
ag.JointGrid(x="x", y="bad_column", data=self.data)
def test_margin_grid_axis_labels(self):
g = ag.JointGrid(x="x", y="y", data=self.data)
xlabel, ylabel = g.ax_joint.get_xlabel(), g.ax_joint.get_ylabel()
assert xlabel == "x"
assert ylabel == "y"
g.set_axis_labels("x variable", "y variable")
xlabel, ylabel = g.ax_joint.get_xlabel(), g.ax_joint.get_ylabel()
assert xlabel == "x variable"
assert ylabel == "y variable"
def test_dropna(self):
g = ag.JointGrid(x="x_na", y="y", data=self.data, dropna=False)
assert len(g.x) == len(self.x_na)
g = ag.JointGrid(x="x_na", y="y", data=self.data, dropna=True)
assert len(g.x) == pd.notnull(self.x_na).sum()
def test_axlims(self):
lim = (-3, 3)
g = ag.JointGrid(x="x", y="y", data=self.data, xlim=lim, ylim=lim)
assert g.ax_joint.get_xlim() == lim
assert g.ax_joint.get_ylim() == lim
assert g.ax_marg_x.get_xlim() == lim
assert g.ax_marg_y.get_ylim() == lim
def test_marginal_ticks(self):
g = ag.JointGrid(marginal_ticks=False)
assert not sum(t.get_visible() for t in g.ax_marg_x.get_yticklabels())
assert not sum(t.get_visible() for t in g.ax_marg_y.get_xticklabels())
g = ag.JointGrid(marginal_ticks=True)
assert sum(t.get_visible() for t in g.ax_marg_x.get_yticklabels())
assert sum(t.get_visible() for t in g.ax_marg_y.get_xticklabels())
def test_bivariate_plot(self):
g = ag.JointGrid(x="x", y="y", data=self.data)
g.plot_joint(plt.plot)
x, y = g.ax_joint.lines[0].get_xydata().T
npt.assert_array_equal(x, self.x)
npt.assert_array_equal(y, self.y)
def test_univariate_plot(self):
g = ag.JointGrid(x="x", y="x", data=self.data)
g.plot_marginals(kdeplot)
_, y1 = g.ax_marg_x.lines[0].get_xydata().T
y2, _ = g.ax_marg_y.lines[0].get_xydata().T
npt.assert_array_equal(y1, y2)
def test_univariate_plot_distplot(self):
bins = 10
g = ag.JointGrid(x="x", y="x", data=self.data)
with pytest.warns(UserWarning):
g.plot_marginals(distplot, bins=bins)
assert len(g.ax_marg_x.patches) == bins
assert len(g.ax_marg_y.patches) == bins
for x, y in zip(g.ax_marg_x.patches, g.ax_marg_y.patches):
assert x.get_height() == y.get_width()
def test_univariate_plot_matplotlib(self):
bins = 10
g = ag.JointGrid(x="x", y="x", data=self.data)
g.plot_marginals(plt.hist, bins=bins)
assert len(g.ax_marg_x.patches) == bins
assert len(g.ax_marg_y.patches) == bins
def test_plot(self):
g = ag.JointGrid(x="x", y="x", data=self.data)
g.plot(plt.plot, kdeplot)
x, y = g.ax_joint.lines[0].get_xydata().T
npt.assert_array_equal(x, self.x)
npt.assert_array_equal(y, self.x)
_, y1 = g.ax_marg_x.lines[0].get_xydata().T
y2, _ = g.ax_marg_y.lines[0].get_xydata().T
npt.assert_array_equal(y1, y2)
def test_space(self):
g = ag.JointGrid(x="x", y="y", data=self.data, space=0)
joint_bounds = g.ax_joint.bbox.bounds
marg_x_bounds = g.ax_marg_x.bbox.bounds
marg_y_bounds = g.ax_marg_y.bbox.bounds
assert joint_bounds[2] == marg_x_bounds[2]
assert joint_bounds[3] == marg_y_bounds[3]
@pytest.mark.parametrize(
"as_vector", [True, False],
)
def test_hue(self, long_df, as_vector):
if as_vector:
data = None
x, y, hue = long_df["x"], long_df["y"], long_df["a"]
else:
data = long_df
x, y, hue = "x", "y", "a"
g = ag.JointGrid(data=data, x=x, y=y, hue=hue)
g.plot_joint(scatterplot)
g.plot_marginals(histplot)
g2 = ag.JointGrid()
scatterplot(data=long_df, x=x, y=y, hue=hue, ax=g2.ax_joint)
histplot(data=long_df, x=x, hue=hue, ax=g2.ax_marg_x)
histplot(data=long_df, y=y, hue=hue, ax=g2.ax_marg_y)
assert_plots_equal(g.ax_joint, g2.ax_joint)
assert_plots_equal(g.ax_marg_x, g2.ax_marg_x, labels=False)
assert_plots_equal(g.ax_marg_y, g2.ax_marg_y, labels=False)
def test_refline(self):
g = ag.JointGrid(x="x", y="y", data=self.data)
g.plot(scatterplot, histplot)
g.refline()
assert not g.ax_joint.lines and not g.ax_marg_x.lines and not g.ax_marg_y.lines
refx = refy = 0.5
hline = np.array([[0, refy], [1, refy]])
vline = np.array([[refx, 0], [refx, 1]])
g.refline(x=refx, y=refy, joint=False, marginal=False)
assert not g.ax_joint.lines and not g.ax_marg_x.lines and not g.ax_marg_y.lines
g.refline(x=refx, y=refy)
assert g.ax_joint.lines[0].get_color() == '.5'
assert g.ax_joint.lines[0].get_linestyle() == '--'
assert len(g.ax_joint.lines) == 2
assert len(g.ax_marg_x.lines) == 1
assert len(g.ax_marg_y.lines) == 1
npt.assert_array_equal(g.ax_joint.lines[0].get_xydata(), vline)
npt.assert_array_equal(g.ax_joint.lines[1].get_xydata(), hline)
npt.assert_array_equal(g.ax_marg_x.lines[0].get_xydata(), vline)
npt.assert_array_equal(g.ax_marg_y.lines[0].get_xydata(), hline)
color, linestyle = 'red', '-'
g.refline(x=refx, marginal=False, color=color, linestyle=linestyle)
npt.assert_array_equal(g.ax_joint.lines[-1].get_xydata(), vline)
assert g.ax_joint.lines[-1].get_color() == color
assert g.ax_joint.lines[-1].get_linestyle() == linestyle
assert len(g.ax_marg_x.lines) == len(g.ax_marg_y.lines)
g.refline(x=refx, joint=False)
npt.assert_array_equal(g.ax_marg_x.lines[-1].get_xydata(), vline)
assert len(g.ax_marg_x.lines) == len(g.ax_marg_y.lines) + 1
g.refline(y=refy, joint=False)
npt.assert_array_equal(g.ax_marg_y.lines[-1].get_xydata(), hline)
assert len(g.ax_marg_x.lines) == len(g.ax_marg_y.lines)
g.refline(y=refy, marginal=False)
npt.assert_array_equal(g.ax_joint.lines[-1].get_xydata(), hline)
assert len(g.ax_marg_x.lines) == len(g.ax_marg_y.lines)
class TestJointPlot:
rs = np.random.RandomState(sum(map(ord, "jointplot")))
x = rs.randn(100)
y = rs.randn(100)
data = pd.DataFrame(dict(x=x, y=y))
def test_scatter(self):
g = ag.jointplot(x="x", y="y", data=self.data)
assert len(g.ax_joint.collections) == 1
x, y = g.ax_joint.collections[0].get_offsets().T
assert_array_equal(self.x, x)
assert_array_equal(self.y, y)
assert_array_almost_equal(
[b.get_x() for b in g.ax_marg_x.patches],
np.histogram_bin_edges(self.x, "auto")[:-1],
)
assert_array_almost_equal(
[b.get_y() for b in g.ax_marg_y.patches],
np.histogram_bin_edges(self.y, "auto")[:-1],
)
def test_scatter_hue(self, long_df):
g1 = ag.jointplot(data=long_df, x="x", y="y", hue="a")
g2 = ag.JointGrid()
scatterplot(data=long_df, x="x", y="y", hue="a", ax=g2.ax_joint)
kdeplot(data=long_df, x="x", hue="a", ax=g2.ax_marg_x, fill=True)
kdeplot(data=long_df, y="y", hue="a", ax=g2.ax_marg_y, fill=True)
assert_plots_equal(g1.ax_joint, g2.ax_joint)
assert_plots_equal(g1.ax_marg_x, g2.ax_marg_x, labels=False)
assert_plots_equal(g1.ax_marg_y, g2.ax_marg_y, labels=False)
def test_reg(self):
g = ag.jointplot(x="x", y="y", data=self.data, kind="reg")
assert len(g.ax_joint.collections) == 2
x, y = g.ax_joint.collections[0].get_offsets().T
assert_array_equal(self.x, x)
assert_array_equal(self.y, y)
assert g.ax_marg_x.patches
assert g.ax_marg_y.patches
assert g.ax_marg_x.lines
assert g.ax_marg_y.lines
def test_resid(self):
g = ag.jointplot(x="x", y="y", data=self.data, kind="resid")
assert g.ax_joint.collections
assert g.ax_joint.lines
assert not g.ax_marg_x.lines
assert not g.ax_marg_y.lines
def test_hist(self, long_df):
bins = 3, 6
g1 = ag.jointplot(data=long_df, x="x", y="y", kind="hist", bins=bins)
g2 = ag.JointGrid()
histplot(data=long_df, x="x", y="y", ax=g2.ax_joint, bins=bins)
histplot(data=long_df, x="x", ax=g2.ax_marg_x, bins=bins[0])
histplot(data=long_df, y="y", ax=g2.ax_marg_y, bins=bins[1])
assert_plots_equal(g1.ax_joint, g2.ax_joint)
assert_plots_equal(g1.ax_marg_x, g2.ax_marg_x, labels=False)
assert_plots_equal(g1.ax_marg_y, g2.ax_marg_y, labels=False)
def test_hex(self):
g = ag.jointplot(x="x", y="y", data=self.data, kind="hex")
assert g.ax_joint.collections
assert g.ax_marg_x.patches
assert g.ax_marg_y.patches
def test_kde(self, long_df):
g1 = ag.jointplot(data=long_df, x="x", y="y", kind="kde")
g2 = ag.JointGrid()
kdeplot(data=long_df, x="x", y="y", ax=g2.ax_joint)
kdeplot(data=long_df, x="x", ax=g2.ax_marg_x)
kdeplot(data=long_df, y="y", ax=g2.ax_marg_y)
assert_plots_equal(g1.ax_joint, g2.ax_joint)
assert_plots_equal(g1.ax_marg_x, g2.ax_marg_x, labels=False)
assert_plots_equal(g1.ax_marg_y, g2.ax_marg_y, labels=False)
def test_kde_hue(self, long_df):
g1 = ag.jointplot(data=long_df, x="x", y="y", hue="a", kind="kde")
g2 = ag.JointGrid()
kdeplot(data=long_df, x="x", y="y", hue="a", ax=g2.ax_joint)
kdeplot(data=long_df, x="x", hue="a", ax=g2.ax_marg_x)
kdeplot(data=long_df, y="y", hue="a", ax=g2.ax_marg_y)
assert_plots_equal(g1.ax_joint, g2.ax_joint)
assert_plots_equal(g1.ax_marg_x, g2.ax_marg_x, labels=False)
assert_plots_equal(g1.ax_marg_y, g2.ax_marg_y, labels=False)
def test_color(self):
g = ag.jointplot(x="x", y="y", data=self.data, color="purple")
scatter_color = g.ax_joint.collections[0].get_facecolor()
assert_colors_equal(scatter_color, "purple")
hist_color = g.ax_marg_x.patches[0].get_facecolor()[:3]
assert_colors_equal(hist_color, "purple")
def test_palette(self, long_df):
kws = dict(data=long_df, hue="a", palette="Set2")
g1 = ag.jointplot(x="x", y="y", **kws)
g2 = ag.JointGrid()
scatterplot(x="x", y="y", ax=g2.ax_joint, **kws)
kdeplot(x="x", ax=g2.ax_marg_x, fill=True, **kws)
kdeplot(y="y", ax=g2.ax_marg_y, fill=True, **kws)
assert_plots_equal(g1.ax_joint, g2.ax_joint)
assert_plots_equal(g1.ax_marg_x, g2.ax_marg_x, labels=False)
assert_plots_equal(g1.ax_marg_y, g2.ax_marg_y, labels=False)
def test_hex_customise(self):
# test that default gridsize can be overridden
g = ag.jointplot(x="x", y="y", data=self.data, kind="hex",
joint_kws=dict(gridsize=5))
assert len(g.ax_joint.collections) == 1
a = g.ax_joint.collections[0].get_array()
assert a.shape[0] == 28 # 28 hexagons expected for gridsize 5
def test_bad_kind(self):
with pytest.raises(ValueError):
ag.jointplot(x="x", y="y", data=self.data, kind="not_a_kind")
def test_unsupported_hue_kind(self):
for kind in ["reg", "resid", "hex"]:
with pytest.raises(ValueError):
ag.jointplot(x="x", y="y", hue="a", data=self.data, kind=kind)
def test_leaky_dict(self):
# Validate input dicts are unchanged by jointplot plotting function
for kwarg in ("joint_kws", "marginal_kws"):
for kind in ("hex", "kde", "resid", "reg", "scatter"):
empty_dict = {}
ag.jointplot(x="x", y="y", data=self.data, kind=kind,
**{kwarg: empty_dict})
assert empty_dict == {}
def test_distplot_kwarg_warning(self, long_df):
with pytest.warns(UserWarning):
g = ag.jointplot(data=long_df, x="x", y="y", marginal_kws=dict(rug=True))
assert g.ax_marg_x.patches
def test_ax_warning(self, long_df):
ax = plt.gca()
with pytest.warns(UserWarning):
g = ag.jointplot(data=long_df, x="x", y="y", ax=ax)
assert g.ax_joint.collections
================================================
FILE: tests/test_base.py
================================================
import itertools
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import pytest
from numpy.testing import assert_array_equal, assert_array_almost_equal
from pandas.testing import assert_frame_equal
from seaborn.axisgrid import FacetGrid
from seaborn._compat import get_colormap, get_converter
from seaborn._base import (
SemanticMapping,
HueMapping,
SizeMapping,
StyleMapping,
VectorPlotter,
variable_type,
infer_orient,
unique_dashes,
unique_markers,
categorical_order,
)
from seaborn.utils import desaturate
from seaborn.palettes import color_palette
@pytest.fixture(params=[
dict(x="x", y="y"),
dict(x="t", y="y"),
dict(x="a", y="y"),
dict(x="x", y="y", hue="y"),
dict(x="x", y="y", hue="a"),
dict(x="x", y="y", size="a"),
dict(x="x", y="y", style="a"),
dict(x="x", y="y", hue="s"),
dict(x="x", y="y", size="s"),
dict(x="x", y="y", style="s"),
dict(x="x", y="y", hue="a", style="a"),
dict(x="x", y="y", hue="a", size="b", style="b"),
])
def long_variables(request):
return request.param
class TestSemanticMapping:
def test_call_lookup(self):
m = SemanticMapping(VectorPlotter())
lookup_table = dict(zip("abc", (1, 2, 3)))
m.lookup_table = lookup_table
for key, val in lookup_table.items():
assert m(key) == val
class TestHueMapping:
def test_plotter_default_init(self, long_df):
p = VectorPlotter(
data=long_df,
variables=dict(x="x", y="y"),
)
assert not hasattr(p, "_hue_map")
p = VectorPlotter(
data=long_df,
variables=dict(x="x", y="y", hue="a"),
)
assert isinstance(p._hue_map, HueMapping)
assert p._hue_map.map_type == p.var_types["hue"]
def test_plotter_customization(self, long_df):
p = VectorPlotter(
data=long_df,
variables=dict(x="x", y="y", hue="a"),
)
palette = "muted"
hue_order = ["b", "a", "c"]
p.map_hue(palette=palette, order=hue_order)
assert p._hue_map.palette == palette
assert p._hue_map.levels == hue_order
def test_hue_map_null(self, flat_series, null_series):
p = VectorPlotter(variables=dict(x=flat_series, hue=null_series))
m = HueMapping(p)
assert m.levels is None
assert m.map_type is None
assert m.palette is None
assert m.cmap is None
assert m.norm is None
assert m.lookup_table is None
def test_hue_map_categorical(self, wide_df, long_df):
p = VectorPlotter(data=wide_df)
m = HueMapping(p)
assert m.levels == wide_df.columns.to_list()
assert m.map_type == "categorical"
assert m.cmap is None
# Test named palette
palette = "Blues"
expected_colors = color_palette(palette, wide_df.shape[1])
expected_lookup_table = dict(zip(wide_df.columns, expected_colors))
m = HueMapping(p, palette=palette)
assert m.palette == "Blues"
assert m.lookup_table == expected_lookup_table
# Test list palette
palette = color_palette("Reds", wide_df.shape[1])
expected_lookup_table = dict(zip(wide_df.columns, palette))
m = HueMapping(p, palette=palette)
assert m.palette == palette
assert m.lookup_table == expected_lookup_table
# Test dict palette
colors = color_palette("Set1", 8)
palette = dict(zip(wide_df.columns, colors))
m = HueMapping(p, palette=palette)
assert m.palette == palette
assert m.lookup_table == palette
# Test dict with missing keys
palette = dict(zip(wide_df.columns[:-1], colors))
with pytest.raises(ValueError):
HueMapping(p, palette=palette)
# Test list with wrong number of colors
palette = colors[:-1]
with pytest.warns(UserWarning):
HueMapping(p, palette=palette)
# Test hue order
hue_order = ["a", "c", "d"]
m = HueMapping(p, order=hue_order)
assert m.levels == hue_order
# Test long data
p = VectorPlotter(data=long_df, variables=dict(x="x", y="y", hue="a"))
m = HueMapping(p)
assert m.levels == categorical_order(long_df["a"])
assert m.map_type == "categorical"
assert m.cmap is None
# Test default palette
m = HueMapping(p)
hue_levels = categorical_order(long_df["a"])
expected_colors = color_palette(n_colors=len(hue_levels))
expected_lookup_table = dict(zip(hue_levels, expected_colors))
assert m.lookup_table == expected_lookup_table
# Test missing data
m = HueMapping(p)
assert m(np.nan) == (0, 0, 0, 0)
# Test default palette with many levels
x = y = np.arange(26)
hue = pd.Series(list("abcdefghijklmnopqrstuvwxyz"))
p = VectorPlotter(variables=dict(x=x, y=y, hue=hue))
m = HueMapping(p)
expected_colors = color_palette("husl", n_colors=len(hue))
expected_lookup_table = dict(zip(hue, expected_colors))
assert m.lookup_table == expected_lookup_table
# Test binary data
p = VectorPlotter(data=long_df, variables=dict(x="x", y="y", hue="c"))
m = HueMapping(p)
assert m.levels == [0, 1]
assert m.map_type == "categorical"
for val in [0, 1]:
p = VectorPlotter(
data=long_df[long_df["c"] == val],
variables=dict(x="x", y="y", hue="c"),
)
m = HueMapping(p)
assert m.levels == [val]
assert m.map_type == "categorical"
# Test Timestamp data
p = VectorPlotter(data=long_df, variables=dict(x="x", y="y", hue="t"))
m = HueMapping(p)
assert m.levels == [pd.Timestamp(t) for t in long_df["t"].unique()]
assert m.map_type == "datetime"
# Test explicit categories
p = VectorPlotter(data=long_df, variables=dict(x="x", hue="a_cat"))
m = HueMapping(p)
assert m.levels == long_df["a_cat"].cat.categories.to_list()
assert m.map_type == "categorical"
# Test numeric data with category type
p = VectorPlotter(
data=long_df,
variables=dict(x="x", y="y", hue="s_cat")
)
m = HueMapping(p)
assert m.levels == categorical_order(long_df["s_cat"])
assert m.map_type == "categorical"
assert m.cmap is None
# Test categorical palette specified for numeric data
p = VectorPlotter(
data=long_df,
variables=dict(x="x", y="y", hue="s")
)
palette = "deep"
levels = categorical_order(long_df["s"])
expected_colors = color_palette(palette, n_colors=len(levels))
expected_lookup_table = dict(zip(levels, expected_colors))
m = HueMapping(p, palette=palette)
assert m.lookup_table == expected_lookup_table
assert m.map_type == "categorical"
def test_hue_map_numeric(self, long_df):
vals = np.concatenate([np.linspace(0, 1, 256), [-.1, 1.1, np.nan]])
# Test default colormap
p = VectorPlotter(
data=long_df,
variables=dict(x="x", y="y", hue="s")
)
hue_levels = list(np.sort(long_df["s"].unique()))
m = HueMapping(p)
assert m.levels == hue_levels
assert m.map_type == "numeric"
assert m.cmap.name == "seaborn_cubehelix"
# Test named colormap
palette = "Purples"
m = HueMapping(p, palette=palette)
assert_array_equal(m.cmap(vals), get_colormap(palette)(vals))
# Test colormap object
palette = get_colormap("Greens")
m = HueMapping(p, palette=palette)
assert_array_equal(m.cmap(vals), palette(vals))
# Test cubehelix shorthand
palette = "ch:2,0,light=.2"
m = HueMapping(p, palette=palette)
assert isinstance(m.cmap, mpl.colors.ListedColormap)
# Test specified hue limits
hue_norm = 1, 4
m = HueMapping(p, norm=hue_norm)
assert isinstance(m.norm, mpl.colors.Normalize)
assert m.norm.vmin == hue_norm[0]
assert m.norm.vmax == hue_norm[1]
# Test Normalize object
hue_norm = mpl.colors.PowerNorm(2, vmin=1, vmax=10)
m = HueMapping(p, norm=hue_norm)
assert m.norm is hue_norm
# Test default colormap values
hmin, hmax = p.plot_data["hue"].min(), p.plot_data["hue"].max()
m = HueMapping(p)
assert m.lookup_table[hmin] == pytest.approx(m.cmap(0.0))
assert m.lookup_table[hmax] == pytest.approx(m.cmap(1.0))
# Test specified colormap values
hue_norm = hmin - 1, hmax - 1
m = HueMapping(p, norm=hue_norm)
norm_min = (hmin - hue_norm[0]) / (hue_norm[1] - hue_norm[0])
assert m.lookup_table[hmin] == pytest.approx(m.cmap(norm_min))
assert m.lookup_table[hmax] == pytest.approx(m.cmap(1.0))
# Test list of colors
hue_levels = list(np.sort(long_df["s"].unique()))
palette = color_palette("Blues", len(hue_levels))
m = HueMapping(p, palette=palette)
assert m.lookup_table == dict(zip(hue_levels, palette))
palette = color_palette("Blues", len(hue_levels) + 1)
with pytest.warns(UserWarning):
HueMapping(p, palette=palette)
# Test dictionary of colors
palette = dict(zip(hue_levels, color_palette("Reds")))
m = HueMapping(p, palette=palette)
assert m.lookup_table == palette
palette.pop(hue_levels[0])
with pytest.raises(ValueError):
HueMapping(p, palette=palette)
# Test invalid palette
with pytest.raises(ValueError):
HueMapping(p, palette="not a valid palette")
# Test bad norm argument
with pytest.raises(ValueError):
HueMapping(p, norm="not a norm")
def test_hue_map_without_hue_dataa(self, long_df):
p = VectorPlotter(data=long_df, variables=dict(x="x", y="y"))
with pytest.warns(UserWarning, match="Ignoring `palette`"):
HueMapping(p, palette="viridis")
def test_saturation(self, long_df):
p = VectorPlotter(data=long_df, variables=dict(x="x", y="y", hue="a"))
levels = categorical_order(long_df["a"])
palette = color_palette("viridis", len(levels))
saturation = 0.8
m = HueMapping(p, palette=palette, saturation=saturation)
for i, color in enumerate(m(levels)):
assert mpl.colors.same_color(color, desaturate(palette[i], saturation))
class TestSizeMapping:
def test_plotter_default_init(self, long_df):
p = VectorPlotter(
data=long_df,
variables=dict(x="x", y="y"),
)
assert not hasattr(p, "_size_map")
p = VectorPlotter(
data=long_df,
variables=dict(x="x", y="y", size="a"),
)
assert isinstance(p._size_map, SizeMapping)
assert p._size_map.map_type == p.var_types["size"]
def test_plotter_customization(self, long_df):
p = VectorPlotter(
data=long_df,
variables=dict(x="x", y="y", size="a"),
)
sizes = [1, 4, 2]
size_order = ["b", "a", "c"]
p.map_size(sizes=sizes, order=size_order)
assert p._size_map.lookup_table == dict(zip(size_order, sizes))
assert p._size_map.levels == size_order
def test_size_map_null(self, flat_series, null_series):
p = VectorPlotter(variables=dict(x=flat_series, size=null_series))
m = HueMapping(p)
assert m.levels is None
assert m.map_type is None
assert m.norm is None
assert m.lookup_table is None
def test_map_size_numeric(self, long_df):
p = VectorPlotter(
data=long_df,
variables=dict(x="x", y="y", size="s"),
)
# Test default range of keys in the lookup table values
m = SizeMapping(p)
size_values = m.lookup_table.values()
value_range = min(size_values), max(size_values)
assert value_range == p._default_size_range
# Test specified range of size values
sizes = 1, 5
m = SizeMapping(p, sizes=sizes)
size_values = m.lookup_table.values()
assert min(size_values), max(size_values) == sizes
# Test size values with normalization range
norm = 1, 10
m = SizeMapping(p, sizes=sizes, norm=norm)
normalize = mpl.colors.Normalize(*norm, clip=True)
for key, val in m.lookup_table.items():
assert val == sizes[0] + (sizes[1] - sizes[0]) * normalize(key)
# Test size values with normalization object
norm = mpl.colors.LogNorm(1, 10, clip=False)
m = SizeMapping(p, sizes=sizes, norm=norm)
assert m.norm.clip
for key, val in m.lookup_table.items():
assert val == sizes[0] + (sizes[1] - sizes[0]) * norm(key)
# Test bad sizes argument
with pytest.raises(ValueError):
SizeMapping(p, sizes="bad_sizes")
# Test bad sizes argument
with pytest.raises(ValueError):
SizeMapping(p, sizes=(1, 2, 3))
# Test bad norm argument
with pytest.raises(ValueError):
SizeMapping(p, norm="bad_norm")
def test_map_size_categorical(self, long_df):
p = VectorPlotter(
data=long_df,
variables=dict(x="x", y="y", size="a"),
)
# Test specified size order
levels = p.plot_data["size"].unique()
sizes = [1, 4, 6]
order = [levels[1], levels[2], levels[0]]
m = SizeMapping(p, sizes=sizes, order=order)
assert m.lookup_table == dict(zip(order, sizes))
# Test list of sizes
order = categorical_order(p.plot_data["size"])
sizes = list(np.random.rand(len(levels)))
m = SizeMapping(p, sizes=sizes)
assert m.lookup_table == dict(zip(order, sizes))
# Test dict of sizes
sizes = dict(zip(levels, np.random.rand(len(levels))))
m = SizeMapping(p, sizes=sizes)
assert m.lookup_table == sizes
# Test specified size range
sizes = (2, 5)
m = SizeMapping(p, sizes=sizes)
values = np.linspace(*sizes, len(m.levels))[::-1]
assert m.lookup_table == dict(zip(m.levels, values))
# Test explicit categories
p = VectorPlotter(data=long_df, variables=dict(x="x", size="a_cat"))
m = SizeMapping(p)
assert m.levels == long_df["a_cat"].cat.categories.to_list()
assert m.map_type == "categorical"
# Test sizes list with wrong length
sizes = list(np.random.rand(len(levels) + 1))
with pytest.warns(UserWarning):
SizeMapping(p, sizes=sizes)
# Test sizes dict with missing levels
sizes = dict(zip(levels, np.random.rand(len(levels) - 1)))
with pytest.raises(ValueError):
SizeMapping(p, sizes=sizes)
# Test bad sizes argument
with pytest.raises(ValueError):
SizeMapping(p, sizes="bad_size")
def test_array_palette_deprecation(self, long_df):
p = VectorPlotter(long_df, {"y": "y", "hue": "s"})
pal = mpl.cm.Blues([.3, .5, .8])[:, :3]
with pytest.warns(UserWarning, match="Numpy array is not a supported type"):
m = HueMapping(p, pal)
assert m.palette == pal.tolist()
class TestStyleMapping:
def test_plotter_default_init(self, long_df):
p = VectorPlotter(
data=long_df,
variables=dict(x="x", y="y"),
)
assert not hasattr(p, "_map_style")
p = VectorPlotter(
data=long_df,
variables=dict(x="x", y="y", style="a"),
)
assert isinstance(p._style_map, StyleMapping)
def test_plotter_customization(self, long_df):
p = VectorPlotter(
data=long_df,
variables=dict(x="x", y="y", style="a"),
)
markers = ["s", "p", "h"]
style_order = ["b", "a", "c"]
p.map_style(markers=markers, order=style_order)
assert p._style_map.levels == style_order
assert p._style_map(style_order, "marker") == markers
def test_style_map_null(self, flat_series, null_series):
p = VectorPlotter(variables=dict(x=flat_series, style=null_series))
m = HueMapping(p)
assert m.levels is None
assert m.map_type is None
assert m.lookup_table is None
def test_map_style(self, long_df):
p = VectorPlotter(
data=long_df,
variables=dict(x="x", y="y", style="a"),
)
# Test defaults
m = StyleMapping(p, markers=True, dashes=True)
n = len(m.levels)
for key, dashes in zip(m.levels, unique_dashes(n)):
assert m(key, "dashes") == dashes
actual_marker_paths = {
k: mpl.markers.MarkerStyle(m(k, "marker")).get_path()
for k in m.levels
}
expected_marker_paths = {
k: mpl.markers.MarkerStyle(m).get_path()
for k, m in zip(m.levels, unique_markers(n))
}
assert actual_marker_paths == expected_marker_paths
# Test lists
markers, dashes = ["o", "s", "d"], [(1, 0), (1, 1), (2, 1, 3, 1)]
m = StyleMapping(p, markers=markers, dashes=dashes)
for key, mark, dash in zip(m.levels, markers, dashes):
assert m(key, "marker") == mark
assert m(key, "dashes") == dash
# Test dicts
markers = dict(zip(p.plot_data["style"].unique(), markers))
dashes = dict(zip(p.plot_data["style"].unique(), dashes))
m = StyleMapping(p, markers=markers, dashes=dashes)
for key in m.levels:
assert m(key, "marker") == markers[key]
assert m(key, "dashes") == dashes[key]
# Test explicit categories
p = VectorPlotter(data=long_df, variables=dict(x="x", style="a_cat"))
m = StyleMapping(p)
assert m.levels == long_df["a_cat"].cat.categories.to_list()
# Test style order with defaults
order = p.plot_data["style"].unique()[[1, 2, 0]]
m = StyleMapping(p, markers=True, dashes=True, order=order)
n = len(order)
for key, mark, dash in zip(order, unique_markers(n), unique_dashes(n)):
assert m(key, "dashes") == dash
assert m(key, "marker") == mark
obj = mpl.markers.MarkerStyle(mark)
path = obj.get_path().transformed(obj.get_transform())
assert_array_equal(m(key, "path").vertices, path.vertices)
# Test too many levels with style lists
with pytest.warns(UserWarning):
StyleMapping(p, markers=["o", "s"], dashes=False)
with pytest.warns(UserWarning):
StyleMapping(p, markers=False, dashes=[(2, 1)])
# Test missing keys with style dicts
markers, dashes = {"a": "o", "b": "s"}, False
with pytest.raises(ValueError):
StyleMapping(p, markers=markers, dashes=dashes)
markers, dashes = False, {"a": (1, 0), "b": (2, 1)}
with pytest.raises(ValueError):
StyleMapping(p, markers=markers, dashes=dashes)
# Test mixture of filled and unfilled markers
markers, dashes = ["o", "x", "s"], None
with pytest.raises(ValueError):
StyleMapping(p, markers=markers, dashes=dashes)
class TestVectorPlotter:
def test_flat_variables(self, flat_data):
p = VectorPlotter()
p.assign_variables(data=flat_data)
assert p.input_format == "wide"
assert list(p.variables) == ["x", "y"]
assert len(p.plot_data) == len(flat_data)
try:
expected_x = flat_data.index
expected_x_name = flat_data.index.name
except AttributeError:
expected_x = np.arange(len(flat_data))
expected_x_name = None
x = p.plot_data["x"]
assert_array_equal(x, expected_x)
expected_y = flat_data
expected_y_name = getattr(flat_data, "name", None)
y = p.plot_data["y"]
assert_array_equal(y, expected_y)
assert p.variables["x"] == expected_x_name
assert p.variables["y"] == expected_y_name
def test_long_df(self, long_df, long_variables):
p = VectorPlotter()
p.assign_variables(data=long_df, variables=long_variables)
assert p.input_format == "long"
assert p.variables == long_variables
for key, val in long_variables.items():
assert_array_equal(p.plot_data[key], long_df[val])
def test_long_df_with_index(self, long_df, long_variables):
p = VectorPlotter()
p.assign_variables(
data=long_df.set_index("a"),
variables=long_variables,
)
assert p.input_format == "long"
assert p.variables == long_variables
for key, val in long_variables.items():
assert_array_equal(p.plot_data[key], long_df[val])
def test_long_df_with_multiindex(self, long_df, long_variables):
p = VectorPlotter()
p.assign_variables(
data=long_df.set_index(["a", "x"]),
variables=long_variables,
)
assert p.input_format == "long"
assert p.variables == long_variables
for key, val in long_variables.items():
assert_array_equal(p.plot_data[key], long_df[val])
def test_long_dict(self, long_dict, long_variables):
p = VectorPlotter()
p.assign_variables(
data=long_dict,
variables=long_variables,
)
assert p.input_format == "long"
assert p.variables == long_variables
for key, val in long_variables.items():
assert_array_equal(p.plot_data[key], pd.Series(long_dict[val]))
@pytest.mark.parametrize(
"vector_type",
["series", "numpy", "list"],
)
def test_long_vectors(self, long_df, long_variables, vector_type):
variables = {key: long_df[val] for key, val in long_variables.items()}
if vector_type == "numpy":
variables = {key: val.to_numpy() for key, val in variables.items()}
elif vector_type == "list":
variables = {key: val.to_list() for key, val in variables.items()}
p = VectorPlotter()
p.assign_variables(variables=variables)
assert p.input_format == "long"
assert list(p.variables) == list(long_variables)
if vector_type == "series":
assert p.variables == long_variables
for key, val in long_variables.items():
assert_array_equal(p.plot_data[key], long_df[val])
def test_long_undefined_variables(self, long_df):
p = VectorPlotter()
with pytest.raises(ValueError):
p.assign_variables(
data=long_df, variables=dict(x="not_in_df"),
)
with pytest.raises(ValueError):
p.assign_variables(
data=long_df, variables=dict(x="x", y="not_in_df"),
)
with pytest.raises(ValueError):
p.assign_variables(
data=long_df, variables=dict(x="x", y="y", hue="not_in_df"),
)
@pytest.mark.parametrize(
"arg", [[], np.array([]), pd.DataFrame()],
)
def test_empty_data_input(self, arg):
p = VectorPlotter()
p.assign_variables(data=arg)
assert not p.variables
if not isinstance(arg, pd.DataFrame):
p = VectorPlotter()
p.assign_variables(variables=dict(x=arg, y=arg))
assert not p.variables
def test_units(self, repeated_df):
p = VectorPlotter()
p.assign_variables(
data=repeated_df,
variables=dict(x="x", y="y", units="u"),
)
assert_array_equal(p.plot_data["units"], repeated_df["u"])
@pytest.mark.parametrize("name", [3, 4.5])
def test_long_numeric_name(self, long_df, name):
long_df[name] = long_df["x"]
p = VectorPlotter()
p.assign_variables(data=long_df, variables={"x": name})
assert_array_equal(p.plot_data["x"], long_df[name])
assert p.variables["x"] == str(name)
def test_long_hierarchical_index(self, rng):
cols = pd.MultiIndex.from_product([["a"], ["x", "y"]])
data = rng.uniform(size=(50, 2))
df = pd.DataFrame(data, columns=cols)
name = ("a", "y")
var = "y"
p = VectorPlotter()
p.assign_variables(data=df, variables={var: name})
assert_array_equal(p.plot_data[var], df[name])
assert p.variables[var] == str(name)
def test_long_scalar_and_data(self, long_df):
val = 22
p = VectorPlotter(data=long_df, variables={"x": "x", "y": val})
assert (p.plot_data["y"] == val).all()
assert p.variables["y"] is None
def test_wide_semantic_error(self, wide_df):
err = "The following variable cannot be assigned with wide-form data: `hue`"
with pytest.raises(ValueError, match=err):
VectorPlotter(data=wide_df, variables={"hue": "a"})
def test_long_unknown_error(self, long_df):
err = "Could not interpret value `what` for `hue`"
with pytest.raises(ValueError, match=err):
VectorPlotter(data=long_df, variables={"x": "x", "hue": "what"})
def test_long_unmatched_size_error(self, long_df, flat_array):
err = "Length of ndarray vectors must match length of `data`"
with pytest.raises(ValueError, match=err):
VectorPlotter(data=long_df, variables={"x": "x", "hue": flat_array})
def test_wide_categorical_columns(self, wide_df):
wide_df.columns = pd.CategoricalIndex(wide_df.columns)
p = VectorPlotter(data=wide_df)
assert_array_equal(p.plot_data["hue"].unique(), ["a", "b", "c"])
def test_iter_data_quantitites(self, long_df):
p = VectorPlotter(
data=long_df,
variables=dict(x="x", y="y"),
)
out = p.iter_data("hue")
assert len(list(out)) == 1
var = "a"
n_subsets = len(long_df[var].unique())
semantics = ["hue", "size", "style"]
for semantic in semantics:
p = VectorPlotter(
data=long_df,
variables={"x": "x", "y": "y", semantic: var},
)
getattr(p, f"map_{semantic}")()
out = p.iter_data(semantics)
assert len(list(out)) == n_subsets
var = "a"
n_subsets = len(long_df[var].unique())
p = VectorPlotter(
data=long_df,
variables=dict(x="x", y="y", hue=var, style=var),
)
p.map_hue()
p.map_style()
out = p.iter_data(semantics)
assert len(list(out)) == n_subsets
# --
out = p.iter_data(semantics, reverse=True)
assert len(list(out)) == n_subsets
# --
var1, var2 = "a", "s"
n_subsets = len(long_df[var1].unique())
p = VectorPlotter(
data=long_df,
variables=dict(x="x", y="y", hue=var1, style=var2),
)
p.map_hue()
p.map_style()
out = p.iter_data(["hue"])
assert len(list(out)) == n_subsets
n_subsets = len(set(list(map(tuple, long_df[[var1, var2]].values))))
p = VectorPlotter(
data=long_df,
variables=dict(x="x", y="y", hue=var1, style=var2),
)
p.map_hue()
p.map_style()
out = p.iter_data(semantics)
assert len(list(out)) == n_subsets
p = VectorPlotter(
data=long_df,
variables=dict(x="x", y="y", hue=var1, size=var2, style=var1),
)
p.map_hue()
p.map_size()
p.map_style()
out = p.iter_data(semantics)
assert len(list(out)) == n_subsets
# --
var1, var2, var3 = "a", "s", "b"
cols = [var1, var2, var3]
n_subsets = len(set(list(map(tuple, long_df[cols].values))))
p = VectorPlotter(
data=long_df,
variables=dict(x="x", y="y", hue=var1, size=var2, style=var3),
)
p.map_hue()
p.map_size()
p.map_style()
out = p.iter_data(semantics)
assert len(list(out)) == n_subsets
def test_iter_data_keys(self, long_df):
semantics = ["hue", "size", "style"]
p = VectorPlotter(
data=long_df,
variables=dict(x="x", y="y"),
)
for sub_vars, _ in p.iter_data("hue"):
assert sub_vars == {}
# --
var = "a"
p = VectorPlotter(
data=long_df,
variables=dict(x="x", y="y", hue=var),
)
for sub_vars, _ in p.iter_data("hue"):
assert list(sub_vars) == ["hue"]
assert sub_vars["hue"] in long_df[var].values
p = VectorPlotter(
data=long_df,
variables=dict(x="x", y="y", size=var),
)
for sub_vars, _ in p.iter_data("size"):
assert list(sub_vars) == ["size"]
assert sub_vars["size"] in long_df[var].values
p = VectorPlotter(
data=long_df,
variables=dict(x="x", y="y", hue=var, style=var),
)
for sub_vars, _ in p.iter_data(semantics):
assert list(sub_vars) == ["hue", "style"]
assert sub_vars["hue"] in long_df[var].values
assert sub_vars["style"] in long_df[var].values
assert sub_vars["hue"] == sub_vars["style"]
var1, var2 = "a", "s"
p = VectorPlotter(
data=long_df,
variables=dict(x="x", y="y", hue=var1, size=var2),
)
for sub_vars, _ in p.iter_data(semantics):
assert list(sub_vars) == ["hue", "size"]
assert sub_vars["hue"] in long_df[var1].values
assert sub_vars["size"] in long_df[var2].values
semantics = ["hue", "col", "row"]
p = VectorPlotter(
data=long_df,
variables=dict(x="x", y="y", hue=var1, col=var2),
)
for sub_vars, _ in p.iter_data("hue"):
assert list(sub_vars) == ["hue", "col"]
assert sub_vars["hue"] in long_df[var1].values
assert sub_vars["col"] in long_df[var2].values
def test_iter_data_values(self, long_df):
p = VectorPlotter(
data=long_df,
variables=dict(x="x", y="y"),
)
p.sort = True
_, sub_data = next(p.iter_data("hue"))
assert_frame_equal(sub_data, p.plot_data)
p = VectorPlotter(
data=long_df,
variables=dict(x="x", y="y", hue="a"),
)
for sub_vars, sub_data in p.iter_data("hue"):
rows = p.plot_data["hue"] == sub_vars["hue"]
assert_frame_equal(sub_data, p.plot_data[rows])
p = VectorPlotter(
data=long_df,
variables=dict(x="x", y="y", hue="a", size="s"),
)
for sub_vars, sub_data in p.iter_data(["hue", "size"]):
rows = p.plot_data["hue"] == sub_vars["hue"]
rows &= p.plot_data["size"] == sub_vars["size"]
assert_frame_equal(sub_data, p.plot_data[rows])
def test_iter_data_reverse(self, long_df):
reversed_order = categorical_order(long_df["a"])[::-1]
p = VectorPlotter(
data=long_df,
variables=dict(x="x", y="y", hue="a")
)
iterator = p.iter_data("hue", reverse=True)
for i, (sub_vars, _) in enumerate(iterator):
assert sub_vars["hue"] == reversed_order[i]
def test_iter_data_dropna(self, null_df):
p = VectorPlotter(
data=null_df,
variables=dict(x="x", y="y", hue="a")
)
p.map_hue()
for _, sub_df in p.iter_data("hue"):
assert not sub_df.isna().any().any()
some_missing = False
for _, sub_df in p.iter_data("hue", dropna=False):
some_missing |= sub_df.isna().any().any()
assert some_missing
def test_axis_labels(self, long_df):
f, ax = plt.subplots()
p = VectorPlotter(data=long_df, variables=dict(x="a"))
p._add_axis_labels(ax)
assert ax.get_xlabel() == "a"
assert ax.get_ylabel() == ""
ax.clear()
p = VectorPlotter(data=long_df, variables=dict(y="a"))
p._add_axis_labels(ax)
assert ax.get_xlabel() == ""
assert ax.get_ylabel() == "a"
ax.clear()
p = VectorPlotter(data=long_df, variables=dict(x="a"))
p._add_axis_labels(ax, default_y="default")
assert ax.get_xlabel() == "a"
assert ax.get_ylabel() == "default"
ax.clear()
p = VectorPlotter(data=long_df, variables=dict(y="a"))
p._add_axis_labels(ax, default_x="default", default_y="default")
assert ax.get_xlabel() == "default"
assert ax.get_ylabel() == "a"
ax.clear()
p = VectorPlotter(data=long_df, variables=dict(x="x", y="a"))
ax.set(xlabel="existing", ylabel="also existing")
p._add_axis_labels(ax)
assert ax.get_xlabel() == "existing"
assert ax.get_ylabel() == "also existing"
f, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
p = VectorPlotter(data=long_df, variables=dict(x="x", y="y"))
p._add_axis_labels(ax1)
p._add_axis_labels(ax2)
assert ax1.get_xlabel() == "x"
assert ax1.get_ylabel() == "y"
assert ax1.yaxis.label.get_visible()
assert ax2.get_xlabel() == "x"
assert ax2.get_ylabel() == "y"
assert not ax2.yaxis.label.get_visible()
@pytest.mark.parametrize(
"variables",
[
dict(x="x", y="y"),
dict(x="x"),
dict(y="y"),
dict(x="t", y="y"),
dict(x="x", y="a"),
]
)
def test_attach_basics(self, long_df, variables):
_, ax = plt.subplots()
p = VectorPlotter(data=long_df, variables=variables)
p._attach(ax)
assert p.ax is ax
def test_attach_disallowed(self, long_df):
_, ax = plt.subplots()
p = VectorPlotter(data=long_df, variables={"x": "a"})
with pytest.raises(TypeError):
p._attach(ax, allowed_types="numeric")
with pytest.raises(TypeError):
p._attach(ax, allowed_types=["datetime", "numeric"])
_, ax = plt.subplots()
p = VectorPlotter(data=long_df, variables={"x": "x"})
with pytest.raises(TypeError):
p._attach(ax, allowed_types="categorical")
_, ax = plt.subplots()
p = VectorPlotter(data=long_df, variables={"x": "x", "y": "t"})
with pytest.raises(TypeError):
p._attach(ax, allowed_types=["numeric", "categorical"])
def test_attach_log_scale(self, long_df):
_, ax = plt.subplots()
p = VectorPlotter(data=long_df, variables={"x": "x"})
p._attach(ax, log_scale=True)
assert ax.xaxis.get_scale() == "log"
assert ax.yaxis.get_scale() == "linear"
_, ax = plt.subplots()
p = VectorPlotter(data=long_df, variables={"x": "x"})
p._attach(ax, log_scale=2)
assert ax.xaxis.get_scale() == "log"
assert ax.yaxis.get_scale() == "linear"
_, ax = plt.subplots()
p = VectorPlotter(data=long_df, variables={"y": "y"})
p._attach(ax, log_scale=True)
assert ax.xaxis.get_scale() == "linear"
assert ax.yaxis.get_scale() == "log"
_, ax = plt.subplots()
p = VectorPlotter(data=long_df, variables={"x": "x", "y": "y"})
p._attach(ax, log_scale=True)
assert ax.xaxis.get_scale() == "log"
assert ax.yaxis.get_scale() == "log"
_, ax = plt.subplots()
p = VectorPlotter(data=long_df, variables={"x": "x", "y": "y"})
p._attach(ax, log_scale=(True, False))
assert ax.xaxis.get_scale() == "log"
assert ax.yaxis.get_scale() == "linear"
_, ax = plt.subplots()
p = VectorPlotter(data=long_df, variables={"x": "x", "y": "y"})
p._attach(ax, log_scale=(False, 2))
assert ax.xaxis.get_scale() == "linear"
assert ax.yaxis.get_scale() == "log"
_, ax = plt.subplots()
p = VectorPlotter(data=long_df, variables={"x": "a", "y": "y"})
p._attach(ax, log_scale=True)
assert ax.xaxis.get_scale() == "linear"
assert ax.yaxis.get_scale() == "log"
_, ax = plt.subplots()
p = VectorPlotter(data=long_df, variables={"x": "x", "y": "t"})
p._attach(ax, log_scale=True)
assert ax.xaxis.get_scale() == "log"
assert ax.yaxis.get_scale() == "linear"
_, ax = plt.subplots()
p = VectorPlotter(data=long_df, variables={"x": "a", "y": "b"})
p._attach(ax, log_scale=True)
assert ax.xaxis.get_scale() == "linear"
assert ax.yaxis.get_scale() == "linear"
def test_attach_converters(self, long_df):
_, ax = plt.subplots()
p = VectorPlotter(data=long_df, variables={"x": "x", "y": "t"})
p._attach(ax)
assert get_converter(ax.xaxis) is None
assert "Date" in get_converter(ax.yaxis).__class__.__name__
_, ax = plt.subplots()
p = VectorPlotter(data=long_df, variables={"x": "a", "y": "y"})
p._attach(ax)
assert "CategoryConverter" in get_converter(ax.xaxis).__class__.__name__
assert get_converter(ax.yaxis) is None
def test_attach_facets(self, long_df):
g = FacetGrid(long_df, col="a")
p = VectorPlotter(data=long_df, variables={"x": "x", "col": "a"})
p._attach(g)
assert p.ax is None
assert p.facets == g
def test_scale_transform_identity(self, long_df):
_, ax = plt.subplots()
p = VectorPlotter(data=long_df, variables={"x": "x"})
p._attach(ax)
fwd, inv = p._get_scale_transforms("x")
x = np.arange(1, 10)
assert_array_equal(fwd(x), x)
assert_array_equal(inv(x), x)
def test_scale_transform_identity_facets(self, long_df):
g = FacetGrid(long_df, col="a")
p = VectorPlotter(data=long_df, variables={"x": "x", "col": "a"})
p._attach(g)
fwd, inv = p._get_scale_transforms("x")
x = np.arange(1, 10)
assert_array_equal(fwd(x), x)
assert_array_equal(inv(x), x)
def test_scale_transform_log(self, long_df):
_, ax = plt.subplots()
ax.set_xscale("log")
p = VectorPlotter(data=long_df, variables={"x": "x"})
p._attach(ax)
fwd, inv = p._get_scale_transforms("x")
x = np.arange(1, 4)
assert_array_almost_equal(fwd(x), np.log10(x))
assert_array_almost_equal(inv(x), 10 ** x)
def test_scale_transform_facets(self, long_df):
g = FacetGrid(long_df, col="a")
p = VectorPlotter(data=long_df, variables={"x": "x", "col": "a"})
p._attach(g)
fwd, inv = p._get_scale_transforms("x")
x = np.arange(4)
assert_array_equal(inv(fwd(x)), x)
def test_scale_transform_mixed_facets(self, long_df):
g = FacetGrid(long_df, col="a", sharex=False)
g.axes.flat[0].set_xscale("log")
p = VectorPlotter(data=long_df, variables={"x": "x", "col": "a"})
p._attach(g)
err = "Cannot determine transform with mixed scales on faceted axes"
with pytest.raises(RuntimeError, match=err):
p._get_scale_transforms("x")
def test_attach_shared_axes(self, long_df):
g = FacetGrid(long_df)
p = VectorPlotter(data=long_df, variables={"x": "x", "y": "y"})
p._attach(g)
assert p.converters["x"].nunique() == 1
g = FacetGrid(long_df, col="a")
p = VectorPlotter(data=long_df, variables={"x": "x", "y": "y", "col": "a"})
p._attach(g)
assert p.converters["x"].nunique() == 1
assert p.converters["y"].nunique() == 1
g = FacetGrid(long_df, col="a", sharex=False)
p = VectorPlotter(data=long_df, variables={"x": "x", "y": "y", "col": "a"})
p._attach(g)
assert p.converters["x"].nunique() == p.plot_data["col"].nunique()
assert p.converters["x"].groupby(p.plot_data["col"]).nunique().max() == 1
assert p.converters["y"].nunique() == 1
g = FacetGrid(long_df, col="a", sharex=False, col_wrap=2)
p = VectorPlotter(data=long_df, variables={"x": "x", "y": "y", "col": "a"})
p._attach(g)
assert p.converters["x"].nunique() == p.plot_data["col"].nunique()
assert p.converters["x"].groupby(p.plot_data["col"]).nunique().max() == 1
assert p.converters["y"].nunique() == 1
g = FacetGrid(long_df, col="a", row="b")
p = VectorPlotter(
data=long_df, variables={"x": "x", "y": "y", "col": "a", "row": "b"},
)
p._attach(g)
assert p.converters["x"].nunique() == 1
assert p.converters["y"].nunique() == 1
g = FacetGrid(long_df, col="a", row="b", sharex=False)
p = VectorPlotter(
data=long_df, variables={"x": "x", "y": "y", "col": "a", "row": "b"},
)
p._attach(g)
assert p.converters["x"].nunique() == len(g.axes.flat)
assert p.converters["y"].nunique() == 1
g = FacetGrid(long_df, col="a", row="b", sharex="col")
p = VectorPlotter(
data=long_df, variables={"x": "x", "y": "y", "col": "a", "row": "b"},
)
p._attach(g)
assert p.converters["x"].nunique() == p.plot_data["col"].nunique()
assert p.converters["x"].groupby(p.plot_data["col"]).nunique().max() == 1
assert p.converters["y"].nunique() == 1
g = FacetGrid(long_df, col="a", row="b", sharey="row")
p = VectorPlotter(
data=long_df, variables={"x": "x", "y": "y", "col": "a", "row": "b"},
)
p._attach(g)
assert p.converters["x"].nunique() == 1
assert p.converters["y"].nunique() == p.plot_data["row"].nunique()
assert p.converters["y"].groupby(p.plot_data["row"]).nunique().max() == 1
def test_get_axes_single(self, long_df):
ax = plt.figure().subplots()
p = VectorPlotter(data=long_df, variables={"x": "x", "hue": "a"})
p._attach(ax)
assert p._get_axes({"hue": "a"}) is ax
def test_get_axes_facets(self, long_df):
g = FacetGrid(long_df, col="a")
p = VectorPlotter(data=long_df, variables={"x": "x", "col": "a"})
p._attach(g)
assert p._get_axes({"col": "b"}) is g.axes_dict["b"]
g = FacetGrid(long_df, col="a", row="c")
p = VectorPlotter(
data=long_df, variables={"x": "x", "col": "a", "row": "c"}
)
p._attach(g)
assert p._get_axes({"row": 1, "col": "b"}) is g.axes_dict[(1, "b")]
def test_comp_data(self, long_df):
p = VectorPlotter(data=long_df, variables={"x": "x", "y": "t"})
# We have disabled this check for now, while it remains part of
# the internal API, because it will require updating a number of tests
# with pytest.raises(AttributeError):
# p.comp_data
_, ax = plt.subplots()
p._attach(ax)
assert_array_equal(p.comp_data["x"], p.plot_data["x"])
assert_array_equal(
p.comp_data["y"], ax.yaxis.convert_units(p.plot_data["y"])
)
p = VectorPlotter(data=long_df, variables={"x": "a"})
_, ax = plt.subplots()
p._attach(ax)
assert_array_equal(
p.comp_data["x"], ax.xaxis.convert_units(p.plot_data["x"])
)
def test_comp_data_log(self, long_df):
p = VectorPlotter(data=long_df, variables={"x": "z", "y": "y"})
_, ax = plt.subplots()
p._attach(ax, log_scale=(True, False))
assert_array_equal(
p.comp_data["x"], np.log10(p.plot_data["x"])
)
assert_array_equal(p.comp_data["y"], p.plot_data["y"])
def test_comp_data_category_order(self):
s = (pd.Series(["a", "b", "c", "a"], dtype="category")
.cat.set_categories(["b", "c", "a"], ordered=True))
p = VectorPlotter(variables={"x": s})
_, ax = plt.subplots()
p._attach(ax)
assert_array_equal(
p.comp_data["x"],
[2, 0, 1, 2],
)
@pytest.fixture(
params=itertools.product(
[None, np.nan, pd.NA],
["numeric", "category", "datetime"],
)
)
def comp_data_missing_fixture(self, request):
# This fixture holds the logic for parameterizing
# the following test (test_comp_data_missing)
NA, var_type = request.param
comp_data = [0, 1, np.nan, 2, np.nan, 1]
if var_type == "numeric":
orig_data = [0, 1, NA, 2, np.inf, 1]
elif var_type == "category":
orig_data = ["a", "b", NA, "c", pd.NA, "b"]
elif var_type == "datetime":
# Use 1-based numbers to avoid issue on matplotlib<3.2
# Could simplify the test a bit when we roll off that version
comp_data = [1, 2, np.nan, 3, np.nan, 2]
numbers = [1, 2, 3, 2]
orig_data = mpl.dates.num2date(numbers)
orig_data.insert(2, NA)
orig_data.insert(4, np.inf)
return orig_data, comp_data
def test_comp_data_missing(self, comp_data_missing_fixture):
orig_data, comp_data = comp_data_missing_fixture
p = VectorPlotter(variables={"x": orig_data})
ax = plt.figure().subplots()
p._attach(ax)
assert_array_equal(p.comp_data["x"], comp_data)
assert p.comp_data["x"].dtype == "float"
def test_comp_data_duplicate_index(self):
x = pd.Series([1, 2, 3, 4, 5], [1, 1, 1, 2, 2])
p = VectorPlotter(variables={"x": x})
ax = plt.figure().subplots()
p._attach(ax)
assert_array_equal(p.comp_data["x"], x)
def test_comp_data_nullable_dtype(self):
x = pd.Series([1, 2, 3, 4], dtype="Int64")
p = VectorPlotter(variables={"x": x})
ax = plt.figure().subplots()
p._attach(ax)
assert_array_equal(p.comp_data["x"], x)
assert p.comp_data["x"].dtype == "float"
def test_var_order(self, long_df):
order = ["c", "b", "a"]
for var in ["hue", "size", "style"]:
p = VectorPlotter(data=long_df, variables={"x": "x", var: "a"})
mapper = getattr(p, f"map_{var}")
mapper(order=order)
assert p.var_levels[var] == order
def test_scale_native(self, long_df):
p = VectorPlotter(data=long_df, variables={"x": "x"})
with pytest.raises(NotImplementedError):
p.scale_native("x")
def test_scale_numeric(self, long_df):
p = VectorPlotter(data=long_df, variables={"y": "y"})
with pytest.raises(NotImplementedError):
p.scale_numeric("y")
def test_scale_datetime(self, long_df):
p = VectorPlotter(data=long_df, variables={"x": "t"})
with pytest.raises(NotImplementedError):
p.scale_datetime("x")
def test_scale_categorical(self, long_df):
p = VectorPlotter(data=long_df, variables={"x": "x"})
p.scale_categorical("y")
assert p.variables["y"] is None
assert p.var_types["y"] == "categorical"
assert (p.plot_data["y"] == "").all()
p = VectorPlotter(data=long_df, variables={"x": "s"})
p.scale_categorical("x")
assert p.var_types["x"] == "categorical"
assert hasattr(p.plot_data["x"], "str")
assert not p._var_ordered["x"]
assert p.plot_data["x"].is_monotonic_increasing
assert_array_equal(p.var_levels["x"], p.plot_data["x"].unique())
p = VectorPlotter(data=long_df, variables={"x": "a"})
p.scale_categorical("x")
assert not p._var_ordered["x"]
assert_array_equal(p.var_levels["x"], categorical_order(long_df["a"]))
p = VectorPlotter(data=long_df, variables={"x": "a_cat"})
p.scale_categorical("x")
assert p._var_ordered["x"]
assert_array_equal(p.var_levels["x"], categorical_order(long_df["a_cat"]))
p = VectorPlotter(data=long_df, variables={"x": "a"})
order = np.roll(long_df["a"].unique(), 1)
p.scale_categorical("x", order=order)
assert p._var_ordered["x"]
assert_array_equal(p.var_levels["x"], order)
p = VectorPlotter(data=long_df, variables={"x": "s"})
p.scale_categorical("x", formatter=lambda x: f"{x:%}")
assert p.plot_data["x"].str.endswith("%").all()
assert all(s.endswith("%") for s in p.var_levels["x"])
class TestCoreFunc:
def test_unique_dashes(self):
n = 24
dashes = unique_dashes(n)
assert len(dashes) == n
assert len(set(dashes)) == n
assert dashes[0] == ""
for spec in dashes[1:]:
assert isinstance(spec, tuple)
assert not len(spec) % 2
def test_unique_markers(self):
n = 24
markers = unique_markers(n)
assert len(markers) == n
assert len(set(markers)) == n
for m in markers:
assert mpl.markers.MarkerStyle(m).is_filled()
def test_variable_type(self):
s = pd.Series([1., 2., 3.])
assert variable_type(s) == "numeric"
assert variable_type(s.astype(int)) == "numeric"
assert variable_type(s.astype(object)) == "numeric"
assert variable_type(s.to_numpy()) == "numeric"
assert variable_type(s.to_list()) == "numeric"
s = pd.Series([1, 2, 3, np.nan], dtype=object)
assert variable_type(s) == "numeric"
s = pd.Series([np.nan, np.nan])
assert variable_type(s) == "numeric"
s = pd.Series([pd.NA, pd.NA])
assert variable_type(s) == "numeric"
s = pd.Series([1, 2, pd.NA], dtype="Int64")
assert variable_type(s) == "numeric"
s = pd.Series(["1", "2", "3"])
assert variable_type(s) == "categorical"
assert variable_type(s.to_numpy()) == "categorical"
assert variable_type(s.to_list()) == "categorical"
# This should arguably be datmetime, but we don't currently handle it correctly
# Test is mainly asserting that this doesn't fail on the boolean check.
s = pd.timedelta_range(1, periods=3, freq="D").to_series()
assert variable_type(s) == "categorical"
s = pd.Series([True, False, False])
assert variable_type(s) == "numeric"
assert variable_type(s, boolean_type="categorical") == "categorical"
s_cat = s.astype("category")
assert variable_type(s_cat, boolean_type="categorical") == "categorical"
assert variable_type(s_cat, boolean_type="numeric") == "categorical"
s = pd.Series([pd.Timestamp(1), pd.Timestamp(2)])
assert variable_type(s) == "datetime"
assert variable_type(s.astype(object)) == "datetime"
assert variable_type(s.to_numpy()) == "datetime"
assert variable_type(s.to_list()) == "datetime"
def test_infer_orient(self):
nums = pd.Series(np.arange(6))
cats = pd.Series(["a", "b"] * 3)
dates = pd.date_range("1999-09-22", "2006-05-14", 6)
assert infer_orient(cats, nums) == "x"
assert infer_orient(nums, cats) == "y"
assert infer_orient(cats, dates, require_numeric=False) == "x"
assert infer_orient(dates, cats, require_numeric=False) == "y"
assert infer_orient(nums, None) == "y"
with pytest.warns(UserWarning, match="Vertical .+ `x`"):
assert infer_orient(nums, None, "v") == "y"
assert infer_orient(None, nums) == "x"
with pytest.warns(UserWarning, match="Horizontal .+ `y`"):
assert infer_orient(None, nums, "h") == "x"
infer_orient(cats, None, require_numeric=False) == "y"
with pytest.raises(TypeError, match="Horizontal .+ `x`"):
infer_orient(cats, None)
infer_orient(cats, None, require_numeric=False) == "x"
with pytest.raises(TypeError, match="Vertical .+ `y`"):
infer_orient(None, cats)
assert infer_orient(nums, nums, "vert") == "x"
assert infer_orient(nums, nums, "hori") == "y"
assert infer_orient(cats, cats, "h", require_numeric=False) == "y"
assert infer_orient(cats, cats, "v", require_numeric=False) == "x"
assert infer_orient(cats, cats, require_numeric=False) == "x"
with pytest.raises(TypeError, match="Vertical .+ `y`"):
infer_orient(cats, cats, "x")
with pytest.raises(TypeError, match="Horizontal .+ `x`"):
infer_orient(cats, cats, "y")
with pytest.raises(TypeError, match="Neither"):
infer_orient(cats, cats)
with pytest.raises(ValueError, match="`orient` must start with"):
infer_orient(cats, nums, orient="bad value")
def test_categorical_order(self):
x = ["a", "c", "c", "b", "a", "d"]
y = [3, 2, 5, 1, 4]
order = ["a", "b", "c", "d"]
out = categorical_order(x)
assert out == ["a", "c", "b", "d"]
out = categorical_order(x, order)
assert out == order
out = categorical_order(x, ["b", "a"])
assert out == ["b", "a"]
out = categorical_order(np.array(x))
assert out == ["a", "c", "b", "d"]
out = categorical_order(pd.Series(x))
assert out == ["a", "c", "b", "d"]
out = categorical_order(y)
assert out == [1, 2, 3, 4, 5]
out = categorical_order(np.array(y))
assert out == [1, 2, 3, 4, 5]
out = categorical_order(pd.Series(y))
assert out == [1, 2, 3, 4, 5]
x = pd.Categorical(x, order)
out = categorical_order(x)
assert out == list(x.categories)
x = pd.Series(x)
out = categorical_order(x)
assert out == list(x.cat.categories)
out = categorical_order(x, ["b", "a"])
assert out == ["b", "a"]
x = ["a", np.nan, "c", "c", "b", "a", "d"]
out = categorical_order(x)
assert out == ["a", "c", "b", "d"]
================================================
FILE: tests/test_categorical.py
================================================
import itertools
from functools import partial
import warnings
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import same_color, to_rgb, to_rgba
import pytest
from pytest import approx
from numpy.testing import (
assert_array_equal,
assert_array_less,
assert_array_almost_equal,
)
from seaborn import categorical as cat
from seaborn._base import categorical_order
from seaborn._compat import get_colormap, get_legend_handles
from seaborn._testing import assert_plots_equal
from seaborn.categorical import (
_CategoricalPlotter,
Beeswarm,
BoxPlotContainer,
catplot,
barplot,
boxplot,
boxenplot,
countplot,
pointplot,
stripplot,
swarmplot,
violinplot,
)
from seaborn.palettes import color_palette
from seaborn.utils import _draw_figure, _version_predates, desaturate
PLOT_FUNCS = [
catplot,
barplot,
boxplot,
boxenplot,
pointplot,
stripplot,
swarmplot,
violinplot,
]
class TestCategoricalPlotterNew:
@pytest.mark.parametrize(
"func,kwargs",
itertools.product(
PLOT_FUNCS,
[
{"x": "x", "y": "a"},
{"x": "a", "y": "y"},
{"x": "y"},
{"y": "x"},
],
),
)
def test_axis_labels(self, long_df, func, kwargs):
func(data=long_df, **kwargs)
ax = plt.gca()
for axis in "xy":
val = kwargs.get(axis, "")
label_func = getattr(ax, f"get_{axis}label")
assert label_func() == val
@pytest.mark.parametrize("func", PLOT_FUNCS)
def test_empty(self, func):
func()
ax = plt.gca()
assert not ax.collections
assert not ax.patches
assert not ax.lines
func(x=[], y=[])
ax = plt.gca()
assert not ax.collections
assert not ax.patches
assert not ax.lines
def test_redundant_hue_backcompat(self, long_df):
p = _CategoricalPlotter(
data=long_df,
variables={"x": "s", "y": "y"},
)
color = None
palette = dict(zip(long_df["s"].unique(), color_palette()))
hue_order = None
palette, _ = p._hue_backcompat(color, palette, hue_order, force_hue=True)
assert p.variables["hue"] == "s"
assert_array_equal(p.plot_data["hue"], p.plot_data["x"])
assert all(isinstance(k, str) for k in palette)
class SharedAxesLevelTests:
def orient_indices(self, orient):
pos_idx = ["x", "y"].index(orient)
val_idx = ["y", "x"].index(orient)
return pos_idx, val_idx
@pytest.fixture
def common_kws(self):
return {}
@pytest.mark.parametrize("orient", ["x", "y"])
def test_labels_long(self, long_df, orient):
depend = {"x": "y", "y": "x"}[orient]
kws = {orient: "a", depend: "y", "hue": "b"}
ax = self.func(long_df, **kws)
# To populate texts; only needed on older matplotlibs
_draw_figure(ax.figure)
assert getattr(ax, f"get_{orient}label")() == kws[orient]
assert getattr(ax, f"get_{depend}label")() == kws[depend]
get_ori_labels = getattr(ax, f"get_{orient}ticklabels")
ori_labels = [t.get_text() for t in get_ori_labels()]
ori_levels = categorical_order(long_df[kws[orient]])
assert ori_labels == ori_levels
legend = ax.get_legend()
assert legend.get_title().get_text() == kws["hue"]
hue_labels = [t.get_text() for t in legend.texts]
hue_levels = categorical_order(long_df[kws["hue"]])
assert hue_labels == hue_levels
def test_labels_wide(self, wide_df):
wide_df = wide_df.rename_axis("cols", axis=1)
ax = self.func(wide_df)
# To populate texts; only needed on older matplotlibs
_draw_figure(ax.figure)
assert ax.get_xlabel() == wide_df.columns.name
labels = [t.get_text() for t in ax.get_xticklabels()]
for label, level in zip(labels, wide_df.columns):
assert label == level
def test_labels_hue_order(self, long_df):
hue_var = "b"
hue_order = categorical_order(long_df[hue_var])[::-1]
ax = self.func(long_df, x="a", y="y", hue=hue_var, hue_order=hue_order)
legend = ax.get_legend()
hue_labels = [t.get_text() for t in legend.texts]
assert hue_labels == hue_order
def test_color(self, long_df, common_kws):
common_kws.update(data=long_df, x="a", y="y")
ax = plt.figure().subplots()
self.func(ax=ax, **common_kws)
assert self.get_last_color(ax) == to_rgba("C0")
ax = plt.figure().subplots()
self.func(ax=ax, **common_kws)
self.func(ax=ax, **common_kws)
assert self.get_last_color(ax) == to_rgba("C1")
ax = plt.figure().subplots()
self.func(color="C2", ax=ax, **common_kws)
assert self.get_last_color(ax) == to_rgba("C2")
ax = plt.figure().subplots()
self.func(color="C3", ax=ax, **common_kws)
assert self.get_last_color(ax) == to_rgba("C3")
def test_two_calls(self):
ax = plt.figure().subplots()
self.func(x=["a", "b", "c"], y=[1, 2, 3], ax=ax)
self.func(x=["e", "f"], y=[4, 5], ax=ax)
assert ax.get_xlim() == (-.5, 4.5)
def test_redundant_hue_legend(self, long_df):
ax = self.func(long_df, x="a", y="y", hue="a")
assert ax.get_legend() is None
ax.clear()
self.func(long_df, x="a", y="y", hue="a", legend=True)
assert ax.get_legend() is not None
@pytest.mark.parametrize("orient", ["x", "y"])
def test_log_scale(self, long_df, orient):
depvar = {"x": "y", "y": "x"}[orient]
variables = {orient: "a", depvar: "z"}
ax = self.func(long_df, **variables, log_scale=True)
assert getattr(ax, f"get_{orient}scale")() == "linear"
assert getattr(ax, f"get_{depvar}scale")() == "log"
class SharedScatterTests(SharedAxesLevelTests):
"""Tests functionality common to stripplot and swarmplot."""
def get_last_color(self, ax):
colors = ax.collections[-1].get_facecolors()
unique_colors = np.unique(colors, axis=0)
assert len(unique_colors) == 1
return to_rgba(unique_colors.squeeze())
# ------------------------------------------------------------------------------
def test_color(self, long_df, common_kws):
super().test_color(long_df, common_kws)
ax = plt.figure().subplots()
self.func(data=long_df, x="a", y="y", facecolor="C4", ax=ax)
assert self.get_last_color(ax) == to_rgba("C4")
ax = plt.figure().subplots()
self.func(data=long_df, x="a", y="y", fc="C5", ax=ax)
assert self.get_last_color(ax) == to_rgba("C5")
def test_supplied_color_array(self, long_df):
cmap = get_colormap("Blues")
norm = mpl.colors.Normalize()
colors = cmap(norm(long_df["y"].to_numpy()))
keys = ["c", "fc", "facecolor", "facecolors"]
for key in keys:
ax = plt.figure().subplots()
self.func(x=long_df["y"], **{key: colors})
_draw_figure(ax.figure)
assert_array_equal(ax.collections[0].get_facecolors(), colors)
ax = plt.figure().subplots()
self.func(x=long_df["y"], c=long_df["y"], cmap=cmap)
_draw_figure(ax.figure)
assert_array_equal(ax.collections[0].get_facecolors(), colors)
def test_unfilled_marker(self, long_df):
with warnings.catch_warnings():
warnings.simplefilter("error", UserWarning)
ax = self.func(long_df, x="y", y="a", marker="x", color="r")
for points in ax.collections:
assert same_color(points.get_facecolors().squeeze(), "r")
assert same_color(points.get_edgecolors().squeeze(), "r")
@pytest.mark.parametrize(
"orient,data_type", [
("h", "dataframe"), ("h", "dict"),
("v", "dataframe"), ("v", "dict"),
("y", "dataframe"), ("y", "dict"),
("x", "dataframe"), ("x", "dict"),
]
)
def test_wide(self, wide_df, orient, data_type):
if data_type == "dict":
wide_df = {k: v.to_numpy() for k, v in wide_df.items()}
ax = self.func(data=wide_df, orient=orient, color="C0")
_draw_figure(ax.figure)
cat_idx = 0 if orient in "vx" else 1
val_idx = int(not cat_idx)
axis_objs = ax.xaxis, ax.yaxis
cat_axis = axis_objs[cat_idx]
for i, label in enumerate(cat_axis.get_majorticklabels()):
key = label.get_text()
points = ax.collections[i]
point_pos = points.get_offsets().T
val_pos = point_pos[val_idx]
cat_pos = point_pos[cat_idx]
assert_array_equal(cat_pos.round(), i)
assert_array_equal(val_pos, wide_df[key])
for point_color in points.get_facecolors():
assert tuple(point_color) == to_rgba("C0")
@pytest.mark.parametrize("orient", ["h", "v"])
def test_flat(self, flat_series, orient):
ax = self.func(data=flat_series, orient=orient)
_draw_figure(ax.figure)
cat_idx = ["v", "h"].index(orient)
val_idx = int(not cat_idx)
points = ax.collections[0]
pos = points.get_offsets().T
assert_array_equal(pos[cat_idx].round(), np.zeros(len(flat_series)))
assert_array_equal(pos[val_idx], flat_series)
@pytest.mark.parametrize(
"variables,orient",
[
# Order matters for assigning to x/y
({"cat": "a", "val": "y", "hue": None}, None),
({"val": "y", "cat": "a", "hue": None}, None),
({"cat": "a", "val": "y", "hue": "a"}, None),
({"val": "y", "cat": "a", "hue": "a"}, None),
({"cat": "a", "val": "y", "hue": "b"}, None),
({"val": "y", "cat": "a", "hue": "x"}, None),
({"cat": "s", "val": "y", "hue": None}, None),
({"val": "y", "cat": "s", "hue": None}, "h"),
({"cat": "a", "val": "b", "hue": None}, None),
({"val": "a", "cat": "b", "hue": None}, "h"),
({"cat": "a", "val": "t", "hue": None}, None),
({"val": "t", "cat": "a", "hue": None}, None),
({"cat": "d", "val": "y", "hue": None}, None),
({"val": "y", "cat": "d", "hue": None}, None),
({"cat": "a_cat", "val": "y", "hue": None}, None),
({"val": "y", "cat": "s_cat", "hue": None}, None),
],
)
def test_positions(self, long_df, variables, orient):
cat_var = variables["cat"]
val_var = variables["val"]
hue_var = variables["hue"]
var_names = list(variables.values())
x_var, y_var, *_ = var_names
ax = self.func(
data=long_df, x=x_var, y=y_var, hue=hue_var, orient=orient,
)
_draw_figure(ax.figure)
cat_idx = var_names.index(cat_var)
val_idx = var_names.index(val_var)
axis_objs = ax.xaxis, ax.yaxis
cat_axis = axis_objs[cat_idx]
val_axis = axis_objs[val_idx]
cat_data = long_df[cat_var]
cat_levels = categorical_order(cat_data)
for i, label in enumerate(cat_levels):
vals = long_df.loc[cat_data == label, val_var]
points = ax.collections[i].get_offsets().T
cat_pos = points[var_names.index(cat_var)]
val_pos = points[var_names.index(val_var)]
assert_array_equal(val_pos, val_axis.convert_units(vals))
assert_array_equal(cat_pos.round(), i)
assert 0 <= np.ptp(cat_pos) <= .8
label = pd.Index([label]).astype(str)[0]
assert cat_axis.get_majorticklabels()[i].get_text() == label
@pytest.mark.parametrize(
"variables",
[
# Order matters for assigning to x/y
{"cat": "a", "val": "y", "hue": "b"},
{"val": "y", "cat": "a", "hue": "c"},
{"cat": "a", "val": "y", "hue": "f"},
],
)
def test_positions_dodged(self, long_df, variables):
cat_var = variables["cat"]
val_var = variables["val"]
hue_var = variables["hue"]
var_names = list(variables.values())
x_var, y_var, *_ = var_names
ax = self.func(
data=long_df, x=x_var, y=y_var, hue=hue_var, dodge=True,
)
cat_vals = categorical_order(long_df[cat_var])
hue_vals = categorical_order(long_df[hue_var])
n_hue = len(hue_vals)
offsets = np.linspace(0, .8, n_hue + 1)[:-1]
offsets -= offsets.mean()
nest_width = .8 / n_hue
for i, cat_val in enumerate(cat_vals):
for j, hue_val in enumerate(hue_vals):
rows = (long_df[cat_var] == cat_val) & (long_df[hue_var] == hue_val)
vals = long_df.loc[rows, val_var]
points = ax.collections[n_hue * i + j].get_offsets().T
cat_pos = points[var_names.index(cat_var)]
val_pos = points[var_names.index(val_var)]
if pd.api.types.is_datetime64_any_dtype(vals):
vals = mpl.dates.date2num(vals)
assert_array_equal(val_pos, vals)
assert_array_equal(cat_pos.round(), i)
assert_array_equal((cat_pos - (i + offsets[j])).round() / nest_width, 0)
assert 0 <= np.ptp(cat_pos) <= nest_width
@pytest.mark.parametrize("cat_var", ["a", "s", "d"])
def test_positions_unfixed(self, long_df, cat_var):
long_df = long_df.sort_values(cat_var)
kws = dict(size=.001)
if "stripplot" in str(self.func): # can't use __name__ with partial
kws["jitter"] = False
ax = self.func(data=long_df, x=cat_var, y="y", native_scale=True, **kws)
for i, (cat_level, cat_data) in enumerate(long_df.groupby(cat_var)):
points = ax.collections[i].get_offsets().T
cat_pos = points[0]
val_pos = points[1]
assert_array_equal(val_pos, cat_data["y"])
comp_level = np.squeeze(ax.xaxis.convert_units(cat_level)).item()
assert_array_equal(cat_pos.round(), comp_level)
@pytest.mark.parametrize(
"x_type,order",
[
(str, None),
(str, ["a", "b", "c"]),
(str, ["c", "a"]),
(str, ["a", "b", "c", "d"]),
(int, None),
(int, [3, 1, 2]),
(int, [3, 1]),
(int, [1, 2, 3, 4]),
(int, ["3", "1", "2"]),
]
)
def test_order(self, x_type, order):
if x_type is str:
x = ["b", "a", "c"]
else:
x = [2, 1, 3]
y = [1, 2, 3]
ax = self.func(x=x, y=y, order=order)
_draw_figure(ax.figure)
if order is None:
order = x
if x_type is int:
order = np.sort(order)
assert len(ax.collections) == len(order)
tick_labels = ax.xaxis.get_majorticklabels()
assert ax.get_xlim()[1] == (len(order) - .5)
for i, points in enumerate(ax.collections):
cat = order[i]
assert tick_labels[i].get_text() == str(cat)
positions = points.get_offsets()
if x_type(cat) in x:
val = y[x.index(x_type(cat))]
assert positions[0, 1] == val
else:
assert not positions.size
@pytest.mark.parametrize("hue_var", ["a", "b"])
def test_hue_categorical(self, long_df, hue_var):
cat_var = "b"
hue_levels = categorical_order(long_df[hue_var])
cat_levels = categorical_order(long_df[cat_var])
pal_name = "muted"
palette = dict(zip(hue_levels, color_palette(pal_name)))
ax = self.func(data=long_df, x=cat_var, y="y", hue=hue_var, palette=pal_name)
for i, level in enumerate(cat_levels):
sub_df = long_df[long_df[cat_var] == level]
point_hues = sub_df[hue_var]
points = ax.collections[i]
point_colors = points.get_facecolors()
assert len(point_hues) == len(point_colors)
for hue, color in zip(point_hues, point_colors):
assert tuple(color) == to_rgba(palette[hue])
@pytest.mark.parametrize("hue_var", ["a", "b"])
def test_hue_dodged(self, long_df, hue_var):
ax = self.func(data=long_df, x="y", y="a", hue=hue_var, dodge=True)
colors = color_palette(n_colors=long_df[hue_var].nunique())
collections = iter(ax.collections)
# Slightly awkward logic to handle challenges of how the artists work.
# e.g. there are empty scatter collections but the because facecolors
# for the empty collections will return the default scatter color
while colors:
points = next(collections)
if points.get_offsets().any():
face_color = tuple(points.get_facecolors()[0])
expected_color = to_rgba(colors.pop(0))
assert face_color == expected_color
@pytest.mark.parametrize(
"val_var,val_col,hue_col",
list(itertools.product(["x", "y"], ["b", "y", "t"], [None, "a"])),
)
def test_single(self, long_df, val_var, val_col, hue_col):
var_kws = {val_var: val_col, "hue": hue_col}
ax = self.func(data=long_df, **var_kws)
_draw_figure(ax.figure)
axis_vars = ["x", "y"]
val_idx = axis_vars.index(val_var)
cat_idx = int(not val_idx)
cat_var = axis_vars[cat_idx]
cat_axis = getattr(ax, f"{cat_var}axis")
val_axis = getattr(ax, f"{val_var}axis")
points = ax.collections[0]
point_pos = points.get_offsets().T
cat_pos = point_pos[cat_idx]
val_pos = point_pos[val_idx]
assert_array_equal(cat_pos.round(), 0)
assert cat_pos.max() <= .4
assert cat_pos.min() >= -.4
num_vals = val_axis.convert_units(long_df[val_col])
assert_array_equal(val_pos, num_vals)
if hue_col is not None:
palette = dict(zip(
categorical_order(long_df[hue_col]), color_palette()
))
facecolors = points.get_facecolors()
for i, color in enumerate(facecolors):
if hue_col is None:
assert tuple(color) == to_rgba("C0")
else:
hue_level = long_df.loc[i, hue_col]
expected_color = palette[hue_level]
assert tuple(color) == to_rgba(expected_color)
ticklabels = cat_axis.get_majorticklabels()
assert len(ticklabels) == 1
assert not ticklabels[0].get_text()
def test_attributes(self, long_df):
kwargs = dict(
size=2,
linewidth=1,
edgecolor="C2",
)
ax = self.func(x=long_df["y"], **kwargs)
points, = ax.collections
assert points.get_sizes().item() == kwargs["size"] ** 2
assert points.get_linewidths().item() == kwargs["linewidth"]
assert tuple(points.get_edgecolors().squeeze()) == to_rgba(kwargs["edgecolor"])
def test_three_points(self):
x = np.arange(3)
ax = self.func(x=x)
for point_color in ax.collections[0].get_facecolor():
assert tuple(point_color) == to_rgba("C0")
def test_legend_categorical(self, long_df):
ax = self.func(data=long_df, x="y", y="a", hue="b")
legend_texts = [t.get_text() for t in ax.legend_.texts]
expected = categorical_order(long_df["b"])
assert legend_texts == expected
def test_legend_numeric(self, long_df):
ax = self.func(data=long_df, x="y", y="a", hue="z")
vals = [float(t.get_text()) for t in ax.legend_.texts]
assert (vals[1] - vals[0]) == approx(vals[2] - vals[1])
def test_legend_attributes(self, long_df):
kws = {"edgecolor": "r", "linewidth": 1}
ax = self.func(data=long_df, x="x", y="y", hue="a", **kws)
for pt in get_legend_handles(ax.get_legend()):
assert same_color(pt.get_markeredgecolor(), kws["edgecolor"])
assert pt.get_markeredgewidth() == kws["linewidth"]
def test_legend_disabled(self, long_df):
ax = self.func(data=long_df, x="y", y="a", hue="b", legend=False)
assert ax.legend_ is None
def test_palette_from_color_deprecation(self, long_df):
color = (.9, .4, .5)
hex_color = mpl.colors.to_hex(color)
hue_var = "a"
n_hue = long_df[hue_var].nunique()
palette = color_palette(f"dark:{hex_color}", n_hue)
with pytest.warns(FutureWarning, match="Setting a gradient palette"):
ax = self.func(data=long_df, x="z", hue=hue_var, color=color)
points = ax.collections[0]
for point_color in points.get_facecolors():
assert to_rgb(point_color) in palette
def test_palette_with_hue_deprecation(self, long_df):
palette = "Blues"
with pytest.warns(FutureWarning, match="Passing `palette` without"):
ax = self.func(data=long_df, x="a", y=long_df["y"], palette=palette)
strips = ax.collections
colors = color_palette(palette, len(strips))
for strip, color in zip(strips, colors):
assert same_color(strip.get_facecolor()[0], color)
def test_log_scale(self):
x = [1, 10, 100, 1000]
ax = plt.figure().subplots()
ax.set_xscale("log")
self.func(x=x)
vals = ax.collections[0].get_offsets()[:, 0]
assert_array_equal(x, vals)
y = [1, 2, 3, 4]
ax = plt.figure().subplots()
ax.set_xscale("log")
self.func(x=x, y=y, native_scale=True)
for i, point in enumerate(ax.collections):
val = point.get_offsets()[0, 0]
assert val == approx(x[i])
x = y = np.ones(100)
ax = plt.figure().subplots()
ax.set_yscale("log")
self.func(x=x, y=y, orient="h", native_scale=True)
cat_points = ax.collections[0].get_offsets().copy()[:, 1]
assert np.ptp(np.log10(cat_points)) <= .8
@pytest.mark.parametrize(
"kwargs",
[
dict(data="wide"),
dict(data="wide", orient="h"),
dict(data="long", x="x", color="C3"),
dict(data="long", y="y", hue="a", jitter=False),
dict(data="long", x="a", y="y", hue="z", edgecolor="w", linewidth=.5),
dict(data="long", x="a", y="y", hue="z", edgecolor="auto", linewidth=.5),
dict(data="long", x="a_cat", y="y", hue="z"),
dict(data="long", x="y", y="s", hue="c", orient="h", dodge=True),
dict(data="long", x="s", y="y", hue="c", native_scale=True),
]
)
def test_vs_catplot(self, long_df, wide_df, kwargs):
kwargs = kwargs.copy()
if kwargs["data"] == "long":
kwargs["data"] = long_df
elif kwargs["data"] == "wide":
kwargs["data"] = wide_df
try:
name = self.func.__name__[:-4]
except AttributeError:
name = self.func.func.__name__[:-4]
if name == "swarm":
kwargs.pop("jitter", None)
np.random.seed(0) # for jitter
ax = self.func(**kwargs)
np.random.seed(0)
g = catplot(**kwargs, kind=name)
assert_plots_equal(ax, g.ax)
def test_empty_palette(self):
self.func(x=[], y=[], hue=[], palette=[])
class SharedAggTests(SharedAxesLevelTests):
def test_labels_flat(self):
ind = pd.Index(["a", "b", "c"], name="x")
ser = pd.Series([1, 2, 3], ind, name="y")
ax = self.func(ser)
# To populate texts; only needed on older matplotlibs
_draw_figure(ax.figure)
assert ax.get_xlabel() == ind.name
assert ax.get_ylabel() == ser.name
labels = [t.get_text() for t in ax.get_xticklabels()]
for label, level in zip(labels, ind):
assert label == level
class SharedPatchArtistTests:
@pytest.mark.parametrize("fill", [True, False])
def test_legend_fill(self, long_df, fill):
palette = color_palette()
ax = self.func(
long_df, x="x", y="y", hue="a",
saturation=1, linecolor="k", fill=fill,
)
for i, patch in enumerate(get_legend_handles(ax.get_legend())):
fc = patch.get_facecolor()
ec = patch.get_edgecolor()
if fill:
assert same_color(fc, palette[i])
assert same_color(ec, "k")
else:
assert fc == (0, 0, 0, 0)
assert same_color(ec, palette[i])
def test_legend_attributes(self, long_df):
ax = self.func(long_df, x="x", y="y", hue="a", linewidth=3)
for patch in get_legend_handles(ax.get_legend()):
assert patch.get_linewidth() == 3
class TestStripPlot(SharedScatterTests):
func = staticmethod(stripplot)
def test_jitter_unfixed(self, long_df):
ax1, ax2 = plt.figure().subplots(2)
kws = dict(data=long_df, x="y", orient="h", native_scale=True)
np.random.seed(0)
stripplot(**kws, y="s", ax=ax1)
np.random.seed(0)
stripplot(**kws, y=long_df["s"] * 2, ax=ax2)
p1 = ax1.collections[0].get_offsets()[1]
p2 = ax2.collections[0].get_offsets()[1]
assert p2.std() > p1.std()
@pytest.mark.parametrize(
"orient,jitter",
itertools.product(["v", "h"], [True, .1]),
)
def test_jitter(self, long_df, orient, jitter):
cat_var, val_var = "a", "y"
if orient == "x":
x_var, y_var = cat_var, val_var
cat_idx, val_idx = 0, 1
else:
x_var, y_var = val_var, cat_var
cat_idx, val_idx = 1, 0
cat_vals = categorical_order(long_df[cat_var])
ax = stripplot(
data=long_df, x=x_var, y=y_var, jitter=jitter,
)
if jitter is True:
jitter_range = .4
else:
jitter_range = 2 * jitter
for i, level in enumerate(cat_vals):
vals = long_df.loc[long_df[cat_var] == level, val_var]
points = ax.collections[i].get_offsets().T
cat_points = points[cat_idx]
val_points = points[val_idx]
assert_array_equal(val_points, vals)
assert np.std(cat_points) > 0
assert np.ptp(cat_points) <= jitter_range
class TestSwarmPlot(SharedScatterTests):
func = staticmethod(partial(swarmplot, warn_thresh=1))
class TestBoxPlot(SharedAxesLevelTests, SharedPatchArtistTests):
func = staticmethod(boxplot)
@pytest.fixture
def common_kws(self):
return {"saturation": 1}
def get_last_color(self, ax):
colors = [b.get_facecolor() for b in ax.containers[-1].boxes]
unique_colors = np.unique(colors, axis=0)
assert len(unique_colors) == 1
return to_rgba(unique_colors.squeeze())
def get_box_verts(self, box):
path = box.get_path()
visible_codes = [mpl.path.Path.MOVETO, mpl.path.Path.LINETO]
visible = np.isin(path.codes, visible_codes)
return path.vertices[visible].T
def check_box(self, bxp, data, orient, pos, width=0.8):
pos_idx, val_idx = self.orient_indices(orient)
p25, p50, p75 = np.percentile(data, [25, 50, 75])
box = self.get_box_verts(bxp.box)
assert box[val_idx].min() == approx(p25, 1e-3)
assert box[val_idx].max() == approx(p75, 1e-3)
assert box[pos_idx].min() == approx(pos - width / 2)
assert box[pos_idx].max() == approx(pos + width / 2)
med = bxp.median.get_xydata().T
assert np.allclose(med[val_idx], (p50, p50), rtol=1e-3)
assert np.allclose(med[pos_idx], (pos - width / 2, pos + width / 2))
def check_whiskers(self, bxp, data, orient, pos, capsize=0.4, whis=1.5):
pos_idx, val_idx = self.orient_indices(orient)
whis_lo = bxp.whiskers[0].get_xydata().T
whis_hi = bxp.whiskers[1].get_xydata().T
caps_lo = bxp.caps[0].get_xydata().T
caps_hi = bxp.caps[1].get_xydata().T
fliers = bxp.fliers.get_xydata().T
p25, p75 = np.percentile(data, [25, 75])
iqr = p75 - p25
adj_lo = data[data >= (p25 - iqr * whis)].min()
adj_hi = data[data <= (p75 + iqr * whis)].max()
assert whis_lo[val_idx].max() == approx(p25, 1e-3)
assert whis_lo[val_idx].min() == approx(adj_lo)
assert np.allclose(whis_lo[pos_idx], (pos, pos))
assert np.allclose(caps_lo[val_idx], (adj_lo, adj_lo))
assert np.allclose(caps_lo[pos_idx], (pos - capsize / 2, pos + capsize / 2))
assert whis_hi[val_idx].min() == approx(p75, 1e-3)
assert whis_hi[val_idx].max() == approx(adj_hi)
assert np.allclose(whis_hi[pos_idx], (pos, pos))
assert np.allclose(caps_hi[val_idx], (adj_hi, adj_hi))
assert np.allclose(caps_hi[pos_idx], (pos - capsize / 2, pos + capsize / 2))
flier_data = data[(data < adj_lo) | (data > adj_hi)]
assert sorted(fliers[val_idx]) == sorted(flier_data)
assert np.allclose(fliers[pos_idx], pos)
@pytest.mark.parametrize("orient,col", [("x", "y"), ("y", "z")])
def test_single_var(self, long_df, orient, col):
var = {"x": "y", "y": "x"}[orient]
ax = boxplot(long_df, **{var: col})
bxp = ax.containers[0][0]
self.check_box(bxp, long_df[col], orient, 0)
self.check_whiskers(bxp, long_df[col], orient, 0)
@pytest.mark.parametrize("orient,col", [(None, "x"), ("x", "y"), ("y", "z")])
def test_vector_data(self, long_df, orient, col):
ax = boxplot(long_df[col], orient=orient)
orient = "x" if orient is None else orient
bxp = ax.containers[0][0]
self.check_box(bxp, long_df[col], orient, 0)
self.check_whiskers(bxp, long_df[col], orient, 0)
@pytest.mark.parametrize("orient", ["h", "v"])
def test_wide_data(self, wide_df, orient):
orient = {"h": "y", "v": "x"}[orient]
ax = boxplot(wide_df, orient=orient, color="C0")
for i, bxp in enumerate(ax.containers):
col = wide_df.columns[i]
self.check_box(bxp[i], wide_df[col], orient, i)
self.check_whiskers(bxp[i], wide_df[col], orient, i)
@pytest.mark.parametrize("orient", ["x", "y"])
def test_grouped(self, long_df, orient):
value = {"x": "y", "y": "x"}[orient]
ax = boxplot(long_df, **{orient: "a", value: "z"})
bxp, = ax.containers
levels = categorical_order(long_df["a"])
for i, level in enumerate(levels):
data = long_df.loc[long_df["a"] == level, "z"]
self.check_box(bxp[i], data, orient, i)
self.check_whiskers(bxp[i], data, orient, i)
@pytest.mark.parametrize("orient", ["x", "y"])
def test_hue_grouped(self, long_df, orient):
value = {"x": "y", "y": "x"}[orient]
ax = boxplot(long_df, hue="c", **{orient: "a", value: "z"})
for i, hue_level in enumerate(categorical_order(long_df["c"])):
bxp = ax.containers[i]
for j, level in enumerate(categorical_order(long_df["a"])):
rows = (long_df["a"] == level) & (long_df["c"] == hue_level)
data = long_df.loc[rows, "z"]
pos = j + [-.2, +.2][i]
width, capsize = 0.4, 0.2
self.check_box(bxp[j], data, orient, pos, width)
self.check_whiskers(bxp[j], data, orient, pos, capsize)
def test_hue_not_dodged(self, long_df):
levels = categorical_order(long_df["b"])
hue = long_df["b"].isin(levels[:2])
ax = boxplot(long_df, x="b", y="z", hue=hue)
bxps = ax.containers
for i, level in enumerate(levels):
idx = int(i < 2)
data = long_df.loc[long_df["b"] == level, "z"]
self.check_box(bxps[idx][i % 2], data, "x", i)
self.check_whiskers(bxps[idx][i % 2], data, "x", i)
def test_dodge_native_scale(self, long_df):
centers = categorical_order(long_df["s"])
hue_levels = categorical_order(long_df["c"])
spacing = min(np.diff(centers))
width = 0.8 * spacing / len(hue_levels)
offset = width / len(hue_levels)
ax = boxplot(long_df, x="s", y="z", hue="c", native_scale=True)
for i, hue_level in enumerate(hue_levels):
bxp = ax.containers[i]
for j, center in enumerate(centers):
rows = (long_df["s"] == center) & (long_df["c"] == hue_level)
data = long_df.loc[rows, "z"]
pos = center + [-offset, +offset][i]
self.check_box(bxp[j], data, "x", pos, width)
self.check_whiskers(bxp[j], data, "x", pos, width / 2)
def test_dodge_native_scale_log(self, long_df):
pos = 10 ** long_df["s"]
ax = mpl.figure.Figure().subplots()
ax.set_xscale("log")
boxplot(long_df, x=pos, y="z", hue="c", native_scale=True, ax=ax)
widths = []
for bxp in ax.containers:
for box in bxp.boxes:
coords = np.log10(box.get_path().vertices.T[0])
widths.append(np.ptp(coords))
assert np.std(widths) == approx(0)
def test_dodge_without_hue(self, long_df):
ax = boxplot(long_df, x="a", y="y", dodge=True)
bxp, = ax.containers
levels = categorical_order(long_df["a"])
for i, level in enumerate(levels):
data = long_df.loc[long_df["a"] == level, "y"]
self.check_box(bxp[i], data, "x", i)
self.check_whiskers(bxp[i], data, "x", i)
@pytest.mark.parametrize("orient", ["x", "y"])
def test_log_data_scale(self, long_df, orient):
var = {"x": "y", "y": "x"}[orient]
s = long_df["z"]
ax = mpl.figure.Figure().subplots()
getattr(ax, f"set_{var}scale")("log")
boxplot(**{var: s}, whis=np.inf, ax=ax)
bxp = ax.containers[0][0]
self.check_box(bxp, s, orient, 0)
self.check_whiskers(bxp, s, orient, 0, whis=np.inf)
def test_color(self, long_df):
color = "#123456"
ax = boxplot(long_df, x="a", y="y", color=color, saturation=1)
for box in ax.containers[0].boxes:
assert same_color(box.get_facecolor(), color)
def test_wide_data_multicolored(self, wide_df):
ax = boxplot(wide_df)
assert len(ax.containers) == wide_df.shape[1]
def test_wide_data_single_color(self, wide_df):
ax = boxplot(wide_df, color="C1", saturation=1)
assert len(ax.containers) == 1
for box in ax.containers[0].boxes:
assert same_color(box.get_facecolor(), "C1")
def test_hue_colors(self, long_df):
ax = boxplot(long_df, x="a", y="y", hue="b", saturation=1)
for i, bxp in enumerate(ax.containers):
for box in bxp.boxes:
assert same_color(box.get_facecolor(), f"C{i}")
def test_linecolor(self, long_df):
color = "#778815"
ax = boxplot(long_df, x="a", y="y", linecolor=color)
bxp = ax.containers[0]
for line in [*bxp.medians, *bxp.whiskers, *bxp.caps]:
assert same_color(line.get_color(), color)
for box in bxp.boxes:
assert same_color(box.get_edgecolor(), color)
for flier in bxp.fliers:
assert same_color(flier.get_markeredgecolor(), color)
def test_linecolor_gray_warning(self, long_df):
with pytest.warns(FutureWarning, match="Use \"auto\" to set automatic"):
boxplot(long_df, x="y", linecolor="gray")
def test_saturation(self, long_df):
color = "#8912b0"
ax = boxplot(long_df["x"], color=color, saturation=.5)
box = ax.containers[0].boxes[0]
assert np.allclose(box.get_facecolor()[:3], desaturate(color, 0.5))
def test_linewidth(self, long_df):
width = 5
ax = boxplot(long_df, x="a", y="y", linewidth=width)
bxp = ax.containers[0]
for line in [*bxp.boxes, *bxp.medians, *bxp.whiskers, *bxp.caps]:
assert line.get_linewidth() == width
def test_fill(self, long_df):
color = "#459900"
ax = boxplot(x=long_df["z"], fill=False, color=color)
bxp = ax.containers[0]
assert isinstance(bxp.boxes[0], mpl.lines.Line2D)
for line in [*bxp.boxes, *bxp.medians, *bxp.whiskers, *bxp.caps]:
assert same_color(line.get_color(), color)
@pytest.mark.parametrize("notch_param", ["notch", "shownotches"])
def test_notch(self, long_df, notch_param):
ax = boxplot(x=long_df["z"], **{notch_param: True})
verts = ax.containers[0].boxes[0].get_path().vertices
assert len(verts) == 12
def test_whis(self, long_df):
data = long_df["z"]
ax = boxplot(x=data, whis=2)
bxp = ax.containers[0][0]
self.check_whiskers(bxp, data, "y", 0, whis=2)
def test_gap(self, long_df):
ax = boxplot(long_df, x="a", y="z", hue="c", gap=.1)
for i, hue_level in enumerate(categorical_order(long_df["c"])):
bxp = ax.containers[i]
for j, level in enumerate(categorical_order(long_df["a"])):
rows = (long_df["a"] == level) & (long_df["c"] == hue_level)
data = long_df.loc[rows, "z"]
pos = j + [-.2, +.2][i]
width = 0.9 * 0.4
self.check_box(bxp[j], data, "x", pos, width)
def test_prop_dicts(self, long_df):
prop_dicts = dict(
boxprops=dict(linewidth=3),
medianprops=dict(color=".1"),
whiskerprops=dict(linestyle="--"),
capprops=dict(solid_capstyle="butt"),
flierprops=dict(marker="s"),
)
attr_map = dict(box="boxes", flier="fliers")
ax = boxplot(long_df, x="a", y="z", hue="c", **prop_dicts)
for bxp in ax.containers:
for element in ["box", "median", "whisker", "cap", "flier"]:
attr = attr_map.get(element, f"{element}s")
for artist in getattr(bxp, attr):
for k, v in prop_dicts[f"{element}props"].items():
assert plt.getp(artist, k) == v
def test_showfliers(self, long_df):
ax = boxplot(long_df["x"], showfliers=False)
assert not ax.containers[0].fliers
@pytest.mark.parametrize(
"kwargs",
[
dict(data="wide"),
dict(data="wide", orient="h"),
dict(data="flat"),
dict(data="long", x="a", y="y"),
dict(data=None, x="a", y="y"),
dict(data="long", x="a", y="y", hue="a"),
dict(data=None, x="a", y="y", hue="a"),
dict(data="long", x="a", y="y", hue="b"),
dict(data=None, x="s", y="y", hue="a"),
dict(data="long", x="a", y="y", hue="s"),
dict(data="null", x="a", y="y", hue="a"),
dict(data="long", x="s", y="y", hue="a", native_scale=True),
dict(data="long", x="d", y="y", hue="a", native_scale=True),
dict(data="null", x="a", y="y", hue="b", fill=False, gap=.2),
dict(data="null", x="a", y="y", whis=1, showfliers=False),
dict(data="null", x="a", y="y", linecolor="r", linewidth=5),
dict(data="null", x="a", y="y", shownotches=True, showcaps=False),
]
)
def test_vs_catplot(self, long_df, wide_df, null_df, flat_series, kwargs):
if kwargs["data"] == "long":
kwargs["data"] = long_df
elif kwargs["data"] == "wide":
kwargs["data"] = wide_df
elif kwargs["data"] == "flat":
kwargs["data"] = flat_series
elif kwargs["data"] == "null":
kwargs["data"] = null_df
elif kwargs["data"] is None:
for var in ["x", "y", "hue"]:
if var in kwargs:
kwargs[var] = long_df[kwargs[var]]
ax = boxplot(**kwargs)
g = catplot(**kwargs, kind="box")
assert_plots_equal(ax, g.ax)
class TestBoxenPlot(SharedAxesLevelTests, SharedPatchArtistTests):
func = staticmethod(boxenplot)
@pytest.fixture
def common_kws(self):
return {"saturation": 1}
def get_last_color(self, ax):
fcs = ax.collections[-2].get_facecolors()
return to_rgba(fcs[len(fcs) // 2])
def get_box_width(self, path, orient="x"):
verts = path.vertices.T
idx = ["y", "x"].index(orient)
return np.ptp(verts[idx])
def check_boxen(self, patches, data, orient, pos, width=0.8):
pos_idx, val_idx = self.orient_indices(orient)
verts = np.stack([v.vertices for v in patches.get_paths()], 1).T
assert verts[pos_idx].min().round(4) >= np.round(pos - width / 2, 4)
assert verts[pos_idx].max().round(4) <= np.round(pos + width / 2, 4)
assert np.isin(
np.percentile(data, [25, 75]).round(4), verts[val_idx].round(4).flat
).all()
assert_array_equal(verts[val_idx, 1:, 0], verts[val_idx, :-1, 2])
@pytest.mark.parametrize("orient,col", [("x", "y"), ("y", "z")])
def test_single_var(self, long_df, orient, col):
var = {"x": "y", "y": "x"}[orient]
ax = boxenplot(long_df, **{var: col})
patches = ax.collections[0]
self.check_boxen(patches, long_df[col], orient, 0)
@pytest.mark.parametrize("orient,col", [(None, "x"), ("x", "y"), ("y", "z")])
def test_vector_data(self, long_df, orient, col):
orient = "x" if orient is None else orient
ax = boxenplot(long_df[col], orient=orient)
patches = ax.collections[0]
self.check_boxen(patches, long_df[col], orient, 0)
@pytest.mark.parametrize("orient", ["h", "v"])
def test_wide_data(self, wide_df, orient):
orient = {"h": "y", "v": "x"}[orient]
ax = boxenplot(wide_df, orient=orient)
collections = ax.findobj(mpl.collections.PatchCollection)
for i, patches in enumerate(collections):
col = wide_df.columns[i]
self.check_boxen(patches, wide_df[col], orient, i)
@pytest.mark.parametrize("orient", ["x", "y"])
def test_grouped(self, long_df, orient):
value = {"x": "y", "y": "x"}[orient]
ax = boxenplot(long_df, **{orient: "a", value: "z"})
levels = categorical_order(long_df["a"])
collections = ax.findobj(mpl.collections.PatchCollection)
for i, level in enumerate(levels):
data = long_df.loc[long_df["a"] == level, "z"]
self.check_boxen(collections[i], data, orient, i)
@pytest.mark.parametrize("orient", ["x", "y"])
def test_hue_grouped(self, long_df, orient):
value = {"x": "y", "y": "x"}[orient]
ax = boxenplot(long_df, hue="c", **{orient: "a", value: "z"})
collections = iter(ax.findobj(mpl.collections.PatchCollection))
for i, level in enumerate(categorical_order(long_df["a"])):
for j, hue_level in enumerate(categorical_order(long_df["c"])):
rows = (long_df["a"] == level) & (long_df["c"] == hue_level)
data = long_df.loc[rows, "z"]
pos = i + [-.2, +.2][j]
width = 0.4
self.check_boxen(next(collections), data, orient, pos, width)
def test_dodge_native_scale(self, long_df):
centers = categorical_order(long_df["s"])
hue_levels = categorical_order(long_df["c"])
spacing = min(np.diff(centers))
width = 0.8 * spacing / len(hue_levels)
offset = width / len(hue_levels)
ax = boxenplot(long_df, x="s", y="z", hue="c", native_scale=True)
collections = iter(ax.findobj(mpl.collections.PatchCollection))
for center in centers:
for i, hue_level in enumerate(hue_levels):
rows = (long_df["s"] == center) & (long_df["c"] == hue_level)
data = long_df.loc[rows, "z"]
pos = center + [-offset, +offset][i]
self.check_boxen(next(collections), data, "x", pos, width)
def test_color(self, long_df):
color = "#123456"
ax = boxenplot(long_df, x="a", y="y", color=color, saturation=1)
collections = ax.findobj(mpl.collections.PatchCollection)
for patches in collections:
fcs = patches.get_facecolors()
assert same_color(fcs[len(fcs) // 2], color)
def test_hue_colors(self, long_df):
ax = boxenplot(long_df, x="a", y="y", hue="b", saturation=1)
n_levels = long_df["b"].nunique()
collections = ax.findobj(mpl.collections.PatchCollection)
for i, patches in enumerate(collections):
fcs = patches.get_facecolors()
assert same_color(fcs[len(fcs) // 2], f"C{i % n_levels}")
def test_linecolor(self, long_df):
color = "#669913"
ax = boxenplot(long_df, x="a", y="y", linecolor=color)
for patches in ax.findobj(mpl.collections.PatchCollection):
assert same_color(patches.get_edgecolor(), color)
def test_linewidth(self, long_df):
width = 5
ax = boxenplot(long_df, x="a", y="y", linewidth=width)
for patches in ax.findobj(mpl.collections.PatchCollection):
assert patches.get_linewidth() == width
def test_saturation(self, long_df):
color = "#8912b0"
ax = boxenplot(long_df["x"], color=color, saturation=.5)
fcs = ax.collections[0].get_facecolors()
assert np.allclose(fcs[len(fcs) // 2, :3], desaturate(color, 0.5))
def test_gap(self, long_df):
ax1, ax2 = mpl.figure.Figure().subplots(2)
boxenplot(long_df, x="a", y="y", hue="s", ax=ax1)
boxenplot(long_df, x="a", y="y", hue="s", gap=.2, ax=ax2)
c1 = ax1.findobj(mpl.collections.PatchCollection)
c2 = ax2.findobj(mpl.collections.PatchCollection)
for p1, p2 in zip(c1, c2):
w1 = np.ptp(p1.get_paths()[0].vertices[:, 0])
w2 = np.ptp(p2.get_paths()[0].vertices[:, 0])
assert (w2 / w1) == pytest.approx(0.8)
def test_fill(self, long_df):
ax = boxenplot(long_df, x="a", y="y", hue="s", fill=False)
for c in ax.findobj(mpl.collections.PatchCollection):
assert not c.get_facecolors().size
def test_k_depth_int(self, rng):
x = rng.normal(0, 1, 10_000)
ax = boxenplot(x, k_depth=(k := 8))
assert len(ax.collections[0].get_paths()) == (k * 2 - 1)
def test_k_depth_full(self, rng):
x = rng.normal(0, 1, 10_000)
ax = boxenplot(x=x, k_depth="full")
paths = ax.collections[0].get_paths()
assert len(paths) == 2 * int(np.log2(x.size)) + 1
verts = np.concatenate([p.vertices for p in paths]).T
assert verts[0].min() == x.min()
assert verts[0].max() == x.max()
assert not ax.collections[1].get_offsets().size
def test_trust_alpha(self, rng):
x = rng.normal(0, 1, 10_000)
ax = boxenplot(x, k_depth="trustworthy", trust_alpha=.1)
boxenplot(x, k_depth="trustworthy", trust_alpha=.001, ax=ax)
cs = ax.findobj(mpl.collections.PatchCollection)
assert len(cs[0].get_paths()) > len(cs[1].get_paths())
def test_outlier_prop(self, rng):
x = rng.normal(0, 1, 10_000)
ax = boxenplot(x, k_depth="proportion", outlier_prop=.001)
boxenplot(x, k_depth="proportion", outlier_prop=.1, ax=ax)
cs = ax.findobj(mpl.collections.PatchCollection)
assert len(cs[0].get_paths()) > len(cs[1].get_paths())
def test_exponential_width_method(self, rng):
x = rng.normal(0, 1, 10_000)
ax = boxenplot(x=x, width_method="exponential")
c = ax.findobj(mpl.collections.PatchCollection)[0]
ws = [self.get_box_width(p) for p in c.get_paths()]
assert (ws[1] / ws[0]) == pytest.approx(ws[2] / ws[1])
def test_linear_width_method(self, rng):
x = rng.normal(0, 1, 10_000)
ax = boxenplot(x=x, width_method="linear")
c = ax.findobj(mpl.collections.PatchCollection)[0]
ws = [self.get_box_width(p) for p in c.get_paths()]
assert (ws[1] - ws[0]) == pytest.approx(ws[2] - ws[1])
def test_area_width_method(self, rng):
x = rng.uniform(0, 1, 10_000)
ax = boxenplot(x=x, width_method="area", k_depth=2)
ps = ax.findobj(mpl.collections.PatchCollection)[0].get_paths()
ws = [self.get_box_width(p) for p in ps]
assert np.greater(ws, 0.7).all()
def test_box_kws(self, long_df):
ax = boxenplot(long_df, x="a", y="y", box_kws={"linewidth": (lw := 7.1)})
for c in ax.findobj(mpl.collections.PatchCollection):
assert c.get_linewidths() == lw
def test_line_kws(self, long_df):
ax = boxenplot(long_df, x="a", y="y", line_kws={"linewidth": (lw := 6.2)})
for line in ax.lines:
assert line.get_linewidth() == lw
def test_flier_kws(self, long_df):
ax = boxenplot(long_df, x="a", y="y", flier_kws={"marker": (marker := "X")})
expected = mpl.markers.MarkerStyle(marker).get_path().vertices
for c in ax.findobj(mpl.collections.PathCollection):
assert_array_equal(c.get_paths()[0].vertices, expected)
def test_k_depth_checks(self, long_df):
with pytest.raises(ValueError, match="The value for `k_depth`"):
boxenplot(x=long_df["y"], k_depth="auto")
with pytest.raises(TypeError, match="The `k_depth` parameter"):
boxenplot(x=long_df["y"], k_depth=(1, 2))
def test_width_method_check(self, long_df):
with pytest.raises(ValueError, match="The value for `width_method`"):
boxenplot(x=long_df["y"], width_method="uniform")
def test_scale_deprecation(self, long_df):
with pytest.warns(FutureWarning, match="The `scale` parameter has been"):
boxenplot(x=long_df["y"], scale="linear")
with pytest.warns(FutureWarning, match=".+result for 'area' will appear"):
boxenplot(x=long_df["y"], scale="area")
@pytest.mark.parametrize(
"kwargs",
[
dict(data="wide"),
dict(data="wide", orient="h"),
dict(data="flat"),
dict(data="long", x="a", y="y"),
dict(data=None, x="a", y="y"),
dict(data="long", x="a", y="y", hue="a"),
dict(data=None, x="a", y="y", hue="a"),
dict(data="long", x="a", y="y", hue="b"),
dict(data=None, x="s", y="y", hue="a"),
dict(data="long", x="a", y="y", hue="s", showfliers=False),
dict(data="null", x="a", y="y", hue="a", saturation=.5),
dict(data="long", x="s", y="y", hue="a", native_scale=True),
dict(data="long", x="d", y="y", hue="a", native_scale=True),
dict(data="null", x="a", y="y", hue="b", fill=False, gap=.2),
dict(data="null", x="a", y="y", linecolor="r", linewidth=5),
dict(data="long", x="a", y="y", k_depth="trustworthy", trust_alpha=.1),
dict(data="long", x="a", y="y", k_depth="proportion", outlier_prop=.1),
dict(data="long", x="a", y="z", width_method="area"),
dict(data="long", x="a", y="z", box_kws={"alpha": .2}, alpha=.4)
]
)
def test_vs_catplot(self, long_df, wide_df, null_df, flat_series, kwargs):
if kwargs["data"] == "long":
kwargs["data"] = long_df
elif kwargs["data"] == "wide":
kwargs["data"] = wide_df
elif kwargs["data"] == "flat":
kwargs["data"] = flat_series
elif kwargs["data"] == "null":
kwargs["data"] = null_df
elif kwargs["data"] is None:
for var in ["x", "y", "hue"]:
if var in kwargs:
kwargs[var] = long_df[kwargs[var]]
ax = boxenplot(**kwargs)
g = catplot(**kwargs, kind="boxen")
assert_plots_equal(ax, g.ax)
class TestViolinPlot(SharedAxesLevelTests, SharedPatchArtistTests):
func = staticmethod(violinplot)
@pytest.fixture
def common_kws(self):
return {"saturation": 1}
def get_last_color(self, ax):
color = ax.collections[-1].get_facecolor()
return to_rgba(color)
def violin_width(self, poly, orient="x"):
idx, _ = self.orient_indices(orient)
return np.ptp(poly.get_paths()[0].vertices[:, idx])
def check_violin(self, poly, data, orient, pos, width=0.8):
pos_idx, val_idx = self.orient_indices(orient)
verts = poly.get_paths()[0].vertices.T
assert verts[pos_idx].min() >= (pos - width / 2)
assert verts[pos_idx].max() <= (pos + width / 2)
# Assumes violin was computed with cut=0
assert verts[val_idx].min() == approx(data.min())
assert verts[val_idx].max() == approx(data.max())
@pytest.mark.parametrize("orient,col", [("x", "y"), ("y", "z")])
def test_single_var(self, long_df, orient, col):
var = {"x": "y", "y": "x"}[orient]
ax = violinplot(long_df, **{var: col}, cut=0)
poly = ax.collections[0]
self.check_violin(poly, long_df[col], orient, 0)
@pytest.mark.parametrize("orient,col", [(None, "x"), ("x", "y"), ("y", "z")])
def test_vector_data(self, long_df, orient, col):
orient = "x" if orient is None else orient
ax = violinplot(long_df[col], cut=0, orient=orient)
poly = ax.collections[0]
self.check_violin(poly, long_df[col], orient, 0)
@pytest.mark.parametrize("orient", ["h", "v"])
def test_wide_data(self, wide_df, orient):
orient = {"h": "y", "v": "x"}[orient]
ax = violinplot(wide_df, cut=0, orient=orient)
for i, poly in enumerate(ax.collections):
col = wide_df.columns[i]
self.check_violin(poly, wide_df[col], orient, i)
@pytest.mark.parametrize("orient", ["x", "y"])
def test_grouped(self, long_df, orient):
value = {"x": "y", "y": "x"}[orient]
ax = violinplot(long_df, **{orient: "a", value: "z"}, cut=0)
levels = categorical_order(long_df["a"])
for i, level in enumerate(levels):
data = long_df.loc[long_df["a"] == level, "z"]
self.check_violin(ax.collections[i], data, orient, i)
@pytest.mark.parametrize("orient", ["x", "y"])
def test_hue_grouped(self, long_df, orient):
value = {"x": "y", "y": "x"}[orient]
ax = violinplot(long_df, hue="c", **{orient: "a", value: "z"}, cut=0)
polys = iter(ax.collections)
for i, level in enumerate(categorical_order(long_df["a"])):
for j, hue_level in enumerate(categorical_order(long_df["c"])):
rows = (long_df["a"] == level) & (long_df["c"] == hue_level)
data = long_df.loc[rows, "z"]
pos = i + [-.2, +.2][j]
width = 0.4
self.check_violin(next(polys), data, orient, pos, width)
def test_hue_not_dodged(self, long_df):
levels = categorical_order(long_df["b"])
hue = long_df["b"].isin(levels[:2])
ax = violinplot(long_df, x="b", y="z", hue=hue, cut=0)
for i, level in enumerate(levels):
poly = ax.collections[i]
data = long_df.loc[long_df["b"] == level, "z"]
self.check_violin(poly, data, "x", i)
def test_dodge_native_scale(self, long_df):
centers = categorical_order(long_df["s"])
hue_levels = categorical_order(long_df["c"])
spacing = min(np.diff(centers))
width = 0.8 * spacing / len(hue_levels)
offset = width / len(hue_levels)
ax = violinplot(long_df, x="s", y="z", hue="c", native_scale=True, cut=0)
violins = iter(ax.collections)
for center in centers:
for i, hue_level in enumerate(hue_levels):
rows = (long_df["s"] == center) & (long_df["c"] == hue_level)
data = long_df.loc[rows, "z"]
pos = center + [-offset, +offset][i]
poly = next(violins)
self.check_violin(poly, data, "x", pos, width)
def test_dodge_native_scale_log(self, long_df):
pos = 10 ** long_df["s"]
ax = mpl.figure.Figure().subplots()
ax.set_xscale("log")
variables = dict(x=pos, y="z", hue="c")
violinplot(long_df, **variables, native_scale=True, density_norm="width", ax=ax)
widths = []
n_violins = long_df["s"].nunique() * long_df["c"].nunique()
for poly in ax.collections[:n_violins]:
verts = poly.get_paths()[0].vertices[:, 0]
coords = np.log10(verts)
widths.append(np.ptp(coords))
assert np.std(widths) == approx(0)
def test_color(self, long_df):
color = "#123456"
ax = violinplot(long_df, x="a", y="y", color=color, saturation=1)
for poly in ax.collections:
assert same_color(poly.get_facecolor(), color)
def test_hue_colors(self, long_df):
ax = violinplot(long_df, x="a", y="y", hue="b", saturation=1)
n_levels = long_df["b"].nunique()
for i, poly in enumerate(ax.collections):
assert same_color(poly.get_facecolor(), f"C{i % n_levels}")
@pytest.mark.parametrize("inner", ["box", "quart", "stick", "point"])
def test_linecolor(self, long_df, inner):
color = "#669913"
ax = violinplot(long_df, x="a", y="y", linecolor=color, inner=inner)
for poly in ax.findobj(mpl.collections.PolyCollection):
assert same_color(poly.get_edgecolor(), color)
for lines in ax.findobj(mpl.collections.LineCollection):
assert same_color(lines.get_color(), color)
for line in ax.lines:
assert same_color(line.get_color(), color)
def test_linewidth(self, long_df):
width = 5
ax = violinplot(long_df, x="a", y="y", linewidth=width)
poly = ax.collections[0]
assert poly.get_linewidth() == width
def test_saturation(self, long_df):
color = "#8912b0"
ax = violinplot(long_df["x"], color=color, saturation=.5)
poly = ax.collections[0]
assert np.allclose(poly.get_facecolors()[0, :3], desaturate(color, 0.5))
@pytest.mark.parametrize("inner", ["box", "quart", "stick", "point"])
def test_fill(self, long_df, inner):
color = "#459900"
ax = violinplot(x=long_df["z"], fill=False, color=color, inner=inner)
for poly in ax.findobj(mpl.collections.PolyCollection):
assert poly.get_facecolor().size == 0
assert same_color(poly.get_edgecolor(), color)
for lines in ax.findobj(mpl.collections.LineCollection):
assert same_color(lines.get_color(), color)
for line in ax.lines:
assert same_color(line.get_color(), color)
@pytest.mark.parametrize("orient", ["x", "y"])
def test_inner_box(self, long_df, orient):
pos_idx, val_idx = self.orient_indices(orient)
ax = violinplot(long_df["y"], orient=orient)
stats = mpl.cbook.boxplot_stats(long_df["y"])[0]
whiskers = ax.lines[0].get_xydata()
assert whiskers[0, val_idx] == stats["whislo"]
assert whiskers[1, val_idx] == stats["whishi"]
assert whiskers[:, pos_idx].tolist() == [0, 0]
box = ax.lines[1].get_xydata()
assert box[0, val_idx] == stats["q1"]
assert box[1, val_idx] == stats["q3"]
assert box[:, pos_idx].tolist() == [0, 0]
median = ax.lines[2].get_xydata()
assert median[0, val_idx] == stats["med"]
assert median[0, pos_idx] == 0
@pytest.mark.parametrize("orient", ["x", "y"])
def test_inner_quartiles(self, long_df, orient):
pos_idx, val_idx = self.orient_indices(orient)
ax = violinplot(long_df["y"], orient=orient, inner="quart")
quartiles = np.percentile(long_df["y"], [25, 50, 75])
for q, line in zip(quartiles, ax.lines):
pts = line.get_xydata()
for pt in pts:
assert pt[val_idx] == q
assert pts[0, pos_idx] == -pts[1, pos_idx]
@pytest.mark.parametrize("orient", ["x", "y"])
def test_inner_stick(self, long_df, orient):
pos_idx, val_idx = self.orient_indices(orient)
ax = violinplot(long_df["y"], orient=orient, inner="stick")
for i, pts in enumerate(ax.collections[1].get_segments()):
for pt in pts:
assert pt[val_idx] == long_df["y"].iloc[i]
assert pts[0, pos_idx] == -pts[1, pos_idx]
@pytest.mark.parametrize("orient", ["x", "y"])
def test_inner_points(self, long_df, orient):
pos_idx, val_idx = self.orient_indices(orient)
ax = violinplot(long_df["y"], orient=orient, inner="points")
points = ax.collections[1]
for i, pt in enumerate(points.get_offsets()):
assert pt[val_idx] == long_df["y"].iloc[i]
assert pt[pos_idx] == 0
def test_split_single(self, long_df):
ax = violinplot(long_df, x="a", y="z", split=True, cut=0)
levels = categorical_order(long_df["a"])
for i, level in enumerate(levels):
data = long_df.loc[long_df["a"] == level, "z"]
self.check_violin(ax.collections[i], data, "x", i)
verts = ax.collections[i].get_paths()[0].vertices
assert np.isclose(verts[:, 0], i + .4).sum() >= 100
def test_split_multi(self, long_df):
ax = violinplot(long_df, x="a", y="z", hue="c", split=True, cut=0)
polys = iter(ax.collections)
for i, level in enumerate(categorical_order(long_df["a"])):
for j, hue_level in enumerate(categorical_order(long_df["c"])):
rows = (long_df["a"] == level) & (long_df["c"] == hue_level)
data = long_df.loc[rows, "z"]
pos = i + [-.2, +.2][j]
poly = next(polys)
self.check_violin(poly, data, "x", pos, width=0.4)
verts = poly.get_paths()[0].vertices
assert np.isclose(verts[:, 0], i).sum() >= 100
def test_density_norm_area(self, long_df):
y = long_df["y"].to_numpy()
ax = violinplot([y, y * 5], color="C0")
widths = []
for poly in ax.collections:
widths.append(self.violin_width(poly))
assert widths[0] / widths[1] == approx(5)
def test_density_norm_count(self, long_df):
y = long_df["y"].to_numpy()
ax = violinplot([np.repeat(y, 3), y], density_norm="count", color="C0")
widths = []
for poly in ax.collections:
widths.append(self.violin_width(poly))
assert widths[0] / widths[1] == approx(3)
def test_density_norm_width(self, long_df):
ax = violinplot(long_df, x="a", y="y", density_norm="width")
for poly in ax.collections:
assert self.violin_width(poly) == approx(0.8)
def test_common_norm(self, long_df):
ax = violinplot(long_df, x="a", y="y", hue="c", common_norm=True)
widths = []
for poly in ax.collections:
widths.append(self.violin_width(poly))
assert sum(w > 0.3999 for w in widths) == 1
def test_scale_deprecation(self, long_df):
with pytest.warns(FutureWarning, match=r".+Pass `density_norm='count'`"):
violinplot(long_df, x="a", y="y", hue="b", scale="count")
def test_scale_hue_deprecation(self, long_df):
with pytest.warns(FutureWarning, match=r".+Pass `common_norm=True`"):
violinplot(long_df, x="a", y="y", hue="b", scale_hue=False)
def test_bw_adjust(self, long_df):
ax = violinplot(long_df["y"], bw_adjust=.2)
violinplot(long_df["y"], bw_adjust=2)
kde1 = ax.collections[0].get_paths()[0].vertices[:100, 0]
kde2 = ax.collections[1].get_paths()[0].vertices[:100, 0]
assert np.std(np.diff(kde1)) > np.std(np.diff(kde2))
def test_bw_deprecation(self, long_df):
with pytest.warns(FutureWarning, match=r".*Setting `bw_method='silverman'`"):
violinplot(long_df["y"], bw="silverman")
def test_gap(self, long_df):
ax = violinplot(long_df, y="y", hue="c", gap=.2)
a = ax.collections[0].get_paths()[0].vertices[:, 0].max()
b = ax.collections[1].get_paths()[0].vertices[:, 0].min()
assert (b - a) == approx(0.2 * 0.8 / 2)
def test_inner_kws(self, long_df):
kws = {"linewidth": 3}
ax = violinplot(long_df, x="a", y="y", inner="stick", inner_kws=kws)
for line in ax.lines:
assert line.get_linewidth() == kws["linewidth"]
def test_box_inner_kws(self, long_df):
kws = {"box_width": 10, "whis_width": 2, "marker": "x"}
ax = violinplot(long_df, x="a", y="y", inner_kws=kws)
for line in ax.lines[::3]:
assert line.get_linewidth() == kws["whis_width"]
for line in ax.lines[1::3]:
assert line.get_linewidth() == kws["box_width"]
for line in ax.lines[2::3]:
assert line.get_marker() == kws["marker"]
@pytest.mark.parametrize(
"kwargs",
[
dict(data="wide"),
dict(data="wide", orient="h"),
dict(data="flat"),
dict(data="long", x="a", y="y"),
dict(data=None, x="a", y="y", split=True),
dict(data="long", x="a", y="y", hue="a"),
dict(data=None, x="a", y="y", hue="a"),
dict(data="long", x="a", y="y", hue="b"),
dict(data=None, x="s", y="y", hue="a"),
dict(data="long", x="a", y="y", hue="s", split=True),
dict(data="null", x="a", y="y", hue="a"),
dict(data="long", x="s", y="y", hue="a", native_scale=True),
dict(data="long", x="d", y="y", hue="a", native_scale=True),
dict(data="null", x="a", y="y", hue="b", fill=False, gap=.2),
dict(data="null", x="a", y="y", linecolor="r", linewidth=5),
dict(data="long", x="a", y="y", inner="stick"),
dict(data="long", x="a", y="y", inner="points"),
dict(data="long", x="a", y="y", hue="b", inner="quartiles", split=True),
dict(data="long", x="a", y="y", density_norm="count", common_norm=True),
dict(data="long", x="a", y="y", bw_adjust=2),
]
)
def test_vs_catplot(self, long_df, wide_df, null_df, flat_series, kwargs):
if kwargs["data"] == "long":
kwargs["data"] = long_df
elif kwargs["data"] == "wide":
kwargs["data"] = wide_df
elif kwargs["data"] == "flat":
kwargs["data"] = flat_series
elif kwargs["data"] == "null":
kwargs["data"] = null_df
elif kwargs["data"] is None:
for var in ["x", "y", "hue"]:
if var in kwargs:
kwargs[var] = long_df[kwargs[var]]
ax = violinplot(**kwargs)
g = catplot(**kwargs, kind="violin")
assert_plots_equal(ax, g.ax)
class TestBarPlot(SharedAggTests):
func = staticmethod(barplot)
@pytest.fixture
def common_kws(self):
return {"saturation": 1}
def get_last_color(self, ax):
colors = [p.get_facecolor() for p in ax.containers[-1]]
unique_colors = np.unique(colors, axis=0)
assert len(unique_colors) == 1
return to_rgba(unique_colors.squeeze())
@pytest.mark.parametrize("orient", ["x", "y"])
def test_single_var(self, orient):
vals = pd.Series([1, 3, 10])
ax = barplot(**{orient: vals})
bar, = ax.patches
prop = {"x": "width", "y": "height"}[orient]
assert getattr(bar, f"get_{prop}")() == approx(vals.mean())
@pytest.mark.parametrize("orient", ["x", "y", "h", "v"])
def test_wide_df(self, wide_df, orient):
ax = barplot(wide_df, orient=orient)
orient = {"h": "y", "v": "x"}.get(orient, orient)
prop = {"x": "height", "y": "width"}[orient]
for i, bar in enumerate(ax.patches):
assert getattr(bar, f"get_{prop}")() == approx(wide_df.iloc[:, i].mean())
@pytest.mark.parametrize("orient", ["x", "y", "h", "v"])
def test_vector_orient(self, orient):
keys, vals = ["a", "b", "c"], [1, 2, 3]
data = dict(zip(keys, vals))
orient = {"h": "y", "v": "x"}.get(orient, orient)
prop = {"x": "height", "y": "width"}[orient]
ax = barplot(data, orient=orient)
for i, bar in enumerate(ax.patches):
assert getattr(bar, f"get_{orient}")() == approx(i - 0.4)
assert getattr(bar, f"get_{prop}")() == approx(vals[i])
def test_xy_vertical(self):
x, y = ["a", "b", "c"], [1, 3, 2.5]
ax = barplot(x=x, y=y)
for i, bar in enumerate(ax.patches):
assert bar.get_x() + bar.get_width() / 2 == approx(i)
assert bar.get_y() == approx(0)
assert bar.get_height() == approx(y[i])
assert bar.get_width() == approx(0.8)
def test_xy_horizontal(self):
x, y = [1, 3, 2.5], ["a", "b", "c"]
ax = barplot(x=x, y=y)
for i, bar in enumerate(ax.patches):
assert bar.get_x() == approx(0)
assert bar.get_y() + bar.get_height() / 2 == approx(i)
assert bar.get_height() == approx(0.8)
assert bar.get_width() == approx(x[i])
def test_xy_with_na_grouper(self):
x, y = ["a", None, "b"], [1, 2, 3]
ax = barplot(x=x, y=y)
_draw_figure(ax.figure) # For matplotlib<3.5
assert ax.get_xticks() == [0, 1]
assert [t.get_text() for t in ax.get_xticklabels()] == ["a", "b"]
assert ax.patches[0].get_height() == 1
assert ax.patches[1].get_height() == 3
def test_xy_with_na_value(self):
x, y = ["a", "b", "c"], [1, None, 3]
ax = barplot(x=x, y=y)
_draw_figure(ax.figure) # For matplotlib<3.5
assert ax.get_xticks() == [0, 1, 2]
assert [t.get_text() for t in ax.get_xticklabels()] == ["a", "b", "c"]
assert ax.patches[0].get_height() == 1
assert ax.patches[1].get_height() == 3
def test_hue_redundant(self):
x, y = ["a", "b", "c"], [1, 2, 3]
ax = barplot(x=x, y=y, hue=x, saturation=1)
for i, bar in enumerate(ax.patches):
assert bar.get_x() + bar.get_width() / 2 == approx(i)
assert bar.get_y() == 0
assert bar.get_height() == y[i]
assert bar.get_width() == approx(0.8)
assert same_color(bar.get_facecolor(), f"C{i}")
def test_hue_matched(self):
x, y = ["a", "b", "c"], [1, 2, 3]
hue = ["x", "x", "y"]
ax = barplot(x=x, y=y, hue=hue, saturation=1, legend=False)
for i, bar in enumerate(ax.patches):
assert bar.get_x() + bar.get_width() / 2 == approx(i)
assert bar.get_y() == 0
assert bar.get_height() == y[i]
assert bar.get_width() == approx(0.8)
assert same_color(bar.get_facecolor(), f"C{i // 2}")
def test_hue_matched_by_name(self):
data = {"x": ["a", "b", "c"], "y": [1, 2, 3]}
ax = barplot(data, x="x", y="y", hue="x", saturation=1)
for i, bar in enumerate(ax.patches):
assert bar.get_x() + bar.get_width() / 2 == approx(i)
assert bar.get_y() == 0
assert bar.get_height() == data["y"][i]
assert bar.get_width() == approx(0.8)
assert same_color(bar.get_facecolor(), f"C{i}")
def test_hue_dodged(self):
x = ["a", "b", "a", "b"]
y = [1, 2, 3, 4]
hue = ["x", "x", "y", "y"]
ax = barplot(x=x, y=y, hue=hue, saturation=1, legend=False)
for i, bar in enumerate(ax.patches):
sign = 1 if i // 2 else -1
assert (
bar.get_x() + bar.get_width() / 2
== approx(i % 2 + sign * 0.8 / 4)
)
assert bar.get_y() == 0
assert bar.get_height() == y[i]
assert bar.get_width() == approx(0.8 / 2)
assert same_color(bar.get_facecolor(), f"C{i // 2}")
def test_gap(self):
x = ["a", "b", "a", "b"]
y = [1, 2, 3, 4]
hue = ["x", "x", "y", "y"]
ax = barplot(x=x, y=y, hue=hue, gap=.25, legend=False)
for i, bar in enumerate(ax.patches):
assert bar.get_width() == approx(0.8 / 2 * .75)
def test_hue_undodged(self):
x = ["a", "b", "a", "b"]
y = [1, 2, 3, 4]
hue = ["x", "x", "y", "y"]
ax = barplot(x=x, y=y, hue=hue, saturation=1, dodge=False, legend=False)
for i, bar in enumerate(ax.patches):
assert bar.get_x() + bar.get_width() / 2 == approx(i % 2)
assert bar.get_y() == 0
assert bar.get_height() == y[i]
assert bar.get_width() == approx(0.8)
assert same_color(bar.get_facecolor(), f"C{i // 2}")
def test_hue_order(self):
x, y = ["a", "b", "c"], [1, 2, 3]
hue_order = ["c", "b", "a"]
ax = barplot(x=x, y=y, hue=x, hue_order=hue_order, saturation=1)
for i, bar in enumerate(ax.patches):
assert same_color(bar.get_facecolor(), f"C{i}")
assert bar.get_x() + bar.get_width() / 2 == approx(2 - i)
def test_hue_norm(self):
x, y = [1, 2, 3, 4], [1, 2, 3, 4]
ax = barplot(x=x, y=y, hue=x, hue_norm=(2, 3))
colors = [bar.get_facecolor() for bar in ax.patches]
assert colors[0] == colors[1]
assert colors[1] != colors[2]
assert colors[2] == colors[3]
def test_fill(self):
x = ["a", "b", "a", "b"]
y = [1, 2, 3, 4]
hue = ["x", "x", "y", "y"]
ax = barplot(x=x, y=y, hue=hue, fill=False, legend=False)
for i, bar in enumerate(ax.patches):
assert same_color(bar.get_edgecolor(), f"C{i // 2}")
assert same_color(bar.get_facecolor(), (0, 0, 0, 0))
def test_xy_native_scale(self):
x, y = [2, 4, 8], [1, 2, 3]
ax = barplot(x=x, y=y, native_scale=True)
for i, bar in enumerate(ax.patches):
assert bar.get_x() + bar.get_width() / 2 == approx(x[i])
assert bar.get_y() == 0
assert bar.get_height() == y[i]
assert bar.get_width() == approx(0.8 * 2)
def test_xy_native_scale_log_transform(self):
x, y = [1, 10, 100], [1, 2, 3]
ax = mpl.figure.Figure().subplots()
ax.set_xscale("log")
barplot(x=x, y=y, native_scale=True, ax=ax)
for i, bar in enumerate(ax.patches):
x0, x1 = np.log10([bar.get_x(), bar.get_x() + bar.get_width()])
center = 10 ** (x0 + (x1 - x0) / 2)
assert center == approx(x[i])
assert bar.get_y() == 0
assert bar.get_height() == y[i]
assert ax.patches[1].get_width() > ax.patches[0].get_width()
def test_datetime_native_scale_axis(self):
x = pd.date_range("2010-01-01", periods=20, freq="MS")
y = np.arange(20)
ax = barplot(x=x, y=y, native_scale=True)
assert "Date" in ax.xaxis.get_major_locator().__class__.__name__
day = "2003-02-28"
assert_array_equal(ax.xaxis.convert_units([day]), mpl.dates.date2num([day]))
def test_native_scale_dodged(self):
x, y = [2, 4, 2, 4], [1, 2, 3, 4]
hue = ["x", "x", "y", "y"]
ax = barplot(x=x, y=y, hue=hue, native_scale=True)
for x_i, bar in zip(x[:2], ax.patches[:2]):
assert bar.get_x() + bar.get_width() == approx(x_i)
for x_i, bar in zip(x[2:], ax.patches[2:]):
assert bar.get_x() == approx(x_i)
def test_native_scale_log_transform_dodged(self):
x, y = [1, 100, 1, 100], [1, 2, 3, 4]
hue = ["x", "x", "y", "y"]
ax = mpl.figure.Figure().subplots()
ax.set_xscale("log")
barplot(x=x, y=y, hue=hue, native_scale=True, ax=ax)
for x_i, bar in zip(x[:2], ax.patches[:2]):
assert bar.get_x() + bar.get_width() == approx(x_i)
for x_i, bar in zip(x[2:], ax.patches[2:]):
assert bar.get_x() == approx(x_i)
def test_estimate_default(self, long_df):
agg_var, val_var = "a", "y"
agg_df = long_df.groupby(agg_var)[val_var].mean()
ax = barplot(long_df, x=agg_var, y=val_var, errorbar=None)
order = categorical_order(long_df[agg_var])
for i, bar in enumerate(ax.patches):
assert bar.get_height() == approx(agg_df[order[i]])
def test_estimate_string(self, long_df):
agg_var, val_var = "a", "y"
agg_df = long_df.groupby(agg_var)[val_var].median()
ax = barplot(long_df, x=agg_var, y=val_var, estimator="median", errorbar=None)
order = categorical_order(long_df[agg_var])
for i, bar in enumerate(ax.patches):
assert bar.get_height() == approx(agg_df[order[i]])
def test_estimate_func(self, long_df):
agg_var, val_var = "a", "y"
agg_df = long_df.groupby(agg_var)[val_var].median()
ax = barplot(long_df, x=agg_var, y=val_var, estimator=np.median, errorbar=None)
order = categorical_order(long_df[agg_var])
for i, bar in enumerate(ax.patches):
assert bar.get_height() == approx(agg_df[order[i]])
def test_weighted_estimate(self, long_df):
ax = barplot(long_df, y="y", weights="x")
height = ax.patches[0].get_height()
expected = np.average(long_df["y"], weights=long_df["x"])
assert height == expected
def test_estimate_log_transform(self, long_df):
ax = mpl.figure.Figure().subplots()
ax.set_xscale("log")
barplot(x=long_df["z"], ax=ax)
bar, = ax.patches
assert bar.get_width() == 10 ** np.log10(long_df["z"]).mean()
def test_errorbars(self, long_df):
agg_var, val_var = "a", "y"
agg_df = long_df.groupby(agg_var)[val_var].agg(["mean", "std"])
ax = barplot(long_df, x=agg_var, y=val_var, errorbar="sd")
order = categorical_order(long_df[agg_var])
for i, line in enumerate(ax.lines):
row = agg_df.loc[order[i]]
lo, hi = line.get_ydata()
assert lo == approx(row["mean"] - row["std"])
assert hi == approx(row["mean"] + row["std"])
def test_width(self):
width = .5
x, y = ["a", "b", "c"], [1, 2, 3]
ax = barplot(x=x, y=y, width=width)
for i, bar in enumerate(ax.patches):
assert bar.get_x() + bar.get_width() / 2 == approx(i)
assert bar.get_width() == width
def test_width_native_scale(self):
width = .5
x, y = [4, 6, 10], [1, 2, 3]
ax = barplot(x=x, y=y, width=width, native_scale=True)
for bar in ax.patches:
assert bar.get_width() == (width * 2)
def test_width_spaced_categories(self):
ax = barplot(x=["a", "b", "c"], y=[4, 5, 6])
barplot(x=["a", "c"], y=[1, 3], ax=ax)
for bar in ax.patches:
assert bar.get_width() == pytest.approx(0.8)
def test_saturation_color(self):
color = (.1, .9, .2)
x, y = ["a", "b", "c"], [1, 2, 3]
ax = barplot(x=x, y=y)
for bar in ax.patches:
assert np.var(bar.get_facecolor()[:3]) < np.var(color)
def test_saturation_palette(self):
palette = color_palette("viridis", 3)
x, y = ["a", "b", "c"], [1, 2, 3]
ax = barplot(x=x, y=y, hue=x, palette=palette)
for i, bar in enumerate(ax.patches):
assert np.var(bar.get_facecolor()[:3]) < np.var(palette[i])
def test_legend_numeric_auto(self, long_df):
ax = barplot(long_df, x="x", y="y", hue="x")
assert len(ax.get_legend().texts) <= 6
def test_legend_numeric_full(self, long_df):
ax = barplot(long_df, x="x", y="y", hue="x", legend="full")
labels = [t.get_text() for t in ax.get_legend().texts]
levels = [str(x) for x in sorted(long_df["x"].unique())]
assert labels == levels
def test_legend_disabled(self, long_df):
ax = barplot(long_df, x="x", y="y", hue="b", legend=False)
assert ax.get_legend() is None
def test_error_caps(self):
x, y = ["a", "b", "c"] * 2, [1, 2, 3, 4, 5, 6]
ax = barplot(x=x, y=y, capsize=.8, errorbar="pi")
assert len(ax.patches) == len(ax.lines)
for bar, error in zip(ax.patches, ax.lines):
pos = error.get_xdata()
assert len(pos) == 8
assert np.nanmin(pos) == approx(bar.get_x())
assert np.nanmax(pos) == approx(bar.get_x() + bar.get_width())
def test_error_caps_native_scale(self):
x, y = [2, 4, 20] * 2, [1, 2, 3, 4, 5, 6]
ax = barplot(x=x, y=y, capsize=.8, native_scale=True, errorbar="pi")
assert len(ax.patches) == len(ax.lines)
for bar, error in zip(ax.patches, ax.lines):
pos = error.get_xdata()
assert len(pos) == 8
assert np.nanmin(pos) == approx(bar.get_x())
assert np.nanmax(pos) == approx(bar.get_x() + bar.get_width())
def test_error_caps_native_scale_log_transform(self):
x, y = [1, 10, 1000] * 2, [1, 2, 3, 4, 5, 6]
ax = mpl.figure.Figure().subplots()
ax.set_xscale("log")
barplot(x=x, y=y, capsize=.8, native_scale=True, errorbar="pi", ax=ax)
assert len(ax.patches) == len(ax.lines)
for bar, error in zip(ax.patches, ax.lines):
pos = error.get_xdata()
assert len(pos) == 8
assert np.nanmin(pos) == approx(bar.get_x())
assert np.nanmax(pos) == approx(bar.get_x() + bar.get_width())
def test_bar_kwargs(self):
x, y = ["a", "b", "c"], [1, 2, 3]
kwargs = dict(linewidth=3, facecolor=(.5, .4, .3, .2), rasterized=True)
ax = barplot(x=x, y=y, **kwargs)
for bar in ax.patches:
assert bar.get_linewidth() == kwargs["linewidth"]
assert bar.get_facecolor() == kwargs["facecolor"]
assert bar.get_rasterized() == kwargs["rasterized"]
def test_legend_attributes(self, long_df):
palette = color_palette()
ax = barplot(
long_df, x="a", y="y", hue="c", saturation=1, edgecolor="k", linewidth=3
)
for i, patch in enumerate(get_legend_handles(ax.get_legend())):
assert same_color(patch.get_facecolor(), palette[i])
assert same_color(patch.get_edgecolor(), "k")
assert patch.get_linewidth() == 3
def test_legend_unfilled(self, long_df):
palette = color_palette()
ax = barplot(long_df, x="a", y="y", hue="c", fill=False, linewidth=3)
for i, patch in enumerate(get_legend_handles(ax.get_legend())):
assert patch.get_facecolor() == (0, 0, 0, 0)
assert same_color(patch.get_edgecolor(), palette[i])
assert patch.get_linewidth() == 3
@pytest.mark.parametrize("fill", [True, False])
def test_err_kws(self, fill):
x, y = ["a", "b", "c"], [1, 2, 3]
err_kws = dict(color=(1, 1, .5, .5), linewidth=5)
ax = barplot(x=x, y=y, fill=fill, err_kws=err_kws)
for line in ax.lines:
assert line.get_color() == err_kws["color"]
assert line.get_linewidth() == err_kws["linewidth"]
@pytest.mark.parametrize(
"kwargs",
[
dict(data="wide"),
dict(data="wide", orient="h"),
dict(data="flat"),
dict(data="long", x="a", y="y"),
dict(data=None, x="a", y="y"),
dict(data="long", x="a", y="y", hue="a"),
dict(data=None, x="a", y="y", hue="a"),
dict(data="long", x="a", y="y", hue="b"),
dict(data=None, x="s", y="y", hue="a"),
dict(data="long", x="a", y="y", hue="s"),
dict(data="long", x="a", y="y", units="c"),
dict(data="null", x="a", y="y", hue="a", gap=.1, fill=False),
dict(data="long", x="s", y="y", hue="a", native_scale=True),
dict(data="long", x="d", y="y", hue="a", native_scale=True),
dict(data="long", x="a", y="y", errorbar=("pi", 50)),
dict(data="long", x="a", y="y", errorbar=None),
dict(data="long", x="a", y="y", capsize=.3, err_kws=dict(c="k")),
dict(data="long", x="a", y="y", color="blue", edgecolor="green", alpha=.5),
]
)
def test_vs_catplot(self, long_df, wide_df, null_df, flat_series, kwargs):
kwargs = kwargs.copy()
kwargs["seed"] = 0
kwargs["n_boot"] = 10
if kwargs["data"] == "long":
kwargs["data"] = long_df
elif kwargs["data"] == "wide":
kwargs["data"] = wide_df
elif kwargs["data"] == "flat":
kwargs["data"] = flat_series
elif kwargs["data"] == "null":
kwargs["data"] = null_df
elif kwargs["data"] is None:
for var in ["x", "y", "hue"]:
if var in kwargs:
kwargs[var] = long_df[kwargs[var]]
ax = barplot(**kwargs)
g = catplot(**kwargs, kind="bar")
assert_plots_equal(ax, g.ax)
def test_errwidth_deprecation(self):
x, y = ["a", "b", "c"], [1, 2, 3]
val = 5
with pytest.warns(FutureWarning, match="\n\nThe `errwidth` parameter"):
ax = barplot(x=x, y=y, errwidth=val)
for line in ax.lines:
assert line.get_linewidth() == val
def test_errcolor_deprecation(self):
x, y = ["a", "b", "c"], [1, 2, 3]
val = (1, .7, .4, .8)
with pytest.warns(FutureWarning, match="\n\nThe `errcolor` parameter"):
ax = barplot(x=x, y=y, errcolor=val)
for line in ax.lines:
assert line.get_color() == val
def test_capsize_as_none_deprecation(self):
x, y = ["a", "b", "c"], [1, 2, 3]
with pytest.warns(FutureWarning, match="\n\nPassing `capsize=None`"):
ax = barplot(x=x, y=y, capsize=None)
for line in ax.lines:
assert len(line.get_xdata()) == 2
def test_hue_implied_by_palette_deprecation(self):
x = ["a", "b", "c"]
y = [1, 2, 3]
palette = "Set1"
colors = color_palette(palette, len(x))
msg = "Passing `palette` without assigning `hue` is deprecated."
with pytest.warns(FutureWarning, match=msg):
ax = barplot(x=x, y=y, saturation=1, palette=palette)
for i, bar in enumerate(ax.patches):
assert same_color(bar.get_facecolor(), colors[i])
class TestPointPlot(SharedAggTests):
func = staticmethod(pointplot)
def get_last_color(self, ax):
color = ax.lines[-1].get_color()
return to_rgba(color)
@pytest.mark.parametrize("orient", ["x", "y"])
def test_single_var(self, orient):
vals = pd.Series([1, 3, 10])
ax = pointplot(**{orient: vals})
line = ax.lines[0]
assert getattr(line, f"get_{orient}data")() == approx(vals.mean())
@pytest.mark.parametrize("orient", ["x", "y", "h", "v"])
def test_wide_df(self, wide_df, orient):
ax = pointplot(wide_df, orient=orient)
orient = {"h": "y", "v": "x"}.get(orient, orient)
depend = {"x": "y", "y": "x"}[orient]
line = ax.lines[0]
assert_array_equal(
getattr(line, f"get_{orient}data")(),
np.arange(len(wide_df.columns)),
)
assert_array_almost_equal(
getattr(line, f"get_{depend}data")(),
wide_df.mean(axis=0),
)
@pytest.mark.parametrize("orient", ["x", "y", "h", "v"])
def test_vector_orient(self, orient):
keys, vals = ["a", "b", "c"], [1, 2, 3]
data = dict(zip(keys, vals))
orient = {"h": "y", "v": "x"}.get(orient, orient)
depend = {"x": "y", "y": "x"}[orient]
ax = pointplot(data, orient=orient)
line = ax.lines[0]
assert_array_equal(
getattr(line, f"get_{orient}data")(),
np.arange(len(keys)),
)
assert_array_equal(getattr(line, f"get_{depend}data")(), vals)
def test_xy_vertical(self):
x, y = ["a", "b", "c"], [1, 3, 2.5]
ax = pointplot(x=x, y=y)
for i, xy in enumerate(ax.lines[0].get_xydata()):
assert tuple(xy) == (i, y[i])
def test_xy_horizontal(self):
x, y = [1, 3, 2.5], ["a", "b", "c"]
ax = pointplot(x=x, y=y)
for i, xy in enumerate(ax.lines[0].get_xydata()):
assert tuple(xy) == (x[i], i)
def test_xy_with_na_grouper(self):
x, y = ["a", None, "b"], [1, 2, 3]
ax = pointplot(x=x, y=y)
_draw_figure(ax.figure) # For matplotlib<3.5
assert ax.get_xticks() == [0, 1]
assert [t.get_text() for t in ax.get_xticklabels()] == ["a", "b"]
assert_array_equal(ax.lines[0].get_xdata(), [0, 1])
assert_array_equal(ax.lines[0].get_ydata(), [1, 3])
def test_xy_with_na_value(self):
x, y = ["a", "b", "c"], [1, np.nan, 3]
ax = pointplot(x=x, y=y)
_draw_figure(ax.figure) # For matplotlib<3.5
assert ax.get_xticks() == [0, 1, 2]
assert [t.get_text() for t in ax.get_xticklabels()] == x
assert_array_equal(ax.lines[0].get_xdata(), [0, 1, 2])
assert_array_equal(ax.lines[0].get_ydata(), y)
def test_hue(self):
x, y = ["a", "a", "b", "b"], [1, 2, 3, 4]
hue = ["x", "y", "x", "y"]
ax = pointplot(x=x, y=y, hue=hue, errorbar=None)
for i, line in enumerate(ax.lines[:2]):
assert_array_equal(line.get_ydata(), y[i::2])
assert same_color(line.get_color(), f"C{i}")
def test_wide_data_is_joined(self, wide_df):
ax = pointplot(wide_df, errorbar=None)
assert len(ax.lines) == 1
def test_xy_native_scale(self):
x, y = [2, 4, 8], [1, 2, 3]
ax = pointplot(x=x, y=y, native_scale=True)
line = ax.lines[0]
assert_array_equal(line.get_xdata(), x)
assert_array_equal(line.get_ydata(), y)
# Use lambda around np.mean to avoid uninformative pandas deprecation warning
@pytest.mark.parametrize("estimator", ["mean", lambda x: np.mean(x)])
def test_estimate(self, long_df, estimator):
agg_var, val_var = "a", "y"
agg_df = long_df.groupby(agg_var)[val_var].agg(estimator)
ax = pointplot(long_df, x=agg_var, y=val_var, errorbar=None)
order = categorical_order(long_df[agg_var])
for i, xy in enumerate(ax.lines[0].get_xydata()):
assert tuple(xy) == approx((i, agg_df[order[i]]))
def test_weighted_estimate(self, long_df):
ax = pointplot(long_df, y="y", weights="x")
val = ax.lines[0].get_ydata().item()
expected = np.average(long_df["y"], weights=long_df["x"])
assert val == expected
def test_estimate_log_transform(self, long_df):
ax = mpl.figure.Figure().subplots()
ax.set_xscale("log")
pointplot(x=long_df["z"], ax=ax)
val, = ax.lines[0].get_xdata()
assert val == 10 ** np.log10(long_df["z"]).mean()
def test_errorbars(self, long_df):
agg_var, val_var = "a", "y"
agg_df = long_df.groupby(agg_var)[val_var].agg(["mean", "std"])
ax = pointplot(long_df, x=agg_var, y=val_var, errorbar="sd")
order = categorical_order(long_df[agg_var])
for i, line in enumerate(ax.lines[1:]):
row = agg_df.loc[order[i]]
lo, hi = line.get_ydata()
assert lo == approx(row["mean"] - row["std"])
assert hi == approx(row["mean"] + row["std"])
def test_marker_linestyle(self):
x, y = ["a", "b", "c"], [1, 2, 3]
ax = pointplot(x=x, y=y, marker="s", linestyle="--")
line = ax.lines[0]
assert line.get_marker() == "s"
assert line.get_linestyle() == "--"
def test_markers_linestyles_single(self):
x, y = ["a", "b", "c"], [1, 2, 3]
ax = pointplot(x=x, y=y, markers="s", linestyles="--")
line = ax.lines[0]
assert line.get_marker() == "s"
assert line.get_linestyle() == "--"
def test_markers_linestyles_mapped(self):
x, y = ["a", "a", "b", "b"], [1, 2, 3, 4]
hue = ["x", "y", "x", "y"]
markers = ["d", "s"]
linestyles = ["--", ":"]
ax = pointplot(
x=x, y=y, hue=hue,
markers=markers, linestyles=linestyles,
errorbar=None,
)
for i, line in enumerate(ax.lines[:2]):
assert line.get_marker() == markers[i]
assert line.get_linestyle() == linestyles[i]
def test_dodge_boolean(self):
x, y = ["a", "b", "a", "b"], [1, 2, 3, 4]
hue = ["x", "x", "y", "y"]
ax = pointplot(x=x, y=y, hue=hue, dodge=True, errorbar=None)
for i, xy in enumerate(ax.lines[0].get_xydata()):
assert tuple(xy) == (i - .025, y[i])
for i, xy in enumerate(ax.lines[1].get_xydata()):
assert tuple(xy) == (i + .025, y[2 + i])
def test_dodge_float(self):
x, y = ["a", "b", "a", "b"], [1, 2, 3, 4]
hue = ["x", "x", "y", "y"]
ax = pointplot(x=x, y=y, hue=hue, dodge=.2, errorbar=None)
for i, xy in enumerate(ax.lines[0].get_xydata()):
assert tuple(xy) == (i - .1, y[i])
for i, xy in enumerate(ax.lines[1].get_xydata()):
assert tuple(xy) == (i + .1, y[2 + i])
def test_dodge_log_scale(self):
x, y = [10, 1000, 10, 1000], [1, 2, 3, 4]
hue = ["x", "x", "y", "y"]
ax = mpl.figure.Figure().subplots()
ax.set_xscale("log")
pointplot(x=x, y=y, hue=hue, dodge=.2, native_scale=True, errorbar=None, ax=ax)
for i, xy in enumerate(ax.lines[0].get_xydata()):
assert tuple(xy) == approx((10 ** (np.log10(x[i]) - .2), y[i]))
for i, xy in enumerate(ax.lines[1].get_xydata()):
assert tuple(xy) == approx((10 ** (np.log10(x[2 + i]) + .2), y[2 + i]))
def test_err_kws(self):
x, y = ["a", "a", "b", "b"], [1, 2, 3, 4]
err_kws = dict(color=(.2, .5, .3), linewidth=10)
ax = pointplot(x=x, y=y, errorbar=("pi", 100), err_kws=err_kws)
for line in ax.lines[1:]:
assert same_color(line.get_color(), err_kws["color"])
assert line.get_linewidth() == err_kws["linewidth"]
def test_err_kws_inherited(self):
x, y = ["a", "a", "b", "b"], [1, 2, 3, 4]
kws = dict(color=(.2, .5, .3), linewidth=10)
ax = pointplot(x=x, y=y, errorbar=("pi", 100), **kws)
for line in ax.lines[1:]:
assert same_color(line.get_color(), kws["color"])
assert line.get_linewidth() == kws["linewidth"]
@pytest.mark.skipif(
_version_predates(mpl, "3.6"),
reason="Legend handle missing marker property"
)
def test_legend_contents(self):
x, y = ["a", "a", "b", "b"], [1, 2, 3, 4]
hue = ["x", "y", "x", "y"]
ax = pointplot(x=x, y=y, hue=hue)
_draw_figure(ax.figure)
legend = ax.get_legend()
assert [t.get_text() for t in legend.texts] == ["x", "y"]
for i, handle in enumerate(get_legend_handles(legend)):
assert handle.get_marker() == "o"
assert handle.get_linestyle() == "-"
assert same_color(handle.get_color(), f"C{i}")
@pytest.mark.skipif(
_version_predates(mpl, "3.6"),
reason="Legend handle missing marker property"
)
def test_legend_set_props(self):
x, y = ["a", "a", "b", "b"], [1, 2, 3, 4]
hue = ["x", "y", "x", "y"]
kws = dict(marker="s", linewidth=1)
ax = pointplot(x=x, y=y, hue=hue, **kws)
legend = ax.get_legend()
for i, handle in enumerate(get_legend_handles(legend)):
assert handle.get_marker() == kws["marker"]
assert handle.get_linewidth() == kws["linewidth"]
@pytest.mark.skipif(
_version_predates(mpl, "3.6"),
reason="Legend handle missing marker property"
)
def test_legend_synced_props(self):
x, y = ["a", "a", "b", "b"], [1, 2, 3, 4]
hue = ["x", "y", "x", "y"]
kws = dict(markers=["s", "d"], linestyles=["--", ":"])
ax = pointplot(x=x, y=y, hue=hue, **kws)
legend = ax.get_legend()
for i, handle in enumerate(get_legend_handles(legend)):
assert handle.get_marker() == kws["markers"][i]
assert handle.get_linestyle() == kws["linestyles"][i]
@pytest.mark.parametrize(
"kwargs",
[
dict(data="wide"),
dict(data="wide", orient="h"),
dict(data="flat"),
dict(data="long", x="a", y="y"),
dict(data=None, x="a", y="y"),
dict(data="long", x="a", y="y", hue="a"),
dict(data=None, x="a", y="y", hue="a"),
dict(data="long", x="a", y="y", hue="b"),
dict(data=None, x="s", y="y", hue="a"),
dict(data="long", x="a", y="y", hue="s"),
dict(data="long", x="a", y="y", units="c"),
dict(data="null", x="a", y="y", hue="a"),
dict(data="long", x="s", y="y", hue="a", native_scale=True),
dict(data="long", x="d", y="y", hue="a", native_scale=True),
dict(data="long", x="a", y="y", errorbar=("pi", 50)),
dict(data="long", x="a", y="y", errorbar=None),
dict(data="null", x="a", y="y", hue="a", dodge=True),
dict(data="null", x="a", y="y", hue="a", dodge=.2),
dict(data="long", x="a", y="y", capsize=.3, err_kws=dict(c="k")),
dict(data="long", x="a", y="y", color="blue", marker="s"),
dict(data="long", x="a", y="y", hue="a", markers=["s", "d", "p"]),
]
)
def test_vs_catplot(self, long_df, wide_df, null_df, flat_series, kwargs):
kwargs = kwargs.copy()
kwargs["seed"] = 0
kwargs["n_boot"] = 10
if kwargs["data"] == "long":
kwargs["data"] = long_df
elif kwargs["data"] == "wide":
kwargs["data"] = wide_df
elif kwargs["data"] == "flat":
kwargs["data"] = flat_series
elif kwargs["data"] == "null":
kwargs["data"] = null_df
elif kwargs["data"] is None:
for var in ["x", "y", "hue"]:
if var in kwargs:
kwargs[var] = long_df[kwargs[var]]
ax = pointplot(**kwargs)
g = catplot(**kwargs, kind="point")
assert_plots_equal(ax, g.ax)
def test_legend_disabled(self, long_df):
ax = pointplot(long_df, x="x", y="y", hue="b", legend=False)
assert ax.get_legend() is None
def test_join_deprecation(self):
with pytest.warns(UserWarning, match="The `join` parameter"):
ax = pointplot(x=["a", "b", "c"], y=[1, 2, 3], join=False)
assert ax.lines[0].get_linestyle().lower() == "none"
def test_scale_deprecation(self):
x, y = ["a", "b", "c"], [1, 2, 3]
ax = pointplot(x=x, y=y, errorbar=None)
with pytest.warns(UserWarning, match="The `scale` parameter"):
pointplot(x=x, y=y, errorbar=None, scale=2)
l1, l2 = ax.lines
assert l2.get_linewidth() == 2 * l1.get_linewidth()
assert l2.get_markersize() > l1.get_markersize()
def test_layered_plot_clipping(self):
x, y = ['a'], [4]
pointplot(x=x, y=y)
x, y = ['b'], [5]
ax = pointplot(x=x, y=y)
y_range = ax.viewLim.intervaly
assert y_range[0] < 4 and y_range[1] > 5
class TestCountPlot:
def test_empty(self):
ax = countplot()
assert not ax.patches
ax = countplot(x=[])
assert not ax.patches
def test_labels_long(self, long_df):
fig = mpl.figure.Figure()
axs = fig.subplots(2)
countplot(long_df, x="a", ax=axs[0])
countplot(long_df, x="b", stat="percent", ax=axs[1])
# To populate texts; only needed on older matplotlibs
_draw_figure(fig)
assert axs[0].get_xlabel() == "a"
assert axs[1].get_xlabel() == "b"
assert axs[0].get_ylabel() == "count"
assert axs[1].get_ylabel() == "percent"
def test_wide_data(self, wide_df):
ax = countplot(wide_df)
assert len(ax.patches) == len(wide_df.columns)
for i, bar in enumerate(ax.patches):
assert bar.get_x() + bar.get_width() / 2 == approx(i)
assert bar.get_y() == 0
assert bar.get_height() == len(wide_df)
assert bar.get_width() == approx(0.8)
def test_flat_series(self):
vals = ["a", "b", "c"]
counts = [2, 1, 4]
vals = pd.Series([x for x, n in zip(vals, counts) for _ in range(n)])
ax = countplot(vals)
for i, bar in enumerate(ax.patches):
assert bar.get_x() == 0
assert bar.get_y() + bar.get_height() / 2 == approx(i)
assert bar.get_height() == approx(0.8)
assert bar.get_width() == counts[i]
def test_x_series(self):
vals = ["a", "b", "c"]
counts = [2, 1, 4]
vals = pd.Series([x for x, n in zip(vals, counts) for _ in range(n)])
ax = countplot(x=vals)
for i, bar in enumerate(ax.patches):
assert bar.get_x() + bar.get_width() / 2 == approx(i)
assert bar.get_y() == 0
assert bar.get_height() == counts[i]
assert bar.get_width() == approx(0.8)
def test_y_series(self):
vals = ["a", "b", "c"]
counts = [2, 1, 4]
vals = pd.Series([x for x, n in zip(vals, counts) for _ in range(n)])
ax = countplot(y=vals)
for i, bar in enumerate(ax.patches):
assert bar.get_x() == 0
assert bar.get_y() + bar.get_height() / 2 == approx(i)
assert bar.get_height() == approx(0.8)
assert bar.get_width() == counts[i]
def test_hue_redundant(self):
vals = ["a", "b", "c"]
counts = [2, 1, 4]
vals = pd.Series([x for x, n in zip(vals, counts) for _ in range(n)])
ax = countplot(x=vals, hue=vals, saturation=1)
for i, bar in enumerate(ax.patches):
assert bar.get_x() + bar.get_width() / 2 == approx(i)
assert bar.get_y() == 0
assert bar.get_height() == counts[i]
assert bar.get_width() == approx(0.8)
assert same_color(bar.get_facecolor(), f"C{i}")
def test_hue_dodged(self):
vals = ["a", "a", "a", "b", "b", "b"]
hue = ["x", "y", "y", "x", "x", "x"]
counts = [1, 3, 2, 0]
ax = countplot(x=vals, hue=hue, saturation=1, legend=False)
for i, bar in enumerate(ax.patches):
sign = 1 if i // 2 else -1
assert (
bar.get_x() + bar.get_width() / 2
== approx(i % 2 + sign * 0.8 / 4)
)
assert bar.get_y() == 0
assert bar.get_height() == counts[i]
assert bar.get_width() == approx(0.8 / 2)
assert same_color(bar.get_facecolor(), f"C{i // 2}")
@pytest.mark.parametrize("stat", ["percent", "probability", "proportion"])
def test_stat(self, long_df, stat):
col = "a"
order = categorical_order(long_df[col])
expected = long_df[col].value_counts(normalize=True)
if stat == "percent":
expected *= 100
ax = countplot(long_df, x=col, stat=stat)
for i, bar in enumerate(ax.patches):
assert bar.get_height() == approx(expected[order[i]])
def test_xy_error(self, long_df):
with pytest.raises(TypeError, match="Cannot pass values for both"):
countplot(long_df, x="a", y="b")
def test_legend_numeric_auto(self, long_df):
ax = countplot(long_df, x="x", hue="x")
assert len(ax.get_legend().texts) <= 6
def test_legend_disabled(self, long_df):
ax = countplot(long_df, x="x", hue="b", legend=False)
assert ax.get_legend() is None
@pytest.mark.parametrize(
"kwargs",
[
dict(data="wide"),
dict(data="wide", orient="h"),
dict(data="flat"),
dict(data="long", x="a"),
dict(data=None, x="a"),
dict(data="long", y="b"),
dict(data="long", x="a", hue="a"),
dict(data=None, x="a", hue="a"),
dict(data="long", x="a", hue="b"),
dict(data=None, x="s", hue="a"),
dict(data="long", x="a", hue="s"),
dict(data="null", x="a", hue="a"),
dict(data="long", x="s", hue="a", native_scale=True),
dict(data="long", x="d", hue="a", native_scale=True),
dict(data="long", x="a", stat="percent"),
dict(data="long", x="a", hue="b", stat="proportion"),
dict(data="long", x="a", color="blue", ec="green", alpha=.5),
]
)
def test_vs_catplot(self, long_df, wide_df, null_df, flat_series, kwargs):
kwargs = kwargs.copy()
if kwargs["data"] == "long":
kwargs["data"] = long_df
elif kwargs["data"] == "wide":
kwargs["data"] = wide_df
elif kwargs["data"] == "flat":
kwargs["data"] = flat_series
elif kwargs["data"] == "null":
kwargs["data"] = null_df
elif kwargs["data"] is None:
for var in ["x", "y", "hue"]:
if var in kwargs:
kwargs[var] = long_df[kwargs[var]]
ax = countplot(**kwargs)
g = catplot(**kwargs, kind="count")
assert_plots_equal(ax, g.ax)
class CategoricalFixture:
"""Test boxplot (also base class for things like violinplots)."""
rs = np.random.RandomState(30)
n_total = 60
x = rs.randn(int(n_total / 3), 3)
x_df = pd.DataFrame(x, columns=pd.Series(list("XYZ"), name="big"))
y = pd.Series(rs.randn(n_total), name="y_data")
y_perm = y.reindex(rs.choice(y.index, y.size, replace=False))
g = pd.Series(np.repeat(list("abc"), int(n_total / 3)), name="small")
h = pd.Series(np.tile(list("mn"), int(n_total / 2)), name="medium")
u = pd.Series(np.tile(list("jkh"), int(n_total / 3)))
df = pd.DataFrame(dict(y=y, g=g, h=h, u=u))
x_df["W"] = g
def get_box_artists(self, ax):
if _version_predates(mpl, "3.5.0b0"):
return ax.artists
else:
# Exclude labeled patches, which are for the legend
return [p for p in ax.patches if not p.get_label()]
class TestCatPlot(CategoricalFixture):
def test_facet_organization(self):
g = cat.catplot(x="g", y="y", data=self.df)
assert g.axes.shape == (1, 1)
g = cat.catplot(x="g", y="y", col="h", data=self.df)
assert g.axes.shape == (1, 2)
g = cat.catplot(x="g", y="y", row="h", data=self.df)
assert g.axes.shape == (2, 1)
g = cat.catplot(x="g", y="y", col="u", row="h", data=self.df)
assert g.axes.shape == (2, 3)
def test_plot_elements(self):
g = cat.catplot(x="g", y="y", data=self.df, kind="point")
want_lines = 1 + self.g.unique().size
assert len(g.ax.lines) == want_lines
g = cat.catplot(x="g", y="y", hue="h", data=self.df, kind="point")
want_lines = (
len(self.g.unique()) * len(self.h.unique()) + 2 * len(self.h.unique())
)
assert len(g.ax.lines) == want_lines
g = cat.catplot(x="g", y="y", data=self.df, kind="bar")
want_elements = self.g.unique().size
assert len(g.ax.patches) == want_elements
assert len(g.ax.lines) == want_elements
g = cat.catplot(x="g", y="y", hue="h", data=self.df, kind="bar")
want_elements = self.g.nunique() * self.h.nunique()
assert len(g.ax.patches) == (want_elements + self.h.nunique())
assert len(g.ax.lines) == want_elements
g = cat.catplot(x="g", data=self.df, kind="count")
want_elements = self.g.unique().size
assert len(g.ax.patches) == want_elements
assert len(g.ax.lines) == 0
g = cat.catplot(x="g", hue="h", data=self.df, kind="count")
want_elements = self.g.nunique() * self.h.nunique() + self.h.nunique()
assert len(g.ax.patches) == want_elements
assert len(g.ax.lines) == 0
g = cat.catplot(y="y", data=self.df, kind="box")
want_artists = 1
assert len(self.get_box_artists(g.ax)) == want_artists
g = cat.catplot(x="g", y="y", data=self.df, kind="box")
want_artists = self.g.unique().size
assert len(self.get_box_artists(g.ax)) == want_artists
g = cat.catplot(x="g", y="y", hue="h", data=self.df, kind="box")
want_artists = self.g.nunique() * self.h.nunique()
assert len(self.get_box_artists(g.ax)) == want_artists
g = cat.catplot(x="g", y="y", data=self.df,
kind="violin", inner=None)
want_elements = self.g.unique().size
assert len(g.ax.collections) == want_elements
g = cat.catplot(x="g", y="y", hue="h", data=self.df,
kind="violin", inner=None)
want_elements = self.g.nunique() * self.h.nunique()
assert len(g.ax.collections) == want_elements
g = cat.catplot(x="g", y="y", data=self.df, kind="strip")
want_elements = self.g.unique().size
assert len(g.ax.collections) == want_elements
for strip in g.ax.collections:
assert same_color(strip.get_facecolors(), "C0")
g = cat.catplot(x="g", y="y", hue="h", data=self.df, kind="strip")
want_elements = self.g.nunique()
assert len(g.ax.collections) == want_elements
def test_bad_plot_kind_error(self):
with pytest.raises(ValueError):
cat.catplot(x="g", y="y", data=self.df, kind="not_a_kind")
def test_count_x_and_y(self):
with pytest.raises(ValueError):
cat.catplot(x="g", y="y", data=self.df, kind="count")
def test_plot_colors(self):
ax = cat.barplot(x="g", y="y", data=self.df)
g = cat.catplot(x="g", y="y", data=self.df, kind="bar")
for p1, p2 in zip(ax.patches, g.ax.patches):
assert p1.get_facecolor() == p2.get_facecolor()
plt.close("all")
ax = cat.barplot(x="g", y="y", data=self.df, color="purple")
g = cat.catplot(x="g", y="y", data=self.df,
kind="bar", color="purple")
for p1, p2 in zip(ax.patches, g.ax.patches):
assert p1.get_facecolor() == p2.get_facecolor()
plt.close("all")
ax = cat.barplot(x="g", y="y", data=self.df, palette="Set2", hue="h")
g = cat.catplot(x="g", y="y", data=self.df,
kind="bar", palette="Set2", hue="h")
for p1, p2 in zip(ax.patches, g.ax.patches):
assert p1.get_facecolor() == p2.get_facecolor()
plt.close("all")
ax = cat.pointplot(x="g", y="y", data=self.df)
g = cat.catplot(x="g", y="y", data=self.df)
for l1, l2 in zip(ax.lines, g.ax.lines):
assert l1.get_color() == l2.get_color()
plt.close("all")
ax = cat.pointplot(x="g", y="y", data=self.df, color="purple")
g = cat.catplot(x="g", y="y", data=self.df, color="purple", kind="point")
for l1, l2 in zip(ax.lines, g.ax.lines):
assert l1.get_color() == l2.get_color()
plt.close("all")
ax = cat.pointplot(x="g", y="y", data=self.df, palette="Set2", hue="h")
g = cat.catplot(
x="g", y="y", data=self.df, palette="Set2", hue="h", kind="point"
)
for l1, l2 in zip(ax.lines, g.ax.lines):
assert l1.get_color() == l2.get_color()
plt.close("all")
def test_ax_kwarg_removal(self):
f, ax = plt.subplots()
with pytest.warns(UserWarning, match="catplot is a figure-level"):
g = cat.catplot(x="g", y="y", data=self.df, ax=ax)
assert len(ax.collections) == 0
assert len(g.ax.collections) > 0
def test_share_xy(self):
# Test default behavior works
g = cat.catplot(x="g", y="y", col="g", data=self.df, sharex=True)
for ax in g.axes.flat:
assert len(ax.collections) == len(self.df.g.unique())
g = cat.catplot(x="y", y="g", col="g", data=self.df, sharey=True)
for ax in g.axes.flat:
assert len(ax.collections) == len(self.df.g.unique())
# Test unsharing works
g = cat.catplot(
x="g", y="y", col="g", data=self.df, sharex=False, kind="bar",
)
for ax in g.axes.flat:
assert len(ax.patches) == 1
g = cat.catplot(
x="y", y="g", col="g", data=self.df, sharey=False, kind="bar",
)
for ax in g.axes.flat:
assert len(ax.patches) == 1
g = cat.catplot(
x="g", y="y", col="g", data=self.df, sharex=False, color="b"
)
for ax in g.axes.flat:
assert ax.get_xlim() == (-.5, .5)
g = cat.catplot(
x="y", y="g", col="g", data=self.df, sharey=False, color="r"
)
for ax in g.axes.flat:
assert ax.get_ylim() == (.5, -.5)
# Make sure order is used if given, regardless of sharex value
order = self.df.g.unique()
g = cat.catplot(x="g", y="y", col="g", data=self.df, sharex=False, order=order)
for ax in g.axes.flat:
assert len(ax.collections) == len(self.df.g.unique())
g = cat.catplot(x="y", y="g", col="g", data=self.df, sharey=False, order=order)
for ax in g.axes.flat:
assert len(ax.collections) == len(self.df.g.unique())
def test_facetgrid_data(self, long_df):
g1 = catplot(data=long_df, x="a", y="y", col="c")
assert g1.data is long_df
g2 = catplot(x=long_df["a"], y=long_df["y"], col=long_df["c"])
assert g2.data.equals(long_df[["a", "y", "c"]])
@pytest.mark.parametrize("var", ["col", "row"])
def test_array_faceter(self, long_df, var):
g1 = catplot(data=long_df, x="y", **{var: "a"})
g2 = catplot(data=long_df, x="y", **{var: long_df["a"].to_numpy()})
for ax1, ax2 in zip(g1.axes.flat, g2.axes.flat):
assert_plots_equal(ax1, ax2)
def test_invalid_kind(self, long_df):
with pytest.raises(ValueError, match="Invalid `kind`: 'wrong'"):
catplot(long_df, kind="wrong")
def test_legend_with_auto(self):
g1 = catplot(self.df, x="g", y="y", hue="g", legend='auto')
assert g1._legend is None
g2 = catplot(self.df, x="g", y="y", hue="g", legend=True)
assert g2._legend is not None
def test_weights_warning(self, long_df):
with pytest.warns(UserWarning, match="The `weights` parameter"):
g = catplot(long_df, x="a", y="y", weights="z")
assert g.ax is not None
class TestBeeswarm:
def test_could_overlap(self):
p = Beeswarm()
neighbors = p.could_overlap(
(1, 1, .5),
[(0, 0, .5),
(1, .1, .2),
(.5, .5, .5)]
)
assert_array_equal(neighbors, [(.5, .5, .5)])
def test_position_candidates(self):
p = Beeswarm()
xy_i = (0, 1, .5)
neighbors = [(0, 1, .5), (0, 1.5, .5)]
candidates = p.position_candidates(xy_i, neighbors)
dx1 = 1.05
dx2 = np.sqrt(1 - .5 ** 2) * 1.05
assert_array_equal(
candidates,
[(0, 1, .5), (-dx1, 1, .5), (dx1, 1, .5), (dx2, 1, .5), (-dx2, 1, .5)]
)
def test_find_first_non_overlapping_candidate(self):
p = Beeswarm()
candidates = [(.5, 1, .5), (1, 1, .5), (1.5, 1, .5)]
neighbors = np.array([(0, 1, .5)])
first = p.first_non_overlapping_candidate(candidates, neighbors)
assert_array_equal(first, (1, 1, .5))
def test_beeswarm(self, long_df):
p = Beeswarm()
data = long_df["y"]
d = data.diff().mean() * 1.5
x = np.zeros(data.size)
y = np.sort(data)
r = np.full_like(y, d)
orig_xyr = np.c_[x, y, r]
swarm = p.beeswarm(orig_xyr)[:, :2]
dmat = np.sqrt(np.sum(np.square(swarm[:, np.newaxis] - swarm), axis=-1))
triu = dmat[np.triu_indices_from(dmat, 1)]
assert_array_less(d, triu)
assert_array_equal(y, swarm[:, 1])
def test_add_gutters(self):
p = Beeswarm(width=1)
points = np.zeros(10)
t_fwd = t_inv = lambda x: x
assert_array_equal(points, p.add_gutters(points, 0, t_fwd, t_inv))
points = np.array([0, -1, .4, .8])
msg = r"50.0% of the points cannot be placed.+$"
with pytest.warns(UserWarning, match=msg):
new_points = p.add_gutters(points, 0, t_fwd, t_inv)
assert_array_equal(new_points, np.array([0, -.5, .4, .5]))
class TestBoxPlotContainer:
@pytest.fixture
def container(self, wide_array):
ax = mpl.figure.Figure().subplots()
artist_dict = ax.boxplot(wide_array)
return BoxPlotContainer(artist_dict)
def test_repr(self, container, wide_array):
n = wide_array.shape[1]
assert str(container) == f""
def test_iteration(self, container):
for artist_tuple in container:
for attr in ["box", "median", "whiskers", "caps", "fliers", "mean"]:
assert hasattr(artist_tuple, attr)
def test_label(self, container):
label = "a box plot"
container.set_label(label)
assert container.get_label() == label
def test_children(self, container):
children = container.get_children()
for child in children:
assert isinstance(child, mpl.artist.Artist)
================================================
FILE: tests/test_distributions.py
================================================
import itertools
import warnings
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgb, to_rgba
import pytest
from numpy.testing import assert_array_equal, assert_array_almost_equal
from seaborn import distributions as dist
from seaborn.palettes import (
color_palette,
light_palette,
)
from seaborn._base import (
categorical_order,
)
from seaborn._statistics import (
KDE,
Histogram,
_no_scipy,
)
from seaborn.distributions import (
_DistributionPlotter,
displot,
distplot,
histplot,
ecdfplot,
kdeplot,
rugplot,
)
from seaborn.utils import _version_predates
from seaborn.axisgrid import FacetGrid
from seaborn._testing import (
assert_plots_equal,
assert_legends_equal,
assert_colors_equal,
)
def get_contour_coords(c, filter_empty=False):
"""Provide compatability for change in contour artist types."""
if isinstance(c, mpl.collections.LineCollection):
# See https://github.com/matplotlib/matplotlib/issues/20906
return c.get_segments()
elif isinstance(c, (mpl.collections.PathCollection, mpl.contour.QuadContourSet)):
return [
p.vertices[:np.argmax(p.codes) + 1] for p in c.get_paths()
if len(p) or not filter_empty
]
def get_contour_color(c):
"""Provide compatability for change in contour artist types."""
if isinstance(c, mpl.collections.LineCollection):
# See https://github.com/matplotlib/matplotlib/issues/20906
return c.get_color()
elif isinstance(c, (mpl.collections.PathCollection, mpl.contour.QuadContourSet)):
if c.get_facecolor().size:
return c.get_facecolor()
else:
return c.get_edgecolor()
class TestDistPlot:
rs = np.random.RandomState(0)
x = rs.randn(100)
def test_hist_bins(self):
fd_edges = np.histogram_bin_edges(self.x, "fd")
with pytest.warns(UserWarning):
ax = distplot(self.x)
for edge, bar in zip(fd_edges, ax.patches):
assert pytest.approx(edge) == bar.get_x()
plt.close(ax.figure)
n = 25
n_edges = np.histogram_bin_edges(self.x, n)
with pytest.warns(UserWarning):
ax = distplot(self.x, bins=n)
for edge, bar in zip(n_edges, ax.patches):
assert pytest.approx(edge) == bar.get_x()
def test_elements(self):
with pytest.warns(UserWarning):
n = 10
ax = distplot(self.x, bins=n,
hist=True, kde=False, rug=False, fit=None)
assert len(ax.patches) == 10
assert len(ax.lines) == 0
assert len(ax.collections) == 0
plt.close(ax.figure)
ax = distplot(self.x,
hist=False, kde=True, rug=False, fit=None)
assert len(ax.patches) == 0
assert len(ax.lines) == 1
assert len(ax.collections) == 0
plt.close(ax.figure)
ax = distplot(self.x,
hist=False, kde=False, rug=True, fit=None)
assert len(ax.patches) == 0
assert len(ax.lines) == 0
assert len(ax.collections) == 1
class Norm:
"""Dummy object that looks like a scipy RV"""
def fit(self, x):
return ()
def pdf(self, x, *params):
return np.zeros_like(x)
plt.close(ax.figure)
ax = distplot(
self.x, hist=False, kde=False, rug=False, fit=Norm())
assert len(ax.patches) == 0
assert len(ax.lines) == 1
assert len(ax.collections) == 0
def test_distplot_with_nans(self):
f, (ax1, ax2) = plt.subplots(2)
x_null = np.append(self.x, [np.nan])
with pytest.warns(UserWarning):
distplot(self.x, ax=ax1)
distplot(x_null, ax=ax2)
line1 = ax1.lines[0]
line2 = ax2.lines[0]
assert np.array_equal(line1.get_xydata(), line2.get_xydata())
for bar1, bar2 in zip(ax1.patches, ax2.patches):
assert bar1.get_xy() == bar2.get_xy()
assert bar1.get_height() == bar2.get_height()
class SharedAxesLevelTests:
def test_color(self, long_df, **kwargs):
ax = plt.figure().subplots()
self.func(data=long_df, x="y", ax=ax, **kwargs)
assert_colors_equal(self.get_last_color(ax, **kwargs), "C0", check_alpha=False)
ax = plt.figure().subplots()
self.func(data=long_df, x="y", ax=ax, **kwargs)
self.func(data=long_df, x="y", ax=ax, **kwargs)
assert_colors_equal(self.get_last_color(ax, **kwargs), "C1", check_alpha=False)
ax = plt.figure().subplots()
self.func(data=long_df, x="y", color="C2", ax=ax, **kwargs)
assert_colors_equal(self.get_last_color(ax, **kwargs), "C2", check_alpha=False)
class TestRugPlot(SharedAxesLevelTests):
func = staticmethod(rugplot)
def get_last_color(self, ax, **kwargs):
return ax.collections[-1].get_color()
def assert_rug_equal(self, a, b):
assert_array_equal(a.get_segments(), b.get_segments())
@pytest.mark.parametrize("variable", ["x", "y"])
def test_long_data(self, long_df, variable):
vector = long_df[variable]
vectors = [
variable, vector, np.asarray(vector), vector.to_list(),
]
f, ax = plt.subplots()
for vector in vectors:
rugplot(data=long_df, **{variable: vector})
for a, b in itertools.product(ax.collections, ax.collections):
self.assert_rug_equal(a, b)
def test_bivariate_data(self, long_df):
f, (ax1, ax2) = plt.subplots(ncols=2)
rugplot(data=long_df, x="x", y="y", ax=ax1)
rugplot(data=long_df, x="x", ax=ax2)
rugplot(data=long_df, y="y", ax=ax2)
self.assert_rug_equal(ax1.collections[0], ax2.collections[0])
self.assert_rug_equal(ax1.collections[1], ax2.collections[1])
def test_wide_vs_long_data(self, wide_df):
f, (ax1, ax2) = plt.subplots(ncols=2)
rugplot(data=wide_df, ax=ax1)
for col in wide_df:
rugplot(data=wide_df, x=col, ax=ax2)
wide_segments = np.sort(
np.array(ax1.collections[0].get_segments())
)
long_segments = np.sort(
np.concatenate([c.get_segments() for c in ax2.collections])
)
assert_array_equal(wide_segments, long_segments)
def test_flat_vector(self, long_df):
f, ax = plt.subplots()
rugplot(data=long_df["x"])
rugplot(x=long_df["x"])
self.assert_rug_equal(*ax.collections)
def test_datetime_data(self, long_df):
ax = rugplot(data=long_df["t"])
vals = np.stack(ax.collections[0].get_segments())[:, 0, 0]
assert_array_equal(vals, mpl.dates.date2num(long_df["t"]))
def test_empty_data(self):
ax = rugplot(x=[])
assert not ax.collections
def test_a_deprecation(self, flat_series):
f, ax = plt.subplots()
with pytest.warns(UserWarning):
rugplot(a=flat_series)
rugplot(x=flat_series)
self.assert_rug_equal(*ax.collections)
@pytest.mark.parametrize("variable", ["x", "y"])
def test_axis_deprecation(self, flat_series, variable):
f, ax = plt.subplots()
with pytest.warns(UserWarning):
rugplot(flat_series, axis=variable)
rugplot(**{variable: flat_series})
self.assert_rug_equal(*ax.collections)
def test_vertical_deprecation(self, flat_series):
f, ax = plt.subplots()
with pytest.warns(UserWarning):
rugplot(flat_series, vertical=True)
rugplot(y=flat_series)
self.assert_rug_equal(*ax.collections)
def test_rug_data(self, flat_array):
height = .05
ax = rugplot(x=flat_array, height=height)
segments = np.stack(ax.collections[0].get_segments())
n = flat_array.size
assert_array_equal(segments[:, 0, 1], np.zeros(n))
assert_array_equal(segments[:, 1, 1], np.full(n, height))
assert_array_equal(segments[:, 1, 0], flat_array)
def test_rug_colors(self, long_df):
ax = rugplot(data=long_df, x="x", hue="a")
order = categorical_order(long_df["a"])
palette = color_palette()
expected_colors = np.ones((len(long_df), 4))
for i, val in enumerate(long_df["a"]):
expected_colors[i, :3] = palette[order.index(val)]
assert_array_equal(ax.collections[0].get_color(), expected_colors)
def test_expand_margins(self, flat_array):
f, ax = plt.subplots()
x1, y1 = ax.margins()
rugplot(x=flat_array, expand_margins=False)
x2, y2 = ax.margins()
assert x1 == x2
assert y1 == y2
f, ax = plt.subplots()
x1, y1 = ax.margins()
height = .05
rugplot(x=flat_array, height=height)
x2, y2 = ax.margins()
assert x1 == x2
assert y1 + height * 2 == pytest.approx(y2)
def test_multiple_rugs(self):
values = np.linspace(start=0, stop=1, num=5)
ax = rugplot(x=values)
ylim = ax.get_ylim()
rugplot(x=values, ax=ax, expand_margins=False)
assert ylim == ax.get_ylim()
def test_matplotlib_kwargs(self, flat_series):
lw = 2
alpha = .2
ax = rugplot(y=flat_series, linewidth=lw, alpha=alpha)
rug = ax.collections[0]
assert np.all(rug.get_alpha() == alpha)
assert np.all(rug.get_linewidth() == lw)
def test_axis_labels(self, flat_series):
ax = rugplot(x=flat_series)
assert ax.get_xlabel() == flat_series.name
assert not ax.get_ylabel()
def test_log_scale(self, long_df):
ax1, ax2 = plt.figure().subplots(2)
ax2.set_xscale("log")
rugplot(data=long_df, x="z", ax=ax1)
rugplot(data=long_df, x="z", ax=ax2)
rug1 = np.stack(ax1.collections[0].get_segments())
rug2 = np.stack(ax2.collections[0].get_segments())
assert_array_almost_equal(rug1, rug2)
class TestKDEPlotUnivariate(SharedAxesLevelTests):
func = staticmethod(kdeplot)
def get_last_color(self, ax, fill=True):
if fill:
return ax.collections[-1].get_facecolor()
else:
return ax.lines[-1].get_color()
@pytest.mark.parametrize("fill", [True, False])
def test_color(self, long_df, fill):
super().test_color(long_df, fill=fill)
if fill:
ax = plt.figure().subplots()
self.func(data=long_df, x="y", facecolor="C3", fill=True, ax=ax)
assert_colors_equal(self.get_last_color(ax), "C3", check_alpha=False)
ax = plt.figure().subplots()
self.func(data=long_df, x="y", fc="C4", fill=True, ax=ax)
assert_colors_equal(self.get_last_color(ax), "C4", check_alpha=False)
@pytest.mark.parametrize(
"variable", ["x", "y"],
)
def test_long_vectors(self, long_df, variable):
vector = long_df[variable]
vectors = [
variable, vector, vector.to_numpy(), vector.to_list(),
]
f, ax = plt.subplots()
for vector in vectors:
kdeplot(data=long_df, **{variable: vector})
xdata = [l.get_xdata() for l in ax.lines]
for a, b in itertools.product(xdata, xdata):
assert_array_equal(a, b)
ydata = [l.get_ydata() for l in ax.lines]
for a, b in itertools.product(ydata, ydata):
assert_array_equal(a, b)
def test_wide_vs_long_data(self, wide_df):
f, (ax1, ax2) = plt.subplots(ncols=2)
kdeplot(data=wide_df, ax=ax1, common_norm=False, common_grid=False)
for col in wide_df:
kdeplot(data=wide_df, x=col, ax=ax2)
for l1, l2 in zip(ax1.lines[::-1], ax2.lines):
assert_array_equal(l1.get_xydata(), l2.get_xydata())
def test_flat_vector(self, long_df):
f, ax = plt.subplots()
kdeplot(data=long_df["x"])
kdeplot(x=long_df["x"])
assert_array_equal(ax.lines[0].get_xydata(), ax.lines[1].get_xydata())
def test_empty_data(self):
ax = kdeplot(x=[])
assert not ax.lines
def test_singular_data(self):
with pytest.warns(UserWarning):
ax = kdeplot(x=np.ones(10))
assert not ax.lines
with pytest.warns(UserWarning):
ax = kdeplot(x=[5])
assert not ax.lines
with pytest.warns(UserWarning):
# https://github.com/mwaskom/seaborn/issues/2762
ax = kdeplot(x=[1929245168.06679] * 18)
assert not ax.lines
with warnings.catch_warnings():
warnings.simplefilter("error", UserWarning)
ax = kdeplot(x=[5], warn_singular=False)
assert not ax.lines
def test_variable_assignment(self, long_df):
f, ax = plt.subplots()
kdeplot(data=long_df, x="x", fill=True)
kdeplot(data=long_df, y="x", fill=True)
v0 = ax.collections[0].get_paths()[0].vertices
v1 = ax.collections[1].get_paths()[0].vertices[:, [1, 0]]
assert_array_equal(v0, v1)
def test_vertical_deprecation(self, long_df):
f, ax = plt.subplots()
kdeplot(data=long_df, y="x")
with pytest.warns(UserWarning):
kdeplot(data=long_df, x="x", vertical=True)
assert_array_equal(ax.lines[0].get_xydata(), ax.lines[1].get_xydata())
def test_bw_deprecation(self, long_df):
f, ax = plt.subplots()
kdeplot(data=long_df, x="x", bw_method="silverman")
with pytest.warns(UserWarning):
kdeplot(data=long_df, x="x", bw="silverman")
assert_array_equal(ax.lines[0].get_xydata(), ax.lines[1].get_xydata())
def test_kernel_deprecation(self, long_df):
f, ax = plt.subplots()
kdeplot(data=long_df, x="x")
with pytest.warns(UserWarning):
kdeplot(data=long_df, x="x", kernel="epi")
assert_array_equal(ax.lines[0].get_xydata(), ax.lines[1].get_xydata())
def test_shade_deprecation(self, long_df):
f, ax = plt.subplots()
with pytest.warns(FutureWarning):
kdeplot(data=long_df, x="x", shade=True)
kdeplot(data=long_df, x="x", fill=True)
fill1, fill2 = ax.collections
assert_array_equal(
fill1.get_paths()[0].vertices, fill2.get_paths()[0].vertices
)
@pytest.mark.parametrize("multiple", ["layer", "stack", "fill"])
def test_hue_colors(self, long_df, multiple):
ax = kdeplot(
data=long_df, x="x", hue="a",
multiple=multiple,
fill=True, legend=False
)
# Note that hue order is reversed in the plot
lines = ax.lines[::-1]
fills = ax.collections[::-1]
palette = color_palette()
for line, fill, color in zip(lines, fills, palette):
assert_colors_equal(line.get_color(), color)
assert_colors_equal(fill.get_facecolor(), to_rgba(color, .25))
def test_hue_stacking(self, long_df):
f, (ax1, ax2) = plt.subplots(ncols=2)
kdeplot(
data=long_df, x="x", hue="a",
multiple="layer", common_grid=True,
legend=False, ax=ax1,
)
kdeplot(
data=long_df, x="x", hue="a",
multiple="stack", fill=False,
legend=False, ax=ax2,
)
layered_densities = np.stack([
l.get_ydata() for l in ax1.lines
])
stacked_densities = np.stack([
l.get_ydata() for l in ax2.lines
])
assert_array_equal(layered_densities.cumsum(axis=0), stacked_densities)
def test_hue_filling(self, long_df):
f, (ax1, ax2) = plt.subplots(ncols=2)
kdeplot(
data=long_df, x="x", hue="a",
multiple="layer", common_grid=True,
legend=False, ax=ax1,
)
kdeplot(
data=long_df, x="x", hue="a",
multiple="fill", fill=False,
legend=False, ax=ax2,
)
layered = np.stack([l.get_ydata() for l in ax1.lines])
filled = np.stack([l.get_ydata() for l in ax2.lines])
assert_array_almost_equal(
(layered / layered.sum(axis=0)).cumsum(axis=0),
filled,
)
@pytest.mark.parametrize("multiple", ["stack", "fill"])
def test_fill_default(self, long_df, multiple):
ax = kdeplot(
data=long_df, x="x", hue="a", multiple=multiple, fill=None
)
assert len(ax.collections) > 0
@pytest.mark.parametrize("multiple", ["layer", "stack", "fill"])
def test_fill_nondefault(self, long_df, multiple):
f, (ax1, ax2) = plt.subplots(ncols=2)
kws = dict(data=long_df, x="x", hue="a")
kdeplot(**kws, multiple=multiple, fill=False, ax=ax1)
kdeplot(**kws, multiple=multiple, fill=True, ax=ax2)
assert len(ax1.collections) == 0
assert len(ax2.collections) > 0
def test_color_cycle_interaction(self, flat_series):
color = (.2, 1, .6)
f, ax = plt.subplots()
kdeplot(flat_series)
kdeplot(flat_series)
assert_colors_equal(ax.lines[0].get_color(), "C0")
assert_colors_equal(ax.lines[1].get_color(), "C1")
plt.close(f)
f, ax = plt.subplots()
kdeplot(flat_series, color=color)
kdeplot(flat_series)
assert_colors_equal(ax.lines[0].get_color(), color)
assert_colors_equal(ax.lines[1].get_color(), "C0")
plt.close(f)
f, ax = plt.subplots()
kdeplot(flat_series, fill=True)
kdeplot(flat_series, fill=True)
assert_colors_equal(ax.collections[0].get_facecolor(), to_rgba("C0", .25))
assert_colors_equal(ax.collections[1].get_facecolor(), to_rgba("C1", .25))
plt.close(f)
@pytest.mark.parametrize("fill", [True, False])
def test_artist_color(self, long_df, fill):
color = (.2, 1, .6)
alpha = .5
f, ax = plt.subplots()
kdeplot(long_df["x"], fill=fill, color=color)
if fill:
artist_color = ax.collections[-1].get_facecolor().squeeze()
else:
artist_color = ax.lines[-1].get_color()
default_alpha = .25 if fill else 1
assert_colors_equal(artist_color, to_rgba(color, default_alpha))
kdeplot(long_df["x"], fill=fill, color=color, alpha=alpha)
if fill:
artist_color = ax.collections[-1].get_facecolor().squeeze()
else:
artist_color = ax.lines[-1].get_color()
assert_colors_equal(artist_color, to_rgba(color, alpha))
def test_datetime_scale(self, long_df):
f, (ax1, ax2) = plt.subplots(2)
kdeplot(x=long_df["t"], fill=True, ax=ax1)
kdeplot(x=long_df["t"], fill=False, ax=ax2)
assert ax1.get_xlim() == ax2.get_xlim()
def test_multiple_argument_check(self, long_df):
with pytest.raises(ValueError, match="`multiple` must be"):
kdeplot(data=long_df, x="x", hue="a", multiple="bad_input")
def test_cut(self, rng):
x = rng.normal(0, 3, 1000)
f, ax = plt.subplots()
kdeplot(x=x, cut=0, legend=False)
xdata_0 = ax.lines[0].get_xdata()
assert xdata_0.min() == x.min()
assert xdata_0.max() == x.max()
kdeplot(x=x, cut=2, legend=False)
xdata_2 = ax.lines[1].get_xdata()
assert xdata_2.min() < xdata_0.min()
assert xdata_2.max() > xdata_0.max()
assert len(xdata_0) == len(xdata_2)
def test_clip(self, rng):
x = rng.normal(0, 3, 1000)
clip = -1, 1
ax = kdeplot(x=x, clip=clip)
xdata = ax.lines[0].get_xdata()
assert xdata.min() >= clip[0]
assert xdata.max() <= clip[1]
def test_line_is_density(self, long_df):
ax = kdeplot(data=long_df, x="x", cut=5)
x, y = ax.lines[0].get_xydata().T
assert integrate(y, x) == pytest.approx(1)
@pytest.mark.skipif(_no_scipy, reason="Test requires scipy")
def test_cumulative(self, long_df):
ax = kdeplot(data=long_df, x="x", cut=5, cumulative=True)
y = ax.lines[0].get_ydata()
assert y[0] == pytest.approx(0)
assert y[-1] == pytest.approx(1)
@pytest.mark.skipif(not _no_scipy, reason="Test requires scipy's absence")
def test_cumulative_requires_scipy(self, long_df):
with pytest.raises(RuntimeError):
kdeplot(data=long_df, x="x", cut=5, cumulative=True)
def test_common_norm(self, long_df):
f, (ax1, ax2) = plt.subplots(ncols=2)
kdeplot(
data=long_df, x="x", hue="c", common_norm=True, cut=10, ax=ax1
)
kdeplot(
data=long_df, x="x", hue="c", common_norm=False, cut=10, ax=ax2
)
total_area = 0
for line in ax1.lines:
xdata, ydata = line.get_xydata().T
total_area += integrate(ydata, xdata)
assert total_area == pytest.approx(1)
for line in ax2.lines:
xdata, ydata = line.get_xydata().T
assert integrate(ydata, xdata) == pytest.approx(1)
def test_common_grid(self, long_df):
f, (ax1, ax2) = plt.subplots(ncols=2)
order = "a", "b", "c"
kdeplot(
data=long_df, x="x", hue="a", hue_order=order,
common_grid=False, cut=0, ax=ax1,
)
kdeplot(
data=long_df, x="x", hue="a", hue_order=order,
common_grid=True, cut=0, ax=ax2,
)
for line, level in zip(ax1.lines[::-1], order):
xdata = line.get_xdata()
assert xdata.min() == long_df.loc[long_df["a"] == level, "x"].min()
assert xdata.max() == long_df.loc[long_df["a"] == level, "x"].max()
for line in ax2.lines:
xdata = line.get_xdata().T
assert xdata.min() == long_df["x"].min()
assert xdata.max() == long_df["x"].max()
def test_bw_method(self, long_df):
f, ax = plt.subplots()
kdeplot(data=long_df, x="x", bw_method=0.2, legend=False)
kdeplot(data=long_df, x="x", bw_method=1.0, legend=False)
kdeplot(data=long_df, x="x", bw_method=3.0, legend=False)
l1, l2, l3 = ax.lines
assert (
np.abs(np.diff(l1.get_ydata())).mean()
> np.abs(np.diff(l2.get_ydata())).mean()
)
assert (
np.abs(np.diff(l2.get_ydata())).mean()
> np.abs(np.diff(l3.get_ydata())).mean()
)
def test_bw_adjust(self, long_df):
f, ax = plt.subplots()
kdeplot(data=long_df, x="x", bw_adjust=0.2, legend=False)
kdeplot(data=long_df, x="x", bw_adjust=1.0, legend=False)
kdeplot(data=long_df, x="x", bw_adjust=3.0, legend=False)
l1, l2, l3 = ax.lines
assert (
np.abs(np.diff(l1.get_ydata())).mean()
> np.abs(np.diff(l2.get_ydata())).mean()
)
assert (
np.abs(np.diff(l2.get_ydata())).mean()
> np.abs(np.diff(l3.get_ydata())).mean()
)
def test_log_scale_implicit(self, rng):
x = rng.lognormal(0, 1, 100)
f, (ax1, ax2) = plt.subplots(ncols=2)
ax1.set_xscale("log")
kdeplot(x=x, ax=ax1)
kdeplot(x=x, ax=ax1)
xdata_log = ax1.lines[0].get_xdata()
assert (xdata_log > 0).all()
assert (np.diff(xdata_log, 2) > 0).all()
assert np.allclose(np.diff(np.log(xdata_log), 2), 0)
f, ax = plt.subplots()
ax.set_yscale("log")
kdeplot(y=x, ax=ax)
assert_array_equal(ax.lines[0].get_xdata(), ax1.lines[0].get_ydata())
def test_log_scale_explicit(self, rng):
x = rng.lognormal(0, 1, 100)
f, (ax1, ax2, ax3) = plt.subplots(ncols=3)
ax1.set_xscale("log")
kdeplot(x=x, ax=ax1)
kdeplot(x=x, log_scale=True, ax=ax2)
kdeplot(x=x, log_scale=10, ax=ax3)
for ax in f.axes:
assert ax.get_xscale() == "log"
supports = [ax.lines[0].get_xdata() for ax in f.axes]
for a, b in itertools.product(supports, supports):
assert_array_equal(a, b)
densities = [ax.lines[0].get_ydata() for ax in f.axes]
for a, b in itertools.product(densities, densities):
assert_array_equal(a, b)
f, ax = plt.subplots()
kdeplot(y=x, log_scale=True, ax=ax)
assert ax.get_yscale() == "log"
def test_log_scale_with_hue(self, rng):
data = rng.lognormal(0, 1, 50), rng.lognormal(0, 2, 100)
ax = kdeplot(data=data, log_scale=True, common_grid=True)
assert_array_equal(ax.lines[0].get_xdata(), ax.lines[1].get_xdata())
def test_log_scale_normalization(self, rng):
x = rng.lognormal(0, 1, 100)
ax = kdeplot(x=x, log_scale=True, cut=10)
xdata, ydata = ax.lines[0].get_xydata().T
integral = integrate(ydata, np.log10(xdata))
assert integral == pytest.approx(1)
def test_weights(self):
x = [1, 2]
weights = [2, 1]
ax = kdeplot(x=x, weights=weights, bw_method=.1)
xdata, ydata = ax.lines[0].get_xydata().T
y1 = ydata[np.abs(xdata - 1).argmin()]
y2 = ydata[np.abs(xdata - 2).argmin()]
assert y1 == pytest.approx(2 * y2)
def test_weight_norm(self, rng):
vals = rng.normal(0, 1, 50)
x = np.concatenate([vals, vals])
w = np.repeat([1, 2], 50)
ax = kdeplot(x=x, weights=w, hue=w, common_norm=True)
# Recall that artists are added in reverse of hue order
x1, y1 = ax.lines[0].get_xydata().T
x2, y2 = ax.lines[1].get_xydata().T
assert integrate(y1, x1) == pytest.approx(2 * integrate(y2, x2))
def test_sticky_edges(self, long_df):
f, (ax1, ax2) = plt.subplots(ncols=2)
kdeplot(data=long_df, x="x", fill=True, ax=ax1)
assert ax1.collections[0].sticky_edges.y[:] == [0, np.inf]
kdeplot(
data=long_df, x="x", hue="a", multiple="fill", fill=True, ax=ax2
)
assert ax2.collections[0].sticky_edges.y[:] == [0, 1]
def test_line_kws(self, flat_array):
lw = 3
color = (.2, .5, .8)
ax = kdeplot(x=flat_array, linewidth=lw, color=color)
line, = ax.lines
assert line.get_linewidth() == lw
assert_colors_equal(line.get_color(), color)
def test_input_checking(self, long_df):
err = "The x variable is categorical,"
with pytest.raises(TypeError, match=err):
kdeplot(data=long_df, x="a")
def test_axis_labels(self, long_df):
f, (ax1, ax2) = plt.subplots(ncols=2)
kdeplot(data=long_df, x="x", ax=ax1)
assert ax1.get_xlabel() == "x"
assert ax1.get_ylabel() == "Density"
kdeplot(data=long_df, y="y", ax=ax2)
assert ax2.get_xlabel() == "Density"
assert ax2.get_ylabel() == "y"
def test_legend(self, long_df):
ax = kdeplot(data=long_df, x="x", hue="a")
assert ax.legend_.get_title().get_text() == "a"
legend_labels = ax.legend_.get_texts()
order = categorical_order(long_df["a"])
for label, level in zip(legend_labels, order):
assert label.get_text() == level
legend_artists = ax.legend_.findobj(mpl.lines.Line2D)
if _version_predates(mpl, "3.5.0b0"):
# https://github.com/matplotlib/matplotlib/pull/20699
legend_artists = legend_artists[::2]
palette = color_palette()
for artist, color in zip(legend_artists, palette):
assert_colors_equal(artist.get_color(), color)
ax.clear()
kdeplot(data=long_df, x="x", hue="a", legend=False)
assert ax.legend_ is None
def test_replaced_kws(self, long_df):
with pytest.raises(TypeError, match=r"`data2` has been removed"):
kdeplot(data=long_df, x="x", data2="y")
class TestKDEPlotBivariate:
def test_long_vectors(self, long_df):
ax1 = kdeplot(data=long_df, x="x", y="y")
x = long_df["x"]
x_values = [x, x.to_numpy(), x.to_list()]
y = long_df["y"]
y_values = [y, y.to_numpy(), y.to_list()]
for x, y in zip(x_values, y_values):
f, ax2 = plt.subplots()
kdeplot(x=x, y=y, ax=ax2)
for c1, c2 in zip(ax1.collections, ax2.collections):
assert_array_equal(c1.get_offsets(), c2.get_offsets())
def test_singular_data(self):
with pytest.warns(UserWarning):
ax = dist.kdeplot(x=np.ones(10), y=np.arange(10))
assert not ax.lines
with pytest.warns(UserWarning):
ax = dist.kdeplot(x=[5], y=[6])
assert not ax.lines
with pytest.warns(UserWarning):
ax = kdeplot(x=[1929245168.06679] * 18, y=np.arange(18))
assert not ax.lines
with warnings.catch_warnings():
warnings.simplefilter("error", UserWarning)
ax = kdeplot(x=[5], y=[7], warn_singular=False)
assert not ax.lines
def test_fill_artists(self, long_df):
for fill in [True, False]:
f, ax = plt.subplots()
kdeplot(data=long_df, x="x", y="y", hue="c", fill=fill)
for c in ax.collections:
if not _version_predates(mpl, "3.8.0rc1"):
assert isinstance(c, mpl.contour.QuadContourSet)
elif fill or not _version_predates(mpl, "3.5.0b0"):
assert isinstance(c, mpl.collections.PathCollection)
else:
assert isinstance(c, mpl.collections.LineCollection)
def test_common_norm(self, rng):
hue = np.repeat(["a", "a", "a", "b"], 40)
x, y = rng.multivariate_normal([0, 0], [(.2, .5), (.5, 2)], len(hue)).T
x[hue == "a"] -= 2
x[hue == "b"] += 2
f, (ax1, ax2) = plt.subplots(ncols=2)
kdeplot(x=x, y=y, hue=hue, common_norm=True, ax=ax1)
kdeplot(x=x, y=y, hue=hue, common_norm=False, ax=ax2)
n_seg_1 = sum(len(get_contour_coords(c, True)) for c in ax1.collections)
n_seg_2 = sum(len(get_contour_coords(c, True)) for c in ax2.collections)
assert n_seg_2 > n_seg_1
def test_log_scale(self, rng):
x = rng.lognormal(0, 1, 100)
y = rng.uniform(0, 1, 100)
levels = .2, .5, 1
f, ax = plt.subplots()
kdeplot(x=x, y=y, log_scale=True, levels=levels, ax=ax)
assert ax.get_xscale() == "log"
assert ax.get_yscale() == "log"
f, (ax1, ax2) = plt.subplots(ncols=2)
kdeplot(x=x, y=y, log_scale=(10, False), levels=levels, ax=ax1)
assert ax1.get_xscale() == "log"
assert ax1.get_yscale() == "linear"
p = _DistributionPlotter()
kde = KDE()
density, (xx, yy) = kde(np.log10(x), y)
levels = p._quantile_to_level(density, levels)
ax2.contour(10 ** xx, yy, density, levels=levels)
for c1, c2 in zip(ax1.collections, ax2.collections):
assert len(get_contour_coords(c1)) == len(get_contour_coords(c2))
for arr1, arr2 in zip(get_contour_coords(c1), get_contour_coords(c2)):
assert_array_equal(arr1, arr2)
def test_bandwidth(self, rng):
n = 100
x, y = rng.multivariate_normal([0, 0], [(.2, .5), (.5, 2)], n).T
f, (ax1, ax2) = plt.subplots(ncols=2)
kdeplot(x=x, y=y, ax=ax1)
kdeplot(x=x, y=y, bw_adjust=2, ax=ax2)
for c1, c2 in zip(ax1.collections, ax2.collections):
seg1, seg2 = get_contour_coords(c1), get_contour_coords(c2)
if seg1 + seg2:
x1 = seg1[0][:, 0]
x2 = seg2[0][:, 0]
assert np.abs(x2).max() > np.abs(x1).max()
def test_weights(self, rng):
n = 100
x, y = rng.multivariate_normal([1, 3], [(.2, .5), (.5, 2)], n).T
hue = np.repeat([0, 1], n // 2)
weights = rng.uniform(0, 1, n)
f, (ax1, ax2) = plt.subplots(ncols=2)
kdeplot(x=x, y=y, hue=hue, ax=ax1)
kdeplot(x=x, y=y, hue=hue, weights=weights, ax=ax2)
for c1, c2 in zip(ax1.collections, ax2.collections):
if get_contour_coords(c1) and get_contour_coords(c2):
seg1 = np.concatenate(get_contour_coords(c1), axis=0)
seg2 = np.concatenate(get_contour_coords(c2), axis=0)
assert not np.array_equal(seg1, seg2)
def test_hue_ignores_cmap(self, long_df):
with pytest.warns(UserWarning, match="cmap parameter ignored"):
ax = kdeplot(data=long_df, x="x", y="y", hue="c", cmap="viridis")
assert_colors_equal(get_contour_color(ax.collections[0]), "C0")
def test_contour_line_colors(self, long_df):
color = (.2, .9, .8, 1)
ax = kdeplot(data=long_df, x="x", y="y", color=color)
for c in ax.collections:
assert_colors_equal(get_contour_color(c), color)
def test_contour_line_cmap(self, long_df):
color_list = color_palette("Blues", 12)
cmap = mpl.colors.ListedColormap(color_list)
ax = kdeplot(data=long_df, x="x", y="y", cmap=cmap)
for c in ax.collections:
for color in get_contour_color(c):
assert to_rgb(color) in color_list
def test_contour_fill_colors(self, long_df):
n = 6
color = (.2, .9, .8, 1)
ax = kdeplot(
data=long_df, x="x", y="y", fill=True, color=color, levels=n,
)
cmap = light_palette(color, reverse=True, as_cmap=True)
lut = cmap(np.linspace(0, 1, 256))
for c in ax.collections:
for color in c.get_facecolor():
assert color in lut
def test_colorbar(self, long_df):
ax = kdeplot(data=long_df, x="x", y="y", fill=True, cbar=True)
assert len(ax.figure.axes) == 2
def test_levels_and_thresh(self, long_df):
f, (ax1, ax2) = plt.subplots(ncols=2)
n = 8
thresh = .1
plot_kws = dict(data=long_df, x="x", y="y")
kdeplot(**plot_kws, levels=n, thresh=thresh, ax=ax1)
kdeplot(**plot_kws, levels=np.linspace(thresh, 1, n), ax=ax2)
for c1, c2 in zip(ax1.collections, ax2.collections):
assert len(get_contour_coords(c1)) == len(get_contour_coords(c2))
for arr1, arr2 in zip(get_contour_coords(c1), get_contour_coords(c2)):
assert_array_equal(arr1, arr2)
with pytest.raises(ValueError):
kdeplot(**plot_kws, levels=[0, 1, 2])
ax1.clear()
ax2.clear()
kdeplot(**plot_kws, levels=n, thresh=None, ax=ax1)
kdeplot(**plot_kws, levels=n, thresh=0, ax=ax2)
for c1, c2 in zip(ax1.collections, ax2.collections):
assert len(get_contour_coords(c1)) == len(get_contour_coords(c2))
for arr1, arr2 in zip(get_contour_coords(c1), get_contour_coords(c2)):
assert_array_equal(arr1, arr2)
for c1, c2 in zip(ax1.collections, ax2.collections):
assert_array_equal(c1.get_facecolors(), c2.get_facecolors())
def test_quantile_to_level(self, rng):
x = rng.uniform(0, 1, 100000)
isoprop = np.linspace(.1, 1, 6)
levels = _DistributionPlotter()._quantile_to_level(x, isoprop)
for h, p in zip(levels, isoprop):
assert (x[x <= h].sum() / x.sum()) == pytest.approx(p, abs=1e-4)
def test_input_checking(self, long_df):
with pytest.raises(TypeError, match="The x variable is categorical,"):
kdeplot(data=long_df, x="a", y="y")
class TestHistPlotUnivariate(SharedAxesLevelTests):
func = staticmethod(histplot)
def get_last_color(self, ax, element="bars", fill=True):
if element == "bars":
if fill:
return ax.patches[-1].get_facecolor()
else:
return ax.patches[-1].get_edgecolor()
else:
if fill:
artist = ax.collections[-1]
facecolor = artist.get_facecolor()
edgecolor = artist.get_edgecolor()
assert_colors_equal(facecolor, edgecolor, check_alpha=False)
return facecolor
else:
return ax.lines[-1].get_color()
@pytest.mark.parametrize(
"element,fill",
itertools.product(["bars", "step", "poly"], [True, False]),
)
def test_color(self, long_df, element, fill):
super().test_color(long_df, element=element, fill=fill)
@pytest.mark.parametrize(
"variable", ["x", "y"],
)
def test_long_vectors(self, long_df, variable):
vector = long_df[variable]
vectors = [
variable, vector, vector.to_numpy(), vector.to_list(),
]
f, axs = plt.subplots(3)
for vector, ax in zip(vectors, axs):
histplot(data=long_df, ax=ax, **{variable: vector})
bars = [ax.patches for ax in axs]
for a_bars, b_bars in itertools.product(bars, bars):
for a, b in zip(a_bars, b_bars):
assert_array_equal(a.get_height(), b.get_height())
assert_array_equal(a.get_xy(), b.get_xy())
def test_wide_vs_long_data(self, wide_df):
f, (ax1, ax2) = plt.subplots(2)
histplot(data=wide_df, ax=ax1, common_bins=False)
for col in wide_df.columns[::-1]:
histplot(data=wide_df, x=col, ax=ax2)
for a, b in zip(ax1.patches, ax2.patches):
assert a.get_height() == b.get_height()
assert a.get_xy() == b.get_xy()
def test_flat_vector(self, long_df):
f, (ax1, ax2) = plt.subplots(2)
histplot(data=long_df["x"], ax=ax1)
histplot(data=long_df, x="x", ax=ax2)
for a, b in zip(ax1.patches, ax2.patches):
assert a.get_height() == b.get_height()
assert a.get_xy() == b.get_xy()
def test_empty_data(self):
ax = histplot(x=[])
assert not ax.patches
def test_variable_assignment(self, long_df):
f, (ax1, ax2) = plt.subplots(2)
histplot(data=long_df, x="x", ax=ax1)
histplot(data=long_df, y="x", ax=ax2)
for a, b in zip(ax1.patches, ax2.patches):
assert a.get_height() == b.get_width()
@pytest.mark.parametrize("element", ["bars", "step", "poly"])
@pytest.mark.parametrize("multiple", ["layer", "dodge", "stack", "fill"])
def test_hue_fill_colors(self, long_df, multiple, element):
ax = histplot(
data=long_df, x="x", hue="a",
multiple=multiple, bins=1,
fill=True, element=element, legend=False,
)
palette = color_palette()
if multiple == "layer":
if element == "bars":
a = .5
else:
a = .25
else:
a = .75
for bar, color in zip(ax.patches[::-1], palette):
assert_colors_equal(bar.get_facecolor(), to_rgba(color, a))
for poly, color in zip(ax.collections[::-1], palette):
assert_colors_equal(poly.get_facecolor(), to_rgba(color, a))
def test_hue_stack(self, long_df):
f, (ax1, ax2) = plt.subplots(2)
n = 10
kws = dict(data=long_df, x="x", hue="a", bins=n, element="bars")
histplot(**kws, multiple="layer", ax=ax1)
histplot(**kws, multiple="stack", ax=ax2)
layer_heights = np.reshape([b.get_height() for b in ax1.patches], (-1, n))
stack_heights = np.reshape([b.get_height() for b in ax2.patches], (-1, n))
assert_array_equal(layer_heights, stack_heights)
stack_xys = np.reshape([b.get_xy() for b in ax2.patches], (-1, n, 2))
assert_array_equal(
stack_xys[..., 1] + stack_heights,
stack_heights.cumsum(axis=0),
)
def test_hue_fill(self, long_df):
f, (ax1, ax2) = plt.subplots(2)
n = 10
kws = dict(data=long_df, x="x", hue="a", bins=n, element="bars")
histplot(**kws, multiple="layer", ax=ax1)
histplot(**kws, multiple="fill", ax=ax2)
layer_heights = np.reshape([b.get_height() for b in ax1.patches], (-1, n))
stack_heights = np.reshape([b.get_height() for b in ax2.patches], (-1, n))
assert_array_almost_equal(
layer_heights / layer_heights.sum(axis=0), stack_heights
)
stack_xys = np.reshape([b.get_xy() for b in ax2.patches], (-1, n, 2))
assert_array_almost_equal(
(stack_xys[..., 1] + stack_heights) / stack_heights.sum(axis=0),
stack_heights.cumsum(axis=0),
)
def test_hue_dodge(self, long_df):
f, (ax1, ax2) = plt.subplots(2)
bw = 2
kws = dict(data=long_df, x="x", hue="c", binwidth=bw, element="bars")
histplot(**kws, multiple="layer", ax=ax1)
histplot(**kws, multiple="dodge", ax=ax2)
layer_heights = [b.get_height() for b in ax1.patches]
dodge_heights = [b.get_height() for b in ax2.patches]
assert_array_equal(layer_heights, dodge_heights)
layer_xs = np.reshape([b.get_x() for b in ax1.patches], (2, -1))
dodge_xs = np.reshape([b.get_x() for b in ax2.patches], (2, -1))
assert_array_almost_equal(layer_xs[1], dodge_xs[1])
assert_array_almost_equal(layer_xs[0], dodge_xs[0] - bw / 2)
def test_hue_as_numpy_dodged(self, long_df):
# https://github.com/mwaskom/seaborn/issues/2452
ax = histplot(
long_df,
x="y", hue=long_df["a"].to_numpy(),
multiple="dodge", bins=1,
)
# Note hue order reversal
assert ax.patches[1].get_x() < ax.patches[0].get_x()
def test_multiple_input_check(self, flat_series):
with pytest.raises(ValueError, match="`multiple` must be"):
histplot(flat_series, multiple="invalid")
def test_element_input_check(self, flat_series):
with pytest.raises(ValueError, match="`element` must be"):
histplot(flat_series, element="invalid")
def test_count_stat(self, flat_series):
ax = histplot(flat_series, stat="count")
bar_heights = [b.get_height() for b in ax.patches]
assert sum(bar_heights) == len(flat_series)
def test_density_stat(self, flat_series):
ax = histplot(flat_series, stat="density")
bar_heights = [b.get_height() for b in ax.patches]
bar_widths = [b.get_width() for b in ax.patches]
assert np.multiply(bar_heights, bar_widths).sum() == pytest.approx(1)
def test_density_stat_common_norm(self, long_df):
ax = histplot(
data=long_df, x="x", hue="a",
stat="density", common_norm=True, element="bars",
)
bar_heights = [b.get_height() for b in ax.patches]
bar_widths = [b.get_width() for b in ax.patches]
assert np.multiply(bar_heights, bar_widths).sum() == pytest.approx(1)
def test_density_stat_unique_norm(self, long_df):
n = 10
ax = histplot(
data=long_df, x="x", hue="a",
stat="density", bins=n, common_norm=False, element="bars",
)
bar_groups = ax.patches[:n], ax.patches[-n:]
for bars in bar_groups:
bar_heights = [b.get_height() for b in bars]
bar_widths = [b.get_width() for b in bars]
bar_areas = np.multiply(bar_heights, bar_widths)
assert bar_areas.sum() == pytest.approx(1)
@pytest.fixture(params=["probability", "proportion"])
def height_norm_arg(self, request):
return request.param
def test_probability_stat(self, flat_series, height_norm_arg):
ax = histplot(flat_series, stat=height_norm_arg)
bar_heights = [b.get_height() for b in ax.patches]
assert sum(bar_heights) == pytest.approx(1)
def test_probability_stat_common_norm(self, long_df, height_norm_arg):
ax = histplot(
data=long_df, x="x", hue="a",
stat=height_norm_arg, common_norm=True, element="bars",
)
bar_heights = [b.get_height() for b in ax.patches]
assert sum(bar_heights) == pytest.approx(1)
def test_probability_stat_unique_norm(self, long_df, height_norm_arg):
n = 10
ax = histplot(
data=long_df, x="x", hue="a",
stat=height_norm_arg, bins=n, common_norm=False, element="bars",
)
bar_groups = ax.patches[:n], ax.patches[-n:]
for bars in bar_groups:
bar_heights = [b.get_height() for b in bars]
assert sum(bar_heights) == pytest.approx(1)
def test_percent_stat(self, flat_series):
ax = histplot(flat_series, stat="percent")
bar_heights = [b.get_height() for b in ax.patches]
assert sum(bar_heights) == 100
def test_common_bins(self, long_df):
n = 10
ax = histplot(
long_df, x="x", hue="a", common_bins=True, bins=n, element="bars",
)
bar_groups = ax.patches[:n], ax.patches[-n:]
assert_array_equal(
[b.get_xy() for b in bar_groups[0]],
[b.get_xy() for b in bar_groups[1]]
)
def test_unique_bins(self, wide_df):
ax = histplot(wide_df, common_bins=False, bins=10, element="bars")
bar_groups = np.split(np.array(ax.patches), len(wide_df.columns))
for i, col in enumerate(wide_df.columns[::-1]):
bars = bar_groups[i]
start = bars[0].get_x()
stop = bars[-1].get_x() + bars[-1].get_width()
assert_array_almost_equal(start, wide_df[col].min())
assert_array_almost_equal(stop, wide_df[col].max())
def test_range_with_inf(self, rng):
x = rng.normal(0, 1, 20)
ax = histplot([-np.inf, *x])
leftmost_edge = min(p.get_x() for p in ax.patches)
assert leftmost_edge == x.min()
def test_weights_with_missing(self, null_df):
ax = histplot(null_df, x="x", weights="s", bins=5)
bar_heights = [bar.get_height() for bar in ax.patches]
total_weight = null_df[["x", "s"]].dropna()["s"].sum()
assert sum(bar_heights) == pytest.approx(total_weight)
def test_weight_norm(self, rng):
vals = rng.normal(0, 1, 50)
x = np.concatenate([vals, vals])
w = np.repeat([1, 2], 50)
ax = histplot(
x=x, weights=w, hue=w, common_norm=True, stat="density", bins=5
)
# Recall that artists are added in reverse of hue order
y1 = [bar.get_height() for bar in ax.patches[:5]]
y2 = [bar.get_height() for bar in ax.patches[5:]]
assert sum(y1) == 2 * sum(y2)
def test_discrete(self, long_df):
ax = histplot(long_df, x="s", discrete=True)
data_min = long_df["s"].min()
data_max = long_df["s"].max()
assert len(ax.patches) == (data_max - data_min + 1)
for i, bar in enumerate(ax.patches):
assert bar.get_width() == 1
assert bar.get_x() == (data_min + i - .5)
def test_discrete_categorical_default(self, long_df):
ax = histplot(long_df, x="a")
for i, bar in enumerate(ax.patches):
assert bar.get_width() == 1
def test_categorical_yaxis_inversion(self, long_df):
ax = histplot(long_df, y="a")
ymax, ymin = ax.get_ylim()
assert ymax > ymin
def test_datetime_scale(self, long_df):
f, (ax1, ax2) = plt.subplots(2)
histplot(x=long_df["t"], fill=True, ax=ax1)
histplot(x=long_df["t"], fill=False, ax=ax2)
assert ax1.get_xlim() == ax2.get_xlim()
@pytest.mark.parametrize("stat", ["count", "density", "probability"])
def test_kde(self, flat_series, stat):
ax = histplot(
flat_series, kde=True, stat=stat, kde_kws={"cut": 10}
)
bar_widths = [b.get_width() for b in ax.patches]
bar_heights = [b.get_height() for b in ax.patches]
hist_area = np.multiply(bar_widths, bar_heights).sum()
density, = ax.lines
kde_area = integrate(density.get_ydata(), density.get_xdata())
assert kde_area == pytest.approx(hist_area)
@pytest.mark.parametrize("multiple", ["layer", "dodge"])
@pytest.mark.parametrize("stat", ["count", "density", "probability"])
def test_kde_with_hue(self, long_df, stat, multiple):
n = 10
ax = histplot(
long_df, x="x", hue="c", multiple=multiple,
kde=True, stat=stat, element="bars",
kde_kws={"cut": 10}, bins=n,
)
bar_groups = ax.patches[:n], ax.patches[-n:]
for i, bars in enumerate(bar_groups):
bar_widths = [b.get_width() for b in bars]
bar_heights = [b.get_height() for b in bars]
hist_area = np.multiply(bar_widths, bar_heights).sum()
x, y = ax.lines[i].get_xydata().T
kde_area = integrate(y, x)
if multiple == "layer":
assert kde_area == pytest.approx(hist_area)
elif multiple == "dodge":
assert kde_area == pytest.approx(hist_area * 2)
def test_kde_default_cut(self, flat_series):
ax = histplot(flat_series, kde=True)
support = ax.lines[0].get_xdata()
assert support.min() == flat_series.min()
assert support.max() == flat_series.max()
def test_kde_hue(self, long_df):
n = 10
ax = histplot(data=long_df, x="x", hue="a", kde=True, bins=n)
for bar, line in zip(ax.patches[::n], ax.lines):
assert_colors_equal(
bar.get_facecolor(), line.get_color(), check_alpha=False
)
def test_kde_yaxis(self, flat_series):
f, ax = plt.subplots()
histplot(x=flat_series, kde=True)
histplot(y=flat_series, kde=True)
x, y = ax.lines
assert_array_equal(x.get_xdata(), y.get_ydata())
assert_array_equal(x.get_ydata(), y.get_xdata())
def test_kde_line_kws(self, flat_series):
lw = 5
ax = histplot(flat_series, kde=True, line_kws=dict(lw=lw))
assert ax.lines[0].get_linewidth() == lw
def test_kde_singular_data(self):
with warnings.catch_warnings():
warnings.simplefilter("error")
ax = histplot(x=np.ones(10), kde=True)
assert not ax.lines
with warnings.catch_warnings():
warnings.simplefilter("error")
ax = histplot(x=[5], kde=True)
assert not ax.lines
def test_element_default(self, long_df):
f, (ax1, ax2) = plt.subplots(2)
histplot(long_df, x="x", ax=ax1)
histplot(long_df, x="x", ax=ax2, element="bars")
assert len(ax1.patches) == len(ax2.patches)
f, (ax1, ax2) = plt.subplots(2)
histplot(long_df, x="x", hue="a", ax=ax1)
histplot(long_df, x="x", hue="a", ax=ax2, element="bars")
assert len(ax1.patches) == len(ax2.patches)
def test_bars_no_fill(self, flat_series):
alpha = .5
ax = histplot(flat_series, element="bars", fill=False, alpha=alpha)
for bar in ax.patches:
assert bar.get_facecolor() == (0, 0, 0, 0)
assert bar.get_edgecolor()[-1] == alpha
def test_step_fill(self, flat_series):
f, (ax1, ax2) = plt.subplots(2)
n = 10
histplot(flat_series, element="bars", fill=True, bins=n, ax=ax1)
histplot(flat_series, element="step", fill=True, bins=n, ax=ax2)
bar_heights = [b.get_height() for b in ax1.patches]
bar_widths = [b.get_width() for b in ax1.patches]
bar_edges = [b.get_x() for b in ax1.patches]
fill = ax2.collections[0]
x, y = fill.get_paths()[0].vertices[::-1].T
assert_array_equal(x[1:2 * n:2], bar_edges)
assert_array_equal(y[1:2 * n:2], bar_heights)
assert x[n * 2] == bar_edges[-1] + bar_widths[-1]
assert y[n * 2] == bar_heights[-1]
def test_poly_fill(self, flat_series):
f, (ax1, ax2) = plt.subplots(2)
n = 10
histplot(flat_series, element="bars", fill=True, bins=n, ax=ax1)
histplot(flat_series, element="poly", fill=True, bins=n, ax=ax2)
bar_heights = np.array([b.get_height() for b in ax1.patches])
bar_widths = np.array([b.get_width() for b in ax1.patches])
bar_edges = np.array([b.get_x() for b in ax1.patches])
fill = ax2.collections[0]
x, y = fill.get_paths()[0].vertices[::-1].T
assert_array_equal(x[1:n + 1], bar_edges + bar_widths / 2)
assert_array_equal(y[1:n + 1], bar_heights)
def test_poly_no_fill(self, flat_series):
f, (ax1, ax2) = plt.subplots(2)
n = 10
histplot(flat_series, element="bars", fill=False, bins=n, ax=ax1)
histplot(flat_series, element="poly", fill=False, bins=n, ax=ax2)
bar_heights = np.array([b.get_height() for b in ax1.patches])
bar_widths = np.array([b.get_width() for b in ax1.patches])
bar_edges = np.array([b.get_x() for b in ax1.patches])
x, y = ax2.lines[0].get_xydata().T
assert_array_equal(x, bar_edges + bar_widths / 2)
assert_array_equal(y, bar_heights)
def test_step_no_fill(self, flat_series):
f, (ax1, ax2) = plt.subplots(2)
histplot(flat_series, element="bars", fill=False, ax=ax1)
histplot(flat_series, element="step", fill=False, ax=ax2)
bar_heights = [b.get_height() for b in ax1.patches]
bar_widths = [b.get_width() for b in ax1.patches]
bar_edges = [b.get_x() for b in ax1.patches]
x, y = ax2.lines[0].get_xydata().T
assert_array_equal(x[:-1], bar_edges)
assert_array_equal(y[:-1], bar_heights)
assert x[-1] == bar_edges[-1] + bar_widths[-1]
assert y[-1] == y[-2]
def test_step_fill_xy(self, flat_series):
f, ax = plt.subplots()
histplot(x=flat_series, element="step", fill=True)
histplot(y=flat_series, element="step", fill=True)
xverts = ax.collections[0].get_paths()[0].vertices
yverts = ax.collections[1].get_paths()[0].vertices
assert_array_equal(xverts, yverts[:, ::-1])
def test_step_no_fill_xy(self, flat_series):
f, ax = plt.subplots()
histplot(x=flat_series, element="step", fill=False)
histplot(y=flat_series, element="step", fill=False)
xline, yline = ax.lines
assert_array_equal(xline.get_xdata(), yline.get_ydata())
assert_array_equal(xline.get_ydata(), yline.get_xdata())
def test_weighted_histogram(self):
ax = histplot(x=[0, 1, 2], weights=[1, 2, 3], discrete=True)
bar_heights = [b.get_height() for b in ax.patches]
assert bar_heights == [1, 2, 3]
def test_weights_with_auto_bins(self, long_df):
with pytest.warns(UserWarning):
ax = histplot(long_df, x="x", weights="f")
assert len(ax.patches) == 10
def test_shrink(self, long_df):
f, (ax1, ax2) = plt.subplots(2)
bw = 2
shrink = .4
histplot(long_df, x="x", binwidth=bw, ax=ax1)
histplot(long_df, x="x", binwidth=bw, shrink=shrink, ax=ax2)
for p1, p2 in zip(ax1.patches, ax2.patches):
w1, w2 = p1.get_width(), p2.get_width()
assert w2 == pytest.approx(shrink * w1)
x1, x2 = p1.get_x(), p2.get_x()
assert (x2 + w2 / 2) == pytest.approx(x1 + w1 / 2)
def test_log_scale_explicit(self, rng):
x = rng.lognormal(0, 2, 1000)
ax = histplot(x, log_scale=True, binrange=(-3, 3), binwidth=1)
bar_widths = [b.get_width() for b in ax.patches]
steps = np.divide(bar_widths[1:], bar_widths[:-1])
assert np.allclose(steps, 10)
def test_log_scale_implicit(self, rng):
x = rng.lognormal(0, 2, 1000)
f, ax = plt.subplots()
ax.set_xscale("log")
histplot(x, binrange=(-3, 3), binwidth=1, ax=ax)
bar_widths = [b.get_width() for b in ax.patches]
steps = np.divide(bar_widths[1:], bar_widths[:-1])
assert np.allclose(steps, 10)
def test_log_scale_dodge(self, rng):
x = rng.lognormal(0, 2, 100)
hue = np.repeat(["a", "b"], 50)
ax = histplot(x=x, hue=hue, bins=5, log_scale=True, multiple="dodge")
x_min = np.log([b.get_x() for b in ax.patches])
x_max = np.log([b.get_x() + b.get_width() for b in ax.patches])
assert np.unique(np.round(x_max - x_min, 10)).size == 1
def test_log_scale_kde(self, rng):
x = rng.lognormal(0, 1, 1000)
ax = histplot(x=x, log_scale=True, kde=True, bins=20)
bar_height = max(p.get_height() for p in ax.patches)
kde_height = max(ax.lines[0].get_ydata())
assert bar_height == pytest.approx(kde_height, rel=.1)
@pytest.mark.parametrize(
"fill", [True, False],
)
def test_auto_linewidth(self, flat_series, fill):
get_lw = lambda ax: ax.patches[0].get_linewidth() # noqa: E731
kws = dict(element="bars", fill=fill)
f, (ax1, ax2) = plt.subplots(2)
histplot(flat_series, **kws, bins=10, ax=ax1)
histplot(flat_series, **kws, bins=100, ax=ax2)
assert get_lw(ax1) > get_lw(ax2)
f, ax1 = plt.subplots(figsize=(10, 5))
f, ax2 = plt.subplots(figsize=(2, 5))
histplot(flat_series, **kws, bins=30, ax=ax1)
histplot(flat_series, **kws, bins=30, ax=ax2)
assert get_lw(ax1) > get_lw(ax2)
f, ax1 = plt.subplots(figsize=(4, 5))
f, ax2 = plt.subplots(figsize=(4, 5))
histplot(flat_series, **kws, bins=30, ax=ax1)
histplot(10 ** flat_series, **kws, bins=30, log_scale=True, ax=ax2)
assert get_lw(ax1) == pytest.approx(get_lw(ax2))
f, ax1 = plt.subplots(figsize=(4, 5))
f, ax2 = plt.subplots(figsize=(4, 5))
histplot(y=[0, 1, 1], **kws, discrete=True, ax=ax1)
histplot(y=["a", "b", "b"], **kws, ax=ax2)
assert get_lw(ax1) == pytest.approx(get_lw(ax2))
def test_bar_kwargs(self, flat_series):
lw = 2
ec = (1, .2, .9, .5)
ax = histplot(flat_series, binwidth=1, ec=ec, lw=lw)
for bar in ax.patches:
assert_colors_equal(bar.get_edgecolor(), ec)
assert bar.get_linewidth() == lw
def test_step_fill_kwargs(self, flat_series):
lw = 2
ec = (1, .2, .9, .5)
ax = histplot(flat_series, element="step", ec=ec, lw=lw)
poly = ax.collections[0]
assert_colors_equal(poly.get_edgecolor(), ec)
assert poly.get_linewidth() == lw
def test_step_line_kwargs(self, flat_series):
lw = 2
ls = "--"
ax = histplot(flat_series, element="step", fill=False, lw=lw, ls=ls)
line = ax.lines[0]
assert line.get_linewidth() == lw
assert line.get_linestyle() == ls
def test_label(self, flat_series):
ax = histplot(flat_series, label="a label")
handles, labels = ax.get_legend_handles_labels()
assert len(handles) == 1
assert labels == ["a label"]
def test_default_color_scout_cleanup(self, flat_series):
ax = histplot(flat_series)
assert len(ax.containers) == 1
class TestHistPlotBivariate:
def test_mesh(self, long_df):
hist = Histogram()
counts, (x_edges, y_edges) = hist(long_df["x"], long_df["y"])
ax = histplot(long_df, x="x", y="y")
mesh = ax.collections[0]
mesh_data = mesh.get_array()
assert_array_equal(mesh_data.data.flat, counts.T.flat)
assert_array_equal(mesh_data.mask.flat, counts.T.flat == 0)
edges = itertools.product(y_edges[:-1], x_edges[:-1])
for i, (y, x) in enumerate(edges):
path = mesh.get_paths()[i]
assert path.vertices[0, 0] == x
assert path.vertices[0, 1] == y
def test_mesh_with_hue(self, long_df):
ax = histplot(long_df, x="x", y="y", hue="c")
hist = Histogram()
hist.define_bin_params(long_df["x"], long_df["y"])
for i, sub_df in long_df.groupby("c"):
mesh = ax.collections[i]
mesh_data = mesh.get_array()
counts, (x_edges, y_edges) = hist(sub_df["x"], sub_df["y"])
assert_array_equal(mesh_data.data.flat, counts.T.flat)
assert_array_equal(mesh_data.mask.flat, counts.T.flat == 0)
edges = itertools.product(y_edges[:-1], x_edges[:-1])
for i, (y, x) in enumerate(edges):
path = mesh.get_paths()[i]
assert path.vertices[0, 0] == x
assert path.vertices[0, 1] == y
def test_mesh_with_hue_unique_bins(self, long_df):
ax = histplot(long_df, x="x", y="y", hue="c", common_bins=False)
for i, sub_df in long_df.groupby("c"):
hist = Histogram()
mesh = ax.collections[i]
mesh_data = mesh.get_array()
counts, (x_edges, y_edges) = hist(sub_df["x"], sub_df["y"])
assert_array_equal(mesh_data.data.flat, counts.T.flat)
assert_array_equal(mesh_data.mask.flat, counts.T.flat == 0)
edges = itertools.product(y_edges[:-1], x_edges[:-1])
for i, (y, x) in enumerate(edges):
path = mesh.get_paths()[i]
assert path.vertices[0, 0] == x
assert path.vertices[0, 1] == y
def test_mesh_with_col_unique_bins(self, long_df):
g = displot(long_df, x="x", y="y", col="c", common_bins=False)
for i, sub_df in long_df.groupby("c"):
hist = Histogram()
mesh = g.axes.flat[i].collections[0]
mesh_data = mesh.get_array()
counts, (x_edges, y_edges) = hist(sub_df["x"], sub_df["y"])
assert_array_equal(mesh_data.data.flat, counts.T.flat)
assert_array_equal(mesh_data.mask.flat, counts.T.flat == 0)
edges = itertools.product(y_edges[:-1], x_edges[:-1])
for i, (y, x) in enumerate(edges):
path = mesh.get_paths()[i]
assert path.vertices[0, 0] == x
assert path.vertices[0, 1] == y
def test_mesh_log_scale(self, rng):
x, y = rng.lognormal(0, 1, (2, 1000))
hist = Histogram()
counts, (x_edges, y_edges) = hist(np.log10(x), np.log10(y))
ax = histplot(x=x, y=y, log_scale=True)
mesh = ax.collections[0]
mesh_data = mesh.get_array()
assert_array_equal(mesh_data.data.flat, counts.T.flat)
edges = itertools.product(y_edges[:-1], x_edges[:-1])
for i, (y_i, x_i) in enumerate(edges):
path = mesh.get_paths()[i]
assert path.vertices[0, 0] == pytest.approx(10 ** x_i)
assert path.vertices[0, 1] == pytest.approx(10 ** y_i)
def test_mesh_thresh(self, long_df):
hist = Histogram()
counts, (x_edges, y_edges) = hist(long_df["x"], long_df["y"])
thresh = 5
ax = histplot(long_df, x="x", y="y", thresh=thresh)
mesh = ax.collections[0]
mesh_data = mesh.get_array()
assert_array_equal(mesh_data.data.flat, counts.T.flat)
assert_array_equal(mesh_data.mask.flat, (counts <= thresh).T.flat)
def test_mesh_sticky_edges(self, long_df):
ax = histplot(long_df, x="x", y="y", thresh=None)
mesh = ax.collections[0]
assert mesh.sticky_edges.x == [long_df["x"].min(), long_df["x"].max()]
assert mesh.sticky_edges.y == [long_df["y"].min(), long_df["y"].max()]
ax.clear()
ax = histplot(long_df, x="x", y="y")
mesh = ax.collections[0]
assert not mesh.sticky_edges.x
assert not mesh.sticky_edges.y
def test_mesh_common_norm(self, long_df):
stat = "density"
ax = histplot(
long_df, x="x", y="y", hue="c", common_norm=True, stat=stat,
)
hist = Histogram(stat="density")
hist.define_bin_params(long_df["x"], long_df["y"])
for i, sub_df in long_df.groupby("c"):
mesh = ax.collections[i]
mesh_data = mesh.get_array()
density, (x_edges, y_edges) = hist(sub_df["x"], sub_df["y"])
scale = len(sub_df) / len(long_df)
assert_array_equal(mesh_data.data.flat, (density * scale).T.flat)
def test_mesh_unique_norm(self, long_df):
stat = "density"
ax = histplot(
long_df, x="x", y="y", hue="c", common_norm=False, stat=stat,
)
hist = Histogram()
bin_kws = hist.define_bin_params(long_df["x"], long_df["y"])
for i, sub_df in long_df.groupby("c"):
sub_hist = Histogram(bins=bin_kws["bins"], stat=stat)
mesh = ax.collections[i]
mesh_data = mesh.get_array()
density, (x_edges, y_edges) = sub_hist(sub_df["x"], sub_df["y"])
assert_array_equal(mesh_data.data.flat, density.T.flat)
@pytest.mark.parametrize("stat", ["probability", "proportion", "percent"])
def test_mesh_normalization(self, long_df, stat):
ax = histplot(
long_df, x="x", y="y", stat=stat,
)
mesh_data = ax.collections[0].get_array()
expected_sum = {"percent": 100}.get(stat, 1)
assert mesh_data.data.sum() == expected_sum
def test_mesh_colors(self, long_df):
color = "r"
f, ax = plt.subplots()
histplot(
long_df, x="x", y="y", color=color,
)
mesh = ax.collections[0]
assert_array_equal(
mesh.get_cmap().colors,
_DistributionPlotter()._cmap_from_color(color).colors,
)
f, ax = plt.subplots()
histplot(
long_df, x="x", y="y", hue="c",
)
colors = color_palette()
for i, mesh in enumerate(ax.collections):
assert_array_equal(
mesh.get_cmap().colors,
_DistributionPlotter()._cmap_from_color(colors[i]).colors,
)
def test_color_limits(self, long_df):
f, (ax1, ax2, ax3) = plt.subplots(3)
kws = dict(data=long_df, x="x", y="y")
hist = Histogram()
counts, _ = hist(long_df["x"], long_df["y"])
histplot(**kws, ax=ax1)
assert ax1.collections[0].get_clim() == (0, counts.max())
vmax = 10
histplot(**kws, vmax=vmax, ax=ax2)
counts, _ = hist(long_df["x"], long_df["y"])
assert ax2.collections[0].get_clim() == (0, vmax)
pmax = .8
pthresh = .1
f = _DistributionPlotter()._quantile_to_level
histplot(**kws, pmax=pmax, pthresh=pthresh, ax=ax3)
counts, _ = hist(long_df["x"], long_df["y"])
mesh = ax3.collections[0]
assert mesh.get_clim() == (0, f(counts, pmax))
assert_array_equal(
mesh.get_array().mask.flat,
(counts <= f(counts, pthresh)).T.flat,
)
def test_hue_color_limits(self, long_df):
_, (ax1, ax2, ax3, ax4) = plt.subplots(4)
kws = dict(data=long_df, x="x", y="y", hue="c", bins=4)
hist = Histogram(bins=kws["bins"])
hist.define_bin_params(long_df["x"], long_df["y"])
full_counts, _ = hist(long_df["x"], long_df["y"])
sub_counts = []
for _, sub_df in long_df.groupby(kws["hue"]):
c, _ = hist(sub_df["x"], sub_df["y"])
sub_counts.append(c)
pmax = .8
pthresh = .05
f = _DistributionPlotter()._quantile_to_level
histplot(**kws, common_norm=True, ax=ax1)
for i, mesh in enumerate(ax1.collections):
assert mesh.get_clim() == (0, full_counts.max())
histplot(**kws, common_norm=False, ax=ax2)
for i, mesh in enumerate(ax2.collections):
assert mesh.get_clim() == (0, sub_counts[i].max())
histplot(**kws, common_norm=True, pmax=pmax, pthresh=pthresh, ax=ax3)
for i, mesh in enumerate(ax3.collections):
assert mesh.get_clim() == (0, f(full_counts, pmax))
assert_array_equal(
mesh.get_array().mask.flat,
(sub_counts[i] <= f(full_counts, pthresh)).T.flat,
)
histplot(**kws, common_norm=False, pmax=pmax, pthresh=pthresh, ax=ax4)
for i, mesh in enumerate(ax4.collections):
assert mesh.get_clim() == (0, f(sub_counts[i], pmax))
assert_array_equal(
mesh.get_array().mask.flat,
(sub_counts[i] <= f(sub_counts[i], pthresh)).T.flat,
)
def test_colorbar(self, long_df):
f, ax = plt.subplots()
histplot(long_df, x="x", y="y", cbar=True, ax=ax)
assert len(ax.figure.axes) == 2
f, (ax, cax) = plt.subplots(2)
histplot(long_df, x="x", y="y", cbar=True, cbar_ax=cax, ax=ax)
assert len(ax.figure.axes) == 2
class TestECDFPlotUnivariate(SharedAxesLevelTests):
func = staticmethod(ecdfplot)
def get_last_color(self, ax):
return to_rgb(ax.lines[-1].get_color())
@pytest.mark.parametrize("variable", ["x", "y"])
def test_long_vectors(self, long_df, variable):
vector = long_df[variable]
vectors = [
variable, vector, vector.to_numpy(), vector.to_list(),
]
f, ax = plt.subplots()
for vector in vectors:
ecdfplot(data=long_df, ax=ax, **{variable: vector})
xdata = [l.get_xdata() for l in ax.lines]
for a, b in itertools.product(xdata, xdata):
assert_array_equal(a, b)
ydata = [l.get_ydata() for l in ax.lines]
for a, b in itertools.product(ydata, ydata):
assert_array_equal(a, b)
def test_hue(self, long_df):
ax = ecdfplot(long_df, x="x", hue="a")
for line, color in zip(ax.lines[::-1], color_palette()):
assert_colors_equal(line.get_color(), color)
def test_line_kwargs(self, long_df):
color = "r"
ls = "--"
lw = 3
ax = ecdfplot(long_df, x="x", color=color, ls=ls, lw=lw)
for line in ax.lines:
assert_colors_equal(line.get_color(), color)
assert line.get_linestyle() == ls
assert line.get_linewidth() == lw
@pytest.mark.parametrize("data_var", ["x", "y"])
def test_drawstyle(self, flat_series, data_var):
ax = ecdfplot(**{data_var: flat_series})
drawstyles = dict(x="steps-post", y="steps-pre")
assert ax.lines[0].get_drawstyle() == drawstyles[data_var]
@pytest.mark.parametrize(
"data_var,stat_var", [["x", "y"], ["y", "x"]],
)
def test_proportion_limits(self, flat_series, data_var, stat_var):
ax = ecdfplot(**{data_var: flat_series})
data = getattr(ax.lines[0], f"get_{stat_var}data")()
assert data[0] == 0
assert data[-1] == 1
sticky_edges = getattr(ax.lines[0].sticky_edges, stat_var)
assert sticky_edges[:] == [0, 1]
@pytest.mark.parametrize(
"data_var,stat_var", [["x", "y"], ["y", "x"]],
)
def test_proportion_limits_complementary(self, flat_series, data_var, stat_var):
ax = ecdfplot(**{data_var: flat_series}, complementary=True)
data = getattr(ax.lines[0], f"get_{stat_var}data")()
assert data[0] == 1
assert data[-1] == 0
sticky_edges = getattr(ax.lines[0].sticky_edges, stat_var)
assert sticky_edges[:] == [0, 1]
@pytest.mark.parametrize(
"data_var,stat_var", [["x", "y"], ["y", "x"]],
)
def test_proportion_count(self, flat_series, data_var, stat_var):
n = len(flat_series)
ax = ecdfplot(**{data_var: flat_series}, stat="count")
data = getattr(ax.lines[0], f"get_{stat_var}data")()
assert data[0] == 0
assert data[-1] == n
sticky_edges = getattr(ax.lines[0].sticky_edges, stat_var)
assert sticky_edges[:] == [0, n]
def test_weights(self):
ax = ecdfplot(x=[1, 2, 3], weights=[1, 1, 2])
y = ax.lines[0].get_ydata()
assert_array_equal(y, [0, .25, .5, 1])
def test_bivariate_error(self, long_df):
with pytest.raises(NotImplementedError, match="Bivariate ECDF plots"):
ecdfplot(data=long_df, x="x", y="y")
def test_log_scale(self, long_df):
ax1, ax2 = plt.figure().subplots(2)
ecdfplot(data=long_df, x="z", ax=ax1)
ecdfplot(data=long_df, x="z", log_scale=True, ax=ax2)
# Ignore first point, which either -inf (in linear) or 0 (in log)
line1 = ax1.lines[0].get_xydata()[1:]
line2 = ax2.lines[0].get_xydata()[1:]
assert_array_almost_equal(line1, line2)
class TestDisPlot:
# TODO probably good to move these utility attributes/methods somewhere else
@pytest.mark.parametrize(
"kwargs", [
dict(),
dict(x="x"),
dict(x="t"),
dict(x="a"),
dict(x="z", log_scale=True),
dict(x="x", binwidth=4),
dict(x="x", weights="f", bins=5),
dict(x="x", color="green", linewidth=2, binwidth=4),
dict(x="x", hue="a", fill=False),
dict(x="y", hue="a", fill=False),
dict(x="x", hue="a", multiple="stack"),
dict(x="x", hue="a", element="step"),
dict(x="x", hue="a", palette="muted"),
dict(x="x", hue="a", kde=True),
dict(x="x", hue="a", stat="density", common_norm=False),
dict(x="x", y="y"),
],
)
def test_versus_single_histplot(self, long_df, kwargs):
ax = histplot(long_df, **kwargs)
g = displot(long_df, **kwargs)
assert_plots_equal(ax, g.ax)
if ax.legend_ is not None:
assert_legends_equal(ax.legend_, g._legend)
if kwargs:
long_df["_"] = "_"
g2 = displot(long_df, col="_", **kwargs)
assert_plots_equal(ax, g2.ax)
@pytest.mark.parametrize(
"kwargs", [
dict(),
dict(x="x"),
dict(x="t"),
dict(x="z", log_scale=True),
dict(x="x", bw_adjust=.5),
dict(x="x", weights="f"),
dict(x="x", color="green", linewidth=2),
dict(x="x", hue="a", multiple="stack"),
dict(x="x", hue="a", fill=True),
dict(x="y", hue="a", fill=False),
dict(x="x", hue="a", palette="muted"),
dict(x="x", y="y"),
],
)
def test_versus_single_kdeplot(self, long_df, kwargs):
ax = kdeplot(data=long_df, **kwargs)
g = displot(long_df, kind="kde", **kwargs)
assert_plots_equal(ax, g.ax)
if ax.legend_ is not None:
assert_legends_equal(ax.legend_, g._legend)
if kwargs:
long_df["_"] = "_"
g2 = displot(long_df, kind="kde", col="_", **kwargs)
assert_plots_equal(ax, g2.ax)
@pytest.mark.parametrize(
"kwargs", [
dict(),
dict(x="x"),
dict(x="t"),
dict(x="z", log_scale=True),
dict(x="x", weights="f"),
dict(y="x"),
dict(x="x", color="green", linewidth=2),
dict(x="x", hue="a", complementary=True),
dict(x="x", hue="a", stat="count"),
dict(x="x", hue="a", palette="muted"),
],
)
def test_versus_single_ecdfplot(self, long_df, kwargs):
ax = ecdfplot(data=long_df, **kwargs)
g = displot(long_df, kind="ecdf", **kwargs)
assert_plots_equal(ax, g.ax)
if ax.legend_ is not None:
assert_legends_equal(ax.legend_, g._legend)
if kwargs:
long_df["_"] = "_"
g2 = displot(long_df, kind="ecdf", col="_", **kwargs)
assert_plots_equal(ax, g2.ax)
@pytest.mark.parametrize(
"kwargs", [
dict(x="x"),
dict(x="x", y="y"),
dict(x="x", hue="a"),
]
)
def test_with_rug(self, long_df, kwargs):
ax = plt.figure().subplots()
histplot(data=long_df, **kwargs, ax=ax)
rugplot(data=long_df, **kwargs, ax=ax)
g = displot(long_df, rug=True, **kwargs)
assert_plots_equal(ax, g.ax, labels=False)
long_df["_"] = "_"
g2 = displot(long_df, col="_", rug=True, **kwargs)
assert_plots_equal(ax, g2.ax, labels=False)
@pytest.mark.parametrize(
"facet_var", ["col", "row"],
)
def test_facets(self, long_df, facet_var):
kwargs = {facet_var: "a"}
ax = kdeplot(data=long_df, x="x", hue="a")
g = displot(long_df, x="x", kind="kde", **kwargs)
legend_texts = ax.legend_.get_texts()
for i, line in enumerate(ax.lines[::-1]):
facet_ax = g.axes.flat[i]
facet_line = facet_ax.lines[0]
assert_array_equal(line.get_xydata(), facet_line.get_xydata())
text = legend_texts[i].get_text()
assert text in facet_ax.get_title()
@pytest.mark.parametrize("multiple", ["dodge", "stack", "fill"])
def test_facet_multiple(self, long_df, multiple):
bins = np.linspace(0, 20, 5)
ax = histplot(
data=long_df[long_df["c"] == 0],
x="x", hue="a", hue_order=["a", "b", "c"],
multiple=multiple, bins=bins,
)
g = displot(
data=long_df, x="x", hue="a", col="c", hue_order=["a", "b", "c"],
multiple=multiple, bins=bins,
)
assert_plots_equal(ax, g.axes_dict[0])
def test_ax_warning(self, long_df):
ax = plt.figure().subplots()
with pytest.warns(UserWarning, match="`displot` is a figure-level"):
displot(long_df, x="x", ax=ax)
@pytest.mark.parametrize("key", ["col", "row"])
def test_array_faceting(self, long_df, key):
a = long_df["a"].to_numpy()
vals = categorical_order(a)
g = displot(long_df, x="x", **{key: a})
assert len(g.axes.flat) == len(vals)
for ax, val in zip(g.axes.flat, vals):
assert val in ax.get_title()
def test_legend(self, long_df):
g = displot(long_df, x="x", hue="a")
assert g._legend is not None
def test_empty(self):
g = displot(x=[], y=[])
assert isinstance(g, FacetGrid)
def test_bivariate_ecdf_error(self, long_df):
with pytest.raises(NotImplementedError):
displot(long_df, x="x", y="y", kind="ecdf")
def test_bivariate_kde_norm(self, rng):
x, y = rng.normal(0, 1, (2, 100))
z = [0] * 80 + [1] * 20
def count_contours(ax):
if _version_predates(mpl, "3.8.0rc1"):
return sum(bool(get_contour_coords(c)) for c in ax.collections)
else:
return sum(bool(p.vertices.size) for p in ax.collections[0].get_paths())
g = displot(x=x, y=y, col=z, kind="kde", levels=10)
l1 = count_contours(g.axes.flat[0])
l2 = count_contours(g.axes.flat[1])
assert l1 > l2
g = displot(x=x, y=y, col=z, kind="kde", levels=10, common_norm=False)
l1 = count_contours(g.axes.flat[0])
l2 = count_contours(g.axes.flat[1])
assert l1 == l2
def test_bivariate_hist_norm(self, rng):
x, y = rng.normal(0, 1, (2, 100))
z = [0] * 80 + [1] * 20
g = displot(x=x, y=y, col=z, kind="hist")
clim1 = g.axes.flat[0].collections[0].get_clim()
clim2 = g.axes.flat[1].collections[0].get_clim()
assert clim1 == clim2
g = displot(x=x, y=y, col=z, kind="hist", common_norm=False)
clim1 = g.axes.flat[0].collections[0].get_clim()
clim2 = g.axes.flat[1].collections[0].get_clim()
assert clim1[1] > clim2[1]
def test_facetgrid_data(self, long_df):
g = displot(
data=long_df.to_dict(orient="list"),
x="z",
hue=long_df["a"].rename("hue_var"),
col=long_df["c"].to_numpy(),
)
expected_cols = set(long_df.columns.to_list() + ["hue_var", "_col_"])
assert set(g.data.columns) == expected_cols
assert_array_equal(g.data["hue_var"], long_df["a"])
assert_array_equal(g.data["_col_"], long_df["c"])
def integrate(y, x):
""""Simple numerical integration for testing KDE code."""
y = np.asarray(y)
x = np.asarray(x)
dx = np.diff(x)
return (dx * y[:-1] + dx * y[1:]).sum() / 2
================================================
FILE: tests/test_docstrings.py
================================================
from seaborn._docstrings import DocstringComponents
EXAMPLE_DICT = dict(
param_a="""
a : str
The first parameter.
""",
)
class ExampleClass:
def example_method(self):
"""An example method.
Parameters
----------
a : str
A method parameter.
"""
def example_func():
"""An example function.
Parameters
----------
a : str
A function parameter.
"""
class TestDocstringComponents:
def test_from_dict(self):
obj = DocstringComponents(EXAMPLE_DICT)
assert obj.param_a == "a : str\n The first parameter."
def test_from_nested_components(self):
obj_inner = DocstringComponents(EXAMPLE_DICT)
obj_outer = DocstringComponents.from_nested_components(inner=obj_inner)
assert obj_outer.inner.param_a == "a : str\n The first parameter."
def test_from_function(self):
obj = DocstringComponents.from_function_params(example_func)
assert obj.a == "a : str\n A function parameter."
def test_from_method(self):
obj = DocstringComponents.from_function_params(
ExampleClass.example_method
)
assert obj.a == "a : str\n A method parameter."
================================================
FILE: tests/test_matrix.py
================================================
import tempfile
import copy
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
try:
from scipy.spatial import distance
from scipy.cluster import hierarchy
_no_scipy = False
except ImportError:
_no_scipy = True
try:
import fastcluster
assert fastcluster
_no_fastcluster = False
except ImportError:
_no_fastcluster = True
import numpy.testing as npt
import pandas.testing as pdt
import pytest
from seaborn import matrix as mat
from seaborn import color_palette
from seaborn._compat import get_colormap
from seaborn._testing import assert_colors_equal
class TestHeatmap:
rs = np.random.RandomState(sum(map(ord, "heatmap")))
x_norm = rs.randn(4, 8)
letters = pd.Series(["A", "B", "C", "D"], name="letters")
df_norm = pd.DataFrame(x_norm, index=letters)
x_unif = rs.rand(20, 13)
df_unif = pd.DataFrame(x_unif)
default_kws = dict(vmin=None, vmax=None, cmap=None, center=None,
robust=False, annot=False, fmt=".2f", annot_kws=None,
cbar=True, cbar_kws=None, mask=None)
def test_ndarray_input(self):
p = mat._HeatMapper(self.x_norm, **self.default_kws)
npt.assert_array_equal(p.plot_data, self.x_norm)
pdt.assert_frame_equal(p.data, pd.DataFrame(self.x_norm))
npt.assert_array_equal(p.xticklabels, np.arange(8))
npt.assert_array_equal(p.yticklabels, np.arange(4))
assert p.xlabel == ""
assert p.ylabel == ""
def test_array_like_input(self):
class ArrayLike:
def __init__(self, data):
self.data = data
def __array__(self, **kwargs):
return np.asarray(self.data, **kwargs)
p = mat._HeatMapper(ArrayLike(self.x_norm), **self.default_kws)
npt.assert_array_equal(p.plot_data, self.x_norm)
pdt.assert_frame_equal(p.data, pd.DataFrame(self.x_norm))
npt.assert_array_equal(p.xticklabels, np.arange(8))
npt.assert_array_equal(p.yticklabels, np.arange(4))
assert p.xlabel == ""
assert p.ylabel == ""
def test_df_input(self):
p = mat._HeatMapper(self.df_norm, **self.default_kws)
npt.assert_array_equal(p.plot_data, self.x_norm)
pdt.assert_frame_equal(p.data, self.df_norm)
npt.assert_array_equal(p.xticklabels, np.arange(8))
npt.assert_array_equal(p.yticklabels, self.letters.values)
assert p.xlabel == ""
assert p.ylabel == "letters"
def test_df_multindex_input(self):
df = self.df_norm.copy()
index = pd.MultiIndex.from_tuples([("A", 1), ("B", 2),
("C", 3), ("D", 4)],
names=["letter", "number"])
index.name = "letter-number"
df.index = index
p = mat._HeatMapper(df, **self.default_kws)
combined_tick_labels = ["A-1", "B-2", "C-3", "D-4"]
npt.assert_array_equal(p.yticklabels, combined_tick_labels)
assert p.ylabel == "letter-number"
p = mat._HeatMapper(df.T, **self.default_kws)
npt.assert_array_equal(p.xticklabels, combined_tick_labels)
assert p.xlabel == "letter-number"
@pytest.mark.parametrize("dtype", [float, np.int64, object])
def test_mask_input(self, dtype):
kws = self.default_kws.copy()
mask = self.x_norm > 0
kws['mask'] = mask
data = self.x_norm.astype(dtype)
p = mat._HeatMapper(data, **kws)
plot_data = np.ma.masked_where(mask, data)
npt.assert_array_equal(p.plot_data, plot_data)
def test_mask_limits(self):
"""Make sure masked cells are not used to calculate extremes"""
kws = self.default_kws.copy()
mask = self.x_norm > 0
kws['mask'] = mask
p = mat._HeatMapper(self.x_norm, **kws)
assert p.vmax == np.ma.array(self.x_norm, mask=mask).max()
assert p.vmin == np.ma.array(self.x_norm, mask=mask).min()
mask = self.x_norm < 0
kws['mask'] = mask
p = mat._HeatMapper(self.x_norm, **kws)
assert p.vmin == np.ma.array(self.x_norm, mask=mask).min()
assert p.vmax == np.ma.array(self.x_norm, mask=mask).max()
def test_default_vlims(self):
p = mat._HeatMapper(self.df_unif, **self.default_kws)
assert p.vmin == self.x_unif.min()
assert p.vmax == self.x_unif.max()
def test_robust_vlims(self):
kws = self.default_kws.copy()
kws["robust"] = True
p = mat._HeatMapper(self.df_unif, **kws)
assert p.vmin == np.percentile(self.x_unif, 2)
assert p.vmax == np.percentile(self.x_unif, 98)
def test_custom_sequential_vlims(self):
kws = self.default_kws.copy()
kws["vmin"] = 0
kws["vmax"] = 1
p = mat._HeatMapper(self.df_unif, **kws)
assert p.vmin == 0
assert p.vmax == 1
def test_custom_diverging_vlims(self):
kws = self.default_kws.copy()
kws["vmin"] = -4
kws["vmax"] = 5
kws["center"] = 0
p = mat._HeatMapper(self.df_norm, **kws)
assert p.vmin == -4
assert p.vmax == 5
def test_array_with_nans(self):
x1 = self.rs.rand(10, 10)
nulls = np.zeros(10) * np.nan
x2 = np.c_[x1, nulls]
m1 = mat._HeatMapper(x1, **self.default_kws)
m2 = mat._HeatMapper(x2, **self.default_kws)
assert m1.vmin == m2.vmin
assert m1.vmax == m2.vmax
def test_mask(self):
df = pd.DataFrame(data={'a': [1, 1, 1],
'b': [2, np.nan, 2],
'c': [3, 3, np.nan]})
kws = self.default_kws.copy()
kws["mask"] = np.isnan(df.values)
m = mat._HeatMapper(df, **kws)
npt.assert_array_equal(np.isnan(m.plot_data.data),
m.plot_data.mask)
def test_custom_cmap(self):
kws = self.default_kws.copy()
kws["cmap"] = "BuGn"
p = mat._HeatMapper(self.df_unif, **kws)
assert p.cmap == mpl.cm.BuGn
def test_centered_vlims(self):
kws = self.default_kws.copy()
kws["center"] = .5
p = mat._HeatMapper(self.df_unif, **kws)
assert p.vmin == self.df_unif.values.min()
assert p.vmax == self.df_unif.values.max()
def test_default_colors(self):
vals = np.linspace(.2, 1, 9)
cmap = mpl.cm.binary
ax = mat.heatmap([vals], cmap=cmap)
fc = ax.collections[0].get_facecolors()
cvals = np.linspace(0, 1, 9)
npt.assert_array_almost_equal(fc, cmap(cvals), 2)
def test_custom_vlim_colors(self):
vals = np.linspace(.2, 1, 9)
cmap = mpl.cm.binary
ax = mat.heatmap([vals], vmin=0, cmap=cmap)
fc = ax.collections[0].get_facecolors()
npt.assert_array_almost_equal(fc, cmap(vals), 2)
def test_custom_center_colors(self):
vals = np.linspace(.2, 1, 9)
cmap = mpl.cm.binary
ax = mat.heatmap([vals], center=.5, cmap=cmap)
fc = ax.collections[0].get_facecolors()
npt.assert_array_almost_equal(fc, cmap(vals), 2)
def test_cmap_with_properties(self):
kws = self.default_kws.copy()
cmap = copy.copy(get_colormap("BrBG"))
cmap.set_bad("red")
kws["cmap"] = cmap
hm = mat._HeatMapper(self.df_unif, **kws)
npt.assert_array_equal(
cmap(np.ma.masked_invalid([np.nan])),
hm.cmap(np.ma.masked_invalid([np.nan])))
kws["center"] = 0.5
hm = mat._HeatMapper(self.df_unif, **kws)
npt.assert_array_equal(
cmap(np.ma.masked_invalid([np.nan])),
hm.cmap(np.ma.masked_invalid([np.nan])))
kws = self.default_kws.copy()
cmap = copy.copy(get_colormap("BrBG"))
cmap.set_under("red")
kws["cmap"] = cmap
hm = mat._HeatMapper(self.df_unif, **kws)
npt.assert_array_equal(cmap(-np.inf), hm.cmap(-np.inf))
kws["center"] = .5
hm = mat._HeatMapper(self.df_unif, **kws)
npt.assert_array_equal(cmap(-np.inf), hm.cmap(-np.inf))
kws = self.default_kws.copy()
cmap = copy.copy(get_colormap("BrBG"))
cmap.set_over("red")
kws["cmap"] = cmap
hm = mat._HeatMapper(self.df_unif, **kws)
npt.assert_array_equal(cmap(-np.inf), hm.cmap(-np.inf))
kws["center"] = .5
hm = mat._HeatMapper(self.df_unif, **kws)
npt.assert_array_equal(cmap(np.inf), hm.cmap(np.inf))
def test_explicit_none_norm(self):
vals = np.linspace(.2, 1, 9)
cmap = mpl.cm.binary
_, (ax1, ax2) = plt.subplots(2)
mat.heatmap([vals], vmin=0, cmap=cmap, ax=ax1)
fc_default_norm = ax1.collections[0].get_facecolors()
mat.heatmap([vals], vmin=0, norm=None, cmap=cmap, ax=ax2)
fc_explicit_norm = ax2.collections[0].get_facecolors()
npt.assert_array_almost_equal(fc_default_norm, fc_explicit_norm, 2)
def test_ticklabels_off(self):
kws = self.default_kws.copy()
kws['xticklabels'] = False
kws['yticklabels'] = False
p = mat._HeatMapper(self.df_norm, **kws)
assert p.xticklabels == []
assert p.yticklabels == []
def test_custom_ticklabels(self):
kws = self.default_kws.copy()
xticklabels = list('iheartheatmaps'[:self.df_norm.shape[1]])
yticklabels = list('heatmapsarecool'[:self.df_norm.shape[0]])
kws['xticklabels'] = xticklabels
kws['yticklabels'] = yticklabels
p = mat._HeatMapper(self.df_norm, **kws)
assert p.xticklabels == xticklabels
assert p.yticklabels == yticklabels
def test_custom_ticklabel_interval(self):
kws = self.default_kws.copy()
xstep, ystep = 2, 3
kws['xticklabels'] = xstep
kws['yticklabels'] = ystep
p = mat._HeatMapper(self.df_norm, **kws)
nx, ny = self.df_norm.T.shape
npt.assert_array_equal(p.xticks, np.arange(0, nx, xstep) + .5)
npt.assert_array_equal(p.yticks, np.arange(0, ny, ystep) + .5)
npt.assert_array_equal(p.xticklabels,
self.df_norm.columns[0:nx:xstep])
npt.assert_array_equal(p.yticklabels,
self.df_norm.index[0:ny:ystep])
def test_heatmap_annotation(self):
ax = mat.heatmap(self.df_norm, annot=True, fmt=".1f",
annot_kws={"fontsize": 14})
for val, text in zip(self.x_norm.flat, ax.texts):
assert text.get_text() == f"{val:.1f}"
assert text.get_fontsize() == 14
def test_heatmap_annotation_overwrite_kws(self):
annot_kws = dict(color="0.3", va="bottom", ha="left")
ax = mat.heatmap(self.df_norm, annot=True, fmt=".1f",
annot_kws=annot_kws)
for text in ax.texts:
assert text.get_color() == "0.3"
assert text.get_ha() == "left"
assert text.get_va() == "bottom"
def test_heatmap_annotation_with_mask(self):
df = pd.DataFrame(data={'a': [1, 1, 1],
'b': [2, np.nan, 2],
'c': [3, 3, np.nan]})
mask = np.isnan(df.values)
df_masked = np.ma.masked_where(mask, df)
ax = mat.heatmap(df, annot=True, fmt='.1f', mask=mask)
assert len(df_masked.compressed()) == len(ax.texts)
for val, text in zip(df_masked.compressed(), ax.texts):
assert f"{val:.1f}" == text.get_text()
def test_heatmap_annotation_mesh_colors(self):
ax = mat.heatmap(self.df_norm, annot=True)
mesh = ax.collections[0]
assert len(mesh.get_facecolors()) == self.df_norm.values.size
plt.close("all")
def test_heatmap_annotation_other_data(self):
annot_data = self.df_norm + 10
ax = mat.heatmap(self.df_norm, annot=annot_data, fmt=".1f",
annot_kws={"fontsize": 14})
for val, text in zip(annot_data.values.flat, ax.texts):
assert text.get_text() == f"{val:.1f}"
assert text.get_fontsize() == 14
def test_heatmap_annotation_different_shapes(self):
annot_data = self.df_norm.iloc[:-1]
with pytest.raises(ValueError):
mat.heatmap(self.df_norm, annot=annot_data)
def test_heatmap_annotation_with_limited_ticklabels(self):
ax = mat.heatmap(self.df_norm, fmt=".2f", annot=True,
xticklabels=False, yticklabels=False)
for val, text in zip(self.x_norm.flat, ax.texts):
assert text.get_text() == f"{val:.2f}"
def test_heatmap_cbar(self):
f = plt.figure()
mat.heatmap(self.df_norm)
assert len(f.axes) == 2
plt.close(f)
f = plt.figure()
mat.heatmap(self.df_norm, cbar=False)
assert len(f.axes) == 1
plt.close(f)
f, (ax1, ax2) = plt.subplots(2)
mat.heatmap(self.df_norm, ax=ax1, cbar_ax=ax2)
assert len(f.axes) == 2
plt.close(f)
def test_heatmap_axes(self):
ax = mat.heatmap(self.df_norm)
xtl = [int(l.get_text()) for l in ax.get_xticklabels()]
assert xtl == list(self.df_norm.columns)
ytl = [l.get_text() for l in ax.get_yticklabels()]
assert ytl == list(self.df_norm.index)
assert ax.get_xlabel() == ""
assert ax.get_ylabel() == "letters"
assert ax.get_xlim() == (0, 8)
assert ax.get_ylim() == (4, 0)
def test_heatmap_ticklabel_rotation(self):
f, ax = plt.subplots(figsize=(2, 2))
mat.heatmap(self.df_norm, xticklabels=1, yticklabels=1, ax=ax)
for t in ax.get_xticklabels():
assert t.get_rotation() == 0
for t in ax.get_yticklabels():
assert t.get_rotation() == 90
plt.close(f)
df = self.df_norm.copy()
df.columns = [str(c) * 10 for c in df.columns]
df.index = [i * 10 for i in df.index]
f, ax = plt.subplots(figsize=(2, 2))
mat.heatmap(df, xticklabels=1, yticklabels=1, ax=ax)
for t in ax.get_xticklabels():
assert t.get_rotation() == 90
for t in ax.get_yticklabels():
assert t.get_rotation() == 0
plt.close(f)
def test_heatmap_inner_lines(self):
c = (0, 0, 1, 1)
ax = mat.heatmap(self.df_norm, linewidths=2, linecolor=c)
mesh = ax.collections[0]
assert mesh.get_linewidths()[0] == 2
assert tuple(mesh.get_edgecolor()[0]) == c
def test_square_aspect(self):
ax = mat.heatmap(self.df_norm, square=True)
npt.assert_equal(ax.get_aspect(), 1)
def test_mask_validation(self):
mask = mat._matrix_mask(self.df_norm, None)
assert mask.shape == self.df_norm.shape
assert mask.values.sum() == 0
with pytest.raises(ValueError):
bad_array_mask = self.rs.randn(3, 6) > 0
mat._matrix_mask(self.df_norm, bad_array_mask)
with pytest.raises(ValueError):
bad_df_mask = pd.DataFrame(self.rs.randn(4, 8) > 0)
mat._matrix_mask(self.df_norm, bad_df_mask)
def test_missing_data_mask(self):
data = pd.DataFrame(np.arange(4, dtype=float).reshape(2, 2))
data.loc[0, 0] = np.nan
mask = mat._matrix_mask(data, None)
npt.assert_array_equal(mask, [[True, False], [False, False]])
mask_in = np.array([[False, True], [False, False]])
mask_out = mat._matrix_mask(data, mask_in)
npt.assert_array_equal(mask_out, [[True, True], [False, False]])
def test_cbar_ticks(self):
f, (ax1, ax2) = plt.subplots(2)
mat.heatmap(self.df_norm, ax=ax1, cbar_ax=ax2,
cbar_kws=dict(drawedges=True))
assert len(ax2.collections) == 2
@pytest.mark.skipif(_no_scipy, reason="Test requires scipy")
class TestDendrogram:
rs = np.random.RandomState(sum(map(ord, "dendrogram")))
default_kws = dict(linkage=None, metric='euclidean', method='single',
axis=1, label=True, rotate=False)
x_norm = rs.randn(4, 8) + np.arange(8)
x_norm = (x_norm.T + np.arange(4)).T
letters = pd.Series(["A", "B", "C", "D", "E", "F", "G", "H"],
name="letters")
df_norm = pd.DataFrame(x_norm, columns=letters)
if not _no_scipy:
if _no_fastcluster:
x_norm_distances = distance.pdist(x_norm.T, metric='euclidean')
x_norm_linkage = hierarchy.linkage(x_norm_distances, method='single')
else:
x_norm_linkage = fastcluster.linkage_vector(x_norm.T,
metric='euclidean',
method='single')
x_norm_dendrogram = hierarchy.dendrogram(x_norm_linkage, no_plot=True,
color_threshold=-np.inf)
x_norm_leaves = x_norm_dendrogram['leaves']
df_norm_leaves = np.asarray(df_norm.columns[x_norm_leaves])
def test_ndarray_input(self):
p = mat._DendrogramPlotter(self.x_norm, **self.default_kws)
npt.assert_array_equal(p.array.T, self.x_norm)
pdt.assert_frame_equal(p.data.T, pd.DataFrame(self.x_norm))
npt.assert_array_equal(p.linkage, self.x_norm_linkage)
assert p.dendrogram == self.x_norm_dendrogram
npt.assert_array_equal(p.reordered_ind, self.x_norm_leaves)
npt.assert_array_equal(p.xticklabels, self.x_norm_leaves)
npt.assert_array_equal(p.yticklabels, [])
assert p.xlabel is None
assert p.ylabel == ''
def test_df_input(self):
p = mat._DendrogramPlotter(self.df_norm, **self.default_kws)
npt.assert_array_equal(p.array.T, np.asarray(self.df_norm))
pdt.assert_frame_equal(p.data.T, self.df_norm)
npt.assert_array_equal(p.linkage, self.x_norm_linkage)
assert p.dendrogram == self.x_norm_dendrogram
npt.assert_array_equal(p.xticklabels,
np.asarray(self.df_norm.columns)[
self.x_norm_leaves])
npt.assert_array_equal(p.yticklabels, [])
assert p.xlabel == 'letters'
assert p.ylabel == ''
def test_df_multindex_input(self):
df = self.df_norm.copy()
index = pd.MultiIndex.from_tuples([("A", 1), ("B", 2),
("C", 3), ("D", 4)],
names=["letter", "number"])
index.name = "letter-number"
df.index = index
kws = self.default_kws.copy()
kws['label'] = True
p = mat._DendrogramPlotter(df.T, **kws)
xticklabels = ["A-1", "B-2", "C-3", "D-4"]
xticklabels = [xticklabels[i] for i in p.reordered_ind]
npt.assert_array_equal(p.xticklabels, xticklabels)
npt.assert_array_equal(p.yticklabels, [])
assert p.xlabel == "letter-number"
def test_axis0_input(self):
kws = self.default_kws.copy()
kws['axis'] = 0
p = mat._DendrogramPlotter(self.df_norm.T, **kws)
npt.assert_array_equal(p.array, np.asarray(self.df_norm.T))
pdt.assert_frame_equal(p.data, self.df_norm.T)
npt.assert_array_equal(p.linkage, self.x_norm_linkage)
assert p.dendrogram == self.x_norm_dendrogram
npt.assert_array_equal(p.xticklabels, self.df_norm_leaves)
npt.assert_array_equal(p.yticklabels, [])
assert p.xlabel == 'letters'
assert p.ylabel == ''
def test_rotate_input(self):
kws = self.default_kws.copy()
kws['rotate'] = True
p = mat._DendrogramPlotter(self.df_norm, **kws)
npt.assert_array_equal(p.array.T, np.asarray(self.df_norm))
pdt.assert_frame_equal(p.data.T, self.df_norm)
npt.assert_array_equal(p.xticklabels, [])
npt.assert_array_equal(p.yticklabels, self.df_norm_leaves)
assert p.xlabel == ''
assert p.ylabel == 'letters'
def test_rotate_axis0_input(self):
kws = self.default_kws.copy()
kws['rotate'] = True
kws['axis'] = 0
p = mat._DendrogramPlotter(self.df_norm.T, **kws)
npt.assert_array_equal(p.reordered_ind, self.x_norm_leaves)
def test_custom_linkage(self):
kws = self.default_kws.copy()
try:
import fastcluster
linkage = fastcluster.linkage_vector(self.x_norm, method='single',
metric='euclidean')
except ImportError:
d = distance.pdist(self.x_norm, metric='euclidean')
linkage = hierarchy.linkage(d, method='single')
dendrogram = hierarchy.dendrogram(linkage, no_plot=True,
color_threshold=-np.inf)
kws['linkage'] = linkage
p = mat._DendrogramPlotter(self.df_norm, **kws)
npt.assert_array_equal(p.linkage, linkage)
assert p.dendrogram == dendrogram
def test_label_false(self):
kws = self.default_kws.copy()
kws['label'] = False
p = mat._DendrogramPlotter(self.df_norm, **kws)
assert p.xticks == []
assert p.yticks == []
assert p.xticklabels == []
assert p.yticklabels == []
assert p.xlabel == ""
assert p.ylabel == ""
def test_linkage_scipy(self):
p = mat._DendrogramPlotter(self.x_norm, **self.default_kws)
scipy_linkage = p._calculate_linkage_scipy()
from scipy.spatial import distance
from scipy.cluster import hierarchy
dists = distance.pdist(self.x_norm.T,
metric=self.default_kws['metric'])
linkage = hierarchy.linkage(dists, method=self.default_kws['method'])
npt.assert_array_equal(scipy_linkage, linkage)
@pytest.mark.skipif(_no_fastcluster, reason="fastcluster not installed")
def test_fastcluster_other_method(self):
import fastcluster
kws = self.default_kws.copy()
kws['method'] = 'average'
linkage = fastcluster.linkage(self.x_norm.T, method='average',
metric='euclidean')
p = mat._DendrogramPlotter(self.x_norm, **kws)
npt.assert_array_equal(p.linkage, linkage)
@pytest.mark.skipif(_no_fastcluster, reason="fastcluster not installed")
def test_fastcluster_non_euclidean(self):
import fastcluster
kws = self.default_kws.copy()
kws['metric'] = 'cosine'
kws['method'] = 'average'
linkage = fastcluster.linkage(self.x_norm.T, method=kws['method'],
metric=kws['metric'])
p = mat._DendrogramPlotter(self.x_norm, **kws)
npt.assert_array_equal(p.linkage, linkage)
def test_dendrogram_plot(self):
d = mat.dendrogram(self.x_norm, **self.default_kws)
ax = plt.gca()
xlim = ax.get_xlim()
# 10 comes from _plot_dendrogram in scipy.cluster.hierarchy
xmax = len(d.reordered_ind) * 10
assert xlim[0] == 0
assert xlim[1] == xmax
assert len(ax.collections[0].get_paths()) == len(d.dependent_coord)
def test_dendrogram_rotate(self):
kws = self.default_kws.copy()
kws['rotate'] = True
d = mat.dendrogram(self.x_norm, **kws)
ax = plt.gca()
ylim = ax.get_ylim()
# 10 comes from _plot_dendrogram in scipy.cluster.hierarchy
ymax = len(d.reordered_ind) * 10
# Since y axis is inverted, ylim is (80, 0)
# and therefore not (0, 80) as usual:
assert ylim[1] == 0
assert ylim[0] == ymax
def test_dendrogram_ticklabel_rotation(self):
f, ax = plt.subplots(figsize=(2, 2))
mat.dendrogram(self.df_norm, ax=ax)
for t in ax.get_xticklabels():
assert t.get_rotation() == 0
plt.close(f)
df = self.df_norm.copy()
df.columns = [str(c) * 10 for c in df.columns]
df.index = [i * 10 for i in df.index]
f, ax = plt.subplots(figsize=(2, 2))
mat.dendrogram(df, ax=ax)
for t in ax.get_xticklabels():
assert t.get_rotation() == 90
plt.close(f)
f, ax = plt.subplots(figsize=(2, 2))
mat.dendrogram(df.T, axis=0, rotate=True)
for t in ax.get_yticklabels():
assert t.get_rotation() == 0
plt.close(f)
@pytest.mark.skipif(_no_scipy, reason="Test requires scipy")
class TestClustermap:
rs = np.random.RandomState(sum(map(ord, "clustermap")))
x_norm = rs.randn(4, 8) + np.arange(8)
x_norm = (x_norm.T + np.arange(4)).T
letters = pd.Series(["A", "B", "C", "D", "E", "F", "G", "H"],
name="letters")
df_norm = pd.DataFrame(x_norm, columns=letters)
default_kws = dict(pivot_kws=None, z_score=None, standard_scale=None,
figsize=(10, 10), row_colors=None, col_colors=None,
dendrogram_ratio=.2, colors_ratio=.03,
cbar_pos=(0, .8, .05, .2))
default_plot_kws = dict(metric='euclidean', method='average',
colorbar_kws=None,
row_cluster=True, col_cluster=True,
row_linkage=None, col_linkage=None,
tree_kws=None)
row_colors = color_palette('Set2', df_norm.shape[0])
col_colors = color_palette('Dark2', df_norm.shape[1])
if not _no_scipy:
if _no_fastcluster:
x_norm_distances = distance.pdist(x_norm.T, metric='euclidean')
x_norm_linkage = hierarchy.linkage(x_norm_distances, method='single')
else:
x_norm_linkage = fastcluster.linkage_vector(x_norm.T,
metric='euclidean',
method='single')
x_norm_dendrogram = hierarchy.dendrogram(x_norm_linkage, no_plot=True,
color_threshold=-np.inf)
x_norm_leaves = x_norm_dendrogram['leaves']
df_norm_leaves = np.asarray(df_norm.columns[x_norm_leaves])
def test_ndarray_input(self):
cg = mat.ClusterGrid(self.x_norm, **self.default_kws)
pdt.assert_frame_equal(cg.data, pd.DataFrame(self.x_norm))
assert len(cg.fig.axes) == 4
assert cg.ax_row_colors is None
assert cg.ax_col_colors is None
def test_df_input(self):
cg = mat.ClusterGrid(self.df_norm, **self.default_kws)
pdt.assert_frame_equal(cg.data, self.df_norm)
def test_corr_df_input(self):
df = self.df_norm.corr()
cg = mat.ClusterGrid(df, **self.default_kws)
cg.plot(**self.default_plot_kws)
diag = cg.data2d.values[np.diag_indices_from(cg.data2d)]
npt.assert_array_almost_equal(diag, np.ones(cg.data2d.shape[0]))
def test_pivot_input(self):
df_norm = self.df_norm.copy()
df_norm.index.name = 'numbers'
df_long = pd.melt(df_norm.reset_index(), var_name='letters',
id_vars='numbers')
kws = self.default_kws.copy()
kws['pivot_kws'] = dict(index='numbers', columns='letters',
values='value')
cg = mat.ClusterGrid(df_long, **kws)
pdt.assert_frame_equal(cg.data2d, df_norm)
def test_colors_input(self):
kws = self.default_kws.copy()
kws['row_colors'] = self.row_colors
kws['col_colors'] = self.col_colors
cg = mat.ClusterGrid(self.df_norm, **kws)
npt.assert_array_equal(cg.row_colors, self.row_colors)
npt.assert_array_equal(cg.col_colors, self.col_colors)
assert len(cg.fig.axes) == 6
def test_categorical_colors_input(self):
kws = self.default_kws.copy()
row_colors = pd.Series(self.row_colors, dtype="category")
col_colors = pd.Series(
self.col_colors, dtype="category", index=self.df_norm.columns
)
kws['row_colors'] = row_colors
kws['col_colors'] = col_colors
exp_row_colors = list(map(mpl.colors.to_rgb, row_colors))
exp_col_colors = list(map(mpl.colors.to_rgb, col_colors))
cg = mat.ClusterGrid(self.df_norm, **kws)
npt.assert_array_equal(cg.row_colors, exp_row_colors)
npt.assert_array_equal(cg.col_colors, exp_col_colors)
assert len(cg.fig.axes) == 6
def test_nested_colors_input(self):
kws = self.default_kws.copy()
row_colors = [self.row_colors, self.row_colors]
col_colors = [self.col_colors, self.col_colors]
kws['row_colors'] = row_colors
kws['col_colors'] = col_colors
cm = mat.ClusterGrid(self.df_norm, **kws)
npt.assert_array_equal(cm.row_colors, row_colors)
npt.assert_array_equal(cm.col_colors, col_colors)
assert len(cm.fig.axes) == 6
def test_colors_input_custom_cmap(self):
kws = self.default_kws.copy()
kws['cmap'] = mpl.cm.PRGn
kws['row_colors'] = self.row_colors
kws['col_colors'] = self.col_colors
cg = mat.clustermap(self.df_norm, **kws)
npt.assert_array_equal(cg.row_colors, self.row_colors)
npt.assert_array_equal(cg.col_colors, self.col_colors)
assert len(cg.fig.axes) == 6
def test_z_score(self):
df = self.df_norm.copy()
df = (df - df.mean()) / df.std()
kws = self.default_kws.copy()
kws['z_score'] = 1
cg = mat.ClusterGrid(self.df_norm, **kws)
pdt.assert_frame_equal(cg.data2d, df)
def test_z_score_axis0(self):
df = self.df_norm.copy()
df = df.T
df = (df - df.mean()) / df.std()
df = df.T
kws = self.default_kws.copy()
kws['z_score'] = 0
cg = mat.ClusterGrid(self.df_norm, **kws)
pdt.assert_frame_equal(cg.data2d, df)
def test_standard_scale(self):
df = self.df_norm.copy()
df = (df - df.min()) / (df.max() - df.min())
kws = self.default_kws.copy()
kws['standard_scale'] = 1
cg = mat.ClusterGrid(self.df_norm, **kws)
pdt.assert_frame_equal(cg.data2d, df)
def test_standard_scale_axis0(self):
df = self.df_norm.copy()
df = df.T
df = (df - df.min()) / (df.max() - df.min())
df = df.T
kws = self.default_kws.copy()
kws['standard_scale'] = 0
cg = mat.ClusterGrid(self.df_norm, **kws)
pdt.assert_frame_equal(cg.data2d, df)
def test_z_score_standard_scale(self):
kws = self.default_kws.copy()
kws['z_score'] = True
kws['standard_scale'] = True
with pytest.raises(ValueError):
mat.ClusterGrid(self.df_norm, **kws)
def test_color_list_to_matrix_and_cmap(self):
# Note this uses the attribute named col_colors but tests row colors
matrix, cmap = mat.ClusterGrid.color_list_to_matrix_and_cmap(
self.col_colors, self.x_norm_leaves, axis=0)
for i, leaf in enumerate(self.x_norm_leaves):
color = self.col_colors[leaf]
assert_colors_equal(cmap(matrix[i, 0]), color)
def test_nested_color_list_to_matrix_and_cmap(self):
# Note this uses the attribute named col_colors but tests row colors
colors = [self.col_colors, self.col_colors[::-1]]
matrix, cmap = mat.ClusterGrid.color_list_to_matrix_and_cmap(
colors, self.x_norm_leaves, axis=0)
for i, leaf in enumerate(self.x_norm_leaves):
for j, color_row in enumerate(colors):
color = color_row[leaf]
assert_colors_equal(cmap(matrix[i, j]), color)
def test_color_list_to_matrix_and_cmap_axis1(self):
matrix, cmap = mat.ClusterGrid.color_list_to_matrix_and_cmap(
self.col_colors, self.x_norm_leaves, axis=1)
for j, leaf in enumerate(self.x_norm_leaves):
color = self.col_colors[leaf]
assert_colors_equal(cmap(matrix[0, j]), color)
def test_color_list_to_matrix_and_cmap_different_sizes(self):
colors = [self.col_colors, self.col_colors * 2]
with pytest.raises(ValueError):
matrix, cmap = mat.ClusterGrid.color_list_to_matrix_and_cmap(
colors, self.x_norm_leaves, axis=1)
def test_savefig(self):
# Not sure if this is the right way to test....
cg = mat.ClusterGrid(self.df_norm, **self.default_kws)
cg.plot(**self.default_plot_kws)
cg.savefig(tempfile.NamedTemporaryFile(), format='png')
def test_plot_dendrograms(self):
cm = mat.clustermap(self.df_norm, **self.default_kws)
assert len(cm.ax_row_dendrogram.collections[0].get_paths()) == len(
cm.dendrogram_row.independent_coord
)
assert len(cm.ax_col_dendrogram.collections[0].get_paths()) == len(
cm.dendrogram_col.independent_coord
)
data2d = self.df_norm.iloc[cm.dendrogram_row.reordered_ind,
cm.dendrogram_col.reordered_ind]
pdt.assert_frame_equal(cm.data2d, data2d)
def test_cluster_false(self):
kws = self.default_kws.copy()
kws['row_cluster'] = False
kws['col_cluster'] = False
cm = mat.clustermap(self.df_norm, **kws)
assert len(cm.ax_row_dendrogram.lines) == 0
assert len(cm.ax_col_dendrogram.lines) == 0
assert len(cm.ax_row_dendrogram.get_xticks()) == 0
assert len(cm.ax_row_dendrogram.get_yticks()) == 0
assert len(cm.ax_col_dendrogram.get_xticks()) == 0
assert len(cm.ax_col_dendrogram.get_yticks()) == 0
pdt.assert_frame_equal(cm.data2d, self.df_norm)
def test_row_col_colors(self):
kws = self.default_kws.copy()
kws['row_colors'] = self.row_colors
kws['col_colors'] = self.col_colors
cm = mat.clustermap(self.df_norm, **kws)
assert len(cm.ax_row_colors.collections) == 1
assert len(cm.ax_col_colors.collections) == 1
def test_cluster_false_row_col_colors(self):
kws = self.default_kws.copy()
kws['row_cluster'] = False
kws['col_cluster'] = False
kws['row_colors'] = self.row_colors
kws['col_colors'] = self.col_colors
cm = mat.clustermap(self.df_norm, **kws)
assert len(cm.ax_row_dendrogram.lines) == 0
assert len(cm.ax_col_dendrogram.lines) == 0
assert len(cm.ax_row_dendrogram.get_xticks()) == 0
assert len(cm.ax_row_dendrogram.get_yticks()) == 0
assert len(cm.ax_col_dendrogram.get_xticks()) == 0
assert len(cm.ax_col_dendrogram.get_yticks()) == 0
assert len(cm.ax_row_colors.collections) == 1
assert len(cm.ax_col_colors.collections) == 1
pdt.assert_frame_equal(cm.data2d, self.df_norm)
def test_row_col_colors_df(self):
kws = self.default_kws.copy()
kws['row_colors'] = pd.DataFrame({'row_1': list(self.row_colors),
'row_2': list(self.row_colors)},
index=self.df_norm.index,
columns=['row_1', 'row_2'])
kws['col_colors'] = pd.DataFrame({'col_1': list(self.col_colors),
'col_2': list(self.col_colors)},
index=self.df_norm.columns,
columns=['col_1', 'col_2'])
cm = mat.clustermap(self.df_norm, **kws)
row_labels = [l.get_text() for l in
cm.ax_row_colors.get_xticklabels()]
assert cm.row_color_labels == ['row_1', 'row_2']
assert row_labels == cm.row_color_labels
col_labels = [l.get_text() for l in
cm.ax_col_colors.get_yticklabels()]
assert cm.col_color_labels == ['col_1', 'col_2']
assert col_labels == cm.col_color_labels
def test_row_col_colors_df_shuffled(self):
# Tests if colors are properly matched, even if given in wrong order
m, n = self.df_norm.shape
shuffled_inds = [self.df_norm.index[i] for i in
list(range(0, m, 2)) + list(range(1, m, 2))]
shuffled_cols = [self.df_norm.columns[i] for i in
list(range(0, n, 2)) + list(range(1, n, 2))]
kws = self.default_kws.copy()
row_colors = pd.DataFrame({'row_annot': list(self.row_colors)},
index=self.df_norm.index)
kws['row_colors'] = row_colors.loc[shuffled_inds]
col_colors = pd.DataFrame({'col_annot': list(self.col_colors)},
index=self.df_norm.columns)
kws['col_colors'] = col_colors.loc[shuffled_cols]
cm = mat.clustermap(self.df_norm, **kws)
assert list(cm.col_colors)[0] == list(self.col_colors)
assert list(cm.row_colors)[0] == list(self.row_colors)
def test_row_col_colors_df_missing(self):
kws = self.default_kws.copy()
row_colors = pd.DataFrame({'row_annot': list(self.row_colors)},
index=self.df_norm.index)
kws['row_colors'] = row_colors.drop(self.df_norm.index[0])
col_colors = pd.DataFrame({'col_annot': list(self.col_colors)},
index=self.df_norm.columns)
kws['col_colors'] = col_colors.drop(self.df_norm.columns[0])
cm = mat.clustermap(self.df_norm, **kws)
assert list(cm.col_colors)[0] == [(1.0, 1.0, 1.0)] + list(self.col_colors[1:])
assert list(cm.row_colors)[0] == [(1.0, 1.0, 1.0)] + list(self.row_colors[1:])
def test_row_col_colors_df_one_axis(self):
# Test case with only row annotation.
kws1 = self.default_kws.copy()
kws1['row_colors'] = pd.DataFrame({'row_1': list(self.row_colors),
'row_2': list(self.row_colors)},
index=self.df_norm.index,
columns=['row_1', 'row_2'])
cm1 = mat.clustermap(self.df_norm, **kws1)
row_labels = [l.get_text() for l in
cm1.ax_row_colors.get_xticklabels()]
assert cm1.row_color_labels == ['row_1', 'row_2']
assert row_labels == cm1.row_color_labels
# Test case with only col annotation.
kws2 = self.default_kws.copy()
kws2['col_colors'] = pd.DataFrame({'col_1': list(self.col_colors),
'col_2': list(self.col_colors)},
index=self.df_norm.columns,
columns=['col_1', 'col_2'])
cm2 = mat.clustermap(self.df_norm, **kws2)
col_labels = [l.get_text() for l in
cm2.ax_col_colors.get_yticklabels()]
assert cm2.col_color_labels == ['col_1', 'col_2']
assert col_labels == cm2.col_color_labels
def test_row_col_colors_series(self):
kws = self.default_kws.copy()
kws['row_colors'] = pd.Series(list(self.row_colors), name='row_annot',
index=self.df_norm.index)
kws['col_colors'] = pd.Series(list(self.col_colors), name='col_annot',
index=self.df_norm.columns)
cm = mat.clustermap(self.df_norm, **kws)
row_labels = [l.get_text() for l in cm.ax_row_colors.get_xticklabels()]
assert cm.row_color_labels == ['row_annot']
assert row_labels == cm.row_color_labels
col_labels = [l.get_text() for l in cm.ax_col_colors.get_yticklabels()]
assert cm.col_color_labels == ['col_annot']
assert col_labels == cm.col_color_labels
def test_row_col_colors_series_shuffled(self):
# Tests if colors are properly matched, even if given in wrong order
m, n = self.df_norm.shape
shuffled_inds = [self.df_norm.index[i] for i in
list(range(0, m, 2)) + list(range(1, m, 2))]
shuffled_cols = [self.df_norm.columns[i] for i in
list(range(0, n, 2)) + list(range(1, n, 2))]
kws = self.default_kws.copy()
row_colors = pd.Series(list(self.row_colors), name='row_annot',
index=self.df_norm.index)
kws['row_colors'] = row_colors.loc[shuffled_inds]
col_colors = pd.Series(list(self.col_colors), name='col_annot',
index=self.df_norm.columns)
kws['col_colors'] = col_colors.loc[shuffled_cols]
cm = mat.clustermap(self.df_norm, **kws)
assert list(cm.col_colors) == list(self.col_colors)
assert list(cm.row_colors) == list(self.row_colors)
def test_row_col_colors_series_missing(self):
kws = self.default_kws.copy()
row_colors = pd.Series(list(self.row_colors), name='row_annot',
index=self.df_norm.index)
kws['row_colors'] = row_colors.drop(self.df_norm.index[0])
col_colors = pd.Series(list(self.col_colors), name='col_annot',
index=self.df_norm.columns)
kws['col_colors'] = col_colors.drop(self.df_norm.columns[0])
cm = mat.clustermap(self.df_norm, **kws)
assert list(cm.col_colors) == [(1.0, 1.0, 1.0)] + list(self.col_colors[1:])
assert list(cm.row_colors) == [(1.0, 1.0, 1.0)] + list(self.row_colors[1:])
def test_row_col_colors_ignore_heatmap_kwargs(self):
g = mat.clustermap(self.rs.uniform(0, 200, self.df_norm.shape),
row_colors=self.row_colors,
col_colors=self.col_colors,
cmap="Spectral",
norm=mpl.colors.LogNorm(),
vmax=100)
assert np.array_equal(
np.array(self.row_colors)[g.dendrogram_row.reordered_ind],
g.ax_row_colors.collections[0].get_facecolors()[:, :3]
)
assert np.array_equal(
np.array(self.col_colors)[g.dendrogram_col.reordered_ind],
g.ax_col_colors.collections[0].get_facecolors()[:, :3]
)
def test_row_col_colors_raise_on_mixed_index_types(self):
row_colors = pd.Series(
list(self.row_colors), name="row_annot", index=self.df_norm.index
)
col_colors = pd.Series(
list(self.col_colors), name="col_annot", index=self.df_norm.columns
)
with pytest.raises(TypeError):
mat.clustermap(self.x_norm, row_colors=row_colors)
with pytest.raises(TypeError):
mat.clustermap(self.x_norm, col_colors=col_colors)
def test_mask_reorganization(self):
kws = self.default_kws.copy()
kws["mask"] = self.df_norm > 0
g = mat.clustermap(self.df_norm, **kws)
npt.assert_array_equal(g.data2d.index, g.mask.index)
npt.assert_array_equal(g.data2d.columns, g.mask.columns)
npt.assert_array_equal(g.mask.index,
self.df_norm.index[
g.dendrogram_row.reordered_ind])
npt.assert_array_equal(g.mask.columns,
self.df_norm.columns[
g.dendrogram_col.reordered_ind])
def test_ticklabel_reorganization(self):
kws = self.default_kws.copy()
xtl = np.arange(self.df_norm.shape[1])
kws["xticklabels"] = list(xtl)
ytl = self.letters.loc[:self.df_norm.shape[0]]
kws["yticklabels"] = ytl
g = mat.clustermap(self.df_norm, **kws)
xtl_actual = [t.get_text() for t in g.ax_heatmap.get_xticklabels()]
ytl_actual = [t.get_text() for t in g.ax_heatmap.get_yticklabels()]
xtl_want = xtl[g.dendrogram_col.reordered_ind].astype(" g1.ax_col_dendrogram.get_position().height)
assert (g2.ax_col_colors.get_position().height
> g1.ax_col_colors.get_position().height)
assert (g2.ax_heatmap.get_position().height
< g1.ax_heatmap.get_position().height)
assert (g2.ax_row_dendrogram.get_position().width
> g1.ax_row_dendrogram.get_position().width)
assert (g2.ax_row_colors.get_position().width
> g1.ax_row_colors.get_position().width)
assert (g2.ax_heatmap.get_position().width
< g1.ax_heatmap.get_position().width)
kws1 = self.default_kws.copy()
kws1.update(col_colors=self.col_colors)
kws2 = kws1.copy()
kws2.update(col_colors=[self.col_colors, self.col_colors])
g1 = mat.clustermap(self.df_norm, **kws1)
g2 = mat.clustermap(self.df_norm, **kws2)
assert (g2.ax_col_colors.get_position().height
> g1.ax_col_colors.get_position().height)
kws1 = self.default_kws.copy()
kws1.update(dendrogram_ratio=(.2, .2))
kws2 = kws1.copy()
kws2.update(dendrogram_ratio=(.2, .3))
g1 = mat.clustermap(self.df_norm, **kws1)
g2 = mat.clustermap(self.df_norm, **kws2)
# Fails on pinned matplotlib?
# assert (g2.ax_row_dendrogram.get_position().width
# == g1.ax_row_dendrogram.get_position().width)
assert g1.gs.get_width_ratios() == g2.gs.get_width_ratios()
assert (g2.ax_col_dendrogram.get_position().height
> g1.ax_col_dendrogram.get_position().height)
def test_cbar_pos(self):
kws = self.default_kws.copy()
kws["cbar_pos"] = (.2, .1, .4, .3)
g = mat.clustermap(self.df_norm, **kws)
pos = g.ax_cbar.get_position()
assert pytest.approx(tuple(pos.p0)) == kws["cbar_pos"][:2]
assert pytest.approx(pos.width) == kws["cbar_pos"][2]
assert pytest.approx(pos.height) == kws["cbar_pos"][3]
kws["cbar_pos"] = None
g = mat.clustermap(self.df_norm, **kws)
assert g.ax_cbar is None
def test_square_warning(self):
kws = self.default_kws.copy()
g1 = mat.clustermap(self.df_norm, **kws)
with pytest.warns(UserWarning):
kws["square"] = True
g2 = mat.clustermap(self.df_norm, **kws)
g1_shape = g1.ax_heatmap.get_position().get_points()
g2_shape = g2.ax_heatmap.get_position().get_points()
assert np.array_equal(g1_shape, g2_shape)
def test_clustermap_annotation(self):
g = mat.clustermap(self.df_norm, annot=True, fmt=".1f")
for val, text in zip(np.asarray(g.data2d).flat, g.ax_heatmap.texts):
assert text.get_text() == f"{val:.1f}"
g = mat.clustermap(self.df_norm, annot=self.df_norm, fmt=".1f")
for val, text in zip(np.asarray(g.data2d).flat, g.ax_heatmap.texts):
assert text.get_text() == f"{val:.1f}"
def test_tree_kws(self):
rgb = (1, .5, .2)
g = mat.clustermap(self.df_norm, tree_kws=dict(color=rgb))
for ax in [g.ax_col_dendrogram, g.ax_row_dendrogram]:
tree, = ax.collections
assert tuple(tree.get_color().squeeze())[:3] == rgb
if _no_scipy:
def test_required_scipy_errors():
x = np.random.normal(0, 1, (10, 10))
with pytest.raises(RuntimeError):
mat.clustermap(x)
with pytest.raises(RuntimeError):
mat.ClusterGrid(x)
with pytest.raises(RuntimeError):
mat.dendrogram(x)
================================================
FILE: tests/test_miscplot.py
================================================
import matplotlib.pyplot as plt
from seaborn import miscplot as misc
from seaborn.palettes import color_palette
from .test_utils import _network
class TestPalPlot:
"""Test the function that visualizes a color palette."""
def test_palplot_size(self):
pal4 = color_palette("husl", 4)
misc.palplot(pal4)
size4 = plt.gcf().get_size_inches()
assert tuple(size4) == (4, 1)
pal5 = color_palette("husl", 5)
misc.palplot(pal5)
size5 = plt.gcf().get_size_inches()
assert tuple(size5) == (5, 1)
palbig = color_palette("husl", 3)
misc.palplot(palbig, 2)
sizebig = plt.gcf().get_size_inches()
assert tuple(sizebig) == (6, 2)
class TestDogPlot:
@_network(url="https://github.com/mwaskom/seaborn-data")
def test_dogplot(self):
misc.dogplot()
ax = plt.gca()
assert len(ax.images) == 1
================================================
FILE: tests/test_objects.py
================================================
import seaborn.objects
from seaborn._core.plot import Plot
from seaborn._core.moves import Move
from seaborn._core.scales import Scale
from seaborn._marks.base import Mark
from seaborn._stats.base import Stat
def test_objects_namespace():
for name in dir(seaborn.objects):
if not name.startswith("__"):
obj = getattr(seaborn.objects, name)
assert issubclass(obj, (Plot, Mark, Stat, Move, Scale))
================================================
FILE: tests/test_palettes.py
================================================
import colorsys
import numpy as np
import matplotlib as mpl
import pytest
import numpy.testing as npt
from seaborn import palettes, utils, rcmod
from seaborn.external import husl
from seaborn._compat import get_colormap
from seaborn.colors import xkcd_rgb, crayons
class TestColorPalettes:
def test_current_palette(self):
pal = palettes.color_palette(["red", "blue", "green"])
rcmod.set_palette(pal)
assert pal == utils.get_color_cycle()
rcmod.set()
def test_palette_context(self):
default_pal = palettes.color_palette()
context_pal = palettes.color_palette("muted")
with palettes.color_palette(context_pal):
assert utils.get_color_cycle() == context_pal
assert utils.get_color_cycle() == default_pal
def test_big_palette_context(self):
original_pal = palettes.color_palette("deep", n_colors=8)
context_pal = palettes.color_palette("husl", 10)
rcmod.set_palette(original_pal)
with palettes.color_palette(context_pal, 10):
assert utils.get_color_cycle() == context_pal
assert utils.get_color_cycle() == original_pal
# Reset default
rcmod.set()
def test_palette_size(self):
pal = palettes.color_palette("deep")
assert len(pal) == palettes.QUAL_PALETTE_SIZES["deep"]
pal = palettes.color_palette("pastel6")
assert len(pal) == palettes.QUAL_PALETTE_SIZES["pastel6"]
pal = palettes.color_palette("Set3")
assert len(pal) == palettes.QUAL_PALETTE_SIZES["Set3"]
pal = palettes.color_palette("husl")
assert len(pal) == 6
pal = palettes.color_palette("Greens")
assert len(pal) == 6
def test_seaborn_palettes(self):
pals = "deep", "muted", "pastel", "bright", "dark", "colorblind"
for name in pals:
full = palettes.color_palette(name, 10).as_hex()
short = palettes.color_palette(name + "6", 6).as_hex()
b, _, g, r, m, _, _, _, y, c = full
assert [b, g, r, m, y, c] == list(short)
def test_hls_palette(self):
pal1 = palettes.hls_palette()
pal2 = palettes.color_palette("hls")
npt.assert_array_equal(pal1, pal2)
cmap1 = palettes.hls_palette(as_cmap=True)
cmap2 = palettes.color_palette("hls", as_cmap=True)
npt.assert_array_equal(cmap1([.2, .8]), cmap2([.2, .8]))
def test_husl_palette(self):
pal1 = palettes.husl_palette()
pal2 = palettes.color_palette("husl")
npt.assert_array_equal(pal1, pal2)
cmap1 = palettes.husl_palette(as_cmap=True)
cmap2 = palettes.color_palette("husl", as_cmap=True)
npt.assert_array_equal(cmap1([.2, .8]), cmap2([.2, .8]))
def test_mpl_palette(self):
pal1 = palettes.mpl_palette("Reds")
pal2 = palettes.color_palette("Reds")
npt.assert_array_equal(pal1, pal2)
cmap1 = get_colormap("Reds")
cmap2 = palettes.mpl_palette("Reds", as_cmap=True)
cmap3 = palettes.color_palette("Reds", as_cmap=True)
npt.assert_array_equal(cmap1, cmap2)
npt.assert_array_equal(cmap1, cmap3)
def test_mpl_dark_palette(self):
mpl_pal1 = palettes.mpl_palette("Blues_d")
mpl_pal2 = palettes.color_palette("Blues_d")
npt.assert_array_equal(mpl_pal1, mpl_pal2)
mpl_pal1 = palettes.mpl_palette("Blues_r_d")
mpl_pal2 = palettes.color_palette("Blues_r_d")
npt.assert_array_equal(mpl_pal1, mpl_pal2)
def test_bad_palette_name(self):
with pytest.raises(ValueError):
palettes.color_palette("IAmNotAPalette")
def test_terrible_palette_name(self):
with pytest.raises(ValueError):
palettes.color_palette("jet")
def test_bad_palette_colors(self):
pal = ["red", "blue", "iamnotacolor"]
with pytest.raises(ValueError):
palettes.color_palette(pal)
def test_palette_desat(self):
pal1 = palettes.husl_palette(6)
pal1 = [utils.desaturate(c, .5) for c in pal1]
pal2 = palettes.color_palette("husl", desat=.5)
npt.assert_array_equal(pal1, pal2)
def test_palette_is_list_of_tuples(self):
pal_in = np.array(["red", "blue", "green"])
pal_out = palettes.color_palette(pal_in, 3)
assert isinstance(pal_out, list)
assert isinstance(pal_out[0], tuple)
assert isinstance(pal_out[0][0], float)
assert len(pal_out[0]) == 3
def test_palette_cycles(self):
deep = palettes.color_palette("deep6")
double_deep = palettes.color_palette("deep6", 12)
assert double_deep == deep + deep
def test_hls_values(self):
pal1 = palettes.hls_palette(6, h=0)
pal2 = palettes.hls_palette(6, h=.5)
pal2 = pal2[3:] + pal2[:3]
npt.assert_array_almost_equal(pal1, pal2)
pal_dark = palettes.hls_palette(5, l=.2) # noqa
pal_bright = palettes.hls_palette(5, l=.8) # noqa
npt.assert_array_less(list(map(sum, pal_dark)),
list(map(sum, pal_bright)))
pal_flat = palettes.hls_palette(5, s=.1)
pal_bold = palettes.hls_palette(5, s=.9)
npt.assert_array_less(list(map(np.std, pal_flat)),
list(map(np.std, pal_bold)))
def test_husl_values(self):
pal1 = palettes.husl_palette(6, h=0)
pal2 = palettes.husl_palette(6, h=.5)
pal2 = pal2[3:] + pal2[:3]
npt.assert_array_almost_equal(pal1, pal2)
pal_dark = palettes.husl_palette(5, l=.2) # noqa
pal_bright = palettes.husl_palette(5, l=.8) # noqa
npt.assert_array_less(list(map(sum, pal_dark)),
list(map(sum, pal_bright)))
pal_flat = palettes.husl_palette(5, s=.1)
pal_bold = palettes.husl_palette(5, s=.9)
npt.assert_array_less(list(map(np.std, pal_flat)),
list(map(np.std, pal_bold)))
def test_cbrewer_qual(self):
pal_short = palettes.mpl_palette("Set1", 4)
pal_long = palettes.mpl_palette("Set1", 6)
assert pal_short == pal_long[:4]
pal_full = palettes.mpl_palette("Set2", 8)
pal_long = palettes.mpl_palette("Set2", 10)
assert pal_full == pal_long[:8]
def test_mpl_reversal(self):
pal_forward = palettes.mpl_palette("BuPu", 6)
pal_reverse = palettes.mpl_palette("BuPu_r", 6)
npt.assert_array_almost_equal(pal_forward, pal_reverse[::-1])
def test_rgb_from_hls(self):
color = .5, .8, .4
rgb_got = palettes._color_to_rgb(color, "hls")
rgb_want = colorsys.hls_to_rgb(*color)
assert rgb_got == rgb_want
def test_rgb_from_husl(self):
color = 120, 50, 40
rgb_got = palettes._color_to_rgb(color, "husl")
rgb_want = tuple(husl.husl_to_rgb(*color))
assert rgb_got == rgb_want
for h in range(0, 360):
color = h, 100, 100
rgb = palettes._color_to_rgb(color, "husl")
assert min(rgb) >= 0
assert max(rgb) <= 1
def test_rgb_from_xkcd(self):
color = "dull red"
rgb_got = palettes._color_to_rgb(color, "xkcd")
rgb_want = mpl.colors.to_rgb(xkcd_rgb[color])
assert rgb_got == rgb_want
def test_light_palette(self):
n = 4
pal_forward = palettes.light_palette("red", n)
pal_reverse = palettes.light_palette("red", n, reverse=True)
assert np.allclose(pal_forward, pal_reverse[::-1])
red = mpl.colors.colorConverter.to_rgb("red")
assert pal_forward[-1] == red
pal_f_from_string = palettes.color_palette("light:red", n)
assert pal_forward[3] == pal_f_from_string[3]
pal_r_from_string = palettes.color_palette("light:red_r", n)
assert pal_reverse[3] == pal_r_from_string[3]
pal_cmap = palettes.light_palette("blue", as_cmap=True)
assert isinstance(pal_cmap, mpl.colors.LinearSegmentedColormap)
pal_cmap_from_string = palettes.color_palette("light:blue", as_cmap=True)
assert pal_cmap(.8) == pal_cmap_from_string(.8)
pal_cmap = palettes.light_palette("blue", as_cmap=True, reverse=True)
pal_cmap_from_string = palettes.color_palette("light:blue_r", as_cmap=True)
assert pal_cmap(.8) == pal_cmap_from_string(.8)
def test_dark_palette(self):
n = 4
pal_forward = palettes.dark_palette("red", n)
pal_reverse = palettes.dark_palette("red", n, reverse=True)
assert np.allclose(pal_forward, pal_reverse[::-1])
red = mpl.colors.colorConverter.to_rgb("red")
assert pal_forward[-1] == red
pal_f_from_string = palettes.color_palette("dark:red", n)
assert pal_forward[3] == pal_f_from_string[3]
pal_r_from_string = palettes.color_palette("dark:red_r", n)
assert pal_reverse[3] == pal_r_from_string[3]
pal_cmap = palettes.dark_palette("blue", as_cmap=True)
assert isinstance(pal_cmap, mpl.colors.LinearSegmentedColormap)
pal_cmap_from_string = palettes.color_palette("dark:blue", as_cmap=True)
assert pal_cmap(.8) == pal_cmap_from_string(.8)
pal_cmap = palettes.dark_palette("blue", as_cmap=True, reverse=True)
pal_cmap_from_string = palettes.color_palette("dark:blue_r", as_cmap=True)
assert pal_cmap(.8) == pal_cmap_from_string(.8)
def test_diverging_palette(self):
h_neg, h_pos = 100, 200
sat, lum = 70, 50
args = h_neg, h_pos, sat, lum
n = 12
pal = palettes.diverging_palette(*args, n=n)
neg_pal = palettes.light_palette((h_neg, sat, lum), int(n // 2),
input="husl")
pos_pal = palettes.light_palette((h_pos, sat, lum), int(n // 2),
input="husl")
assert len(pal) == n
assert pal[0] == neg_pal[-1]
assert pal[-1] == pos_pal[-1]
pal_dark = palettes.diverging_palette(*args, n=n, center="dark")
assert np.mean(pal[int(n / 2)]) > np.mean(pal_dark[int(n / 2)])
pal_cmap = palettes.diverging_palette(*args, as_cmap=True)
assert isinstance(pal_cmap, mpl.colors.LinearSegmentedColormap)
def test_blend_palette(self):
colors = ["red", "yellow", "white"]
pal_cmap = palettes.blend_palette(colors, as_cmap=True)
assert isinstance(pal_cmap, mpl.colors.LinearSegmentedColormap)
colors = ["red", "blue"]
pal = palettes.blend_palette(colors)
pal_str = "blend:" + ",".join(colors)
pal_from_str = palettes.color_palette(pal_str)
assert pal == pal_from_str
def test_cubehelix_against_matplotlib(self):
x = np.linspace(0, 1, 8)
mpl_pal = mpl.cm.cubehelix(x)[:, :3].tolist()
sns_pal = palettes.cubehelix_palette(8, start=0.5, rot=-1.5, hue=1,
dark=0, light=1, reverse=True)
assert sns_pal == mpl_pal
def test_cubehelix_n_colors(self):
for n in [3, 5, 8]:
pal = palettes.cubehelix_palette(n)
assert len(pal) == n
def test_cubehelix_reverse(self):
pal_forward = palettes.cubehelix_palette()
pal_reverse = palettes.cubehelix_palette(reverse=True)
assert pal_forward == pal_reverse[::-1]
def test_cubehelix_cmap(self):
cmap = palettes.cubehelix_palette(as_cmap=True)
assert isinstance(cmap, mpl.colors.ListedColormap)
pal = palettes.cubehelix_palette()
x = np.linspace(0, 1, 6)
npt.assert_array_equal(cmap(x)[:, :3], pal)
cmap_rev = palettes.cubehelix_palette(as_cmap=True, reverse=True)
x = np.linspace(0, 1, 6)
pal_forward = cmap(x).tolist()
pal_reverse = cmap_rev(x[::-1]).tolist()
assert pal_forward == pal_reverse
def test_cubehelix_code(self):
color_palette = palettes.color_palette
cubehelix_palette = palettes.cubehelix_palette
pal1 = color_palette("ch:", 8)
pal2 = color_palette(cubehelix_palette(8))
assert pal1 == pal2
pal1 = color_palette("ch:.5, -.25,hue = .5,light=.75", 8)
pal2 = color_palette(cubehelix_palette(8, .5, -.25, hue=.5, light=.75))
assert pal1 == pal2
pal1 = color_palette("ch:h=1,r=.5", 9)
pal2 = color_palette(cubehelix_palette(9, hue=1, rot=.5))
assert pal1 == pal2
pal1 = color_palette("ch:_r", 6)
pal2 = color_palette(cubehelix_palette(6, reverse=True))
assert pal1 == pal2
pal1 = color_palette("ch:_r", as_cmap=True)
pal2 = cubehelix_palette(6, reverse=True, as_cmap=True)
assert pal1(.5) == pal2(.5)
def test_xkcd_palette(self):
names = list(xkcd_rgb.keys())[10:15]
colors = palettes.xkcd_palette(names)
for name, color in zip(names, colors):
as_hex = mpl.colors.rgb2hex(color)
assert as_hex == xkcd_rgb[name]
def test_crayon_palette(self):
names = list(crayons.keys())[10:15]
colors = palettes.crayon_palette(names)
for name, color in zip(names, colors):
as_hex = mpl.colors.rgb2hex(color)
assert as_hex == crayons[name].lower()
def test_color_codes(self):
palettes.set_color_codes("deep")
colors = palettes.color_palette("deep6") + [".1"]
for code, color in zip("bgrmyck", colors):
rgb_want = mpl.colors.colorConverter.to_rgb(color)
rgb_got = mpl.colors.colorConverter.to_rgb(code)
assert rgb_want == rgb_got
palettes.set_color_codes("reset")
with pytest.raises(ValueError):
palettes.set_color_codes("Set1")
def test_as_hex(self):
pal = palettes.color_palette("deep")
for rgb, hex in zip(pal, pal.as_hex()):
assert mpl.colors.rgb2hex(rgb) == hex
def test_preserved_palette_length(self):
pal_in = palettes.color_palette("Set1", 10)
pal_out = palettes.color_palette(pal_in)
assert pal_in == pal_out
def test_html_repr(self):
pal = palettes.color_palette()
html = pal._repr_html_()
for color in pal.as_hex():
assert color in html
def test_colormap_display_patch(self):
orig_repr_png = getattr(mpl.colors.Colormap, "_repr_png_", None)
orig_repr_html = getattr(mpl.colors.Colormap, "_repr_html_", None)
try:
palettes._patch_colormap_display()
cmap = mpl.cm.Reds
assert cmap._repr_html_().startswith('
yhat_log[0]
assert yhat_log[20] > yhat_lin[20]
assert yhat_lin[90] > yhat_log[90]
@pytest.mark.skipif(_no_statsmodels, reason="no statsmodels")
def test_regress_n_boot(self):
p = lm._RegressionPlotter("x", "y", data=self.df, n_boot=self.n_boot)
# Fast (linear algebra) version
_, boots_fast = p.fit_fast(self.grid)
npt.assert_equal(boots_fast.shape, (self.n_boot, self.grid.size))
# Slower (np.polyfit) version
_, boots_poly = p.fit_poly(self.grid, 1)
npt.assert_equal(boots_poly.shape, (self.n_boot, self.grid.size))
# Slowest (statsmodels) version
_, boots_smod = p.fit_statsmodels(self.grid, smlm.OLS)
npt.assert_equal(boots_smod.shape, (self.n_boot, self.grid.size))
@pytest.mark.skipif(_no_statsmodels, reason="no statsmodels")
def test_regress_without_bootstrap(self):
p = lm._RegressionPlotter("x", "y", data=self.df,
n_boot=self.n_boot, ci=None)
# Fast (linear algebra) version
_, boots_fast = p.fit_fast(self.grid)
assert boots_fast is None
# Slower (np.polyfit) version
_, boots_poly = p.fit_poly(self.grid, 1)
assert boots_poly is None
# Slowest (statsmodels) version
_, boots_smod = p.fit_statsmodels(self.grid, smlm.OLS)
assert boots_smod is None
def test_regress_bootstrap_seed(self):
seed = 200
p1 = lm._RegressionPlotter("x", "y", data=self.df,
n_boot=self.n_boot, seed=seed)
p2 = lm._RegressionPlotter("x", "y", data=self.df,
n_boot=self.n_boot, seed=seed)
_, boots1 = p1.fit_fast(self.grid)
_, boots2 = p2.fit_fast(self.grid)
npt.assert_array_equal(boots1, boots2)
def test_numeric_bins(self):
p = lm._RegressionPlotter(self.df.x, self.df.y)
x_binned, bins = p.bin_predictor(self.bins_numeric)
npt.assert_equal(len(bins), self.bins_numeric)
npt.assert_array_equal(np.unique(x_binned), bins)
def test_provided_bins(self):
p = lm._RegressionPlotter(self.df.x, self.df.y)
x_binned, bins = p.bin_predictor(self.bins_given)
npt.assert_array_equal(np.unique(x_binned), self.bins_given)
def test_bin_results(self):
p = lm._RegressionPlotter(self.df.x, self.df.y)
x_binned, bins = p.bin_predictor(self.bins_given)
assert self.df.x[x_binned == 0].min() > self.df.x[x_binned == -1].max()
assert self.df.x[x_binned == 1].min() > self.df.x[x_binned == 0].max()
def test_scatter_data(self):
p = lm._RegressionPlotter(self.df.x, self.df.y)
x, y = p.scatter_data
npt.assert_array_equal(x, self.df.x)
npt.assert_array_equal(y, self.df.y)
p = lm._RegressionPlotter(self.df.d, self.df.y)
x, y = p.scatter_data
npt.assert_array_equal(x, self.df.d)
npt.assert_array_equal(y, self.df.y)
p = lm._RegressionPlotter(self.df.d, self.df.y, x_jitter=.1)
x, y = p.scatter_data
assert (x != self.df.d).any()
npt.assert_array_less(np.abs(self.df.d - x), np.repeat(.1, len(x)))
npt.assert_array_equal(y, self.df.y)
p = lm._RegressionPlotter(self.df.d, self.df.y, y_jitter=.05)
x, y = p.scatter_data
npt.assert_array_equal(x, self.df.d)
npt.assert_array_less(np.abs(self.df.y - y), np.repeat(.1, len(y)))
def test_estimate_data(self):
p = lm._RegressionPlotter(self.df.d, self.df.y, x_estimator=np.mean)
x, y, ci = p.estimate_data
npt.assert_array_equal(x, np.sort(np.unique(self.df.d)))
npt.assert_array_almost_equal(y, self.df.groupby("d").y.mean())
npt.assert_array_less(np.array(ci)[:, 0], y)
npt.assert_array_less(y, np.array(ci)[:, 1])
def test_estimate_cis(self):
seed = 123
p = lm._RegressionPlotter(self.df.d, self.df.y,
x_estimator=np.mean, ci=95, seed=seed)
_, _, ci_big = p.estimate_data
p = lm._RegressionPlotter(self.df.d, self.df.y,
x_estimator=np.mean, ci=50, seed=seed)
_, _, ci_wee = p.estimate_data
npt.assert_array_less(np.diff(ci_wee), np.diff(ci_big))
p = lm._RegressionPlotter(self.df.d, self.df.y,
x_estimator=np.mean, ci=None)
_, _, ci_nil = p.estimate_data
npt.assert_array_equal(ci_nil, [None] * len(ci_nil))
def test_estimate_units(self):
# Seed the RNG locally
seed = 345
p = lm._RegressionPlotter("x", "y", data=self.df,
units="s", seed=seed, x_bins=3)
_, _, ci_big = p.estimate_data
ci_big = np.diff(ci_big, axis=1)
p = lm._RegressionPlotter("x", "y", data=self.df, seed=seed, x_bins=3)
_, _, ci_wee = p.estimate_data
ci_wee = np.diff(ci_wee, axis=1)
npt.assert_array_less(ci_wee, ci_big)
def test_partial(self):
x = self.rs.randn(100)
y = x + self.rs.randn(100)
z = x + self.rs.randn(100)
p = lm._RegressionPlotter(y, z)
_, r_orig = np.corrcoef(p.x, p.y)[0]
p = lm._RegressionPlotter(y, z, y_partial=x)
_, r_semipartial = np.corrcoef(p.x, p.y)[0]
assert r_semipartial < r_orig
p = lm._RegressionPlotter(y, z, x_partial=x, y_partial=x)
_, r_partial = np.corrcoef(p.x, p.y)[0]
assert r_partial < r_orig
x = pd.Series(x)
y = pd.Series(y)
p = lm._RegressionPlotter(y, z, x_partial=x, y_partial=x)
_, r_partial = np.corrcoef(p.x, p.y)[0]
assert r_partial < r_orig
@pytest.mark.skipif(_no_statsmodels, reason="no statsmodels")
def test_logistic_regression(self):
p = lm._RegressionPlotter("x", "c", data=self.df,
logistic=True, n_boot=self.n_boot)
_, yhat, _ = p.fit_regression(x_range=(-3, 3))
npt.assert_array_less(yhat, 1)
npt.assert_array_less(0, yhat)
@pytest.mark.skipif(_no_statsmodels, reason="no statsmodels")
def test_logistic_perfect_separation(self):
y = self.df.x > self.df.x.mean()
p = lm._RegressionPlotter("x", y, data=self.df,
logistic=True, n_boot=10)
with warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)
_, yhat, _ = p.fit_regression(x_range=(-3, 3))
assert np.isnan(yhat).all()
@pytest.mark.skipif(_no_statsmodels, reason="no statsmodels")
def test_robust_regression(self):
p_ols = lm._RegressionPlotter("x", "y", data=self.df,
n_boot=self.n_boot)
_, ols_yhat, _ = p_ols.fit_regression(x_range=(-3, 3))
p_robust = lm._RegressionPlotter("x", "y", data=self.df,
robust=True, n_boot=self.n_boot)
_, robust_yhat, _ = p_robust.fit_regression(x_range=(-3, 3))
assert len(ols_yhat) == len(robust_yhat)
@pytest.mark.skipif(_no_statsmodels, reason="no statsmodels")
def test_lowess_regression(self):
p = lm._RegressionPlotter("x", "y", data=self.df, lowess=True)
grid, yhat, err_bands = p.fit_regression(x_range=(-3, 3))
assert len(grid) == len(yhat)
assert err_bands is None
def test_regression_options(self):
with pytest.raises(ValueError):
lm._RegressionPlotter("x", "y", data=self.df,
lowess=True, order=2)
with pytest.raises(ValueError):
lm._RegressionPlotter("x", "y", data=self.df,
lowess=True, logistic=True)
def test_regression_limits(self):
f, ax = plt.subplots()
ax.scatter(self.df.x, self.df.y)
p = lm._RegressionPlotter("x", "y", data=self.df)
grid, _, _ = p.fit_regression(ax)
xlim = ax.get_xlim()
assert grid.min() == xlim[0]
assert grid.max() == xlim[1]
p = lm._RegressionPlotter("x", "y", data=self.df, truncate=True)
grid, _, _ = p.fit_regression()
assert grid.min() == self.df.x.min()
assert grid.max() == self.df.x.max()
class TestRegressionPlots:
rs = np.random.RandomState(56)
df = pd.DataFrame(dict(x=rs.randn(90),
y=rs.randn(90) + 5,
z=rs.randint(0, 1, 90),
g=np.repeat(list("abc"), 30),
h=np.tile(list("xy"), 45),
u=np.tile(np.arange(6), 15)))
bw_err = rs.randn(6)[df.u.values]
df.y += bw_err
def test_regplot_basic(self):
f, ax = plt.subplots()
lm.regplot(x="x", y="y", data=self.df)
assert len(ax.lines) == 1
assert len(ax.collections) == 2
x, y = ax.collections[0].get_offsets().T
npt.assert_array_equal(x, self.df.x)
npt.assert_array_equal(y, self.df.y)
def test_regplot_selective(self):
f, ax = plt.subplots()
ax = lm.regplot(x="x", y="y", data=self.df, scatter=False, ax=ax)
assert len(ax.lines) == 1
assert len(ax.collections) == 1
ax.clear()
f, ax = plt.subplots()
ax = lm.regplot(x="x", y="y", data=self.df, fit_reg=False)
assert len(ax.lines) == 0
assert len(ax.collections) == 1
ax.clear()
f, ax = plt.subplots()
ax = lm.regplot(x="x", y="y", data=self.df, ci=None)
assert len(ax.lines) == 1
assert len(ax.collections) == 1
ax.clear()
def test_regplot_scatter_kws_alpha(self):
f, ax = plt.subplots()
color = np.array([[0.3, 0.8, 0.5, 0.5]])
ax = lm.regplot(x="x", y="y", data=self.df,
scatter_kws={'color': color})
assert ax.collections[0]._alpha is None
assert ax.collections[0]._facecolors[0, 3] == 0.5
f, ax = plt.subplots()
color = np.array([[0.3, 0.8, 0.5]])
ax = lm.regplot(x="x", y="y", data=self.df,
scatter_kws={'color': color})
assert ax.collections[0]._alpha == 0.8
f, ax = plt.subplots()
color = np.array([[0.3, 0.8, 0.5]])
ax = lm.regplot(x="x", y="y", data=self.df,
scatter_kws={'color': color, 'alpha': 0.4})
assert ax.collections[0]._alpha == 0.4
f, ax = plt.subplots()
color = 'r'
ax = lm.regplot(x="x", y="y", data=self.df,
scatter_kws={'color': color})
assert ax.collections[0]._alpha == 0.8
f, ax = plt.subplots()
alpha = .3
ax = lm.regplot(x="x", y="y", data=self.df,
x_bins=5, fit_reg=False,
scatter_kws={"alpha": alpha})
for line in ax.lines:
assert line.get_alpha() == alpha
def test_regplot_binned(self):
ax = lm.regplot(x="x", y="y", data=self.df, x_bins=5)
assert len(ax.lines) == 6
assert len(ax.collections) == 2
def test_lmplot_no_data(self):
with pytest.raises(TypeError):
# keyword argument `data` is required
lm.lmplot(x="x", y="y")
def test_lmplot_basic(self):
g = lm.lmplot(x="x", y="y", data=self.df)
ax = g.axes[0, 0]
assert len(ax.lines) == 1
assert len(ax.collections) == 2
x, y = ax.collections[0].get_offsets().T
npt.assert_array_equal(x, self.df.x)
npt.assert_array_equal(y, self.df.y)
def test_lmplot_hue(self):
g = lm.lmplot(x="x", y="y", data=self.df, hue="h")
ax = g.axes[0, 0]
assert len(ax.lines) == 2
assert len(ax.collections) == 4
def test_lmplot_markers(self):
g1 = lm.lmplot(x="x", y="y", data=self.df, hue="h", markers="s")
assert g1.hue_kws == {"marker": ["s", "s"]}
g2 = lm.lmplot(x="x", y="y", data=self.df, hue="h", markers=["o", "s"])
assert g2.hue_kws == {"marker": ["o", "s"]}
with pytest.raises(ValueError):
lm.lmplot(x="x", y="y", data=self.df, hue="h",
markers=["o", "s", "d"])
def test_lmplot_marker_linewidths(self):
g = lm.lmplot(x="x", y="y", data=self.df, hue="h",
fit_reg=False, markers=["o", "+"])
c = g.axes[0, 0].collections
assert c[1].get_linewidths()[0] == mpl.rcParams["lines.linewidth"]
def test_lmplot_facets(self):
g = lm.lmplot(x="x", y="y", data=self.df, row="g", col="h")
assert g.axes.shape == (3, 2)
g = lm.lmplot(x="x", y="y", data=self.df, col="u", col_wrap=4)
assert g.axes.shape == (6,)
g = lm.lmplot(x="x", y="y", data=self.df, hue="h", col="u")
assert g.axes.shape == (1, 6)
def test_lmplot_hue_col_nolegend(self):
g = lm.lmplot(x="x", y="y", data=self.df, col="h", hue="h")
assert g._legend is None
def test_lmplot_scatter_kws(self):
g = lm.lmplot(x="x", y="y", hue="h", data=self.df, ci=None)
red_scatter, blue_scatter = g.axes[0, 0].collections
red, blue = color_palette(n_colors=2)
npt.assert_array_equal(red, red_scatter.get_facecolors()[0, :3])
npt.assert_array_equal(blue, blue_scatter.get_facecolors()[0, :3])
@pytest.mark.parametrize("sharex", [True, False])
def test_lmplot_facet_truncate(self, sharex):
g = lm.lmplot(
data=self.df, x="x", y="y", hue="g", col="h",
truncate=False, facet_kws=dict(sharex=sharex),
)
for ax in g.axes.flat:
for line in ax.lines:
xdata = line.get_xdata()
assert ax.get_xlim() == tuple(xdata[[0, -1]])
def test_lmplot_sharey(self):
df = pd.DataFrame(dict(
x=[0, 1, 2, 0, 1, 2],
y=[1, -1, 0, -100, 200, 0],
z=["a", "a", "a", "b", "b", "b"],
))
with pytest.warns(UserWarning):
g = lm.lmplot(data=df, x="x", y="y", col="z", sharey=False)
ax1, ax2 = g.axes.flat
assert ax1.get_ylim()[0] > ax2.get_ylim()[0]
assert ax1.get_ylim()[1] < ax2.get_ylim()[1]
def test_lmplot_facet_kws(self):
xlim = -4, 20
g = lm.lmplot(
data=self.df, x="x", y="y", col="h", facet_kws={"xlim": xlim}
)
for ax in g.axes.flat:
assert ax.get_xlim() == xlim
def test_residplot(self):
x, y = self.df.x, self.df.y
ax = lm.residplot(x=x, y=y)
resid = y - np.polyval(np.polyfit(x, y, 1), x)
x_plot, y_plot = ax.collections[0].get_offsets().T
npt.assert_array_equal(x, x_plot)
npt.assert_array_almost_equal(resid, y_plot)
@pytest.mark.skipif(_no_statsmodels, reason="no statsmodels")
def test_residplot_lowess(self):
ax = lm.residplot(x="x", y="y", data=self.df, lowess=True)
assert len(ax.lines) == 2
x, y = ax.lines[1].get_xydata().T
npt.assert_array_equal(x, np.sort(self.df.x))
@pytest.mark.parametrize("option", ["robust", "lowess"])
@pytest.mark.skipif(not _no_statsmodels, reason="statsmodels installed")
def test_residplot_statsmodels_missing_errors(self, long_df, option):
with pytest.raises(RuntimeError, match=rf"`{option}=True` requires"):
lm.residplot(long_df, x="x", y="y", **{option: True})
def test_three_point_colors(self):
x, y = np.random.randn(2, 3)
ax = lm.regplot(x=x, y=y, color=(1, 0, 0))
color = ax.collections[0].get_facecolors()
npt.assert_almost_equal(color[0, :3],
(1, 0, 0))
def test_regplot_xlim(self):
f, ax = plt.subplots()
x, y1, y2 = np.random.randn(3, 50)
lm.regplot(x=x, y=y1, truncate=False)
lm.regplot(x=x, y=y2, truncate=False)
line1, line2 = ax.lines
assert np.array_equal(line1.get_xdata(), line2.get_xdata())
================================================
FILE: tests/test_relational.py
================================================
from itertools import product
import warnings
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import same_color, to_rgba
import pytest
from numpy.testing import assert_array_equal, assert_array_almost_equal
from seaborn.palettes import color_palette
from seaborn._base import categorical_order, unique_markers
from seaborn.relational import (
_RelationalPlotter,
_LinePlotter,
_ScatterPlotter,
relplot,
lineplot,
scatterplot
)
from seaborn.utils import _draw_figure, _version_predates
from seaborn._compat import get_colormap, get_legend_handles
from seaborn._testing import assert_plots_equal
@pytest.fixture(params=[
dict(x="x", y="y"),
dict(x="t", y="y"),
dict(x="a", y="y"),
dict(x="x", y="y", hue="y"),
dict(x="x", y="y", hue="a"),
dict(x="x", y="y", size="a"),
dict(x="x", y="y", style="a"),
dict(x="x", y="y", hue="s"),
dict(x="x", y="y", size="s"),
dict(x="x", y="y", style="s"),
dict(x="x", y="y", hue="a", style="a"),
dict(x="x", y="y", hue="a", size="b", style="b"),
])
def long_semantics(request):
return request.param
class Helpers:
@pytest.fixture
def levels(self, long_df):
return {var: categorical_order(long_df[var]) for var in ["a", "b"]}
def scatter_rgbs(self, collections):
rgbs = []
for col in collections:
rgb = tuple(col.get_facecolor().squeeze()[:3])
rgbs.append(rgb)
return rgbs
def paths_equal(self, *args):
equal = all([len(a) == len(args[0]) for a in args])
for p1, p2 in zip(*args):
equal &= np.array_equal(p1.vertices, p2.vertices)
equal &= np.array_equal(p1.codes, p2.codes)
return equal
class SharedAxesLevelTests:
def test_color(self, long_df):
ax = plt.figure().subplots()
self.func(data=long_df, x="x", y="y", ax=ax)
assert self.get_last_color(ax) == to_rgba("C0")
ax = plt.figure().subplots()
self.func(data=long_df, x="x", y="y", ax=ax)
self.func(data=long_df, x="x", y="y", ax=ax)
assert self.get_last_color(ax) == to_rgba("C1")
ax = plt.figure().subplots()
self.func(data=long_df, x="x", y="y", color="C2", ax=ax)
assert self.get_last_color(ax) == to_rgba("C2")
ax = plt.figure().subplots()
self.func(data=long_df, x="x", y="y", c="C2", ax=ax)
assert self.get_last_color(ax) == to_rgba("C2")
class TestRelationalPlotter(Helpers):
def test_wide_df_variables(self, wide_df):
p = _RelationalPlotter()
p.assign_variables(data=wide_df)
assert p.input_format == "wide"
assert list(p.variables) == ["x", "y", "hue", "style"]
assert len(p.plot_data) == np.prod(wide_df.shape)
x = p.plot_data["x"]
expected_x = np.tile(wide_df.index, wide_df.shape[1])
assert_array_equal(x, expected_x)
y = p.plot_data["y"]
expected_y = wide_df.to_numpy().ravel(order="f")
assert_array_equal(y, expected_y)
hue = p.plot_data["hue"]
expected_hue = np.repeat(wide_df.columns.to_numpy(), wide_df.shape[0])
assert_array_equal(hue, expected_hue)
style = p.plot_data["style"]
expected_style = expected_hue
assert_array_equal(style, expected_style)
assert p.variables["x"] == wide_df.index.name
assert p.variables["y"] is None
assert p.variables["hue"] == wide_df.columns.name
assert p.variables["style"] == wide_df.columns.name
def test_wide_df_with_nonnumeric_variables(self, long_df):
p = _RelationalPlotter()
p.assign_variables(data=long_df)
assert p.input_format == "wide"
assert list(p.variables) == ["x", "y", "hue", "style"]
numeric_df = long_df.select_dtypes("number")
assert len(p.plot_data) == np.prod(numeric_df.shape)
x = p.plot_data["x"]
expected_x = np.tile(numeric_df.index, numeric_df.shape[1])
assert_array_equal(x, expected_x)
y = p.plot_data["y"]
expected_y = numeric_df.to_numpy().ravel(order="f")
assert_array_equal(y, expected_y)
hue = p.plot_data["hue"]
expected_hue = np.repeat(
numeric_df.columns.to_numpy(), numeric_df.shape[0]
)
assert_array_equal(hue, expected_hue)
style = p.plot_data["style"]
expected_style = expected_hue
assert_array_equal(style, expected_style)
assert p.variables["x"] == numeric_df.index.name
assert p.variables["y"] is None
assert p.variables["hue"] == numeric_df.columns.name
assert p.variables["style"] == numeric_df.columns.name
def test_wide_array_variables(self, wide_array):
p = _RelationalPlotter()
p.assign_variables(data=wide_array)
assert p.input_format == "wide"
assert list(p.variables) == ["x", "y", "hue", "style"]
assert len(p.plot_data) == np.prod(wide_array.shape)
nrow, ncol = wide_array.shape
x = p.plot_data["x"]
expected_x = np.tile(np.arange(nrow), ncol)
assert_array_equal(x, expected_x)
y = p.plot_data["y"]
expected_y = wide_array.ravel(order="f")
assert_array_equal(y, expected_y)
hue = p.plot_data["hue"]
expected_hue = np.repeat(np.arange(ncol), nrow)
assert_array_equal(hue, expected_hue)
style = p.plot_data["style"]
expected_style = expected_hue
assert_array_equal(style, expected_style)
assert p.variables["x"] is None
assert p.variables["y"] is None
assert p.variables["hue"] is None
assert p.variables["style"] is None
def test_flat_array_variables(self, flat_array):
p = _RelationalPlotter()
p.assign_variables(data=flat_array)
assert p.input_format == "wide"
assert list(p.variables) == ["x", "y"]
assert len(p.plot_data) == np.prod(flat_array.shape)
x = p.plot_data["x"]
expected_x = np.arange(flat_array.shape[0])
assert_array_equal(x, expected_x)
y = p.plot_data["y"]
expected_y = flat_array
assert_array_equal(y, expected_y)
assert p.variables["x"] is None
assert p.variables["y"] is None
def test_flat_list_variables(self, flat_list):
p = _RelationalPlotter()
p.assign_variables(data=flat_list)
assert p.input_format == "wide"
assert list(p.variables) == ["x", "y"]
assert len(p.plot_data) == len(flat_list)
x = p.plot_data["x"]
expected_x = np.arange(len(flat_list))
assert_array_equal(x, expected_x)
y = p.plot_data["y"]
expected_y = flat_list
assert_array_equal(y, expected_y)
assert p.variables["x"] is None
assert p.variables["y"] is None
def test_flat_series_variables(self, flat_series):
p = _RelationalPlotter()
p.assign_variables(data=flat_series)
assert p.input_format == "wide"
assert list(p.variables) == ["x", "y"]
assert len(p.plot_data) == len(flat_series)
x = p.plot_data["x"]
expected_x = flat_series.index
assert_array_equal(x, expected_x)
y = p.plot_data["y"]
expected_y = flat_series
assert_array_equal(y, expected_y)
assert p.variables["x"] is flat_series.index.name
assert p.variables["y"] is flat_series.name
def test_wide_list_of_series_variables(self, wide_list_of_series):
p = _RelationalPlotter()
p.assign_variables(data=wide_list_of_series)
assert p.input_format == "wide"
assert list(p.variables) == ["x", "y", "hue", "style"]
chunks = len(wide_list_of_series)
chunk_size = max(len(l) for l in wide_list_of_series)
assert len(p.plot_data) == chunks * chunk_size
index_union = np.unique(
np.concatenate([s.index for s in wide_list_of_series])
)
x = p.plot_data["x"]
expected_x = np.tile(index_union, chunks)
assert_array_equal(x, expected_x)
y = p.plot_data["y"]
expected_y = np.concatenate([
s.reindex(index_union) for s in wide_list_of_series
])
assert_array_equal(y, expected_y)
hue = p.plot_data["hue"]
series_names = [s.name for s in wide_list_of_series]
expected_hue = np.repeat(series_names, chunk_size)
assert_array_equal(hue, expected_hue)
style = p.plot_data["style"]
expected_style = expected_hue
assert_array_equal(style, expected_style)
assert p.variables["x"] is None
assert p.variables["y"] is None
assert p.variables["hue"] is None
assert p.variables["style"] is None
def test_wide_list_of_arrays_variables(self, wide_list_of_arrays):
p = _RelationalPlotter()
p.assign_variables(data=wide_list_of_arrays)
assert p.input_format == "wide"
assert list(p.variables) == ["x", "y", "hue", "style"]
chunks = len(wide_list_of_arrays)
chunk_size = max(len(l) for l in wide_list_of_arrays)
assert len(p.plot_data) == chunks * chunk_size
x = p.plot_data["x"]
expected_x = np.tile(np.arange(chunk_size), chunks)
assert_array_equal(x, expected_x)
y = p.plot_data["y"].dropna()
expected_y = np.concatenate(wide_list_of_arrays)
assert_array_equal(y, expected_y)
hue = p.plot_data["hue"]
expected_hue = np.repeat(np.arange(chunks), chunk_size)
assert_array_equal(hue, expected_hue)
style = p.plot_data["style"]
expected_style = expected_hue
assert_array_equal(style, expected_style)
assert p.variables["x"] is None
assert p.variables["y"] is None
assert p.variables["hue"] is None
assert p.variables["style"] is None
def test_wide_list_of_list_variables(self, wide_list_of_lists):
p = _RelationalPlotter()
p.assign_variables(data=wide_list_of_lists)
assert p.input_format == "wide"
assert list(p.variables) == ["x", "y", "hue", "style"]
chunks = len(wide_list_of_lists)
chunk_size = max(len(l) for l in wide_list_of_lists)
assert len(p.plot_data) == chunks * chunk_size
x = p.plot_data["x"]
expected_x = np.tile(np.arange(chunk_size), chunks)
assert_array_equal(x, expected_x)
y = p.plot_data["y"].dropna()
expected_y = np.concatenate(wide_list_of_lists)
assert_array_equal(y, expected_y)
hue = p.plot_data["hue"]
expected_hue = np.repeat(np.arange(chunks), chunk_size)
assert_array_equal(hue, expected_hue)
style = p.plot_data["style"]
expected_style = expected_hue
assert_array_equal(style, expected_style)
assert p.variables["x"] is None
assert p.variables["y"] is None
assert p.variables["hue"] is None
assert p.variables["style"] is None
def test_wide_dict_of_series_variables(self, wide_dict_of_series):
p = _RelationalPlotter()
p.assign_variables(data=wide_dict_of_series)
assert p.input_format == "wide"
assert list(p.variables) == ["x", "y", "hue", "style"]
chunks = len(wide_dict_of_series)
chunk_size = max(len(l) for l in wide_dict_of_series.values())
assert len(p.plot_data) == chunks * chunk_size
x = p.plot_data["x"]
expected_x = np.tile(np.arange(chunk_size), chunks)
assert_array_equal(x, expected_x)
y = p.plot_data["y"].dropna()
expected_y = np.concatenate(list(wide_dict_of_series.values()))
assert_array_equal(y, expected_y)
hue = p.plot_data["hue"]
expected_hue = np.repeat(list(wide_dict_of_series), chunk_size)
assert_array_equal(hue, expected_hue)
style = p.plot_data["style"]
expected_style = expected_hue
assert_array_equal(style, expected_style)
assert p.variables["x"] is None
assert p.variables["y"] is None
assert p.variables["hue"] is None
assert p.variables["style"] is None
def test_wide_dict_of_arrays_variables(self, wide_dict_of_arrays):
p = _RelationalPlotter()
p.assign_variables(data=wide_dict_of_arrays)
assert p.input_format == "wide"
assert list(p.variables) == ["x", "y", "hue", "style"]
chunks = len(wide_dict_of_arrays)
chunk_size = max(len(l) for l in wide_dict_of_arrays.values())
assert len(p.plot_data) == chunks * chunk_size
x = p.plot_data["x"]
expected_x = np.tile(np.arange(chunk_size), chunks)
assert_array_equal(x, expected_x)
y = p.plot_data["y"].dropna()
expected_y = np.concatenate(list(wide_dict_of_arrays.values()))
assert_array_equal(y, expected_y)
hue = p.plot_data["hue"]
expected_hue = np.repeat(list(wide_dict_of_arrays), chunk_size)
assert_array_equal(hue, expected_hue)
style = p.plot_data["style"]
expected_style = expected_hue
assert_array_equal(style, expected_style)
assert p.variables["x"] is None
assert p.variables["y"] is None
assert p.variables["hue"] is None
assert p.variables["style"] is None
def test_wide_dict_of_lists_variables(self, wide_dict_of_lists):
p = _RelationalPlotter()
p.assign_variables(data=wide_dict_of_lists)
assert p.input_format == "wide"
assert list(p.variables) == ["x", "y", "hue", "style"]
chunks = len(wide_dict_of_lists)
chunk_size = max(len(l) for l in wide_dict_of_lists.values())
assert len(p.plot_data) == chunks * chunk_size
x = p.plot_data["x"]
expected_x = np.tile(np.arange(chunk_size), chunks)
assert_array_equal(x, expected_x)
y = p.plot_data["y"].dropna()
expected_y = np.concatenate(list(wide_dict_of_lists.values()))
assert_array_equal(y, expected_y)
hue = p.plot_data["hue"]
expected_hue = np.repeat(list(wide_dict_of_lists), chunk_size)
assert_array_equal(hue, expected_hue)
style = p.plot_data["style"]
expected_style = expected_hue
assert_array_equal(style, expected_style)
assert p.variables["x"] is None
assert p.variables["y"] is None
assert p.variables["hue"] is None
assert p.variables["style"] is None
def test_relplot_simple(self, long_df):
g = relplot(data=long_df, x="x", y="y", kind="scatter")
x, y = g.ax.collections[0].get_offsets().T
assert_array_equal(x, long_df["x"])
assert_array_equal(y, long_df["y"])
g = relplot(data=long_df, x="x", y="y", kind="line")
x, y = g.ax.lines[0].get_xydata().T
expected = long_df.groupby("x").y.mean()
assert_array_equal(x, expected.index)
assert y == pytest.approx(expected.values)
with pytest.raises(ValueError):
g = relplot(data=long_df, x="x", y="y", kind="not_a_kind")
def test_relplot_complex(self, long_df):
for sem in ["hue", "size", "style"]:
g = relplot(data=long_df, x="x", y="y", **{sem: "a"})
x, y = g.ax.collections[0].get_offsets().T
assert_array_equal(x, long_df["x"])
assert_array_equal(y, long_df["y"])
for sem in ["hue", "size", "style"]:
g = relplot(
data=long_df, x="x", y="y", col="c", **{sem: "a"}
)
grouped = long_df.groupby("c")
for (_, grp_df), ax in zip(grouped, g.axes.flat):
x, y = ax.collections[0].get_offsets().T
assert_array_equal(x, grp_df["x"])
assert_array_equal(y, grp_df["y"])
for sem in ["size", "style"]:
g = relplot(
data=long_df, x="x", y="y", hue="b", col="c", **{sem: "a"}
)
grouped = long_df.groupby("c")
for (_, grp_df), ax in zip(grouped, g.axes.flat):
x, y = ax.collections[0].get_offsets().T
assert_array_equal(x, grp_df["x"])
assert_array_equal(y, grp_df["y"])
for sem in ["hue", "size", "style"]:
g = relplot(
data=long_df.sort_values(["c", "b"]),
x="x", y="y", col="b", row="c", **{sem: "a"}
)
grouped = long_df.groupby(["c", "b"])
for (_, grp_df), ax in zip(grouped, g.axes.flat):
x, y = ax.collections[0].get_offsets().T
assert_array_equal(x, grp_df["x"])
assert_array_equal(y, grp_df["y"])
@pytest.mark.parametrize("vector_type", ["series", "numpy", "list"])
def test_relplot_vectors(self, long_df, vector_type):
semantics = dict(x="x", y="y", hue="f", col="c")
kws = {key: long_df[val] for key, val in semantics.items()}
if vector_type == "numpy":
kws = {k: v.to_numpy() for k, v in kws.items()}
elif vector_type == "list":
kws = {k: v.to_list() for k, v in kws.items()}
g = relplot(data=long_df, **kws)
grouped = long_df.groupby("c")
assert len(g.axes_dict) == len(grouped)
for (_, grp_df), ax in zip(grouped, g.axes.flat):
x, y = ax.collections[0].get_offsets().T
assert_array_equal(x, grp_df["x"])
assert_array_equal(y, grp_df["y"])
def test_relplot_wide(self, wide_df):
g = relplot(data=wide_df)
x, y = g.ax.collections[0].get_offsets().T
assert_array_equal(y, wide_df.to_numpy().T.ravel())
assert not g.ax.get_ylabel()
def test_relplot_hues(self, long_df):
palette = ["r", "b", "g"]
g = relplot(
x="x", y="y", hue="a", style="b", col="c",
palette=palette, data=long_df
)
palette = dict(zip(long_df["a"].unique(), palette))
grouped = long_df.groupby("c")
for (_, grp_df), ax in zip(grouped, g.axes.flat):
points = ax.collections[0]
expected_hues = [palette[val] for val in grp_df["a"]]
assert same_color(points.get_facecolors(), expected_hues)
def test_relplot_sizes(self, long_df):
sizes = [5, 12, 7]
g = relplot(
data=long_df,
x="x", y="y", size="a", hue="b", col="c",
sizes=sizes,
)
sizes = dict(zip(long_df["a"].unique(), sizes))
grouped = long_df.groupby("c")
for (_, grp_df), ax in zip(grouped, g.axes.flat):
points = ax.collections[0]
expected_sizes = [sizes[val] for val in grp_df["a"]]
assert_array_equal(points.get_sizes(), expected_sizes)
def test_relplot_styles(self, long_df):
markers = ["o", "d", "s"]
g = relplot(
data=long_df,
x="x", y="y", style="a", hue="b", col="c",
markers=markers,
)
paths = []
for m in markers:
m = mpl.markers.MarkerStyle(m)
paths.append(m.get_path().transformed(m.get_transform()))
paths = dict(zip(long_df["a"].unique(), paths))
grouped = long_df.groupby("c")
for (_, grp_df), ax in zip(grouped, g.axes.flat):
points = ax.collections[0]
expected_paths = [paths[val] for val in grp_df["a"]]
assert self.paths_equal(points.get_paths(), expected_paths)
def test_relplot_weighted_estimator(self, long_df):
g = relplot(data=long_df, x="a", y="y", weights="x", kind="line")
ydata = g.ax.lines[0].get_ydata()
for i, level in enumerate(categorical_order(long_df["a"])):
pos_df = long_df[long_df["a"] == level]
expected = np.average(pos_df["y"], weights=pos_df["x"])
assert ydata[i] == pytest.approx(expected)
def test_relplot_stringy_numerics(self, long_df):
long_df["x_str"] = long_df["x"].astype(str)
g = relplot(data=long_df, x="x", y="y", hue="x_str")
points = g.ax.collections[0]
xys = points.get_offsets()
mask = np.ma.getmask(xys)
assert not mask.any()
assert_array_equal(xys, long_df[["x", "y"]])
g = relplot(data=long_df, x="x", y="y", size="x_str")
points = g.ax.collections[0]
xys = points.get_offsets()
mask = np.ma.getmask(xys)
assert not mask.any()
assert_array_equal(xys, long_df[["x", "y"]])
def test_relplot_legend(self, long_df):
g = relplot(data=long_df, x="x", y="y")
assert g._legend is None
g = relplot(data=long_df, x="x", y="y", hue="a")
texts = [t.get_text() for t in g._legend.texts]
expected_texts = long_df["a"].unique()
assert_array_equal(texts, expected_texts)
g = relplot(data=long_df, x="x", y="y", hue="s", size="s")
texts = [t.get_text() for t in g._legend.texts]
assert_array_equal(texts, np.sort(texts))
g = relplot(data=long_df, x="x", y="y", hue="a", legend=False)
assert g._legend is None
palette = color_palette("deep", len(long_df["b"].unique()))
a_like_b = dict(zip(long_df["a"].unique(), long_df["b"].unique()))
long_df["a_like_b"] = long_df["a"].map(a_like_b)
g = relplot(
data=long_df,
x="x", y="y", hue="b", style="a_like_b",
palette=palette, kind="line", estimator=None,
)
lines = g._legend.get_lines()[1:] # Chop off title dummy
for line, color in zip(lines, palette):
assert line.get_color() == color
def test_relplot_unshared_axis_labels(self, long_df):
col, row = "a", "b"
g = relplot(
data=long_df, x="x", y="y", col=col, row=row,
facet_kws=dict(sharex=False, sharey=False),
)
for ax in g.axes[-1, :].flat:
assert ax.get_xlabel() == "x"
for ax in g.axes[:-1, :].flat:
assert ax.get_xlabel() == ""
for ax in g.axes[:, 0].flat:
assert ax.get_ylabel() == "y"
for ax in g.axes[:, 1:].flat:
assert ax.get_ylabel() == ""
def test_relplot_data(self, long_df):
g = relplot(
data=long_df.to_dict(orient="list"),
x="x",
y=long_df["y"].rename("y_var"),
hue=long_df["a"].to_numpy(),
col="c",
)
expected_cols = set(long_df.columns.to_list() + ["_hue_", "y_var"])
assert set(g.data.columns) == expected_cols
assert_array_equal(g.data["y_var"], long_df["y"])
assert_array_equal(g.data["_hue_"], long_df["a"])
def test_facet_variable_collision(self, long_df):
# https://github.com/mwaskom/seaborn/issues/2488
col_data = long_df["c"]
long_df = long_df.assign(size=col_data)
g = relplot(
data=long_df,
x="x", y="y", col="size",
)
assert g.axes.shape == (1, len(col_data.unique()))
def test_relplot_scatter_unused_variables(self, long_df):
with pytest.warns(UserWarning, match="The `units` parameter"):
g = relplot(long_df, x="x", y="y", units="a")
assert g.ax is not None
with pytest.warns(UserWarning, match="The `weights` parameter"):
g = relplot(long_df, x="x", y="y", weights="x")
assert g.ax is not None
def test_ax_kwarg_removal(self, long_df):
f, ax = plt.subplots()
with pytest.warns(UserWarning):
g = relplot(data=long_df, x="x", y="y", ax=ax)
assert len(ax.collections) == 0
assert len(g.ax.collections) > 0
def test_legend_has_no_offset(self, long_df):
g = relplot(data=long_df, x="x", y="y", hue=long_df["z"] + 1e8)
for text in g.legend.texts:
assert float(text.get_text()) > 1e7
def test_lineplot_2d_dashes(self, long_df):
ax = lineplot(data=long_df[["x", "y"]], dashes=[(5, 5), (10, 10)])
for line in ax.get_lines():
assert line.is_dashed()
def test_legend_attributes_hue(self, long_df):
kws = {"s": 50, "linewidth": 1, "marker": "X"}
g = relplot(long_df, x="x", y="y", hue="a", **kws)
palette = color_palette()
for i, pt in enumerate(get_legend_handles(g.legend)):
assert same_color(pt.get_color(), palette[i])
assert pt.get_markersize() == np.sqrt(kws["s"])
assert pt.get_markeredgewidth() == kws["linewidth"]
if not _version_predates(mpl, "3.7.0"):
assert pt.get_marker() == kws["marker"]
def test_legend_attributes_style(self, long_df):
kws = {"s": 50, "linewidth": 1, "color": "r"}
g = relplot(long_df, x="x", y="y", style="a", **kws)
for pt in get_legend_handles(g.legend):
assert pt.get_markersize() == np.sqrt(kws["s"])
assert pt.get_markeredgewidth() == kws["linewidth"]
assert same_color(pt.get_color(), "r")
def test_legend_attributes_hue_and_style(self, long_df):
kws = {"s": 50, "linewidth": 1}
g = relplot(long_df, x="x", y="y", hue="a", style="b", **kws)
for pt in get_legend_handles(g.legend):
if pt.get_label() not in ["a", "b"]:
assert pt.get_markersize() == np.sqrt(kws["s"])
assert pt.get_markeredgewidth() == kws["linewidth"]
class TestLinePlotter(SharedAxesLevelTests, Helpers):
func = staticmethod(lineplot)
def get_last_color(self, ax):
return to_rgba(ax.lines[-1].get_color())
def test_legend_no_semantics(self, long_df):
ax = lineplot(long_df, x="x", y="y")
handles, _ = ax.get_legend_handles_labels()
assert handles == []
def test_legend_hue_categorical(self, long_df, levels):
ax = lineplot(long_df, x="x", y="y", hue="a")
handles, labels = ax.get_legend_handles_labels()
colors = [h.get_color() for h in handles]
assert labels == levels["a"]
assert colors == color_palette(n_colors=len(labels))
def test_legend_hue_and_style_same(self, long_df, levels):
ax = lineplot(long_df, x="x", y="y", hue="a", style="a", markers=True)
handles, labels = ax.get_legend_handles_labels()
colors = [h.get_color() for h in handles]
markers = [h.get_marker() for h in handles]
assert labels == levels["a"]
assert colors == color_palette(n_colors=len(labels))
assert markers == unique_markers(len(labels))
def test_legend_hue_and_style_diff(self, long_df, levels):
ax = lineplot(long_df, x="x", y="y", hue="a", style="b", markers=True)
handles, labels = ax.get_legend_handles_labels()
colors = [h.get_color() for h in handles]
markers = [h.get_marker() for h in handles]
expected_labels = ["a", *levels["a"], "b", *levels["b"]]
expected_colors = [
"w", *color_palette(n_colors=len(levels["a"])),
"w", *[".2" for _ in levels["b"]],
]
expected_markers = [
"", *["None" for _ in levels["a"]]
+ [""] + unique_markers(len(levels["b"]))
]
assert labels == expected_labels
assert colors == expected_colors
assert markers == expected_markers
def test_legend_hue_and_size_same(self, long_df, levels):
ax = lineplot(long_df, x="x", y="y", hue="a", size="a")
handles, labels = ax.get_legend_handles_labels()
colors = [h.get_color() for h in handles]
widths = [h.get_linewidth() for h in handles]
assert labels == levels["a"]
assert colors == color_palette(n_colors=len(levels["a"]))
expected_widths = [
w * mpl.rcParams["lines.linewidth"]
for w in np.linspace(2, 0.5, len(levels["a"]))
]
assert widths == expected_widths
@pytest.mark.parametrize("var", ["hue", "size", "style"])
def test_legend_numerical_full(self, long_df, var):
x, y = np.random.randn(2, 40)
z = np.tile(np.arange(20), 2)
ax = lineplot(x=x, y=y, **{var: z}, legend="full")
_, labels = ax.get_legend_handles_labels()
assert labels == [str(z_i) for z_i in sorted(set(z))]
@pytest.mark.parametrize("var", ["hue", "size", "style"])
def test_legend_numerical_brief(self, var):
x, y = np.random.randn(2, 40)
z = np.tile(np.arange(20), 2)
ax = lineplot(x=x, y=y, **{var: z}, legend="brief")
_, labels = ax.get_legend_handles_labels()
if var == "style":
assert labels == [str(z_i) for z_i in sorted(set(z))]
else:
assert labels == ["0", "4", "8", "12", "16"]
def test_legend_value_error(self, long_df):
with pytest.raises(ValueError, match=r"`legend` must be"):
lineplot(long_df, x="x", y="y", hue="a", legend="bad_value")
@pytest.mark.parametrize("var", ["hue", "size"])
def test_legend_log_norm(self, var):
x, y = np.random.randn(2, 40)
z = np.tile(np.arange(20), 2)
norm = mpl.colors.LogNorm()
ax = lineplot(x=x, y=y, **{var: z + 1, f"{var}_norm": norm})
_, labels = ax.get_legend_handles_labels()
assert float(labels[1]) / float(labels[0]) == 10
@pytest.mark.parametrize("var", ["hue", "size"])
def test_legend_binary_var(self, var):
x, y = np.random.randn(2, 40)
z = np.tile(np.arange(20), 2)
ax = lineplot(x=x, y=y, hue=z % 2)
_, labels = ax.get_legend_handles_labels()
assert labels == ["0", "1"]
@pytest.mark.parametrize("var", ["hue", "size"])
def test_legend_binary_numberic_brief(self, long_df, var):
ax = lineplot(long_df, x="x", y="y", **{var: "f"}, legend="brief")
_, labels = ax.get_legend_handles_labels()
expected_labels = ['0.20', '0.22', '0.24', '0.26', '0.28']
assert labels == expected_labels
def test_plot(self, long_df, repeated_df):
f, ax = plt.subplots()
p = _LinePlotter(
data=long_df,
variables=dict(x="x", y="y"),
sort=False,
estimator=None
)
p.plot(ax, {})
line, = ax.lines
assert_array_equal(line.get_xdata(), long_df.x.to_numpy())
assert_array_equal(line.get_ydata(), long_df.y.to_numpy())
ax.clear()
p.plot(ax, {"color": "k", "label": "test"})
line, = ax.lines
assert line.get_color() == "k"
assert line.get_label() == "test"
p = _LinePlotter(
data=long_df,
variables=dict(x="x", y="y"),
sort=True, estimator=None
)
ax.clear()
p.plot(ax, {})
line, = ax.lines
sorted_data = long_df.sort_values(["x", "y"])
assert_array_equal(line.get_xdata(), sorted_data.x.to_numpy())
assert_array_equal(line.get_ydata(), sorted_data.y.to_numpy())
p = _LinePlotter(
data=long_df,
variables=dict(x="x", y="y", hue="a"),
)
ax.clear()
p.plot(ax, {})
assert len(ax.lines) == len(p._hue_map.levels)
for line, level in zip(ax.lines, p._hue_map.levels):
assert line.get_color() == p._hue_map(level)
p = _LinePlotter(
data=long_df,
variables=dict(x="x", y="y", size="a"),
)
ax.clear()
p.plot(ax, {})
assert len(ax.lines) == len(p._size_map.levels)
for line, level in zip(ax.lines, p._size_map.levels):
assert line.get_linewidth() == p._size_map(level)
p = _LinePlotter(
data=long_df,
variables=dict(x="x", y="y", hue="a", style="a"),
)
p.map_style(markers=True)
ax.clear()
p.plot(ax, {})
assert len(ax.lines) == len(p._hue_map.levels)
assert len(ax.lines) == len(p._style_map.levels)
for line, level in zip(ax.lines, p._hue_map.levels):
assert line.get_color() == p._hue_map(level)
assert line.get_marker() == p._style_map(level, "marker")
p = _LinePlotter(
data=long_df,
variables=dict(x="x", y="y", hue="a", style="b"),
)
p.map_style(markers=True)
ax.clear()
p.plot(ax, {})
levels = product(p._hue_map.levels, p._style_map.levels)
expected_line_count = len(p._hue_map.levels) * len(p._style_map.levels)
assert len(ax.lines) == expected_line_count
for line, (hue, style) in zip(ax.lines, levels):
assert line.get_color() == p._hue_map(hue)
assert line.get_marker() == p._style_map(style, "marker")
p = _LinePlotter(
data=long_df,
variables=dict(x="x", y="y"),
estimator="mean", err_style="band", errorbar="sd", sort=True
)
ax.clear()
p.plot(ax, {})
line, = ax.lines
expected_data = long_df.groupby("x").y.mean()
assert_array_equal(line.get_xdata(), expected_data.index.to_numpy())
assert np.allclose(line.get_ydata(), expected_data.to_numpy())
assert len(ax.collections) == 1
# Test that nans do not propagate to means or CIs
p = _LinePlotter(
variables=dict(
x=[1, 1, 1, 2, 2, 2, 3, 3, 3],
y=[1, 2, 3, 3, np.nan, 5, 4, 5, 6],
),
estimator="mean", err_style="band", errorbar="ci", n_boot=100, sort=True,
)
ax.clear()
p.plot(ax, {})
line, = ax.lines
assert line.get_xdata().tolist() == [1, 2, 3]
err_band = ax.collections[0].get_paths()
assert len(err_band) == 1
assert len(err_band[0].vertices) == 9
p = _LinePlotter(
data=long_df,
variables=dict(x="x", y="y", hue="a"),
estimator="mean", err_style="band", errorbar="sd"
)
ax.clear()
p.plot(ax, {})
assert len(ax.lines) == len(ax.collections) == len(p._hue_map.levels)
for c in ax.collections:
assert isinstance(c, mpl.collections.PolyCollection)
p = _LinePlotter(
data=long_df,
variables=dict(x="x", y="y", hue="a"),
estimator="mean", err_style="bars", errorbar="sd"
)
ax.clear()
p.plot(ax, {})
n_lines = len(ax.lines)
assert n_lines / 2 == len(ax.collections) == len(p._hue_map.levels)
assert len(ax.collections) == len(p._hue_map.levels)
for c in ax.collections:
assert isinstance(c, mpl.collections.LineCollection)
p = _LinePlotter(
data=repeated_df,
variables=dict(x="x", y="y", units="u"),
estimator=None
)
ax.clear()
p.plot(ax, {})
n_units = len(repeated_df["u"].unique())
assert len(ax.lines) == n_units
p = _LinePlotter(
data=repeated_df,
variables=dict(x="x", y="y", hue="a", units="u"),
estimator=None
)
ax.clear()
p.plot(ax, {})
n_units *= len(repeated_df["a"].unique())
assert len(ax.lines) == n_units
p.estimator = "mean"
with pytest.raises(ValueError):
p.plot(ax, {})
p = _LinePlotter(
data=long_df,
variables=dict(x="x", y="y", hue="a"),
err_style="band", err_kws={"alpha": .5},
)
ax.clear()
p.plot(ax, {})
for band in ax.collections:
assert band.get_alpha() == .5
p = _LinePlotter(
data=long_df,
variables=dict(x="x", y="y", hue="a"),
err_style="bars", err_kws={"elinewidth": 2},
)
ax.clear()
p.plot(ax, {})
for lines in ax.collections:
assert lines.get_linestyles() == 2
p.err_style = "invalid"
with pytest.raises(ValueError):
p.plot(ax, {})
x_str = long_df["x"].astype(str)
p = _LinePlotter(
data=long_df,
variables=dict(x="x", y="y", hue=x_str),
)
ax.clear()
p.plot(ax, {})
p = _LinePlotter(
data=long_df,
variables=dict(x="x", y="y", size=x_str),
)
ax.clear()
p.plot(ax, {})
def test_weights(self, long_df):
ax = lineplot(long_df, x="a", y="y", weights="x")
vals = ax.lines[0].get_ydata()
for i, level in enumerate(categorical_order(long_df["a"])):
pos_df = long_df[long_df["a"] == level]
expected = np.average(pos_df["y"], weights=pos_df["x"])
assert vals[i] == pytest.approx(expected)
def test_non_aggregated_data(self):
x = [1, 2, 3, 4]
y = [2, 4, 6, 8]
ax = lineplot(x=x, y=y)
line, = ax.lines
assert_array_equal(line.get_xdata(), x)
assert_array_equal(line.get_ydata(), y)
def test_orient(self, long_df):
long_df = long_df.drop("x", axis=1).rename(columns={"s": "y", "y": "x"})
ax1 = plt.figure().subplots()
lineplot(data=long_df, x="x", y="y", orient="y", errorbar="sd")
assert len(ax1.lines) == len(ax1.collections)
line, = ax1.lines
expected = long_df.groupby("y").agg({"x": "mean"}).reset_index()
assert_array_almost_equal(line.get_xdata(), expected["x"])
assert_array_almost_equal(line.get_ydata(), expected["y"])
ribbon_y = ax1.collections[0].get_paths()[0].vertices[:, 1]
assert_array_equal(np.unique(ribbon_y), long_df["y"].sort_values().unique())
ax2 = plt.figure().subplots()
lineplot(
data=long_df, x="x", y="y", orient="y", errorbar="sd", err_style="bars"
)
segments = ax2.collections[0].get_segments()
for i, val in enumerate(sorted(long_df["y"].unique())):
assert (segments[i][:, 1] == val).all()
with pytest.raises(ValueError, match="`orient` must be either 'x' or 'y'"):
lineplot(long_df, x="y", y="x", orient="bad")
def test_log_scale(self):
f, ax = plt.subplots()
ax.set_xscale("log")
x = [1, 10, 100]
y = [1, 2, 3]
lineplot(x=x, y=y)
line = ax.lines[0]
assert_array_equal(line.get_xdata(), x)
assert_array_equal(line.get_ydata(), y)
f, ax = plt.subplots()
ax.set_xscale("log")
ax.set_yscale("log")
x = [1, 1, 2, 2]
y = [1, 10, 1, 100]
lineplot(x=x, y=y, err_style="bars", errorbar=("pi", 100))
line = ax.lines[0]
assert line.get_ydata()[1] == 10
ebars = ax.collections[0].get_segments()
assert_array_equal(ebars[0][:, 1], y[:2])
assert_array_equal(ebars[1][:, 1], y[2:])
def test_axis_labels(self, long_df):
f, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
p = _LinePlotter(
data=long_df,
variables=dict(x="x", y="y"),
)
p.plot(ax1, {})
assert ax1.get_xlabel() == "x"
assert ax1.get_ylabel() == "y"
p.plot(ax2, {})
assert ax2.get_xlabel() == "x"
assert ax2.get_ylabel() == "y"
assert not ax2.yaxis.label.get_visible()
def test_matplotlib_kwargs(self, long_df):
kws = {
"linestyle": "--",
"linewidth": 3,
"color": (1, .5, .2),
"markeredgecolor": (.2, .5, .2),
"markeredgewidth": 1,
}
ax = lineplot(data=long_df, x="x", y="y", **kws)
line, *_ = ax.lines
for key, val in kws.items():
plot_val = getattr(line, f"get_{key}")()
assert plot_val == val
def test_nonmapped_dashes(self):
ax = lineplot(x=[1, 2], y=[1, 2], dashes=(2, 1))
line = ax.lines[0]
# Not a great test, but lines don't expose the dash style publicly
assert line.get_linestyle() == "--"
def test_lineplot_axes(self, wide_df):
f1, ax1 = plt.subplots()
f2, ax2 = plt.subplots()
ax = lineplot(data=wide_df)
assert ax is ax2
ax = lineplot(data=wide_df, ax=ax1)
assert ax is ax1
def test_legend_attributes_with_hue(self, long_df):
kws = {"marker": "o", "linewidth": 3}
ax = lineplot(long_df, x="x", y="y", hue="a", **kws)
palette = color_palette()
for i, line in enumerate(get_legend_handles(ax.get_legend())):
assert same_color(line.get_color(), palette[i])
assert line.get_linewidth() == kws["linewidth"]
if not _version_predates(mpl, "3.7.0"):
assert line.get_marker() == kws["marker"]
def test_legend_attributes_with_style(self, long_df):
kws = {"color": "r", "marker": "o", "linewidth": 3}
ax = lineplot(long_df, x="x", y="y", style="a", **kws)
for line in get_legend_handles(ax.get_legend()):
assert same_color(line.get_color(), kws["color"])
if not _version_predates(mpl, "3.7.0"):
assert line.get_marker() == kws["marker"]
assert line.get_linewidth() == kws["linewidth"]
def test_legend_attributes_with_hue_and_style(self, long_df):
kws = {"marker": "o", "linewidth": 3}
ax = lineplot(long_df, x="x", y="y", hue="a", style="b", **kws)
for line in get_legend_handles(ax.get_legend()):
if line.get_label() not in ["a", "b"]:
if not _version_predates(mpl, "3.7.0"):
assert line.get_marker() == kws["marker"]
assert line.get_linewidth() == kws["linewidth"]
def test_lineplot_vs_relplot(self, long_df, long_semantics):
ax = lineplot(data=long_df, legend=False, **long_semantics)
g = relplot(data=long_df, kind="line", legend=False, **long_semantics)
lin_lines = ax.lines
rel_lines = g.ax.lines
for l1, l2 in zip(lin_lines, rel_lines):
assert_array_equal(l1.get_xydata(), l2.get_xydata())
assert same_color(l1.get_color(), l2.get_color())
assert l1.get_linewidth() == l2.get_linewidth()
assert l1.get_linestyle() == l2.get_linestyle()
def test_lineplot_smoke(
self,
wide_df, wide_array,
wide_list_of_series, wide_list_of_arrays, wide_list_of_lists,
flat_array, flat_series, flat_list,
long_df, null_df, object_df
):
f, ax = plt.subplots()
lineplot(x=[], y=[])
ax.clear()
lineplot(data=wide_df)
ax.clear()
lineplot(data=wide_array)
ax.clear()
lineplot(data=wide_list_of_series)
ax.clear()
lineplot(data=wide_list_of_arrays)
ax.clear()
lineplot(data=wide_list_of_lists)
ax.clear()
lineplot(data=flat_series)
ax.clear()
lineplot(data=flat_array)
ax.clear()
lineplot(data=flat_list)
ax.clear()
lineplot(x="x", y="y", data=long_df)
ax.clear()
lineplot(x=long_df.x, y=long_df.y)
ax.clear()
lineplot(x=long_df.x, y="y", data=long_df)
ax.clear()
lineplot(x="x", y=long_df.y.to_numpy(), data=long_df)
ax.clear()
lineplot(x="x", y="t", data=long_df)
ax.clear()
lineplot(x="x", y="y", hue="a", data=long_df)
ax.clear()
lineplot(x="x", y="y", hue="a", style="a", data=long_df)
ax.clear()
lineplot(x="x", y="y", hue="a", style="b", data=long_df)
ax.clear()
lineplot(x="x", y="y", hue="a", style="a", data=null_df)
ax.clear()
lineplot(x="x", y="y", hue="a", style="b", data=null_df)
ax.clear()
lineplot(x="x", y="y", hue="a", size="a", data=long_df)
ax.clear()
lineplot(x="x", y="y", hue="a", size="s", data=long_df)
ax.clear()
lineplot(x="x", y="y", hue="a", size="a", data=null_df)
ax.clear()
lineplot(x="x", y="y", hue="a", size="s", data=null_df)
ax.clear()
lineplot(x="x", y="y", hue="f", data=object_df)
ax.clear()
lineplot(x="x", y="y", hue="c", size="f", data=object_df)
ax.clear()
lineplot(x="x", y="y", hue="f", size="s", data=object_df)
ax.clear()
lineplot(x="x", y="y", hue="a", data=long_df.iloc[:0])
ax.clear()
def test_ci_deprecation(self, long_df):
axs = plt.figure().subplots(2)
lineplot(data=long_df, x="x", y="y", errorbar=("ci", 95), seed=0, ax=axs[0])
with pytest.warns(FutureWarning, match="\n\nThe `ci` parameter is deprecated"):
lineplot(data=long_df, x="x", y="y", ci=95, seed=0, ax=axs[1])
assert_plots_equal(*axs)
axs = plt.figure().subplots(2)
lineplot(data=long_df, x="x", y="y", errorbar="sd", ax=axs[0])
with pytest.warns(FutureWarning, match="\n\nThe `ci` parameter is deprecated"):
lineplot(data=long_df, x="x", y="y", ci="sd", ax=axs[1])
assert_plots_equal(*axs)
class TestScatterPlotter(SharedAxesLevelTests, Helpers):
func = staticmethod(scatterplot)
def get_last_color(self, ax):
colors = ax.collections[-1].get_facecolors()
unique_colors = np.unique(colors, axis=0)
assert len(unique_colors) == 1
return to_rgba(unique_colors.squeeze())
def test_color(self, long_df):
super().test_color(long_df)
ax = plt.figure().subplots()
self.func(data=long_df, x="x", y="y", facecolor="C5", ax=ax)
assert self.get_last_color(ax) == to_rgba("C5")
ax = plt.figure().subplots()
self.func(data=long_df, x="x", y="y", facecolors="C6", ax=ax)
assert self.get_last_color(ax) == to_rgba("C6")
ax = plt.figure().subplots()
self.func(data=long_df, x="x", y="y", fc="C4", ax=ax)
assert self.get_last_color(ax) == to_rgba("C4")
def test_legend_no_semantics(self, long_df):
ax = scatterplot(long_df, x="x", y="y")
handles, _ = ax.get_legend_handles_labels()
assert not handles
def test_legend_hue(self, long_df):
ax = scatterplot(long_df, x="x", y="y", hue="a")
handles, labels = ax.get_legend_handles_labels()
colors = [h.get_color() for h in handles]
expected_colors = color_palette(n_colors=len(handles))
assert same_color(colors, expected_colors)
assert labels == categorical_order(long_df["a"])
def test_legend_hue_style_same(self, long_df):
ax = scatterplot(long_df, x="x", y="y", hue="a", style="a")
handles, labels = ax.get_legend_handles_labels()
colors = [h.get_color() for h in handles]
expected_colors = color_palette(n_colors=len(labels))
markers = [h.get_marker() for h in handles]
expected_markers = unique_markers(len(handles))
assert same_color(colors, expected_colors)
assert markers == expected_markers
assert labels == categorical_order(long_df["a"])
def test_legend_hue_style_different(self, long_df):
ax = scatterplot(long_df, x="x", y="y", hue="a", style="b")
handles, labels = ax.get_legend_handles_labels()
colors = [h.get_color() for h in handles]
expected_colors = [
"w", *color_palette(n_colors=long_df["a"].nunique()),
"w", *[".2" for _ in long_df["b"].unique()],
]
markers = [h.get_marker() for h in handles]
expected_markers = [
"", *["o" for _ in long_df["a"].unique()],
"", *unique_markers(long_df["b"].nunique()),
]
assert same_color(colors, expected_colors)
assert markers == expected_markers
assert labels == [
"a", *categorical_order(long_df["a"]),
"b", *categorical_order(long_df["b"]),
]
def test_legend_data_hue_size_same(self, long_df):
ax = scatterplot(long_df, x="x", y="y", hue="a", size="a")
handles, labels = ax.get_legend_handles_labels()
colors = [h.get_color() for h in handles]
expected_colors = color_palette(n_colors=len(labels))
sizes = [h.get_markersize() for h in handles]
ms = mpl.rcParams["lines.markersize"] ** 2
expected_sizes = np.sqrt(
[ms * scl for scl in np.linspace(2, 0.5, len(handles))]
).tolist()
assert same_color(colors, expected_colors)
assert sizes == expected_sizes
assert labels == categorical_order(long_df["a"])
assert ax.get_legend().get_title().get_text() == "a"
def test_legend_size_numeric_list(self, long_df):
size_list = [10, 100, 200]
ax = scatterplot(long_df, x="x", y="y", size="s", sizes=size_list)
handles, labels = ax.get_legend_handles_labels()
sizes = [h.get_markersize() for h in handles]
expected_sizes = list(np.sqrt(size_list))
assert sizes == expected_sizes
assert labels == list(map(str, categorical_order(long_df["s"])))
assert ax.get_legend().get_title().get_text() == "s"
def test_legend_size_numeric_dict(self, long_df):
size_dict = {2: 10, 4: 100, 8: 200}
ax = scatterplot(long_df, x="x", y="y", size="s", sizes=size_dict)
handles, labels = ax.get_legend_handles_labels()
sizes = [h.get_markersize() for h in handles]
order = categorical_order(long_df["s"])
expected_sizes = [np.sqrt(size_dict[k]) for k in order]
assert sizes == expected_sizes
assert labels == list(map(str, order))
assert ax.get_legend().get_title().get_text() == "s"
def test_legend_numeric_hue_full(self):
x, y = np.random.randn(2, 40)
z = np.tile(np.arange(20), 2)
ax = scatterplot(x=x, y=y, hue=z, legend="full")
_, labels = ax.get_legend_handles_labels()
assert labels == [str(z_i) for z_i in sorted(set(z))]
assert ax.get_legend().get_title().get_text() == ""
def test_legend_numeric_hue_brief(self):
x, y = np.random.randn(2, 40)
z = np.tile(np.arange(20), 2)
ax = scatterplot(x=x, y=y, hue=z, legend="brief")
_, labels = ax.get_legend_handles_labels()
assert len(labels) < len(set(z))
def test_legend_numeric_size_full(self):
x, y = np.random.randn(2, 40)
z = np.tile(np.arange(20), 2)
ax = scatterplot(x=x, y=y, size=z, legend="full")
_, labels = ax.get_legend_handles_labels()
assert labels == [str(z_i) for z_i in sorted(set(z))]
def test_legend_numeric_size_brief(self):
x, y = np.random.randn(2, 40)
z = np.tile(np.arange(20), 2)
ax = scatterplot(x=x, y=y, size=z, legend="brief")
_, labels = ax.get_legend_handles_labels()
assert len(labels) < len(set(z))
def test_legend_attributes_hue(self, long_df):
kws = {"s": 50, "linewidth": 1, "marker": "X"}
ax = scatterplot(long_df, x="x", y="y", hue="a", **kws)
palette = color_palette()
for i, pt in enumerate(get_legend_handles(ax.get_legend())):
assert same_color(pt.get_color(), palette[i])
assert pt.get_markersize() == np.sqrt(kws["s"])
assert pt.get_markeredgewidth() == kws["linewidth"]
if not _version_predates(mpl, "3.7.0"):
# This attribute is empty on older matplotlibs
# but the legend looks correct so I assume it is a bug
assert pt.get_marker() == kws["marker"]
def test_legend_attributes_style(self, long_df):
kws = {"s": 50, "linewidth": 1, "color": "r"}
ax = scatterplot(long_df, x="x", y="y", style="a", **kws)
for pt in get_legend_handles(ax.get_legend()):
assert pt.get_markersize() == np.sqrt(kws["s"])
assert pt.get_markeredgewidth() == kws["linewidth"]
assert same_color(pt.get_color(), "r")
def test_legend_attributes_hue_and_style(self, long_df):
kws = {"s": 50, "linewidth": 1}
ax = scatterplot(long_df, x="x", y="y", hue="a", style="b", **kws)
for pt in get_legend_handles(ax.get_legend()):
if pt.get_label() not in ["a", "b"]:
assert pt.get_markersize() == np.sqrt(kws["s"])
assert pt.get_markeredgewidth() == kws["linewidth"]
def test_legend_value_error(self, long_df):
with pytest.raises(ValueError, match=r"`legend` must be"):
scatterplot(long_df, x="x", y="y", hue="a", legend="bad_value")
def test_plot(self, long_df, repeated_df):
f, ax = plt.subplots()
p = _ScatterPlotter(data=long_df, variables=dict(x="x", y="y"))
p.plot(ax, {})
points = ax.collections[0]
assert_array_equal(points.get_offsets(), long_df[["x", "y"]].to_numpy())
ax.clear()
p.plot(ax, {"color": "k", "label": "test"})
points = ax.collections[0]
assert same_color(points.get_facecolor(), "k")
assert points.get_label() == "test"
p = _ScatterPlotter(
data=long_df, variables=dict(x="x", y="y", hue="a")
)
ax.clear()
p.plot(ax, {})
points = ax.collections[0]
expected_colors = p._hue_map(p.plot_data["hue"])
assert same_color(points.get_facecolors(), expected_colors)
p = _ScatterPlotter(
data=long_df,
variables=dict(x="x", y="y", style="c"),
)
p.map_style(markers=["+", "x"])
ax.clear()
color = (1, .3, .8)
p.plot(ax, {"color": color})
points = ax.collections[0]
assert same_color(points.get_edgecolors(), [color])
p = _ScatterPlotter(
data=long_df, variables=dict(x="x", y="y", size="a"),
)
ax.clear()
p.plot(ax, {})
points = ax.collections[0]
expected_sizes = p._size_map(p.plot_data["size"])
assert_array_equal(points.get_sizes(), expected_sizes)
p = _ScatterPlotter(
data=long_df,
variables=dict(x="x", y="y", hue="a", style="a"),
)
p.map_style(markers=True)
ax.clear()
p.plot(ax, {})
points = ax.collections[0]
expected_colors = p._hue_map(p.plot_data["hue"])
expected_paths = p._style_map(p.plot_data["style"], "path")
assert same_color(points.get_facecolors(), expected_colors)
assert self.paths_equal(points.get_paths(), expected_paths)
p = _ScatterPlotter(
data=long_df,
variables=dict(x="x", y="y", hue="a", style="b"),
)
p.map_style(markers=True)
ax.clear()
p.plot(ax, {})
points = ax.collections[0]
expected_colors = p._hue_map(p.plot_data["hue"])
expected_paths = p._style_map(p.plot_data["style"], "path")
assert same_color(points.get_facecolors(), expected_colors)
assert self.paths_equal(points.get_paths(), expected_paths)
x_str = long_df["x"].astype(str)
p = _ScatterPlotter(
data=long_df, variables=dict(x="x", y="y", hue=x_str),
)
ax.clear()
p.plot(ax, {})
p = _ScatterPlotter(
data=long_df, variables=dict(x="x", y="y", size=x_str),
)
ax.clear()
p.plot(ax, {})
def test_axis_labels(self, long_df):
f, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
p = _ScatterPlotter(data=long_df, variables=dict(x="x", y="y"))
p.plot(ax1, {})
assert ax1.get_xlabel() == "x"
assert ax1.get_ylabel() == "y"
p.plot(ax2, {})
assert ax2.get_xlabel() == "x"
assert ax2.get_ylabel() == "y"
assert not ax2.yaxis.label.get_visible()
def test_scatterplot_axes(self, wide_df):
f1, ax1 = plt.subplots()
f2, ax2 = plt.subplots()
ax = scatterplot(data=wide_df)
assert ax is ax2
ax = scatterplot(data=wide_df, ax=ax1)
assert ax is ax1
def test_literal_attribute_vectors(self):
f, ax = plt.subplots()
x = y = [1, 2, 3]
s = [5, 10, 15]
c = [(1, 1, 0, 1), (1, 0, 1, .5), (.5, 1, 0, 1)]
scatterplot(x=x, y=y, c=c, s=s, ax=ax)
points, = ax.collections
assert_array_equal(points.get_sizes().squeeze(), s)
assert_array_equal(points.get_facecolors(), c)
def test_supplied_color_array(self, long_df):
cmap = get_colormap("Blues")
norm = mpl.colors.Normalize()
colors = cmap(norm(long_df["y"].to_numpy()))
keys = ["c", "fc", "facecolor", "facecolors"]
for key in keys:
ax = plt.figure().subplots()
scatterplot(data=long_df, x="x", y="y", **{key: colors})
_draw_figure(ax.figure)
assert_array_equal(ax.collections[0].get_facecolors(), colors)
ax = plt.figure().subplots()
scatterplot(data=long_df, x="x", y="y", c=long_df["y"], cmap=cmap)
_draw_figure(ax.figure)
assert_array_equal(ax.collections[0].get_facecolors(), colors)
def test_hue_order(self, long_df):
order = categorical_order(long_df["a"])
unused = order.pop()
ax = scatterplot(data=long_df, x="x", y="y", hue="a", hue_order=order)
points = ax.collections[0]
assert (points.get_facecolors()[long_df["a"] == unused] == 0).all()
assert [t.get_text() for t in ax.legend_.texts] == order
def test_linewidths(self, long_df):
f, ax = plt.subplots()
scatterplot(data=long_df, x="x", y="y", s=10)
scatterplot(data=long_df, x="x", y="y", s=20)
points1, points2 = ax.collections
assert (
points1.get_linewidths().item() < points2.get_linewidths().item()
)
ax.clear()
scatterplot(data=long_df, x="x", y="y", s=long_df["x"])
scatterplot(data=long_df, x="x", y="y", s=long_df["x"] * 2)
points1, points2 = ax.collections
assert (
points1.get_linewidths().item() < points2.get_linewidths().item()
)
ax.clear()
lw = 2
scatterplot(data=long_df, x="x", y="y", linewidth=lw)
assert ax.collections[0].get_linewidths().item() == lw
def test_size_norm_extrapolation(self):
# https://github.com/mwaskom/seaborn/issues/2539
x = np.arange(0, 20, 2)
f, axs = plt.subplots(1, 2, sharex=True, sharey=True)
slc = 5
kws = dict(sizes=(50, 200), size_norm=(0, x.max()), legend="brief")
scatterplot(x=x, y=x, size=x, ax=axs[0], **kws)
scatterplot(x=x[:slc], y=x[:slc], size=x[:slc], ax=axs[1], **kws)
assert np.allclose(
axs[0].collections[0].get_sizes()[:slc],
axs[1].collections[0].get_sizes()
)
legends = [ax.legend_ for ax in axs]
legend_data = [
{
label.get_text(): handle.get_markersize()
for label, handle in zip(legend.get_texts(), get_legend_handles(legend))
} for legend in legends
]
for key in set(legend_data[0]) & set(legend_data[1]):
if key == "y":
# At some point (circa 3.0) matplotlib auto-added pandas series
# with a valid name into the legend, which messes up this test.
# I can't track down when that was added (or removed), so let's
# just anticipate and ignore it here.
continue
assert legend_data[0][key] == legend_data[1][key]
def test_datetime_scale(self, long_df):
ax = scatterplot(data=long_df, x="t", y="y")
# Check that we avoid weird matplotlib default auto scaling
# https://github.com/matplotlib/matplotlib/issues/17586
ax.get_xlim()[0] > ax.xaxis.convert_units(np.datetime64("2002-01-01"))
def test_unfilled_marker_edgecolor_warning(self, long_df): # GH2636
with warnings.catch_warnings():
warnings.simplefilter("error")
scatterplot(data=long_df, x="x", y="y", marker="+")
def test_short_form_kwargs(self, long_df):
ax = scatterplot(data=long_df, x="x", y="y", ec="g")
pts = ax.collections[0]
assert same_color(pts.get_edgecolors().squeeze(), "g")
def test_scatterplot_vs_relplot(self, long_df, long_semantics):
ax = scatterplot(data=long_df, **long_semantics)
g = relplot(data=long_df, kind="scatter", **long_semantics)
for s_pts, r_pts in zip(ax.collections, g.ax.collections):
assert_array_equal(s_pts.get_offsets(), r_pts.get_offsets())
assert_array_equal(s_pts.get_sizes(), r_pts.get_sizes())
assert_array_equal(s_pts.get_facecolors(), r_pts.get_facecolors())
assert self.paths_equal(s_pts.get_paths(), r_pts.get_paths())
def test_scatterplot_smoke(
self,
wide_df, wide_array,
flat_series, flat_array, flat_list,
wide_list_of_series, wide_list_of_arrays, wide_list_of_lists,
long_df, null_df, object_df
):
f, ax = plt.subplots()
scatterplot(x=[], y=[])
ax.clear()
scatterplot(data=wide_df)
ax.clear()
scatterplot(data=wide_array)
ax.clear()
scatterplot(data=wide_list_of_series)
ax.clear()
scatterplot(data=wide_list_of_arrays)
ax.clear()
scatterplot(data=wide_list_of_lists)
ax.clear()
scatterplot(data=flat_series)
ax.clear()
scatterplot(data=flat_array)
ax.clear()
scatterplot(data=flat_list)
ax.clear()
scatterplot(x="x", y="y", data=long_df)
ax.clear()
scatterplot(x=long_df.x, y=long_df.y)
ax.clear()
scatterplot(x=long_df.x, y="y", data=long_df)
ax.clear()
scatterplot(x="x", y=long_df.y.to_numpy(), data=long_df)
ax.clear()
scatterplot(x="x", y="y", hue="a", data=long_df)
ax.clear()
scatterplot(x="x", y="y", hue="a", style="a", data=long_df)
ax.clear()
scatterplot(x="x", y="y", hue="a", style="b", data=long_df)
ax.clear()
scatterplot(x="x", y="y", hue="a", style="a", data=null_df)
ax.clear()
scatterplot(x="x", y="y", hue="a", style="b", data=null_df)
ax.clear()
scatterplot(x="x", y="y", hue="a", size="a", data=long_df)
ax.clear()
scatterplot(x="x", y="y", hue="a", size="s", data=long_df)
ax.clear()
scatterplot(x="x", y="y", hue="a", size="a", data=null_df)
ax.clear()
scatterplot(x="x", y="y", hue="a", size="s", data=null_df)
ax.clear()
scatterplot(x="x", y="y", hue="f", data=object_df)
ax.clear()
scatterplot(x="x", y="y", hue="c", size="f", data=object_df)
ax.clear()
scatterplot(x="x", y="y", hue="f", size="s", data=object_df)
ax.clear()
================================================
FILE: tests/test_statistics.py
================================================
import numpy as np
import pandas as pd
try:
import statsmodels.distributions as smdist
except ImportError:
smdist = None
import pytest
from numpy.testing import assert_array_equal, assert_array_almost_equal
from seaborn._statistics import (
KDE,
Histogram,
ECDF,
EstimateAggregator,
LetterValues,
WeightedAggregator,
_validate_errorbar_arg,
_no_scipy,
)
class DistributionFixtures:
@pytest.fixture
def x(self, rng):
return rng.normal(0, 1, 100)
@pytest.fixture
def x2(self, rng):
return rng.normal(0, 1, 742) # random value to avoid edge cases
@pytest.fixture
def y(self, rng):
return rng.normal(0, 5, 100)
@pytest.fixture
def weights(self, rng):
return rng.uniform(0, 5, 100)
class TestKDE:
def integrate(self, y, x):
y = np.asarray(y)
x = np.asarray(x)
dx = np.diff(x)
return (dx * y[:-1] + dx * y[1:]).sum() / 2
def test_gridsize(self, rng):
x = rng.normal(0, 3, 1000)
n = 200
kde = KDE(gridsize=n)
density, support = kde(x)
assert density.size == n
assert support.size == n
def test_cut(self, rng):
x = rng.normal(0, 3, 1000)
kde = KDE(cut=0)
_, support = kde(x)
assert support.min() == x.min()
assert support.max() == x.max()
cut = 2
bw_scale = .5
bw = x.std() * bw_scale
kde = KDE(cut=cut, bw_method=bw_scale, gridsize=1000)
_, support = kde(x)
assert support.min() == pytest.approx(x.min() - bw * cut, abs=1e-2)
assert support.max() == pytest.approx(x.max() + bw * cut, abs=1e-2)
def test_clip(self, rng):
x = rng.normal(0, 3, 100)
clip = -1, 1
kde = KDE(clip=clip)
_, support = kde(x)
assert support.min() >= clip[0]
assert support.max() <= clip[1]
def test_density_normalization(self, rng):
x = rng.normal(0, 3, 1000)
kde = KDE()
density, support = kde(x)
assert self.integrate(density, support) == pytest.approx(1, abs=1e-5)
@pytest.mark.skipif(_no_scipy, reason="Test requires scipy")
def test_cumulative(self, rng):
x = rng.normal(0, 3, 1000)
kde = KDE(cumulative=True)
density, _ = kde(x)
assert density[0] == pytest.approx(0, abs=1e-5)
assert density[-1] == pytest.approx(1, abs=1e-5)
def test_cached_support(self, rng):
x = rng.normal(0, 3, 100)
kde = KDE()
kde.define_support(x)
_, support = kde(x[(x > -1) & (x < 1)])
assert_array_equal(support, kde.support)
def test_bw_method(self, rng):
x = rng.normal(0, 3, 100)
kde1 = KDE(bw_method=.2)
kde2 = KDE(bw_method=2)
d1, _ = kde1(x)
d2, _ = kde2(x)
assert np.abs(np.diff(d1)).mean() > np.abs(np.diff(d2)).mean()
def test_bw_adjust(self, rng):
x = rng.normal(0, 3, 100)
kde1 = KDE(bw_adjust=.2)
kde2 = KDE(bw_adjust=2)
d1, _ = kde1(x)
d2, _ = kde2(x)
assert np.abs(np.diff(d1)).mean() > np.abs(np.diff(d2)).mean()
def test_bivariate_grid(self, rng):
n = 100
x, y = rng.normal(0, 3, (2, 50))
kde = KDE(gridsize=n)
density, (xx, yy) = kde(x, y)
assert density.shape == (n, n)
assert xx.size == n
assert yy.size == n
def test_bivariate_normalization(self, rng):
x, y = rng.normal(0, 3, (2, 50))
kde = KDE(gridsize=100)
density, (xx, yy) = kde(x, y)
dx = xx[1] - xx[0]
dy = yy[1] - yy[0]
total = density.sum() * (dx * dy)
assert total == pytest.approx(1, abs=1e-2)
@pytest.mark.skipif(_no_scipy, reason="Test requires scipy")
def test_bivariate_cumulative(self, rng):
x, y = rng.normal(0, 3, (2, 50))
kde = KDE(gridsize=100, cumulative=True)
density, _ = kde(x, y)
assert density[0, 0] == pytest.approx(0, abs=1e-2)
assert density[-1, -1] == pytest.approx(1, abs=1e-2)
class TestHistogram(DistributionFixtures):
def test_string_bins(self, x):
h = Histogram(bins="sqrt")
bin_kws = h.define_bin_params(x)
assert bin_kws["range"] == (x.min(), x.max())
assert bin_kws["bins"] == int(np.sqrt(len(x)))
def test_int_bins(self, x):
n = 24
h = Histogram(bins=n)
bin_kws = h.define_bin_params(x)
assert bin_kws["range"] == (x.min(), x.max())
assert bin_kws["bins"] == n
def test_array_bins(self, x):
bins = [-3, -2, 1, 2, 3]
h = Histogram(bins=bins)
bin_kws = h.define_bin_params(x)
assert_array_equal(bin_kws["bins"], bins)
def test_bivariate_string_bins(self, x, y):
s1, s2 = "sqrt", "fd"
h = Histogram(bins=s1)
e1, e2 = h.define_bin_params(x, y)["bins"]
assert_array_equal(e1, np.histogram_bin_edges(x, s1))
assert_array_equal(e2, np.histogram_bin_edges(y, s1))
h = Histogram(bins=(s1, s2))
e1, e2 = h.define_bin_params(x, y)["bins"]
assert_array_equal(e1, np.histogram_bin_edges(x, s1))
assert_array_equal(e2, np.histogram_bin_edges(y, s2))
def test_bivariate_int_bins(self, x, y):
b1, b2 = 5, 10
h = Histogram(bins=b1)
e1, e2 = h.define_bin_params(x, y)["bins"]
assert len(e1) == b1 + 1
assert len(e2) == b1 + 1
h = Histogram(bins=(b1, b2))
e1, e2 = h.define_bin_params(x, y)["bins"]
assert len(e1) == b1 + 1
assert len(e2) == b2 + 1
def test_bivariate_array_bins(self, x, y):
b1 = [-3, -2, 1, 2, 3]
b2 = [-5, -2, 3, 6]
h = Histogram(bins=b1)
e1, e2 = h.define_bin_params(x, y)["bins"]
assert_array_equal(e1, b1)
assert_array_equal(e2, b1)
h = Histogram(bins=(b1, b2))
e1, e2 = h.define_bin_params(x, y)["bins"]
assert_array_equal(e1, b1)
assert_array_equal(e2, b2)
def test_binwidth(self, x):
binwidth = .5
h = Histogram(binwidth=binwidth)
bin_kws = h.define_bin_params(x)
n_bins = bin_kws["bins"]
left, right = bin_kws["range"]
assert (right - left) / n_bins == pytest.approx(binwidth)
def test_bivariate_binwidth(self, x, y):
w1, w2 = .5, 1
h = Histogram(binwidth=w1)
e1, e2 = h.define_bin_params(x, y)["bins"]
assert np.all(np.diff(e1) == w1)
assert np.all(np.diff(e2) == w1)
h = Histogram(binwidth=(w1, w2))
e1, e2 = h.define_bin_params(x, y)["bins"]
assert np.all(np.diff(e1) == w1)
assert np.all(np.diff(e2) == w2)
def test_binrange(self, x):
binrange = (-4, 4)
h = Histogram(binrange=binrange)
bin_kws = h.define_bin_params(x)
assert bin_kws["range"] == binrange
def test_bivariate_binrange(self, x, y):
r1, r2 = (-4, 4), (-10, 10)
h = Histogram(binrange=r1)
e1, e2 = h.define_bin_params(x, y)["bins"]
assert e1.min() == r1[0]
assert e1.max() == r1[1]
assert e2.min() == r1[0]
assert e2.max() == r1[1]
h = Histogram(binrange=(r1, r2))
e1, e2 = h.define_bin_params(x, y)["bins"]
assert e1.min() == r1[0]
assert e1.max() == r1[1]
assert e2.min() == r2[0]
assert e2.max() == r2[1]
def test_discrete_bins(self, rng):
x = rng.binomial(20, .5, 100)
h = Histogram(discrete=True)
bin_kws = h.define_bin_params(x)
assert bin_kws["range"] == (x.min() - .5, x.max() + .5)
assert bin_kws["bins"] == (x.max() - x.min() + 1)
def test_odd_single_observation(self):
# GH2721
x = np.array([0.49928])
h, e = Histogram(binwidth=0.03)(x)
assert len(h) == 1
assert (e[1] - e[0]) == pytest.approx(.03)
def test_binwidth_roundoff(self):
# GH2785
x = np.array([2.4, 2.5, 2.6])
h, e = Histogram(binwidth=0.01)(x)
assert h.sum() == 3
def test_histogram(self, x):
h = Histogram()
heights, edges = h(x)
heights_mpl, edges_mpl = np.histogram(x, bins="auto")
assert_array_equal(heights, heights_mpl)
assert_array_equal(edges, edges_mpl)
def test_count_stat(self, x):
h = Histogram(stat="count")
heights, _ = h(x)
assert heights.sum() == len(x)
def test_density_stat(self, x):
h = Histogram(stat="density")
heights, edges = h(x)
assert (heights * np.diff(edges)).sum() == 1
def test_probability_stat(self, x):
h = Histogram(stat="probability")
heights, _ = h(x)
assert heights.sum() == 1
def test_frequency_stat(self, x):
h = Histogram(stat="frequency")
heights, edges = h(x)
assert (heights * np.diff(edges)).sum() == len(x)
def test_cumulative_count(self, x):
h = Histogram(stat="count", cumulative=True)
heights, _ = h(x)
assert heights[-1] == len(x)
def test_cumulative_density(self, x):
h = Histogram(stat="density", cumulative=True)
heights, _ = h(x)
assert heights[-1] == 1
def test_cumulative_probability(self, x):
h = Histogram(stat="probability", cumulative=True)
heights, _ = h(x)
assert heights[-1] == 1
def test_cumulative_frequency(self, x):
h = Histogram(stat="frequency", cumulative=True)
heights, _ = h(x)
assert heights[-1] == len(x)
def test_bivariate_histogram(self, x, y):
h = Histogram()
heights, edges = h(x, y)
bins_mpl = (
np.histogram_bin_edges(x, "auto"),
np.histogram_bin_edges(y, "auto"),
)
heights_mpl, *edges_mpl = np.histogram2d(x, y, bins_mpl)
assert_array_equal(heights, heights_mpl)
assert_array_equal(edges[0], edges_mpl[0])
assert_array_equal(edges[1], edges_mpl[1])
def test_bivariate_count_stat(self, x, y):
h = Histogram(stat="count")
heights, _ = h(x, y)
assert heights.sum() == len(x)
def test_bivariate_density_stat(self, x, y):
h = Histogram(stat="density")
heights, (edges_x, edges_y) = h(x, y)
areas = np.outer(np.diff(edges_x), np.diff(edges_y))
assert (heights * areas).sum() == pytest.approx(1)
def test_bivariate_probability_stat(self, x, y):
h = Histogram(stat="probability")
heights, _ = h(x, y)
assert heights.sum() == 1
def test_bivariate_frequency_stat(self, x, y):
h = Histogram(stat="frequency")
heights, (x_edges, y_edges) = h(x, y)
area = np.outer(np.diff(x_edges), np.diff(y_edges))
assert (heights * area).sum() == len(x)
def test_bivariate_cumulative_count(self, x, y):
h = Histogram(stat="count", cumulative=True)
heights, _ = h(x, y)
assert heights[-1, -1] == len(x)
def test_bivariate_cumulative_density(self, x, y):
h = Histogram(stat="density", cumulative=True)
heights, _ = h(x, y)
assert heights[-1, -1] == pytest.approx(1)
def test_bivariate_cumulative_frequency(self, x, y):
h = Histogram(stat="frequency", cumulative=True)
heights, _ = h(x, y)
assert heights[-1, -1] == len(x)
def test_bivariate_cumulative_probability(self, x, y):
h = Histogram(stat="probability", cumulative=True)
heights, _ = h(x, y)
assert heights[-1, -1] == pytest.approx(1)
def test_bad_stat(self):
with pytest.raises(ValueError):
Histogram(stat="invalid")
class TestECDF(DistributionFixtures):
def test_univariate_proportion(self, x):
ecdf = ECDF()
stat, vals = ecdf(x)
assert_array_equal(vals[1:], np.sort(x))
assert_array_almost_equal(stat[1:], np.linspace(0, 1, len(x) + 1)[1:])
assert stat[0] == 0
def test_univariate_count(self, x):
ecdf = ECDF(stat="count")
stat, vals = ecdf(x)
assert_array_equal(vals[1:], np.sort(x))
assert_array_almost_equal(stat[1:], np.arange(len(x)) + 1)
assert stat[0] == 0
def test_univariate_percent(self, x2):
ecdf = ECDF(stat="percent")
stat, vals = ecdf(x2)
assert_array_equal(vals[1:], np.sort(x2))
assert_array_almost_equal(stat[1:], (np.arange(len(x2)) + 1) / len(x2) * 100)
assert stat[0] == 0
def test_univariate_proportion_weights(self, x, weights):
ecdf = ECDF()
stat, vals = ecdf(x, weights=weights)
assert_array_equal(vals[1:], np.sort(x))
expected_stats = weights[x.argsort()].cumsum() / weights.sum()
assert_array_almost_equal(stat[1:], expected_stats)
assert stat[0] == 0
def test_univariate_count_weights(self, x, weights):
ecdf = ECDF(stat="count")
stat, vals = ecdf(x, weights=weights)
assert_array_equal(vals[1:], np.sort(x))
assert_array_almost_equal(stat[1:], weights[x.argsort()].cumsum())
assert stat[0] == 0
@pytest.mark.skipif(smdist is None, reason="Requires statsmodels")
def test_against_statsmodels(self, x):
sm_ecdf = smdist.empirical_distribution.ECDF(x)
ecdf = ECDF()
stat, vals = ecdf(x)
assert_array_equal(vals, sm_ecdf.x)
assert_array_almost_equal(stat, sm_ecdf.y)
ecdf = ECDF(complementary=True)
stat, vals = ecdf(x)
assert_array_equal(vals, sm_ecdf.x)
assert_array_almost_equal(stat, sm_ecdf.y[::-1])
def test_invalid_stat(self, x):
with pytest.raises(ValueError, match="`stat` must be one of"):
ECDF(stat="density")
def test_bivariate_error(self, x, y):
with pytest.raises(NotImplementedError, match="Bivariate ECDF"):
ecdf = ECDF()
ecdf(x, y)
class TestEstimateAggregator:
def test_func_estimator(self, long_df):
func = np.mean
agg = EstimateAggregator(func)
out = agg(long_df, "x")
assert out["x"] == func(long_df["x"])
def test_name_estimator(self, long_df):
agg = EstimateAggregator("mean")
out = agg(long_df, "x")
assert out["x"] == long_df["x"].mean()
def test_custom_func_estimator(self, long_df):
def func(x):
return np.asarray(x).min()
agg = EstimateAggregator(func)
out = agg(long_df, "x")
assert out["x"] == func(long_df["x"])
def test_se_errorbars(self, long_df):
agg = EstimateAggregator("mean", "se")
out = agg(long_df, "x")
assert out["x"] == long_df["x"].mean()
assert out["xmin"] == (long_df["x"].mean() - long_df["x"].sem())
assert out["xmax"] == (long_df["x"].mean() + long_df["x"].sem())
agg = EstimateAggregator("mean", ("se", 2))
out = agg(long_df, "x")
assert out["x"] == long_df["x"].mean()
assert out["xmin"] == (long_df["x"].mean() - 2 * long_df["x"].sem())
assert out["xmax"] == (long_df["x"].mean() + 2 * long_df["x"].sem())
def test_sd_errorbars(self, long_df):
agg = EstimateAggregator("mean", "sd")
out = agg(long_df, "x")
assert out["x"] == long_df["x"].mean()
assert out["xmin"] == (long_df["x"].mean() - long_df["x"].std())
assert out["xmax"] == (long_df["x"].mean() + long_df["x"].std())
agg = EstimateAggregator("mean", ("sd", 2))
out = agg(long_df, "x")
assert out["x"] == long_df["x"].mean()
assert out["xmin"] == (long_df["x"].mean() - 2 * long_df["x"].std())
assert out["xmax"] == (long_df["x"].mean() + 2 * long_df["x"].std())
def test_pi_errorbars(self, long_df):
agg = EstimateAggregator("mean", "pi")
out = agg(long_df, "y")
assert out["ymin"] == np.percentile(long_df["y"], 2.5)
assert out["ymax"] == np.percentile(long_df["y"], 97.5)
agg = EstimateAggregator("mean", ("pi", 50))
out = agg(long_df, "y")
assert out["ymin"] == np.percentile(long_df["y"], 25)
assert out["ymax"] == np.percentile(long_df["y"], 75)
def test_ci_errorbars(self, long_df):
agg = EstimateAggregator("mean", "ci", n_boot=100000, seed=0)
out = agg(long_df, "y")
agg_ref = EstimateAggregator("mean", ("se", 1.96))
out_ref = agg_ref(long_df, "y")
assert out["ymin"] == pytest.approx(out_ref["ymin"], abs=1e-2)
assert out["ymax"] == pytest.approx(out_ref["ymax"], abs=1e-2)
agg = EstimateAggregator("mean", ("ci", 68), n_boot=100000, seed=0)
out = agg(long_df, "y")
agg_ref = EstimateAggregator("mean", ("se", 1))
out_ref = agg_ref(long_df, "y")
assert out["ymin"] == pytest.approx(out_ref["ymin"], abs=1e-2)
assert out["ymax"] == pytest.approx(out_ref["ymax"], abs=1e-2)
agg = EstimateAggregator("mean", "ci", seed=0)
out_orig = agg_ref(long_df, "y")
out_test = agg_ref(long_df, "y")
assert_array_equal(out_orig, out_test)
def test_custom_errorbars(self, long_df):
f = lambda x: (x.min(), x.max()) # noqa: E731
agg = EstimateAggregator("mean", f)
out = agg(long_df, "y")
assert out["ymin"] == long_df["y"].min()
assert out["ymax"] == long_df["y"].max()
def test_singleton_errorbars(self):
agg = EstimateAggregator("mean", "ci")
val = 7
out = agg(pd.DataFrame(dict(y=[val])), "y")
assert out["y"] == val
assert pd.isna(out["ymin"])
assert pd.isna(out["ymax"])
def test_errorbar_validation(self):
method, level = _validate_errorbar_arg(("ci", 99))
assert method == "ci"
assert level == 99
method, level = _validate_errorbar_arg("sd")
assert method == "sd"
assert level == 1
f = lambda x: (x.min(), x.max()) # noqa: E731
method, level = _validate_errorbar_arg(f)
assert method is f
assert level is None
bad_args = [
("sem", ValueError),
(("std", 2), ValueError),
(("pi", 5, 95), ValueError),
(95, TypeError),
(("ci", "large"), TypeError),
]
for arg, exception in bad_args:
with pytest.raises(exception, match="`errorbar` must be"):
_validate_errorbar_arg(arg)
class TestWeightedAggregator:
def test_weighted_mean(self, long_df):
long_df["weight"] = long_df["x"]
est = WeightedAggregator("mean")
out = est(long_df, "y")
expected = np.average(long_df["y"], weights=long_df["weight"])
assert_array_equal(out["y"], expected)
assert_array_equal(out["ymin"], np.nan)
assert_array_equal(out["ymax"], np.nan)
def test_weighted_ci(self, long_df):
long_df["weight"] = long_df["x"]
est = WeightedAggregator("mean", "ci")
out = est(long_df, "y")
expected = np.average(long_df["y"], weights=long_df["weight"])
assert_array_equal(out["y"], expected)
assert (out["ymin"] <= out["y"]).all()
assert (out["ymax"] >= out["y"]).all()
def test_limited_estimator(self):
with pytest.raises(ValueError, match="Weighted estimator must be 'mean'"):
WeightedAggregator("median")
def test_limited_ci(self):
with pytest.raises(ValueError, match="Error bar method must be 'ci'"):
WeightedAggregator("mean", "sd")
class TestLetterValues:
@pytest.fixture
def x(self, rng):
return pd.Series(rng.standard_t(10, 10_000))
def test_levels(self, x):
res = LetterValues(k_depth="tukey", outlier_prop=0, trust_alpha=0)(x)
k = res["k"]
expected = np.concatenate([np.arange(k), np.arange(k - 1)[::-1]])
assert_array_equal(res["levels"], expected)
def test_values(self, x):
res = LetterValues(k_depth="tukey", outlier_prop=0, trust_alpha=0)(x)
assert_array_equal(np.percentile(x, res["percs"]), res["values"])
def test_fliers(self, x):
res = LetterValues(k_depth="tukey", outlier_prop=0, trust_alpha=0)(x)
fliers = res["fliers"]
values = res["values"]
assert ((fliers < values.min()) | (fliers > values.max())).all()
def test_median(self, x):
res = LetterValues(k_depth="tukey", outlier_prop=0, trust_alpha=0)(x)
assert res["median"] == np.median(x)
def test_k_depth_int(self, x):
res = LetterValues(k_depth=(k := 12), outlier_prop=0, trust_alpha=0)(x)
assert res["k"] == k
assert len(res["levels"]) == (2 * k - 1)
def test_trust_alpha(self, x):
res1 = LetterValues(k_depth="trustworthy", outlier_prop=0, trust_alpha=.1)(x)
res2 = LetterValues(k_depth="trustworthy", outlier_prop=0, trust_alpha=.001)(x)
assert res1["k"] > res2["k"]
def test_outlier_prop(self, x):
res1 = LetterValues(k_depth="proportion", outlier_prop=.001, trust_alpha=0)(x)
res2 = LetterValues(k_depth="proportion", outlier_prop=.1, trust_alpha=0)(x)
assert res1["k"] > res2["k"]
================================================
FILE: tests/test_utils.py
================================================
"""Tests for seaborn utility functions."""
import re
import tempfile
from types import ModuleType
from urllib.request import urlopen
from http.client import HTTPException
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
from cycler import cycler
import pytest
from numpy.testing import (
assert_array_equal,
)
from pandas.testing import (
assert_series_equal,
assert_frame_equal,
)
from seaborn import utils, rcmod, scatterplot
from seaborn.utils import (
get_dataset_names,
get_color_cycle,
remove_na,
load_dataset,
_assign_default_kwargs,
_check_argument,
_draw_figure,
_deprecate_ci,
_version_predates, DATASET_NAMES_URL,
)
from seaborn._compat import get_legend_handles
a_norm = np.random.randn(100)
def _network(t=None, url="https://github.com"):
"""
Decorator that will skip a test if `url` is unreachable.
Parameters
----------
t : function, optional
url : str, optional
"""
if t is None:
return lambda x: _network(x, url=url)
def wrapper(*args, **kwargs):
# attempt to connect
try:
f = urlopen(url)
except (OSError, HTTPException):
pytest.skip("No internet connection")
else:
f.close()
return t(*args, **kwargs)
return wrapper
def test_ci_to_errsize():
"""Test behavior of ci_to_errsize."""
cis = [[.5, .5],
[1.25, 1.5]]
heights = [1, 1.5]
actual_errsize = np.array([[.5, 1],
[.25, 0]])
test_errsize = utils.ci_to_errsize(cis, heights)
assert_array_equal(actual_errsize, test_errsize)
def test_desaturate():
"""Test color desaturation."""
out1 = utils.desaturate("red", .5)
assert out1 == (.75, .25, .25)
out2 = utils.desaturate("#00FF00", .5)
assert out2 == (.25, .75, .25)
out3 = utils.desaturate((0, 0, 1), .5)
assert out3 == (.25, .25, .75)
out4 = utils.desaturate("red", .5)
assert out4 == (.75, .25, .25)
out5 = utils.desaturate("lightblue", 1)
assert out5 == mpl.colors.to_rgb("lightblue")
def test_desaturation_prop():
"""Test that pct outside of [0, 1] raises exception."""
with pytest.raises(ValueError):
utils.desaturate("blue", 50)
def test_saturate():
"""Test performance of saturation function."""
out = utils.saturate((.75, .25, .25))
assert out == (1, 0, 0)
@pytest.mark.parametrize(
"s,exp",
[
("a", "a"),
("abc", "abc"),
(b"a", "a"),
(b"abc", "abc"),
(bytearray("abc", "utf-8"), "abc"),
(bytearray(), ""),
(1, "1"),
(0, "0"),
([], str([])),
],
)
def test_to_utf8(s, exp):
"""Test the to_utf8 function: object to string"""
u = utils.to_utf8(s)
assert isinstance(u, str)
assert u == exp
class TestSpineUtils:
sides = ["left", "right", "bottom", "top"]
outer_sides = ["top", "right"]
inner_sides = ["left", "bottom"]
offset = 10
original_position = ("outward", 0)
offset_position = ("outward", offset)
def test_despine(self):
f, ax = plt.subplots()
for side in self.sides:
assert ax.spines[side].get_visible()
utils.despine()
for side in self.outer_sides:
assert not ax.spines[side].get_visible()
for side in self.inner_sides:
assert ax.spines[side].get_visible()
utils.despine(**dict(zip(self.sides, [True] * 4)))
for side in self.sides:
assert not ax.spines[side].get_visible()
def test_despine_specific_axes(self):
f, (ax1, ax2) = plt.subplots(2, 1)
utils.despine(ax=ax2)
for side in self.sides:
assert ax1.spines[side].get_visible()
for side in self.outer_sides:
assert not ax2.spines[side].get_visible()
for side in self.inner_sides:
assert ax2.spines[side].get_visible()
def test_despine_with_offset(self):
f, ax = plt.subplots()
for side in self.sides:
pos = ax.spines[side].get_position()
assert pos == self.original_position
utils.despine(ax=ax, offset=self.offset)
for side in self.sides:
is_visible = ax.spines[side].get_visible()
new_position = ax.spines[side].get_position()
if is_visible:
assert new_position == self.offset_position
else:
assert new_position == self.original_position
def test_despine_side_specific_offset(self):
f, ax = plt.subplots()
utils.despine(ax=ax, offset=dict(left=self.offset))
for side in self.sides:
is_visible = ax.spines[side].get_visible()
new_position = ax.spines[side].get_position()
if is_visible and side == "left":
assert new_position == self.offset_position
else:
assert new_position == self.original_position
def test_despine_with_offset_specific_axes(self):
f, (ax1, ax2) = plt.subplots(2, 1)
utils.despine(offset=self.offset, ax=ax2)
for side in self.sides:
pos1 = ax1.spines[side].get_position()
pos2 = ax2.spines[side].get_position()
assert pos1 == self.original_position
if ax2.spines[side].get_visible():
assert pos2 == self.offset_position
else:
assert pos2 == self.original_position
def test_despine_trim_spines(self):
f, ax = plt.subplots()
ax.plot([1, 2, 3], [1, 2, 3])
ax.set_xlim(.75, 3.25)
utils.despine(trim=True)
for side in self.inner_sides:
bounds = ax.spines[side].get_bounds()
assert bounds == (1, 3)
def test_despine_trim_inverted(self):
f, ax = plt.subplots()
ax.plot([1, 2, 3], [1, 2, 3])
ax.set_ylim(.85, 3.15)
ax.invert_yaxis()
utils.despine(trim=True)
for side in self.inner_sides:
bounds = ax.spines[side].get_bounds()
assert bounds == (1, 3)
def test_despine_trim_noticks(self):
f, ax = plt.subplots()
ax.plot([1, 2, 3], [1, 2, 3])
ax.set_yticks([])
utils.despine(trim=True)
assert ax.get_yticks().size == 0
def test_despine_trim_categorical(self):
f, ax = plt.subplots()
ax.plot(["a", "b", "c"], [1, 2, 3])
utils.despine(trim=True)
bounds = ax.spines["left"].get_bounds()
assert bounds == (1, 3)
bounds = ax.spines["bottom"].get_bounds()
assert bounds == (0, 2)
def test_despine_moved_ticks(self):
f, ax = plt.subplots()
for t in ax.yaxis.majorTicks:
t.tick1line.set_visible(True)
utils.despine(ax=ax, left=True, right=False)
for t in ax.yaxis.majorTicks:
assert t.tick2line.get_visible()
plt.close(f)
f, ax = plt.subplots()
for t in ax.yaxis.majorTicks:
t.tick1line.set_visible(False)
utils.despine(ax=ax, left=True, right=False)
for t in ax.yaxis.majorTicks:
assert not t.tick2line.get_visible()
plt.close(f)
f, ax = plt.subplots()
for t in ax.xaxis.majorTicks:
t.tick1line.set_visible(True)
utils.despine(ax=ax, bottom=True, top=False)
for t in ax.xaxis.majorTicks:
assert t.tick2line.get_visible()
plt.close(f)
f, ax = plt.subplots()
for t in ax.xaxis.majorTicks:
t.tick1line.set_visible(False)
utils.despine(ax=ax, bottom=True, top=False)
for t in ax.xaxis.majorTicks:
assert not t.tick2line.get_visible()
plt.close(f)
def test_ticklabels_overlap():
rcmod.set()
f, ax = plt.subplots(figsize=(2, 2))
f.tight_layout() # This gets the Agg renderer working
assert not utils.axis_ticklabels_overlap(ax.get_xticklabels())
big_strings = "abcdefgh", "ijklmnop"
ax.set_xlim(-.5, 1.5)
ax.set_xticks([0, 1])
ax.set_xticklabels(big_strings)
assert utils.axis_ticklabels_overlap(ax.get_xticklabels())
x, y = utils.axes_ticklabels_overlap(ax)
assert x
assert not y
def test_locator_to_legend_entries():
locator = mpl.ticker.MaxNLocator(nbins=3)
limits = (0.09, 0.4)
levels, str_levels = utils.locator_to_legend_entries(
locator, limits, float
)
assert str_levels == ["0.15", "0.30"]
limits = (0.8, 0.9)
levels, str_levels = utils.locator_to_legend_entries(
locator, limits, float
)
assert str_levels == ["0.80", "0.84", "0.88"]
limits = (1, 6)
levels, str_levels = utils.locator_to_legend_entries(locator, limits, int)
assert str_levels == ["2", "4", "6"]
locator = mpl.ticker.LogLocator(numticks=5)
limits = (5, 1425)
levels, str_levels = utils.locator_to_legend_entries(locator, limits, int)
assert str_levels == ['10', '100', '1000']
limits = (0.00003, 0.02)
_, str_levels = utils.locator_to_legend_entries(locator, limits, float)
for i, exp in enumerate([4, 3, 2]):
# Use regex as mpl switched to minus sign, not hyphen, in 3.6
assert re.match(f"1e.0{exp}", str_levels[i])
def test_move_legend_matplotlib_objects():
fig, ax = plt.subplots()
colors = "C2", "C5"
labels = "first label", "second label"
title = "the legend"
for color, label in zip(colors, labels):
ax.plot([0, 1], color=color, label=label)
ax.legend(loc="upper right", title=title)
utils._draw_figure(fig)
xfm = ax.transAxes.inverted().transform
# --- Test axes legend
old_pos = xfm(ax.legend_.legendPatch.get_extents())
new_fontsize = 14
utils.move_legend(ax, "lower left", title_fontsize=new_fontsize)
utils._draw_figure(fig)
new_pos = xfm(ax.legend_.legendPatch.get_extents())
assert (new_pos < old_pos).all()
assert ax.legend_.get_title().get_text() == title
assert ax.legend_.get_title().get_size() == new_fontsize
# --- Test title replacement
new_title = "new title"
utils.move_legend(ax, "lower left", title=new_title)
utils._draw_figure(fig)
assert ax.legend_.get_title().get_text() == new_title
# --- Test figure legend
fig.legend(loc="upper right", title=title)
_draw_figure(fig)
xfm = fig.transFigure.inverted().transform
old_pos = xfm(fig.legends[0].legendPatch.get_extents())
utils.move_legend(fig, "lower left", title=new_title)
_draw_figure(fig)
new_pos = xfm(fig.legends[0].legendPatch.get_extents())
assert (new_pos < old_pos).all()
assert fig.legends[0].get_title().get_text() == new_title
def test_move_legend_grid_object(long_df):
from seaborn.axisgrid import FacetGrid
hue_var = "a"
g = FacetGrid(long_df, hue=hue_var)
g.map(plt.plot, "x", "y")
g.add_legend()
_draw_figure(g.figure)
xfm = g.figure.transFigure.inverted().transform
old_pos = xfm(g.legend.legendPatch.get_extents())
fontsize = 20
utils.move_legend(g, "lower left", title_fontsize=fontsize)
_draw_figure(g.figure)
new_pos = xfm(g.legend.legendPatch.get_extents())
assert (new_pos < old_pos).all()
assert g.legend.get_title().get_text() == hue_var
assert g.legend.get_title().get_size() == fontsize
assert get_legend_handles(g.legend)
for i, h in enumerate(get_legend_handles(g.legend)):
assert mpl.colors.to_rgb(h.get_color()) == mpl.colors.to_rgb(f"C{i}")
def test_move_legend_input_checks():
ax = plt.figure().subplots()
with pytest.raises(TypeError):
utils.move_legend(ax.xaxis, "best")
with pytest.raises(ValueError):
utils.move_legend(ax, "best")
with pytest.raises(ValueError):
utils.move_legend(ax.figure, "best")
def test_move_legend_with_labels(long_df):
order = long_df["a"].unique()
labels = [s.capitalize() for s in order]
ax = scatterplot(long_df, x="x", y="y", hue="a", hue_order=order)
handles_before = get_legend_handles(ax.get_legend())
colors_before = [h.get_markerfacecolor() for h in handles_before]
utils.move_legend(ax, "best", labels=labels)
_draw_figure(ax.figure)
texts = [t.get_text() for t in ax.get_legend().get_texts()]
assert texts == labels
handles_after = get_legend_handles(ax.get_legend())
colors_after = [h.get_markerfacecolor() for h in handles_after]
assert colors_before == colors_after
with pytest.raises(ValueError, match="Length of new labels"):
utils.move_legend(ax, "best", labels=labels[:-1])
def check_load_dataset(name):
ds = load_dataset(name, cache=False)
assert isinstance(ds, pd.DataFrame)
def check_load_cached_dataset(name):
# Test the caching using a temporary file.
with tempfile.TemporaryDirectory() as tmpdir:
# download and cache
ds = load_dataset(name, cache=True, data_home=tmpdir)
# use cached version
ds2 = load_dataset(name, cache=True, data_home=tmpdir)
assert_frame_equal(ds, ds2)
@_network(url=DATASET_NAMES_URL)
def test_get_dataset_names():
names = get_dataset_names()
assert names
assert "tips" in names
@_network(url=DATASET_NAMES_URL)
def test_load_datasets():
# Heavy test to verify that we can load all available datasets
for name in get_dataset_names():
# unfortunately @network somehow obscures this generator so it
# does not get in effect, so we need to call explicitly
# yield check_load_dataset, name
check_load_dataset(name)
@_network(url=DATASET_NAMES_URL)
def test_load_dataset_string_error():
name = "bad_name"
err = f"'{name}' is not one of the example datasets."
with pytest.raises(ValueError, match=err):
load_dataset(name)
def test_load_dataset_passed_data_error():
df = pd.DataFrame()
err = "This function accepts only strings"
with pytest.raises(TypeError, match=err):
load_dataset(df)
@_network(url="https://github.com/mwaskom/seaborn-data")
def test_load_cached_datasets():
# Heavy test to verify that we can load all available datasets
for name in get_dataset_names():
# unfortunately @network somehow obscures this generator so it
# does not get in effect, so we need to call explicitly
# yield check_load_dataset, name
check_load_cached_dataset(name)
def test_relative_luminance():
"""Test relative luminance."""
out1 = utils.relative_luminance("white")
assert out1 == 1
out2 = utils.relative_luminance("#000000")
assert out2 == 0
out3 = utils.relative_luminance((.25, .5, .75))
assert out3 == pytest.approx(0.201624536)
rgbs = mpl.cm.RdBu(np.linspace(0, 1, 10))
lums1 = [utils.relative_luminance(rgb) for rgb in rgbs]
lums2 = utils.relative_luminance(rgbs)
for lum1, lum2 in zip(lums1, lums2):
assert lum1 == pytest.approx(lum2)
@pytest.mark.parametrize(
"cycler,result",
[
(cycler(color=["y"]), ["y"]),
(cycler(color=["k"]), ["k"]),
(cycler(color=["k", "y"]), ["k", "y"]),
(cycler(color=["y", "k"]), ["y", "k"]),
(cycler(color=["b", "r"]), ["b", "r"]),
(cycler(color=["r", "b"]), ["r", "b"]),
(cycler(lw=[1, 2]), [".15"]), # no color in cycle
],
)
def test_get_color_cycle(cycler, result):
with mpl.rc_context(rc={"axes.prop_cycle": cycler}):
assert get_color_cycle() == result
def test_remove_na():
a_array = np.array([1, 2, np.nan, 3])
a_array_rm = remove_na(a_array)
assert_array_equal(a_array_rm, np.array([1, 2, 3]))
a_series = pd.Series([1, 2, np.nan, 3])
a_series_rm = remove_na(a_series)
assert_series_equal(a_series_rm, pd.Series([1., 2, 3], [0, 1, 3]))
def test_assign_default_kwargs():
def f(a, b, c, d):
pass
def g(c=1, d=2):
pass
kws = {"c": 3}
kws = _assign_default_kwargs(kws, f, g)
assert kws == {"c": 3, "d": 2}
def test_check_argument():
opts = ["a", "b", None]
assert _check_argument("arg", opts, "a") == "a"
assert _check_argument("arg", opts, None) is None
assert _check_argument("arg", opts, "aa", prefix=True) == "aa"
assert _check_argument("arg", opts, None, prefix=True) is None
with pytest.raises(ValueError, match="The value for `arg`"):
_check_argument("arg", opts, "c")
with pytest.raises(ValueError, match="The value for `arg`"):
_check_argument("arg", opts, "c", prefix=True)
with pytest.raises(ValueError, match="The value for `arg`"):
_check_argument("arg", opts[:-1], None)
with pytest.raises(ValueError, match="The value for `arg`"):
_check_argument("arg", opts[:-1], None, prefix=True)
def test_draw_figure():
f, ax = plt.subplots()
ax.plot(["a", "b", "c"], [1, 2, 3])
_draw_figure(f)
assert not f.stale
# ticklabels are not populated until a draw, but this may change
assert ax.get_xticklabels()[0].get_text() == "a"
def test_deprecate_ci():
msg = "\n\nThe `ci` parameter is deprecated. Use `errorbar="
with pytest.warns(FutureWarning, match=msg + "None"):
out = _deprecate_ci(None, None)
assert out is None
with pytest.warns(FutureWarning, match=msg + "'sd'"):
out = _deprecate_ci(None, "sd")
assert out == "sd"
with pytest.warns(FutureWarning, match=msg + r"\('ci', 68\)"):
out = _deprecate_ci(None, 68)
assert out == ("ci", 68)
def test_version_predates():
mock = ModuleType("mock")
mock.__version__ = "1.2.3"
assert _version_predates(mock, "1.2.4")
assert _version_predates(mock, "1.3")
assert not _version_predates(mock, "1.2.3")
assert not _version_predates(mock, "0.8")
assert not _version_predates(mock, "1")